Commit ba15a729 by Lianmin Zheng Committed by Tianqi Chen

[HybridScript] Capture constant external python variables (#3157)

parent 654192de
...@@ -31,6 +31,8 @@ HalideIR. ...@@ -31,6 +31,8 @@ HalideIR.
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import inspect
from .._ffi.base import decorate from .._ffi.base import decorate
from .._ffi.function import _init_api from .._ffi.function import _init_api
from ..build_module import form_body from ..build_module import form_body
...@@ -55,7 +57,9 @@ def script(pyfunc): ...@@ -55,7 +57,9 @@ def script(pyfunc):
from .util import _is_tvm_arg_types from .util import _is_tvm_arg_types
if _is_tvm_arg_types(args): if _is_tvm_arg_types(args):
src = _pruned_source(func) src = _pruned_source(func)
return source_to_op(src, func.__globals__, args) closure_vars = inspect.getclosurevars(func).nonlocals
closure_vars.update(inspect.getclosurevars(func).globals)
return source_to_op(src, args, func.__globals__, closure_vars)
from .runtime import _enter_hybrid_runtime, _restore_runtime from .runtime import _enter_hybrid_runtime, _restore_runtime
intersect = _enter_hybrid_runtime(func) intersect = _enter_hybrid_runtime(func)
......
...@@ -62,7 +62,7 @@ class HybridModule(object): ...@@ -62,7 +62,7 @@ class HybridModule(object):
def __call__(self, *args): def __call__(self, *args):
if _is_tvm_arg_types(args): if _is_tvm_arg_types(args):
return source_to_op(self.root_, globals(), args) return source_to_op(self.root_, args, globals(), {})
return self.func_(*args) return self.func_(*args)
......
...@@ -25,7 +25,7 @@ import numbers ...@@ -25,7 +25,7 @@ import numbers
from enum import Enum from enum import Enum
from .util import _internal_assert from .util import _internal_assert, _apply_indices
from . import calls from . import calls
from . import util from . import util
from .preprocessor import determine_variable_usage from .preprocessor import determine_variable_usage
...@@ -112,7 +112,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -112,7 +112,7 @@ class HybridParser(ast.NodeVisitor):
} }
def __init__(self, args, usage, symbols, func_name=None): def __init__(self, args, usage, symbols, closure_vars, func_name=None):
""" """
Parameters Parameters
---------- ----------
...@@ -122,6 +122,12 @@ class HybridParser(ast.NodeVisitor): ...@@ -122,6 +122,12 @@ class HybridParser(ast.NodeVisitor):
usage: A dict of variables used in last in this function usage: A dict of variables used in last in this function
Provided by last lower pass, which collects this information Provided by last lower pass, which collects this information
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns Returns
------- -------
func_name: str func_name: str
...@@ -136,6 +142,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -136,6 +142,8 @@ class HybridParser(ast.NodeVisitor):
if isinstance(v, types.FunctionType): if isinstance(v, types.FunctionType):
self.add_symbol(k, Symbol.Callable, v) self.add_symbol(k, Symbol.Callable, v)
self.closure_vars = closure_vars
self.binds = {} # Thread binds self.binds = {} # Thread binds
self.device = 0 # Is it generating device self.device = 0 # Is it generating device
...@@ -236,7 +244,11 @@ class HybridParser(ast.NodeVisitor): ...@@ -236,7 +244,11 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
name = node.id name = node.id
if sys.version_info[0] == 2 and name in ['True', 'False']: if sys.version_info[0] == 2 and name in ['True', 'False']:
return _api.convert(eval(name)) #pylint: disable=eval-used return _api.convert(ast.literal_eval(name))
if name in self.closure_vars:
return _api.convert(self.closure_vars[name])
ty, entry = self.symbols[name] ty, entry = self.symbols[name]
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name) _internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]: if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
...@@ -356,10 +368,12 @@ class HybridParser(ast.NodeVisitor): ...@@ -356,10 +368,12 @@ class HybridParser(ast.NodeVisitor):
buf = self.visit(node.value) buf = self.visit(node.value)
return getattr(buf, node.attr) return getattr(buf, node.attr)
def visit_Subscript(self, node): def visit_Subscript(self, node):
args = self.visit(node.slice) args = self.visit(node.slice)
if isinstance(node.value, ast.Name): if isinstance(node.value, ast.Name):
if node.value.id in self.closure_vars:
args = ast.literal_eval(str(args))
return _api.convert(_apply_indices(self.closure_vars[node.value.id], args))
buf = self.visit(node.value) buf = self.visit(node.value)
if isinstance(buf, Array): if isinstance(buf, Array):
...@@ -576,7 +590,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -576,7 +590,7 @@ class HybridParser(ast.NodeVisitor):
return _make.AssertStmt(test, mesg, util.make_nop()) return _make.AssertStmt(test, mesg, util.make_nop())
def parse_python(src, symbols, args): def parse_python(src, args, symbols, closure_vars):
"""The helper function of calling the AST visitor """The helper function of calling the AST visitor
Parameters Parameters
...@@ -585,14 +599,17 @@ def parse_python(src, symbols, args): ...@@ -585,14 +599,17 @@ def parse_python(src, symbols, args):
If an ast.node, then directly lower it. If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it. If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars args : list of Tensors or Vars
The argument lists to the function. The argument lists to the function.
It is NOT encouraged to write a function without arguments. It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect. It is NOT encouraged to write a function with side effect.
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns Returns
------- -------
root : Stmt root : Stmt
...@@ -600,14 +617,14 @@ def parse_python(src, symbols, args): ...@@ -600,14 +617,14 @@ def parse_python(src, symbols, args):
""" """
root = ast.parse(src) if isinstance(src, str) else src root = ast.parse(src) if isinstance(src, str) else src
_internal_assert(root, ast.AST) _internal_assert(root, ast.AST)
var_usage = determine_variable_usage(root, args, symbols) var_usage = determine_variable_usage(root, args, symbols, closure_vars)
parser = HybridParser(args, var_usage, symbols) parser = HybridParser(args, var_usage, symbols, closure_vars)
parser.parsed_body = parser.visit(root) parser.parsed_body = parser.visit(root)
_internal_assert(parser.returned, 'No valid return found in the function body!') _internal_assert(parser.returned, 'No valid return found in the function body!')
return parser return parser
def source_to_op(src, symbols, args): def source_to_op(src, args, symbols, closure_vars):
"""Another level of wrapper """Another level of wrapper
Parameters Parameters
...@@ -616,20 +633,23 @@ def source_to_op(src, symbols, args): ...@@ -616,20 +633,23 @@ def source_to_op(src, symbols, args):
If an ast.node, then directly lower it. If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it. If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars args : list of Tensors or Vars
The argument lists to the function. The argument lists to the function.
It is NOT encouraged to write a function without arguments. It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect. It is NOT encouraged to write a function with side effect.
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns Returns
------- -------
res : list of output tensors res : list of output tensors
The result of output tensors of the formed OpNode. The result of output tensors of the formed OpNode.
""" """
parser = parse_python(src, symbols, args) parser = parse_python(src, args, symbols, closure_vars)
input_tensors = [] input_tensors = []
for i in args: for i in args:
......
...@@ -26,14 +26,14 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -26,14 +26,14 @@ class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable""" """The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name #pylint: disable=invalid-name
#pylint: disable=missing-docstring #pylint: disable=missing-docstring
def __init__(self, args, symbols): def __init__(self, args, symbols, closure_vars):
self.status = {} self.status = {}
self.scope_level = [] self.scope_level = []
self._args = {} self._args = {}
self.args = args self.args = args
self.aug_assign_ = False self.aug_assign_ = False
self.symbols = symbols self.symbols = symbols
self.closure_vars = closure_vars
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self.scope_level.append(node) self.scope_level.append(node)
...@@ -89,6 +89,14 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -89,6 +89,14 @@ class PyVariableUsage(ast.NodeVisitor):
"Iter var cannot be overwritten") "Iter var cannot be overwritten")
if node.id not in self.status.keys(): if node.id not in self.status.keys():
# It is a captured value in closure
if node.id in self.closure_vars:
try:
ast.literal_eval(str(self.closure_vars[node.id]))
except ValueError:
raise ValueError("Only support capturing constant values in closure")
return
_internal_assert(isinstance(node.ctx, ast.Store), \ _internal_assert(isinstance(node.ctx, ast.Store), \
'Undeclared variable %s' % node.id) 'Undeclared variable %s' % node.id)
if self.aug_assign_: if self.aug_assign_:
...@@ -102,8 +110,8 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -102,8 +110,8 @@ class PyVariableUsage(ast.NodeVisitor):
self.status[node.id] = (decl, loop, usage) self.status[node.id] = (decl, loop, usage)
def determine_variable_usage(root, args, symbols): def determine_variable_usage(root, args, symbols, closure_vars):
"""The helper function for calling the dedicated visitor.""" """The helper function for calling the dedicated visitor."""
visitor = PyVariableUsage(args, symbols) visitor = PyVariableUsage(args, symbols, closure_vars)
visitor.visit(root) visitor.visit(root)
return visitor.status return visitor.status
...@@ -101,3 +101,9 @@ def _is_tvm_arg_types(args): ...@@ -101,3 +101,9 @@ def _is_tvm_arg_types(args):
_internal_assert(isinstance(elem, np_arg_types), \ _internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem))) "Expect a numpy type but %s get!" % str(type(elem)))
return False return False
def _apply_indices(value, indices):
"""Apply multidimensional index"""
if indices:
return _apply_indices(value[indices[0]], indices[1:])
return value
...@@ -768,6 +768,24 @@ def test_schedule(): ...@@ -768,6 +768,24 @@ def test_schedule():
# Test loop binds # Test loop binds
def test_capture():
n = 8
constant_tuple = (10, n)
constant_list = [[1, 2], [3, n]]
const_value = 1
@tvm.hybrid.script
def add_something(a):
c = output_tensor((constant_tuple[1],), 'int32')
for i in range(constant_tuple[1]):
c[i] = a[i] + constant_list[1][const_value]
return c
a = tvm.placeholder((n, ), dtype='int32', name='a')
func, ins, outs = run_and_check(add_something, [a])
run_and_check(func, ins, outs=outs)
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
...@@ -786,5 +804,6 @@ if __name__ == "__main__": ...@@ -786,5 +804,6 @@ if __name__ == "__main__":
test_bool() test_bool()
test_const_range() test_const_range()
test_schedule() test_schedule()
test_capture()
# TODO: # TODO:
# test_inplace() # test_inplace()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment