Commit ba15a729 by Lianmin Zheng Committed by Tianqi Chen

[HybridScript] Capture constant external python variables (#3157)

parent 654192de
......@@ -31,6 +31,8 @@ HalideIR.
from __future__ import absolute_import as _abs
import inspect
from .._ffi.base import decorate
from .._ffi.function import _init_api
from ..build_module import form_body
......@@ -55,7 +57,9 @@ def script(pyfunc):
from .util import _is_tvm_arg_types
if _is_tvm_arg_types(args):
src = _pruned_source(func)
return source_to_op(src, func.__globals__, args)
closure_vars = inspect.getclosurevars(func).nonlocals
closure_vars.update(inspect.getclosurevars(func).globals)
return source_to_op(src, args, func.__globals__, closure_vars)
from .runtime import _enter_hybrid_runtime, _restore_runtime
intersect = _enter_hybrid_runtime(func)
......
......@@ -62,7 +62,7 @@ class HybridModule(object):
def __call__(self, *args):
if _is_tvm_arg_types(args):
return source_to_op(self.root_, globals(), args)
return source_to_op(self.root_, args, globals(), {})
return self.func_(*args)
......
......@@ -25,7 +25,7 @@ import numbers
from enum import Enum
from .util import _internal_assert
from .util import _internal_assert, _apply_indices
from . import calls
from . import util
from .preprocessor import determine_variable_usage
......@@ -112,7 +112,7 @@ class HybridParser(ast.NodeVisitor):
}
def __init__(self, args, usage, symbols, func_name=None):
def __init__(self, args, usage, symbols, closure_vars, func_name=None):
"""
Parameters
----------
......@@ -122,6 +122,12 @@ class HybridParser(ast.NodeVisitor):
usage: A dict of variables used in last in this function
Provided by last lower pass, which collects this information
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns
-------
func_name: str
......@@ -136,6 +142,8 @@ class HybridParser(ast.NodeVisitor):
if isinstance(v, types.FunctionType):
self.add_symbol(k, Symbol.Callable, v)
self.closure_vars = closure_vars
self.binds = {} # Thread binds
self.device = 0 # Is it generating device
......@@ -236,7 +244,11 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node):
name = node.id
if sys.version_info[0] == 2 and name in ['True', 'False']:
return _api.convert(eval(name)) #pylint: disable=eval-used
return _api.convert(ast.literal_eval(name))
if name in self.closure_vars:
return _api.convert(self.closure_vars[name])
ty, entry = self.symbols[name]
_internal_assert(name in self.symbols, "Unknown symbol %s!" % name)
if ty in [Symbol.LoopVar, Symbol.Input, Symbol.ConstLoopVar]:
......@@ -356,10 +368,12 @@ class HybridParser(ast.NodeVisitor):
buf = self.visit(node.value)
return getattr(buf, node.attr)
def visit_Subscript(self, node):
args = self.visit(node.slice)
if isinstance(node.value, ast.Name):
if node.value.id in self.closure_vars:
args = ast.literal_eval(str(args))
return _api.convert(_apply_indices(self.closure_vars[node.value.id], args))
buf = self.visit(node.value)
if isinstance(buf, Array):
......@@ -576,7 +590,7 @@ class HybridParser(ast.NodeVisitor):
return _make.AssertStmt(test, mesg, util.make_nop())
def parse_python(src, symbols, args):
def parse_python(src, args, symbols, closure_vars):
"""The helper function of calling the AST visitor
Parameters
......@@ -585,14 +599,17 @@ def parse_python(src, symbols, args):
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns
-------
root : Stmt
......@@ -600,14 +617,14 @@ def parse_python(src, symbols, args):
"""
root = ast.parse(src) if isinstance(src, str) else src
_internal_assert(root, ast.AST)
var_usage = determine_variable_usage(root, args, symbols)
parser = HybridParser(args, var_usage, symbols)
var_usage = determine_variable_usage(root, args, symbols, closure_vars)
parser = HybridParser(args, var_usage, symbols, closure_vars)
parser.parsed_body = parser.visit(root)
_internal_assert(parser.returned, 'No valid return found in the function body!')
return parser
def source_to_op(src, symbols, args):
def source_to_op(src, args, symbols, closure_vars):
"""Another level of wrapper
Parameters
......@@ -616,20 +633,23 @@ def source_to_op(src, symbols, args):
If an ast.node, then directly lower it.
If a str, then parse it to ast and lower it.
symbols : str
The symbol list of the global context of the function.
args : list of Tensors or Vars
The argument lists to the function.
It is NOT encouraged to write a function without arguments.
It is NOT encouraged to write a function with side effect.
symbols : list of str
The symbol list of the global context of the function.
closure_vars: dict
A dict of external name reference captured by this function.
Returns
-------
res : list of output tensors
The result of output tensors of the formed OpNode.
"""
parser = parse_python(src, symbols, args)
parser = parse_python(src, args, symbols, closure_vars)
input_tensors = []
for i in args:
......
......@@ -26,14 +26,14 @@ class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
#pylint: disable=invalid-name
#pylint: disable=missing-docstring
def __init__(self, args, symbols):
def __init__(self, args, symbols, closure_vars):
self.status = {}
self.scope_level = []
self._args = {}
self.args = args
self.aug_assign_ = False
self.symbols = symbols
self.closure_vars = closure_vars
def visit_FunctionDef(self, node):
self.scope_level.append(node)
......@@ -89,6 +89,14 @@ class PyVariableUsage(ast.NodeVisitor):
"Iter var cannot be overwritten")
if node.id not in self.status.keys():
# It is a captured value in closure
if node.id in self.closure_vars:
try:
ast.literal_eval(str(self.closure_vars[node.id]))
except ValueError:
raise ValueError("Only support capturing constant values in closure")
return
_internal_assert(isinstance(node.ctx, ast.Store), \
'Undeclared variable %s' % node.id)
if self.aug_assign_:
......@@ -102,8 +110,8 @@ class PyVariableUsage(ast.NodeVisitor):
self.status[node.id] = (decl, loop, usage)
def determine_variable_usage(root, args, symbols):
def determine_variable_usage(root, args, symbols, closure_vars):
"""The helper function for calling the dedicated visitor."""
visitor = PyVariableUsage(args, symbols)
visitor = PyVariableUsage(args, symbols, closure_vars)
visitor.visit(root)
return visitor.status
......@@ -101,3 +101,9 @@ def _is_tvm_arg_types(args):
_internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem)))
return False
def _apply_indices(value, indices):
"""Apply multidimensional index"""
if indices:
return _apply_indices(value[indices[0]], indices[1:])
return value
......@@ -768,6 +768,24 @@ def test_schedule():
# Test loop binds
def test_capture():
n = 8
constant_tuple = (10, n)
constant_list = [[1, 2], [3, n]]
const_value = 1
@tvm.hybrid.script
def add_something(a):
c = output_tensor((constant_tuple[1],), 'int32')
for i in range(constant_tuple[1]):
c[i] = a[i] + constant_list[1][const_value]
return c
a = tvm.placeholder((n, ), dtype='int32', name='a')
func, ins, outs = run_and_check(add_something, [a])
run_and_check(func, ins, outs=outs)
if __name__ == "__main__":
test_outer_product()
......@@ -786,5 +804,6 @@ if __name__ == "__main__":
test_bool()
test_const_range()
test_schedule()
test_capture()
# TODO:
# test_inplace()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment