Commit b410df8c by Kimish Patel Committed by Tianqi Chen

Changes to make tensorize work. These changes also fix the previously broken test. (#3981)

* Changes to make tensorize work. These changes also fix the previously
broken test.

Summary:
Tensorize was breaking  for a few reasons.
1)
Assert at: src/op/tensorize.cc:234 CHECK(is_one(e.region[j]->extent))
In some cases this cannot be proven, e.g.:
expected shape=[16, 4], given region=[range(min=((ax1.outer*16)/16), ext=(((((ax1.outer*16) + 15)/16) + 1) - ax1.outer)), range(min=((k.outer*4)/4), ext=(((((k.outer*4) + 3)/4) + 1) - k.outer)), range(min=0, ext=16), range(min=0, ext=4)]
The unprovable one is: ext=(((((ax1.outer*16) + 15)/16) + 1) - ax1.outer)).
This can be simplified but it is not because to simplify divide, it must
prove ax1.outer > 0 and since it is var it cannot. The fix for this to
just find all the vars in expr in relace them with some const value.

2) Equivalence between tensorized expr and one being asked to tensorize. For example,
the error would be.
TVMError: Check failed: Equal(lhs, rhs):
Failed to match the compute with TensorIntrin tensor_intrin's declaration
provided= reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[(int16)0]), source=[(int16(data(k))*int16(kernel(((((((((k.outer.outer*64) + (k.outer.inner*2)) + k)/2)*128) + i) - (k.outer.inner*128)) - (k.outer.outer*4096)), ((((k.outer.outer*64) + (k.outer.inner*2)) + k) % 2))))], axis=[iter_var(k, range(min=0, ext=2))], where=(bool)1, value_index=0),
intrin=  reduce(combiner=comm_reducer(result=[(x + y)], lhs=[x], rhs=[y], identity_element=[(int16)0]), source=[(int16(data(k))*int16(kernel(i, k)))], axis=[iter_var(k, range(min=0, ext=2))], where=(bool)1, value_index=0)
Difference is mainly in the source part:
source=[(int16(data(k))*int16(kernel(((((((((k.outer.outer*64) + (k.outer.inner*2)) + k)/2)*128) + i) - (k.outer.inner*128)) - (k.outer.outer*4096)), ((((k.outer.outer*64) + (k.outer.inner*2)) + k) % 2))))]
source=[(int16(data(k))*int16(kernel(i, k)))], axis=[iter_var(k, range(min=0, ext=2))]
This was not being simpifiled due to compute_intrin_iter_space (map for
iter var to range) not containing leaf iter vars.

3) Here it fails with:
Check failed: is_one(Simplify(value->shape[i])): Argument b_buffer shape mismatch[16, 4] vs [(((((ax1.outer*16) + 15)/16) + 1) - ax1.outer), (((((k.outer*4) + 3)/4) + 1) - k.outer), 16, 4]
This is in buffer binding where it thinks expected and buffer bound
shape is different. Although if we could simplify expr, this would not
be the case.

Test Plan:
On skylake avx512 machine:
python tests/python/contrib/test_gemm_acc16.py

Reviewers:

Subscribers:

Tasks:

Tags:

* Implemented bounded analyzer which traverses tree and for reduce/for
statements binds the bound of the analyzer. Later this is used to
simplify expressions. Inspired from ir_mutator_with_analyzer

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Addressed comments.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Added ASF header + define macro for the header file: TVM_ARITHMETIC_IR_VISITOR_WITH_ANALYZER_H_
Some lint fixes as well.

* Relax the assumption that dom_map must always contain all leaf itervars.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

* Disable copy constructor and move to raw ptr.

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
parent d1830964
...@@ -471,6 +471,11 @@ class IntSetAnalyzer { ...@@ -471,6 +471,11 @@ class IntSetAnalyzer {
*/ */
class Analyzer { class Analyzer {
public: public:
/*
* Disable copy constructor.
*/
Analyzer(const Analyzer&) = delete;
Analyzer& operator=(const Analyzer&) = delete;
/*! \brief sub-analyzer: const integer bound */ /*! \brief sub-analyzer: const integer bound */
ConstIntBoundAnalyzer const_int_bound; ConstIntBoundAnalyzer const_int_bound;
/*! \brief sub-analyzer: modular set */ /*! \brief sub-analyzer: modular set */
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/arithmetic/ir_visitor_with_analyzer.h
* \brief IR visitor class with an analyzer context.
*/
#ifndef TVM_ARITHMETIC_IR_VISITOR_WITH_ANALYZER_H_
#define TVM_ARITHMETIC_IR_VISITOR_WITH_ANALYZER_H_
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
namespace tvm {
namespace ir {
class IRVisitorWithAnalyzer final : public IRVisitor {
public:
Expr Simplify(const Expr& expr) {
return analyzer_.Simplify(expr);
}
void Visit_(const For* op) {
analyzer_.Bind(op->loop_var,
Range::make_by_min_extent(op->min, op->extent));
return IRVisitor::Visit_(op);
}
void Visit_(const AttrStmt* op) {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::virtual_thread) {
IterVar iv(op->node.node_);
CHECK_NE(iv->thread_tag.length(), 0U);
analyzer_.Bind(iv->var,
Range::make_by_min_extent(0, op->value));
IRVisitor::Visit_(op);
} else {
IRVisitor::Visit_(op);
}
}
void Visit_(const Reduce* op) {
// Setup the domain information before simplification.
for (const IterVar& iv : op->axis) {
analyzer_.Bind(iv->var, iv->dom);
}
// Recursively call simplification when necessary.
IRVisitor::Visit_(op);
}
protected:
/*! \brief internal analyzer field. */
arith::Analyzer analyzer_;
};
} // namespace ir
} // namespace tvm
#endif // TVM_ARITHMETIC_IR_VISITOR_WITH_ANALYZER_H_
...@@ -157,7 +157,6 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self, ...@@ -157,7 +157,6 @@ void VerifyTensorizeLoopNest(const ComputeOpNode* self,
} }
} }
// Remap the tensor placeholder, index and inline things. // Remap the tensor placeholder, index and inline things.
class TensorIntrinMatcher final : public IRMutator { class TensorIntrinMatcher final : public IRMutator {
public: public:
...@@ -207,11 +206,22 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -207,11 +206,22 @@ class TensorIntrinMatcher final : public IRMutator {
void Init(const ComputeOpNode* self, void Init(const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region, const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin, const TensorIntrin& intrin,
Map<Var, Range>* compute_intrin_iter_space) { Map<Var, Range>* compute_intrin_iter_space) {
CHECK(self == stage->op.get()); CHECK(self == stage->op.get());
for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
IterVar iv = stage->leaf_iter_vars[i];
auto vit = dom_map.find(iv);
if (vit != dom_map.end()) {
const Range vrange = vit->second;
compute_intrin_iter_space->Set(iv->var, vrange);
}
}
// input remap. // input remap.
Array<Tensor> inputs = self->InputTensors(); Array<Tensor> inputs = self->InputTensors();
CHECK_EQ(inputs.size(), intrin->inputs.size()); CHECK_EQ(inputs.size(), intrin->inputs.size());
...@@ -222,8 +232,9 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -222,8 +232,9 @@ class TensorIntrinMatcher final : public IRMutator {
CHECK_GE(e.region.size(), e.tensor.ndim()); CHECK_GE(e.region.size(), e.tensor.ndim());
// Enable fuzzy matching, to match [1, n, m] to [n, m] // Enable fuzzy matching, to match [1, n, m] to [n, m]
e.start = e.region.size() - e.tensor.ndim(); e.start = e.region.size() - e.tensor.ndim();
for (size_t i = 0; i < e.start; ++i) { for (size_t j = 0; j < e.start; ++j) {
CHECK(is_one(e.region[i]->extent)) auto canonical_extent = Simplify(e.region[j]->extent, *compute_intrin_iter_space);
CHECK(is_one(canonical_extent))
<< "Tensorize " << intrin->name << ":" << "Tensorize " << intrin->name << ":"
<< " Input dimension mismatch with tensor intrin " << " Input dimension mismatch with tensor intrin "
<< " expected shape=" << e.tensor->shape << " expected shape=" << e.tensor->shape
...@@ -298,12 +309,13 @@ class TensorIntrinMatcher final : public IRMutator { ...@@ -298,12 +309,13 @@ class TensorIntrinMatcher final : public IRMutator {
Array<Expr> MatchTensorizeBody( Array<Expr> MatchTensorizeBody(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region, const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin, const TensorIntrin& intrin,
Map<Var, Range>* compute_intrin_iter_space) { Map<Var, Range>* compute_intrin_iter_space) {
TensorIntrinMatcher matcher; TensorIntrinMatcher matcher;
matcher.Init(self, stage, out_dom, in_region, intrin, compute_intrin_iter_space); matcher.Init(self, stage, dom_map, out_dom, in_region, intrin, compute_intrin_iter_space);
Array<Expr> ret; Array<Expr> ret;
for (Expr expr : self->body) { for (Expr expr : self->body) {
ret.push_back(matcher.Mutate(expr)); ret.push_back(matcher.Mutate(expr));
...@@ -314,11 +326,12 @@ Array<Expr> MatchTensorizeBody( ...@@ -314,11 +326,12 @@ Array<Expr> MatchTensorizeBody(
void VerifyTensorizeBody( void VerifyTensorizeBody(
const ComputeOpNode* self, const ComputeOpNode* self,
const Stage& stage, const Stage& stage,
const std::unordered_map<IterVar, Range>& dom_map,
const std::unordered_map<IterVar, Range>& out_dom, const std::unordered_map<IterVar, Range>& out_dom,
const std::unordered_map<Tensor, Array<Range> >& in_region, const std::unordered_map<Tensor, Array<Range> >& in_region,
const TensorIntrin& intrin) { const TensorIntrin& intrin) {
Map<Var, Range> compute_intrin_iter_space; Map<Var, Range> compute_intrin_iter_space;
Array<Expr> body = MatchTensorizeBody(self, stage, out_dom, in_region, intrin, Array<Expr> body = MatchTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin,
&compute_intrin_iter_space); &compute_intrin_iter_space);
const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>(); const ComputeOpNode* intrin_compute = intrin->op.as<ComputeOpNode>();
CHECK(intrin_compute) << "Only support compute intrinsic for now"; CHECK(intrin_compute) << "Only support compute intrinsic for now";
...@@ -356,7 +369,7 @@ Stmt MakeTensorize(const ComputeOpNode* self, ...@@ -356,7 +369,7 @@ Stmt MakeTensorize(const ComputeOpNode* self,
CHECK(intrin.defined()); CHECK(intrin.defined());
ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop); ComputeLoopNest n = ComputeLoopNest::make(self, stage, dom_map, debug_keep_trivial_loop);
VerifyTensorizeLoopNest(self, stage, n, tloc); VerifyTensorizeLoopNest(self, stage, n, tloc);
VerifyTensorizeBody(self, stage, out_dom, in_region, intrin); VerifyTensorizeBody(self, stage, dom_map, out_dom, in_region, intrin);
// Start bind data. // Start bind data.
Stmt nop = Evaluate::make(0); Stmt nop = Evaluate::make(0);
std::vector<Stmt> input_bind_nest, output_bind_nest; std::vector<Stmt> input_bind_nest, output_bind_nest;
...@@ -509,6 +522,7 @@ TVM_REGISTER_API("test.op.MatchTensorizeBody") ...@@ -509,6 +522,7 @@ TVM_REGISTER_API("test.op.MatchTensorizeBody")
CHECK(stage->op.as<ComputeOpNode>()); CHECK(stage->op.as<ComputeOpNode>());
*ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(), *ret = MatchTensorizeBody(stage->op.as<ComputeOpNode>(),
stage, stage,
{},
as_unordered_map(out_dom), as_unordered_map(out_dom),
as_unordered_map(in_region), as_unordered_map(in_region),
intrin, intrin,
......
...@@ -128,7 +128,7 @@ void ArgBinder::BindBuffer(const Buffer& arg, ...@@ -128,7 +128,7 @@ void ArgBinder::BindBuffer(const Buffer& arg,
CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch"; CHECK(fuzzy_match) << "Argument " << arg_name << " size mismatch";
size_t diff = value->shape.size() - arg->shape.size(); size_t diff = value->shape.size() - arg->shape.size();
for (size_t i = 0; i < diff; ++i) { for (size_t i = 0; i < diff; ++i) {
CHECK(is_one(value->shape[i])) CHECK(is_one(Simplify(value->shape[i])))
<< "Argument " << arg_name << " shape mismatch" << "Argument " << arg_name << " shape mismatch"
<< arg->shape << " vs " << value->shape; << arg->shape << " vs " << value->shape;
} }
......
...@@ -23,10 +23,12 @@ ...@@ -23,10 +23,12 @@
*/ */
// Flattens storage from multi-dimensional array to 1D // Flattens storage from multi-dimensional array to 1D
// buffer access as in Halide pipeline. // buffer access as in Halide pipeline.
#include <tvm/arithmetic.h>
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/expr.h> #include <tvm/expr.h>
#include <tvm/operation.h> #include <tvm/operation.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_visitor.h>
#include <tvm/expr_operator.h> #include <tvm/expr_operator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/buffer.h> #include <tvm/buffer.h>
...@@ -36,6 +38,7 @@ ...@@ -36,6 +38,7 @@
#include "ir_util.h" #include "ir_util.h"
#include "arg_binder.h" #include "arg_binder.h"
#include "../arithmetic/compute_expr.h" #include "../arithmetic/compute_expr.h"
#include "../arithmetic/ir_visitor_with_analyzer.h"
#include "../runtime/thread_storage_scope.h" #include "../runtime/thread_storage_scope.h"
namespace tvm { namespace tvm {
...@@ -49,8 +52,10 @@ using intrinsic::tvm_address_of; ...@@ -49,8 +52,10 @@ using intrinsic::tvm_address_of;
class StorageFlattener : public IRMutator { class StorageFlattener : public IRMutator {
public: public:
explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer, explicit StorageFlattener(Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes) int cache_line_size, bool create_bound_attributes,
: create_bound_attributes_(create_bound_attributes) { IRVisitorWithAnalyzer* bounded_analyzer)
: bounded_analyzer_(bounded_analyzer),
create_bound_attributes_(create_bound_attributes) {
for (auto kv : extern_buffer) { for (auto kv : extern_buffer) {
BufferEntry e; BufferEntry e;
e.buffer = kv.second; e.buffer = kv.second;
...@@ -419,7 +424,8 @@ class StorageFlattener : public IRMutator { ...@@ -419,7 +424,8 @@ class StorageFlattener : public IRMutator {
} else { } else {
for (size_t i = 0; i < tuple->args.size(); i += 2) { for (size_t i = 0; i < tuple->args.size(); i += 2) {
begins.push_back(tuple->args[i]); begins.push_back(tuple->args[i]);
extents.push_back(tuple->args[i + 1]); auto new_extent = bounded_analyzer_->Simplify(tuple->args[i+1]);
extents.push_back(new_extent);
} }
} }
Buffer slice = be.buffer.MakeSlice(begins, extents); Buffer slice = be.buffer.MakeSlice(begins, extents);
...@@ -510,6 +516,9 @@ class StorageFlattener : public IRMutator { ...@@ -510,6 +516,9 @@ class StorageFlattener : public IRMutator {
std::vector<ThreadScope> curr_thread_scope_; std::vector<ThreadScope> curr_thread_scope_;
// Collects shapes. // Collects shapes.
std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_; std::vector<std::pair<VarExpr, Array<Expr>>> shape_collector_;
// bounds populator. We really need the analyzer from it.
// However
IRVisitorWithAnalyzer* bounded_analyzer_;
// The size of cacheline // The size of cacheline
int cache_line_size_; int cache_line_size_;
// The current stage is an OpenGL shader. // The current stage is an OpenGL shader.
...@@ -520,9 +529,11 @@ class StorageFlattener : public IRMutator { ...@@ -520,9 +529,11 @@ class StorageFlattener : public IRMutator {
Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer, Stmt StorageFlatten(Stmt stmt, Map<Tensor, Buffer> extern_buffer,
int cache_line_size, bool create_bound_attributes) { int cache_line_size, bool create_bound_attributes) {
IRVisitorWithAnalyzer bounded_analyzer;
bounded_analyzer.Visit(stmt);
stmt = stmt =
StorageFlattener(extern_buffer, cache_line_size, create_bound_attributes) StorageFlattener(extern_buffer, cache_line_size,
.Mutate(stmt); create_bound_attributes, &bounded_analyzer).Mutate(stmt);
return stmt; return stmt;
} }
......
...@@ -43,8 +43,8 @@ def benchmark_fc_int8_acc16(): ...@@ -43,8 +43,8 @@ def benchmark_fc_int8_acc16():
pc = dot_16x1x16_int8_int8_int16() pc = dot_16x1x16_int8_int8_int16()
ak = tvm.reduce_axis((0, k), name='k') ak = tvm.reduce_axis((0, k), name='k')
packedW = tvm.placeholder((n/128, 128*(k/2), 2), name='packedW', dtype="int8") packedW = tvm.placeholder((n//128, 128*(k//2), 2), name='packedW', dtype="int8")
t_fc = tvm.compute((m, n), lambda i, j: tvm.sum(X[i, ak].astype("int16") * packedW[j/128, (ak/2)*128+j%128, ak%2].astype("int16"), axis=ak), name="F") t_fc = tvm.compute((m, n), lambda i, j: tvm.sum(X[i, ak].astype("int16") * packedW[j//128, (ak//2)*128+j%128, ak%2].astype("int16"), axis=ak), name="F")
t_sch = tvm.create_schedule(t_fc.op) t_sch = tvm.create_schedule(t_fc.op)
a_x, a_y = t_fc.op.axis a_x, a_y = t_fc.op.axis
...@@ -66,12 +66,12 @@ def benchmark_fc_int8_acc16(): ...@@ -66,12 +66,12 @@ def benchmark_fc_int8_acc16():
a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8") a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8")
b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8") b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8")
packW = np.random.uniform(1, 10, size=(n/128, 128*(k/2), 2)).astype("int8") packW = np.random.uniform(1, 10, size=(n//128, 128*(k//2), 2)).astype("int8")
# This occurs in pre_compute stage # This occurs in pre_compute stage
for r_idx in range(n/128): for r_idx in range(n//128):
for s_idx in range(128*(k/2)): for s_idx in range(128*(k//2)):
for t_idx in range(2): for t_idx in range(2):
packW[r_idx][s_idx][t_idx] = b_[r_idx*128+s_idx%128][s_idx/128*2+t_idx] packW[r_idx][s_idx][t_idx] = b_[r_idx*128+s_idx%128][s_idx//128*2+t_idx]
x = tvm.nd.array(a_, ctx) x = tvm.nd.array(a_, ctx)
w = tvm.nd.array(packW, ctx) w = tvm.nd.array(packW, ctx)
...@@ -82,7 +82,7 @@ def benchmark_fc_int8_acc16(): ...@@ -82,7 +82,7 @@ def benchmark_fc_int8_acc16():
tvm.testing.assert_allclose( tvm.testing.assert_allclose(
y.asnumpy(), np.dot(a_, b_.T), rtol=1e-5) y.asnumpy(), np.dot(a_, b_.T), rtol=1e-5)
print('Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}.'.format(result.mean*1000, gops_per_sec, gops_per_sec/peak)) print('Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}.'.format(result.mean*1000, gops_per_sec, gops_per_sec/peak))
t_func.export_library("gemm_tensorize.o") #t_func.export_library("gemm_tensorize.o")
verify() verify()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment