Last active
March 21, 2024 20:50
-
-
Save battmatt/3bd206fba7cd1aa13f37a9c7b88a23a3 to your computer and use it in GitHub Desktop.
Example Parallel Task API based on Celery
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# (c) Copyright 2018 Zymergen, Inc. | |
# All Rights Reserved | |
""" | |
The following is example code used for a technology blog post: https://medium.com/@ZymergenTechBlog/building-a-parallel-task-api-with-celery-dbae5ced4e28 | |
The ForkJoin class can be used to generate a ZWork task that contains a single | |
distributed processing step. Your job should have 3 parts. An initial setup step | |
responsible for splitting of inputs into workable chunks. A process step that can | |
process each chunk in a forked execution process and a join step that puts it all | |
together and returns the final result. | |
""" | |
import copy | |
import functools | |
from types import GeneratorType | |
import cerberus | |
from celery import chain, group, shared_task | |
# Parallelization settings | |
PARALLEL_DEFAULT_GROUP_SIZE = 32 | |
class InputValidationError(TypeError): | |
"""Exception representing a failure to validate job inputs against a Cerberus schema. | |
Attributes: | |
inputs: The inputs that failed validation | |
schema: The Cerberus schema | |
errors: The list of validation errors | |
""" | |
def __init__(self, inputs, schema, errors): | |
self.inputs = inputs | |
self.schema = schema | |
self.errors = errors | |
self.warnings = None | |
err_template = "Inputs did not match schema.\n\tInputs: {}\n\tSchema: {}\n\tErrors: {}" | |
err_msg = err_template.format(self.inputs, self.schema, self.errors) | |
super(InputValidationError, self).__init__(err_msg) | |
def chunks(l, n): | |
"""A utility function for splitting a list into n roughly-equal sized chunks. | |
""" | |
return [l[i::n] for i in xrange(n)] | |
def flatten(xs): | |
"""Flatten a list. | |
Turns a nested list into a flat list. | |
Only goes one layer deep, not fully nested. | |
Args: | |
xs: a list of elements and lists | |
Returns: | |
A list with those same (and more) elements, but no lists | |
""" | |
for x in xs: | |
if isinstance(x, list) or isinstance(x, GeneratorType): | |
for elem in x: | |
yield elem | |
else: | |
yield x | |
def validate_inputs(ins, schema): | |
"""Validates JSON inputs against a Cerberus schema. | |
""" | |
validator = cerberus.Validator() | |
is_valid = validator.validate(ins, schema) | |
if not is_valid: | |
raise InputValidationError(ins, schema, validator.errors) | |
# def strips_extra_args(fn): | |
# """Makes a function able to accept arbitrary extra keyword arguments. | |
# | |
# Returns a version of fn that can safely receive extra arguments, and raises | |
# an exception upon incorrect argumentation that is more useful than Python's | |
# default. | |
# | |
# Details omitted | |
# | |
# Args: | |
# fn: any function | |
# Returns: | |
# fn, but able to accept extra keyword arguments | |
# Raises: | |
# TypeError is raised whenever too few arguments are passed in. | |
# """ | |
# return fn | |
class ForkJoinTask(object): | |
"""Represents a task that can be run in parallel. | |
To implement your own parallel task: | |
1. Write a new class, MyTask, that inherits from this one | |
2. Override the command and job_inputs properties | |
3. Write business logic functions outside of the MyTask class | |
- a "setup" for dividing inputs into units of work, each unit is a dictionary | |
with keys that can be used as arg names for the process step | |
- a "processer" for performing a single unit of work defined by the setup step | |
- a "joiner" for recombining results from the work units | |
4. Decorate those functions with MyTask.setup(), MyTask.process(), and MyTask.join() | |
- see their docstrings below for details, restrictions, and expectations | |
5. Input args are saved as 'orig_input_args`, which is dict of the original args | |
that can be accessed in individual method signatures (e.g., process, join). | |
Why do I need to implement a new static singleton? Can't this be more dynamic? | |
No. This limitation comes from Celery, not us. | |
""" | |
queue = 'parallel' | |
def run_without_celery(self, job_args): | |
""" | |
Run this task as a job without using the celery backend. This entire task will run in a | |
single process, and there is no serialization/deserialization to/from JSON between steps, | |
but otherwise should mimic the celery case exactly. | |
Args: | |
job_args: The input dictionary to the job | |
Returns: | |
The job results dictionary. | |
""" | |
# avoid changing the global state of original job args | |
nested_job_args = job_args.copy() | |
nested_job_args['orig_input_args'] = job_args.copy() | |
setup_results = self.setup_step((), **nested_job_args) | |
process_results = [self.process_step(setup_results, i, **nested_job_args) | |
for i in range(len(setup_results))] | |
return self.join_step(process_results, **nested_job_args) | |
@classmethod | |
def get_input_schema(cls, *args, **kwargs): | |
"""This method provides a Cerberus schema for job input validation. | |
Returns: | |
A Cerberus schema that will be used to validate job inputs. | |
""" | |
raise NotImplementedError("Tasks MUST implement a schema") | |
@classmethod | |
def split(cls, work_units): | |
"""Override this to explicitly control which workers do which work. | |
You can almost certainly use this default. | |
Args: | |
work_units (list<dict>): The inputs for each possible process() step | |
This is the output from the setup() step. | |
Returns: | |
A list of list<dict>s. Each parent list corresponds to work for one worker. | |
""" | |
expected_workers = PARALLEL_DEFAULT_GROUP_SIZE | |
return chunks(work_units, min(expected_workers, len(work_units))) | |
@classmethod | |
def setup(cls, task): | |
"""Decorates a function that divides job inputs into units of work. | |
This decorator MUST be used EXACTLY ONCE per job like this: | |
@MyTask.setup | |
def divide_job_inputs(arg0_from_schema, arg1_from_schema, ...): | |
return [{'arg': 'val0'}, {'arg': 'val1'}] | |
Args (passed into your decorated function): | |
**kwargs (dict): Corresponds to parsed and validated job inputs. | |
These should match keys from your jobs' get_input_schema(). | |
Return (expected from your decorated function): | |
A list of dicts. Each dict represents a subset of inputs needed to do one | |
unit of parallel work. The keys in this dict should match the params | |
from your process() step function. Remaining inputs to procces() are | |
provided by parsed/validated job inputs. | |
""" | |
@functools.wraps(task) | |
def setter_upper(*args, **kwargs): | |
# We use cerberus for schema validation | |
schema = strips_extra_args(cls.get_input_schema)(**kwargs) | |
validate_inputs(kwargs, schema) | |
normalized_inputs = cerberus.Validator().normalized(kwargs, schema) | |
result = strips_extra_args(task)(**normalized_inputs) | |
split_results = cls.split(result) | |
return split_results | |
cls.setup_step = lifecycle_task(setter_upper, queue=cls.queue) | |
return cls.setup_step | |
@classmethod | |
def process(cls, task): | |
"""Decorates a function that performs one unit of parallel work. | |
This decorator MUST be used EXACTLY ONCE per job like this: | |
@MyTask.process | |
def do_parallel_work(param0, param1, ...): | |
return "whatever I want" | |
Args (passed into your decorated function): | |
**kwargs (dict): All the parsed/validated job inputs, PLUS all the | |
arguments from one of setup()'s outputs. Your function will be | |
passed only the arguments explicitly claimed in its signature. | |
In the event of key collision, arguments from one of setup()'s | |
outputs take precedent over parsed/validated job inputs. | |
Return (expected from your decorated function): | |
Your function can return any serializable object. | |
""" | |
@functools.wraps(task) | |
def process_inputs(divided_inputs, group_index, **kwargs): | |
try: | |
my_share_of_work = divided_inputs[group_index] | |
except IndexError: | |
return [] | |
outs = [] | |
schema = strips_extra_args(cls.get_input_schema)(**kwargs) | |
normalized_kwargs = cerberus.Validator().normalized(kwargs, schema) | |
for work_unit in my_share_of_work: | |
work_unit_kwargs = copy.deepcopy(normalized_kwargs) | |
work_unit_kwargs.update(work_unit) | |
result = strips_extra_args(task)(**work_unit_kwargs) | |
outs.append(result) | |
return outs | |
cls.process_step = lifecycle_task(process_inputs, queue=cls.queue) | |
return cls.process_step | |
@classmethod | |
def join(cls, task): | |
"""Decorates a function that recombines parallel results. | |
This decorator MUST be used EXACTLY ONCE per job like this: | |
@MyTask.join | |
def recombine_parallel_results(parallel_results, **kwargs): | |
return "whatever I want" | |
Args (passed into your decorated function): | |
parallel_results (list): A list of all values that were returned from | |
an invocation of the process() step. | |
**kwargs (dict): All the parsed/validated job inputs. | |
Return (expected from your decorated function): | |
Any serializable object. This object is the job's final result. | |
""" | |
@functools.wraps(task) | |
def joiner(distributed_results, *args, **kwargs): | |
schema = strips_extra_args(cls.get_input_schema)(**kwargs) | |
normalized_kwargs = cerberus.Validator().normalized(kwargs, schema) | |
# The process step nested its results. Unnest them. | |
unnested_results = list(flatten(distributed_results)) | |
results = strips_extra_args(task)(unnested_results, *args, **normalized_kwargs) | |
return results | |
# place the final step on a different queue so it's not blocked by other running | |
# parallel jobs | |
cls.join_step = lifecycle_task(joiner, queue='celery') | |
return cls.join_step | |
@classmethod | |
def signature(cls, options): | |
"""Returns a Celery signature to describe the complete job-step sequence. | |
Args: | |
options (dict): The shared options for the job (auth, request, etc.) | |
Returns: | |
A Celery signature. | |
""" | |
return fork_join_task(cls.setup_step, cls.process_step, cls.join_step, options) | |
def fork_join_task(setup_step, process_step, join_step, bound_args): | |
"""Creates a parallel Celery fork/join task from provided functions. | |
Args: | |
setup_step (celery task): A "setup" step for the whole job | |
process_step (celery task): A "process" step that runs in parallel after setup | |
join_step (celery task): A "join" step to recombine parallel results | |
bound_args (dict): Any bound arguments that can be accessed by all steps | |
Returns: | |
A new Celery job that performs a setup/process/join work pattern. | |
The returned job's steps will all be partially applied over bound_args. | |
""" | |
setup_sig = setup_step.signature(**bound_args) | |
process_sig = parallel_processing_step(process_step, bound_args) | |
join_sig = join_step.signature(**bound_args) | |
return chain(setup_sig, process_sig, join_sig) | |
def parallel_processing_step( | |
process_step, bound_args, group_size=PARALLEL_DEFAULT_GROUP_SIZE): | |
"""Returns a "group" signature for a distributed application of process_fn. | |
""" | |
signatures = [process_step.signature(group_index=i, **bound_args) for i in range(group_size)] | |
return group(signatures) | |
def lifecycle_task(task, queue): | |
"""Makes a Celery task from a Python function. | |
This runner can act as a single step within a parallel job. | |
Args: | |
task (fn): the function to run | |
queue (str): The name of the Celery queue to use. | |
Returns: | |
A Celery task that runs the provided function | |
""" | |
name = "{}.{}".format(task.__module__, task.__name__) | |
@shared_task(ignore_result=False, | |
name=name, queue=queue, options={'queue': queue}) | |
def internal_runner(*args, **kwargs): | |
return task(*args, **kwargs) | |
return internal_runner |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment