Created
February 26, 2014 06:16
-
-
Save DasIch/9224484 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from nonlocal_ import nonlocal_ | |
def foo(): | |
a = 1 | |
def bar(): | |
nonlocal_('a') | |
a = 2 | |
bar() | |
return a | |
assert foo() == 2, foo() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import nonlocal_ | |
import bar |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import imp | |
import types | |
import ast | |
if sys.version_info[0] == 2: | |
class NonLocalTransformer(ast.NodeTransformer): | |
def __init__(self): | |
ast.NodeTransformer.__init__(self) | |
self.function_stack = [] | |
def visit_FunctionDef(self, node): | |
self.function_stack.append(node) | |
for i, statement in enumerate(node.body): | |
node.body[i] = self.visit(statement) | |
return self.function_stack.pop() | |
def visit_Expr(self, node): | |
node.value = self.generic_visit(node.value) | |
if isinstance(node.value, ast.Call): | |
call = node.value | |
if isinstance(call.func, ast.Name) and call.func.id == 'nonlocal_': | |
for argument in call.args: | |
function_def = self.find_defining_function(argument.s) | |
if function_def is not None: | |
i, function_def = function_def | |
function_def = ListWrapTransformation( | |
argument.s | |
).visit(function_def) | |
self.function_stack[i] = function_def | |
return ast.Pass() | |
return node | |
def find_defining_function(self, name): | |
def _assignment_in_statement(statement): | |
if isinstance(statement, ast.Assign): | |
return any(map(_name_in_expression, statement.targets)) | |
if hasattr(statement, 'body'): | |
return any(map(_assignment_in_statement, statement.body)) | |
def _name_in_expression(expression): | |
if isinstance(expression, ast.Attribute): | |
return _name_in_expression(expression.value) | |
elif isinstance(expression, ast.Subscript): | |
return _name_in_expression(expression.value) | |
elif isinstance(expression, ast.Name): | |
return expression.id == name | |
elif isinstance(expression, ast.List): | |
return any(map(_name_in_expression, expression.elts)) | |
elif isinstance(expression, ast.Tuple): | |
return any(map(_name_in_expression, expression.elts)) | |
raise NotImplemented(expression) | |
for i, function_def in enumerate(reversed(self.function_stack), 1): | |
i = len(self.function_stack) - i | |
if i == 1: | |
continue | |
if _assignment_in_statement(function_def): | |
return i, function_def | |
class ListWrapTransformation(ast.NodeTransformer): | |
def __init__(self, name): | |
ast.NodeTransformer.__init__(self) | |
self.name = name | |
self.first = True | |
def visit_Assign(self, node): | |
if self.first: | |
if isinstance(node.targets[0], ast.Name) and node.targets[0].id == self.name: | |
assert len(node.targets) == 1 | |
node.value = ast.List([node.value], ast.Load()) | |
self.first = False | |
return node | |
else: | |
node.targets = [self.visit(target) for target in node.targets] | |
node.value = self.visit(node.value) | |
return node | |
def visit_Name(self, node): | |
if node.id == self.name: | |
if self.first: | |
self.first = False | |
return ast.List([node], node.ctx) | |
return ast.Subscript( | |
ast.Name(node.id, ast.Load()), | |
ast.Index(ast.Num(0)), | |
node.ctx | |
) | |
return node | |
else: | |
class NonLocalTransformer(ast.NodeTransformer): | |
def visit_Expr(self, node): | |
if isinstance(node.value, ast.Call): | |
call = node.value | |
if isinstance(call.func, ast.Name) and call.func.id == 'nonlocal_': | |
return ast.Nonlocal([argument.s for argument in call.args]) | |
return node | |
class NonLocalImporter(object): | |
def __init__(self): | |
self._found_modules = {} | |
def find_module(self, name, path=None): | |
try: | |
self._found_modules[name] = (imp.find_module(name, path), path) | |
except ImportError: | |
return None | |
return self | |
def load_module(self, name): | |
(file, filename, description), path = self._found_modules[name] | |
newpath = None | |
if description[2] == imp.PY_SOURCE: | |
with file: | |
code = file.read() | |
elif description[2] == imp.PY_COMPILED: | |
filename = filename[:-1] # .pyc or .pyo | |
with open(filename, 'U') as file: | |
code = file.read() | |
elif description[2] == imp.PKG_DIRECTORY: | |
filename = os.path.join(filename, '__init__.py') | |
newpath = [filename] | |
with open(filename, 'U') as file: | |
code = file.read() | |
else: | |
return imp.load_module(name, file, filename, description) | |
module = types.ModuleType(name) | |
module.__file__ = filename | |
if newpath: | |
module.__path__ = newpath | |
tree = ast.parse(code) | |
tree = NonLocalTransformer().visit(tree) | |
ast.fix_missing_locations(tree) | |
code = compile(tree, filename, 'exec') | |
sys.modules[name] = module | |
exec(code, module.__dict__) | |
return module | |
def nonlocal_(*names): | |
raise RuntimeError( | |
'nonlocal_ needs to be imported before a module using it is imported' | |
) | |
sys.meta_path.insert(0, NonLocalImporter()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment