Skip to content

Instantly share code, notes, and snippets.

@mdbenito
Created September 17, 2023 15:49
Show Gist options
  • Save mdbenito/7a964f3cdc2ec96380c75cb85e68618b to your computer and use it in GitHub Desktop.
Save mdbenito/7a964f3cdc2ec96380c75cb85e68618b to your computer and use it in GitHub Desktop.
Batched iterators
from typing import Type, Callable, Iterable, List, TypeVar, Generator, cast
import pytest
T = TypeVar('T')
def class_wraps(original_cls: Type) -> Callable[[Type], Type]:
""" A decorator to update the wrapped class's docstring with that of the
original class (plus that of the wrapped class).
Args:
original_cls: The original class whose docstring should be copied.
Returns:
A decorator function that updates the docstring of the wrapped class.
"""
def wrapper(decorated_cls: Type) -> Type:
if decorated_cls.__doc__ is None:
decorated_cls.__doc__ = original_cls.__doc__
elif original_cls.__doc__ is not None:
decorated_cls.__doc__ = original_cls.__doc__ + '\n' + decorated_cls.__doc__
return decorated_cls
return wrapper
def batched(initial_size: int) -> Callable[[Type[Iterable[T]]], Type[Iterable[List[T]]]]:
"""
Decorates an iterable to yield its items in batches.
This decorator wraps the `__iter__` method of an iterable object, allowing
items to be yielded in batches. The size of the batches can be dynamically
adjusted during iteration using the `send` method on the generator.
It can also be adjusted by setting the `batch_size` property of the
iterable.
Args:
initial_size: The initial size of the batches. Must be a positive
integer.
Returns:
A class with an `__iter__` method that yields items in batches.
Raises:
ValueError: If the `initial_size` is not a positive integer.
!!! example "Example usage"
```python
@batched(3)
class MyIterable:
...
iterable_instance = MyIterable(...)
generator = iter(iterable_instance)
next(generator) # Yields first batch of size 3
generator.send(2) # Adjusts batch size to 5 and yields next batch
generator.send(-1) # Adjusts batch size to 4 and yields next batch
```
"""
if initial_size <= 0:
raise ValueError("Initial batch size should be a positive integer")
def decorator(cls: Type[Iterable[T]]) -> Type[Iterable[List[T]]]:
@class_wraps(cls)
class BatchedWrapper(cls): # type: ignore # Can't use vars as types
# Weird docstring thinking of the decorated class:
"""
## Batched iteration
The iterable can yield its items in batches.
The size of the batches can be dynamically adjusted during iteration
using the `send` method on the generator. It can also be adjusted by
setting the `batch_size` property of the iterable.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__batch_size = initial_size
@property
def batch_size(self) -> int:
"""Get the current batch size."""
return self.__batch_size
@batch_size.setter
def batch_size(self, new_size: int):
"""Set a new batch size."""
if new_size <= 0:
raise ValueError("Batch size should be a positive integer")
self.__batch_size = new_size
def __iter__(self) -> Generator[List[T], int, None]:
iterator = super().__iter__()
batch = []
for item in iterator:
batch.append(item)
while len(batch) >= self.__batch_size:
to_yield = batch[:self.__batch_size]
# yield the batch and receive the new batch size if it was sent
adjust_by = yield to_yield
if adjust_by is not None:
self.batch_size += adjust_by
# Adjust the batch by removing the items that were yielded
batch = batch[len(to_yield):]
if batch:
yield batch
return BatchedWrapper
return decorator
def test_batched_decorator_basic_functionality():
"""Test basic functionality of the batched decorator."""
@batched(3)
class SampleIterable:
def __init__(self, data):
self.data = data
def __iter__(self):
return iter(self.data)
iterable = SampleIterable(list(range(10)))
batches = list(iter(iterable))
assert batches == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
def test_batched_decorator_dynamic_batch_size():
"""Test dynamic batch size adjustment of the batched decorator."""
@batched(3)
class SampleIterable:
def __init__(self, data):
self.data = data
def __iter__(self):
return iter(self.data)
iterable = SampleIterable(list(range(10)))
gen = iter(iterable)
assert next(gen) == [0, 1, 2]
assert gen.send(2) == [3, 4, 5, 6, 7]
assert next(gen) == [8, 9]
def test_batched_decorator_invalid_batch_size():
"""Test that the batched decorator raises an error when the batch size is
invalid."""
with pytest.raises(ValueError):
@batched(0)
class SampleIterable:
def __init__(self, data):
self.data = data
def __iter__(self):
return iter(self.data)
with pytest.raises(ValueError):
@batched(-1)
class SampleIterable:
def __init__(self, data):
self.data = data
def __iter__(self):
return iter(self.data)
@batched(1)
class SampleIterable:
def __init__(self, data):
self.data = data
def __iter__(self):
return iter(self.data)
iterable = SampleIterable(list(range(10)))
gen = iter(iterable)
assert next(gen) == [0]
with pytest.raises(ValueError):
iterable.batch_size = -1
with pytest.raises(ValueError):
gen.send(-1)
def test_batched_decorator_batch_size_property():
"""Test batch size adjustment of the batched decorator using the batch_size
setter."""
@batched(1)
class SampleIterable:
def __init__(self, data):
self.data = data
def __iter__(self):
return iter(self.data)
iterable = SampleIterable(list(range(10)))
gen = iter(iterable)
assert next(gen) == [0]
iterable.batch_size = 3
assert next(gen) == [1, 2, 3]
iterable.batch_size = 2
assert next(gen) == [4, 5]
iterable.batch_size = 5
assert next(gen) == [6, 7, 8, 9]
test_batched_decorator_dynamic_batch_size()
test_batched_decorator_basic_functionality()
test_batched_decorator_batch_size_property()
test_batched_decorator_invalid_batch_size()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment