Commit 3b3b8cbe by Jian Weng Committed by Tianqi Chen

allows constant param in op construct (#2257)

parent d50f7b66
......@@ -144,7 +144,7 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node):
_id = node.id
if _id in self._args.keys() and isinstance(self._args[_id], _expr.Var):
if _id in self._args.keys() and isinstance(self._args[_id], (_expr.Var, _expr.ConstExpr)):
return self._args[_id]
elif _id in self.loops_above.keys():
return self.loops_above[_id]
......@@ -166,6 +166,15 @@ class HybridParser(ast.NodeVisitor):
return _api.const(node.n)
def visit_AugAssign(self, node):
lhs = self.visit(node.target)
rhs = self.visit(node.value)
rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs)
_internal_assert(isinstance(lhs, _expr.Call), \
"The LHS of an AugAssign is supposed to be a call!")
return _make.Provide(lhs.func, 0, rhs, lhs.args)
def visit_Assign(self, node):
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
lhs = node.targets[0]
......
......@@ -15,9 +15,14 @@ from ..tensor import Tensor
#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, _expr.Var)
tvm_arg_types = (Tensor, _expr.Var, _expr.ConstExpr)
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
......@@ -50,14 +55,16 @@ def _is_tvm_arg_types(args):
If neither is true, raise a value error."""
if isinstance(args[0], tvm_arg_types):
for elem in args[1:]:
if not isinstance(elem, tvm_arg_types):
raise ValueError("Expect a Var or Tensor instance but % get!" % str(type(elem)))
_internal_assert(isinstance(elem, tvm_arg_types),
"Expecting a Var, Tensor or ConstExpr instance but %s get!" \
% str(type(elem)))
return True
if not isinstance(args[0], np_arg_types):
raise ValueError("Expect a numpy type but % get!" % str(type(args[0])))
_internal_assert(isinstance(args[0], np_arg_types), \
"Expect a numpy type but %s get!" % str(type(args[0])))
for elem in args[1:]:
if not isinstance(elem, np_arg_types):
raise ValueError("Expect a numpy type but % get!" % str(type(elem)))
_internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem)))
return False
......@@ -79,12 +86,3 @@ def _restore_runtime(func, intersect):
_globals.pop(elem)
for k, v in intersect:
_globals[k] = v
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Almost the same functionality as the one above, but in this case,
# the error is caused by users inproper usage.
_user_assert = _internal_assert
......@@ -15,6 +15,7 @@ class PyVariableUsage(ast.NodeVisitor):
self.scope_level = []
self._args = {}
self.args = args
self.aug_assign_ = False
def visit_FunctionDef(self, node):
......@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
self.visit(elem)
def visit_AugAssign(self, node):
self.aug_assign_ = True
self.generic_visit(node)
self.aug_assign_ = False
def visit_Name(self, node):
# If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys():
......@@ -62,6 +69,8 @@ class PyVariableUsage(ast.NodeVisitor):
if node.id not in self.status.keys():
_internal_assert(isinstance(node.ctx, ast.Store), \
'Undeclared variable %s' % node.id)
if self.aug_assign_:
raise ValueError('"First store" cannot be an AugAssign')
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
......
......@@ -115,7 +115,7 @@ def test_fanout():
for i in range(a.shape[0] - 3):
sigma = 0.0
for j in range(3):
sigma = sigma + a[i + j]
sigma += a[i + j]
sigma = sigma / three
b[i] = sigma
return b
......@@ -246,7 +246,7 @@ def test_bind():
def vec_add(a, b):
c = output_tensor((1000, ), dtype='float32')
for tx in bind('threadIdx.x', 1000):
c[tx] = b[tx] + c[tx]
c[tx] = a[tx] + b[tx]
return c
a = tvm.placeholder((1000, ), dtype='float32', name='a')
......@@ -308,7 +308,7 @@ def test_non_zero():
s = 0.0
for di in range(3):
for dj in range(3):
s = s + a[i-di, j-dj]
s += a[i-di, j-dj]
b[i-2, j-2] = s / 9.0
return b
......@@ -419,6 +419,32 @@ def test_downstream():
module(tvm_a, tvm_c)
tvm.testing.assert_allclose(tvm_c.asnumpy(), ref, 1e-5, 1e-5)
def test_const_param():
@tvm.hybrid.script
def add_something(a, b):
c = output_tensor((11, ), 'int32')
for i in range(11):
c[i] = a[i] + b
return c
a = tvm.placeholder((11, ), dtype='int32', name='a')
b = tvm.const(11, 'int32')
c = add_something(a, b)
sch = tvm.create_schedule(c.op)
module = tvm.build(sch, [a, c], 'llvm')
assert(module)
np_a = numpy.arange(11).astype('int32')
np_b = 11
np_c = numpy.zeros((11, )).astype('int32')
nd_a = tvm.ndarray.array(np_a)
nd_c = tvm.ndarray.array(numpy.zeros((11, )).astype('int32'))
module(nd_a, nd_c)
ref = add_something(np_a, 11)
tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5)
if __name__ == "__main__":
test_outer_product()
......@@ -432,5 +458,6 @@ if __name__ == "__main__":
#test_inplace()
test_upstream()
test_downstream()
test_const_param()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment