Commit b8fedfb1 by Jian Weng Committed by Tianqi Chen

[FRONTEND] [HYBRID] Augmented assign operator supported! (#1459)

parent 84eea572
...@@ -15,7 +15,7 @@ from .. import ir_pass as _ir_pass ...@@ -15,7 +15,7 @@ from .. import ir_pass as _ir_pass
def list_to_block(visit, lst): def list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block""" """Convert a list of Python IR nodes to HalideIR Block"""
lst = list(map(visit, lst)) lst = [visit(i) for i in lst]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())]
if not lst: if not lst:
return make_nop() return make_nop()
...@@ -162,6 +162,13 @@ class HybridParser(ast.NodeVisitor): ...@@ -162,6 +162,13 @@ class HybridParser(ast.NodeVisitor):
def visit_Num(self, node): def visit_Num(self, node):
return _api.const(node.n) return _api.const(node.n)
def visit_AugAssign(self, node):
lhs = self.visit(node.target)
rhs = self.visit(node.value)
rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs)
if not isinstance(lhs, _expr.Call):
raise ValueError("The LHS of an AugAssign is supposed to be a call!")
return _make.Provide(lhs.func, 0, rhs, lhs.args)
def visit_Assign(self, node): def visit_Assign(self, node):
if len(node.targets) != 1: if len(node.targets) != 1:
......
...@@ -14,6 +14,7 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -14,6 +14,7 @@ class PyVariableUsage(ast.NodeVisitor):
self.scope_level = [] self.scope_level = []
self._args = {} self._args = {}
self.args = args self.args = args
self.aug_assign_ = False
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
...@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -48,6 +49,12 @@ class PyVariableUsage(ast.NodeVisitor):
self.visit(elem) self.visit(elem)
def visit_AugAssign(self, node):
self.aug_assign_ = True
self.generic_visit(node)
self.aug_assign_ = False
def visit_Name(self, node): def visit_Name(self, node):
# If it is from the argument list or loop variable, we do not worry about it! # If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys(): if node.id in self._args.keys():
...@@ -62,6 +69,8 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -62,6 +69,8 @@ class PyVariableUsage(ast.NodeVisitor):
if node.id not in self.status.keys(): if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store): if not isinstance(node.ctx, ast.Store):
raise ValueError('In Python, "first store" indicates "declaration"') raise ValueError('In Python, "first store" indicates "declaration"')
if self.aug_assign_:
raise ValueError('"First store" cannot be an AugAssign')
self.status[node.id] = (node, self.scope_level[-1], set()) self.status[node.id] = (node, self.scope_level[-1], set())
else: else:
decl, loop, usage = self.status[node.id] decl, loop, usage = self.status[node.id]
......
...@@ -38,7 +38,9 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'): ...@@ -38,7 +38,9 @@ def run_and_check(func, args, outs, var_dict={}, target='llvm'):
module(*nd_args) module(*nd_args)
for nd, np in to_check: for nd, np in to_check:
numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-3, atol=1e-3)
return module
@script @script
...@@ -83,7 +85,7 @@ def test_outer_product(): ...@@ -83,7 +85,7 @@ def test_outer_product():
func = tvm.lower(ir, [n, m, a, b, c]) func = tvm.lower(ir, [n, m, a, b, c])
func = tvm.build(func) func = tvm.build(func)
run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001}) run_and_check(outer_product, [n, m, a, b, c], [c], {n: 99, m: 101})
for key, _ in HYBRID_GLOBALS.items(): for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys() assert key not in globals().keys()
...@@ -165,20 +167,32 @@ def test_fanout(): ...@@ -165,20 +167,32 @@ def test_fanout():
run_and_check(fanout, [n, a, b], [b], {n: 10}) run_and_check(fanout, [n, a, b], [b], {n: 10})
@script
def failure():
for i in range(1, 100):
i = 0
def test_failure(): def test_failure():
try: try:
@script
def failure():
for i in range(1, 100):
i = 0
tvm.hybrid.parse(failure, []) tvm.hybrid.parse(failure, [])
except IOError as err: except IOError as err:
assert sys.version_info[0] == 2 assert sys.version_info[0] == 2
print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err)) print('[Warning] Case test_failure.0 is skipped by Python2 because "%s"' % str(err))
except Exception as err: except ValueError as err:
assert str(err) == 'You CAN NEVER overwrite a loop variable!' assert str(err) == 'You CAN NEVER overwrite a loop variable!'
try:
@tvm.hybrid.script
def augdefine():
for i in range(10):
es += 0
tvm.hybrid.parse(augdefine, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_failure.1 is skipped by Python2 because "%s"' % str(err))
except ValueError as err:
assert str(err) == '"First store" cannot be an AugAssign'
def test_looptype(): def test_looptype():
@script @script
...@@ -280,7 +294,7 @@ def test_non_zero(): ...@@ -280,7 +294,7 @@ def test_non_zero():
s = 0.0 s = 0.0
for di in range(3): for di in range(3):
for dj in range(3): for dj in range(3):
s = s + a[i-di, j-dj] s += a[i-di, j-dj]
b[i-2, j-2] = s / 9.0 b[i-2, j-2] = s / 9.0
try: try:
a = tvm.placeholder((32, 32), 'float32', 'a') a = tvm.placeholder((32, 32), 'float32', 'a')
...@@ -315,29 +329,39 @@ def test_allocate(): ...@@ -315,29 +329,39 @@ def test_allocate():
a = tvm.placeholder((32, 32), 'float32', 'a') a = tvm.placeholder((32, 32), 'float32', 'a')
b = tvm.placeholder((30, 30), 'float32', 'b') b = tvm.placeholder((30, 30), 'float32', 'b')
run_and_check(blur2d, [a, b], [b]) run_and_check(blur2d, [a, b], [b])
if tvm.gpu().exist: if tvm.gpu().exist:
@tvm.hybrid.script @tvm.hybrid.script
def share_vec_add(a, b, c): def shared_gemm(a, b, c):
shared = allocate((256, ), 'float32', 'shared') for io in bind('blockIdx.x', 8):
for i in bind("threadIdx.x", 256): for ii in bind('blockIdx.y', 8):
shared[i] = a[i] shared_b = allocate((64, 64), 'float32', 'shared')
local = allocate((256, ), 'float32', 'local') for k in range(64):
for i in bind("threadIdx.x", 256): shared_b[io * 8 + ii, k] = b[io * 8 + ii, k]
local[i] = b[i] for jo in bind('threadIdx.y', 8):
for i in bind("threadIdx.x", 256): for ji in bind('threadIdx.x', 8):
c[i] = shared[i] + local[i] for k in range(64):
c[io*8+ii, jo*8+ji] += a[io*8+ii, k] * shared_b[k, jo*8+ji]
a = tvm.placeholder((256, ), dtype='float32', name='a')
b = tvm.placeholder((256, ), dtype='float32', name='b') a = tvm.placeholder((64, 64), dtype='float32', name='a')
c = tvm.placeholder((256, ), dtype='float32', name='c') b = tvm.placeholder((64, 64), dtype='float32', name='b')
run_and_check(share_vec_add, [a, b, c], [c], target='cuda') c = tvm.placeholder((64, 64), dtype='float32', name='c')
module = run_and_check(shared_gemm, [a, b, c], [c], target='cuda')
assert "__syncthreads()" in module.imported_modules[0].get_source()
else: else:
print('[Warning] No GPU found! Skip shared mem test!') print('[Warning] No GPU found! Skip shared mem test!')
def test_augassign():
@tvm.hybrid.script
def augassign(a):
for i in range(a.shape[0]):
a[i] += 1.0
a = tvm.placeholder((16, ), dtype='float32', name='a')
run_and_check(augassign, [a], [a])
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
test_fanout() test_fanout()
...@@ -348,4 +372,5 @@ if __name__ == "__main__": ...@@ -348,4 +372,5 @@ if __name__ == "__main__":
test_math_intrin() test_math_intrin()
test_non_zero() test_non_zero()
test_allocate() test_allocate()
test_augassign()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment