Last active
April 10, 2018 21:14
-
-
Save amitu/f6241ed802afa3312713efd03ef10509 to your computer and use it in GitHub Desktop.
Acko's api framework
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
# License: BSD | |
import json | |
import os | |
import re | |
import time | |
from django import forms | |
from django.conf import settings | |
from django.contrib.auth import authenticate, get_user_model, login | |
from django.contrib.postgres.forms import JSONField, SimpleArrayField | |
from django.core.exceptions import ObjectDoesNotExist | |
from django.core.paginator import EmptyPage, Paginator | |
from django.db.models import Model, Q | |
from django.http import Http404, JsonResponse | |
from django.shortcuts import render | |
from django.template.loader import render_to_string | |
import jsonschema | |
import structlog | |
from encrypted_id.models import EncryptedIDDecodeError, EncryptedIDModel | |
import r2d2.utils as r2d2_utils | |
from acko import constants, helpers, utils | |
from acko.models import Quote as QuoteModel | |
from acko.models import Asset, Policy | |
from acko.utils import JSONEncoder, get_point_from_latlong | |
from masters.models import RTO, Pincode, Variant | |
from pas import pas_cloud as pas | |
from pas.api import exceptions as pas_exceptions | |
from users.models import Phone, User, UserProfile | |
from users.utils import sanitize_phone | |
logger = structlog.get_logger() | |
class ApiJsonEncoder(JSONEncoder): | |
def default(self, obj): | |
if isinstance(obj, EncryptedIDModel): | |
return obj.ekey | |
if isinstance(obj, Model): | |
return obj.pk | |
return super().default(obj) | |
class IntegerField(forms.IntegerField): | |
""" | |
If this field is used, and some integer field exists in schema, but is | |
not passed by client in request, we will pick a default value of 0. | |
If you want to override the default value for some field, add the | |
following to your API subclass: | |
class MyAPI(api.API): | |
FIELD_MAPPING = { | |
'answer': (api.IntegerField, {"default": 42}) | |
} | |
""" | |
def __init__(self, default=0, *args, **kw): | |
self._default = default | |
super().__init__(*args, **kw) | |
def to_python(self, value): | |
value = super().to_python(value=value) | |
if value is None: | |
value = self._default | |
return value | |
class CharField(forms.CharField): | |
def to_python(self, value): | |
value = super().to_python(value=value) | |
if value is None: | |
# TODO: this will screw with django's form validation | |
# framework to a bit, have to verify what exactly is the | |
# problem and whats the solution. | |
value = "" | |
return value | |
class EkeyField(forms.ModelChoiceField): | |
def __init__(self, queryset, *args, **kw): | |
self._model = queryset.model | |
super().__init__(queryset, *args, **kw) | |
def to_python(self, value): | |
if value in self.empty_values: | |
return None | |
try: | |
obj = self._model.objects.get_by_ekey(value) | |
except (self.queryset.model.DoesNotExist, EncryptedIDDecodeError): | |
raise ValidationError( | |
str(self.error_messages['invalid_choice']), | |
code='invalid_choice' | |
) | |
return obj | |
class LatLongField(forms.CharField): | |
def to_python(self, value): | |
value = super().to_python(value) | |
try: | |
return get_point_from_latlong(value) | |
except ValueError: | |
return None | |
settings.FIELD_MAPPING.update({ | |
"string": (CharField, {}), | |
"integer": (IntegerField, {}), | |
"object": (JSONField, {}), | |
"boolean": (forms.BooleanField, {}), | |
"latlong": (LatLongField, {}), | |
"user_id": (EkeyField, { | |
'queryset': get_user_model().objects.all() | |
}) | |
}) | |
error_registry = {} | |
# noinspection PyInitNewSignature,PyMethodParameters | |
class ErrorMeta(type): | |
def __new__(cls, clsname, bases, clsdict): | |
new = super().__new__(cls, clsname, bases, clsdict) | |
new._code = clsdict["__module__"][:-4] + "." + clsname | |
error_registry[new._code] = new | |
return new | |
# noinspection PyMethodMayBeStatic | |
class APIError(forms.ValidationError, metaclass=ErrorMeta): | |
""" | |
Add template = "name of template" if you want to overwrite the | |
template to be used for this error. Consider overwrite .context() if | |
it makes sense. | |
""" | |
def __init__(self, message, code=None, params=None, field=None, **context): | |
self.human = message | |
self._context = context | |
self._field = field | |
super().__init__(message, code, params) | |
def error(self): | |
return { | |
"human": self.human, | |
"code": self._code, | |
"context": self._context | |
} | |
@classmethod | |
def context(cls, request, human, code, context): | |
""" | |
This method can be used to embed extra information in the | |
context. Say we want to fetch some data from database to | |
better help our user understand what went wrong, or what other | |
values they can try etc. | |
""" | |
return { | |
"human": human, | |
"code": code, | |
"context": context, | |
"request": request | |
} | |
class APIGone(object): | |
pass | |
class SchemaError(APIError): | |
""" | |
We are treating schema errors to be "code bugs", it is not customer's | |
fault, its developer's. | |
""" | |
class DjValidationError(APIError): | |
""" | |
Normally we will not use this, this is only if we are using some | |
django field that gives this error, all .clean_xxx() methods would be | |
raising more specialised error. | |
""" | |
class ValidationError(APIError): | |
""" | |
Generic form validation error. All our .clean_xxx() will be raising | |
this error | |
""" | |
def camel_to_snake(camel): | |
# FIXME CamelCaseURLx -> camel_case_ur_lx | |
snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) | |
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() | |
def snake_to_camel(snake): | |
return ''.join(word.title() for word in snake.split('_')) | |
class APIRegistry(object): | |
csrf_exempt = True | |
def __init__(self): | |
self._registry = {} | |
def register(self, mod, cls, name, v): | |
key = (name, v) | |
self._registry[key] = APISpec(mod, cls, name, v) | |
def handle_one(self, request, method, name, version, data): | |
# request MUST be only used for auth etc, not for data | |
spec = self._find_spec(name, version) | |
if spec is None: | |
raise Http404 | |
cls = spec.cls | |
form = cls(request, spec, data) | |
if method == "GET" and "__doc__" in data: | |
# TODO | |
return render(request, "acko/api.html", {"form": form}) | |
# We consider this request a API request either a. it is POST or | |
# b. it is GET but __doc__ is not passed. | |
if not form.is_valid(): | |
d = {} | |
for k, v in form.errors.as_data().items(): | |
errors = [] | |
for e in v: | |
if isinstance(e, APIError): | |
# TODO: convert it to new style error | |
errors.append(e.error()) | |
else: | |
for message in e.messages: | |
errors.append( | |
DjValidationError( | |
str(message), | |
key=k, | |
cleaned_data=form.cleaned_data, | |
method=method, | |
name=name, | |
version=version, | |
data=data, | |
error_class=e.__class__.__name__, | |
).error() | |
) | |
d[k] = errors | |
try: | |
return spec.error(**d), form.cookies | |
except jsonschema.ValidationError as e: | |
# NOTE: if validator fails here it is a BUG | |
print("Error ValidationError:", e) | |
print("Data:", d) | |
print("Schema:", spec.error_schema) | |
raise e | |
try: | |
# TODO remove the round-trip | |
# Don't modify the data in clean_xxx (except basic type | |
# casting). | |
# If any change is required, add it as a property. | |
# Directly validate the cleaned_data. | |
# Example: self.cleaned_pincode should be the model instance, | |
# and self.cleaned_data['pincode'] should be the pincode | |
# string. | |
form.serialized_data = json.loads( | |
json.dumps(form.cleaned_data, cls=ApiJsonEncoder) | |
) | |
spec.input(form.serialized_data) | |
except jsonschema.ValidationError as e: | |
error = e.message | |
if e.path: | |
field = e.path.popleft() | |
if e.path: | |
error = "%s: %s" % ( | |
' -> '.join(map(str, e.path)), e.message | |
) | |
elif e.validator_value and e.validator_value[0] == 'name': | |
field = e.validator_value[1] | |
else: | |
field = '__all__' | |
errors = { | |
field: [ | |
SchemaError( | |
error, | |
cleaned_data=form.cleaned_data, | |
schema=spec.input_schema, | |
method=method, | |
name=name, | |
version=version, | |
data=data, | |
).error() | |
] | |
} | |
return spec.error(**errors), form.cookies | |
if "validate_only" in data: | |
return {"success": True, "result": None} | |
try: | |
d = form.save() | |
except APIError as e: | |
print(e, getattr(e, '_field', '__all__')) | |
return spec.error( | |
**{getattr(e, '_field', '__all__'): [e.error()]} | |
), form.cookies | |
else: | |
j = form.json(d) | |
try: | |
return spec.success(j), form.cookies | |
except jsonschema.ValidationError as e: | |
# NOTE: if validator fails here it is a BUG | |
print("Output ValidationError:", e) | |
print("Data:", j) | |
print("Schema:", spec._success_schema) | |
raise | |
def __call__(self, request, api, v=settings.API_VERSION, internal=False): | |
""" | |
:param request: | |
:param api: | |
:param v: | |
:param internal: Flag if API is called internally by some other API. | |
:return: dict if internal else JsonResponse | |
""" | |
v = int(v) | |
if api == "bulk": | |
# in case of bulk, data looks like this: | |
# | |
# { | |
# "one": {"api": "api", "method": "GET/POST", "data": {}} | |
# "two": {"api": "api", "method": "GET/POST", "data": {}} | |
# } | |
# | |
# and the response must look like: | |
# | |
# { | |
# "one": {"success": True, "result": "foo"} | |
# "two": { | |
# "success": False, "errors": {"__all__": ["yo"]} | |
# } | |
# } | |
print("bulk", request.data) | |
result = {} | |
cookies = [] | |
for api, payload in request.data.items(): | |
result[api], api_cookies = self.handle_one( | |
request, payload["method"], payload["name"], v, | |
payload["data"]) | |
cookies.extend(api_cookies) | |
else: | |
result, cookies = self.handle_one(request, request.method, api, v, | |
request.data) | |
if internal: | |
return result | |
response = JsonResponse(result, encoder=ApiJsonEncoder) | |
for args, kw in cookies: | |
response.set_cookie(*args, **kw) | |
return response | |
def _find_spec(self, name, v): | |
if v > settings.API_VERSION: | |
v = settings.API_VERSION | |
for version in range(v, 0, -1): | |
key = (name, version) | |
try: | |
spec = self._registry[key] | |
except KeyError: | |
continue | |
if issubclass(spec.cls, APIGone): | |
return None | |
return spec | |
return None | |
registry = APIRegistry() | |
# in urls.py we have: surl('/api/v<int:v>/<something:api>/', api.registry), | |
class APISpec(object): | |
def __init__(self, mod, cls, name, v): | |
self.api = name | |
self.cls = cls | |
self.module = mod | |
self.version = v | |
self._input_schema = self._read_schema("%s_input" % name, v) | |
self._success_schema = self._read_schema("%s_success" % name, v) | |
self._error_schema = self._read_schema("%s_error" % name, v) | |
@property | |
def input_schema(self): | |
return self._input_schema | |
def input(self, data): | |
jsonschema.validate( | |
data, | |
self._input_schema, | |
format_checker=jsonschema.FormatChecker()) | |
return data | |
def error(self, **errors): | |
for k, v in errors.items(): | |
if isinstance(v, str): | |
errors[k] = [v] | |
return self._validate({ | |
"success": False, | |
"errors": errors | |
}, self._error_schema) | |
def success(self, result): | |
return self._validate({ | |
"success": True, | |
"result": result | |
}, self._success_schema) | |
@classmethod | |
def _read_schema(cls, name, version): | |
name = os.path.join(settings.BASE_DIR, | |
"../schema/v%d/%s.json" % (version, name)) | |
with open(name) as f: | |
text = f.read() | |
if not text: | |
return {} | |
schema = json.loads(text) | |
cls._remove_empty_required(schema) | |
return schema | |
@staticmethod | |
def _validate(obj, schema): | |
jsonschema.validate( | |
obj, schema, format_checker=jsonschema.FormatChecker() | |
) | |
return obj | |
@classmethod | |
def _remove_empty_required(cls, schema): | |
""" | |
If all the fields in an object are optional, the required array in | |
the schema will be empty; jsonschema then raises an error. This | |
method removes all the empty "required" arrays. | |
""" | |
schema_type = schema.get('type') | |
if schema_type == 'object': | |
if 'required' in schema and not schema['required']: | |
del schema['required'] | |
for p, p_schema in schema['properties'].items(): | |
cls._remove_empty_required(p_schema) | |
elif schema_type == 'array': | |
if 'items' in schema: | |
cls._remove_empty_required(schema['items']) | |
elif schema.get('anyOf'): | |
for sub_schema in schema['anyOf']: | |
cls._remove_empty_required(sub_schema) | |
elif schema.get('oneOf'): | |
for sub_schema in schema['oneOf']: | |
cls._remove_empty_required(sub_schema) | |
elif schema.get('allOf'): | |
for sub_schema in schema['allOf']: | |
cls._remove_empty_required(sub_schema) | |
def __repr__(self): | |
return str(self.__dict__) | |
# noinspection PyMethodMayBeStatic | |
class API(forms.Form): | |
FIELD_MAPPING = {} | |
# if perms is not empty, we check if current user has those perms | |
PERMS = [] | |
def __init__(self, request, spec, *args, **kw): | |
self.request = request | |
self.spec = spec | |
self.cookies = [] | |
super().__init__(*args, **kw) | |
for name, field in spec.input_schema['properties'].items(): | |
self.fields[name] = self._get_field( | |
name, field, | |
(name in spec.input_schema.get("required", [])) | |
) | |
def set_cookie(self, *args, **kw): | |
self.cookies.append((args, kw)) | |
def d(self, name, default=None): | |
val = self.cleaned_data.get(name) | |
if val is None: | |
return default | |
return val | |
def i(self, name, default=0): | |
return int(self.cleaned_data.get(name) or default) | |
def _get_field(self, name, field, required): | |
# first lookup in cls.FIELD_MAPPING with name, then with type, | |
# then in settings.FIELD_MAPPING | |
try: | |
tipe = field["type"] | |
except KeyError: | |
# TODO | |
# Current assumptions: | |
# * it's a oneOf / allOf / anyOf field if "type" is not a key | |
# * first field of the oneOf / allOf /anyOf is the required field, | |
# * oneOf / anyOf: remaining are placeholders for null / blank | |
if 'anyOf' in field: | |
tipe = field["anyOf"][0]["type"] | |
elif 'oneOf' in field: | |
tipe = field["oneOf"][0]["type"] | |
elif 'allOf' in field: | |
tipe = field["allOf"][0]["type"] | |
else: | |
raise Exception("Unknown field: %s" % field) | |
cls, kw = self.FIELD_MAPPING.get( | |
name, self.FIELD_MAPPING.get( | |
tipe, settings.FIELD_MAPPING.get( | |
name, | |
settings.FIELD_MAPPING.get( | |
tipe, (forms.CharField, {}) | |
) | |
) | |
) | |
) | |
kw = kw.copy() | |
kw["required"] = required | |
if "minLength" in field: | |
kw["min_length"] = field["minLength"] | |
if "maxLength" in field: | |
kw["max_length"] = field["maxLength"] | |
if "pattern" in field: | |
kw = kw # TODO: add a regex validator | |
return cls(**kw) | |
class GETyAPI(API): | |
def save(self): | |
pass | |
class AnonymousUser(APIError): | |
pass | |
class AGETyAPI(GETyAPI): | |
def clean(self): | |
super().clean() | |
if not self.request.user.is_authenticated(): | |
raise AnonymousUser(message="No user is logged in.") | |
class SGETyAPI(AGETyAPI): | |
def clean(self): | |
super().clean() | |
if not self.request.user.is_staff: | |
raise AnonymousUser(message="Only staff can access this API.") | |
class ListAPI(GETyAPI): | |
def object_list(self): | |
return self.model.objects.all() | |
def json(self, _): | |
object_list = self.object_list() | |
return [self.obj2json(o) for o in object_list] | |
# noinspection PyMethodMayBeStatic | |
class PaginatedAPI(ListAPI): | |
def extra(self): | |
return None | |
def json(self, _): | |
object_list = self.object_list() | |
pager = Paginator( | |
object_list, self.i("per_page", 25), | |
orphans=self.i("orphans", 0), | |
allow_empty_first_page=True, | |
) | |
try: | |
page = pager.page(self.i("page", 1)) | |
except EmptyPage: | |
# If page is out of range (e.g. 9999), deliver last page of | |
# results | |
page = pager.page(pager.num_pages) | |
return { | |
"object_list": [self.obj2json(o) for o in page.object_list], | |
"num_pages": pager.num_pages, "page": page.number, | |
"has_next": page.has_next(), | |
"has_previous": page.has_previous(), | |
"has_other_pages": page.has_other_pages(), | |
"next_page_number": ( | |
page.next_page_number() if page.has_next() else 0 | |
), | |
"previous_page_number": ( | |
page.previous_page_number() if page.has_previous() else 0 | |
), | |
"start_index": page.start_index(), | |
"end_index": page.end_index(), | |
"extra": self.extra(), | |
} | |
class FlatAPIError(APIError): | |
pass | |
class FlatAPI(GETyAPI): | |
api = None | |
tree_data = None | |
result = None | |
def clean(self): | |
super().clean() | |
data = {} | |
for key, v in self.cleaned_data.items(): | |
parts = key.split('__') | |
last_i = len(parts) - 1 | |
d = data | |
for i, k in enumerate(parts): | |
if i == last_i: | |
d[k] = v | |
else: | |
d[k] = d.get(k, {}) | |
d = d[k] | |
self.tree_data = data | |
def save(self): | |
self.call_api(self.tree_data) | |
def call_api(self, data): | |
out = helpers.call_api(self.request, self.api, data) | |
if out['success']: | |
self.result = out['result'] | |
else: | |
errors = {} | |
for f, err in out['errors'].items(): | |
errors[f] = [e['human'] for e in err] | |
raise FlatAPIError(field='__all__', message=json.dumps(errors)) | |
def json(self, _): | |
return self.result | |
def register_api(cls): | |
""" | |
If class name ends with Vn, then n is assumed to be the version | |
number. Else the class is assumed to implement a version 1 API. | |
""" | |
version_split = re.split(r'(V\d+)$', cls.__name__) | |
name = camel_to_snake(version_split[0]) | |
if len(version_split) == 3: | |
vn = version_split[1] | |
v = int(vn[1:]) | |
else: | |
v = 1 | |
import inspect | |
caller_frame = inspect.stack()[1] | |
mod = caller_frame.filename.split(os.path.sep)[-2] | |
registry.register(mod, cls, name, v) | |
return cls | |
class UserDoesNotExist(APIError): | |
pass | |
@register_api | |
class UserInfo(GETyAPI): | |
def __init__(self, *args, **kw): | |
super().__init__(*args, **kw) | |
self.user = None | |
def clean_ekey(self): | |
# self.ekey is same as self.cleaned_data["ekey"] | |
try: | |
self.user = User.objects.get_by_ekey(self.ekey) | |
except (User.DoesNotExist, EncryptedIDDecodeError): | |
raise UserDoesNotExist(message="UserDoesNotExist", ekey=self.ekey) | |
return self.ekey | |
def json(self): | |
return { | |
"id": self.user.ekey, | |
"informal": self.user.informal, | |
"formal": self.user.formal, | |
"phone": self.user.phone, | |
"is_staff": self.user.is_staff, | |
"is_superuser": self.user.is_superuser, | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment