from __future__ import annotations import dataclasses import itertools from collections.abc import Callable, Iterable, Iterator from typing import Generic, Protocol, Self, TypeVar class SupportsLessThan(Protocol): def __lt__(self, other: Self, /) -> bool: raise NotImplementedError() T = TypeVar("T") KeyT = TypeVar("KeyT", bound=SupportsLessThan) GroupT = TypeVar("GroupT") def groupby( it: Iterable[T], key_fn: Callable[[T], KeyT], group_fn: Callable[[Iterable[T]], GroupT] = lambda x: x, ) -> Iterator[tuple[KeyT, GroupT]]: for k, g in itertools.groupby(sorted(it, key=key_fn), key=key_fn): yield k, group_fn(g) def identity(t: T) -> T: return t U = TypeVar("U") V = TypeVar("V") def map_none(fn: Callable[[U], V], value: U | None) -> V | None: if value is None: return None return fn(value) class FnChain(Generic[U]): def __init__(self, fn: Callable[[], U], /): self._fn: Callable[[], U] = fn def __or__(self, next_fn: Callable[[U], V], /) -> FnChain[V]: return FnChain(lambda: next_fn(self._fn())) def result(self) -> U: return self._fn() @classmethod def transform(cls, x: V, /) -> FnChain[V]: return FnChain(lambda: x) def batched(iterable: Iterable[T], n: int) -> Iterator[tuple[T, ...]]: if n < 1: raise ValueError('n must be at least one') iterator = iter(iterable) while batch := tuple(itertools.islice(iterator, n)): yield batch ResultT = TypeVar("ResultT") ErrorT = TypeVar("ErrorT") @dataclasses.dataclass class Result(Generic[ResultT, ErrorT]): result: ResultT | None = None error: ErrorT | None = None