Unverified Commit f98035b0 by Tianqi Chen Committed by GitHub

[ARITH] cleanup the indexmod/div on python side (#4028)

parent bbf82e0e
...@@ -350,7 +350,9 @@ def compute_flop(sch): ...@@ -350,7 +350,9 @@ def compute_flop(sch):
return _count_flop(exp.value) return _count_flop(exp.value)
if isinstance(exp, expr.Var): if isinstance(exp, expr.Var):
return 0 return 0
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod, if isinstance(exp, (expr.Add, expr.Sub, expr.Mul,
expr.Div, expr.Mod,
expr.FloorDiv, expr.FloorMod,
expr.Max, expr.Min, expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE, expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)): expr.And, expr.Or, expr.Not)):
......
...@@ -72,23 +72,23 @@ class ExprOp(object): ...@@ -72,23 +72,23 @@ class ExprOp(object):
return _generic.multiply(other, self) return _generic.multiply(other, self)
def __div__(self, other): def __div__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other): if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error() raise div_ambiguity_error()
return _generic.divide(self, other) return _generic.divide(self, other)
def __rdiv__(self, other): def __rdiv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other): if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error() raise div_ambiguity_error()
return _generic.divide(other, self) return _generic.divide(other, self)
def __truediv__(self, other): def __truediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other): if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error() raise div_ambiguity_error()
return _generic.divide(self, other) return _generic.divide(self, other)
def __rtruediv__(self, other): def __rtruediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other): if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error() raise div_ambiguity_error()
return _generic.divide(other, self) return _generic.divide(other, self)
def __floordiv__(self, other): def __floordiv__(self, other):
...@@ -100,8 +100,8 @@ class ExprOp(object): ...@@ -100,8 +100,8 @@ class ExprOp(object):
return _generic.divide(other, self) return _generic.divide(other, self)
def __mod__(self, other): def __mod__(self, other):
# raise div_ambiguity_error() raise div_ambiguity_error()
return _make._OpMod(self, other) # return _make._OpMod(self, other)
def __neg__(self): def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype) neg_one = _api_internal._const(-1, self.dtype)
......
...@@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> { ...@@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
bool VisitExpr_(const Mul* op) final { return BinaryOp(op); } bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
bool VisitExpr_(const Div* op) final { return BinaryOp(op); } bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
bool VisitExpr_(const Mod* op) final { return BinaryOp(op); } bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); }
bool VisitExpr_(const Min* op) final { return BinaryOp(op); } bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
bool VisitExpr_(const Max* op) final { return BinaryOp(op); } bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
bool VisitExpr_(const EQ* op) final { return BinaryOp(op); } bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }
......
...@@ -373,6 +373,8 @@ def test_split_infer_type(): ...@@ -373,6 +373,8 @@ def test_split_infer_type():
yy = run_infer_type(y.astuple()) yy = run_infer_type(y.astuple())
assert yy.checked_type == ret_type assert yy.checked_type == ret_type
idxd = tvm.indexdiv
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
axis = tvm.var("axis") axis = tvm.var("axis")
verify_split((5, 5, 2, 2), 5, verify_split((5, 5, 2, 2), 5,
...@@ -393,15 +395,15 @@ def test_split_infer_type(): ...@@ -393,15 +395,15 @@ def test_split_infer_type():
axis=0) axis=0)
verify_split((d1, d2, d3, d4), 4, verify_split((d1, d2, d3, d4), 4,
relay.ty.TupleType(tvm.convert([ relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"), relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])), relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])),
axis=2) axis=2)
verify_split((d1, d2, d3, d4), 2, verify_split((d1, d2, d3, d4), 2,
relay.ty.TupleType(tvm.convert([ relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1/2, d2, d3, d4), "float32"), relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])), relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])),
axis=0) axis=0)
verify_split((d1, d2, d3, d4), (2, 4, 7), verify_split((d1, d2, d3, d4), (2, 4, 7),
relay.ty.TupleType(tvm.convert([ relay.ty.TupleType(tvm.convert([
......
...@@ -487,8 +487,9 @@ def test_yolo_reorg_infer_shape(): ...@@ -487,8 +487,9 @@ def test_yolo_reorg_infer_shape():
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32") assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
idxd = tvm.indexdiv
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2)) verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2)) verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2)))
def test_yolo_reorg(): def test_yolo_reorg():
def verify_yolo_reorg(shape, stride): def verify_yolo_reorg(shape, stride):
......
...@@ -60,14 +60,14 @@ def test_pack_gemm(): ...@@ -60,14 +60,14 @@ def test_pack_gemm():
k = tvm.reduce_axis((0, L)) k = tvm.reduce_axis((0, L))
bn = 4 bn = 4
fld = tvm.floordiv idxd = tvm.indexdiv
flm = tvm.floormod idxm = tvm.indexmod
A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j]) A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j]) B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj: C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k])) tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)]) C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)])
s = tvm.create_schedule([C.op]) s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * L * M assert compute_flop(s) == 2 * N * L * M
......
...@@ -37,7 +37,7 @@ def test_cuda_vectorize_add(): ...@@ -37,7 +37,7 @@ def test_cuda_vectorize_add():
print("skip because gpu does not support int8") print("skip because gpu does not support int8")
return return
A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes)) A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B') B = tvm.compute((n,), lambda i: A[i] + tvm.const(1, A.dtype), name='B')
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread) xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, bx) s[B].bind(xo, bx)
...@@ -165,9 +165,10 @@ def test_cuda_shuffle(): ...@@ -165,9 +165,10 @@ def test_cuda_shuffle():
print("skip because cuda is not enabled..") print("skip because cuda is not enabled..")
return return
idxm = tvm.indexmod
a = tvm.placeholder((64, ), 'int32') a = tvm.placeholder((64, ), 'int32')
b = tvm.placeholder((64, ), 'int32') b = tvm.placeholder((64, ), 'int32')
c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)]) c = tvm.compute((64, ), lambda x: a[x] + b[x - idxm(x, 4) + (3 - idxm(x, 4))])
sch = tvm.create_schedule(c.op) sch = tvm.create_schedule(c.op)
x = c.op.axis[0] x = c.op.axis[0]
xo, xi = sch[c].split(x, 4) xo, xi = sch[c].split(x, 4)
......
...@@ -109,14 +109,15 @@ def test_gpu(): ...@@ -109,14 +109,15 @@ def test_gpu():
dtype = "float32" dtype = "float32"
A = tvm.placeholder((n,), name='A') A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B') B = tvm.placeholder((n,), name='B')
fld = tvm.floordiv idxd = tvm.indexdiv
def test_device_ir(A, B, C): def test_device_ir(A, B, C):
n = A.shape[0] n = A.shape[0]
max_threads = 32 max_threads = 32
ib = tvm.ir_builder.create() ib = tvm.ir_builder.create()
bx = tvm.thread_axis("blockIdx.x") bx = tvm.thread_axis("blockIdx.x")
tx = tvm.thread_axis("threadIdx.x") tx = tvm.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads)) ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads))
ib.scope_attr(tx, "thread_extent", max_threads) ib.scope_attr(tx, "thread_extent", max_threads)
idx = bx.var * max_threads + tx.var idx = bx.var * max_threads + tx.var
Aptr = ib.buffer_ptr(A) Aptr = ib.buffer_ptr(A)
......
...@@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod(): ...@@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod():
def assert_simplified_equal(index_simplified, index_direct): def assert_simplified_equal(index_simplified, index_direct):
assert tvm.ir_pass.Equal(index_simplified, index_direct),\ assert tvm.ir_pass.Equal(index_simplified, index_direct),\
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct) "index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
idxdiv = tvm.indexdiv idxd = tvm.indexdiv
idxmod = tvm.indexmod idxm = tvm.indexmod
# Test Case1 # Test Case1
index_simplified = A_stride.vload( index_simplified = A_stride.vload(
(idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1)) (idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1))
index_direct = A_stride.vload((0, k0)) index_direct = A_stride.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
# Test Case2 # Test Case2
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n), index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1))) idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s)))) index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
# Test Case3 # Test Case3
index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) + index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxdiv(idxmod(k0, idxdiv(k1, s)), n), idxd(idxm(k0, idxd(k1, s)), n),
idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) + idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxmod(idxmod(k0, idxdiv(k1, s)), n))) idxm(idxm(k0, idxd(k1, s)), n)))
index_direct = A.vload((0, k0)) index_direct = A.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
# Test Case4 (not able to simplify) # Test Case4 (not able to simplify)
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n), index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))) idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n + index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n +
(idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))) (idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))))
assert_simplified_equal(index_simplified, index_direct) assert_simplified_equal(index_simplified, index_direct)
......
...@@ -28,7 +28,7 @@ def test_rewrite_Select(): ...@@ -28,7 +28,7 @@ def test_rewrite_Select():
tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1) tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value
a = tvm.expr.Select(i>10, y, z) a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z)
aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
assert yy.name == "tvm_if_then_else" assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else" assert zz.name == "tvm_if_then_else"
......
...@@ -221,14 +221,15 @@ def test_tensorize_matmul(): ...@@ -221,14 +221,15 @@ def test_tensorize_matmul():
# This tests whether algorithm and intrinsics expressions are simplified # This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696 # as much as possible first and then checked for equality. See Issue #696
def test_tensorize_op(): def test_tensorize_op():
tdiv = tvm.truncdiv idxd = tvm.indexdiv
tmod = tvm.truncmod idxm = tvm.indexmod
def op_intrin(): def op_intrin():
bh = 9 bh = 9
bw = 9 bw = 9
x = tvm.placeholder((5, 5), name='A') x = tvm.placeholder((5, 5), name='A')
y = tvm.compute((bh, bw), y = tvm.compute((bh, bw),
lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)]) lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)])
def intrin_func(ins, outs): def intrin_func(ins, outs):
xx, = ins xx, = ins
...@@ -239,7 +240,7 @@ def test_tensorize_op(): ...@@ -239,7 +240,7 @@ def test_tensorize_op():
return tvm.decl_tensor_intrin(y.op, intrin_func) return tvm.decl_tensor_intrin(y.op, intrin_func)
A = tvm.placeholder((5, 5), name='A') A = tvm.placeholder((5, 5), name='A')
B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)]) B = tvm.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
bt = op_intrin() bt = op_intrin()
s = tvm.create_schedule(B.op) s = tvm.create_schedule(B.op)
......
...@@ -70,6 +70,9 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh ...@@ -70,6 +70,9 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
OW = (PAD_W - KW) // WSTR + 1 OW = (PAD_W - KW) // WSTR + 1
oshape = (1, OH, OW, CO) oshape = (1, OH, OW, CO)
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Pad input channels of weights and data when it is not a multiple of 8 # Pad input channels of weights and data when it is not a multiple of 8
if CI_packed % 8 != 0: if CI_packed % 8 != 0:
CI_PAD = CI_packed % 8 CI_PAD = CI_packed % 8
...@@ -106,7 +109,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh ...@@ -106,7 +109,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8') data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')
kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4) kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4)
if kernel_vec.shape[-1] % 8 != 0 and CI_PAD != 0: idxm = tvm.indexmod
if idxm(kernel_vec.shape[-1], 8) != 0 and CI_PAD != 0:
kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD]) kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD])
N, H, W, IB, CI = data_q.shape N, H, W, IB, CI = data_q.shape
...@@ -147,8 +151,12 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh ...@@ -147,8 +151,12 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
else: else:
conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar') conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar')
conv = tvm.compute(oshape, lambda n, h, w, co:
conv_vec[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype), conv = tvm.compute(oshape,
lambda n, h, w, co:
conv_vec[n,
idxd(h, VH), idxd(w, VW), idxd(co, VC),
idxm(h, VH), idxm(w, VW), idxm(co, VC)].astype(out_dtype),
name='conv', tag='spatial_bitserial_conv_nhwc') name='conv', tag='spatial_bitserial_conv_nhwc')
return conv return conv
......
...@@ -171,6 +171,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -171,6 +171,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1 assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad") data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
idxd = tvm.indexdiv
idxm = tvm.indexmod
r = KW r = KW
m = tile_size m = tile_size
alpha = m + r - 1 alpha = m + r - 1
...@@ -190,10 +193,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -190,10 +193,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
VK = cfg['tile_k'].size[-1] VK = cfg['tile_k'].size[-1]
# pack input tile # pack input tile
input_tile = tvm.compute((C, P // VP, alpha, alpha, VP), input_tile = tvm.compute((C, idxd(P, VP), alpha, alpha, VP),
lambda c, b, eps, nu, bb: lambda c, b, eps, nu, bb:
data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps] data_pad[idxd(b*VP + bb, nH*nW), c,
[(b*VP+bb) % nW * m + nu], idxm(idxd(b*VP + bb, nW), nH) * m + eps,
idxm(b*VP + bb, nW) * m + nu],
name='d') name='d')
# transform kernel # transform kernel
...@@ -202,22 +206,22 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -202,22 +206,22 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
else: else:
r_kh = tvm.reduce_axis((0, KH), 'r_kh') r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw') r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk: U = tvm.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) * tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U') G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')
# transform image # transform image
r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu') r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb: V = tvm.compute((alpha, alpha, idxd(P, VP), C, VP), lambda eps, nu, b, c, bb:
tvm.sum(input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) * tvm.sum(input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) *
B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V') B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V')
# batch gemm # batch gemm
c = tvm.reduce_axis((0, C), name='c') c = tvm.reduce_axis((0, C), name='c')
M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b: M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b:
tvm.sum(U[eps][nu][k // VK][c][k % VK] * tvm.sum(U[eps][nu][idxd(k, VK)][c][idxm(k, VK)] *
V[eps][nu][b // VP][c][b % VP], axis=c), name='M') V[eps][nu][idxd(b, VP)][c][idxm(b, VP)], axis=c), name='M')
# inverse transform # inverse transform
r_eps = tvm.reduce_axis((0, alpha), 'r_eps') r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
...@@ -228,7 +232,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt ...@@ -228,7 +232,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# unpack output # unpack output
output = tvm.compute((N, K, H, W), lambda n, k, h, w: output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m], Y[k][n * nH * nW + idxd(h, m) * nW + idxd(w, m),
idxm(h, m), idxm(w, m)],
name='output', tag='winograd_conv2d_output') name='output', tag='winograd_conv2d_output')
# we have to manually assign effective GFLOP for winograd # we have to manually assign effective GFLOP for winograd
...@@ -517,6 +522,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -517,6 +522,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
N, CI, H, W = get_const_tuple(data.shape) N, CI, H, W = get_const_tuple(data.shape)
CO, _, KH, KW = get_const_tuple(kernel.shape) CO, _, KH, KW = get_const_tuple(kernel.shape)
idxd = tvm.indexdiv
if groups == 1: if groups == 1:
# query config of this workload # query config of this workload
workload = autotvm.task.args_to_workload( workload = autotvm.task.args_to_workload(
...@@ -535,7 +542,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -535,7 +542,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# Store the same config for the altered operator (workload) # Store the same config for the altered operator (workload)
new_data = data new_data = data
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype) new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload( new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d) [new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg) dispatch_ctx.update(target, new_workload, cfg)
...@@ -553,7 +560,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -553,7 +560,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1], weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size) tile_size=tile_size)
weight = F.reshape(weight, weight = F.reshape(weight,
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI)) newshape=(KH + tile_size - 1,
KW + tile_size - 1,
idxd(CO, VC), VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3]) weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
copy_inputs[1] = weight copy_inputs[1] = weight
...@@ -561,7 +570,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -561,7 +570,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# Store the same config for the altered operator (workload) # Store the same config for the altered operator (workload)
new_data = data new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC), new_weight = tvm.placeholder((KH + tile_size - 1,
KH + tile_size -1,
idxd(CO, VC), CI, VC),
kernel.dtype) kernel.dtype)
new_workload = autotvm.task.args_to_workload( new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation, [new_data, new_weight, strides, padding, dilation,
...@@ -612,7 +623,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F): ...@@ -612,7 +623,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# Store the same config for the altered operator (workload) # Store the same config for the altered operator (workload)
new_data = data new_data = data
CO, M, KH, KW = get_const_tuple(kernel.shape) CO, M, KH, KW = get_const_tuple(kernel.shape)
new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype) new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload( new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype], [new_data, new_kernel, strides, padding, dilation, out_dtype],
depthwise_conv2d_nchw) depthwise_conv2d_nchw)
......
...@@ -243,14 +243,16 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx): ...@@ -243,14 +243,16 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx tid = bx * max_threads + tx
new_range = num_anchors // elem_per_thread + 1 new_range = num_anchors // elem_per_thread + 1
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Scan: Downsweep: # Scan: Downsweep:
with ib. if_scope(tid < batch_size * num_anchors): with ib. if_scope(tid < batch_size * num_anchors):
i = tid // num_anchors # number of batches i = idxd(tid, num_anchors) # number of batches
j = tid % num_anchors # number of anchors j = idxm(tid, num_anchors) # number of anchors
with ib.if_scope(j < elem_per_thread): with ib.if_scope(j < elem_per_thread):
idx[tid] = idx_in[tid] idx[tid] = idx_in[tid]
with ib.else_scope(): with ib.else_scope():
idx[tid] = idx_in[tid] + partial[i * new_range + j // elem_per_thread - 1] idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1]
return ib.get() return ib.get()
...@@ -303,9 +305,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): ...@@ -303,9 +305,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx tid = bx * max_threads + tx
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.if_scope(tid < batch_size * num_anchors): with ib.if_scope(tid < batch_size * num_anchors):
i = tid // num_anchors i = idxd(tid, num_anchors)
j = tid % num_anchors j = idxm(tid, num_anchors)
base_idx = i * num_anchors * elem_length base_idx = i * num_anchors * elem_length
with ib.if_scope(flag[tid] > 0): with ib.if_scope(flag[tid] > 0):
with ib.for_range(0, elem_length) as k: with ib.for_range(0, elem_length) as k:
......
...@@ -79,10 +79,13 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r ...@@ -79,10 +79,13 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r
p_im_info = ib.buffer_ptr(im_info_buf) p_im_info = ib.buffer_ptr(im_info_buf)
p_out = ib.buffer_ptr(out_buf) p_out = ib.buffer_ptr(out_buf)
idxm = tvm.indexmod
idxd = tvm.indexdiv
with ib.if_scope(tid < batch * height * width): with ib.if_scope(tid < batch * height * width):
w = tid % width w = idxm(tid, width)
h = (tid // width) % height h = idxm(idxd(tid, width), height)
b = tid // width // height b = idxd(idxd(tid, width), height)
for k in range(num_anchors): for k in range(num_anchors):
out_index = tid * num_anchors + k out_index = tid * num_anchors + k
...@@ -163,6 +166,8 @@ def argsort_ir(data_buf, out_index_buf): ...@@ -163,6 +166,8 @@ def argsort_ir(data_buf, out_index_buf):
temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
idxm = tvm.indexmod
with ib.for_range(0, batch, for_type="unroll") as b: with ib.for_range(0, batch, for_type="unroll") as b:
start = b * num_bbox start = b * num_bbox
for i in range(2): for i in range(2):
...@@ -170,7 +175,7 @@ def argsort_ir(data_buf, out_index_buf): ...@@ -170,7 +175,7 @@ def argsort_ir(data_buf, out_index_buf):
with ib.if_scope(bbox_id < num_bbox): with ib.if_scope(bbox_id < num_bbox):
index_out[start + bbox_id] = bbox_id index_out[start + bbox_id] = bbox_id
with ib.for_range(0, num_bbox) as k: with ib.for_range(0, num_bbox) as k:
offset = start + 2 * tid + (k % 2) offset = start + 2 * tid + idxm(k, 2)
with ib.if_scope( with ib.if_scope(
tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])): tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])):
temp_data[0] = p_data[offset] temp_data[0] = p_data[offset]
......
...@@ -115,6 +115,8 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): ...@@ -115,6 +115,8 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
ib.emit(tvm.make.Call(None, 'tvm_storage_sync', ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']), tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0)) tvm.expr.Call.Intrinsic, None, 0))
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.for_range(0, axis_mul_before) as i: with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j: with ib.for_range(0, axis_mul_after) as j:
...@@ -122,13 +124,13 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None): ...@@ -122,13 +124,13 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
base_idx = i * shape[axis] * axis_mul_after + j base_idx = i * shape[axis] * axis_mul_after + j
# OddEvenTransposeSort # OddEvenTransposeSort
with ib.for_range(0, current_sort_num) as k: with ib.for_range(0, current_sort_num) as k:
with ib.if_scope(tid < (current_sort_num + 1) // 2): with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
if is_ascend: if is_ascend:
cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num, cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
values_out[offset] > values_out[offset + axis_mul_after]) values_out[offset] > values_out[offset + axis_mul_after])
else: else:
cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num, cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
values_out[offset] < values_out[offset + axis_mul_after]) values_out[offset] < values_out[offset + axis_mul_after])
with ib.if_scope(cond): with ib.if_scope(cond):
temp_data[0] = values_out[offset] temp_data[0] = values_out[offset]
...@@ -199,6 +201,9 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): ...@@ -199,6 +201,9 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.for_range(0, axis_mul_before) as i: with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j: with ib.for_range(0, axis_mul_after) as j:
current_sort_num = valid_count[i * axis_mul_after + j] current_sort_num = valid_count[i * axis_mul_after + j]
...@@ -207,10 +212,10 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): ...@@ -207,10 +212,10 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
output[base_idx + tid * axis_mul_after] = tid output[base_idx + tid * axis_mul_after] = tid
# OddEvenTransposeSort # OddEvenTransposeSort
with ib.for_range(0, current_sort_num) as k: with ib.for_range(0, current_sort_num) as k:
with ib.if_scope(tid < (current_sort_num + 1) // 2): with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
with ib.if_scope(tvm.all(is_ascend == 1, \ with ib.if_scope(tvm.all(is_ascend == 1, \
2 * tid + (k % 2) + 1 < current_sort_num, \ 2 * tid + idxm(k, 2) + 1 < current_sort_num, \
data[offset] > data[offset + axis_mul_after])): data[offset] > data[offset + axis_mul_after])):
temp_data[0] = data[offset] temp_data[0] = data[offset]
data[offset] = data[offset + axis_mul_after] data[offset] = data[offset + axis_mul_after]
...@@ -219,7 +224,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend): ...@@ -219,7 +224,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
output[offset] = output[offset + axis_mul_after] output[offset] = output[offset + axis_mul_after]
output[offset + axis_mul_after] = temp_index[0] output[offset + axis_mul_after] = temp_index[0]
with ib.if_scope(tvm.all(is_ascend == 0, \ with ib.if_scope(tvm.all(is_ascend == 0, \
2 * tid + (k % 2) + 1 < current_sort_num, \ 2 * tid + idxm(k, 2) + 1 < current_sort_num, \
data[offset] < data[offset + axis_mul_after])): data[offset] < data[offset + axis_mul_after])):
temp_data[0] = data[offset] temp_data[0] = data[offset]
data[offset] = data[offset + axis_mul_after] data[offset] = data[offset + axis_mul_after]
......
...@@ -95,8 +95,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): ...@@ -95,8 +95,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
for k in range(num_sizes + num_ratios - 1): for k in range(num_sizes + num_ratios - 1):
w = if_then_else(k < num_sizes, w = if_then_else(k < num_sizes,
size_ratio_concat[k] * in_height / in_width / 2.0, float(size_ratio_concat[k]) * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width * float(size_ratio_concat[0]) * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0) math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = if_then_else( h = if_then_else(
k < num_sizes, size_ratio_concat[k] / 2.0, k < num_sizes, size_ratio_concat[k] / 2.0,
...@@ -204,10 +204,12 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp ...@@ -204,10 +204,12 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx tid = bx * max_threads + tx
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.if_scope(tid < batch_size * num_anchors): with ib.if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors i = idxd(tid, num_anchors)
j = tid % num_anchors j = idxm(tid, num_anchors)
valid_count[i] = 0 valid_count[i] = 0
score[tid] = -1.0 score[tid] = -1.0
cls_id[tid] = 0 cls_id[tid] = 0
...@@ -314,9 +316,13 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score ...@@ -314,9 +316,13 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
ib.scope_attr(bx, "thread_extent", nthread_bx) ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx tid = bx * max_threads + tx
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.if_scope(tid < batch_size * num_anchors): with ib.if_scope(tid < batch_size * num_anchors):
i = tid // num_anchors i = idxd(tid, num_anchors)
j = tid % num_anchors j = idxm(tid, num_anchors)
with ib.if_scope(cls_id[tid] > 0): with ib.if_scope(cls_id[tid] > 0):
with ib.if_scope(tid == 0): with ib.if_scope(tid == 0):
out_base_idx = i * num_anchors * 6 out_base_idx = i * num_anchors * 6
......
...@@ -313,13 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -313,13 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
axis=[ci, dh, dw, b1, b2]) axis=[ci, dh, dw, b1, b2])
conv = tvm.compute(ovshape, _conv, name='conv_out') conv = tvm.compute(ovshape, _conv, name='conv_out')
idxdiv = tvm.indexdiv idxd = tvm.indexdiv
idxmod = tvm.indexmod idxm = tvm.indexmod
return tvm.compute( return tvm.compute(
oshape, lambda n, co, h, w: oshape, lambda n, co, h, w:
conv[n][idxdiv(co, VC)][idxdiv(h, VH)][idxdiv( conv[n,
w, VW)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)], idxd(co, VC), idxd(h, VH), idxd(w, VW),
idxm(h, VH), idxm(w, VW), idxm(co, VC)],
name='conv_vec', tag='spatial_bitserial_conv_nchw') name='conv_vec', tag='spatial_bitserial_conv_nchw')
@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct') @autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
...@@ -419,12 +420,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits, ...@@ -419,12 +420,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
conv = tvm.compute(ovshape, _conv, name='conv') conv = tvm.compute(ovshape, _conv, name='conv')
idxdiv = tvm.indexdiv idxd = tvm.indexdiv
idxmod = tvm.indexmod idxm = tvm.indexmod
return tvm.compute( return tvm.compute(
oshape, lambda n, h, w, co: oshape, lambda n, h, w, co:
conv[n][idxdiv(h, VH)][idxdiv(w, VW)][idxdiv( conv[n,
co, VC)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)], idxd(h, VH), idxd(w, VW), idxd(co, VC),
idxm(h, VH), idxm(w, VW), idxm(co, VC)],
name='output_unpack', tag='spatial_bitserial_conv_nhwc') name='output_unpack', tag='spatial_bitserial_conv_nhwc')
@tvm.target.generic_func @tvm.target.generic_func
......
...@@ -94,12 +94,15 @@ def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr): ...@@ -94,12 +94,15 @@ def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
x_val = data[i, bs_c * block_j + c] x_val = data[i, bs_c * block_j + c]
return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c]) return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c])
idxd = tvm.indexdiv
idxm = tvm.indexmod
bsrmm_block = tvm.compute( bsrmm_block = tvm.compute(
(m, num_blocks, bs_r), _compute_block, (m, num_blocks, bs_r), _compute_block,
tag="sparse_dense_bsrmm_block") tag="sparse_dense_bsrmm_block")
return tvm.compute( return tvm.compute(
(m, num_blocks * bs_r), (m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r], lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
tag="sparse_dense_bsrmm") tag="sparse_dense_bsrmm")
@tvm.target.generic_func @tvm.target.generic_func
......
...@@ -232,10 +232,12 @@ def unravel_index(idx, shape): ...@@ -232,10 +232,12 @@ def unravel_index(idx, shape):
indices : tuple of int or tvm.expr.IntImm indices : tuple of int or tvm.expr.IntImm
Corresponding coordinate of the 1D index Corresponding coordinate of the 1D index
""" """
idxd = tvm.indexdiv
idxm = tvm.indexmod
indices = [] indices = []
for i in range(len(shape) - 1, -1, -1): for i in range(len(shape) - 1, -1, -1):
indices.append(idx % shape[i]) indices.append(idxm(idx, shape[i]))
idx = idx // shape[i] idx = idxd(idx, shape[i])
indices = indices[::-1] indices = indices[::-1]
return indices return indices
...@@ -257,12 +259,13 @@ def const_matrix(matrix, name="const_matrix"): ...@@ -257,12 +259,13 @@ def const_matrix(matrix, name="const_matrix"):
""" """
row, col = matrix.shape row, col = matrix.shape
dtype = str(matrix.dtype) dtype = str(matrix.dtype)
idxm = tvm.indexmod
def select_array(i, j): def select_array(i, j):
now = tvm.const(0.0, dtype) now = tvm.const(0.0, dtype)
for ii in range(row): for ii in range(row):
for jj in range(col): for jj in range(col):
now = tvm.expr.Select(tvm.all(i % row == ii, j % col == jj), now = tvm.expr.Select(tvm.all(idxm(i, row) == ii, idxm(j, col) == jj),
tvm.const(matrix[ii][jj], dtype), tvm.const(matrix[ii][jj], dtype),
now) now)
return now return now
......
...@@ -73,10 +73,10 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets): ...@@ -73,10 +73,10 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets):
center_w = (j + offset_w) * steps_w center_w = (j + offset_w) * steps_w
for k in const_range(num_sizes + num_ratios - 1): for k in const_range(num_sizes + num_ratios - 1):
if k < num_sizes: if k < num_sizes:
w = sizes[k] * in_height / in_width / 2.0 w = float32(sizes[k] * in_height) / in_width / 2.0
h = sizes[k] / 2.0 h = sizes[k] / 2.0
else: else:
w = sizes[0] * in_height / in_width \ w = float32(sizes[0] * in_height) / in_width \
* sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0 * sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
h = sizes[0] / sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0 h = sizes[0] / sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
count = i * in_width * (num_sizes + num_ratios - 1) \ count = i * in_width * (num_sizes + num_ratios - 1) \
......
...@@ -309,8 +309,15 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o ...@@ -309,8 +309,15 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
# packing the Filter to let memory access be consecutive for AVX512 intrinsic # packing the Filter to let memory access be consecutive for AVX512 intrinsic
# Done in pre-compute stage # Done in pre-compute stage
packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4) idxd = tvm.indexdiv
PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e], idxm = tvm.indexmod
packw_shape = (kernel_h, kernel_w, idxd(num_filter, 16), 16 * idxd(channel, 4), 4)
PackW = tvm.compute(packw_shape,
lambda a, b, c, d, e:
Filter[a, b,
c*16 + idxm(d, 16),
idxd(d, 16) * 4 + e],
name="packed_filter") name="packed_filter")
rc = tvm.reduce_axis((0, in_channel), name='rc') rc = tvm.reduce_axis((0, in_channel), name='rc')
...@@ -321,7 +328,9 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o ...@@ -321,7 +328,9 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
lambda nn, yy, xx, ff: tvm.sum( lambda nn, yy, xx, ff: tvm.sum(
PaddedInput[nn, yy * stride_h + ry * dilation_h, PaddedInput[nn, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) * xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]), PackW[ry, rx, idxd(ff, 16),
idxd(rc, 4) * 16 + idxm(ff, 16),
idxm(rc, 4)].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8") name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8")
return Output return Output
......
...@@ -247,7 +247,7 @@ print(tvm.lower(s, [A, B, C], simple_mode=True)) ...@@ -247,7 +247,7 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
# We have to re-write the algorithm slightly. # We have to re-write the algorithm slightly.
packedB = tvm.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB') packedB = tvm.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB')
C = tvm.compute((M, N), C = tvm.compute((M, N),
lambda x, y: tvm.sum(A[x, k] * packedB[y / bn, k, y % bn], axis=k), lambda x, y: tvm.sum(A[x, k] * packedB[y // bn, k, tvm.indexmod(y, bn)], axis=k),
name = 'C') name = 'C')
s = tvm.create_schedule(C.op) s = tvm.create_schedule(C.op)
......
...@@ -335,6 +335,9 @@ def inject_dma_intrin(stmt_in): ...@@ -335,6 +335,9 @@ def inject_dma_intrin(stmt_in):
Transformed statement Transformed statement
""" """
env = get_env() env = get_env()
idxd = tvm.indexdiv
idxm = tvm.indexmod
def _check_compact(buf): def _check_compact(buf):
ndim = len(buf.shape) ndim = len(buf.shape)
size = tvm.const(1, buf.shape[0].dtype) size = tvm.const(1, buf.shape[0].dtype)
...@@ -369,7 +372,7 @@ def inject_dma_intrin(stmt_in): ...@@ -369,7 +372,7 @@ def inject_dma_intrin(stmt_in):
x_size = 1 x_size = 1
x_stride = buf.strides[ndim - base] x_stride = buf.strides[ndim - base]
next_base = base next_base = base
if not util.equal_const_int(x_stride % elem_block, 0): if not util.equal_const_int(idxm(x_stride, elem_block), 0):
raise RuntimeError( raise RuntimeError(
"scope %s need to have block=%d, shape=%s, strides=%s" % ( "scope %s need to have block=%d, shape=%s, strides=%s" % (
scope, elem_block, buf.shape, buf.strides)) scope, elem_block, buf.shape, buf.strides))
...@@ -394,7 +397,7 @@ def inject_dma_intrin(stmt_in): ...@@ -394,7 +397,7 @@ def inject_dma_intrin(stmt_in):
raise RuntimeError("Expect buffer type to be %s instead of %s" % raise RuntimeError("Expect buffer type to be %s instead of %s" %
(dtype, buf.dtype)) (dtype, buf.dtype))
shape, strides = buf.shape, buf.strides shape, strides = buf.shape, buf.strides
if not util.equal_const_int(buf.elem_offset % elem_block, 0): if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block)) raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
if allow_fold: if allow_fold:
shape, strides = _fold_buffer_dim(buf, scope, elem_block) shape, strides = _fold_buffer_dim(buf, scope, elem_block)
...@@ -421,7 +424,7 @@ def inject_dma_intrin(stmt_in): ...@@ -421,7 +424,7 @@ def inject_dma_intrin(stmt_in):
x_size = 1 x_size = 1
x_stride = 1 x_stride = 1
y_size = 1 y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(strides[-2] - elem_block, 0): if not util.equal_const_int(strides[-2] - elem_block, 0):
raise_error() raise_error()
...@@ -429,15 +432,15 @@ def inject_dma_intrin(stmt_in): ...@@ -429,15 +432,15 @@ def inject_dma_intrin(stmt_in):
x_size = shape[-2] x_size = shape[-2]
x_stride = shape[-2] x_stride = shape[-2]
y_size = 1 y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(strides[-3] % elem_block, 0): if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
raise_error() raise_error()
if ndim == 3: if ndim == 3:
x_size = shape[-2] x_size = shape[-2]
x_stride = strides[-3] / elem_block x_stride = idxd(strides[-3], elem_block)
y_size = shape[-3] y_size = shape[-3]
return x_size, y_size, x_stride, buf.elem_offset / elem_block return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
else: else:
if not util.equal_const_int(strides[-1], 1): if not util.equal_const_int(strides[-1], 1):
...@@ -451,7 +454,7 @@ def inject_dma_intrin(stmt_in): ...@@ -451,7 +454,7 @@ def inject_dma_intrin(stmt_in):
x_size = 1 x_size = 1
x_stride = 1 x_stride = 1
y_size = 1 y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(strides[-3], elem_block): if not util.equal_const_int(strides[-3], elem_block):
raise_error() raise_error()
...@@ -459,15 +462,15 @@ def inject_dma_intrin(stmt_in): ...@@ -459,15 +462,15 @@ def inject_dma_intrin(stmt_in):
x_size = shape[-3] x_size = shape[-3]
x_stride = shape[-3] x_stride = shape[-3]
y_size = 1 y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(strides[-4] % elem_block, 0): if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
raise_error() raise_error()
if ndim == 4: if ndim == 4:
x_size = shape[-3] x_size = shape[-3]
x_stride = strides[-4] / elem_block x_stride = idxd(strides[-4], elem_block)
y_size = shape[-4] y_size = shape[-4]
return x_size, y_size, x_stride, buf.elem_offset / elem_block return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
raise_error() raise_error()
...@@ -765,6 +768,8 @@ def inject_alu_intrin(stmt_in): ...@@ -765,6 +768,8 @@ def inject_alu_intrin(stmt_in):
Transformed statement Transformed statement
""" """
env = get_env() env = get_env()
idxm = tvm.indexmod
def _do_fold(stmt): def _do_fold(stmt):
def _equal(x, y): def _equal(x, y):
return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0) return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0)
...@@ -910,10 +915,10 @@ def inject_alu_intrin(stmt_in): ...@@ -910,10 +915,10 @@ def inject_alu_intrin(stmt_in):
assert len(extents) != 0 assert len(extents) != 0
assert tvm.ir_pass.Equal( assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify( tvm.ir_pass.Simplify(
src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0) idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal( assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify( tvm.ir_pass.Simplify(
dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0) idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(src_coeff[-2], 1) assert tvm.ir_pass.Equal(src_coeff[-2], 1)
assert tvm.ir_pass.Equal(dst_coeff[-2], 1) assert tvm.ir_pass.Equal(dst_coeff[-2], 1)
if env.BATCH > 1: if env.BATCH > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment