Unverified Commit 4a3abb94 by Tianqi Chen Committed by GitHub

Revert "Added tesnorizeation for avx2 based gemm. (#3982)" (#4007)

This reverts commit 23727eb4.
parent 5f19e5a8
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition
import tvm
import numpy as np
from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int32_vnni
from topi.x86.tensor_intrin import dot_1x4x16_int8_int8_int32_avx2
def test_avx2_int8_gemm_acc32():
m = 1024
n = 1024
k = 1024
X = tvm.placeholder((m, k), name='X', dtype="uint8")
W = tvm.placeholder((n, k), name='W', dtype="int8")
memory_ops = m * k + n * k + 2 * m * n
gops_per_mm = 2 * m * n * k
def verify(target="llvm -mcpu=core-avx2"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return
ctx = tvm.context(target, 0)
pc = dot_1x4x16_int8_int8_int32_avx2()
ak = tvm.reduce_axis((0, k), name='k')
packedW = tvm.placeholder(
(n // 16, 16 * (k // 4), 4), name='packedW', dtype="int8")
t_fc = tvm.compute((m, n), lambda i, j: tvm.sum(X[i, ak].astype(
"int32") * packedW[j // 16, (ak // 4) * 16 + j % 16, ak % 4].astype("int32"), axis=ak), name="F")
t_sch = tvm.create_schedule(t_fc.op)
a_x, a_y = t_fc.op.axis
a_k, = t_fc.op.reduce_axis
a_yo, a_yi = t_sch[t_fc].split(a_y, factor=16)
a_xo, a_xi = t_sch[t_fc].split(a_x, factor=32)
a_ko, a_ki = t_sch[t_fc].split(a_k, factor=4)
a_koo, a_koi = t_sch[t_fc].split(a_ko, factor=4)
t_sch[t_fc].reorder(a_yo, a_xo, a_xi, a_koo, a_koi, a_yi, a_ki)
t_sch[t_fc].unroll(a_koi)
t_sch[t_fc].tensorize(a_yi, pc)
t_func = tvm.build(t_sch, [X, packedW, t_fc], target, name="intrinsic")
t_evaluator = t_func.time_evaluator(t_func.entry_name, ctx, number=10)
# generate the plain data
a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8")
b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8")
packW = np.random.uniform(1, 10, size=(
n // 16, 16 * (k // 4), 4)).astype("int8")
# This occurs in pre_compute stage
for r_idx in range(n // 16):
for s_idx in range(16 * (k // 4)):
for t_idx in range(4):
packW[r_idx][s_idx][t_idx] = b_[r_idx * 16 + s_idx %
16][(s_idx // 16) * 4 + t_idx]
x = tvm.nd.array(a_, ctx)
w = tvm.nd.array(packW, ctx)
y = tvm.nd.array(np.zeros((m, n), dtype="int32"), ctx)
result = t_evaluator(x, w, y)
gops_per_sec = gops_per_mm / result.mean / 1e9
# verify the correctness
tvm.testing.assert_allclose(y.asnumpy(), np.dot(a_, b_.T), rtol=0)
print('Tensorization: running time: {:.3f} ms, {:.2f} Gops/s'.format(
result.mean * 1000, gops_per_sec))
verify()
if __name__ == "__main__":
test_avx2_int8_gemm_acc32()
pass
......@@ -275,97 +275,3 @@ def dot_16x1x16_int8_int8_int32_vnni():
with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
def dot_1x4x16_int8_int8_int32_avx2():
"""
Int8 dot product by every 4 elements using x86 AVX2 instructions.
This function takes two arrays of int8 datatype -- data[4] and
kernel[16][4] -- and computes a dot product of data[4] with every
4 elements of kernels, resulting in output[16] of int32 datatype.
The pseudo code is as follows.
.. code-block:: c
void dot_1x4x16_int8_int8_int32(int8 data[4], int8 kernel[16][4],
int32 output[16]){
for (int i = 0; i < 16; i++){
out[i] = 0;
for (int k = 0; k < 4; k++){
out[i] += data[k] * kernel[i][k]
}
}
}
Physically, the kernel array sits in two AVX2 vector registers and
the data[4] is broadcasted to AVX2 vector register. This
function returns a TensorIntrin that can be used to tensorize
a schedule.
Returns
-------
intrin : TensorIntrin
The AVX2 int8 TensorIntrin that can be used in tensorizing schedule
"""
int32_lanes = 16 # 16 int32 lanes in AVX2
num_int8_elements = 4 # 4 int8 elements in int32
data = tvm.placeholder((num_int8_elements,), dtype='uint8', name='data')
kernel = tvm.placeholder((int32_lanes, num_int8_elements), dtype='int8', name='kernel')
k = tvm.reduce_axis((0, num_int8_elements), name='k')
C = tvm.compute((int32_lanes,),
lambda i: tvm.sum(data[k].astype('int32') *
kernel[i, k].astype('int32'),
axis=k),
name="C")
a_buffer = tvm.decl_buffer(data.shape, dtype='uint8', name="a_buffer",
offset_factor=1,
strides=[1])
b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer",
offset_factor=1,
strides=[tvm.var('ldw'), 1])
def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.ir_builder.create()
if index == 1:
ib.emit(outs[0].vstore(0, tvm.const(0, 'int32x16')))
return ib.get()
a_int8 = ins[0].vload([0], "uint8x4")
re_int32 = tvm.call_pure_intrin('int32', 'reinterpret', a_int8)
vec_ai32 = re_int32.astype('int32x8')
vec_a = tvm.call_pure_intrin('int8x32', 'reinterpret', vec_ai32)
vec_b_0 = ins[1].vload([0, 0], "int8x32")
vec_b_1 = ins[1].vload([8, 0], "int8x32")
vec_one = tvm.const(1, "int16x16")
pair_reduction_0 = tvm.call_llvm_intrin('int16x16',
'llvm.x86.avx2.pmadd.ub.sw',
tvm.const(0, 'uint32'),
vec_a, vec_b_0)
quad_reduction_0 = tvm.call_llvm_intrin('int32x8',
'llvm.x86.avx2.pmadd.wd',
tvm.const(0, 'uint32'),
pair_reduction_0, vec_one)
pair_reduction_1 = tvm.call_llvm_intrin('int16x16',
'llvm.x86.avx2.pmadd.ub.sw',
tvm.const(0, 'uint32'),
vec_a, vec_b_1)
quad_reduction_1 = tvm.call_llvm_intrin('int32x8',
'llvm.x86.avx2.pmadd.wd',
tvm.const(0, 'uint32'),
pair_reduction_1, vec_one)
if index == 0:
ib.emit(outs[0].vstore([0], quad_reduction_0))
ib.emit(outs[0].vstore([8], quad_reduction_1))
else:
ib.emit(outs[0].vstore([0], quad_reduction_0 + \
outs[0].vload([0], 'int32x8')))
ib.emit(outs[0].vstore([8], quad_reduction_1 + \
outs[0].vload([8], 'int32x8')))
return ib.get()
# body, reset, update
return _instr(0), _instr(1), _instr(2)
with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment