Unverified Commit f98035b0 by Tianqi Chen Committed by GitHub

[ARITH] cleanup the indexmod/div on python side (#4028)

parent bbf82e0e
......@@ -350,7 +350,9 @@ def compute_flop(sch):
return _count_flop(exp.value)
if isinstance(exp, expr.Var):
return 0
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul, expr.Div, expr.Mod,
if isinstance(exp, (expr.Add, expr.Sub, expr.Mul,
expr.Div, expr.Mod,
expr.FloorDiv, expr.FloorMod,
expr.Max, expr.Min,
expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE,
expr.And, expr.Or, expr.Not)):
......
......@@ -72,23 +72,23 @@ class ExprOp(object):
return _generic.multiply(other, self)
def __div__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)
def __rdiv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)
def __truediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(self, other)
def __rtruediv__(self, other):
# if _dtype_is_int(self) and _dtype_is_int(other):
# raise div_ambiguity_error()
if _dtype_is_int(self) and _dtype_is_int(other):
raise div_ambiguity_error()
return _generic.divide(other, self)
def __floordiv__(self, other):
......@@ -100,8 +100,8 @@ class ExprOp(object):
return _generic.divide(other, self)
def __mod__(self, other):
# raise div_ambiguity_error()
return _make._OpMod(self, other)
raise div_ambiguity_error()
# return _make._OpMod(self, other)
def __neg__(self):
neg_one = _api_internal._const(-1, self.dtype)
......
......@@ -64,6 +64,8 @@ class UnsafeExprDetector : public ExprFunctor<bool(const Expr& n)> {
bool VisitExpr_(const Mul* op) final { return BinaryOp(op); }
bool VisitExpr_(const Div* op) final { return BinaryOp(op); }
bool VisitExpr_(const Mod* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorDiv* op) final { return BinaryOp(op); }
bool VisitExpr_(const FloorMod* op) final { return BinaryOp(op); }
bool VisitExpr_(const Min* op) final { return BinaryOp(op); }
bool VisitExpr_(const Max* op) final { return BinaryOp(op); }
bool VisitExpr_(const EQ* op) final { return BinaryOp(op); }
......
......@@ -373,6 +373,8 @@ def test_split_infer_type():
yy = run_infer_type(y.astuple())
assert yy.checked_type == ret_type
idxd = tvm.indexdiv
d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
axis = tvm.var("axis")
verify_split((5, 5, 2, 2), 5,
......@@ -393,15 +395,15 @@ def test_split_infer_type():
axis=0)
verify_split((d1, d2, d3, d4), 4,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32"),
relay.ty.TensorType((d1, d2, idxd(d3, 4), d4), "float32")])),
axis=2)
verify_split((d1, d2, d3, d4), 2,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1/2, d2, d3, d4), "float32"),
relay.ty.TensorType((d1/2, d2, d3, d4), "float32")])),
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32"),
relay.ty.TensorType((idxd(d1, 2), d2, d3, d4), "float32")])),
axis=0)
verify_split((d1, d2, d3, d4), (2, 4, 7),
relay.ty.TupleType(tvm.convert([
......
......@@ -487,8 +487,9 @@ def test_yolo_reorg_infer_shape():
assert zz.checked_type == relay.ty.TensorType(out_shape, "float32")
n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w")
idxd = tvm.indexdiv
verify_yolo_reorg((n, c, 20, 20), 10, (n, c*10*10, 2, 2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, h/2, w/2))
verify_yolo_reorg((n, c, h, w), 2, (n, c*2*2, idxd(h, 2), idxd(w, 2)))
def test_yolo_reorg():
def verify_yolo_reorg(shape, stride):
......
......@@ -60,14 +60,14 @@ def test_pack_gemm():
k = tvm.reduce_axis((0, L))
bn = 4
fld = tvm.floordiv
flm = tvm.floormod
idxd = tvm.indexdiv
idxm = tvm.indexmod
A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j])
B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j])
C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj:
tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k]))
C = tvm.compute((N, M), lambda i, j: C_pack[fld(i, bn)][fld(j, bn)][flm(i, bn)][flm(j, bn)])
C = tvm.compute((N, M), lambda i, j: C_pack[idxd(i, bn)][idxd(j, bn)][idxm(i, bn)][idxm(j, bn)])
s = tvm.create_schedule([C.op])
assert compute_flop(s) == 2 * N * L * M
......
......@@ -37,7 +37,7 @@ def test_cuda_vectorize_add():
print("skip because gpu does not support int8")
return
A = tvm.placeholder((n,), name='A', dtype="%sx%d" % (dtype, lanes))
B = tvm.compute((n,), lambda i: A[i]+tvm.const(1, A.dtype), name='B')
B = tvm.compute((n,), lambda i: A[i] + tvm.const(1, A.dtype), name='B')
s = tvm.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(xo, bx)
......@@ -165,9 +165,10 @@ def test_cuda_shuffle():
print("skip because cuda is not enabled..")
return
idxm = tvm.indexmod
a = tvm.placeholder((64, ), 'int32')
b = tvm.placeholder((64, ), 'int32')
c = tvm.compute((64, ), lambda x: a[x] + b[x - (x % 4) + (3 - x % 4)])
c = tvm.compute((64, ), lambda x: a[x] + b[x - idxm(x, 4) + (3 - idxm(x, 4))])
sch = tvm.create_schedule(c.op)
x = c.op.axis[0]
xo, xi = sch[c].split(x, 4)
......
......@@ -109,14 +109,15 @@ def test_gpu():
dtype = "float32"
A = tvm.placeholder((n,), name='A')
B = tvm.placeholder((n,), name='B')
fld = tvm.floordiv
idxd = tvm.indexdiv
def test_device_ir(A, B, C):
n = A.shape[0]
max_threads = 32
ib = tvm.ir_builder.create()
bx = tvm.thread_axis("blockIdx.x")
tx = tvm.thread_axis("threadIdx.x")
ib.scope_attr(bx, "thread_extent", fld(n+max_threads-1, max_threads))
ib.scope_attr(bx, "thread_extent", idxd(n+max_threads-1, max_threads))
ib.scope_attr(tx, "thread_extent", max_threads)
idx = bx.var * max_threads + tx.var
Aptr = ib.buffer_ptr(A)
......
......@@ -94,31 +94,31 @@ def test_buffer_index_merge_mult_mod():
def assert_simplified_equal(index_simplified, index_direct):
assert tvm.ir_pass.Equal(index_simplified, index_direct),\
"index_simplified=%s, index_direct=%s" %(index_simplified, index_direct)
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Test Case1
index_simplified = A_stride.vload(
(idxdiv(idxmod(k0, k1), s), idxmod(idxmod(k0, k1), s) + idxdiv(k0, k1) * k1))
(idxd(idxm(k0, k1), s), idxm(idxm(k0, k1), s) + idxd(k0, k1) * k1))
index_direct = A_stride.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)
# Test Case2
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, s)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxmod(k0, k1) + idxmod(k0, idxdiv(k1, s))))
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxm(idxm(k0, idxd(k1, s)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxm(k0, k1) + idxm(k0, idxd(k1, s))))
assert_simplified_equal(index_simplified, index_direct)
# Test Case3
index_simplified = A.vload((idxdiv((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod((idxdiv(k0, idxdiv(k1, s)) * idxdiv(k1, s)), n) +
idxmod(idxmod(k0, idxdiv(k1, s)), n)))
index_simplified = A.vload((idxd((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxd(idxm(k0, idxd(k1, s)), n),
idxm((idxd(k0, idxd(k1, s)) * idxd(k1, s)), n) +
idxm(idxm(k0, idxd(k1, s)), n)))
index_direct = A.vload((0, k0))
assert_simplified_equal(index_simplified, index_direct)
# Test Case4 (not able to simplify)
index_simplified = A.vload((idxdiv(idxmod(k0, idxdiv(k1, s)), n),
idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1)))
index_direct = A.vload((0, idxdiv(idxmod(k0, idxdiv(k1, s)), n) * n +
(idxmod(idxmod(k0, idxdiv(k1, n)), n) + idxmod(k0, k1))))
index_simplified = A.vload((idxd(idxm(k0, idxd(k1, s)), n),
idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1)))
index_direct = A.vload((0, idxd(idxm(k0, idxd(k1, s)), n) * n +
(idxm(idxm(k0, idxd(k1, n)), n) + idxm(k0, k1))))
assert_simplified_equal(index_simplified, index_direct)
......
......@@ -28,7 +28,7 @@ def test_rewrite_Select():
tvm.expr.Select(i > 1, A[i-1], 1.0) > 0.0, A[i], 0.1)
zz = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(z)).value
a = tvm.expr.Select(i>10, y, z)
a = tvm.expr.Select(tvm.floordiv(i, 4) > 10, y, z)
aa = tvm.ir_pass.RewriteUnsafeSelect(tvm.make.Evaluate(a)).value
assert yy.name == "tvm_if_then_else"
assert zz.name == "tvm_if_then_else"
......
......@@ -221,14 +221,15 @@ def test_tensorize_matmul():
# This tests whether algorithm and intrinsics expressions are simplified
# as much as possible first and then checked for equality. See Issue #696
def test_tensorize_op():
tdiv = tvm.truncdiv
tmod = tvm.truncmod
idxd = tvm.indexdiv
idxm = tvm.indexmod
def op_intrin():
bh = 9
bw = 9
x = tvm.placeholder((5, 5), name='A')
y = tvm.compute((bh, bw),
lambda i, j: x[tdiv(j,3) + tmod(i,3), tmod(j,3)+ tdiv(i,3)])
lambda i, j: x[idxd(j,3) + idxm(i,3), idxm(j,3)+ idxd(i,3)])
def intrin_func(ins, outs):
xx, = ins
......@@ -239,7 +240,7 @@ def test_tensorize_op():
return tvm.decl_tensor_intrin(y.op, intrin_func)
A = tvm.placeholder((5, 5), name='A')
B = tvm.compute((9,9), lambda i, j: A[tdiv(j,3) + tmod(i,3), tmod(j,3) + tdiv(i,3)])
B = tvm.compute((9,9), lambda i, j: A[idxd(j,3) + idxm(i,3), idxm(j,3) + idxd(i,3)])
bt = op_intrin()
s = tvm.create_schedule(B.op)
......
......@@ -70,6 +70,9 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
OW = (PAD_W - KW) // WSTR + 1
oshape = (1, OH, OW, CO)
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Pad input channels of weights and data when it is not a multiple of 8
if CI_packed % 8 != 0:
CI_PAD = CI_packed % 8
......@@ -106,7 +109,8 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
data_q = bitpack(data, activation_bits, pack_axis=3, bit_axis=3, pack_type='uint8')
kernel_vec = _kernel_vec_spatial_pack_nhwc(kernel, weight_bits, VC, len(kernel.shape) == 4)
if kernel_vec.shape[-1] % 8 != 0 and CI_PAD != 0:
idxm = tvm.indexmod
if idxm(kernel_vec.shape[-1], 8) != 0 and CI_PAD != 0:
kernel_vec = pad(kernel_vec, [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, CI_PAD])
N, H, W, IB, CI = data_q.shape
......@@ -147,8 +151,12 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, activation_bits, weigh
else:
conv_vec = tvm.compute(ovshape, _bipolar_conv, name='conv_vec', tag='bipolar')
conv = tvm.compute(oshape, lambda n, h, w, co:
conv_vec[n][h//VH][w//VW][co//VC][h%VH][w%VW][co%VC].astype(out_dtype),
conv = tvm.compute(oshape,
lambda n, h, w, co:
conv_vec[n,
idxd(h, VH), idxd(w, VW), idxd(co, VC),
idxm(h, VH), idxm(w, VW), idxm(co, VC)].astype(out_dtype),
name='conv', tag='spatial_bitserial_conv_nhwc')
return conv
......
......@@ -171,6 +171,9 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
assert KH == 3 and KW == 3 and HSTR == 1 and WSTR == 1
data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
idxd = tvm.indexdiv
idxm = tvm.indexmod
r = KW
m = tile_size
alpha = m + r - 1
......@@ -190,10 +193,11 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
VK = cfg['tile_k'].size[-1]
# pack input tile
input_tile = tvm.compute((C, P // VP, alpha, alpha, VP),
input_tile = tvm.compute((C, idxd(P, VP), alpha, alpha, VP),
lambda c, b, eps, nu, bb:
data_pad[(b*VP+bb) // (nH*nW)][c][(b*VP+bb) // nW % nH * m + eps]
[(b*VP+bb) % nW * m + nu],
data_pad[idxd(b*VP + bb, nH*nW), c,
idxm(idxd(b*VP + bb, nW), nH) * m + eps,
idxm(b*VP + bb, nW) * m + nu],
name='d')
# transform kernel
......@@ -202,22 +206,22 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
else:
r_kh = tvm.reduce_axis((0, KH), 'r_kh')
r_kw = tvm.reduce_axis((0, KW), 'r_kw')
U = tvm.compute((alpha, alpha, K // VK, C, VK), lambda eps, nu, k, c, kk:
U = tvm.compute((alpha, alpha, idxd(K, VK), C, VK), lambda eps, nu, k, c, kk:
tvm.sum(kernel[k * VK + kk][c][r_kh][r_kw].astype(out_dtype) *
G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]), name='U')
# transform image
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
r_nu = tvm.reduce_axis((0, alpha), 'r_nu')
V = tvm.compute((alpha, alpha, P // VP, C, VP), lambda eps, nu, b, c, bb:
V = tvm.compute((alpha, alpha, idxd(P, VP), C, VP), lambda eps, nu, b, c, bb:
tvm.sum(input_tile[c][b][r_eps][r_nu][bb].astype(out_dtype) *
B[r_eps][eps] * B[r_nu][nu], axis=[r_eps, r_nu]), name='V')
# batch gemm
c = tvm.reduce_axis((0, C), name='c')
M = tvm.compute((alpha, alpha, K, P), lambda eps, nu, k, b:
tvm.sum(U[eps][nu][k // VK][c][k % VK] *
V[eps][nu][b // VP][c][b % VP], axis=c), name='M')
tvm.sum(U[eps][nu][idxd(k, VK)][c][idxm(k, VK)] *
V[eps][nu][idxd(b, VP)][c][idxm(b, VP)], axis=c), name='M')
# inverse transform
r_eps = tvm.reduce_axis((0, alpha), 'r_eps')
......@@ -228,7 +232,8 @@ def _decl_winograd(cfg, data, kernel, strides, padding, dilation, layout, out_dt
# unpack output
output = tvm.compute((N, K, H, W), lambda n, k, h, w:
Y[k][n * nH * nW + (h//m) * nW + w//m][h % m][w % m],
Y[k][n * nH * nW + idxd(h, m) * nW + idxd(w, m),
idxm(h, m), idxm(w, m)],
name='output', tag='winograd_conv2d_output')
# we have to manually assign effective GFLOP for winograd
......@@ -517,6 +522,8 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
N, CI, H, W = get_const_tuple(data.shape)
CO, _, KH, KW = get_const_tuple(kernel.shape)
idxd = tvm.indexdiv
if groups == 1:
# query config of this workload
workload = autotvm.task.args_to_workload(
......@@ -535,7 +542,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# Store the same config for the altered operator (workload)
new_data = data
new_kernel = tvm.placeholder((CO // VC, CI, KH, KW, VC), dtype=kernel.dtype)
new_kernel = tvm.placeholder((idxd(CO, VC), CI, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, 'NCHW', out_dtype], conv2d)
dispatch_ctx.update(target, new_workload, cfg)
......@@ -553,7 +560,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
weight = F.nn.contrib_conv2d_winograd_weight_transform(copy_inputs[1],
tile_size=tile_size)
weight = F.reshape(weight,
newshape=(KH + tile_size - 1, KW + tile_size - 1, CO // VC, VC, CI))
newshape=(KH + tile_size - 1,
KW + tile_size - 1,
idxd(CO, VC), VC, CI))
weight = F.transpose(weight, axes=[0, 1, 2, 4, 3])
copy_inputs[1] = weight
......@@ -561,7 +570,9 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# Store the same config for the altered operator (workload)
new_data = data
new_weight = tvm.placeholder((KH + tile_size - 1, KH + tile_size -1, CO // VC, CI, VC),
new_weight = tvm.placeholder((KH + tile_size - 1,
KH + tile_size -1,
idxd(CO, VC), CI, VC),
kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_weight, strides, padding, dilation,
......@@ -612,7 +623,7 @@ def _alter_conv2d_layout_arm(attrs, inputs, tinfos, F):
# Store the same config for the altered operator (workload)
new_data = data
CO, M, KH, KW = get_const_tuple(kernel.shape)
new_kernel = tvm.placeholder((CO // VC, M, KH, KW, VC), dtype=kernel.dtype)
new_kernel = tvm.placeholder((idxd(CO, VC), M, KH, KW, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype],
depthwise_conv2d_nchw)
......
......@@ -243,14 +243,16 @@ def get_valid_counts_downsweep(data, idx_in, partial, idx):
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
new_range = num_anchors // elem_per_thread + 1
idxd = tvm.indexdiv
idxm = tvm.indexmod
# Scan: Downsweep:
with ib. if_scope(tid < batch_size * num_anchors):
i = tid // num_anchors # number of batches
j = tid % num_anchors # number of anchors
i = idxd(tid, num_anchors) # number of batches
j = idxm(tid, num_anchors) # number of anchors
with ib.if_scope(j < elem_per_thread):
idx[tid] = idx_in[tid]
with ib.else_scope():
idx[tid] = idx_in[tid] + partial[i * new_range + j // elem_per_thread - 1]
idx[tid] = idx_in[tid] + partial[i * new_range + idxd(j, elem_per_thread) - 1]
return ib.get()
......@@ -303,9 +305,12 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.if_scope(tid < batch_size * num_anchors):
i = tid // num_anchors
j = tid % num_anchors
i = idxd(tid, num_anchors)
j = idxm(tid, num_anchors)
base_idx = i * num_anchors * elem_length
with ib.if_scope(flag[tid] > 0):
with ib.for_range(0, elem_length) as k:
......
......@@ -79,10 +79,13 @@ def predict_bbox_ir(cls_prob_buf, bbox_pred_buf, im_info_buf, out_buf, scales, r
p_im_info = ib.buffer_ptr(im_info_buf)
p_out = ib.buffer_ptr(out_buf)
idxm = tvm.indexmod
idxd = tvm.indexdiv
with ib.if_scope(tid < batch * height * width):
w = tid % width
h = (tid // width) % height
b = tid // width // height
w = idxm(tid, width)
h = idxm(idxd(tid, width), height)
b = idxd(idxd(tid, width), height)
for k in range(num_anchors):
out_index = tid * num_anchors + k
......@@ -163,6 +166,8 @@ def argsort_ir(data_buf, out_index_buf):
temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local")
temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
idxm = tvm.indexmod
with ib.for_range(0, batch, for_type="unroll") as b:
start = b * num_bbox
for i in range(2):
......@@ -170,7 +175,7 @@ def argsort_ir(data_buf, out_index_buf):
with ib.if_scope(bbox_id < num_bbox):
index_out[start + bbox_id] = bbox_id
with ib.for_range(0, num_bbox) as k:
offset = start + 2 * tid + (k % 2)
offset = start + 2 * tid + idxm(k, 2)
with ib.if_scope(
tvm.all(offset + 1 < num_bbox, p_data[offset] < p_data[offset + 1])):
temp_data[0] = p_data[offset]
......
......@@ -115,6 +115,8 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
ib.emit(tvm.make.Call(None, 'tvm_storage_sync',
tvm.convert(['shared']),
tvm.expr.Call.Intrinsic, None, 0))
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j:
......@@ -122,13 +124,13 @@ def sort_ir(data, values_out, axis, is_ascend, indices_out=None):
base_idx = i * shape[axis] * axis_mul_after + j
# OddEvenTransposeSort
with ib.for_range(0, current_sort_num) as k:
with ib.if_scope(tid < (current_sort_num + 1) // 2):
offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
if is_ascend:
cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
values_out[offset] > values_out[offset + axis_mul_after])
else:
cond = tvm.all(2 * tid + (k % 2) + 1 < current_sort_num,
cond = tvm.all(2 * tid + idxm(k, 2) + 1 < current_sort_num,
values_out[offset] < values_out[offset + axis_mul_after])
with ib.if_scope(cond):
temp_data[0] = values_out[offset]
......@@ -199,6 +201,9 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local")
is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend)
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.for_range(0, axis_mul_before) as i:
with ib.for_range(0, axis_mul_after) as j:
current_sort_num = valid_count[i * axis_mul_after + j]
......@@ -207,10 +212,10 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
output[base_idx + tid * axis_mul_after] = tid
# OddEvenTransposeSort
with ib.for_range(0, current_sort_num) as k:
with ib.if_scope(tid < (current_sort_num + 1) // 2):
offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after
with ib.if_scope(tid < idxd(current_sort_num + 1, 2)):
offset = base_idx + (2 * tid + idxm(k, 2)) * axis_mul_after
with ib.if_scope(tvm.all(is_ascend == 1, \
2 * tid + (k % 2) + 1 < current_sort_num, \
2 * tid + idxm(k, 2) + 1 < current_sort_num, \
data[offset] > data[offset + axis_mul_after])):
temp_data[0] = data[offset]
data[offset] = data[offset + axis_mul_after]
......@@ -219,7 +224,7 @@ def sort_nms_ir(data, valid_count, output, axis, is_ascend):
output[offset] = output[offset + axis_mul_after]
output[offset + axis_mul_after] = temp_index[0]
with ib.if_scope(tvm.all(is_ascend == 0, \
2 * tid + (k % 2) + 1 < current_sort_num, \
2 * tid + idxm(k, 2) + 1 < current_sort_num, \
data[offset] < data[offset + axis_mul_after])):
temp_data[0] = data[offset]
data[offset] = data[offset + axis_mul_after]
......
......@@ -95,8 +95,8 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets):
for k in range(num_sizes + num_ratios - 1):
w = if_then_else(k < num_sizes,
size_ratio_concat[k] * in_height / in_width / 2.0,
size_ratio_concat[0] * in_height / in_width *
float(size_ratio_concat[k]) * in_height / in_width / 2.0,
float(size_ratio_concat[0]) * in_height / in_width *
math.sqrt(size_ratio_concat[k + 1]) / 2.0)
h = if_then_else(
k < num_sizes, size_ratio_concat[k] / 2.0,
......@@ -204,10 +204,12 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors
j = tid % num_anchors
i = idxd(tid, num_anchors)
j = idxm(tid, num_anchors)
valid_count[i] = 0
score[tid] = -1.0
cls_id[tid] = 0
......@@ -314,9 +316,13 @@ def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx
idxd = tvm.indexdiv
idxm = tvm.indexmod
with ib.if_scope(tid < batch_size * num_anchors):
i = tid // num_anchors
j = tid % num_anchors
i = idxd(tid, num_anchors)
j = idxm(tid, num_anchors)
with ib.if_scope(cls_id[tid] > 0):
with ib.if_scope(tid == 0):
out_base_idx = i * num_anchors * 6
......
......@@ -313,13 +313,14 @@ def spatial_pack_nchw(cfg, data, kernel, stride, padding, in_bits, weight_bits,
axis=[ci, dh, dw, b1, b2])
conv = tvm.compute(ovshape, _conv, name='conv_out')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
idxd = tvm.indexdiv
idxm = tvm.indexmod
return tvm.compute(
oshape, lambda n, co, h, w:
conv[n][idxdiv(co, VC)][idxdiv(h, VH)][idxdiv(
w, VW)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
conv[n,
idxd(co, VC), idxd(h, VH), idxd(w, VW),
idxm(h, VH), idxm(w, VW), idxm(co, VC)],
name='conv_vec', tag='spatial_bitserial_conv_nchw')
@autotvm.register_topi_compute(bitserial_conv2d_nhwc, 'cpu', 'direct')
......@@ -419,12 +420,13 @@ def spatial_pack_nhwc(cfg, data, kernel, stride, padding, in_bits, weight_bits,
conv = tvm.compute(ovshape, _conv, name='conv')
idxdiv = tvm.indexdiv
idxmod = tvm.indexmod
idxd = tvm.indexdiv
idxm = tvm.indexmod
return tvm.compute(
oshape, lambda n, h, w, co:
conv[n][idxdiv(h, VH)][idxdiv(w, VW)][idxdiv(
co, VC)][idxmod(h, VH)][idxmod(w, VW)][idxmod(co, VC)],
conv[n,
idxd(h, VH), idxd(w, VW), idxd(co, VC),
idxm(h, VH), idxm(w, VW), idxm(co, VC)],
name='output_unpack', tag='spatial_bitserial_conv_nhwc')
@tvm.target.generic_func
......
......@@ -94,12 +94,15 @@ def _sparse_dense_bsrmm(data, weight_data, weight_indices, weight_indptr):
x_val = data[i, bs_c * block_j + c]
return tvm.sum(block_ij_val * x_val, axis=[elem_idx, c])
idxd = tvm.indexdiv
idxm = tvm.indexmod
bsrmm_block = tvm.compute(
(m, num_blocks, bs_r), _compute_block,
tag="sparse_dense_bsrmm_block")
return tvm.compute(
(m, num_blocks * bs_r),
lambda m, n: bsrmm_block[m, n // bs_r, n % bs_r],
lambda m, n: bsrmm_block[m, idxd(n, bs_r), idxm(n, bs_r)],
tag="sparse_dense_bsrmm")
@tvm.target.generic_func
......
......@@ -232,10 +232,12 @@ def unravel_index(idx, shape):
indices : tuple of int or tvm.expr.IntImm
Corresponding coordinate of the 1D index
"""
idxd = tvm.indexdiv
idxm = tvm.indexmod
indices = []
for i in range(len(shape) - 1, -1, -1):
indices.append(idx % shape[i])
idx = idx // shape[i]
indices.append(idxm(idx, shape[i]))
idx = idxd(idx, shape[i])
indices = indices[::-1]
return indices
......@@ -257,12 +259,13 @@ def const_matrix(matrix, name="const_matrix"):
"""
row, col = matrix.shape
dtype = str(matrix.dtype)
idxm = tvm.indexmod
def select_array(i, j):
now = tvm.const(0.0, dtype)
for ii in range(row):
for jj in range(col):
now = tvm.expr.Select(tvm.all(i % row == ii, j % col == jj),
now = tvm.expr.Select(tvm.all(idxm(i, row) == ii, idxm(j, col) == jj),
tvm.const(matrix[ii][jj], dtype),
now)
return now
......
......@@ -73,10 +73,10 @@ def hybrid_multibox_prior(data, sizes, ratios, steps, offsets):
center_w = (j + offset_w) * steps_w
for k in const_range(num_sizes + num_ratios - 1):
if k < num_sizes:
w = sizes[k] * in_height / in_width / 2.0
w = float32(sizes[k] * in_height) / in_width / 2.0
h = sizes[k] / 2.0
else:
w = sizes[0] * in_height / in_width \
w = float32(sizes[0] * in_height) / in_width \
* sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
h = sizes[0] / sqrt(ratios[k - num_sizes + 1] * 1.0) / 2.0
count = i * in_width * (num_sizes + num_ratios - 1) \
......
......@@ -309,8 +309,15 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
# packing the Filter to let memory access be consecutive for AVX512 intrinsic
# Done in pre-compute stage
packw_shape = (kernel_h, kernel_w, num_filter/16, 16*(channel/4), 4)
PackW = tvm.compute(packw_shape, lambda a, b, c, d, e: Filter[a][b][c*16+d%16][d/16*4+e],
idxd = tvm.indexdiv
idxm = tvm.indexmod
packw_shape = (kernel_h, kernel_w, idxd(num_filter, 16), 16 * idxd(channel, 4), 4)
PackW = tvm.compute(packw_shape,
lambda a, b, c, d, e:
Filter[a, b,
c*16 + idxm(d, 16),
idxd(d, 16) * 4 + e],
name="packed_filter")
rc = tvm.reduce_axis((0, in_channel), name='rc')
......@@ -321,7 +328,9 @@ def _declaration_conv_nhwc_pack(cfg, Input, Filter, stride, padding, dilation, o
lambda nn, yy, xx, ff: tvm.sum(
PaddedInput[nn, yy * stride_h + ry * dilation_h,
xx * stride_w + rx * dilation_w, rc].astype(out_dtype) *
PackW[ry, rx, ff/16, (rc/4)*16+ff%16, rc%4].astype(out_dtype), axis=[ry, rx, rc]),
PackW[ry, rx, idxd(ff, 16),
idxd(rc, 4) * 16 + idxm(ff, 16),
idxm(rc, 4)].astype(out_dtype), axis=[ry, rx, rc]),
name="Conv2d_1x1_Output_int8", tag="conv2d_nhwc_pack_int8")
return Output
......
......@@ -247,7 +247,7 @@ print(tvm.lower(s, [A, B, C], simple_mode=True))
# We have to re-write the algorithm slightly.
packedB = tvm.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name='packedB')
C = tvm.compute((M, N),
lambda x, y: tvm.sum(A[x, k] * packedB[y / bn, k, y % bn], axis=k),
lambda x, y: tvm.sum(A[x, k] * packedB[y // bn, k, tvm.indexmod(y, bn)], axis=k),
name = 'C')
s = tvm.create_schedule(C.op)
......
......@@ -335,6 +335,9 @@ def inject_dma_intrin(stmt_in):
Transformed statement
"""
env = get_env()
idxd = tvm.indexdiv
idxm = tvm.indexmod
def _check_compact(buf):
ndim = len(buf.shape)
size = tvm.const(1, buf.shape[0].dtype)
......@@ -369,7 +372,7 @@ def inject_dma_intrin(stmt_in):
x_size = 1
x_stride = buf.strides[ndim - base]
next_base = base
if not util.equal_const_int(x_stride % elem_block, 0):
if not util.equal_const_int(idxm(x_stride, elem_block), 0):
raise RuntimeError(
"scope %s need to have block=%d, shape=%s, strides=%s" % (
scope, elem_block, buf.shape, buf.strides))
......@@ -394,7 +397,7 @@ def inject_dma_intrin(stmt_in):
raise RuntimeError("Expect buffer type to be %s instead of %s" %
(dtype, buf.dtype))
shape, strides = buf.shape, buf.strides
if not util.equal_const_int(buf.elem_offset % elem_block, 0):
if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
if allow_fold:
shape, strides = _fold_buffer_dim(buf, scope, elem_block)
......@@ -421,7 +424,7 @@ def inject_dma_intrin(stmt_in):
x_size = 1
x_stride = 1
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(strides[-2] - elem_block, 0):
raise_error()
......@@ -429,15 +432,15 @@ def inject_dma_intrin(stmt_in):
x_size = shape[-2]
x_stride = shape[-2]
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
if not util.equal_const_int(strides[-3] % elem_block, 0):
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
raise_error()
if ndim == 3:
x_size = shape[-2]
x_stride = strides[-3] / elem_block
x_stride = idxd(strides[-3], elem_block)
y_size = shape[-3]
return x_size, y_size, x_stride, buf.elem_offset / elem_block
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
else:
if not util.equal_const_int(strides[-1], 1):
......@@ -451,7 +454,7 @@ def inject_dma_intrin(stmt_in):
x_size = 1
x_stride = 1
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(strides[-3], elem_block):
raise_error()
......@@ -459,15 +462,15 @@ def inject_dma_intrin(stmt_in):
x_size = shape[-3]
x_stride = shape[-3]
y_size = 1
return x_size, y_size, x_stride, buf.elem_offset / elem_block
if not util.equal_const_int(strides[-4] % elem_block, 0):
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
raise_error()
if ndim == 4:
x_size = shape[-3]
x_stride = strides[-4] / elem_block
x_stride = idxd(strides[-4], elem_block)
y_size = shape[-4]
return x_size, y_size, x_stride, buf.elem_offset / elem_block
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
raise_error()
......@@ -765,6 +768,8 @@ def inject_alu_intrin(stmt_in):
Transformed statement
"""
env = get_env()
idxm = tvm.indexmod
def _do_fold(stmt):
def _equal(x, y):
return tvm.ir_pass.Equal(tvm.ir_pass.Simplify(x - y), 0)
......@@ -910,10 +915,10 @@ def inject_alu_intrin(stmt_in):
assert len(extents) != 0
assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify(
src_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(
tvm.ir_pass.Simplify(
dst_coeff[-1] % (env.BATCH * env.BLOCK_OUT)), 0)
idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
assert tvm.ir_pass.Equal(src_coeff[-2], 1)
assert tvm.ir_pass.Equal(dst_coeff[-2], 1)
if env.BATCH > 1:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment