Created
September 17, 2023 15:49
-
-
Save mdbenito/7a964f3cdc2ec96380c75cb85e68618b to your computer and use it in GitHub Desktop.
Batched iterators
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Type, Callable, Iterable, List, TypeVar, Generator, cast | |
import pytest | |
T = TypeVar('T') | |
def class_wraps(original_cls: Type) -> Callable[[Type], Type]: | |
""" A decorator to update the wrapped class's docstring with that of the | |
original class (plus that of the wrapped class). | |
Args: | |
original_cls: The original class whose docstring should be copied. | |
Returns: | |
A decorator function that updates the docstring of the wrapped class. | |
""" | |
def wrapper(decorated_cls: Type) -> Type: | |
if decorated_cls.__doc__ is None: | |
decorated_cls.__doc__ = original_cls.__doc__ | |
elif original_cls.__doc__ is not None: | |
decorated_cls.__doc__ = original_cls.__doc__ + '\n' + decorated_cls.__doc__ | |
return decorated_cls | |
return wrapper | |
def batched(initial_size: int) -> Callable[[Type[Iterable[T]]], Type[Iterable[List[T]]]]: | |
""" | |
Decorates an iterable to yield its items in batches. | |
This decorator wraps the `__iter__` method of an iterable object, allowing | |
items to be yielded in batches. The size of the batches can be dynamically | |
adjusted during iteration using the `send` method on the generator. | |
It can also be adjusted by setting the `batch_size` property of the | |
iterable. | |
Args: | |
initial_size: The initial size of the batches. Must be a positive | |
integer. | |
Returns: | |
A class with an `__iter__` method that yields items in batches. | |
Raises: | |
ValueError: If the `initial_size` is not a positive integer. | |
!!! example "Example usage" | |
```python | |
@batched(3) | |
class MyIterable: | |
... | |
iterable_instance = MyIterable(...) | |
generator = iter(iterable_instance) | |
next(generator) # Yields first batch of size 3 | |
generator.send(2) # Adjusts batch size to 5 and yields next batch | |
generator.send(-1) # Adjusts batch size to 4 and yields next batch | |
``` | |
""" | |
if initial_size <= 0: | |
raise ValueError("Initial batch size should be a positive integer") | |
def decorator(cls: Type[Iterable[T]]) -> Type[Iterable[List[T]]]: | |
@class_wraps(cls) | |
class BatchedWrapper(cls): # type: ignore # Can't use vars as types | |
# Weird docstring thinking of the decorated class: | |
""" | |
## Batched iteration | |
The iterable can yield its items in batches. | |
The size of the batches can be dynamically adjusted during iteration | |
using the `send` method on the generator. It can also be adjusted by | |
setting the `batch_size` property of the iterable. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.__batch_size = initial_size | |
@property | |
def batch_size(self) -> int: | |
"""Get the current batch size.""" | |
return self.__batch_size | |
@batch_size.setter | |
def batch_size(self, new_size: int): | |
"""Set a new batch size.""" | |
if new_size <= 0: | |
raise ValueError("Batch size should be a positive integer") | |
self.__batch_size = new_size | |
def __iter__(self) -> Generator[List[T], int, None]: | |
iterator = super().__iter__() | |
batch = [] | |
for item in iterator: | |
batch.append(item) | |
while len(batch) >= self.__batch_size: | |
to_yield = batch[:self.__batch_size] | |
# yield the batch and receive the new batch size if it was sent | |
adjust_by = yield to_yield | |
if adjust_by is not None: | |
self.batch_size += adjust_by | |
# Adjust the batch by removing the items that were yielded | |
batch = batch[len(to_yield):] | |
if batch: | |
yield batch | |
return BatchedWrapper | |
return decorator | |
def test_batched_decorator_basic_functionality(): | |
"""Test basic functionality of the batched decorator.""" | |
@batched(3) | |
class SampleIterable: | |
def __init__(self, data): | |
self.data = data | |
def __iter__(self): | |
return iter(self.data) | |
iterable = SampleIterable(list(range(10))) | |
batches = list(iter(iterable)) | |
assert batches == [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] | |
def test_batched_decorator_dynamic_batch_size(): | |
"""Test dynamic batch size adjustment of the batched decorator.""" | |
@batched(3) | |
class SampleIterable: | |
def __init__(self, data): | |
self.data = data | |
def __iter__(self): | |
return iter(self.data) | |
iterable = SampleIterable(list(range(10))) | |
gen = iter(iterable) | |
assert next(gen) == [0, 1, 2] | |
assert gen.send(2) == [3, 4, 5, 6, 7] | |
assert next(gen) == [8, 9] | |
def test_batched_decorator_invalid_batch_size(): | |
"""Test that the batched decorator raises an error when the batch size is | |
invalid.""" | |
with pytest.raises(ValueError): | |
@batched(0) | |
class SampleIterable: | |
def __init__(self, data): | |
self.data = data | |
def __iter__(self): | |
return iter(self.data) | |
with pytest.raises(ValueError): | |
@batched(-1) | |
class SampleIterable: | |
def __init__(self, data): | |
self.data = data | |
def __iter__(self): | |
return iter(self.data) | |
@batched(1) | |
class SampleIterable: | |
def __init__(self, data): | |
self.data = data | |
def __iter__(self): | |
return iter(self.data) | |
iterable = SampleIterable(list(range(10))) | |
gen = iter(iterable) | |
assert next(gen) == [0] | |
with pytest.raises(ValueError): | |
iterable.batch_size = -1 | |
with pytest.raises(ValueError): | |
gen.send(-1) | |
def test_batched_decorator_batch_size_property(): | |
"""Test batch size adjustment of the batched decorator using the batch_size | |
setter.""" | |
@batched(1) | |
class SampleIterable: | |
def __init__(self, data): | |
self.data = data | |
def __iter__(self): | |
return iter(self.data) | |
iterable = SampleIterable(list(range(10))) | |
gen = iter(iterable) | |
assert next(gen) == [0] | |
iterable.batch_size = 3 | |
assert next(gen) == [1, 2, 3] | |
iterable.batch_size = 2 | |
assert next(gen) == [4, 5] | |
iterable.batch_size = 5 | |
assert next(gen) == [6, 7, 8, 9] | |
test_batched_decorator_dynamic_batch_size() | |
test_batched_decorator_basic_functionality() | |
test_batched_decorator_batch_size_property() | |
test_batched_decorator_invalid_batch_size() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment