Commit 475158f6 by Zhi Committed by masahi

[relay][refactor] Cache Op::Get in passes to reduce lookup overhead (#4594)

* Refactor to use IsOp utility

* retrigger CI
parent 35af4c8b
......@@ -594,12 +594,11 @@ inline ValueType OpMap<ValueType>::get(const Expr& expr,
return map_.get<ValueType>(expr, def_value);
}
/*!
* \brief Check that an expression is a "primtive operator".
* \brief Check that an expression is a "primitive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* matches the form of primitive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
......
......@@ -21,6 +21,8 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"
#include <tvm/schedule.h>
#include <tvm/packed_func_ext.h>
#include <tvm/operation.h>
......@@ -29,6 +31,7 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
#include <utility>
......@@ -38,7 +41,6 @@
#include <vector>
#include <unordered_map>
#include "../ir/type_functor.h"
#include "compile_engine.h"
namespace tvm {
namespace relay {
......@@ -102,7 +104,7 @@ class ScheduleGetter :
public ExprFunctor<Array<Tensor>(const Expr&)> {
public:
explicit ScheduleGetter(Target target)
: target_(target) {}
: target_(target), device_copy_op_(Op::Get("device_copy")) {}
std::pair<Schedule, CachedFunc> Create(const Function& prim_func) {
static auto fschedule =
......@@ -250,11 +252,9 @@ class ScheduleGetter :
CHECK(call_node->op.as<OpNode>())
<< "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);
// Check if the op is a device copy op.
bool is_copy_op = op.same_as(Op::Get("device_copy"));
Array<Tensor> outputs;
// Skip fcompute for device copy operators as it is not registered.
if (is_copy_op) {
if (op == device_copy_op_) {
const auto* copy_input = inputs[0].operator->();
outputs.push_back(TensorNode::make(copy_input->shape, copy_input->dtype,
Operation(), 0));
......@@ -282,7 +282,7 @@ class ScheduleGetter :
}
// Set the name to `__copy`. It will be detected in graph runtime to perform
// data copy across devices.
if (is_copy_op) {
if (op == device_copy_op_) {
readable_name_stream_.str(std::string());
readable_name_stream_ << "__copy";
} else {
......@@ -332,6 +332,9 @@ class ScheduleGetter :
std::ostringstream readable_name_stream_;
std::unordered_map<Expr, Array<Tensor>, NodeHash, NodeEqual> memo_;
Array<Operation> scalars_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
};
// Creates shape function from functor.
......
......@@ -246,10 +246,12 @@ class Interpreter :
public ExprFunctor<Value(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const Value& v)> {
public:
Interpreter(Module mod,
DLContext context,
Target target)
: mod_(mod), context_(context), target_(target) {
Interpreter(Module mod, DLContext context, Target target)
: mod_(mod),
context_(context),
target_(target),
debug_op_(Op::Get("debug")),
shape_of_op_(Op::Get("shape_of")) {
engine_ = CompileEngine::Global();
}
......@@ -263,7 +265,7 @@ class Interpreter :
stack_.current_frame().locals.Set(id, v);
}
inline Value Lookup(const Var& local) {
Value Lookup(const Var& local) {
return stack_.Lookup(local);
}
......@@ -307,7 +309,7 @@ class Interpreter :
return TupleValueNode::make(values);
}
inline Value MakeClosure(const Function& func, Var letrec_name = Var()) {
Value MakeClosure(const Function& func, Var letrec_name = Var()) {
tvm::Map<Var, Value> captured_mod;
Array<Var> free_vars = FreeVars(func);
......@@ -454,9 +456,9 @@ class Interpreter :
Value InvokePrimitiveOp(const Function& func,
const Array<Value>& args) {
auto call_node = func->body.as<CallNode>();
const auto* call_node = func->body.as<CallNode>();
if (call_node && call_node->op == Op::Get("debug")) {
if (call_node && call_node->op == debug_op_) {
auto dattrs = call_node->attrs.as<DebugAttrs>();
auto interp_state = this->get_state(call_node->args[0]);
......@@ -540,7 +542,7 @@ class Interpreter :
Array<Shape> out_shapes;
auto ret_type = func->body->checked_type();
bool is_dyn = IsDynamic(func->checked_type());
if (call_node->op == Op::Get("shape_of")) {
if (call_node->op == shape_of_op_) {
// The output shape of shape_of must be static since Relay doesn't support
// dynamic rank tensors.
is_dyn = false;
......@@ -782,6 +784,9 @@ class Interpreter :
Stack stack_;
// Backend compile engine.
CompileEngine engine_;
// Cache ops that need to be frequently used later to reduce lookup overhead.
const Op& debug_op_;
const Op& shape_of_op_;
};
......
......@@ -62,6 +62,8 @@ namespace relay {
// \endcode
class CastCanonicalizer : public ExprMutator {
public:
CastCanonicalizer() : cast_op_(Op::Get("cast")) {}
Expr VisitExpr_(const CallNode* call) {
static auto fpattern = Op::GetAttr<TOpPattern>("TOpPattern");
......@@ -91,15 +93,17 @@ class CastCanonicalizer : public ExprMutator {
private:
std::unordered_map<const Node*, size_t> ref_counter_;
// cast op is frequently checked for equivalence. Therefore, we cache it to
// reduce lookup overhead.
const Op& cast_op_;
Expr GetNewCallArg(const Expr& e) {
// if e is a upcast and ref count > 1, create an copy; otherwise call the default visitor
static auto& cast = Op::Get("cast");
Expr new_expr = this->VisitExpr(e);
if (const CallNode* call = e.as<CallNode>()) {
if (call->op.same_as(cast)) {
if (call->op == cast_op_) {
auto attrs = call->attrs.as<CastAttrs>();
const auto* from_type = call->args[0]->type_as<TensorTypeNode>();
CHECK(from_type);
......@@ -108,7 +112,7 @@ class CastCanonicalizer : public ExprMutator {
if (++ref_counter_[call] > 1) {
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
CHECK(new_call->op.same_as(cast));
CHECK(new_call->op == cast_op_);
return CallNode::make(new_call->op, new_call->args, new_call->attrs,
new_call->type_args);
}
......
......@@ -24,6 +24,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/transform.h>
#include "pattern_util.h"
......@@ -33,10 +34,11 @@ namespace relay {
class BiasAddSimplifier : public ExprMutator {
public:
BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}
Expr VisitExpr_(const CallNode* n) {
static const Op& bias_add = Op::Get("nn.bias_add");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(bias_add)) {
if (n->op == bias_add_op_) {
Call call = Downcast<Call>(new_n);
CHECK_EQ(call->args.size(), 2);
const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
......@@ -54,6 +56,10 @@ class BiasAddSimplifier : public ExprMutator {
}
return new_n;
}
private:
// Cache the bias_add for equivalence checking.
const Op& bias_add_op_;
};
Expr CanonicalizeOps(const Expr& e) {
......
......@@ -27,29 +27,30 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
#include <algorithm>
#include <utility>
#include <unordered_map>
#include <unordered_set>
#include "./expr_subst.h"
#include "./pattern_util.h"
#include "./combine_parallel_op.h"
#include "expr_subst.h"
#include "pattern_util.h"
#include "combine_parallel_op.h"
namespace tvm {
namespace relay {
BranchGroupFinder::BranchGroupFinder(const std::string& op_name,
BranchGroupFinder::BranchGroupFinder(const Op& op,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops)
: op_name_(op_name),
: cached_op_(op),
fis_supported_op_(fis_supported_op),
fare_compatible_ops_(fare_compatible_ops) {
}
std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const Op& op = Op::Get(op_name_);
this->VisitExpr(expr);
std::vector<Group> groups;
......@@ -57,7 +58,7 @@ std::vector<Group> BranchGroupFinder::Find(const Expr& expr) {
const auto& children = children_map_.at(root);
size_t ngroups = groups.size();
for (const CallNode* child : children) {
if (!child->op.same_as(op)) continue;
if (child->op != cached_op_) continue;
auto&& branch = CreateBranch(child);
// add the branch to a group, or create a new group
......@@ -97,9 +98,8 @@ Branch BranchGroupFinder::CreateBranch(const CallNode* op) {
}
void BranchGroupFinder::VisitExpr_(const CallNode* n) {
const Op& op = Op::Get(op_name_);
ExprVisitor::VisitExpr_(n);
if (n->op.same_as(op) && fis_supported_op_(n)) {
if (n->op == cached_op_ && fis_supported_op_(n)) {
op_roots_.insert(n->args[0]);
children_map_[n->args[0]].push_back(n);
} else {
......@@ -110,12 +110,12 @@ void BranchGroupFinder::VisitExpr_(const CallNode* n) {
}
ParallelOpCombiner::ParallelOpCombiner(const std::string& op_name, uint64_t min_num_branches)
: op_name_(op_name),
: cached_op_(Op::Get(op_name)),
min_num_branches_(min_num_branches) {
}
Expr ParallelOpCombiner::Combine(const Expr& expr) {
auto groups = BranchGroupFinder(op_name_,
auto groups = BranchGroupFinder(cached_op_,
[&](const CallNode* n) {
return IsSupportedOp(n);
},
......
......@@ -68,13 +68,13 @@ class BranchGroupFinder : private ExprVisitor {
public:
/*
* \brief Constructor
* \param op_name name of op to start each group
* \param op The op that indicates the start of each group
* \param fis_supported_op function that returns true if op
* is supported for combining
* \param fare_compatible_ops function that returns true if
* two ops are compatible for combining
*/
BranchGroupFinder(const std::string& op_name,
BranchGroupFinder(const Op& op,
FIsSupportedOp fis_supported_op,
FAreCompatibleOps fare_compatible_ops);
......@@ -87,8 +87,8 @@ class BranchGroupFinder : private ExprVisitor {
std::vector<Group> Find(const Expr& expr);
private:
/* \brief name of op to find parallel branches for */
std::string op_name_;
/* \brief Cache the op for finding parallel branches */
const Op& cached_op_;
/* \brief function to return true if op is eligible to be combined,
* false otherwise
......@@ -205,8 +205,8 @@ class ParallelOpCombiner {
ExprSubstMap* subst_map) = 0;
private:
/* \brief name of op to be combined */
std::string op_name_;
/* \brief Cache the op to be combined */
const Op& cached_op_;
/* \brief minimum number of parallel branches to combine */
uint64_t min_num_branches_;
......
......@@ -22,6 +22,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/attrs/transform.h>
......@@ -33,7 +34,6 @@ namespace relay {
using FInterpreter = runtime::TypedPackedFunc<Value(Expr)>;
class ConstantChecker : private ExprVisitor {
public:
// Check whether an expression is constant. The results are memoized.
......@@ -78,8 +78,14 @@ TVM_REGISTER_API("relay._analysis.check_constant")
class ConstantFolder : public ExprMutator {
public:
explicit ConstantFolder(FInterpreter executor, Module module)
: executor_(executor), module_(module) {
}
: executor_(executor),
module_(module),
shape_of_op_(Op::Get("shape_of")),
invoke_tvm_op_(Op::Get("memory.invoke_tvm_op")),
shape_func_op_(Op::Get("memory.shape_func")),
alloc_tensor_op_(Op::Get("memory.alloc_tensor")),
alloc_storage_op_(Op::Get("memory.alloc_storage")),
cast_op_(Op::Get("cast")) {}
Expr VisitExpr_(const LetNode* op) final {
Expr value = this->Mutate(op->value);
......@@ -119,15 +125,15 @@ class ConstantFolder : public ExprMutator {
// skip stateful ops.
if (op_stateful.get(GetRef<Op>(op), false)) return res;
// Try to evaluate shape_of op
if (call->op.same_as(Op::Get("shape_of"))) {
if (call->op == shape_of_op_) {
return EvaluateShapeOf(res, origin_args, call->attrs);
}
// We should think about potentially constant evaluation over these ops too.
if (call->op.same_as(Op::Get("memory.invoke_tvm_op")) ||
call->op.same_as(Op::Get("memory.shape_func")) ||
call->op.same_as(Op::Get("memory.alloc_tensor")) ||
call->op.same_as(Op::Get("memory.alloc_storage"))) {
if (call->op == invoke_tvm_op_ ||
call->op == shape_func_op_ ||
call->op == alloc_tensor_op_ ||
call->op == alloc_storage_op_) {
return GetRef<Call>(call);
}
......@@ -162,6 +168,14 @@ class ConstantFolder : public ExprMutator {
// Module
Module module_;
// Cache the following ops for equivalence checking in this pass.
const Op& shape_of_op_;
const Op& invoke_tvm_op_;
const Op& shape_func_op_;
const Op& alloc_tensor_op_;
const Op& alloc_storage_op_;
const Op& cast_op_;
// Convert value to expression.
Expr ValueToExpr(Value value) {
if (const auto* val = value.as<TensorValueNode>()) {
......@@ -254,8 +268,7 @@ class ConstantFolder : public ExprMutator {
// Cast the constant into correct dtype
auto cast_attrs = make_node<CastAttrs>();
cast_attrs->dtype = param->dtype;
static const Op& cast_op = Op::Get("cast");
Expr ret = CallNode::make(cast_op, { shape }, Attrs(cast_attrs), {});
Expr ret = CallNode::make(cast_op_, { shape }, Attrs(cast_attrs), {});
return ConstEvaluate(ret);
}
};
......
......@@ -78,6 +78,8 @@ using common::LinkedList;
constexpr uint32_t kMaxFusedOps = 256;
static const Op& stop_fusion_op = Op::Get("annotation.stop_fusion");
/*!
* \brief Indexed data flow graph in forward direction.
* This is a temporary data structure used for operator fusion analysis.
......@@ -860,7 +862,6 @@ class FuseMutator : private ExprMutator {
// Transform calls.
Expr VisitExpr_(const CallNode* call) {
static const Op& stop_fusion = Op::Get("annotation.stop_fusion");
if (call->op.as<OpNode>()) {
static auto fnoncomputational =
Op::GetAttr<TNonComputational>("TNonComputational");
......@@ -872,7 +873,7 @@ class FuseMutator : private ExprMutator {
// If it is a primitive op call
// then we must have a group assignment for it already.
CHECK(gmap_.count(call));
if (call->op.same_as(stop_fusion)) {
if (call->op == stop_fusion_op) {
return ExprMutator::VisitExpr(call->args[0]);
}
auto* ret_group = gmap_.at(call)->FindRoot();
......
......@@ -559,30 +559,28 @@ struct WithFuncIdAttrs : public tvm::AttrsNode<WithFuncIdAttrs> {
TVM_REGISTER_NODE_TYPE(WithFuncIdAttrs);
Op WithFuncIdOp() {
static const Op& op = Op::Get("annotation.with_funcid");
return op;
}
Expr MkWithFuncId(const Expr& expr, FuncId fid) {
auto attrs = make_node<WithFuncIdAttrs>();
attrs->fid = fid;
return CallNode::make(WithFuncIdOp(), {expr}, Attrs(attrs), {});
}
RELAY_REGISTER_OP("annotation.with_funcid")
.describe(R"code(Annotate a function with a funcid.)code"
TVM_ADD_FILELINE)
.set_num_inputs(1)
.add_argument("func", "Function", "The input data.");
// Cache with_funcid op to reduce lookup overhead during traversal.
static const Op& with_funcid_op = Op::Get("annotation.with_funcid");
Expr MkWithFuncId(const Expr& expr, FuncId fid) {
auto attrs = make_node<WithFuncIdAttrs>();
attrs->fid = fid;
return CallNode::make(with_funcid_op, {expr}, Attrs(attrs), {});
}
Expr StripWithFuncId(const Expr& e);
Function AsFunc(const Expr& e) {
if (e.as<FunctionNode>()) {
return Downcast<Function>(e);
} else if (const CallNode* c = e.as<CallNode>()) {
CHECK(c->op.same_as(WithFuncIdOp()));
CHECK(c->op == with_funcid_op);
CHECK_EQ(c->args.size(), 1);
return AsFunc(c->args[0]);
} else {
......@@ -604,7 +602,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic VisitExpr(const Expr& e, LetList* ll, const Var& name) {
if (const CallNode* c = e.as<CallNode>()) {
if (c->op.same_as(WithFuncIdOp())) {
if (c->op == with_funcid_op) {
CHECK_EQ(c->args.size(), 1);
return VisitExpr(c->args[0], ll, name);
}
......@@ -722,7 +720,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
if (op->op.same_as(WithFuncIdOp())) {
if (op->op == with_funcid_op) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0], ll);
}
......@@ -1096,7 +1094,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
explicit RegisterFuncIdVisitor(PartialEvaluator* pe) : pe(pe) { }
void VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) {
if (op->op == with_funcid_op) {
CHECK_EQ(op->args.size(), 1);
CHECK(op->attrs.defined());
CHECK(op->attrs.as<WithFuncIdAttrs>());
......@@ -1194,7 +1192,7 @@ Expr Remap(const Expr& e) {
Expr StripWithFuncId(const Expr& e) {
struct StripWithFuncIdMutator : ExprMutator, PatternMutator {
Expr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(WithFuncIdOp())) {
if (op->op == with_funcid_op) {
CHECK_EQ(op->args.size(), 1);
return VisitExpr(op->args[0]);
} else {
......
......@@ -25,15 +25,17 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include "./quantize.h"
namespace tvm {
namespace relay {
namespace quantize {
class StatsCollector : private ExprMutator {
public:
StatsCollector() : simulated_quantize_op_(Op::Get("relay.op.annotation.simulated_quantize")) {}
Expr Collect(const Expr& expr) {
auto new_e = this->Mutate(expr);
const FunctionNode* func = new_e.as<FunctionNode>();
......@@ -45,13 +47,13 @@ class StatsCollector : private ExprMutator {
private:
Array<Expr> profile_data_;
const Op& simulated_quantize_op_;
Expr VisitExpr_(const CallNode* call) {
static const Op& simulated_quantize = Op::Get("relay.op.annotation.simulated_quantize");
Expr new_e = ExprMutator::VisitExpr_(call);
const CallNode* new_call = new_e.as<CallNode>();
CHECK(new_call);
if (new_call->op.same_as(simulated_quantize)) {
if (new_call->op == simulated_quantize_op_) {
auto attrs = new_call->attrs.as<SimulatedQuantizeAttrs>();
// rewrite the annotation
auto new_attrs = make_node<SimulatedQuantizeAttrs>();
......
......@@ -91,7 +91,6 @@ Expr LayerNormToInferUnpack(const Attrs attrs,
return out;
}
Expr InstanceNormToInferUnpack(const Attrs attrs,
Expr data,
Expr gamma,
......@@ -125,23 +124,25 @@ Expr InstanceNormToInferUnpack(const Attrs attrs,
return out;
}
class InferenceSimplifier : public ExprMutator {
public:
Expr VisitExpr_(const TupleGetItemNode* n) final {
static const Op& batch_norm = Op::Get("nn.batch_norm");
static const Op& dropout = Op::Get("nn.dropout");
InferenceSimplifier()
: batch_norm_op_(Op::Get("nn.batch_norm")),
dropout_op_(Op::Get("nn.dropout")),
instance_norm_op_(Op::Get("nn.instance_norm")),
layer_norm_op_(Op::Get("nn.layer_norm")) {}
Expr VisitExpr_(const TupleGetItemNode* n) final {
Expr new_e = ExprMutator::VisitExpr_(n);
const auto* new_n = new_e.as<TupleGetItemNode>();
if (new_n->index != 0) {
return new_e;
}
if (const auto* call = new_n->tuple.as<CallNode>()) {
if (call->op.same_as(batch_norm)) {
if (call->op == batch_norm_op_) {
return BatchNormToInferUnpack(call->attrs, call->args[0], call->args[1], call->args[2],
call->args[3], call->args[4], ty_map_.at(call->args[0]));
} else if (call->op.same_as(dropout)) {
} else if (call->op == dropout_op_) {
return call->args[0];
}
}
......@@ -149,17 +150,14 @@ class InferenceSimplifier : public ExprMutator {
}
Expr VisitExpr_(const CallNode* n) {
static const Op& batch_norm = Op::Get("nn.batch_norm");
static const Op& instance_norm = Op::Get("nn.instance_norm");
static const Op& layer_norm = Op::Get("nn.layer_norm");
auto new_n = ExprMutator::VisitExpr_(n);
if (n->op.same_as(batch_norm)) {
if (n->op == batch_norm_op_) {
ty_map_[new_n.as<CallNode>()->args[0]] = n->args[0]->checked_type();
} else if (n->op.same_as(layer_norm)) {
} else if (n->op == layer_norm_op_) {
const auto* call = new_n.as<CallNode>();
return LayerNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
} else if (n->op.same_as(instance_norm)) {
} else if (n->op == instance_norm_op_) {
const auto* call = new_n.as<CallNode>();
return InstanceNormToInferUnpack(call->attrs, call->args[0], call->args[1],
call->args[2], n->args[0]->checked_type());
......@@ -168,6 +166,13 @@ class InferenceSimplifier : public ExprMutator {
}
private:
// Cache the following ops. They will be used in the passes repeatedly for
// operator equivalence checking so that the registry lookup overhead can be
// reduced.
const Op& batch_norm_op_;
const Op& dropout_op_;
const Op& instance_norm_op_;
const Op& layer_norm_op_;
std::unordered_map<Expr, Type, NodeHash, NodeEqual> ty_map_;
};
......
......@@ -25,6 +25,7 @@
*/
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/pattern_functor.h>
#include "pass_util.h"
#include "../ir/type_functor.h"
......@@ -360,13 +361,14 @@ bool IsNDArrayAllGreaterEqual(const runtime::NDArray& tensor, T value) {
return true;
}
// Cache the operators that are checked recursively to reduce lookup overhead.
static const auto& expand_dims_op = Op::Get("expand_dims");
static const auto& reshape_op = Op::Get("reshape");
static const auto& transpose_op = Op::Get("transpose");
static const auto& squeeze_op = Op::Get("squeeze");
bool IsAllPositiveConstant(const Expr& expr) {
// peel through a few common transform ops.
static const auto& expand_dims = Op::Get("expand_dims");
static const auto& reshape = Op::Get("reshape");
static const auto& transpose = Op::Get("transpose");
static const auto& squeeze = Op::Get("squeeze");
if (const auto* constant = expr.as<ConstantNode>()) {
const auto& tensor = constant->data;
const auto& dtype = tensor->dtype;
......@@ -389,10 +391,10 @@ bool IsAllPositiveConstant(const Expr& expr) {
}
} else if (const auto* op = expr.as<CallNode>()) {
// tail recursion.
if (op->op.same_as(expand_dims) ||
op->op.same_as(reshape) ||
op->op.same_as(transpose) ||
op->op.same_as(squeeze)) {
if (op->op == expand_dims_op ||
op->op == reshape_op ||
op->op == transpose_op ||
op->op == squeeze_op) {
return IsAllPositiveConstant(op->args[0]);
} else {
return false;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment