-
-
Save yinhm/4084683 to your computer and use it in GitHub Desktop.
Rough draft, datafeed adapater for zipline
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# | |
# Copyright 2012 fawce@zipline | |
# Copyright 2012 yinhm | |
""" | |
Generator-style datafeed adapater for zipline. | |
Based on https://gist.github.com/4057021 | |
""" | |
import datetime | |
import pytz | |
from zipline import ndict | |
from zipline.gens.utils import hash_args, \ | |
assert_trade_protocol | |
import zipline.protocol as zp | |
class ZiplineAdapter(object): | |
"""A generator that takes a pymongo Collection object, a list of | |
filters, a start date and an end_date and yields ndicts containing | |
the results of a query to its collection with the given filter, | |
start, and end. The output is also packaged with a unique | |
source_id string for downstream sorting | |
""" | |
def __init__(self, client, symbols, start_date, end_date): | |
assert isinstance(symbols, list) | |
assert isinstance(start_date, datetime.datetime) | |
assert isinstance(end_date, datetime.datetime) | |
assert start_date.tzinfo == pytz.utc | |
assert end_date.tzinfo == pytz.utc | |
self.client = client | |
self.symbols = frozenset(symbols) | |
self.start_date = start_date | |
self.end_date = end_date | |
# Create unique identifier string that can be used to break | |
# sorting ties deterministically. | |
self.argstring = hash_args(str(client), start_date, end_date) | |
self.namestring = self.__class__.__name__ + self.argstring | |
self.iterator = None | |
def __iter__(self): | |
return self | |
def next(self): | |
if self.iterator: | |
return self.iterator.next() | |
else: | |
self.iterator = self._gen() | |
return self.iterator.next() | |
def rewind(self): | |
self.iterator = self._gen() | |
def get_hash(self): | |
return self.namestring | |
def _gen(self): | |
# Set up internal iterator. This outputs raw dictionaries. | |
cursor = self._create_iterator( | |
self.symbols, | |
self.start_date, | |
self.end_date | |
) | |
for event in cursor: | |
# Construct a new event that fulfills the datasource protocol. | |
event['type'] = zp.DATASOURCE_TYPE.TRADE | |
event['source_id'] = self.namestring | |
payload = ndict(event) | |
assert_trade_protocol(payload) | |
yield payload | |
def _create_iterator(self, symbols, start_date, end_date): | |
""" | |
Returns an iterator that spits out raw objects retrieved from | |
datafeed server. | |
""" | |
# ['datetime','sid','volume','high','low','close','open'] | |
# datetime is the datetime in unix time (ms since the epoch) | |
for symbol in symbols: | |
# @FIXME: get_day not support date range yet. | |
y = self.client.get_day(symbol, 1000) | |
names = y.dtype.names | |
for row in y: | |
data = dict(zip(names, row.tolist())) | |
# @FIXME: int sid??? | |
# Zipline specific: integer sid?? | |
#row['sid'] = int(symbol) | |
data['sid'] = 123 | |
data['dt'] = datetime.datetime.fromtimestamp(data['time']).replace( | |
tzinfo=pytz.utc) | |
if data['dt'] < start_date: | |
continue | |
if data['dt'] > end_date: | |
raise StopIteration | |
# Zipline specific | |
data['price'] = data['close'] | |
# Zipline specific: integer Volume | |
data['volume'] = int(data['volume']) | |
del(data['time']) | |
del(data['amount']) | |
yield data | |
if __name__ == '__main__': | |
from datafeed.client import Client | |
client = Client() | |
# Retrieve 1000 days OHLCs data | |
symbols = ['SH600036'] | |
date_end = datetime.datetime.now() | |
date_end = date_end.replace(tzinfo=pytz.utc) | |
date_start = date_end - datetime.timedelta(100) | |
zdata = ZiplineAdapter(client, symbols, date_start, date_end) | |
for row in zdata: | |
print row |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment