Skip to content

Instantly share code, notes, and snippets.

@72squared
Created May 22, 2016 15:28
Show Gist options
  • Save 72squared/435ba452b295cafeb3ae575e39228213 to your computer and use it in GitHub Desktop.
Save 72squared/435ba452b295cafeb3ae575e39228213 to your computer and use it in GitHub Desktop.
import threading
import unittest
import sys
from functools import wraps
import time
class Async(threading.Thread):
def __init__(self, target, args=(), kwargs=None):
super(Async, self).__init__()
self._target = target
self._args = args
self._kwargs = kwargs
self._exc_info = None
self._result = None
def run(self):
# noinspection PyBroadException
try:
self._result = self._target(*self._args, **self._kwargs)
except Exception:
self._exc_info = sys.exc_info()
finally:
# Avoid a refcycle if the thread is running a function with
# an argument that has a member that points to the thread.
del self._target, self._args, self._kwargs
def __call__(self):
if self.is_alive():
self.join()
if self._exc_info is not None:
raise self._exc_info[0], self._exc_info[1], self._exc_info[2]
return self._result
@property
def result(self):
return self.__call__()
def async(fn):
@wraps(fn)
def run(*args, **kwargs):
t = Async(target=fn, args=args, kwargs=kwargs)
t.start()
return t
return run
class TestFuture(unittest.TestCase):
def test_async(self):
@async
def foo(n):
time.sleep(0.2)
return n + n
@async
def bar(n):
time.sleep(0.2)
return n * n
@async
def bazz():
if True:
raise Exception('oops')
start = time.time()
foo_future = foo(1)
bar_future = bar(1)
bazz_future = bazz()
print foo_future()
print bar_future.result
try:
bazz_future()
except Exception:
print "exception caught"
elapsed = time.time() - start
print "elapsed: %.3f" % elapsed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment