69 lines
1.6 KiB
Python
69 lines
1.6 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import itertools
|
|
from collections.abc import Callable, Iterable, Iterator
|
|
from typing import Generic, Protocol, Self, TypeVar
|
|
|
|
|
|
class SupportsLessThan(Protocol):
|
|
def __lt__(self, other: Self, /) -> bool:
|
|
raise NotImplementedError()
|
|
|
|
|
|
T = TypeVar("T")
|
|
KeyT = TypeVar("KeyT", bound=SupportsLessThan)
|
|
GroupT = TypeVar("GroupT")
|
|
|
|
|
|
def groupby(
|
|
it: Iterable[T],
|
|
key_fn: Callable[[T], KeyT],
|
|
group_fn: Callable[[Iterable[T]], GroupT] = lambda x: x,
|
|
) -> Iterator[tuple[KeyT, GroupT]]:
|
|
for k, g in itertools.groupby(sorted(it, key=key_fn), key=key_fn):
|
|
yield k, group_fn(g)
|
|
|
|
|
|
U = TypeVar("U")
|
|
V = TypeVar("V")
|
|
|
|
|
|
def map_none(fn: Callable[[U], V], value: U | None) -> V | None:
|
|
if value is None:
|
|
return None
|
|
return fn(value)
|
|
|
|
|
|
class FnChain(Generic[U]):
|
|
def __init__(self, fn: Callable[[], U], /):
|
|
self._fn: Callable[[], U] = fn
|
|
|
|
def __or__(self, next_fn: Callable[[U], V], /) -> FnChain[V]:
|
|
return FnChain(lambda: next_fn(self._fn()))
|
|
|
|
def result(self) -> U:
|
|
return self._fn()
|
|
|
|
@classmethod
|
|
def transform(cls, x: V, /) -> FnChain[V]:
|
|
return FnChain(lambda: x)
|
|
|
|
|
|
def batched(iterable: Iterable[T], n: int) -> Iterator[tuple[T, ...]]:
|
|
if n < 1:
|
|
raise ValueError('n must be at least one')
|
|
iterator = iter(iterable)
|
|
while batch := tuple(itertools.islice(iterator, n)):
|
|
yield batch
|
|
|
|
|
|
ResultT = TypeVar("ResultT")
|
|
ErrorT = TypeVar("ErrorT")
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class Result(Generic[ResultT, ErrorT]):
|
|
result: ResultT | None = None
|
|
error: ErrorT | None = None
|