Created
December 5, 2018 14:41
-
-
Save sgryjp/f408b9d1028bf7649cfe291fed785a0a to your computer and use it in GitHub Desktop.
Stopwatch to measure code execution time inside a `with` statement.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import statistics | |
class Stopwatch(object): | |
"""Stopwatch to measure code execution time. | |
A Stopwatch object measures execution time of code inside one or more | |
`with` statements. Using a single object multiple times let it store | |
timing data every time. | |
If a callable was given to `print_fn` parameter, it will be called with | |
stringified Stopwatch instance (i.e.: `str(stopwatch)`) on each exit from | |
a `with` statement. Typical suitable function would be `print` or | |
`logging.debug`. | |
Example | |
------- | |
>>> sw = Stopwatch() | |
>>> for _ in range(3): | |
... with sw: | |
... sum(n for n in range(1_000_000) | |
>>> str(sw) | |
607 ms ± 8.82 ms per laps (mean ± std. dev. of 3 laps) | |
>>> sw.mean(), sw.stdev() | |
0.6070484847993166 0.00882254275134833 | |
>>> sw.intervals | |
[0.6147285757252803, 0.6090049394119035, 0.5974119392607662] | |
""" | |
def __init__(self, print_fn=None): | |
self._t0 = 0. | |
self.intervals = [] | |
self.print_fn = print_fn | |
def __enter__(self): | |
self._t0 = time.perf_counter() | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.intervals.append(time.perf_counter() - self._t0) | |
if self.print_fn: | |
self.print_fn(str(self)) | |
def __str__(self): | |
laps = len(self.intervals) | |
mean, stdev = self.mean(), self.stdev() | |
if mean >= 1: | |
k = 1 | |
unit = 's' | |
elif mean >= 1e-3: | |
k = 10**3 | |
unit = 'ms' | |
elif mean >= 1e-6: | |
k = 10**6 | |
unit = 'μs' | |
else: | |
k = 10**9 | |
unit = 'ns' | |
mean *= k | |
stdev *= k | |
return ('{mean:.3g} {unit} ± {stdev:.3g} {unit} per laps' | |
' (mean ± std. dev. of {laps} laps)'.format(**vars())) | |
def mean(self, unit='s'): | |
"""Calculate mean of the intervals. | |
Parameters | |
---------- | |
unit : str [optional] | |
Time unit. Allowed values are: 'ms', 'us', 'ns' and 's' (default). | |
Returns | |
------- | |
mean : float | |
Mean of the measured intervals, or NaN if no intervals are | |
measured. | |
""" | |
try: | |
t = statistics.mean(self.intervals) | |
return t * {'ms': 10**3, 'us': 10**6, 'ns': 10**9}.get(unit, 1) | |
except statistics.StatisticsError: | |
return float('nan') | |
def stdev(self, unit='s'): | |
"""Calculate standard deviation of the intervals. | |
Parameters | |
---------- | |
unit : str [optional] | |
Time unit. Allowed values are: 'ms', 'us', 'ns' and 's' (default). | |
Returns | |
------- | |
stdev : float | |
Standard deviation of the measured intervals, or NaN if number of | |
intervals are less than two. | |
""" | |
try: | |
t = statistics.stdev(self.intervals) | |
return t * {'ms': 10**3, 'us': 10**6, 'ns': 10**9}.get(unit, 1) | |
except statistics.StatisticsError: | |
return float('nan') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment