Commit 3b3b8cbe by Jian Weng Committed by Tianqi Chen

allows constant param in op construct (#2257)

parent d50f7b66
......@@ -144,14 +144,14 @@ class HybridParser(ast.NodeVisitor):
def visit_Name(self, node):
_id = node.id
if _id in self._args.keys() and isinstance(self._args[_id], _expr.Var):
if _id in self._args.keys() and isinstance(self._args[_id], (_expr.Var, _expr.ConstExpr)):
return self._args[_id]
elif _id in self.loops_above.keys():
return self.loops_above[_id]
_internal_assert(_id not in self._args.keys(), \
"This id %s should be handled in visit_Subscript!" % _id)
"This id %s should be handled in visit_Subscript!" % _id)
_internal_assert(_id in self.usage.keys(), \
"This id %s is expected to be a defined variable!" % _id)
"This id %s is expected to be a defined variable!" % _id)
# Buffer
if _id in self.alloc_buffers.keys():
_buf, _ = self.alloc_buffers[_id]
......@@ -166,6 +166,15 @@ class HybridParser(ast.NodeVisitor):
return _api.const(node.n)
def visit_AugAssign(self, node):
lhs = self.visit(node.target)
rhs = self.visit(node.value)
rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs)
_internal_assert(isinstance(lhs, _expr.Call), \
"The LHS of an AugAssign is supposed to be a call!")
return _make.Provide(lhs.func, 0, rhs, lhs.args)
def visit_Assign(self, node):
_internal_assert(len(node.targets) == 1, "So far only one-valued assignment is supported!")
lhs = node.targets[0]
......@@ -177,7 +186,7 @@ class HybridParser(ast.NodeVisitor):
lhs_ = lhs
lhs = lhs.id
_internal_assert(lhs not in self.loops_above.keys(), \
"Loop variable cannot be overwritten!")
"Loop variable cannot be overwritten!")
decl, _, rw = self.usage[lhs]
if decl == lhs_:
_internal_assert(lhs not in self.var_consts.keys(), \
......@@ -227,16 +236,16 @@ class HybridParser(ast.NodeVisitor):
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
_internal_assert(isinstance(node.value, ast.Attribute), \
"Only variable and attribute's subscript supported so far")
"Only variable and attribute's subscript supported so far")
_internal_assert(isinstance(node.value.value, ast.Name), \
"The root of array access is expect to be a id!")
"The root of array access is expect to be a id!")
_internal_assert(node.value.attr == "shape", \
"Attribute access so far only 'shape' is supported!")
"Attribute access so far only 'shape' is supported!")
_internal_assert(len(args) == 1, "For 'shape' access the argument should be only one!")
args = args[0]
#TODO: maybe support non-constant value later?
_internal_assert(isinstance(args, (_expr.IntImm, _expr.UIntImm)), \
"So far only constant shape access supported!")
"So far only constant shape access supported!")
buf = self._get_buffer_from_id(node.value.value.id)
return buf.shape[args.value]
......@@ -294,7 +303,7 @@ class HybridParser(ast.NodeVisitor):
def visit_Call(self, node):
# Yet, no function pointer supported
_internal_assert(isinstance(node.func, ast.Name), \
"Only id-function function call is supported so far!")
"Only id-function function call is supported so far!")
func_id = node.func.id
n = len(node.args)
if func_id in LOOP_INTRIN.keys() and func_id != 'bind':
......@@ -311,7 +320,7 @@ class HybridParser(ast.NodeVisitor):
elif func_id == 'bind':
_internal_assert(n == 2, "A loop bind should only have 2 arguments!")
_internal_assert(isinstance(node.args[0], ast.Str), \
"A loop bind's first argument should be a string!")
"A loop bind's first argument should be a string!")
_vn = node.args[0].s
iter_var = thread_axis(node.args[0].s)
low, ext = _api.const(0, dtype='int32'), self.visit(node.args[1])
......@@ -321,11 +330,11 @@ class HybridParser(ast.NodeVisitor):
return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
elif func_id in ['allocate', 'output_tensor']:
_internal_assert(isinstance(node.args[0], ast.Tuple), \
"allocate's first argument should be a tuple of shape!")
"allocate's first argument should be a tuple of shape!")
shape = tuple(self.visit(i) for i in node.args[0].elts)
if func_id == 'output_tensor':
_internal_assert(not self.loops_above, \
"Are you sure to allocate a output buffer multiple times?")
"Are you sure to allocate a output buffer multiple times?")
for i in shape:
_internal_assert(isinstance(i, _expr.Expr), "The shape should be an expression")
if n > 1:
......@@ -333,18 +342,18 @@ class HybridParser(ast.NodeVisitor):
dtype = node.args[1].s
else:
_internal_assert(isinstance(node.args[1], ast.Attribute), \
"Unable to evaluate to get data type")
"Unable to evaluate to get data type")
to_eval = node.args[1]
_internal_assert(isinstance(to_eval.value, ast.Name), \
"Unable to evaluate the attribute to get data type")
"Unable to evaluate the attribute to get data type")
_internal_assert(to_eval.attr == 'dtype', \
"Only dtype attribute is supported so far")
"Only dtype attribute is supported so far")
dtype = self._get_buffer_from_id(to_eval.value.id).dtype
else:
dtype = 'float32'
if n > 2:
_internal_assert(isinstance(node.args[2], ast.Str), \
"The data scope should be an string")
"The data scope should be an string")
_internal_assert(func_id != 'output_tensor', "Output tensor cannot specify scope")
scope = node.args[2].s
else:
......@@ -361,7 +370,7 @@ class HybridParser(ast.NodeVisitor):
def visit_For(self, node):
iter_var, low, ext, for_type = self.visit(node.iter)
_internal_assert(isinstance(node.target, ast.Name), \
"The loop iterator should be a variable!")
"The loop iterator should be a variable!")
_name = node.target.id
if iter_var is None:
_internal_assert(for_type is not None, "The loop bind function parse error!")
......@@ -389,7 +398,7 @@ class HybridParser(ast.NodeVisitor):
ids.append(node.value.id)
else:
_internal_assert(isinstance(node.value, ast.Tuple), \
"You should return either a single tensor or a tuple")
"You should return either a single tensor or a tuple")
for i in node.value.elts:
_internal_assert(isinstance(i, ast.Name), "What do you return?")
ids.append(i.id)
......
......@@ -15,9 +15,14 @@ from ..tensor import Tensor
#pylint: disable=invalid-name
np_arg_types = tuple(list(numeric_types) + [numpy.ndarray])
tvm_arg_types = (Tensor, _expr.Var)
tvm_arg_types = (Tensor, _expr.Var, _expr.ConstExpr)
halide_imm_types = (_expr.IntImm, _expr.FloatImm, _expr.UIntImm)
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
......@@ -50,14 +55,16 @@ def _is_tvm_arg_types(args):
If neither is true, raise a value error."""
if isinstance(args[0], tvm_arg_types):
for elem in args[1:]:
if not isinstance(elem, tvm_arg_types):
raise ValueError("Expect a Var or Tensor instance but % get!" % str(type(elem)))
_internal_assert(isinstance(elem, tvm_arg_types),
"Expecting a Var, Tensor or ConstExpr instance but %s get!" \
% str(type(elem)))
return True
if not isinstance(args[0], np_arg_types):
raise ValueError("Expect a numpy type but % get!" % str(type(args[0])))
_internal_assert(isinstance(args[0], np_arg_types), \
"Expect a numpy type but %s get!" % str(type(args[0])))
for elem in args[1:]:
if not isinstance(elem, np_arg_types):
raise ValueError("Expect a numpy type but % get!" % str(type(elem)))
_internal_assert(isinstance(elem, np_arg_types), \
"Expect a numpy type but %s get!" % str(type(elem)))
return False
......@@ -79,12 +86,3 @@ def _restore_runtime(func, intersect):
_globals.pop(elem)
for k, v in intersect:
_globals[k] = v
def _internal_assert(cond, err):
"""Simplify the code segment like if not XXX then raise an error"""
if not cond:
raise ValueError(err)
# Almost the same functionality as the one above, but in this case,
# the error is caused by users inproper usage.
_user_assert = _internal_assert
......@@ -15,6 +15,7 @@ class PyVariableUsage(ast.NodeVisitor):
self.scope_level = []
self._args = {}
self.args = args
self.aug_assign_ = False
def visit_FunctionDef(self, node):
......@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
self.visit(elem)
def visit_AugAssign(self, node):
self.aug_assign_ = True
self.generic_visit(node)
self.aug_assign_ = False
def visit_Name(self, node):
# If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys():
......@@ -61,7 +68,9 @@ class PyVariableUsage(ast.NodeVisitor):
if node.id not in self.status.keys():
_internal_assert(isinstance(node.ctx, ast.Store), \
'Undeclared variable %s' % node.id)
'Undeclared variable %s' % node.id)
if self.aug_assign_:
raise ValueError('"First store" cannot be an AugAssign')
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
......
......@@ -115,7 +115,7 @@ def test_fanout():
for i in range(a.shape[0] - 3):
sigma = 0.0
for j in range(3):
sigma = sigma + a[i + j]
sigma += a[i + j]
sigma = sigma / three
b[i] = sigma
return b
......@@ -246,7 +246,7 @@ def test_bind():
def vec_add(a, b):
c = output_tensor((1000, ), dtype='float32')
for tx in bind('threadIdx.x', 1000):
c[tx] = b[tx] + c[tx]
c[tx] = a[tx] + b[tx]
return c
a = tvm.placeholder((1000, ), dtype='float32', name='a')
......@@ -308,7 +308,7 @@ def test_non_zero():
s = 0.0
for di in range(3):
for dj in range(3):
s = s + a[i-di, j-dj]
s += a[i-di, j-dj]
b[i-2, j-2] = s / 9.0
return b
......@@ -419,6 +419,32 @@ def test_downstream():
module(tvm_a, tvm_c)
tvm.testing.assert_allclose(tvm_c.asnumpy(), ref, 1e-5, 1e-5)
def test_const_param():
@tvm.hybrid.script
def add_something(a, b):
c = output_tensor((11, ), 'int32')
for i in range(11):
c[i] = a[i] + b
return c
a = tvm.placeholder((11, ), dtype='int32', name='a')
b = tvm.const(11, 'int32')
c = add_something(a, b)
sch = tvm.create_schedule(c.op)
module = tvm.build(sch, [a, c], 'llvm')
assert(module)
np_a = numpy.arange(11).astype('int32')
np_b = 11
np_c = numpy.zeros((11, )).astype('int32')
nd_a = tvm.ndarray.array(np_a)
nd_c = tvm.ndarray.array(numpy.zeros((11, )).astype('int32'))
module(nd_a, nd_c)
ref = add_something(np_a, 11)
tvm.testing.assert_allclose(nd_c.asnumpy(), ref, 1e-5, 1e-5)
if __name__ == "__main__":
test_outer_product()
......@@ -432,5 +458,6 @@ if __name__ == "__main__":
#test_inplace()
test_upstream()
test_downstream()
test_const_param()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment