Commit d29b1c9e by Jian Weng Committed by Tianqi Chen

[FRONTEND] [HYBRID] Non-zero starting supported; Buffer AttrStmt add! (#1330)

parent 9b8cb1b6
...@@ -29,7 +29,7 @@ class bind(_range): #pylint: disable=invalid-name ...@@ -29,7 +29,7 @@ class bind(_range): #pylint: disable=invalid-name
unroll = vectorize = parallel = _range #pylint: disable=invalid-name unroll = vectorize = parallel = _range #pylint: disable=invalid-name
def allocate(shape, dtype='float32'): def allocate(shape, dtype='float32', scope='global'): #pylint: disable=unused-argument
"""Allocate a buffer with given shape """Allocate a buffer with given shape
Parameters Parameters
...@@ -38,6 +38,8 @@ def allocate(shape, dtype='float32'): ...@@ -38,6 +38,8 @@ def allocate(shape, dtype='float32'):
The shape of the tensor to be allocated The shape of the tensor to be allocated
dtype: string dtype: string
The data type of the tensor The data type of the tensor
scope: string
The storage scope of the tensor
Returns Returns
------- -------
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import ast import ast
import operator import operator
import sys import sys
from .util import make_nop, make_const_true, make_range_one, halide_imm_types from .util import make_nop, halide_imm_types
from .intrin import LOOP_INTRIN, MATH_INTRIN from .intrin import LOOP_INTRIN, MATH_INTRIN
from .var_decl import determine_variable_usage from .var_decl import determine_variable_usage
from ..api import thread_axis from ..api import thread_axis
...@@ -75,7 +75,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -75,7 +75,8 @@ class HybridParser(ast.NodeVisitor):
self.args = args[:] self.args = args[:]
self.usage = usage.copy() self.usage = usage.copy()
self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer) self._args = {} # Dict maps arg name to actual arg instance (either a var or a buffer)
self.buffers = {} self.var_buffers = {} # Buffers formed by mutatble variables
self.alloc_buffers = {} # Buffers formed by allocate instructions
self.loops_above = {} # State variable that indicates loop levels above the current node self.loops_above = {} # State variable that indicates loop levels above the current node
self.var_consts = {} # Variables that are determined as readonly in previous stage self.var_consts = {} # Variables that are determined as readonly in previous stage
self.func_name = func_name # The name of the function to be lowered self.func_name = func_name # The name of the function to be lowered
...@@ -87,19 +88,30 @@ class HybridParser(ast.NodeVisitor): ...@@ -87,19 +88,30 @@ class HybridParser(ast.NodeVisitor):
for key, val in self.usage.items(): for key, val in self.usage.items():
if key in self.var_consts.keys(): if key in self.var_consts.keys():
continue continue
_, scope, _ = val _, level, _ = val
if scope == node: if level == node:
_buf = self.buffers[key] if key in self.var_buffers.keys():
_buf = self.var_buffers[key]
_scope = 'global'
else:
_buf, _scope = self.alloc_buffers[key]
_domain = [_make.range_by_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype _dtype = _buf.dtype
_one = make_range_one() _true = _api.convert(True)
_true = make_const_true() body = _make.Realize(_buf.op, 0, _dtype, _domain, _true, body)
body = _make.Realize(_buf.op, 0, _dtype, [_one], _true, body) body = _make.AttrStmt(_buf.op, 'realize_scope', _api.convert(_scope), body)
return body return body
def _check_id_a_buffer(self, s): def _get_buffer_from_id(self, s):
if s not in self._args.keys(): if s not in self._args.keys() and s not in self.alloc_buffers.keys():
raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s) raise ValueError("This %s is expected to be in argument list or allocated buffer!" % s)
if s in self._args.keys() and s in self.alloc_buffers.keys():
raise ValueError("%s, a buffer cannot be both argument and allocated!" % s)
if s in self._args.keys():
return self._args[s]
return self.alloc_buffers[s][0]
#pylint: disable=invalid-name, missing-docstring #pylint: disable=invalid-name, missing-docstring
...@@ -138,8 +150,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -138,8 +150,8 @@ class HybridParser(ast.NodeVisitor):
if _id not in self.usage.keys(): if _id not in self.usage.keys():
raise ValueError("This id %s is expected to be a defined variable!" % _id) raise ValueError("This id %s is expected to be a defined variable!" % _id)
# Buffer # Buffer
if _id in self.buffers.keys(): if _id in self.var_buffers.keys():
_buf = self.buffers[_id] _buf = self.var_buffers[_id]
return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0) return _make.Call(_buf.dtype, _id, [_api.const(0)], _expr.Call.Halide, _buf.op, 0)
# Compilation time constant # Compilation time constant
if _id not in self.var_consts.keys(): if _id not in self.var_consts.keys():
...@@ -155,7 +167,9 @@ class HybridParser(ast.NodeVisitor): ...@@ -155,7 +167,9 @@ class HybridParser(ast.NodeVisitor):
if len(node.targets) != 1: if len(node.targets) != 1:
raise ValueError("So far only one-valued assignment is supported!") raise ValueError("So far only one-valued assignment is supported!")
lhs = node.targets[0] lhs = node.targets[0]
rhs = _ir_pass.Simplify(self.visit(node.value)) rhs = self.visit(node.value)
if isinstance(rhs, _expr.Expr):
rhs = _ir_pass.Simplify(rhs)
if isinstance(lhs, ast.Name): if isinstance(lhs, ast.Name):
#TODO: support defined intermediate buffer later #TODO: support defined intermediate buffer later
lhs_ = lhs lhs_ = lhs
...@@ -166,25 +180,31 @@ class HybridParser(ast.NodeVisitor): ...@@ -166,25 +180,31 @@ class HybridParser(ast.NodeVisitor):
if decl == lhs_: if decl == lhs_:
if lhs in self.var_consts.keys(): if lhs in self.var_consts.keys():
raise ValueError("BUG: A constant cannot be overwritten!") raise ValueError("BUG: A constant cannot be overwritten!")
if lhs in self.buffers.keys(): if lhs in self.var_buffers.keys() or lhs in self.alloc_buffers.keys():
raise ValueError("BUG: This value should not be defined before this point!") raise ValueError("BUG: This value should not be defined before this point!")
if isinstance(rhs, tuple):
shape, dtype, scope = rhs
ph = _api.placeholder(shape, dtype=dtype, name=lhs)
self.alloc_buffers[lhs] = (ph, scope)
return make_nop()
if isinstance(rhs, halide_imm_types) and ast.Store not in rw: if isinstance(rhs, halide_imm_types) and ast.Store not in rw:
self.var_consts[lhs] = rhs self.var_consts[lhs] = rhs
else: else:
self.buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs) self.var_buffers[lhs] = _api.placeholder((1, ), dtype=rhs.dtype, name=lhs)
if lhs in self.var_consts.keys(): if lhs in self.var_consts.keys():
return make_nop() return make_nop()
else: else:
if lhs not in self.buffers.keys(): if lhs not in self.var_buffers.keys():
raise ValueError("BUG: This value should be defined before!") raise ValueError("BUG: This variable should be defined before!")
return _make.Provide(self.buffers[lhs].op, 0, rhs, [_api.const(0, dtype=rhs.dtype)]) tgt = self.var_buffers[lhs]
return _make.Provide(tgt.op, 0, rhs, [_api.const(0, dtype=rhs.dtype)])
else: else:
lhs = self.visit(lhs) lhs = self.visit(lhs)
if not isinstance(lhs, _expr.Call): if not isinstance(lhs, _expr.Call):
raise ValueError("An array access's LHS is expected to be a expr.Call!") raise ValueError("An array access's LHS is expected to be a expr.Call!")
#TODO: support slice later #TODO: support slice later
self._check_id_a_buffer(lhs.name) buf = self._get_buffer_from_id(lhs.name)
return _make.Provide(self._args[lhs.name].op, 0, rhs, lhs.args) return _make.Provide(buf.op, 0, rhs, lhs.args)
def visit_Index(self, node): def visit_Index(self, node):
...@@ -197,8 +217,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -197,8 +217,7 @@ class HybridParser(ast.NodeVisitor):
args = self.visit(node.slice) args = self.visit(node.slice)
if isinstance(node.value, ast.Name): if isinstance(node.value, ast.Name):
array = node.value.id array = node.value.id
self._check_id_a_buffer(array) _buf = self._get_buffer_from_id(array)
_buf = self._args[array]
return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0) return _make.Call(_buf.dtype, array, args, _expr.Call.Halide, _buf.op, 0)
elif isinstance(node.value, ast.Attribute): elif isinstance(node.value, ast.Attribute):
if not isinstance(node.value.value, ast.Name): if not isinstance(node.value.value, ast.Name):
...@@ -211,8 +230,8 @@ class HybridParser(ast.NodeVisitor): ...@@ -211,8 +230,8 @@ class HybridParser(ast.NodeVisitor):
#TODO: maybe support non-constant value later? #TODO: maybe support non-constant value later?
if not isinstance(args, (_expr.IntImm, _expr.UIntImm)): if not isinstance(args, (_expr.IntImm, _expr.UIntImm)):
raise ValueError("So far only constant shape access supported!") raise ValueError("So far only constant shape access supported!")
self._check_id_a_buffer(node.value.value.id) buf = self._get_buffer_from_id(node.value.value.id)
return self._args[node.value.value.id].shape[args.value] return buf.shape[args.value]
else: else:
raise ValueError("Not supported yet!") raise ValueError("Not supported yet!")
...@@ -303,8 +322,30 @@ class HybridParser(ast.NodeVisitor): ...@@ -303,8 +322,30 @@ class HybridParser(ast.NodeVisitor):
elif func_id in MATH_INTRIN: elif func_id in MATH_INTRIN:
return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args]) return getattr(intrin, func_id)(*[self.visit(arg) for arg in node.args])
elif func_id == 'allocate': elif func_id == 'allocate':
#TODO: Support it later! if not isinstance(node.args[0], ast.Tuple):
return make_nop() raise ValueError("allocate's first argument should be a tuple of shape!")
shape = tuple(self.visit(i) for i in node.args[0].elts)
for i in shape:
if not isinstance(i, _expr.Expr):
raise ValueError("The shape should be an expression")
if n > 1:
if not isinstance(node.args[1], ast.Str):
raise ValueError("The data type should be an string")
dtype = node.args[1].s
else:
dtype = 'float32'
if n > 2:
if not isinstance(node.args[2], ast.Str):
raise ValueError("The data type should be an string")
scope = node.args[2].s
else:
scope = 'global'
return (shape, dtype, scope)
elif func_id == 'max' or func_id == 'min':
if n != 2:
raise ValueError("Max/Min function should have 2 elements")
a, b = self.visit(node.args[0]), self.visit(node.args[1])
return getattr(_make, func_id.title())(a, b)
else: else:
raise ValueError("Function call not supported yet!") raise ValueError("Function call not supported yet!")
...@@ -317,8 +358,10 @@ class HybridParser(ast.NodeVisitor): ...@@ -317,8 +358,10 @@ class HybridParser(ast.NodeVisitor):
if iter_var is None: if iter_var is None:
if for_type is None: if for_type is None:
raise ValueError("The loop bind function parse error!") raise ValueError("The loop bind function parse error!")
iter_var = _api.var(_name) offset = iter_var = _api.var(_name)
self.loops_above[_name] = iter_var if not _ir_pass.Equal(low, _api.const(0, dtype='int32')):
offset = iter_var + low
self.loops_above[_name] = offset
else: else:
if for_type is not None: if for_type is not None:
raise ValueError("The loop iterating function parse error!") raise ValueError("The loop iterating function parse error!")
...@@ -328,7 +371,7 @@ class HybridParser(ast.NodeVisitor): ...@@ -328,7 +371,7 @@ class HybridParser(ast.NodeVisitor):
if for_type is None: if for_type is None:
res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body) res = _make.AttrStmt(iter_var, 'thread_extent', ext, _body)
else: else:
res = _make.For(iter_var, low, ext, for_type, 0, _body) res = _make.For(iter_var, _api.const(0, dtype='int32'), ext, for_type, 0, _body)
self.loops_above.pop(_name) self.loops_above.pop(_name)
return res return res
......
...@@ -22,16 +22,6 @@ def make_nop(): ...@@ -22,16 +22,6 @@ def make_nop():
return _make.Evaluate(_api.const(0, dtype='int32')) return _make.Evaluate(_api.const(0, dtype='int32'))
def make_range_one():
"""Returns a [0, 1] range node in HalideIR."""
return _make.range_by_min_extent(0, 1)
def make_const_true():
"""Returns a constant True node in HalideIR."""
return _api.convert(True)
def _pruned_source(func): def _pruned_source(func):
"""Prune source code's extra leading spaces""" """Prune source code's extra leading spaces"""
lines = inspect.getsource(func).split('\n') lines = inspect.getsource(func).split('\n')
......
...@@ -41,7 +41,8 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -41,7 +41,8 @@ class PyVariableUsage(ast.NodeVisitor):
#No function pointer supported so far #No function pointer supported so far
if not isinstance(node.func, ast.Name): if not isinstance(node.func, ast.Name):
raise ValueError("Function call should be an id") raise ValueError("Function call should be an id")
if (node.func.id not in HYBRID_GLOBALS.keys()) and node.func.id != 'range': func_id = node.func.id
if func_id not in list(HYBRID_GLOBALS.keys()) + ['range', 'max', 'min']:
raise ValueError("Function call id not in intrinsics' list") raise ValueError("Function call id not in intrinsics' list")
for elem in node.args: for elem in node.args:
self.visit(elem) self.visit(elem)
...@@ -64,7 +65,6 @@ class PyVariableUsage(ast.NodeVisitor): ...@@ -64,7 +65,6 @@ class PyVariableUsage(ast.NodeVisitor):
self.status[node.id] = (node, self.scope_level[-1], set()) self.status[node.id] = (node, self.scope_level[-1], set())
else: else:
decl, loop, usage = self.status[node.id] decl, loop, usage = self.status[node.id]
loop = self.scope_level[-1]
usage.add(type(node.ctx)) usage.add(type(node.ctx))
self.status[node.id] = (decl, loop, usage) self.status[node.id] = (decl, loop, usage)
......
...@@ -2,6 +2,7 @@ import tvm, inspect, sys, traceback, numpy ...@@ -2,6 +2,7 @@ import tvm, inspect, sys, traceback, numpy
from tvm.hybrid import script from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS from tvm.hybrid.intrin import HYBRID_GLOBALS
@script @script
def outer_product(n, m, a, b, c): def outer_product(n, m, a, b, c):
for i in range(n): for i in range(n):
...@@ -56,6 +57,7 @@ def test_outer_product(): ...@@ -56,6 +57,7 @@ def test_outer_product():
tvm_c = tvm.ndarray.array(numpy.zeros((_n, _m), dtype='float32')) tvm_c = tvm.ndarray.array(numpy.zeros((_n, _m), dtype='float32'))
func(_n, _m, tvm_a, tvm_b, tvm_c) func(_n, _m, tvm_a, tvm_b, tvm_c)
numpy.testing.assert_allclose(tvm_c.asnumpy(), c_python, rtol=1e-5) numpy.testing.assert_allclose(tvm_c.asnumpy(), c_python, rtol=1e-5)
for key, _ in HYBRID_GLOBALS.items(): for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys() assert key not in globals().keys()
assert key not in outer_product.__globals__.keys() assert key not in outer_product.__globals__.keys()
...@@ -74,8 +76,8 @@ def test_fanout(): ...@@ -74,8 +76,8 @@ def test_fanout():
b[i] = sigma b[i] = sigma
n = tvm.var('n') n = tvm.var('n')
a = tvm.placeholder((n, ), name='a') a = tvm.placeholder((n, ), 'float32', name='a')
b = tvm.placeholder((n-3, ), name='b') b = tvm.placeholder((n-3, ), 'float32', name='b')
ir = fanout(n, a, b) ir = fanout(n, a, b)
#Check for i in (0, n-3) #Check for i in (0, n-3)
...@@ -85,12 +87,14 @@ def test_fanout(): ...@@ -85,12 +87,14 @@ def test_fanout():
assert tvm.ir_pass.Equal(ir.extent, n - 3) assert tvm.ir_pass.Equal(ir.extent, n - 3)
#Check loopbody #Check loopbody
ibody = ir.body ibody = ir.body
assert isinstance(ibody, tvm.stmt.Realize) assert isinstance(ibody, tvm.stmt.AttrStmt)
assert ibody.bounds[0].min.value == 0 abody = ibody.body
assert ibody.bounds[0].extent.value == 1 assert isinstance(abody, tvm.stmt.Realize)
assert ibody.func.name == 'sigma' assert abody.bounds[0].min.value == 0
assert abody.bounds[0].extent.value == 1
assert abody.func.name == 'sigma'
#Check i loop body #Check i loop body
rbody = ibody.body rbody = abody.body
assert isinstance(rbody.first, tvm.stmt.Provide) assert isinstance(rbody.first, tvm.stmt.Provide)
assert rbody.first.func.name == 'sigma' assert rbody.first.func.name == 'sigma'
assert len(rbody.first.args) == 1 assert len(rbody.first.args) == 1
...@@ -131,6 +135,21 @@ def test_fanout(): ...@@ -131,6 +135,21 @@ def test_fanout():
assert len(write.value.args) == 1 assert len(write.value.args) == 1
assert write.value.args[0].value == 0 assert write.value.args[0].value == 0
func = tvm.build(tvm.lower(ir, [n, a, b]))
assert func
np_a = numpy.random.randn(10).astype('float32')
np_b = numpy.zeros(7).astype('float32')
nd_a = tvm.ndarray.array(np_a)
nd_b = tvm.ndarray.array(np_b)
fanout(10, np_a, np_b)
func(10, nd_a, nd_b)
numpy.testing.assert_allclose(nd_b.asnumpy(), np_b, rtol=1e-5, atol=1e-5)
@script @script
def failure(): def failure():
for i in range(1, 100): for i in range(1, 100):
...@@ -148,15 +167,18 @@ def test_failure(): ...@@ -148,15 +167,18 @@ def test_failure():
def test_looptype(): def test_looptype():
@script @script
def looptype(a): def looptype(a, b, c):
for i in parallel(6): for i in parallel(8):
a[i] = i a[i] = i
for j in vectorize(6): for j in vectorize(8):
a[j] = j b[j] = j
for k in unroll(6): for k in unroll(8):
a[k] = k c[k] = k
a = tvm.placeholder((6, ), name='a')
ir = looptype(a) a = tvm.placeholder((8, ), name='a', dtype='int32')
b = tvm.placeholder((8, ), name='b', dtype='int32')
c = tvm.placeholder((8, ), name='c', dtype='int32')
ir = looptype(a, b, c)
iloop = ir.first iloop = ir.first
jloop = ir.rest.first jloop = ir.rest.first
kloop = ir.rest.rest kloop = ir.rest.rest
...@@ -164,6 +186,24 @@ def test_looptype(): ...@@ -164,6 +186,24 @@ def test_looptype():
assert jloop.for_type == tvm.stmt.For.Vectorized assert jloop.for_type == tvm.stmt.For.Vectorized
assert kloop.for_type == tvm.stmt.For.Unrolled assert kloop.for_type == tvm.stmt.For.Unrolled
func = tvm.build(tvm.lower(ir, [a, b, c]))
np_a = numpy.zeros((8, )).astype('int32')
np_b = numpy.zeros((8, )).astype('int32')
np_c = numpy.zeros((8, )).astype('int32')
nd_a = tvm.ndarray.array(np_a)
nd_b = tvm.ndarray.array(np_b)
nd_c = tvm.ndarray.array(np_c)
looptype(np_a, np_b, np_c)
func(nd_a, nd_b, nd_c)
numpy.testing.assert_allclose(np_a, nd_a.asnumpy())
numpy.testing.assert_allclose(np_b, nd_b.asnumpy())
numpy.testing.assert_allclose(np_c, nd_c.asnumpy())
def test_if(): def test_if():
@script @script
def if_then_else(a, b): def if_then_else(a, b):
...@@ -234,12 +274,14 @@ def test_math_intrin(): ...@@ -234,12 +274,14 @@ def test_math_intrin():
a[3] = sigmoid(a[3]) a[3] = sigmoid(a[3])
a[4] = power(a[4], a[5]) a[4] = power(a[4], a[5])
a[5] = tanh(a[5]) a[5] = tanh(a[5])
a[6] = min(a[4], a[5])
a[7] = max(a[5], a[6])
a6 = tvm.placeholder((6, ), dtype='float32', name='a') a6 = tvm.placeholder((8, ), dtype='float32', name='a')
ir = intrin_real(a6) ir = intrin_real(a6)
func = tvm.build(tvm.lower(ir, [a6])) func = tvm.build(tvm.lower(ir, [a6]))
assert func assert func
a = numpy.arange(2, 8).astype('float32') a = numpy.arange(2, 10).astype('float32')
tvm_a = tvm.ndarray.array(a) tvm_a = tvm.ndarray.array(a)
func(tvm_a) func(tvm_a)
intrin_real(a) intrin_real(a)
...@@ -259,22 +301,87 @@ def test_math_intrin(): ...@@ -259,22 +301,87 @@ def test_math_intrin():
func(tvm_a) func(tvm_a)
assert tvm_a.asnumpy()[0] == a[0] assert tvm_a.asnumpy()[0] == a[0]
def test_allocate_buffer(): def test_non_zero():
def blur(a): @tvm.hybrid.script
for i in serail(32): def blur(a, b):
h_blur = allocate((4, 36)) for i in range(2, 32):
for j in serail(4): for j in range(2, 32):
for k in serail(36): s = 0.0
s = allocate((1, ), 'float32') for di in range(3):
for dj in serail(4): for dj in range(3):
s[0] = s[0] + a[i, j + dj] s = s + a[i-di, j-dj]
h_blur[j, k] = s[0] / 4. b[i-2, j-2] = s / 9.0
for j in serail(32): try:
s = 0. np_a = numpy.random.randn(32, 32).astype('float32')
for di in serail(4): np_b = numpy.zeros((30, 30), dtype='float32')
s = s + h_blur[di, j] blur(np_a, np_b)
h_blur[i, j] = s / 4.
ph_a = tvm.placeholder((32, 32), 'float32', 'a')
ph_b = tvm.placeholder((30, 30), 'float32', 'b')
ir = tvm.hybrid.parse(blur, [ph_a, ph_b])
func = tvm.lower(ir, [ph_a, ph_b])
func = tvm.build(func)
nd_a = tvm.ndarray.array(np_a)
nd_b = tvm.ndarray.array(np_b)
func(nd_a, nd_b)
numpy.testing.assert_allclose(np_b, nd_b.asnumpy(), atol=1e-5, rtol=1e-5)
except IOError:
print('[Warning] Non-zero first test skipped by Python2')
@tvm.hybrid.script
def triangle(a, b, c):
for i in range(10):
for j in range(i, 10):
c[i, j] = a[i] * b[j]
a = tvm.placeholder((10, ), dtype='float32', name='a')
b = tvm.placeholder((10, ), dtype='float32', name='b')
c = tvm.placeholder((10, 10), dtype='float32', name='c')
np_a = numpy.random.randn(10).astype('float32')
np_b = numpy.random.randn(10).astype('float32')
np_c = numpy.zeros((10, 10)).astype('float32')
nd_a = tvm.ndarray.array(np_a)
nd_b = tvm.ndarray.array(np_b)
nd_c = tvm.ndarray.array(np_c)
triangle(np_a, np_b, np_c)
func = tvm.build(tvm.lower(triangle(a, b, c), [a, b, c]))
assert func
func(nd_a, nd_b, nd_c)
numpy.testing.assert_allclose(nd_c.asnumpy(), np_c)
def test_allocate():
@tvm.hybrid.script
def blur2d(a, b):
for i in range(30):
ha = allocate((3, 30), 'float32')
for j in range(3):
for k in range(30):
ha[j, k] = a[i+j, k] + a[i+j, k+1] + a[i+j, k+2]
for j in range(30):
b[i, j] = (ha[0, j] + ha[1, j] + ha[2, j]) / 9.0
a = tvm.placeholder((32, 32), 'float32', 'a')
b = tvm.placeholder((30, 30), 'float32', 'b')
func = tvm.build(tvm.lower(blur2d(a, b), [a, b]))
assert func
np_a = numpy.random.randn(32, 32).astype('float32')
np_b = numpy.zeros((30, 30)).astype('float32')
nd_a = tvm.ndarray.array(np_a)
nd_b = tvm.ndarray.array(np_b)
func(nd_a, nd_b)
blur2d(np_a, np_b)
numpy.testing.assert_allclose(nd_b.asnumpy(), np_b, atol=1e-5, rtol=1e-5)
if __name__ == "__main__": if __name__ == "__main__":
test_outer_product() test_outer_product()
...@@ -284,4 +391,6 @@ if __name__ == "__main__": ...@@ -284,4 +391,6 @@ if __name__ == "__main__":
test_if() test_if()
test_bind() test_bind()
test_math_intrin() test_math_intrin()
test_non_zero()
test_allocate()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment