Commit 3b3b8cbe by Jian Weng Committed by Tianqi Chen

allows constant param in op construct (#2257)

parent d50f7b66
...@@ -144,14 +144,14 @@ class HybridParser(ast.NodeVisitor): ...@@ -144,14 +144,14 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node): def visit_Name(self, node):
_id = node.id _id = node.id
if _id in self._args.keys() and isinstance(self._args[_id], _expr.Var): if _id in self._args.keys() and isinstance(self._args[_id], (_expr.Var, _expr.ConstExpr)):
return self._args[_id] return self._args[_id]
elif _id in self.loops_above.keys(): elif _id in self.loops_above.keys():
return self.loops_above[_id] return self.loops_above[_id]
_internal_assert(_id not in self._args.keys(), \ _internal_assert(_id not in self._args.keys(), \
"This id %s should be handled in visit_Subscript!" % _id) "This id %s should be handled in visit_Subscript!" % _id)
_internal_assert(_id in self.usage.keys(), \ _internal_assert(_id in self.usage.keys(), \
"This id %s is expected to be a defined variable!" % _id) "This id %s is expected to be a defined variable!" % _id)
# Buffer # Buffer
if _id in self.alloc_buffers.keys(): if _id in self.alloc_buffers.keys():
_buf, _ = self.alloc_buffers[_id] _buf, _ = self.alloc_buffers[_id]
...@@ -166,6 +166,15 @@ class HybridParser(ast.NodeVisitor): ...@@ -166,6 +166,15 @@ class HybridParser(ast.NodeVisitor):
return _api.const(node.n) return _api.const(node.n)
def visit_AugAssign(self, node):
lhs = self.visit(node.target)
rhs = self.visit(node.value)
rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs)
_internal_assert(isinstance(lhs, _expr.Call), \
"The LHS of an AugAssign is supposed to be a call!")
return _make.Provide(lhs.func, 0, rhs, lhs.args)
def visit_Assign(self, node): def visit_Assign(self, node):
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!") _internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
lhs = node.targets[0] lhs = node.targets[0]
...@@ -177,7 +186,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -177,7 +186,7 @@ class HybridParser(ast.NodeVisitor):
lhs_ = lhs lhs_ = lhs
lhs = lhs.id lhs = lhs.id
_internal_assert(lhs not in self.loops_above.keys(), \ _internal_assert(lhs not in self.loops_above.keys(), \
"Loop variable cannot be overwritten!") "Loop variable cannot be overwritten!")
decl, _, rw = self.usage[lhs] decl, _, rw = self.usage[lhs]
if decl == lhs_: if decl == lhs_:
_internal_assert(lhs not in self.var_consts.keys(), \ _internal_assert(lhs not in self.var_consts.keys(), \
...@@ -227,16 +236,16 @@ class HybridParser(ast.NodeVisitor): ...@@ -227,16 +236,16 @@ class HybridParser(ast.NodeVisitor):
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
_internal_assert(isinstance(node.value, ast.Attribute), \ _internal_assert(isinstance(node.value, ast.Attribute), \
"Only variable and attribute's subscript supported so far") "Only variable and attribute's subscript supported so far")
_internal_assert(isinstance(node.value.value, ast.Name), \ _internal_assert(isinstance(node.value.value, ast.Name), \
"The root of array access is expect to be a id!") "The root of array access is expect to be a id!")
_internal_assert(node.value.attr == "shape", \ _internal_assert(node.value.attr == "shape", \
"Attribute access so far only 'shape' is supported!") "Attribute access so far only 'shape' is supported!")
_internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!") _internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
args = args[0] args = args[0]
#TODO: maybe support non-constant value later? #TODO: maybe support non-constant value later?
_internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \ _internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
"So far only constant shape access supported!") "So far only constant shape access supported!")
buf = self._get_buffer_from_id(node.value.value.id) buf = self._get_buffer_from_id(node.value.value.id)
return buf.shape[args.value] return buf.shape[args.value]
...@@ -294,7 +303,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -294,7 +303,7 @@ class HybridParser(ast.NodeVisitor):
def visit_Call(self, node): def visit_Call(self, node):
# Yet, no function pointer supported # Yet, no function pointer supported
_internal_assert(isinstance(node.func, ast.Name), \ _internal_assert(isinstance(node.func, ast.Name), \
"Only id-function function call is supported so far!") "Only id-function function call is supported so far!")
func_id = node.func.id func_id = node.func.id
n = len(node.args) n = len(node.args)
if func_id in LOOP_INTRIN.keys() and func_id != 'bind': if func_id in LOOP_INTRIN.keys() and func_id != 'bind':
...@@ -311,7 +320,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -311,7 +320,7 @@ class HybridParser(ast.NodeVisitor):
elif func_id == 'bind': elif func_id == 'bind':
_internal_assert(n == 2, "A loop bind should only have 2 arguments!") _internal_assert(n == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(node.args[0], ast.Str), \ _internal_assert(isinstance(node.args[0], ast.Str), \
"A loop bind's first argument should be a string!") "A loop bind's first argument should be a string!")
_vn = node.args[0].s _vn = node.args[0].s
iter_var = thread_axis(node.args[0].s) iter_var = thread_axis(node.args[0].s)
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1]) low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1])
...@@ -321,11 +330,11 @@ class HybridParser(ast.NodeVisitor): ...@@ -321,11 +330,11 @@ class HybridParser(ast.NodeVisitor):
return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args]) return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
elif func_id in ['allocate', 'output_tensor']: elif func_id in ['allocate', 'output_tensor']:
_internal_assert(isinstance(node.args[0], ast.Tuple), \ _internal_assert(isinstance(node.args[0], ast.Tuple), \
"allocate's first argument should be a tuple of shape!") "allocate's first argument should be a tuple of shape!")
shape = tuple(self.visit(i) for i in node.args[0].elts) shape = tuple(self.visit(i) for i in node.args[0].elts)
if func_id == 'output_tensor': if func_id == 'output_tensor':
_internal_assert(not self.loops_above, \ _internal_assert(not self.loops_above, \
"Are you sure to allocate a output buffer multiple times?") "Are you sure to allocate a output buffer multiple times?")
for i in shape: for i in shape:
_internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression") _internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
if n > 1: if n > 1:
...@@ -333,18 +342,18 @@ class HybridParser(ast.NodeVisitor): ...@@ -333,18 +342,18 @@ class HybridParser(ast.NodeVisitor):
dtype = node.args[1].s dtype = node.args[1].s
else: else:
_internal_assert(isinstance(node.args[1], ast.Attribute), \ _internal_assert(isinstance(node.args[1], ast.Attribute), \
"Unable to evaluate to get data type") "Unable to evaluate to get data type")
to_eval = node.args[1] to_eval = node.args[1]
_internal_assert(isinstance(to_eval.value, ast.Name), \ _internal_assert(isinstance(to_eval.value, ast.Name), \
"Unable to evaluate the attribute to get data type") "Unable to evaluate the attribute to get data type")
_internal_assert(to_eval.attr == 'dtype', \ _internal_assert(to_eval.attr == 'dtype', \
"Only dtype attribute is supported so far") "Only dtype attribute is supported so far")
dtype = self._get_buffer_from_id(to_eval.value.id).dtype dtype = self._get_buffer_from_id(to_eval.value.id).dtype
else: else:
dtype = 'float32' dtype = 'float32'
if n > 2: if n > 2:
_internal_assert(isinstance(node.args[2], ast.Str), \ _internal_assert(isinstance(node.args[2], ast.Str), \
"The data scope should be an string") "The data scope should be an string")
_internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope") _internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
scope = node.args[2].s scope = node.args[2].s
else: else:
...@@ -361,7 +370,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -361,7 +370,7 @@ class HybridParser(ast.NodeVisitor):
def visit_For(self, node): def visit_For(self, node):
iter_var, low, ext, for_type = self.visit(node.iter) iter_var, low, ext, for_type = self.visit(node.iter)
_internal_assert(isinstance(node.target, ast.Name), \ _internal_assert(isinstance(node.target, ast.Name), \
"The loop iterator should be a variable!") "The loop iterator should be a variable!")
_name = node.target.id _name = node.target.id
if iter_var is None: if iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!") _internal_assert(for_type is not None, "The loop bind function parse error!")
...@@ -389,7 +398,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -389,7 +398,7 @@ class HybridParser(ast.NodeVisitor):
ids.append(node.value.id) ids.append(node.value.id)
else: else:
_internal_assert(isinstance(node.value, ast.Tuple), \ _internal_assert(isinstance(node.value, ast.Tuple), \
"You should return either a single tensor or a tuple") "You should return either a single tensor or a tuple")
for i in node.value.elts: for i in node.value.elts:
_internal_assert(isinstance(i, ast.Name), "What do you return?") _internal_assert(isinstance(i, ast.Name), "What do you return?")
ids.append(i.id) ids.append(i.id)
......
...@@ -15,9 +15,14 @@ from ..tensor import Tensor ...@@ -15,9 +15,14 @@ from ..tensor import Tensor
#pylint: disable=invalid-name #pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray]) np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, _expr.Var) tvm_arg_types = (Tensor, _expr.Var, _expr.ConstExpr)
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm) halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Useful constants. In avoid of runtime dependences, we use function calls to return them. # Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop(): def make_nop():
...@@ -50,14 +55,16 @@ def _is_tvm_arg_types(args): ...@@ -50,14 +55,16 @@ def _is_tvm_arg_types(args):
If neither is true, raise a value error.""" If neither is true, raise a value error."""
if isinstance(args[0], tvm_arg_types): if isinstance(args[0], tvm_arg_types):
for elem in args[1:]: for elem in args[1:]:
if not isinstance(elem, tvm_arg_types): _internal_assert(isinstance(elem, tvm_arg_types),
raise ValueError("Expect a Var or Tensor instance but % get!" % str(type(elem))) "Expecting a Var, Tensor or ConstExpr instance but %s get!" \
% str(type(elem)))
return True return True
if not isinstance(args[0], np_arg_types):
raise ValueError("Expect a numpy type but % get!" % str(type(args[0]))) _internal_assert(isinstance(args[0], np_arg_types), \
"Expect a numpy type but %s get!" % str(type(args[0])))
for elem in args[1:]: for elem in args[1:]:
if not isinstance(elem, np_arg_types): _internal_assert(isinstance(elem, np_arg_types), \
raise ValueError("Expect a numpy type but % get!" % str(type(elem))) "Expect a numpy type but %s get!" % str(type(elem)))
return False return False
...@@ -79,12 +86,3 @@ def _restore_runtime(func, intersect): ...@@ -79,12 +86,3 @@ def _restore_runtime(func, intersect):
_globals.pop(elem) _globals.pop(elem)
for k, v in intersect: for k, v in intersect:
_globals[k] = v _globals[k] = v
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Almost the same functionality as the one above, but in this case,
# the error is caused by users inproper usage.
_user_assert = _internal_assert
...@@ -15,6 +15,7 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -15,6 +15,7 @@ class PyVariableUsage(ast.NodeVisitor):
self.scope_level = [] self.scope_level = []
self._args = {} self._args = {}
self.args = args self.args = args
self.aug_assign_ = False
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
...@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
self.visit(elem) self.visit(elem)
def visit_AugAssign(self, node):
self.aug_assign_ = True
self.generic_visit(node)
self.aug_assign_ = False
def visit_Name(self, node): def visit_Name(self, node):
# If it is from the argument list or loop variable, we do not worry about it! # If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys(): if node.id in self._args.keys():
...@@ -61,7 +68,9 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -61,7 +68,9 @@ class PyVariableUsage(ast.NodeVisitor):
if node.id not in self.status.keys(): if node.id not in self.status.keys():
_internal_assert(isinstance(node.ctx, ast.Store), \ _internal_assert(isinstance(node.ctx, ast.Store), \
'Undeclared variable %s' % node.id) 'Undeclared variable %s' % node.id)
if self.aug_assign_:
raise ValueError('"First store" cannot be an AugAssign')
self.status[node.id] = (node, self.scope_level[-1], set()) self.status[node.id] = (node, self.scope_level[-1], set())
else: else:
decl, loop, usage = self.status[node.id] decl, loop, usage = self.status[node.id]
......
...@@ -115,7 +115,7 @@ def test_fanout(): ...@@ -115,7 +115,7 @@ def test_fanout():
for i in range(a.shape[0] - 3): for i in range(a.shape[0] - 3):
sigma = 0.0 sigma = 0.0
for j in range(3): for j in range(3):
sigma = sigma + a[i + j] sigma += a[i + j]
sigma = sigma / three sigma = sigma / three
b[i] = sigma b[i] = sigma
return b return b
...@@ -246,7 +246,7 @@ def test_bind(): ...@@ -246,7 +246,7 @@ def test_bind():
def vec_add(a, b): def vec_add(a, b):
c = output_tensor((1000, ), dtype='float32') c = output_tensor((1000, ), dtype='float32')
for tx in bind('threadIdx.x', 1000): for tx in bind('threadIdx.x', 1000):
c[tx] = b[tx] + c[tx] c[tx] = a[tx] + b[tx]
return c return c
a = tvm.placeholder((1000, ), dtype='float32', name='a') a = tvm.placeholder((1000, ), dtype='float32', name='a')
...@@ -308,7 +308,7 @@ def test_non_zero(): ...@@ -308,7 +308,7 @@ def test_non_zero():
s = 0.0 s = 0.0
for di in range(3): for di in range(3):
for dj in range(3): for dj in range(3):
s = s + a[i-di, j-dj] s += a[i-di, j-dj]
b[i-2, j-2] = s / 9.0 b[i-2, j-2] = s / 9.0
return b return b
...@@ -419,6 +419,32 @@ def test_downstream(): ...@@ -419,6 +419,32 @@ def test_downstream():
module(tvm_a, tvm_c) module(tvm_a, tvm_c)
tvm.testing.assert_allclose(tvm_c.asnumpy(), ref, 1e-5, 1e-5) tvm.testing.assert_allclose(tvm_c.asnumpy(), ref, 1e-5, 1e-5)
def test_const_param():
@tvm.hybrid.script
def add_something(a, b):
c = output_tensor((11, ), 'int32')
for i in range(11):
c[i] = a[i] + b
return c
a = tvm.placeholder((11, ), dtype='int32', name='a')
b = tvm.const(11, 'int32')
c = add_something(a, b)
sch = tvm.create_schedule(c.op)
module = tvm.build(sch, [a, c], 'llvm')
assert(module)
np_a = numpy.arange(11).astype('int32')
np_b = 11
np_c = numpy.zeros((11, )).astype('int32')
nd_a = tvm.ndarray.array(np_a)
nd_c = tvm.ndarray.array(numpy.zeros((11, )).astype('int32'))
module(nd_a, nd_c)
ref = add_something(np_a, 11)
tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5)
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
...@@ -432,5 +458,6 @@ if __name__ == "__main__": ...@@ -432,5 +458,6 @@ if __name__ == "__main__":
#test_inplace() #test_inplace()
test_upstream() test_upstream()
test_downstream() test_downstream()
test_const_param()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment