Commit b8fedfb1 by Jian Weng Committed by Tianqi Chen

[FRONTEND] [HYBRID] Augmented assign operator supported! (#1459)

parent 84eea572
......@@ -15,7 +15,7 @@ from .. import ir_pass as _ir_pass
def list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block"""
lst = list(map(visit, lst))
lst = [visit(i) for i in lst]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())]
if not lst:
return make_nop()
......@@ -162,6 +162,13 @@ class HybridParser(ast.NodeVisitor):
def visit_Num(self, node):
return _api.const(node.n)
def visit_AugAssign(self, node):
lhs = self.visit(node.target)
rhs = self.visit(node.value)
rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs)
if not isinstance(lhs, _expr.Call):
raise ValueError("The LHS of an AugAssign is supposed to be a call!")
return _make.Provide(lhs.func, 0, rhs, lhs.args)
def visit_Assign(self, node):
if len(node.targets) != 1:
......
......@@ -14,6 +14,7 @@ class PyVariableUsage(ast.NodeVisitor):
self.scope_level = []
self._args = {}
self.args = args
self.aug_assign_ = False
def visit_FunctionDef(self, node):
......@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
self.visit(elem)
def visit_AugAssign(self, node):
self.aug_assign_ = True
self.generic_visit(node)
self.aug_assign_ = False
def visit_Name(self, node):
# If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys():
......@@ -62,6 +69,8 @@ class PyVariableUsage(ast.NodeVisitor):
if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store):
raise ValueError('In Python, "first store" indicates "declaration"')
if self.aug_assign_:
raise ValueError('"First store" cannot be an AugAssign')
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
......
......@@ -38,7 +38,9 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'):
module(*nd_args)
for nd, np in to_check:
numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-3, atol=1e-3)
return module
@script
......@@ -83,7 +85,7 @@ def test_outer_product():
func = tvm.lower(ir, [n, m, a, b, c])
func = tvm.build(func)
run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001})
run_and_check(outer_product, [n, m, a, b, c], [c], {n: 99, m: 101})
for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys()
......@@ -165,20 +167,32 @@ def test_fanout():
run_and_check(fanout, [n, a, b], [b], {n: 10})
@script
def failure():
for i in range(1, 100):
i = 0
def test_failure():
try:
@script
def failure():
for i in range(1, 100):
i = 0
tvm.hybrid.parse(failure, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err))
except Exception as err:
print('[Warning] Case test_failure.0 is skipped by Python2 because "%s"' % str(err))
except ValueError as err:
assert str(err) == 'You CAN NEVER overwrite a loop variable!'
try:
@tvm.hybrid.script
def augdefine():
for i in range(10):
es += 0
tvm.hybrid.parse(augdefine, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_failure.1 is skipped by Python2 because "%s"' % str(err))
except ValueError as err:
assert str(err) == '"First store" cannot be an AugAssign'
def test_looptype():
@script
......@@ -280,7 +294,7 @@ def test_non_zero():
s = 0.0
for di in range(3):
for dj in range(3):
s = s + a[i-di, j-dj]
s += a[i-di, j-dj]
b[i-2, j-2] = s / 9.0
try:
a = tvm.placeholder((32, 32), 'float32', 'a')
......@@ -315,29 +329,39 @@ def test_allocate():
a = tvm.placeholder((32, 32), 'float32', 'a')
b = tvm.placeholder((30, 30), 'float32', 'b')
run_and_check(blur2d, [a, b], [b])
if tvm.gpu().exist:
@tvm.hybrid.script
def share_vec_add(a, b, c):
shared = allocate((256, ), 'float32', 'shared')
for i in bind("threadIdx.x", 256):
shared[i] = a[i]
local = allocate((256, ), 'float32', 'local')
for i in bind("threadIdx.x", 256):
local[i] = b[i]
for i in bind("threadIdx.x", 256):
c[i] = shared[i] + local[i]
a = tvm.placeholder((256, ), dtype='float32', name='a')
b = tvm.placeholder((256, ), dtype='float32', name='b')
c = tvm.placeholder((256, ), dtype='float32', name='c')
run_and_check(share_vec_add, [a, b, c], [c], target='cuda')
def shared_gemm(a, b, c):
for io in bind('blockIdx.x', 8):
for ii in bind('blockIdx.y', 8):
shared_b = allocate((64, 64), 'float32', 'shared')
for k in range(64):
shared_b[io * 8 + ii, k] = b[io * 8 + ii, k]
for jo in bind('threadIdx.y', 8):
for ji in bind('threadIdx.x', 8):
for k in range(64):
c[io*8+ii, jo*8+ji] += a[io*8+ii, k] * shared_b[k, jo*8+ji]
a = tvm.placeholder((64, 64), dtype='float32', name='a')
b = tvm.placeholder((64, 64), dtype='float32', name='b')
c = tvm.placeholder((64, 64), dtype='float32', name='c')
module = run_and_check(shared_gemm, [a, b, c], [c], target='cuda')
assert "__syncthreads()" in module.imported_modules[0].get_source()
else:
print('[Warning] No GPU found! Skip shared mem test!')
def test_augassign():
@tvm.hybrid.script
def augassign(a):
for i in range(a.shape[0]):
a[i] += 1.0
a = tvm.placeholder((16, ), dtype='float32', name='a')
run_and_check(augassign, [a], [a])
if __name__ == "__main__":
test_outer_product()
test_fanout()
......@@ -348,4 +372,5 @@ if __name__ == "__main__":
test_math_intrin()
test_non_zero()
test_allocate()
test_augassign()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment