Commit 475158f6 by Zhi Committed by masahi

[relay][refactor] Cache Op::Get in passes to reduce lookup overhead (#4594)

* Refactor to use IsOp utility

* retrigger CI
parent 35af4c8b
...@@ -594,12 +594,11 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr, ...@@ -594,12 +594,11 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return map_.get<ValueType>(expr, def_value); return map_.get<ValueType>(expr, def_value);
} }
/*! /*!
* \brief Check that an expression is a "primtive operator". * \brief Check that an expression is a "primitive operator".
* *
* Will return true if the expression is an operator which * Will return true if the expression is an operator which
* matches the form of primtive operators registered directly * matches the form of primitive operators registered directly
* by the Relay codebase. * by the Relay codebase.
* *
* That is the arguments are all type variables, and there is a single * That is the arguments are all type variables, and there is a single
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
* \file relay/backend/compile_engine.cc * \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine. * \brief Internal compialtion engine.
*/ */
#include "compile_engine.h"
#include <tvm/schedule.h> #include <tvm/schedule.h>
#include <tvm/packed_func_ext.h> #include <tvm/packed_func_ext.h>
#include <tvm/operation.h> #include <tvm/operation.h>
...@@ -29,6 +31,7 @@ ...@@ -29,6 +31,7 @@
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h> #include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <topi/tags.h> #include <topi/tags.h>
#include <utility> #include <utility>
...@@ -38,7 +41,6 @@ ...@@ -38,7 +41,6 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "../ir/type_functor.h" #include "../ir/type_functor.h"
#include "compile_engine.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -102,7 +104,7 @@ class ScheduleGetter : ...@@ -102,7 +104,7 @@ class ScheduleGetter :
public ExprFunctor<Array<Tensor>(const Expr&)> { public ExprFunctor<Array<Tensor>(const Expr&)> {
public: public:
explicit ScheduleGetter(Target target) explicit ScheduleGetter(Target target)
: target_(target) {} : target_(target), device_copy_op_(Op::Get("device_copy")) {}
std::pair<Schedule, CachedFunc> Create(const Function& prim_func) { std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
static auto fschedule = static auto fschedule =
...@@ -250,11 +252,9 @@ class ScheduleGetter : ...@@ -250,11 +252,9 @@ class ScheduleGetter :
CHECK(call_node->op.as<OpNode>()) CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops"; << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op); Op op = Downcast<Op>(call_node->op);
// Check if the op is a device copy op.
bool is_copy_op = op.same_as(Op::Get("device_copy"));
Array<Tensor> outputs; Array<Tensor> outputs;
// Skip fcompute for device copy operators as it is not registered. // Skip fcompute for device copy operators as it is not registered.
if (is_copy_op) { if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->(); const auto* copy_input = inputs[0].operator->();
outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype, outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
Operation(), 0)); Operation(), 0));
...@@ -282,7 +282,7 @@ class ScheduleGetter : ...@@ -282,7 +282,7 @@ class ScheduleGetter :
} }
// Set the name to `__copy`. It will be detected in graph runtime to perform // Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices. // data copy across devices.
if (is_copy_op) { if (op == device_copy_op_) {
readable_name_stream_.str(std::string()); readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy"; readable_name_stream_ << "__copy";
} else { } else {
...@@ -332,6 +332,9 @@ class ScheduleGetter : ...@@ -332,6 +332,9 @@ class ScheduleGetter :
std::ostringstream readable_name_stream_; std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_; std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
Array<Operation> scalars_; Array<Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
}; };
// Creates shape function from functor. // Creates shape function from functor.
......
...@@ -246,10 +246,12 @@ class Interpreter : ...@@ -246,10 +246,12 @@ class Interpreter :
public ExprFunctor<Value(const Expr& n)>, public ExprFunctor<Value(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const Value& v)> { PatternFunctor<bool(const Pattern& p, const Value& v)> {
public: public:
Interpreter(Module mod, Interpreter(Module mod, DLContext context, Target target)
DLContext context, : mod_(mod),
Target target) context_(context),
: mod_(mod), context_(context), target_(target) { target_(target),
debug_op_(Op::Get("debug")),
shape_of_op_(Op::Get("shape_of")) {
engine_ = CompileEngine::Global(); engine_ = CompileEngine::Global();
} }
...@@ -263,7 +265,7 @@ class Interpreter : ...@@ -263,7 +265,7 @@ class Interpreter :
stack_.current_frame().locals.Set(id, v); stack_.current_frame().locals.Set(id, v);
} }
inline Value Lookup(const Var& local) { Value Lookup(const Var& local) {
return stack_.Lookup(local); return stack_.Lookup(local);
} }
...@@ -307,7 +309,7 @@ class Interpreter : ...@@ -307,7 +309,7 @@ class Interpreter :
return TupleValueNode::make(values); return TupleValueNode::make(values);
} }
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) { Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod; tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func); Array<Var> free_vars = FreeVars(func);
...@@ -454,9 +456,9 @@ class Interpreter : ...@@ -454,9 +456,9 @@ class Interpreter :
Value InvokePrimitiveOp(const Function& func, Value InvokePrimitiveOp(const Function& func,
const Array<Value>& args) { const Array<Value>& args) {
auto call_node = func->body.as<CallNode>(); const auto* call_node = func->body.as<CallNode>();
if (call_node && call_node->op == Op::Get("debug")) { if (call_node && call_node->op == debug_op_) {
auto dattrs = call_node->attrs.as<DebugAttrs>(); auto dattrs = call_node->attrs.as<DebugAttrs>();
auto interp_state = this->get_state(call_node->args[0]); auto interp_state = this->get_state(call_node->args[0]);
...@@ -540,7 +542,7 @@ class Interpreter : ...@@ -540,7 +542,7 @@ class Interpreter :
Array<Shape> out_shapes; Array<Shape> out_shapes;
auto ret_type = func->body->checked_type(); auto ret_type = func->body->checked_type();
bool is_dyn = IsDynamic(func->checked_type()); bool is_dyn = IsDynamic(func->checked_type());
if (call_node->op == Op::Get("shape_of")) { if (call_node->op == shape_of_op_) {
// The output shape of shape_of must be static since Relay doesn't support // The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors. // dynamic rank tensors.
is_dyn = false; is_dyn = false;
...@@ -782,6 +784,9 @@ class Interpreter : ...@@ -782,6 +784,9 @@ class Interpreter :
Stack stack_; Stack stack_;
// Backend compile engine. // Backend compile engine.
CompileEngine engine_; CompileEngine engine_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
const Op& shape_of_op_;
}; };
......
...@@ -62,6 +62,8 @@ namespace relay { ...@@ -62,6 +62,8 @@ namespace relay {
// \endcode // \endcode
class CastCanonicalizer : public ExprMutator { class CastCanonicalizer : public ExprMutator {
public: public:
CastCanonicalizer() : cast_op_(Op::Get("cast")) {}
Expr VisitExpr_(const CallNode* call) { Expr VisitExpr_(const CallNode* call) {
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern"); static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
...@@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator { ...@@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator {
private: private:
std::unordered_map<const Node*, size_t> ref_counter_; std::unordered_map<const Node*, size_t> ref_counter_;
// cast op is frequently checked for equivalence. Therefore, we cache it to
// reduce lookup overhead.
const Op& cast_op_;
Expr GetNewCallArg(const Expr& e) { Expr GetNewCallArg(const Expr& e) {
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor // if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
static auto& cast = Op::Get("cast");
Expr new_expr = this->VisitExpr(e); Expr new_expr = this->VisitExpr(e);
if (const CallNode* call = e.as<CallNode>()) { if (const CallNode* call = e.as<CallNode>()) {
if (call->op.same_as(cast)) { if (call->op == cast_op_) {
auto attrs = call->attrs.as<CastAttrs>(); auto attrs = call->attrs.as<CastAttrs>();
const auto* from_type = call->args[0]->type_as<TensorTypeNode>(); const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
CHECK(from_type); CHECK(from_type);
...@@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator { ...@@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator {
if (++ref_counter_[call] > 1) { if (++ref_counter_[call] > 1) {
const CallNode* new_call = new_expr.as<CallNode>(); const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call); CHECK(new_call);
CHECK(new_call->op.same_as(cast)); CHECK(new_call->op == cast_op_);
return CallNode::make(new_call->op, new_call->args, new_call->attrs, return CallNode::make(new_call->op, new_call->args, new_call->attrs,
new_call->type_args); new_call->type_args);
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
*/ */
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include "pattern_util.h" #include "pattern_util.h"
...@@ -33,10 +34,11 @@ namespace relay { ...@@ -33,10 +34,11 @@ namespace relay {
class BiasAddSimplifier : public ExprMutator { class BiasAddSimplifier : public ExprMutator {
public: public:
BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}
Expr VisitExpr_(const CallNode* n) { Expr VisitExpr_(const CallNode* n) {
static const Op& bias_add = Op::Get("nn.bias_add");
auto new_n = ExprMutator::VisitExpr_(n); auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(bias_add)) { if (n->op == bias_add_op_) {
Call call = Downcast<Call>(new_n); Call call = Downcast<Call>(new_n);
CHECK_EQ(call->args.size(), 2); CHECK_EQ(call->args.size(), 2);
const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>(); const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
...@@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator { ...@@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator {
} }
return new_n; return new_n;
} }
private:
// Cache the bias_add for equivalence checking.
const Op& bias_add_op_;
}; };
Expr CanonicalizeOps(const Expr& e) { Expr CanonicalizeOps(const Expr& e) {
......
...@@ -27,29 +27,30 @@ ...@@ -27,29 +27,30 @@
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h> #include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h> #include <tvm/relay/transform.h>
#include <algorithm>
#include <utility>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "./expr_subst.h" #include "expr_subst.h"
#include "./pattern_util.h" #include "pattern_util.h"
#include "./combine_parallel_op.h" #include "combine_parallel_op.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
BranchGroupFinder::BranchGroupFinder(const std::string& op_name, BranchGroupFinder::BranchGroupFinder(const Op& op,
FIsSupportedOp fis_supported_op, FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops) FAreCompatibleOps fare_compatible_ops)
: op_name_(op_name), : cached_op_(op),
fis_supported_op_(fis_supported_op), fis_supported_op_(fis_supported_op),
fare_compatible_ops_(fare_compatible_ops) { fare_compatible_ops_(fare_compatible_ops) {
} }
std::vector<Group> BranchGroupFinder::Find(const Expr& expr) { std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const Op& op = Op::Get(op_name_);
this->VisitExpr(expr); this->VisitExpr(expr);
std::vector<Group> groups; std::vector<Group> groups;
...@@ -57,7 +58,7 @@ std::vector<Group> BranchGroupFinder::Find(const Expr& expr) { ...@@ -57,7 +58,7 @@ std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const auto& children = children_map_.at(root); const auto& children = children_map_.at(root);
size_t ngroups = groups.size(); size_t ngroups = groups.size();
for (const CallNode* child : children) { for (const CallNode* child : children) {
if (!child->op.same_as(op)) continue; if (child->op != cached_op_) continue;
auto&& branch = CreateBranch(child); auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group // add the branch to a group, or create a new group
...@@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) { ...@@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
} }
void BranchGroupFinder::VisitExpr_(const CallNode* n) { void BranchGroupFinder::VisitExpr_(const CallNode* n) {
const Op& op = Op::Get(op_name_);
ExprVisitor::VisitExpr_(n); ExprVisitor::VisitExpr_(n);
if (n->op.same_as(op) && fis_supported_op_(n)) { if (n->op == cached_op_ && fis_supported_op_(n)) {
op_roots_.insert(n->args[0]); op_roots_.insert(n->args[0]);
children_map_[n->args[0]].push_back(n); children_map_[n->args[0]].push_back(n);
} else { } else {
...@@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) { ...@@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) {
} }
ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches) ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches)
: op_name_(op_name), : cached_op_(Op::Get(op_name)),
min_num_branches_(min_num_branches) { min_num_branches_(min_num_branches) {
} }
Expr ParallelOpCombiner::Combine(const Expr& expr) { Expr ParallelOpCombiner::Combine(const Expr& expr) {
auto groups = BranchGroupFinder(op_name_, auto groups = BranchGroupFinder(cached_op_,
[&](const CallNode* n) { [&](const CallNode* n) {
return IsSupportedOp(n); return IsSupportedOp(n);
}, },
......
...@@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor { ...@@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor {
public: public:
/* /*
* \brief Constructor * \brief Constructor
* \param op_name name of op to start each group * \param op The op that indicates the start of each group
* \param fis_supported_op function that returns true if op * \param fis_supported_op function that returns true if op
* is supported for combining * is supported for combining
* \param fare_compatible_ops function that returns true if * \param fare_compatible_ops function that returns true if
* two ops are compatible for combining * two ops are compatible for combining
*/ */
BranchGroupFinder(const std::string& op_name, BranchGroupFinder(const Op& op,
FIsSupportedOp fis_supported_op, FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops); FAreCompatibleOps fare_compatible_ops);
...@@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor { ...@@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor {
std::vector<Group> Find(const Expr& expr); std::vector<Group> Find(const Expr& expr);
private: private:
/* \brief name of op to find parallel branches for */ /* \brief Cache the op for finding parallel branches */
std::string op_name_; const Op& cached_op_;
/* \brief function to return true if op is eligible to be combined, /* \brief function to return true if op is eligible to be combined,
* false otherwise * false otherwise
...@@ -205,8 +205,8 @@ class ParallelOpCombiner { ...@@ -205,8 +205,8 @@ class ParallelOpCombiner {
ExprSubstMap* subst_map) = 0; ExprSubstMap* subst_map) = 0;
private: private:
/* \brief name of op to be combined */ /* \brief Cache the op to be combined */
std::string op_name_; const Op& cached_op_;
/* \brief minimum number of parallel branches to combine */ /* \brief minimum number of parallel branches to combine */
uint64_t min_num_branches_; uint64_t min_num_branches_;
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
*/ */
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h> #include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h> #include <tvm/relay/attrs/transform.h>
...@@ -33,7 +34,6 @@ namespace relay { ...@@ -33,7 +34,6 @@ namespace relay {
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>; using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
class ConstantChecker : private ExprVisitor { class ConstantChecker : private ExprVisitor {
public: public:
// Check whether an expression is constant. The results are memoized. // Check whether an expression is constant. The results are memoized.
...@@ -78,8 +78,14 @@ TVM_REGISTER_API("relay._analysis.check_constant") ...@@ -78,8 +78,14 @@ TVM_REGISTER_API("relay._analysis.check_constant")
class ConstantFolder : public ExprMutator { class ConstantFolder : public ExprMutator {
public: public:
explicit ConstantFolder(FInterpreter executor, Module module) explicit ConstantFolder(FInterpreter executor, Module module)
: executor_(executor), module_(module) { : executor_(executor),
} module_(module),
shape_of_op_(Op::Get("shape_of")),
invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
shape_func_op_(Op::Get("memory.shape_func")),
alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
alloc_storage_op_(Op::Get("memory.alloc_storage")),
cast_op_(Op::Get("cast")) {}
Expr VisitExpr_(const LetNode* op) final { Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value); Expr value = this->Mutate(op->value);
...@@ -119,15 +125,15 @@ class ConstantFolder : public ExprMutator { ...@@ -119,15 +125,15 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops. // skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res; if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op // Try to evaluate shape_of op
if (call->op.same_as(Op::Get("shape_of"))) { if (call->op == shape_of_op_) {
return EvaluateShapeOf(res, origin_args, call->attrs); return EvaluateShapeOf(res, origin_args, call->attrs);
} }
// We should think about potentially constant evaluation over these ops too. // We should think about potentially constant evaluation over these ops too.
if (call->op.same_as(Op::Get("memory.invoke_tvm_op")) || if (call->op == invoke_tvm_op_ ||
call->op.same_as(Op::Get("memory.shape_func")) || call->op == shape_func_op_ ||
call->op.same_as(Op::Get("memory.alloc_tensor")) || call->op == alloc_tensor_op_ ||
call->op.same_as(Op::Get("memory.alloc_storage"))) { call->op == alloc_storage_op_) {
return GetRef<Call>(call); return GetRef<Call>(call);
} }
...@@ -162,6 +168,14 @@ class ConstantFolder : public ExprMutator { ...@@ -162,6 +168,14 @@ class ConstantFolder : public ExprMutator {
// Module // Module
Module module_; Module module_;
// Cache the following ops for equivalence checking in this pass.
const Op& shape_of_op_;
const Op& invoke_tvm_op_;
const Op& shape_func_op_;
const Op& alloc_tensor_op_;
const Op& alloc_storage_op_;
const Op& cast_op_;
// Convert value to expression. // Convert value to expression.
Expr ValueToExpr(Value value) { Expr ValueToExpr(Value value) {
if (const auto* val = value.as<TensorValueNode>()) { if (const auto* val = value.as<TensorValueNode>()) {
...@@ -254,8 +268,7 @@ class ConstantFolder : public ExprMutator { ...@@ -254,8 +268,7 @@ class ConstantFolder : public ExprMutator {
// Cast the constant into correct dtype // Cast the constant into correct dtype
auto cast_attrs = make_node<CastAttrs>(); auto cast_attrs = make_node<CastAttrs>();
cast_attrs->dtype = param->dtype; cast_attrs->dtype = param->dtype;
static const Op& cast_op = Op::Get("cast"); Expr ret = CallNode::make(cast_op_, { shape }, Attrs(cast_attrs), {});
Expr ret = CallNode::make(cast_op, { shape }, Attrs(cast_attrs), {});
return ConstEvaluate(ret); return ConstEvaluate(ret);
} }
}; };
......
...@@ -78,6 +78,8 @@ using common::LinkedList; ...@@ -78,6 +78,8 @@ using common::LinkedList;
constexpr uint32_t kMaxFusedOps = 256; constexpr uint32_t kMaxFusedOps = 256;
static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");
/*! /*!
* \brief Indexed data flow graph in forward direction. * \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis. * This is a temporary data structure used for operator fusion analysis.
...@@ -860,7 +862,6 @@ class FuseMutator : private ExprMutator { ...@@ -860,7 +862,6 @@ class FuseMutator : private ExprMutator {
// Transform calls. // Transform calls.
Expr VisitExpr_(const CallNode* call) { Expr VisitExpr_(const CallNode* call) {
static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
if (call->op.as<OpNode>()) { if (call->op.as<OpNode>()) {
static auto fnoncomputational = static auto fnoncomputational =
Op::GetAttr<TNonComputational>("TNonComputational"); Op::GetAttr<TNonComputational>("TNonComputational");
...@@ -872,7 +873,7 @@ class FuseMutator : private ExprMutator { ...@@ -872,7 +873,7 @@ class FuseMutator : private ExprMutator {
// If it is a primitive op call // If it is a primitive op call
// then we must have a group assignment for it already. // then we must have a group assignment for it already.
CHECK(gmap_.count(call)); CHECK(gmap_.count(call));
if (call->op.same_as(stop_fusion)) { if (call->op == stop_fusion_op) {
return ExprMutator::VisitExpr(call->args[0]); return ExprMutator::VisitExpr(call->args[0]);
} }
auto* ret_group = gmap_.at(call)->FindRoot(); auto* ret_group = gmap_.at(call)->FindRoot();
......
...@@ -559,30 +559,28 @@ struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> { ...@@ -559,30 +559,28 @@ struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs); TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
Op WithFuncIdOp() {
static const Op& op = Op::Get("annotation.with_funcid");
return op;
}
Expr MkWithFuncId(const Expr& expr, FuncId fid) {
auto attrs = make_node<WithFuncIdAttrs>();
attrs->fid = fid;
return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
}
RELAY_REGISTER_OP("annotation.with_funcid") RELAY_REGISTER_OP("annotation.with_funcid")
.describe(R"code(Annotate a function with a funcid.)code" .describe(R"code(Annotate a function with a funcid.)code"
TVM_ADD_FILELINE) TVM_ADD_FILELINE)
.set_num_inputs(1) .set_num_inputs(1)
.add_argument("func", "Function", "The input data."); .add_argument("func", "Function", "The input data.");
// Cache with_funcid op to reduce lookup overhead during traversal.
static const Op& with_funcid_op = Op::Get("annotation.with_funcid");
Expr MkWithFuncId(const Expr& expr, FuncId fid) {
auto attrs = make_node<WithFuncIdAttrs>();
attrs->fid = fid;
return CallNode::make(with_funcid_op, {expr}, Attrs(attrs), {});
}
Expr StripWithFuncId(const Expr& e); Expr StripWithFuncId(const Expr& e);
Function AsFunc(const Expr& e) { Function AsFunc(const Expr& e) {
if (e.as<FunctionNode>()) { if (e.as<FunctionNode>()) {
return Downcast<Function>(e); return Downcast<Function>(e);
} else if (const CallNode* c = e.as<CallNode>()) { } else if (const CallNode* c = e.as<CallNode>()) {
CHECK(c->op.same_as(WithFuncIdOp())); CHECK(c->op == with_funcid_op);
CHECK_EQ(c->args.size(), 1); CHECK_EQ(c->args.size(), 1);
return AsFunc(c->args[0]); return AsFunc(c->args[0]);
} else { } else {
...@@ -604,7 +602,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -604,7 +602,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) { PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
if (const CallNode* c = e.as<CallNode>()) { if (const CallNode* c = e.as<CallNode>()) {
if (c->op.same_as(WithFuncIdOp())) { if (c->op == with_funcid_op) {
CHECK_EQ(c->args.size(), 1); CHECK_EQ(c->args.size(), 1);
return VisitExpr(c->args[0], ll, name); return VisitExpr(c->args[0], ll, name);
} }
...@@ -722,7 +720,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -722,7 +720,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
PStatic VisitExpr_(const CallNode* op, LetList* ll) final { PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
if (op->op.same_as(WithFuncIdOp())) { if (op->op == with_funcid_op) {
CHECK_EQ(op->args.size(), 1); CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0], ll); return VisitExpr(op->args[0], ll);
} }
...@@ -1096,7 +1094,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -1096,7 +1094,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { } explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
void VisitExpr_(const CallNode* op) final { void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) { if (op->op == with_funcid_op) {
CHECK_EQ(op->args.size(), 1); CHECK_EQ(op->args.size(), 1);
CHECK(op->attrs.defined()); CHECK(op->attrs.defined());
CHECK(op->attrs.as<WithFuncIdAttrs>()); CHECK(op->attrs.as<WithFuncIdAttrs>());
...@@ -1194,7 +1192,7 @@ Expr Remap(const Expr& e) { ...@@ -1194,7 +1192,7 @@ Expr Remap(const Expr& e) {
Expr StripWithFuncId(const Expr& e) { Expr StripWithFuncId(const Expr& e) {
struct StripWithFuncIdMutator : ExprMutator, PatternMutator { struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
Expr VisitExpr_(const CallNode* op) final { Expr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) { if (op->op == with_funcid_op) {
CHECK_EQ(op->args.size(), 1); CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0]); return VisitExpr(op->args[0]);
} else { } else {
......
...@@ -25,15 +25,17 @@ ...@@ -25,15 +25,17 @@
*/ */
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include "./quantize.h" #include "./quantize.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
namespace quantize { namespace quantize {
class StatsCollector : private ExprMutator { class StatsCollector : private ExprMutator {
public: public:
StatsCollector() : simulated_quantize_op_(Op::Get("relay.op.annotation.simulated_quantize")) {}
Expr Collect(const Expr& expr) { Expr Collect(const Expr& expr) {
auto new_e = this->Mutate(expr); auto new_e = this->Mutate(expr);
const FunctionNode* func = new_e.as<FunctionNode>(); const FunctionNode* func = new_e.as<FunctionNode>();
...@@ -45,13 +47,13 @@ class StatsCollector : private ExprMutator { ...@@ -45,13 +47,13 @@ class StatsCollector : private ExprMutator {
private: private:
Array<Expr> profile_data_; Array<Expr> profile_data_;
const Op& simulated_quantize_op_;
Expr VisitExpr_(const CallNode* call) { Expr VisitExpr_(const CallNode* call) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
Expr new_e = ExprMutator::VisitExpr_(call); Expr new_e = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_e.as<CallNode>(); const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call); CHECK(new_call);
if (new_call->op.same_as(simulated_quantize)) { if (new_call->op == simulated_quantize_op_) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>(); auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
// rewrite the annotation // rewrite the annotation
auto new_attrs = make_node<SimulatedQuantizeAttrs>(); auto new_attrs = make_node<SimulatedQuantizeAttrs>();
......
...@@ -91,7 +91,6 @@ Expr LayerNormToInferUnpack(const Attrs attrs, ...@@ -91,7 +91,6 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
return out; return out;
} }
Expr InstanceNormToInferUnpack(const Attrs attrs, Expr InstanceNormToInferUnpack(const Attrs attrs,
Expr data, Expr data,
Expr gamma, Expr gamma,
...@@ -125,23 +124,25 @@ Expr InstanceNormToInferUnpack(const Attrs attrs, ...@@ -125,23 +124,25 @@ Expr InstanceNormToInferUnpack(const Attrs attrs,
return out; return out;
} }
class InferenceSimplifier : public ExprMutator { class InferenceSimplifier : public ExprMutator {
public: public:
Expr VisitExpr_(const TupleGetItemNode* n) final { InferenceSimplifier()
static const Op& batch_norm = Op::Get("nn.batch_norm"); : batch_norm_op_(Op::Get("nn.batch_norm")),
static const Op& dropout = Op::Get("nn.dropout"); dropout_op_(Op::Get("nn.dropout")),
instance_norm_op_(Op::Get("nn.instance_norm")),
layer_norm_op_(Op::Get("nn.layer_norm")) {}
Expr VisitExpr_(const TupleGetItemNode* n) final {
Expr new_e = ExprMutator::VisitExpr_(n); Expr new_e = ExprMutator::VisitExpr_(n);
const auto* new_n = new_e.as<TupleGetItemNode>(); const auto* new_n = new_e.as<TupleGetItemNode>();
if (new_n->index != 0) { if (new_n->index != 0) {
return new_e; return new_e;
} }
if (const auto* call = new_n->tuple.as<CallNode>()) { if (const auto* call = new_n->tuple.as<CallNode>()) {
if (call->op.same_as(batch_norm)) { if (call->op == batch_norm_op_) {
return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2], return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
call->args[3], call->args[4], ty_map_.at(call->args[0])); call->args[3], call->args[4], ty_map_.at(call->args[0]));
} else if (call->op.same_as(dropout)) { } else if (call->op == dropout_op_) {
return call->args[0]; return call->args[0];
} }
} }
...@@ -149,17 +150,14 @@ class InferenceSimplifier : public ExprMutator { ...@@ -149,17 +150,14 @@ class InferenceSimplifier : public ExprMutator {
} }
Expr VisitExpr_(const CallNode* n) { Expr VisitExpr_(const CallNode* n) {
static const Op& batch_norm = Op::Get("nn.batch_norm");
static const Op& instance_norm = Op::Get("nn.instance_norm");
static const Op& layer_norm = Op::Get("nn.layer_norm");
auto new_n = ExprMutator::VisitExpr_(n); auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(batch_norm)) { if (n->op == batch_norm_op_) {
ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type(); ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
} else if (n->op.same_as(layer_norm)) { } else if (n->op == layer_norm_op_) {
const auto* call = new_n.as<CallNode>(); const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1], return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type()); call->args[2], n->args[0]->checked_type());
} else if (n->op.same_as(instance_norm)) { } else if (n->op == instance_norm_op_) {
const auto* call = new_n.as<CallNode>(); const auto* call = new_n.as<CallNode>();
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1], return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type()); call->args[2], n->args[0]->checked_type());
...@@ -168,6 +166,13 @@ class InferenceSimplifier : public ExprMutator { ...@@ -168,6 +166,13 @@ class InferenceSimplifier : public ExprMutator {
} }
private: private:
// Cache the following ops. They will be used in the passes repeatedly for
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
const Op& batch_norm_op_;
const Op& dropout_op_;
const Op& instance_norm_op_;
const Op& layer_norm_op_;
std::unordered_map<Expr, Type, NodeHash, NodeEqual> ty_map_; std::unordered_map<Expr, Type, NodeHash, NodeEqual> ty_map_;
}; };
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
*/ */
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include "pass_util.h" #include "pass_util.h"
#include "../ir/type_functor.h" #include "../ir/type_functor.h"
...@@ -360,13 +361,14 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) { ...@@ -360,13 +361,14 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
return true; return true;
} }
// Cache the operators that are checked recursively to reduce lookup overhead.
static const auto& expand_dims_op = Op::Get("expand_dims");
static const auto& reshape_op = Op::Get("reshape");
static const auto& transpose_op = Op::Get("transpose");
static const auto& squeeze_op = Op::Get("squeeze");
bool IsAllPositiveConstant(const Expr& expr) { bool IsAllPositiveConstant(const Expr& expr) {
// peel through a few common transform ops. // peel through a few common transform ops.
static const auto& expand_dims = Op::Get("expand_dims");
static const auto& reshape = Op::Get("reshape");
static const auto& transpose = Op::Get("transpose");
static const auto& squeeze = Op::Get("squeeze");
if (const auto* constant = expr.as<ConstantNode>()) { if (const auto* constant = expr.as<ConstantNode>()) {
const auto& tensor = constant->data; const auto& tensor = constant->data;
const auto& dtype = tensor->dtype; const auto& dtype = tensor->dtype;
...@@ -389,10 +391,10 @@ bool IsAllPositiveConstant(const Expr& expr) { ...@@ -389,10 +391,10 @@ bool IsAllPositiveConstant(const Expr& expr) {
} }
} else if (const auto* op = expr.as<CallNode>()) { } else if (const auto* op = expr.as<CallNode>()) {
// tail recursion. // tail recursion.
if (op->op.same_as(expand_dims) || if (op->op == expand_dims_op ||
op->op.same_as(reshape) || op->op == reshape_op ||
op->op.same_as(transpose) || op->op == transpose_op ||
op->op.same_as(squeeze)) { op->op == squeeze_op) {
return IsAllPositiveConstant(op->args[0]); return IsAllPositiveConstant(op->args[0]);
} else { } else {
return false; return false;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment