Unverified Commit 2adcb738 by Zhi Committed by GitHub

[Refactor] Relay Node::make to constructor (#5128)

* relay Node::make to constructor

* patternwildcard

* Address comments
parent 1405b7ba
......@@ -26,11 +26,12 @@
#include <tvm/ir/attrs.h>
#include <tvm/ir/adt.h>
#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <string>
#include <functional>
#include "./base.h"
#include "./type.h"
#include "./expr.h"
#include <utility>
namespace tvm {
namespace relay {
......@@ -69,10 +70,6 @@ class PatternWildcard;
/*! \brief PatternWildcard container node */
class PatternWildcardNode : public PatternNode {
public:
PatternWildcardNode() {}
TVM_DLL static PatternWildcard make();
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span);
}
......@@ -83,7 +80,29 @@ class PatternWildcardNode : public PatternNode {
class PatternWildcard : public Pattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(PatternWildcard, Pattern, PatternWildcardNode);
/* \brief Overload the default constructors. */
TVM_DLL PatternWildcard();
explicit PatternWildcard(ObjectPtr<Object> n) : Pattern(n) {}
/* \brief Copy constructor. */
PatternWildcard(const PatternWildcard& pat) : PatternWildcard(pat.data_) {}
/* \brief Move constructor. */
PatternWildcard(PatternWildcard&& pat) : PatternWildcard(std::move(pat.data_)) {}
/* \brief Copy assignment. */
PatternWildcard& operator=(const PatternWildcard& other) {
(*this).data_ = other.data_;
return *this;
}
/* \brief Move assignment. */
PatternWildcard& operator=(PatternWildcard&& other) {
(*this).data_ = std::move(other.data_);
return *this;
}
const PatternWildcardNode* operator->() const {
return static_cast<const PatternWildcardNode*>(get());
}
using ContainerType = PatternWildcardNode;
};
/*! \brief A var pattern. Accept all input and bind to a var. */
......@@ -91,13 +110,9 @@ class PatternVar;
/*! \brief PatternVar container node */
class PatternVarNode : public PatternNode {
public:
PatternVarNode() {}
/*! \brief Variable that stores the matched value. */
tvm::relay::Var var;
TVM_DLL static PatternVar make(tvm::relay::Var var);
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("span", &span);
......@@ -109,6 +124,12 @@ class PatternVarNode : public PatternNode {
class PatternVar : public Pattern {
public:
/*!
* \brief Constructor
* \param var The var to construct a pattern
*/
TVM_DLL explicit PatternVar(tvm::relay::Var var);
TVM_DEFINE_OBJECT_REF_METHODS(PatternVar, Pattern, PatternVarNode);
};
......@@ -122,10 +143,6 @@ class PatternConstructorNode : public PatternNode {
/*! Sub-patterns to match against each input to the constructor. */
tvm::Array<Pattern> patterns;
PatternConstructorNode() {}
TVM_DLL static PatternConstructor make(Constructor constructor, tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("constructor", &constructor);
v->Visit("patterns", &patterns);
......@@ -138,6 +155,13 @@ class PatternConstructorNode : public PatternNode {
class PatternConstructor : public Pattern {
public:
/*!
* \brief Constructor
* \param constructor The constructor of a pattern
* \param patterns The sub-patterns for matching
*/
TVM_DLL PatternConstructor(Constructor constructor, tvm::Array<Pattern> patterns);
TVM_DEFINE_OBJECT_REF_METHODS(PatternConstructor, Pattern, PatternConstructorNode);
};
......@@ -149,10 +173,6 @@ class PatternTupleNode : public PatternNode {
/*! Sub-patterns to match against each value of the tuple. */
tvm::Array<Pattern> patterns;
PatternTupleNode() {}
TVM_DLL static PatternTuple make(tvm::Array<Pattern> var);
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("patterns", &patterns);
v->Visit("span", &span);
......@@ -164,6 +184,12 @@ class PatternTupleNode : public PatternNode {
class PatternTuple : public Pattern {
public:
/*!
* \brief Constructor
* \param patterns The sub-patterns to match against each value of the tuple
*/
TVM_DLL explicit PatternTuple(tvm::Array<Pattern> patterns);
TVM_DEFINE_OBJECT_REF_METHODS(PatternTuple, Pattern, PatternTupleNode);
};
......@@ -182,14 +208,19 @@ class ClauseNode : public Object {
v->Visit("rhs", &rhs);
}
TVM_DLL static Clause make(Pattern lhs, Expr rhs);
static constexpr const char* _type_key = "relay.Clause";
TVM_DECLARE_FINAL_OBJECT_INFO(ClauseNode, Object);
};
class Clause : public ObjectRef {
public:
/*!
* \brief Constructor
* \param lhs The pattern matched by the clause.
* \param rhs The resulting value
*/
TVM_DLL explicit Clause(Pattern lhs, Expr rhs);
TVM_DEFINE_OBJECT_REF_METHODS(Clause, ObjectRef, ClauseNode);
};
......@@ -217,14 +248,20 @@ class MatchNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Match make(Expr data, tvm::Array<Clause> pattern, bool complete = true);
static constexpr const char* _type_key = "relay.Match";
TVM_DECLARE_FINAL_OBJECT_INFO(MatchNode, ExprNode);
};
class Match : public Expr {
public:
/*!
* \brief Constructor
* \param data the input being deconstructed.
* \param clauses The clauses for matching.
* \param complete Indicate if this match is complete.
*/
TVM_DLL Match(Expr data, tvm::Array<Clause> clauses, bool complete = true);
TVM_DEFINE_OBJECT_REF_METHODS(Match, RelayExpr, MatchNode);
};
......
......@@ -103,6 +103,12 @@ class IdNode : public Object {
class Id : public ObjectRef {
public:
/*!
* \brief The constructor
* \param name_hint The name of the variable.
*/
TVM_DLL explicit Id(std::string name_hint);
TVM_DEFINE_OBJECT_REF_METHODS(Id, ObjectRef, IdNode);
};
......
......@@ -72,14 +72,18 @@ class ConstantNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Constant make(runtime::NDArray data);
static constexpr const char* _type_key = "relay.Constant";
TVM_DECLARE_FINAL_OBJECT_INFO(ConstantNode, ExprNode);
};
class Constant : public Expr {
public:
/*!
* \brief The constructor
* \param data The data of the constant tensor.
*/
TVM_DLL explicit Constant(runtime::NDArray data);
TVM_DEFINE_OBJECT_REF_METHODS(Constant, RelayExpr, ConstantNode);
};
......@@ -97,14 +101,18 @@ class TupleNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Tuple make(tvm::Array<relay::Expr> fields);
static constexpr const char* _type_key = "relay.Tuple";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleNode, ExprNode);
};
class Tuple : public Expr {
public:
/*!
* \brief The constructor
* \param fields The fields of a tuple.
*/
TVM_DLL explicit Tuple(tvm::Array<relay::Expr> fields);
TVM_DEFINE_OBJECT_REF_METHODS(Tuple, RelayExpr, TupleNode);
};
......@@ -161,6 +169,21 @@ class VarNode : public ExprNode {
class Var : public Expr {
public:
/*!
* \brief The constructor
* \param name_hint The name hint of a variable.
* \param type_annotation The type annotation of a variable.
*/
TVM_DLL Var(std::string name_hint, Type type_annotation) :
Var(Id(name_hint), type_annotation) {}
/*!
* \brief The constructor
* \param vid The unique id of a variable.
* \param type_annotation The type annotation of a variable.
*/
TVM_DLL Var(Id vid, Type type_annotation);
TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};
......@@ -215,17 +238,24 @@ class CallNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Call make(Expr op,
Array<Expr> args,
Attrs attrs = Attrs(),
Array<Type> type_args = Array<Type>());
static constexpr const char* _type_key = "relay.Call";
TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode);
};
class Call : public Expr {
public:
/*!
* \brief The constructor
* \param op The operator will be invoked.
* \param args The arguments of the call.
* \param attrs The attributes of the call node.
* \param type_args The type arguments passed to a polymorphic function.
*/
TVM_DLL Call(Expr op,
Array<Expr> args,
Attrs attrs = Attrs(),
Array<Type> type_args = Array<Type>());
TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
};
......@@ -259,14 +289,20 @@ class LetNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static Let make(Var var, Expr value, Expr body);
static constexpr const char* _type_key = "relay.Let";
TVM_DECLARE_FINAL_OBJECT_INFO(LetNode, ExprNode);
};
class Let : public Expr {
public:
/*!
* \brief The constructor
* \param var The variable that is bound to.
* \param value The value used to bind to the variable.
* \param body The body of the let binding.
*/
TVM_DLL Let(Var var, Expr value, Expr body);
TVM_DEFINE_OBJECT_REF_METHODS(Let, RelayExpr, LetNode);
};
......@@ -300,14 +336,20 @@ class IfNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static If make(Expr cond, Expr true_branch, Expr false_branch);
static constexpr const char* _type_key = "relay.If";
TVM_DECLARE_FINAL_OBJECT_INFO(IfNode, ExprNode);
};
class If : public Expr {
public:
/*!
* \brief The constructor
* \param cond The condition of a if node.
* \param true_branch The fall through branch
* \param false_branch The branch for execution when condition is false.
*/
TVM_DLL If(Expr cond, Expr true_branch, Expr false_branch);
TVM_DEFINE_OBJECT_REF_METHODS(If, RelayExpr, IfNode);
};
......@@ -327,14 +369,19 @@ class TupleGetItemNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static TupleGetItem make(Expr tuple, int index);
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleGetItemNode, ExprNode);
};
class TupleGetItem : public Expr {
public:
/*!
* \brief The constructor
* \param tuple The tuple to get an element from.
* \param index The index for extracting a value in the tuple.
*/
TVM_DLL TupleGetItem(Expr tuple, int index);
TVM_DEFINE_OBJECT_REF_METHODS(TupleGetItem, RelayExpr, TupleGetItemNode);
};
......@@ -351,14 +398,18 @@ class RefCreateNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static RefCreate make(Expr value);
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_FINAL_OBJECT_INFO(RefCreateNode, ExprNode);
};
class RefCreate : public Expr {
public:
/*!
* \brief The constructor
* \param value The initial value of the reference.
*/
TVM_DLL explicit RefCreate(Expr value);
TVM_DEFINE_OBJECT_REF_METHODS(RefCreate, RelayExpr, RefCreateNode);
};
......@@ -375,14 +426,18 @@ class RefReadNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}
TVM_DLL static RefRead make(Expr ref);
static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_FINAL_OBJECT_INFO(RefReadNode, ExprNode);
};
class RefRead : public Expr {
public:
/*!
* \brief The constructor
* \param ref The reference where to read data.
*/
TVM_DLL explicit RefRead(Expr ref);
TVM_DEFINE_OBJECT_REF_METHODS(RefRead, RelayExpr, RefReadNode);
};
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
......@@ -409,6 +464,13 @@ class RefWriteNode : public ExprNode {
class RefWrite : public Expr {
public:
/*!
* \brief The constructor
* \param ref The reference where data is write to.
* \param value The value to write.
*/
TVM_DLL RefWrite(Expr ref, Expr value);
TVM_DEFINE_OBJECT_REF_METHODS(RefWrite, RelayExpr, RefWriteNode);
};
......
......@@ -335,9 +335,6 @@ class BijectiveLayoutNode : public Object {
static constexpr const char* _type_key = "BijectiveLayout";
TVM_DECLARE_FINAL_OBJECT_INFO(BijectiveLayoutNode, Object);
TVM_DLL static BijectiveLayout make(const Layout& src_layout,
const Layout& dst_layout);
};
/*! \brief Bijective function mapping for data layout transformation.
......@@ -349,6 +346,12 @@ class BijectiveLayout : public ObjectRef {
public:
BijectiveLayout() = default;
explicit BijectiveLayout(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief The constructor
* \param src_layout The source layout
* \param dst_layout The destination layout
*/
TVM_DLL BijectiveLayout(Layout src_layout, Layout dst_layout);
// Given the source shape, infer the destination shape.
TVM_DLL Array<PrimExpr> ForwardShape(const Array<PrimExpr>& shape) const;
......
......@@ -191,9 +191,9 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
for (auto constructor : td->constructors) {
Array<Pattern> args;
for (auto inp : constructor->inputs) {
args.push_back(PatternWildcardNode::make());
args.push_back(PatternWildcard());
}
ret.push_back(PatternConstructorNode::make(constructor, args));
ret.push_back(PatternConstructor(constructor, args));
}
return ret;
}
......@@ -212,7 +212,7 @@ Array<Pattern> ExpandWildcardsConstructor(const PatternConstructor& clause_ctor,
auto all_subfields = CartesianProduct(values_by_field);
Array<Pattern> ret;
for (auto subfields : all_subfields) {
ret.push_back(PatternConstructorNode::make(ctor_cand->constructor, subfields));
ret.push_back(PatternConstructor(ctor_cand->constructor, subfields));
}
return ret;
}
......@@ -226,9 +226,9 @@ Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
if (cand.as<PatternWildcardNode>()) {
Array<Pattern> args;
for (auto inp : clause_tuple->patterns) {
args.push_back(PatternWildcardNode::make());
args.push_back(PatternWildcard());
}
return {PatternTupleNode::make(args)};
return {PatternTuple(args)};
}
auto tuple_cand = Downcast<PatternTuple>(cand);
......@@ -245,7 +245,7 @@ Array<Pattern> ExpandWildcardsTuple(const PatternTuple& clause_tuple,
auto all_subfields = CartesianProduct(values_by_field);
Array<Pattern> ret;
for (auto subfields : all_subfields) {
ret.push_back(PatternTupleNode::make(subfields));
ret.push_back(PatternTuple(subfields));
}
return ret;
}
......@@ -272,7 +272,7 @@ Array<Pattern> UnmatchedCases(const Match& match, const IRModule& mod) {
* return failed_candidates
*/
std::stack<Pattern> candidates;
candidates.push(PatternWildcardNode::make());
candidates.push(PatternWildcard());
CandidateChecker checker;
Array<Pattern> failures;
......
......@@ -666,7 +666,7 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver")
ErrorReporter *err_reporter = new ErrorReporter();
auto module = IRModule({}, {});
auto dummy_fn_name = GlobalVar("test");
module->Add(dummy_fn_name, Function({}, TupleNode::make({}), Type(), {}, {}));
module->Add(dummy_fn_name, Function({}, Tuple(tvm::Array<relay::Expr>({})), Type(), {}, {}));
auto solver = std::make_shared<TypeSolver>(dummy_fn_name, module, err_reporter);
auto mod = [module, solver, err_reporter](std::string name) -> PackedFunc {
......@@ -689,7 +689,7 @@ TVM_REGISTER_GLOBAL("relay.analysis._test_type_solver")
});
} else if (name == "AddConstraint") {
return TypedPackedFunc<void(TypeConstraint)>([solver](TypeConstraint c) {
Expr e = VarNode::make("dummy_var",
Expr e = Var("dummy_var",
IncompleteType(Kind::kType));
return solver->AddConstraint(c, e);
});
......
......@@ -433,7 +433,7 @@ Expr TypeSubst(const Expr& expr, const tvm::Map<TypeVar, Type>& subst_map) {
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
return Clause(pat, VisitExpr(c->rhs));
}
private:
......
......@@ -210,7 +210,7 @@ class RelayBuildModule : public runtime::ModuleNode {
Map<std::string, Constant> GetParams() {
Map<std::string, Constant> ret;
for (const auto& kv : ret_.params) {
ret.Set(kv.first, ConstantNode::make(kv.second));
ret.Set(kv.first, Constant(kv.second));
}
return ret;
}
......
......@@ -60,11 +60,11 @@ LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation im
data_ = std::move(n);
}
CCacheKey CCacheKeyNode::make(Function source_func, Target target) {
CCacheKey::CCacheKey(Function source_func, Target target) {
auto n = make_object<CCacheKeyNode>();
n->source_func = std::move(source_func);
n->target = std::move(target);
return CCacheKey(n);
data_ = std::move(n);
}
struct IsDynamicVisitor : public TypeVisitor {
......@@ -819,7 +819,9 @@ TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
});
TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed(CCacheKeyNode::make);
.set_body_typed([](Function source_func, Target target) {
return CCacheKey(source_func, target);
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal")
.set_body_typed([]() {
......
......@@ -124,14 +124,6 @@ class CCacheKeyNode : public Object {
* \return The result of equality check.
*/
inline bool Equal(const CCacheKeyNode* other) const;
/*!
* \brief create a cache key.
* \param source_func The source function.
* \param target The target device.
* \return the created key.
*/
TVM_DLL static CCacheKey make(Function source_func,
Target target);
static constexpr const char* _type_key = "relay.CCacheKey";
TVM_DECLARE_FINAL_OBJECT_INFO(CCacheKeyNode, tvm::Object);
......@@ -148,6 +140,14 @@ class CCacheKey : public ObjectRef {
public:
CCacheKey() {}
explicit CCacheKey(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief The constructor
* \param source_func The source function.
* \param target The target device.
*/
TVM_DLL CCacheKey(Function source_func, Target target);
const CCacheKeyNode* operator->() const {
return static_cast<const CCacheKeyNode*>(get());
}
......
......@@ -309,7 +309,7 @@ class Interpreter :
Array<Shape> ComputeDynamicShape(const Function& func,
const Array<ObjectRef>& args) {
auto key = CCacheKeyNode::make(func, Target::Create("llvm"));
CCacheKey key(func, Target::Create("llvm"));
auto cfunc = engine_->LowerShapeFunc(key);
size_t arity = cfunc->inputs.size() + cfunc->outputs.size();
......@@ -520,7 +520,7 @@ class Interpreter :
out_shapes = ComputeDynamicShape(func, args);
}
PackedFunc packed_func = engine_->JIT(CCacheKeyNode::make(func, target_));
PackedFunc packed_func = engine_->JIT(CCacheKey(func, target_));
TVMRetValue rv;
if (const TupleTypeNode* rtype = func->body->checked_type().as<TupleTypeNode>()) {
CHECK(!is_dyn || out_shapes.size() == rtype->fields.size());
......
......@@ -113,7 +113,7 @@ BindParamsByName(relay::Function func,
if (repeat_var.count(arg)) {
LOG(FATAL) << "Multiple args in the function have name " << kv.first;
}
bind_dict[arg] = ConstantNode::make(kv.second);
bind_dict[arg] = Constant(kv.second);
}
Expr bound_expr = relay::Bind(func, bind_dict);
Function ret = Downcast<Function>(bound_expr);
......
......@@ -404,7 +404,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
// Lower shape function
auto key = CCacheKeyNode::make(func, target_host_);
CCacheKey key(func, target_host_);
auto cfunc = engine_->LowerShapeFunc(key);
int op_index = -1;
if (context_->seen_funcs.count(cfunc->funcs[0]) == 0) {
......@@ -485,7 +485,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
}
}
auto key = CCacheKeyNode::make(func, target);
CCacheKey key(func, target);
auto cfunc = engine_->Lower(key);
auto op_index = -1;
......@@ -780,7 +780,7 @@ PackedFunc VMCompiler::GetFunction(const std::string& name,
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
Map<std::string, Constant> ret;
for (const auto& kv : params_) {
ret.Set(kv.first, ConstantNode::make(kv.second));
ret.Set(kv.first, Constant(kv.second));
}
*rv = ret;
});
......
......@@ -92,7 +92,7 @@ struct PrimitiveInliner : ExprMutator {
auto new_arg = VisitExpr(arg);
call_args.push_back(new_arg);
}
return CallNode::make(GetRef<Function>(func), call_args, call->attrs, call->type_args);
return Call(GetRef<Function>(func), call_args, call->attrs, call->type_args);
}
}
......@@ -102,7 +102,7 @@ struct PrimitiveInliner : ExprMutator {
auto new_arg = VisitExpr(arg);
call_args.push_back(new_arg);
}
return CallNode::make(GetRef<GlobalVar>(global), call_args, call->attrs, call->type_args);
return Call(GetRef<GlobalVar>(global), call_args, call->attrs, call->type_args);
}
return ExprMutator::VisitExpr_(call);
......
......@@ -73,7 +73,7 @@ class LambdaLifter : public ExprMutator {
letrec_.pop_back();
}
auto body = VisitExpr(let_node->body);
return LetNode::make(let_node->var, value, body);
return Let(let_node->var, value, body);
}
Expr VisitExpr_(const CallNode* call_node) final {
......@@ -83,7 +83,7 @@ class LambdaLifter : public ExprMutator {
if (!letrec_.empty() && var == letrec_.back()) {
auto it = lambda_map_.find(var);
CHECK(it != lambda_map_.end());
return CallNode::make(it->second, call->args, call_node->attrs,
return Call(it->second, call->args, call_node->attrs,
call_node->type_args);
}
}
......@@ -118,7 +118,7 @@ class LambdaLifter : public ExprMutator {
for (auto fv : captured_vars) {
fvs.push_back(fv);
}
lambda_map_.emplace(letrec_.back(), CallNode::make(global, fvs));
lambda_map_.emplace(letrec_.back(), Call(global, fvs));
} else {
lambda_map_.emplace(letrec_.back(), global);
}
......@@ -178,7 +178,7 @@ class LambdaLifter : public ExprMutator {
for (auto fv : captured_vars) {
fvs.push_back(fv);
}
return CallNode::make(global, fvs);
return Call(global, fvs);
}
}
......
......@@ -27,31 +27,35 @@
namespace tvm {
namespace relay {
PatternWildcard PatternWildcardNode::make() {
PatternWildcard::PatternWildcard() {
ObjectPtr<PatternWildcardNode> n = make_object<PatternWildcardNode>();
return PatternWildcard(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_GLOBAL("relay.ir.PatternWildcard")
.set_body_typed(PatternWildcardNode::make);
.set_body_typed([]() {
return PatternWildcard();
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const ObjectRef& ref, ReprPrinter* p) {
p->stream << "PatternWildcardNode()";
});
PatternVar PatternVarNode::make(tvm::relay::Var var) {
PatternVar::PatternVar(tvm::relay::Var var) {
ObjectPtr<PatternVarNode> n = make_object<PatternVarNode>();
n->var = std::move(var);
return PatternVar(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_GLOBAL("relay.ir.PatternVar")
.set_body_typed(PatternVarNode::make);
.set_body_typed([](tvm::relay::Var var) {
return PatternVar(var);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -59,18 +63,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "PatternVarNode(" << node->var << ")";
});
PatternConstructor PatternConstructorNode::make(Constructor constructor,
tvm::Array<Pattern> patterns) {
PatternConstructor::PatternConstructor(Constructor constructor,
tvm::Array<Pattern> patterns) {
ObjectPtr<PatternConstructorNode> n = make_object<PatternConstructorNode>();
n->constructor = std::move(constructor);
n->patterns = std::move(patterns);
return PatternConstructor(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_GLOBAL("relay.ir.PatternConstructor")
.set_body_typed(PatternConstructorNode::make);
.set_body_typed([](Constructor constructor, tvm::Array<Pattern> patterns) {
return PatternConstructor(constructor, patterns);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -79,16 +85,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ", " << node->patterns << ")";
});
PatternTuple PatternTupleNode::make(tvm::Array<Pattern> patterns) {
PatternTuple::PatternTuple(tvm::Array<Pattern> patterns) {
ObjectPtr<PatternTupleNode> n = make_object<PatternTupleNode>();
n->patterns = std::move(patterns);
return PatternTuple(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(PatternTupleNode);
TVM_REGISTER_GLOBAL("relay.ir.PatternTuple")
.set_body_typed(PatternTupleNode::make);
.set_body_typed([](tvm::Array<Pattern> patterns) {
return PatternTuple(patterns);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PatternTupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -96,17 +104,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "PatternTupleNode(" << node->patterns << ")";
});
Clause ClauseNode::make(Pattern lhs, Expr rhs) {
Clause::Clause(Pattern lhs, Expr rhs) {
ObjectPtr<ClauseNode> n = make_object<ClauseNode>();
n->lhs = std::move(lhs);
n->rhs = std::move(rhs);
return Clause(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_GLOBAL("relay.ir.Clause")
.set_body_typed(ClauseNode::make);
.set_body_typed([](Pattern lhs, Expr rhs) {
return Clause(lhs, rhs);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ClauseNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -115,18 +125,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->rhs << ")";
});
Match MatchNode::make(Expr data, tvm::Array<Clause> clauses, bool complete) {
Match::Match(Expr data, tvm::Array<Clause> clauses, bool complete) {
ObjectPtr<MatchNode> n = make_object<MatchNode>();
n->data = std::move(data);
n->clauses = std::move(clauses);
n->complete = complete;
return Match(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_GLOBAL("relay.ir.Match")
.set_body_typed(MatchNode::make);
.set_body_typed([](Expr data, tvm::Array<Clause> clauses, bool complete) {
return Match(data, clauses, complete);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MatchNode>([](const ObjectRef& ref, ReprPrinter* p) {
......
......@@ -33,6 +33,12 @@ using namespace tvm::runtime;
TVM_REGISTER_NODE_TYPE(IdNode);
Id::Id(std::string name_hint) {
ObjectPtr<IdNode> n = make_object<IdNode>();
n->name_hint = std::move(name_hint);
data_ = std::move(n);
}
TVM_REGISTER_GLOBAL("ir.NodeSetSpan")
.set_body_typed([](ObjectRef node_ref, Span sp) {
if (auto* rn = node_ref.as<RelayNode>()) {
......
......@@ -30,16 +30,18 @@ namespace relay {
using tvm::ReprPrinter;
using namespace tvm::runtime;
Constant ConstantNode::make(runtime::NDArray data) {
Constant::Constant(runtime::NDArray data) {
ObjectPtr<ConstantNode> n = make_object<ConstantNode>();
n->data = std::move(data);
return Constant(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_GLOBAL("relay.ir.Constant")
.set_body_typed(ConstantNode::make);
.set_body_typed([](runtime::NDArray data) {
return Constant(data);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ConstantNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -63,16 +65,18 @@ TensorType ConstantNode::tensor_type() const {
return TensorType(shape, dtype);
}
Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
Tuple::Tuple(tvm::Array<relay::Expr> fields) {
ObjectPtr<TupleNode> n = make_object<TupleNode>();
n->fields = std::move(fields);
return Tuple(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_GLOBAL("relay.ir.Tuple")
.set_body_typed(TupleNode::make);
.set_body_typed([](tvm::Array<relay::Expr> fields) {
return Tuple(fields);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -81,23 +85,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});
Var VarNode::make(Id vid, Type type_annotation) {
Var::Var(Id vid, Type type_annotation) {
ObjectPtr<VarNode> n = make_object<VarNode>();
n->vid = std::move(vid);
n->type_annotation = std::move(type_annotation);
return Var(n);
}
Var VarNode::make(std::string name_hint, Type type_annotation) {
ObjectPtr<IdNode> n = make_object<IdNode>();
n->name_hint = std::move(name_hint);
return VarNode::make(Id(n), type_annotation);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_GLOBAL("relay.ir.Var")
.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
.set_body_typed([](std::string str, Type type_annotation) {
return Var(str, type_annotation);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<VarNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -110,21 +110,21 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << ")";
});
Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
Array<Type> type_args) {
Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args) {
ObjectPtr<CallNode> n = make_object<CallNode>();
n->op = std::move(op);
n->args = std::move(args);
n->attrs = std::move(attrs);
n->type_args = std::move(type_args);
return Call(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_GLOBAL("relay.ir.Call")
.set_body_typed(CallNode::make);
.set_body_typed([](Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args) {
return Call(op, args, attrs, type_args);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<CallNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -133,18 +133,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< node->attrs << ", " << node->type_args << ")";
});
Let LetNode::make(Var var, Expr value, Expr body) {
Let::Let(Var var, Expr value, Expr body) {
ObjectPtr<LetNode> n = make_object<LetNode>();
n->var = std::move(var);
n->value = std::move(value);
n->body = std::move(body);
return Let(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_GLOBAL("relay.ir.Let")
.set_body_typed(LetNode::make);
.set_body_typed([](Var var, Expr value, Expr body) {
return Let(var, value, body);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<LetNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -153,18 +155,20 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ", " << node->body << ")";
});
If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
If::If(Expr cond, Expr true_branch, Expr false_branch) {
ObjectPtr<IfNode> n = make_object<IfNode>();
n->cond = std::move(cond);
n->true_branch = std::move(true_branch);
n->false_branch = std::move(false_branch);
return If(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_GLOBAL("relay.ir.If")
.set_body_typed(IfNode::make);
.set_body_typed([](Expr cond, Expr true_branch, Expr false_branch) {
return If(cond, true_branch, false_branch);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IfNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -173,17 +177,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ", " << node->false_branch << ")";
});
TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
TupleGetItem::TupleGetItem(Expr tuple, int index) {
ObjectPtr<TupleGetItemNode> n = make_object<TupleGetItemNode>();
n->tuple = std::move(tuple);
n->index = index;
return TupleGetItem(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem")
.set_body_typed(TupleGetItemNode::make);
.set_body_typed([](Expr tuple, int index) {
return TupleGetItem(tuple, index);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -191,16 +197,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});
RefCreate RefCreateNode::make(Expr value) {
RefCreate::RefCreate(Expr value) {
ObjectPtr<RefCreateNode> n = make_object<RefCreateNode>();
n->value = std::move(value);
return RefCreate(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(RefCreateNode);
TVM_REGISTER_GLOBAL("relay.ir.RefCreate")
.set_body_typed(RefCreateNode::make);
.set_body_typed([](Expr value) {
return RefCreate(value);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefCreateNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -208,16 +216,18 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "RefCreateNode(" << node->value << ")";
});
RefRead RefReadNode::make(Expr ref) {
RefRead::RefRead(Expr ref) {
ObjectPtr<RefReadNode> n = make_object<RefReadNode>();
n->ref = std::move(ref);
return RefRead(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(RefReadNode);
TVM_REGISTER_GLOBAL("relay.ir.RefRead")
.set_body_typed(RefReadNode::make);
.set_body_typed([](Expr ref) {
return RefRead(ref);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefReadNode>([](const ObjectRef& ref, ReprPrinter* p) {
......@@ -225,17 +235,19 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "RefReadNode(" << node->ref << ")";
});
RefWrite RefWriteNode::make(Expr ref, Expr value) {
RefWrite::RefWrite(Expr ref, Expr value) {
ObjectPtr<RefWriteNode> n = make_object<RefWriteNode>();
n->ref = std::move(ref);
n->value = std::move(value);
return RefWrite(n);
data_ = std::move(n);
}
TVM_REGISTER_NODE_TYPE(RefWriteNode);
TVM_REGISTER_GLOBAL("relay.ir.RefWrite")
.set_body_typed(RefWriteNode::make);
.set_body_typed([](Expr ref, Expr value) {
return RefWrite(ref, value);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RefWriteNode>([](const ObjectRef& ref, ReprPrinter* p) {
......
......@@ -47,7 +47,7 @@ Expr ExprMutator::VisitExpr_(const VarNode* op) {
if (op->type_annotation.defined()) {
auto type = this->VisitType(op->type_annotation);
if (!op->type_annotation.same_as(type)) {
return VarNode::make(op->vid, type);
return Var(op->vid, type);
}
}
// default case return self.
......@@ -78,7 +78,7 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op) {
if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return TupleNode::make(fields);
return Tuple(fields);
}
}
......@@ -134,7 +134,7 @@ Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
if (unchanged) {
return GetRef<Expr>(call_node);
} else {
return CallNode::make(new_op, call_args, call_node->attrs, ty_args);
return Call(new_op, call_args, call_node->attrs, ty_args);
}
}
......@@ -148,7 +148,7 @@ Expr ExprMutator::VisitExpr_(const LetNode* op) {
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return LetNode::make(var, value, body);
return Let(var, value, body);
}
}
......@@ -161,7 +161,7 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) {
op->false_branch.same_as(false_b)) {
return GetRef<Expr>(op);;
} else {
return IfNode::make(guard, true_b, false_b);
return If(guard, true_b, false_b);
}
}
......@@ -170,7 +170,7 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItemNode::make(t, g->index);
return TupleGetItem(t, g->index);
}
}
......@@ -179,7 +179,7 @@ Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
return RefCreateNode::make(value);
return RefCreate(value);
}
}
......@@ -188,7 +188,7 @@ Expr ExprMutator::VisitExpr_(const RefReadNode* op) {
if (ref.same_as(op->ref)) {
return GetRef<Expr>(op);
} else {
return RefReadNode::make(ref);
return RefRead(ref);
}
}
......@@ -198,7 +198,7 @@ Expr ExprMutator::VisitExpr_(const RefWriteNode* op) {
if (ref.same_as(op->ref) && value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
return RefWriteNode::make(ref, value);
return RefWrite(ref, value);
}
}
......@@ -211,12 +211,12 @@ Expr ExprMutator::VisitExpr_(const MatchNode* m) {
for (const Clause& p : m->clauses) {
clauses.push_back(VisitClause(p));
}
return MatchNode::make(VisitExpr(m->data), clauses, m->complete);
return Match(VisitExpr(m->data), clauses, m->complete);
}
Clause ExprMutator::VisitClause(const Clause& c) {
Pattern p = VisitPattern(c->lhs);
return ClauseNode::make(p, VisitExpr(c->rhs));
return Clause(p, VisitExpr(c->rhs));
}
Pattern ExprMutator::VisitPattern(const Pattern& p) { return p; }
......@@ -391,7 +391,7 @@ class ExprBinder : public ExprMutator, PatternMutator {
Clause VisitClause(const Clause& c) final {
Pattern pat = VisitPattern(c->lhs);
return ClauseNode::make(pat, VisitExpr(c->rhs));
return Clause(pat, VisitExpr(c->rhs));
}
Var VisitVar(const Var& v) final {
......
......@@ -36,7 +36,7 @@ Pattern PatternMutator::VisitPattern_(const PatternWildcardNode* op) {
}
Pattern PatternMutator::VisitPattern_(const PatternVarNode* op) {
return PatternVarNode::make(VisitVar(op->var));
return PatternVar(VisitVar(op->var));
}
Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
......@@ -44,7 +44,7 @@ Pattern PatternMutator::VisitPattern_(const PatternConstructorNode* op) {
for (const auto& p : op->patterns) {
pat.push_back(VisitPattern(p));
}
return PatternConstructorNode::make(VisitConstructor(op->constructor), pat);
return PatternConstructor(VisitConstructor(op->constructor), pat);
}
Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) {
......@@ -52,7 +52,7 @@ Pattern PatternMutator::VisitPattern_(const PatternTupleNode* op) {
for (const auto& p : op->patterns) {
pat.push_back(VisitPattern(p));
}
return PatternTupleNode::make(pat);
return PatternTuple(pat);
}
Type PatternMutator::VisitType(const Type& t) {
......@@ -62,7 +62,7 @@ Type PatternMutator::VisitType(const Type& t) {
Var PatternMutator::VisitVar(const Var& v) {
if (var_map_.count(v) == 0) {
var_map_.insert(std::pair<Var, Var>(v,
VarNode::make(v->name_hint(),
Var(v->name_hint(),
VisitType(v->type_annotation))));
}
return var_map_.at(v);
......
......@@ -75,10 +75,6 @@ class FunctionPassNode : public PassNode {
*/
PassInfo Info() const override { return pass_info; }
TVM_DLL static FunctionPass make(
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info);
static constexpr const char* _type_key = "relay.FunctionPass";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode);
......@@ -95,16 +91,25 @@ class FunctionPassNode : public PassNode {
class FunctionPass : public Pass {
public:
/*!
* \brief The constructor
* \param pass_func The packed function which implements a pass.
* \param pass_info The pass info.
*/
TVM_DLL FunctionPass(
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info);
TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
};
FunctionPass FunctionPassNode::make(
FunctionPass::FunctionPass(
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
auto n = make_object<FunctionPassNode>();
n->pass_func = std::move(pass_func);
n->pass_info = std::move(pass_info);
return FunctionPass(n);
data_ = std::move(n);
}
// Perform Module -> Module optimizations at the Function level.
......@@ -149,13 +154,16 @@ Pass CreateFunctionPass(
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required) {
PassInfo pass_info = PassInfo(opt_level, name, required);
return FunctionPassNode::make(pass_func, pass_info);
return FunctionPass(pass_func, pass_info);
}
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass")
.set_body_typed(FunctionPassNode::make);
.set_body_typed([](runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func,
PassInfo pass_info) {
return FunctionPass(pass_func, pass_info);
});
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
......
......@@ -56,7 +56,7 @@ Expr MakeArgsort(Expr data,
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("argsort");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
......
......@@ -79,7 +79,7 @@ Expr MakeTopK(Expr data,
attrs->is_ascend = is_ascend;
attrs->dtype = dtype;
static const Op& op = Op::Get("topk");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
......
......@@ -44,7 +44,7 @@ TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device")
auto attrs = make_object<OnDeviceAttrs>();
attrs->device_type = device_type;
static const Op& op = Op::Get("on_device");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("on_device")
......@@ -59,7 +59,7 @@ RELAY_REGISTER_OP("on_device")
Expr StopFusion(Expr data) {
static const Op& op = Op::Get("annotation.stop_fusion");
return CallNode::make(op, {data}, Attrs{}, {});
return Call(op, {data}, Attrs{}, {});
}
TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion")
......@@ -90,7 +90,7 @@ Expr CastHint(Expr data, DataType dtype) {
auto attrs = make_object<CastHintAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("annotation.cast_hint");
return CallNode::make(op, {data}, Attrs{attrs}, {});
return Call(op, {data}, Attrs{attrs}, {});
}
RELAY_REGISTER_OP("annotation.cast_hint")
......@@ -147,7 +147,7 @@ Mark the end of bitpacking.
TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint")
.set_body_typed([](Expr data) {
static const Op& op = Op::Get("annotation.checkpoint");
return CallNode::make(op, {data}, Attrs{}, {});
return Call(op, {data}, Attrs{}, {});
});
RELAY_REGISTER_OP("annotation.checkpoint")
......@@ -195,7 +195,7 @@ TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_begin")
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_begin");
return CallNode::make(op, {expr}, Attrs(attrs), {});
return Call(op, {expr}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("annotation.compiler_end")
......@@ -220,7 +220,7 @@ TVM_REGISTER_GLOBAL("relay.op.annotation._make.compiler_end")
auto attrs = make_object<CompilerAttrs>();
attrs->compiler = compiler;
static const Op& op = Op::Get("annotation.compiler_end");
return CallNode::make(op, {expr}, Attrs(attrs), {});
return Call(op, {expr}, Attrs(attrs), {});
});
} // namespace relay
......
......@@ -61,7 +61,7 @@ Expr MakeDebug(Expr expr, std::string name) {
dattrs->debug_func = EnvFunc();
}
static const Op& op = Op::Get("debug");
return CallNode::make(op, {expr}, Attrs(dattrs), {});
return Call(op, {expr}, Attrs(dattrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make.debug")
......
......@@ -48,7 +48,7 @@ TVM_REGISTER_GLOBAL("relay.op._make.device_copy")
attrs->src_dev_type = src_dev_type;
attrs->dst_dev_type = dst_dev_type;
static const Op& op = Op::Get("device_copy");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("device_copy")
......
......@@ -62,7 +62,7 @@ Expr MakeDilation2D(Expr data,
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("image.dilation2d");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
return Call(op, {data, weight}, Attrs(attrs), {});
}
template <typename AttrType>
......@@ -80,18 +80,18 @@ bool Dilation2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Dilation2D only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Dilation2D only support kernel layouts that are convertible from OIHW."
<< " But got " << kernel_layout;
Layout out_layout(param->data_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Dilation2D only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
......
......@@ -44,7 +44,7 @@ bool ResizeRel(const Array<Type>& types,
const ResizeAttrs* param = attrs.as<ResizeAttrs>();
CHECK(param != nullptr);
const Layout in_layout(param->layout);
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "Resize only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
......@@ -80,7 +80,7 @@ Expr MakeResize(Expr data,
attrs->coordinate_transformation_mode = coordinate_transformation_mode;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("image.resize");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
......@@ -135,7 +135,7 @@ bool CropAndResizeRel(const Array<Type>& types,
// 4-D tensor of shape [num_boxes, crop_height, crop_width, depth]
static const Layout kNCHW("NCHW");
const Layout in_layout(param->layout);
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
auto oshape = layout_converter.ForwardShape(data->shape);
oshape.Set(0, box_indices->shape[0]);
oshape.Set(2, crop_size[0]);
......@@ -163,7 +163,7 @@ Expr MakeCropAndResize(Expr data,
attrs->extrapolation_value = std::move(extrapolation_value);
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("image.crop_and_resize");
return CallNode::make(op, {data, boxes, box_indices}, Attrs(attrs), {});
return Call(op, {data, boxes, box_indices}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize")
......
......@@ -46,7 +46,7 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage")
auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("memory.alloc_storage");
return CallNode::make(op, {size, alignment}, Attrs(attrs), {});
return Call(op, {size, alignment}, Attrs(attrs), {});
});
bool AllocStorageRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
......@@ -98,7 +98,7 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor")
attrs->const_shape = Downcast<Constant>(shape);
}
static const Op& op = Op::Get("memory.alloc_tensor");
return CallNode::make(op, {storage, shape}, Attrs(attrs), {});
return Call(op, {storage, shape}, Attrs(attrs), {});
});
std::vector<int64_t> FromConstShape(Constant konst) {
......@@ -211,7 +211,7 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op")
.set_body_typed(
[](Expr func, Expr inputs, Expr outputs) {
return CallNode::make(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
return Call(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
});
RELAY_REGISTER_OP("memory.invoke_tvm_op")
......@@ -262,7 +262,7 @@ TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func")
static const Op& op = Op::Get("memory.shape_func");
auto attrs = make_object<ShapeFuncAttrs>();
attrs->is_input = is_input;
return CallNode::make(op, {func, inputs, outputs}, Attrs(attrs), {});
return Call(op, {func, inputs, outputs}, Attrs(attrs), {});
});
static void FlattenTypeAux(const Type& type, std::vector<TensorType>* out) {
......
......@@ -94,7 +94,7 @@ Expr MakeBitPack(Expr data, int bits, int pack_axis, int bit_axis, DataType pack
attrs->pack_type = pack_type;
attrs->name = name;
static const Op& op = Op::Get("nn.bitpack");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.bitpack").set_body_typed(MakeBitPack);
......@@ -130,7 +130,7 @@ bool BinaryConv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attr
static const Layout kNCHW("NCHW");
const Layout in_layout(param->data_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
CHECK(param->channels.defined());
CHECK(param->kernel_size.defined());
......@@ -167,7 +167,7 @@ Expr MakeBinaryConv2D(Expr data, Expr weight, Array<IndexExpr> strides, Array<In
attrs->out_dtype = std::move(out_dtype);
attrs->unipolar = unipolar;
static const Op& op = Op::Get("nn.bitserial_conv2d");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
return Call(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.bitserial_conv2d").set_body_typed(MakeBinaryConv2D);
......@@ -235,7 +235,7 @@ Expr MakeBinaryDense(Expr data, Expr weight, IndexExpr units, int data_bits, int
attrs->out_dtype = out_dtype;
attrs->unipolar = unipolar;
static const Op& op = Op::Get("nn.bitserial_dense");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
return Call(op, {data, weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.bitserial_dense").set_body_typed(MakeBinaryDense);
......
......@@ -60,7 +60,7 @@ Expr MakeConv(Expr data,
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get(op_name);
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
return Call(op, {data, weight}, Attrs(attrs), {});
}
......@@ -218,18 +218,18 @@ bool Conv2DTransposeRel(const Array<Type>& types,
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
......@@ -324,7 +324,7 @@ Expr MakeConv2DTranspose(Expr data,
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv2d_transpose");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
return Call(op, {data, weight}, Attrs(attrs), {});
}
......@@ -383,18 +383,18 @@ bool Conv1DTransposeRel(const Array<Type>& types,
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCW);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCW);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCW."
<< " But got " << out_layout;
......@@ -483,7 +483,7 @@ Expr MakeConv1DTranspose(Expr data,
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.conv1d_transpose");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
return Call(op, {data, weight}, Attrs(attrs), {});
}
......@@ -538,18 +538,18 @@ bool Conv2DWinogradRel(const Array<Type>& types,
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got "<< kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
......@@ -632,7 +632,7 @@ Expr MakeConv2DWinograd(Expr data,
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_without_weight_transform");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
return Call(op, {data, weight}, Attrs(attrs), {});
}
......@@ -695,7 +695,7 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight,
auto attrs = make_object<Conv2DWinogradWeightTransformAttrs>();
attrs->tile_size = tile_size;
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_weight_transform");
return CallNode::make(op, {weight}, Attrs(attrs), {});
return Call(op, {weight}, Attrs(attrs), {});
}
......@@ -759,7 +759,7 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
attrs->convolution_algorithm = convolution_algorithm;
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_winograd_nnpack_weight_transform");
return CallNode::make(op, {weight}, Attrs(attrs), {});
return Call(op, {weight}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
......@@ -805,7 +805,7 @@ Expr MakeConv2DNCHWc(Expr data,
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_conv2d_NCHWc");
return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
return Call(op, {data, kernel}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
......@@ -855,7 +855,7 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data,
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_depthwise_conv2d_NCHWc");
return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
return Call(op, {data, kernel}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
......@@ -1017,7 +1017,7 @@ Expr MakeDeformableConv2D(Expr data,
attrs->out_layout = out_layout;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("nn.deformable_conv2d");
return CallNode::make(op, {data, offset, weight}, Attrs{attrs}, {});
return Call(op, {data, offset, weight}, Attrs{attrs}, {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.deformable_conv2d")
......
......@@ -48,18 +48,18 @@ bool Conv1DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCW);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIW."
<< " But got " << kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCW);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCW."
<< " But got " << out_layout;
......@@ -136,18 +136,18 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCHW);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIHW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIHW."
<< " But got " << kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCHW);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCHW."
<< " But got " << out_layout;
......@@ -255,18 +255,18 @@ bool Conv3DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const Layout in_layout(param->data_layout);
const Layout kernel_layout(param->kernel_layout);
const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCDHW);
const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNCDHW);
CHECK(trans_in_layout.defined())
<< "Conv only support input layouts that are convertible from NCDHW."
<< " But got " << in_layout;
const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIDHW);
const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kOIDHW);
CHECK(trans_kernel_layout.defined())
<< "Conv only support kernel layouts that are convertible from OIDHW."
<< " But got " << kernel_layout;
Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCDHW);
const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNCDHW);
CHECK(trans_out_layout.defined())
<< "Conv only support output layouts that are convertible from NCDHW."
<< " But got " << out_layout;
......
......@@ -75,7 +75,7 @@ Expr MakeBiasAdd(Expr data,
auto attrs = make_object<BiasAddAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.bias_add");
return CallNode::make(op, {data, bias}, Attrs(attrs), {});
return Call(op, {data, bias}, Attrs(attrs), {});
}
......@@ -108,7 +108,7 @@ Expr MakeFIFOBuffer(Expr input, Expr buffer, int axis) {
auto attrs = make_object<FIFOBufferAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.fifo_buffer");
return CallNode::make(op, {input, buffer}, Attrs(attrs), {});
return Call(op, {input, buffer}, Attrs(attrs), {});
}
bool FIFOBufferRel(const Array<Type>& types,
......@@ -180,7 +180,7 @@ Expr MakeDense(Expr data,
attrs->units = units;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("nn.dense");
return CallNode::make(op, {data, weight}, Attrs(attrs), {});
return Call(op, {data, weight}, Attrs(attrs), {});
}
......@@ -212,7 +212,7 @@ Expr MakeLeakyRelu(Expr data,
auto attrs = make_object<LeakyReluAttrs>();
attrs->alpha = alpha;
static const Op& op = Op::Get("nn.leaky_relu");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
......@@ -291,7 +291,7 @@ Expr MakePRelu(Expr data,
auto attrs = make_object<PReluAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.prelu");
return CallNode::make(op, {data, alpha}, Attrs(attrs), {});
return Call(op, {data, alpha}, Attrs(attrs), {});
}
......@@ -329,7 +329,7 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax")
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.softmax");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
});
......@@ -363,7 +363,7 @@ TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax")
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.log_softmax");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("nn.log_softmax")
......@@ -422,7 +422,7 @@ bool BatchFlattenRel(const Array<Type>& types,
Expr MakeBatchFlatten(Expr data) {
static const Op& op = Op::Get("nn.batch_flatten");
return CallNode::make(op, {data}, Attrs(), {});
return Call(op, {data}, Attrs(), {});
}
......@@ -468,7 +468,7 @@ Example::
TVM_REGISTER_GLOBAL("relay.op.nn._make.relu")
.set_body_typed([](Expr data) {
static const Op& op = Op::Get("nn.relu");
return CallNode::make(op, {data}, Attrs(), {});
return Call(op, {data}, Attrs(), {});
});
RELAY_REGISTER_OP("nn.relu")
......@@ -506,7 +506,7 @@ Expr MakeLRN(Expr data,
attrs->beta = beta;
attrs->bias = bias;
static const Op& op = Op::Get("nn.lrn");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.lrn")
......@@ -544,7 +544,7 @@ Expr MakeL2Normalize(Expr data,
attrs->eps = eps;
attrs->axis = std::move(axis);
static const Op& op = Op::Get("nn.l2_normalize");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.l2_normalize")
......@@ -589,7 +589,7 @@ Expr MakeDropout(Expr data, double rate) {
auto attrs = make_object<DropoutAttrs>();
attrs->rate = rate;
static const Op& op = Op::Get("nn.dropout");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.dropout")
......@@ -687,7 +687,7 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.batch_norm");
return CallNode::make(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {});
return Call(op, {data, gamma, beta, moving_mean, moving_var}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.batch_norm")
......@@ -770,7 +770,7 @@ Expr MakeInstanceNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.instance_norm");
return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
return Call(op, {data, gamma, beta}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.instance_norm")
......@@ -840,7 +840,7 @@ Expr MakeLayerNorm(Expr data, Expr gamma, Expr beta, int axis, double epsilon,
attrs->center = center;
attrs->scale = scale;
static const Op& op = Op::Get("nn.layer_norm");
return CallNode::make(op, {data, gamma, beta}, Attrs(attrs), {});
return Call(op, {data, gamma, beta}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.layer_norm")
......@@ -891,7 +891,7 @@ bool BatchMatmulRel(const Array<Type>& types,
Expr MakeBatchMatmul(Expr x,
Expr y) {
static const Op& op = Op::Get("nn.batch_matmul");
return CallNode::make(op, {x, y}, Attrs(), {});
return Call(op, {x, y}, Attrs(), {});
}
......@@ -948,7 +948,7 @@ bool CrossEntropyRel(const Array<Type>& types,
// Positional relay function to create cross_entropy operator used by frontend FFI.
Expr MakeCrossEntropy(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy");
return CallNode::make(op, {predictions, targets}, Attrs(), {});
return Call(op, {predictions, targets}, Attrs(), {});
}
......@@ -971,7 +971,7 @@ Do log on the data - do not accept logits.
// Positional relay function to create cross_entropy_with_logits operator used by frontend FFI.
Expr MakeCrossEntropyWithLogits(Expr predictions, Expr targets) {
static const Op& op = Op::Get("nn.cross_entropy_with_logits");
return CallNode::make(op, {predictions, targets}, Attrs(), {});
return Call(op, {predictions, targets}, Attrs(), {});
}
......@@ -1005,7 +1005,7 @@ bool DepthToSpaceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
CHECK(param != nullptr);
const int block_size = param->block_size;
const Layout in_layout(param->layout);
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "DepthToSpace only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
......@@ -1030,7 +1030,7 @@ Expr MakeDepthToSpace(Expr data, int block_size, std::string layout, std::string
attrs->layout = std::move(layout);
attrs->mode = std::move(mode);
static const Op& op = Op::Get("nn.depth_to_space");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.depth_to_space").set_body_typed(MakeDepthToSpace);
......@@ -1063,7 +1063,7 @@ bool SpaceToDepthRel(const Array<Type>& types, int num_inputs, const Attrs& attr
CHECK(param != nullptr);
const int block_size = param->block_size;
const Layout in_layout(param->layout);
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "SpaceToDepth only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
......@@ -1087,7 +1087,7 @@ Expr MakeSpaceToDepth(Expr data, int block_size, std::string layout) {
attrs->block_size = block_size;
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.space_to_depth");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.space_to_depth").set_body_typed(MakeSpaceToDepth);
......
......@@ -196,7 +196,7 @@ Expr MakePad(Expr data,
attrs->pad_width = std::move(pad_width);
attrs->pad_mode = std::move(pad_mode);
static const Op& op = Op::Get("nn.pad");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.pad")
......@@ -270,7 +270,7 @@ Expr MakeMirrorPad(Expr data, Array<Array<IndexExpr> > pad_width, std::string mo
attrs->mode = mode;
attrs->pad_width = std::move(pad_width);
static const Op& op = Op::Get("nn.mirror_pad");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.mirror_pad")
......
......@@ -70,7 +70,7 @@ Expr MakeMaxPool(Expr data,
attrs->layout = std::move(layout);
attrs->ceil_mode = ceil_mode;
static const Op& op = Op::Get(op_name);
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
template <typename T>
......@@ -90,7 +90,7 @@ Expr MakeAvgPool(Expr data,
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
static const Op& op = Op::Get(op_name);
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
template <typename AttrType>
......@@ -175,7 +175,7 @@ Array<te::Tensor> Pool2DCompute(const Attrs& attrs,
auto ceil_mode = param->ceil_mode;
Layout layout(param->layout);
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
CHECK(tir::BijectiveLayout(layout, kNCHW).defined())
<< "max_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "max_pool2d does not support input split on height";
......@@ -336,7 +336,7 @@ Array<te::Tensor> GlobalPool2DCompute(const Attrs& attrs,
const auto* param = attrs.as<GlobalPool2DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
CHECK(tir::BijectiveLayout(layout, kNCHW).defined())
<< "global_avg_pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "global_avg_pool2d does not support input split on height";
......@@ -355,7 +355,7 @@ Expr MakeGlobalAvgPool2D(Expr data,
auto attrs = make_object<GlobalPool2DAttrs>();
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.global_avg_pool2d");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
......@@ -387,7 +387,7 @@ Expr MakeGlobalMaxPool2D(Expr data,
auto attrs = make_object<GlobalPool2DAttrs>();
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.global_max_pool2d");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.global_max_pool2d")
......@@ -469,7 +469,7 @@ Array<te::Tensor> AdaptivePool2DCompute(const Attrs& attrs,
const auto* param = attrs.as<AdaptivePool2DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
CHECK(tir::BijectiveLayout(layout, kNCHW).defined())
<< "Adaptive pool2d currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "Adaptive pool2d does not support input split on height";
......@@ -507,7 +507,7 @@ Expr MakeAdaptiveAvgPool2D(Expr data,
attrs->output_size = std::move(output_size);
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.adaptive_avg_pool2d");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool2d")
......@@ -545,7 +545,7 @@ Expr MakeAdaptiveMaxPool2D(Expr data,
attrs->output_size = std::move(output_size);
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.adaptive_max_pool2d");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool2d")
......@@ -637,7 +637,7 @@ Array<te::Tensor> AdaptivePool3DCompute(const Attrs& attrs,
const auto* param = attrs.as<AdaptivePool3DAttrs>();
CHECK(param != nullptr);
Layout layout(param->layout);
CHECK(BijectiveLayoutNode::make(layout, kNCDHW).defined())
CHECK(tir::BijectiveLayout(layout, kNCDHW).defined())
<< "Adaptive pool3d currently only supports layouts that are convertible from NCDHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1)
<< "Adaptive pool3d does not support input split on depth";
......@@ -683,7 +683,7 @@ Expr MakeAdaptiveMaxPool3D(Expr data,
attrs->output_size = std::move(output_size);
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.adaptive_max_pool3d");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_max_pool3d")
......@@ -721,7 +721,7 @@ Expr MakeAdaptiveAvgPool3D(Expr data,
attrs->output_size = std::move(output_size);
attrs->layout = std::move(layout);
static const Op& op = Op::Get("nn.adaptive_avg_pool3d");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.adaptive_avg_pool3d")
......@@ -776,7 +776,7 @@ Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs,
auto ceil_mode = param->ceil_mode;
Layout layout(param->layout);
CHECK(BijectiveLayoutNode::make(layout, kNCHW).defined())
CHECK(tir::BijectiveLayout(layout, kNCHW).defined())
<< "pool2d_grad currently only supports layouts that are convertible from NCHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
<< "pool2d_grad does not support input split on height";
......@@ -820,7 +820,7 @@ Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
attrs->layout = std::move(layout);
attrs->ceil_mode = ceil_mode;
static const Op& op = Op::Get("nn.max_pool2d_grad");
return CallNode::make(op, {out_grad, data}, Attrs(attrs), {});
return Call(op, {out_grad, data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d_grad").set_body_typed(MakeMaxPool2DGrad);
......@@ -869,7 +869,7 @@ Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
attrs->ceil_mode = ceil_mode;
attrs->count_include_pad = count_include_pad;
static const Op& op = Op::Get("nn.avg_pool2d_grad");
return CallNode::make(op, {out_grad, data}, Attrs(attrs), {});
return Call(op, {out_grad, data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d_grad").set_body_typed(MakeAvgPool2DGrad);
......@@ -975,7 +975,7 @@ Array<te::Tensor> Pool1DCompute(const Attrs& attrs,
auto ceil_mode = param->ceil_mode;
Layout layout(param->layout);
CHECK(BijectiveLayoutNode::make(layout, kNCW).defined())
CHECK(tir::BijectiveLayout(layout, kNCW).defined())
<< "max_pool1d currently only supports layouts that are convertible from NCW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
<< "max_pool1d does not support input split on width";
......@@ -1166,7 +1166,7 @@ Array<te::Tensor> Pool3DCompute(const Attrs& attrs,
auto ceil_mode = param->ceil_mode;
Layout layout(param->layout);
CHECK(BijectiveLayoutNode::make(layout, kNCDHW).defined())
CHECK(tir::BijectiveLayout(layout, kNCDHW).defined())
<< "max_pool3d currently only supports layouts that are convertible from NCDHW";
CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1)
<< "max_pool3d does not support input split on depth";
......
......@@ -67,7 +67,7 @@ bool SparseDenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
Expr MakeSparseDense(Expr data, Expr weight_data, Expr weight_indices, Expr weight_indptr) {
auto attrs = make_object<SparseDenseAttrs>();
static const Op& op = Op::Get("nn.sparse_dense");
return CallNode::make(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
return Call(op, {data, weight_data, weight_indices, weight_indptr}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_dense")
......@@ -116,7 +116,7 @@ bool SparseTransposeRel(const Array<Type>& types, int num_inputs, const Attrs& a
Expr MakeSparseTranspose(Expr sparse_data, Expr sparse_indices, Expr sparse_indptr) {
auto attrs = make_object<SparseTransposeAttrs>();
static const Op& op = Op::Get("nn.sparse_transpose");
return CallNode::make(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {});
return Call(op, {sparse_data, sparse_indices, sparse_indptr}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.sparse_transpose")
......
......@@ -76,7 +76,7 @@ bool UpSamplingRel(const Array<Type>& types,
CHECK(param != nullptr);
const Layout in_layout(param->layout);
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW);
auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
CHECK(layout_converter.defined())
<< "UpSampling only support input layouts that are convertible from NCHW."
<< " But got " << in_layout;
......@@ -108,7 +108,7 @@ Expr MakeUpSampling(Expr data,
attrs->scale_w = scale_w;
attrs->align_corners = align_corners;
static const Op& op = Op::Get("nn.upsampling");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling")
......@@ -155,7 +155,7 @@ bool UpSampling3DRel(const Array<Type>& types,
CHECK(param != nullptr);
const Layout in_layout(param->layout);
auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCDHW);
auto layout_converter = tir::BijectiveLayout(in_layout, kNCDHW);
CHECK(layout_converter.defined())
<< "UpSampling3D only support input layouts that are convertible from NCDHW."
<< " But got " << in_layout;
......@@ -189,7 +189,7 @@ Expr MakeUpSampling3D(Expr data,
attrs->scale_w = scale_w;
attrs->coordinate_transformation_mode = coordinate_transformation_mode;
static const Op& op = Op::Get("nn.upsampling3d");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.upsampling3d")
......
......@@ -51,7 +51,7 @@ namespace relay {
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed([](Expr data) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(), {}); \
return Call(op, {data}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
......@@ -77,7 +77,7 @@ namespace relay {
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
return Call(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
......@@ -94,7 +94,7 @@ namespace relay {
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
return Call(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
......
......@@ -317,7 +317,7 @@ bool ReduceRel(const Array<Type>& types,
attrs->keepdims = keepdims; \
attrs->exclude = exclude; \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(attrs), {}); \
return Call(op, {data}, Attrs(attrs), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
......@@ -624,7 +624,7 @@ Expr MakeVariance(Expr data,
attrs->keepdims = keepdims;
attrs->exclude = exclude;
static const Op& op = Op::Get("variance");
return CallNode::make(op, {data, mean}, Attrs(attrs), {});
return Call(op, {data, mean}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op._make._variance")
......
......@@ -184,7 +184,7 @@ TVM_REGISTER_GLOBAL("relay.op._make.clip")
attrs->a_min = a_min;
attrs->a_max = a_max;
static const Op& op = Op::Get("clip");
return CallNode::make(op, {a}, Attrs(attrs), {});
return Call(op, {a}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("clip")
......@@ -347,7 +347,7 @@ TVM_REGISTER_GLOBAL("relay.op._make.shape_of")
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("shape_of");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("shape_of")
......@@ -397,7 +397,7 @@ TVM_REGISTER_GLOBAL("relay.op._make.ndarray_size")
auto attrs = make_object<NdarraySizeAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("ndarray_size");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("ndarray_size")
......
......@@ -68,7 +68,7 @@ Expr MakeMultiBoxPrior(Expr data,
attrs->offsets = std::move(offsets);
attrs->clip = clip;
static const Op& op = Op::Get("vision.multibox_prior");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
......@@ -141,7 +141,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob,
attrs->threshold = std::move(threshold);
attrs->variances = std::move(variances);
static const Op& op = Op::Get("vision.multibox_transform_loc");
return CallNode::make(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {});
return Call(op, {cls_prob, loc_pred, anchor}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.vision._make.multibox_transform_loc")
......
......@@ -57,7 +57,7 @@ Expr MakeGetValidCounts(Expr data,
attrs->id_index = id_index;
attrs->score_index = score_index;
static const Op& op = Op::Get("vision.get_valid_counts");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
......@@ -125,7 +125,7 @@ Expr MakeNMS(Expr data,
attrs->return_indices = return_indices;
attrs->invalid_to_bottom = invalid_to_bottom;
static const Op& op = Op::Get("vision.non_max_suppression");
return CallNode::make(op, {data, valid_count}, Attrs(attrs), {});
return Call(op, {data, valid_count}, Attrs(attrs), {});
}
......
......@@ -57,7 +57,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spa
attrs->sample_ratio = sample_ratio;
attrs->layout = layout;
static const Op& op = Op::Get("vision.roi_align");
return CallNode::make(op, {data, rois}, Attrs(attrs), {});
return Call(op, {data, rois}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_align")
......@@ -107,7 +107,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spat
attrs->spatial_scale = spatial_scale;
attrs->layout = layout;
static const Op& op = Op::Get("vision.roi_pool");
return CallNode::make(op, {data, rois}, Attrs(attrs), {});
return Call(op, {data, rois}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.vision._make.roi_pool")
......@@ -173,7 +173,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array<IndexExpr>
attrs->rpn_min_size = rpn_min_size;
attrs->iou_loss = iou_loss;
static const Op& op = Op::Get("vision.proposal");
return CallNode::make(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {});
return Call(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relay.op.vision._make.proposal")
......
......@@ -65,7 +65,7 @@ Expr MakeYoloReorg(Expr data,
auto attrs = make_object<YoloReorgAttrs>();
attrs->stride = stride;
static const Op& op = Op::Get("vision.yolo_reorg");
return CallNode::make(op, {data}, Attrs(attrs), {});
return Call(op, {data}, Attrs(attrs), {});
}
......
......@@ -113,7 +113,7 @@ Expr MakeQnnConcatenate(Expr data, Expr input_scales, Expr input_zero_points, Ex
auto attrs = make_object<ConcatenateAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("qnn.concatenate");
return CallNode::make(op,
return Call(op,
{data, input_scales, input_zero_points, output_scale, output_zero_point},
Attrs(attrs), {});
}
......@@ -184,7 +184,7 @@ Expr ConcatenateQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
}
idx++;
}
return MakeConcatenate(TupleNode::make(requantized_exprs), concatenate_attrs->axis);
return MakeConcatenate(Tuple(requantized_exprs), concatenate_attrs->axis);
}
RELAY_REGISTER_OP("qnn.concatenate")
......
......@@ -673,7 +673,7 @@ Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_ze
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("qnn.conv2d");
return CallNode::make(
return Call(
op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
Attrs(attrs), {});
}
......
......@@ -72,7 +72,7 @@ Expr MakeQuantizedDense(Expr data, Expr weight, Expr input_zero_point, Expr kern
attrs->units = std::move(units);
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("qnn.dense");
return CallNode::make(
return Call(
op, {data, weight, input_zero_point, kernel_zero_point, input_scale, kernel_scale},
Attrs(attrs), {});
}
......
......@@ -62,7 +62,7 @@ Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point) {
// A more detailed explanation can be found here -
// https://github.com/google/gemmlowp/blob/master/doc/quantization.md
static const Op& op = Op::Get("qnn.dequantize");
return CallNode::make(op, {data, input_scale, input_zero_point}, Attrs(), {});
return Call(op, {data, input_scale, input_zero_point}, Attrs(), {});
}
Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
......
......@@ -68,10 +68,10 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
.set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \
Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \
static const Op& op = Op::Get("qnn." OpName); \
return CallNode::make(op, {lhs, rhs, \
lhs_scale, lhs_zero_point, \
rhs_scale, rhs_zero_point, \
output_scale, output_zero_point}, Attrs(), {}); \
return Call(op, {lhs, rhs, \
lhs_scale, lhs_zero_point, \
rhs_scale, rhs_zero_point, \
output_scale, output_zero_point}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP("qnn." OpName) \
.set_num_inputs(8) \
......
......@@ -77,7 +77,7 @@ Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis
// A more detailed explanation can be found here -
// https://github.com/google/gemmlowp/blob/master/doc/quantization.md
static const Op& op = Op::Get("qnn.quantize");
return CallNode::make(op, {data, output_scale, output_zero_point}, Attrs(attrs), {});
return Call(op, {data, output_scale, output_zero_point}, Attrs(attrs), {});
}
Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
......
......@@ -299,7 +299,7 @@ Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr out
attrs->rounding = std::move(rounding);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("qnn.requantize");
return CallNode::make(op, {data, input_scale, input_zero_point, output_scale, output_zero_point},
return Call(op, {data, input_scale, input_zero_point, output_scale, output_zero_point},
Attrs(attrs), {});
}
......
......@@ -45,8 +45,6 @@ class QAnnotateExprNode : public TempExprNode {
v->Visit("kind", &kind);
}
TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QAnnotateExpr";
......@@ -55,6 +53,13 @@ class QAnnotateExprNode : public TempExprNode {
class QAnnotateExpr : public TempExpr {
public:
/*!
* \brief The constructor
* \param expr The original relay expression.
* \param kind The annotation kind.
*/
TVM_DLL QAnnotateExpr(Expr expr, QAnnotateKind kind);
TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode);
};
......@@ -63,18 +68,17 @@ Expr QAnnotateExprNode::Realize() const {
return expr;
}
QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) {
QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) {
auto rnode = make_object<QAnnotateExprNode>();
rnode->expr = expr;
rnode->expr = std::move(expr);
rnode->kind = kind;
return QAnnotateExpr(rnode);
data_ = std::move(rnode);
}
TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QAnnotateExprNode::make(args[0],
static_cast<QAnnotateKind>(args[1].operator int()));
});
.set_body_typed([](Expr expr, int kind) {
return QAnnotateExpr(expr, static_cast<QAnnotateKind>(kind));
});
Pass QuantizeAnnotate() {
......@@ -87,7 +91,7 @@ Pass QuantizeAnnotate() {
const PackedFunc* f =
runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
return static_cast<Expr>(QAnnotateExprNode::make(ret, kQInput));
return static_cast<Expr>(QAnnotateExpr(ret, kQInput));
}
return e;
};
......
......@@ -150,7 +150,7 @@ class StatsCollector : private ExprMutator {
auto new_e = this->Mutate(expr);
const FunctionNode* func = new_e.as<FunctionNode>();
CHECK(func) << "Input shoule be Function";
Expr new_body = TupleNode::make(std::move(profile_data_));
Expr new_body = Tuple(std::move(profile_data_));
return Function(FreeVars(new_body), new_body, NullValue<Type>(), func->type_params,
func->attrs);
}
......@@ -173,7 +173,7 @@ class StatsCollector : private ExprMutator {
new_attrs->kind = QAnnotateKind::kQIdentity;
new_attrs->sign = attrs->sign;
new_attrs->rounding = attrs->rounding;
Expr identity_quantize = CallNode::make(new_call->op, new_args, Attrs{new_attrs}, {});
Expr identity_quantize = Call(new_call->op, new_args, Attrs{new_attrs}, {});
// add non-const expressions to profile data
if (attrs->kind != QAnnotateKind::kQWeight) {
......
......@@ -45,8 +45,6 @@ class QPartitionExprNode : public TempExprNode {
v->Visit("expr", &expr);
}
TVM_DLL static QPartitionExpr make(Expr expr);
Expr Realize() const final;
static constexpr const char* _type_key = "relay.QPartitionExpr";
......@@ -55,6 +53,12 @@ class QPartitionExprNode : public TempExprNode {
class QPartitionExpr : public TempExpr {
public:
/*!
* \brief The constructor
* \param expr The original relay expression.
*/
TVM_DLL explicit QPartitionExpr(Expr expr);
TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode);
};
......@@ -66,16 +70,16 @@ Expr QPartitionExprNode::Realize() const {
return StopFusion(ret);
}
QPartitionExpr QPartitionExprNode::make(Expr expr) {
QPartitionExpr::QPartitionExpr(Expr expr) {
auto rnode = make_object<QPartitionExprNode>();
rnode->expr = expr;
return QPartitionExpr(rnode);
rnode->expr = std::move(expr);
data_ = std::move(rnode);
}
TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = QPartitionExprNode::make(args[0]);
});
.set_body_typed([](Expr expr) {
return QPartitionExpr(expr);
});
Pass QuantizePartition() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
......
......@@ -75,7 +75,7 @@ TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize")
attrs->sign = sign;
attrs->rounding = rounding;
static const Op& op = Op::Get("relay.op.annotation.simulated_quantize");
return CallNode::make(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
});
......
......@@ -67,14 +67,14 @@ class QRealizeIntExprNode : public QRealizeExprNode {
Expr Realize() const final;
TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype);
static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr";
TVM_DECLARE_FINAL_OBJECT_INFO(QRealizeIntExprNode, QRealizeExprNode);
};
class QRealizeIntExpr : public QRealizeExpr {
public:
TVM_DLL QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(QRealizeIntExpr, QRealizeExpr, QRealizeIntExprNode);
};
......@@ -87,18 +87,17 @@ Expr QRealizeIntExprNode::Realize() const {
return data;
}
QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) {
QRealizeIntExpr::QRealizeIntExpr(Expr data, Expr dom_scale, DataType dtype) {
ObjectPtr<QRealizeIntExprNode> n = make_object<QRealizeIntExprNode>();
n->data = std::move(data);
n->dom_scale = std::move(dom_scale);
n->dtype = std::move(dtype);
return QRealizeIntExpr(n);
data_ = std::move(n);
}
inline Expr ForwardOp(const Call& ref_call, const Array<Expr>& args) {
return CallNode::make(ref_call->op,
args, ref_call->attrs, ref_call->type_args);
return Call(ref_call->op, args, ref_call->attrs, ref_call->type_args);
}
......@@ -150,7 +149,7 @@ Expr QuantizeRealize(const Call& ref_call,
if (idom_scale_imm == odom_scale_imm) {
// same domain scale, only clip
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
return QRealizeIntExpr(data, dom_scale, n->dtype);
}
float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm);
......@@ -170,14 +169,14 @@ Expr QuantizeRealize(const Call& ref_call,
static_cast<int>(shift_nbit)));
}
data = Clip(data, clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
return QRealizeIntExpr(data, dom_scale, n->dtype);
} else {
data = Cast(data, DataType::Int(64));
data = qnn::FixedPointMultiply(data, idom_scale_imm / odom_scale_imm,
ref_call->type_as<TensorTypeNode>()->shape,
cfg->rounding);
data = Cast(Clip(data, clip_min_imm, clip_max_imm), n->dtype);
return QRealizeIntExprNode::make(data, dom_scale, n->dtype);
return QRealizeIntExpr(data, dom_scale, n->dtype);
}
}
......@@ -186,7 +185,7 @@ Expr QuantizeRealize(const Call& ref_call,
Expr data = new_args[0];
Expr scaled_data = Multiply(data, MakeConstantScalar(DataType::Float(32), 1 / dom_scale_imm));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QRealizeIntExprNode::make(round_data, dom_scale, DataType::Float(32));
return QRealizeIntExpr(round_data, dom_scale, DataType::Float(32));
}
Expr FoldConstantOpt(const Expr& expr) {
......@@ -225,11 +224,11 @@ Expr Conv2dRealize(const Call& ref_call,
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = CallNode::make(ref_call->op,
Expr ret = Call(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
return QRealizeIntExpr(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.conv2d")
......@@ -259,11 +258,11 @@ Expr DenseRealize(const Call& ref_call,
DataType out_dtype = cfg->dtype_activation;
attrs->out_dtype = out_dtype;
Expr ret = CallNode::make(ref_call->op,
Expr ret = Call(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExprNode::make(ret, dom_scale, out_dtype);
return QRealizeIntExpr(ret, dom_scale, out_dtype);
}
RELAY_REGISTER_OP("nn.dense")
......@@ -293,7 +292,7 @@ Expr MulRealize(const Call& ref_call,
Expr ret = ForwardOp(ref_call, {ldata, rdata});
Expr mul = Multiply(lhs->dom_scale, rhs->dom_scale);
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
return QRealizeIntExpr(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>());
return Expr(nullptr);
......@@ -377,7 +376,7 @@ Expr AddRealize(const Call& ref_call,
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(ref_call->args, new_args, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, ret_args);
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
return QRealizeIntExpr(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>());
......@@ -398,9 +397,9 @@ Expr ClipRealize(const Call& ref_call,
attrs->a_min = ref_attrs->a_min / dom_scale;
attrs->a_max = ref_attrs->a_max / dom_scale;
Expr ret = CallNode::make(ref_call->op,
Expr ret = Call(ref_call->op,
{n->data}, Attrs(attrs), ref_call->type_args);
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
return QRealizeIntExpr(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
......@@ -427,8 +426,8 @@ Expr ConcatenateRealize(const Call& ref_call,
DataType dtype;
Expr dom_scale;
Array<Expr> ret_args = UnifyDTypeScale(ref_arr, arr, &dtype, &dom_scale);
Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)});
return QRealizeIntExprNode::make(ret, dom_scale, dtype);
Expr ret = ForwardOp(ref_call, {Tuple(ret_args)});
return QRealizeIntExpr(ret, dom_scale, dtype);
} else {
for (auto arg : new_args) {
CHECK(!arg->IsInstance<TempExprNode>());
......@@ -448,7 +447,7 @@ Expr IdentityRealize(const Call& ref_call,
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype);
return QRealizeIntExpr(ret, n->dom_scale, n->dtype);
}
CHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
......@@ -472,7 +471,7 @@ Expr CastDtypeInputRealize(const Call& ref_call,
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr data = Cast(n->data, cfg->dtype_input);
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input);
return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_input);
}
CHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
......@@ -493,7 +492,7 @@ Expr AvgPoolRealize(const Call& ref_call,
data = Cast(n->data, cfg->dtype_activation);
}
Expr ret = ForwardOp(ref_call, {data});
return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation);
return QRealizeIntExpr(ret, n->dom_scale, cfg->dtype_activation);
}
CHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
......@@ -512,7 +511,7 @@ Expr CastHintRealize(const Call& ref_call,
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QRealizeIntExprNode>()) {
Expr ret = Cast(n->data, param->dtype);
return QRealizeIntExprNode::make(ret, n->dom_scale, param->dtype);
return QRealizeIntExpr(ret, n->dom_scale, param->dtype);
}
CHECK(!new_args[0]->IsInstance<TempExprNode>());
return Expr(nullptr);
......
......@@ -93,7 +93,7 @@ class AlterTransformMemorizer : public TransformMemorizer {
}
}
if (!modified) {
new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs);
new_e = Call(ref_call->op, new_args, ref_call->attrs);
}
const CallNode* new_call = new_e.as<CallNode>();
......
......@@ -58,7 +58,7 @@ class AnnotateTargetWrapper : public ExprMutator {
Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin);
}
Expr update_call = CallNode::make(call->op, compiler_begins, call->attrs);
Expr update_call = Call(call->op, compiler_begins, call->attrs);
const auto* end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
......
......@@ -83,7 +83,7 @@ class CastCanonicalizer : public ExprMutator {
if (unchanged) {
return GetRef<Expr>(call);
}
return CallNode::make(call->op, call_args, call->attrs, call->type_args);
return Call(call->op, call_args, call->attrs, call->type_args);
}
}
......@@ -112,7 +112,7 @@ class CastCanonicalizer : public ExprMutator {
const CallNode* new_call = new_expr.as<CallNode>();
CHECK(new_call);
CHECK(new_call->op == cast_op_);
return CallNode::make(new_call->op, new_call->args, new_call->attrs,
return Call(new_call->op, new_call->args, new_call->attrs,
new_call->type_args);
}
}
......
......@@ -67,9 +67,9 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
CHECK(attrs_b);
const auto* tweight_a = a->args[1]->type_as<TensorTypeNode>();
const auto* tweight_b = b->args[1]->type_as<TensorTypeNode>();
const auto shape_a = BijectiveLayoutNode::make(
const auto shape_a = tir::BijectiveLayout(
Layout(attrs_a->kernel_layout), kOIHW).ForwardShape(tweight_a->shape);
const auto shape_b = BijectiveLayoutNode::make(
const auto shape_b = tir::BijectiveLayout(
Layout(attrs_b->kernel_layout), kOIHW).ForwardShape(tweight_b->shape);
return eq(attrs_a->strides, attrs_b->strides) && eq(attrs_a->padding, attrs_b->padding) &&
......@@ -108,7 +108,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
channel_pos_ = layout.find('C');
CHECK_NE(channel_pos_, std::string::npos);
return CallNode::make(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
return Call(conv2d, {data, new_weight}, Attrs{new_attrs}, {});
}
bool IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
......@@ -159,11 +159,11 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
tuple.push_back(branch[depth]->args[i]);
}
auto concat = MakeConcatenate(TupleNode::make(tuple), arg_channel_pos);
auto concat = MakeConcatenate(Tuple(tuple), arg_channel_pos);
new_args.push_back(std::move(concat));
}
return CallNode::make(call->op, new_args, call->attrs, {});
return Call(call->op, new_args, call->attrs, {});
}
void UpdateGroupOutput(const Expr& data,
......@@ -203,7 +203,7 @@ class ParallelConv2DCombiner : public ParallelOpCombiner {
}
auto index = branches[0][0]->attrs.as<Conv2DAttrs>()->kernel_layout.find('O');
CHECK_NE(index, std::string::npos);
return std::make_tuple(MakeConcatenate(TupleNode::make(weights), index),
return std::make_tuple(MakeConcatenate(Tuple(weights), index),
tir::make_const(DataType::Int(32), num_filters));
}
};
......
......@@ -105,10 +105,10 @@ Call ParallelOpBatchCombiner::MakeCombinedOp(const Group& branches) {
arg_from_all_branches.push_back(branch[0]->args[i]);
}
new_args.push_back(MakeStack(TupleNode::make(arg_from_all_branches), 0));
new_args.push_back(MakeStack(Tuple(arg_from_all_branches), 0));
}
return CallNode::make(batch_op, new_args, Attrs(), {});
return Call(batch_op, new_args, Attrs(), {});
}
bool ParallelOpBatchCombiner::IsArgCompatible(const CallNode* a, const CallNode* b, size_t index) {
......@@ -153,11 +153,11 @@ Call ParallelOpBatchCombiner::MakeCombinedCallFromFollowingOps(const Expr& data,
}
}
auto stack = MakeStack(TupleNode::make(tuple), 0);
auto stack = MakeStack(Tuple(tuple), 0);
new_args.push_back(std::move(stack));
}
return CallNode::make(call->op, new_args, call->attrs, {});
return Call(call->op, new_args, call->attrs, {});
}
void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data,
......@@ -167,7 +167,7 @@ void ParallelOpBatchCombiner::UpdateGroupOutput(const Expr& data,
int index = 0;
auto split = MakeSplit(data, Integer(branches.size()), 0);
for (const auto& branch : branches) {
auto split_data = TupleGetItemNode::make(split, index++);
auto split_data = TupleGetItem(split, index++);
auto squeezed_data = MakeSqueeze(split_data, {0});
subst_map->insert({GetRef<Expr>(branch[depth]), squeezed_data});
}
......
......@@ -99,7 +99,7 @@ class ConvertTransformMemorizer : public TransformMemorizer {
}
}
if (!modified) {
new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs);
new_e = Call(ref_call->op, new_args, ref_call->attrs);
}
const CallNode* new_call = new_e.as<CallNode>();
......
......@@ -44,7 +44,7 @@ Expr DeDup(const Expr& e) {
Var Fresh(const Var& v) {
CHECK_EQ(rename_.count(v), 0);
CHECK_EQ(memo_.count(v), 0) << v.as<VarNode>();
Var ret = VarNode::make(v->name_hint(), VisitType(v->type_annotation));
Var ret = Var(v->name_hint(), VisitType(v->type_annotation));
rename_[v] = ret;
return ret;
}
......@@ -62,7 +62,7 @@ Expr DeDup(const Expr& e) {
Expr VisitExpr_(const LetNode* op) final {
Var v = Fresh(op->var);
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
return Let(v, VisitExpr(op->value), VisitExpr(op->body));
}
Type VisitType(const Type& t) final {
......@@ -90,7 +90,7 @@ Expr DeDup(const Expr& e) {
}
Pattern VisitPattern_(const PatternVarNode* op) final {
return PatternVarNode::make(Fresh(op->var));
return PatternVar(Fresh(op->var));
}
Type VisitType_(const TypeVarNode* op) final {
......
......@@ -84,7 +84,7 @@ class Eliminator : private ExprMutator {
Expr VisitExpr_(const LetNode* op) final {
Var v = op->var;
if (HasLet(v)) {
return LetNode::make(v, VisitExpr(op->value), VisitExpr(op->body));
return Let(v, VisitExpr(op->value), VisitExpr(op->body));
} else {
return VisitExpr(op->body);
}
......
......@@ -134,7 +134,7 @@ class RewriteAnnotation : public ExprMutator {
if (value.same_as(op->value) && body.same_as(op->body)) {
return ExprMutator::VisitExpr_(op);
} else {
Expr new_let = LetNode::make(op->var, value, body);
Expr new_let = Let(op->var, value, body);
UpdateAnnotationMap(op, new_let.operator->());
return this->VisitExpr(new_let);
}
......@@ -149,7 +149,7 @@ class RewriteAnnotation : public ExprMutator {
}
if (annotated) {
Expr new_tuple = TupleNode::make(fields);
Expr new_tuple = Tuple(fields);
UpdateAnnotationMap(op, new_tuple.operator->());
return this->VisitExpr(new_tuple);
} else {
......@@ -161,7 +161,7 @@ class RewriteAnnotation : public ExprMutator {
Expr tuple = op->tuple;
if (NeedDeviceCopy(tuple.operator->(), op)) {
Expr new_expr =
TupleGetItemNode::make(GetDeviceCopyExpr(tuple, op), op->index);
TupleGetItem(GetDeviceCopyExpr(tuple, op), op->index);
UpdateAnnotationMap(op, new_expr.operator->());
return this->VisitExpr(new_expr);
} else {
......@@ -178,7 +178,7 @@ class RewriteAnnotation : public ExprMutator {
if_node->false_branch.same_as(false_br)) {
return ExprMutator::VisitExpr_(if_node);
} else {
Expr new_if = IfNode::make(cond, true_br, false_br);
Expr new_if = If(cond, true_br, false_br);
UpdateAnnotationMap(if_node, new_if.operator->());
return this->VisitExpr(new_if);
}
......@@ -201,7 +201,7 @@ class RewriteAnnotation : public ExprMutator {
}
if (annotated) {
Call new_call = CallNode::make(call_node->op, new_args, call_node->attrs,
Call new_call = Call(call_node->op, new_args, call_node->attrs,
call_node->type_args);
UpdateAnnotationMap(call_node, new_call.operator->());
......@@ -284,7 +284,7 @@ class RewriteAnnotation : public ExprMutator {
attrs->src_dev_type = src_dev_type;
attrs->dst_dev_type = dst_dev_type;
static const Op& op = Op::Get("device_copy");
Call device_copy = CallNode::make(op, {src}, Attrs(attrs), {});
Call device_copy = Call(op, {src}, Attrs(attrs), {});
annotation_map_.insert({device_copy.operator->(), dst_dev_type});
return device_copy;
}
......@@ -526,7 +526,7 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
} else if (tuple->fields.size() == new_body.size()) {
return new_expr;
} else {
Tuple tuple_body = TupleNode::make(new_body);
Tuple tuple_body = Tuple(new_body);
return Function(params, tuple_body, Type(nullptr),
fn->type_params, fn->attrs);
}
......@@ -545,7 +545,7 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
return new_fields.size() == 1 ? new_fields[0] : new_expr;
} else {
return new_fields.size() == 1 ? new_fields[0]
: TupleNode::make(new_fields);
: Tuple(new_fields);
}
} else {
return new_expr;
......
......@@ -89,7 +89,7 @@ class EtaExpander : public ExprMutator {
for (const auto& arg : call->args) {
new_args.push_back(VisitExpr(arg));
}
return CallNode::make(new_op, new_args, call->attrs, call->type_args);
return Call(new_op, new_args, call->attrs, call->type_args);
}
Expr VisitExpr_(const ConstructorNode* cons_node) final {
......@@ -101,14 +101,14 @@ class EtaExpander : public ExprMutator {
tvm::Array<Expr> params;
for (const auto& type : cons->inputs) {
Type param_type = type_var_replacer_.VisitType(type);
params.push_back(VarNode::make("eta_expand_param", param_type));
params.push_back(Var("eta_expand_param", param_type));
}
tvm::Array<Type> type_params;
TypeData adt_def = mod_->LookupTypeDef(cons->belong_to);
for (const auto& type_var : adt_def->type_vars) {
type_params.push_back(type_var_replacer_.VisitType(type_var));
}
Expr body = CallNode::make(cons, params, Attrs());
Expr body = Call(cons, params, Attrs());
Type ret_type = TypeCall(cons->belong_to, type_params);
return Function(
......@@ -130,14 +130,14 @@ class EtaExpander : public ExprMutator {
tvm::Array<Expr> params;
tvm::Array<Var> args;
for (size_t i = 0; i < func->params.size(); ++i) {
auto var = VarNode::make("eta_expand_param", func->params[i]->type_annotation);
auto var = Var("eta_expand_param", func->params[i]->type_annotation);
params.push_back(var);
args.push_back(var);
}
return Function(
args,
CallNode::make(gvar, params),
Call(gvar, params),
func->ret_type,
func->type_params);
} else {
......
......@@ -103,7 +103,7 @@ class ConstantFolder : public ExprMutator {
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return LetNode::make(var, value, body);
return Let(var, value, body);
}
}
}
......@@ -187,14 +187,14 @@ class ConstantFolder : public ExprMutator {
CHECK_GT(dim, 0)
<< "invalid dimension after constant eval";
}
return ConstantNode::make(nd_array);
return Constant(nd_array);
} else if (const auto* val = value.as<runtime::ADTObj>()) {
runtime::ADT adt = GetRef<runtime::ADT>(val);
Array<Expr> fields;
for (size_t i = 0; i < adt.size(); ++i) {
fields.push_back(ObjectToExpr(adt[i]));
}
return TupleNode::make(fields);
return Tuple(fields);
} else {
LOG(FATAL) << "Cannot handle " << value->GetTypeKey();
return Expr();
......@@ -267,13 +267,13 @@ class ConstantFolder : public ExprMutator {
if (shape->data.Shape().size() == 0 && GetScalarFromConstant<int32_t>(shape) == 0) {
auto ndarray = runtime::NDArray::Empty({}, cdtype, ctx);
shape = ConstantNode::make(ndarray);
shape = Constant(ndarray);
}
// Cast the constant into correct dtype
auto cast_attrs = make_object<CastAttrs>();
cast_attrs->dtype = param->dtype;
Expr ret = CallNode::make(cast_op_, { shape }, Attrs(cast_attrs), {});
Expr ret = Call(cast_op_, { shape }, Attrs(cast_attrs), {});
return ConstEvaluate(ret);
}
};
......
......@@ -92,22 +92,28 @@ class MessageNode : public RelayNode {
*/
bool require_positive;
static Message make(const AxesSet& axes, bool require_positive);
static constexpr const char* _type_key = "relay.pass.fold_scale_axis.Message";
TVM_DECLARE_FINAL_OBJECT_INFO(MessageNode, RelayNode);
};
class Message : public ObjectRef {
public:
/*!
* \brief The constructor
* \param axes Axes for scaling
* \param require_positive If folding requires the scales to be positive
* values.
*/
Message(const AxesSet& axes, bool require_positive);
TVM_DEFINE_OBJECT_REF_METHODS(Message, ObjectRef, MessageNode);
};
Message MessageNode::make(const AxesSet& axes, bool require_positive) {
Message::Message(const AxesSet& axes, bool require_positive) {
auto n = make_object<MessageNode>();
n->axes = axes;
n->require_positive = require_positive;
return Message(n);
data_ = std::move(n);
}
/*!
......@@ -150,7 +156,7 @@ Message Intersect(const Message& lhs, const Message& rhs) {
if (!lhs.defined()) return lhs;
if (!rhs.defined()) return rhs;
auto axes = Intersect(lhs->axes, rhs->axes);
return MessageNode::make(axes, lhs->require_positive || rhs->require_positive);
return Message(axes, lhs->require_positive || rhs->require_positive);
}
/*!
......@@ -315,7 +321,7 @@ class ForwardPrep : private ExprVisitor {
// Intermediate operators
Array<Message> ReluForwardPrep(const Call& call, const Message& out_message) {
if (out_message.defined()) {
return {MessageNode::make(out_message->axes, true)};
return {Message(out_message->axes, true)};
}
return {out_message};
}
......@@ -327,7 +333,7 @@ Expr ReluForwardRewrite(const Call& ref_call,
if (input == nullptr) return Expr(nullptr);
// return transformed conv2d
auto rnode = make_object<ScaledExprNode>();
rnode->value = CallNode::make(
rnode->value = Call(
ref_call->op, {input->value}, ref_call->attrs, ref_call->type_args);
rnode->scale = input->scale;
rnode->axes = input->axes;
......@@ -377,7 +383,7 @@ Expr AddSubForwardRewrite(const Call& ref_call,
Expr scale = ExpandBiasToMatchAxis(
slhs->scale, tlhs->shape.size(), slhs->axes);
Expr rhs = Divide(new_args[1], scale);
rnode->value = CallNode::make(ref_call->op, {slhs->value, rhs},
rnode->value = Call(ref_call->op, {slhs->value, rhs},
ref_call->attrs, ref_call->type_args);
rnode->scale = slhs->scale;
rnode->axes = slhs->axes;
......@@ -387,7 +393,7 @@ Expr AddSubForwardRewrite(const Call& ref_call,
Expr scale = ExpandBiasToMatchAxis(
srhs->scale, trhs->shape.size(), srhs->axes);
Expr lhs = Divide(new_args[0], scale);
rnode->value = CallNode::make(ref_call->op, {lhs, srhs->value},
rnode->value = Call(ref_call->op, {lhs, srhs->value},
ref_call->attrs, ref_call->type_args);
rnode->scale = srhs->scale;
rnode->axes = srhs->axes;
......@@ -476,7 +482,7 @@ Array<Message> Conv2DForwardPrep(const Call& call, const Message& out_message) {
data_axes = {c_big_axis};
}
if (data_axes.defined()) {
return {MessageNode::make(data_axes, false), none};
return {Message(data_axes, false), none};
}
return {none, none};
}
......@@ -521,7 +527,7 @@ Expr Conv2DForwardRewrite(const Call& ref_call,
weight = Multiply(weight, scale);
}
// return transformed conv2d
return CallNode::make(
return Call(
ref_call->op, {sdata->value, weight}, ref_call->attrs, ref_call->type_args);
}
......@@ -726,7 +732,7 @@ Expr BackwardTransformerNode::Transform(
// Intermediate operators
Message ReluBackwardPrep(const Call& call, const Array<Message>& in_messages) {
if (in_messages[0].defined()) {
return MessageNode::make(in_messages[0]->axes, true);
return Message(in_messages[0]->axes, true);
}
return in_messages[0];
}
......@@ -740,7 +746,7 @@ Expr ReluBackwardTransform(const Call& call,
}
Expr input = transformer->Transform(
call->args[0], message, scale);
return CallNode::make(call->op, {input}, call->attrs, call->type_args);
return Call(call->op, {input}, call->attrs, call->type_args);
}
RELAY_REGISTER_OP("nn.relu")
......@@ -796,7 +802,7 @@ Expr AddSubBackwardTransform(const Call& call,
CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], message, scale);
Expr rhs = transformer->Transform(call->args[1], message, scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (lhs_message.defined()) {
CHECK(equal(message->axes, lhs_message->axes));
Expr lhs = transformer->Transform(call->args[0], message, scale);
......@@ -805,7 +811,7 @@ Expr AddSubBackwardTransform(const Call& call,
Expr rhs_scale = ExpandBiasToMatchAxis(
scale, tlhs->shape.size(), message->axes);
rhs = Multiply(rhs, rhs_scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else if (rhs_message.defined()) {
CHECK(equal(message->axes, rhs_message->axes));
Expr lhs = transformer->Transform(
......@@ -814,7 +820,7 @@ Expr AddSubBackwardTransform(const Call& call,
Expr lhs_scale = ExpandBiasToMatchAxis(
scale, trhs->shape.size(), message->axes);
lhs = Multiply(lhs, lhs_scale);
return CallNode::make(call->op, {lhs, rhs}, call->attrs, call->type_args);
return Call(call->op, {lhs, rhs}, call->attrs, call->type_args);
} else {
LOG(FATAL) << "outstanding scale";
return Expr();
......@@ -890,7 +896,7 @@ Message Conv2DBackwardPrep(const Call& call, const Array<Message>& in_messages)
kernel_layout.IndexOf(LayoutAxis::Get('i')) < 0 &&
c_small_axis < 0 &&
(param->groups == 1 || is_depthwise_conv2d)) {
return MessageNode::make({c_big_axis}, false);
return Message({c_big_axis}, false);
} else {
return NullValue<Message>();
}
......@@ -930,7 +936,7 @@ Expr Conv2DBackwardTransform(const Call& call,
Expr wscale = ExpandBiasToMatchAxis(
scale, kernel_layout.ndim(), {big_oc_axis});
weight = Multiply(weight, wscale);
return CallNode::make(
return Call(
call->op, {data, weight}, call->attrs, call->type_args);
}
......
......@@ -125,7 +125,7 @@ class ForwardRewriter : private ExprMutator {
if (tuple.same_as(op->tuple)) {
return GetRef<Expr>(op);
} else {
return TupleGetItemNode::make(tuple, op->index);
return TupleGetItem(tuple, op->index);
}
}
}
......@@ -142,7 +142,7 @@ class ForwardRewriter : private ExprMutator {
if (all_fields_unchanged) {
return GetRef<Expr>(op);
} else {
return TupleNode::make(fields);
return Tuple(fields);
}
}
......@@ -185,7 +185,7 @@ class ForwardRewriter : private ExprMutator {
}
}
if (unchanged) return ref_call;
return CallNode::make(
return Call(
new_op, call_args, call_node->attrs, call_node->type_args);
}
};
......
......@@ -837,7 +837,7 @@ class FuseMutator : private ExprMutator {
// create a new parameter.
std::ostringstream os;
os << "p" << params.size();
auto var = VarNode::make(os.str(), type);
auto var = Var(os.str(), type);
params.push_back(var);
arguments.push_back(expr);
return var;
......@@ -878,7 +878,7 @@ class FuseMutator : private ExprMutator {
auto* ret_group = gmap_.at(call)->FindRoot();
Array<Expr> new_args = GetNewArguments(call->args, ret_group);
auto new_call = CallNode::make(
auto new_call = Call(
call->op, new_args, call->attrs, call->type_args);
if (ret_group->root_ref == call) {
......@@ -902,13 +902,13 @@ class FuseMutator : private ExprMutator {
}
// This tuple is an intermediate node in the group
Array<Expr> new_fields = GetNewArguments(tuple->fields, ret_group);
return TupleNode::make(new_fields);
return Tuple(new_fields);
}
Expr VisitExpr_(const TupleGetItemNode* tuple_get) {
auto* ret_group = gmap_.at(tuple_get)->FindRoot();
auto new_tuple = GetNewArguments({tuple_get->tuple}, ret_group)[0];
auto new_node = TupleGetItemNode::make(new_tuple, tuple_get->index);
auto new_node = TupleGetItem(new_tuple, tuple_get->index);
if (ret_group->root_ref == tuple_get) {
if (gmap_.at(tuple_get->tuple.get())->FindRoot() != ret_group) {
// Isolated. This case occurs when tuple is created by an Opaque op
......@@ -934,7 +934,7 @@ class FuseMutator : private ExprMutator {
const GroupInfo& ginfo = ginfo_[group];
auto func = Function(ginfo.params, body, ret_type, {});
func = WithAttr(std::move(func), attr::kPrimitive, tvm::Integer(visitor.has_call));
return CallNode::make(func, ginfo.arguments, Attrs());
return Call(func, ginfo.arguments, Attrs());
}
Array<Expr> GetNewArguments(const tvm::Array<Expr>& args,
......
......@@ -154,7 +154,7 @@ struct FirstOrderReverseAD : ExprFunctor<ADValue(const Expr &)> {
for (const ADValue& adval : args) {
call_args.push_back(adval->get<ADTensor>().forward);
}
auto orig = CallNode::make(op_ref, call_args, attrs, type_args);
auto orig = Call(op_ref, call_args, attrs, type_args);
orig->checked_type_ = orig_type;
auto ret = std::make_shared<ADTensor>(ll, orig);
backprop_actions.push_back([this, args, orig, ret, op_ref](LetList* ll) {
......@@ -250,7 +250,7 @@ Expr FirstOrderGradient(const Expr& re, const IRModule& mod) {
for (const auto& a : args) {
grad_res.push_back(a->get<ADTensor>().reverse);
}
return TupleNode::make(grad_res);
return Tuple(grad_res);
});
return Pair(res.forward, grad);
});
......@@ -297,7 +297,7 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
fields.push_back(field);
types.push_back(field->checked_type_);
}
auto ret = TupleNode::make(fields);
auto ret = Tuple(fields);
ret->checked_type_ = TupleType(types);
return std::move(ret);
} else {
......@@ -316,14 +316,14 @@ void TransferGrads(const Type& forward_type,
CHECK(IsAtomic(from)) << from;
CHECK(IsAtomic(to)) << to;
if (forward_type.as<TensorTypeNode>()) {
auto from_ref = TupleGetItemNode::make(from, 1);
auto to_ref = TupleGetItemNode::make(to, 1);
ll->Push(RefWriteNode::make(to_ref, RefReadNode::make(from_ref)));
auto from_ref = TupleGetItem(from, 1);
auto to_ref = TupleGetItem(to, 1);
ll->Push(RefWrite(to_ref, RefRead(from_ref)));
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
TransferGrads(tt->fields[i],
ll->Push(TupleGetItemNode::make(from, i)),
ll->Push(TupleGetItemNode::make(to, i)),
ll->Push(TupleGetItem(from, i)),
ll->Push(TupleGetItem(to, i)),
ll);
}
} else {
......@@ -335,7 +335,7 @@ void TransferGrads(const Type& forward_type,
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
auto rev = [&](const Expr& e) {
return Pair(e, ll->Push(RefCreateNode::make(ZerosLike(e))));
return Pair(e, ll->Push(RefCreate(ZerosLike(e))));
};
auto rev_type = [&](const Type& forward_type) {
return ReverseType(forward_type);
......@@ -357,7 +357,7 @@ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
auto grad = [&](const Expr& e) {
return ll->Push(RefReadNode::make(GetField(e, 1)));
return ll->Push(RefRead(GetField(e, 1)));
};
auto grad_type = [&](const Type& forward_type) {
return forward_type;
......@@ -367,8 +367,8 @@ Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
if (t.as<TensorTypeNode>()) {
ll->Push(RefWriteNode::make(GetField(arg, 1),
Add(ll->Push(RefReadNode::make(GetField(arg, 1))),
ll->Push(RefWrite(GetField(arg, 1),
Add(ll->Push(RefRead(GetField(arg, 1))),
grad)));
} else if (auto* tt = t.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
......@@ -384,8 +384,8 @@ void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
}
Expr BPEmpty() {
Expr unitF = Function({}, TupleNode::make({}), TupleType::Empty(), {});
return RefCreateNode::make(unitF);
Expr unitF = Function({}, Tuple(tvm::Array<Expr>({})), TupleType::Empty(), {});
return RefCreate(unitF);
}
struct ReverseAD : ExprMutator {
......@@ -412,7 +412,7 @@ struct ReverseAD : ExprMutator {
return LetList::With([&](LetList* ll) {
auto x_var = ll->Push(x);
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
auto bpv = ll->Push(RefRead(bp));
Expr nbp = Function(
{},
LetList::With([&](LetList* ll) {
......@@ -422,12 +422,12 @@ struct ReverseAD : ExprMutator {
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));
TransferGrads(call->checked_type(), ret, dup_ad, ll);
ll->Push(CallNode::make(RefReadNode::make(dup_bp), {}));
return CallNode::make(bpv, {});
ll->Push(Call(RefRead(dup_bp), {}));
return Call(bpv, {});
}),
TupleType::Empty(),
{});
ll->Push(RefWriteNode::make(bp, nbp));
ll->Push(RefWrite(bp, nbp));
return ret;
});
}
......@@ -451,12 +451,12 @@ struct ReverseAD : ExprMutator {
for (size_t i = 0; i < args.size(); i++) {
orig_args.push_back(GetValue(call->args[i]->checked_type(), args[i], ll));
}
Expr orig = CallNode::make(call->op, orig_args, call->attrs, call->type_args);
Expr orig = Call(call->op, orig_args, call->attrs, call->type_args);
orig->checked_type_ = call->checked_type();
Var orig_var = ll->Push(orig);
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefReadNode::make(bp));
auto bpv = ll->Push(RefRead(bp));
Expr nbp = Function(
{},
LetList::With([&](LetList* ll) {
......@@ -465,11 +465,11 @@ struct ReverseAD : ExprMutator {
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return CallNode::make(bpv, {});
return Call(bpv, {});
}),
TupleType::Empty(),
{});
ll->Push(RefWriteNode::make(bp, nbp));
ll->Push(RefWrite(bp, nbp));
return ret;
});
}
......@@ -478,11 +478,11 @@ struct ReverseAD : ExprMutator {
Expr VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return Pair(e, RefCreateNode::make(ZerosLike(e)));
return Pair(e, RefCreate(ZerosLike(e)));
}
Expr VisitExpr_(const IfNode* op) final {
return IfNode::make(TupleGetItemNode::make(VisitExpr(op->cond), 0),
return If(TupleGetItem(VisitExpr(op->cond), 0),
VisitExpr(op->true_branch),
VisitExpr(op->false_branch));
}
......@@ -545,13 +545,13 @@ Expr Gradient(const Expr& re, const IRModule& mod) {
Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
std::vector<Expr> args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreateNode::make(ZerosLike(p)))));
args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p)))));
}
auto c = ll->Push(CallNode::make(rev, args));
auto c = ll->Push(Call(rev, args));
std::function<void(const Expr&, const Type&)> init_grad;
init_grad = [&](const Expr& e, const Type& t) {
if (t.as<TensorTypeNode>()) {
ll->Push(RefWriteNode::make(GetField(e, 1), OnesLike(GetField(e, 0))));
ll->Push(RefWrite(GetField(e, 1), OnesLike(GetField(e, 0))));
} else if (auto tt = t.as<TupleTypeNode>()) {
CHECK_GT(tt->fields.size(), 0);
init_grad(ll->Push(GetField(e, 0)), tt->fields[0]);
......@@ -561,10 +561,10 @@ Expr Gradient(const Expr& re, const IRModule& mod) {
}
};
init_grad(c, f->body->checked_type());
ll->Push(CallNode::make(RefReadNode::make(bp), {}));
ll->Push(Call(RefRead(bp), {}));
std::vector<Expr> ret;
for (const auto& a : args) {
ret.push_back(RefReadNode::make(GetField(a, 1)));
ret.push_back(RefRead(GetField(a, 1)));
}
std::function<Expr(const Expr&, const Type&)> get_final_result;
get_final_result = [&](const Expr& e, const Type& t) -> Expr {
......@@ -575,13 +575,13 @@ Expr Gradient(const Expr& re, const IRModule& mod) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
fields.push_back(get_final_result(ll->Push(GetField(e, i)), tt->fields[i]));
}
return TupleNode::make(fields);
return Tuple(fields);
} else {
LOG(FATAL) << "unhandled type " << t;
throw;
}
};
return Pair(get_final_result(c, f->body->checked_type()), TupleNode::make(ret));
return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
});
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
......
......@@ -151,7 +151,7 @@ class Inliner : ExprMutator {
return Bind(func->body, bind_map);
}
} else if (const auto* call_node = callee.as<CallNode>()) {
return CallNode::make(func, args, call_node->attrs, call_node->type_args);
return Call(func, args, call_node->attrs, call_node->type_args);
} else {
return std::move(func);
}
......
......@@ -78,7 +78,7 @@ class LetList {
* \return a Var that hold the inserted expr.
*/
Var Push(Expr expr, Type ty) {
return Push(VarNode::make("x", ty), expr);
return Push(Var("x", ty), expr);
}
/*!
......@@ -103,7 +103,7 @@ class LetList {
CHECK(!used_);
Expr ret = body;
for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) {
ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret);
ret = Let(std::get<0>(*rit), std::get<1>(*rit), ret);
}
used_ = true;
return ret;
......@@ -120,10 +120,10 @@ class LetList {
* // Automatically call Get with LetList::With
* return LetList::With([&](LetList* ll) {
* // Turn a call to plus into a variable to avoid duplication of code
* Var b = ll->Push(CallNode::make(plus, {a, a}));
* Var c = ll->Push(CallNode::make(plus, {b, b}));
* Var d = ll->Push(CallNode::make(plus, {c, c}));
* return CallNode::make(plus, {d, d});
* Var b = ll->Push(Call(plus, {a, a}));
* Var c = ll->Push(Call(plus, {b, b}));
* Var d = ll->Push(Callplus, {c, c}));
* return Call(plus, {d, d});
* });
* }
* \endcode
......@@ -136,7 +136,7 @@ class LetList {
return ll.Get(f(&ll));
}
static Expr Let(const Expr& e, const std::function<Expr(const Var&)>& f) {
static Expr LetBind(const Expr& e, const std::function<Expr(const Var&)>& f) {
return With([&](LetList* ll) {
return f(ll->Push(e));
});
......
......@@ -45,7 +45,7 @@ class MergeCompositeWrapper : public ExprMutator {
if (var_map->find(pattern->name_hint()) == var_map->end()) {
// if we haven't encountered this var yet, make a new free var and associate
// it with the value at 'root'
auto free_var = VarNode::make(pattern->name_hint(), Type());
auto free_var = Var(pattern->name_hint(), Type());
var_map->Set(pattern->name_hint(), Array<Expr>({free_var, root}));
return std::move(free_var);
} else {
......@@ -132,7 +132,7 @@ class MergeCompositeWrapper : public ExprMutator {
new_args.push_back(new_arg);
i++;
}
return CallNode::make(root->op, new_args, root->attrs);
return Call(root->op, new_args, root->attrs);
}
Expr VisitExpr_(const CallNode* cn) {
......@@ -149,7 +149,7 @@ class MergeCompositeWrapper : public ExprMutator {
auto new_e = this->Mutate(arg);
new_args.push_back(new_e);
}
return CallNode::make(call->op, new_args, call->attrs);
return Call(call->op, new_args, call->attrs);
}
}
......@@ -175,7 +175,7 @@ class MergeCompositeWrapper : public ExprMutator {
for (const auto& free_var : free_vars) {
args.push_back(args_map[free_var->name_hint()][1]);
}
auto new_call = CallNode::make(f, args);
auto new_call = Call(f, args);
return std::move(new_call);
}
return std::move(call);
......
......@@ -603,7 +603,7 @@ static const Op& with_funcid_op = Op::Get("annotation.with_funcid");
Expr MkWithFuncId(const Expr& expr, FuncId fid) {
auto attrs = make_object<WithFuncIdAttrs>();
attrs->fid = fid;
return CallNode::make(with_funcid_op, {expr}, Attrs(attrs), {});
return Call(with_funcid_op, {expr}, Attrs(attrs), {});
}
Expr StripWithFuncId(const Expr& e);
......@@ -658,7 +658,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
value.push_back(ps);
expr.push_back(ps->dynamic);
}
return HasStatic(MkSTuple(value), ll->Push(TupleNode::make(expr)));
return HasStatic(MkSTuple(value), ll->Push(Tuple(expr)));
}
PStatic VisitExpr_(const TupleGetItemNode* op, LetList* ll) final {
......@@ -666,7 +666,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
if (ps->pstatic.defined()) {
return Downcast<STuple>(ps->pstatic)->fields[op->index];
} else {
return NoStatic(ll->Push(TupleGetItemNode::make(ps->dynamic, op->index)));
return NoStatic(ll->Push(TupleGetItem(ps->dynamic, op->index)));
}
}
......@@ -724,7 +724,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
});
});
store_.Invalidate();
return NoStatic(ll->Push(IfNode::make(c->dynamic, t, f)));
return NoStatic(ll->Push(If(c->dynamic, t, f)));
}
}
......@@ -732,7 +732,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic ps = VisitExpr(op->value, ll);
Static r = MkSRef();
store_.Insert(r.as<SRefNode>(), ps);
return HasStatic(r, ll->Push(RefCreateNode::make(ps->dynamic)));
return HasStatic(r, ll->Push(RefCreate(ps->dynamic)));
}
PStatic VisitExpr_(const RefWriteNode* op, LetList* ll) final {
......@@ -743,7 +743,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} else {
store_.Invalidate();
}
return HasStatic(MkSTuple({}), ll->Push(RefWriteNode::make(r->dynamic, v->dynamic)));
return HasStatic(MkSTuple({}), ll->Push(RefWrite(r->dynamic, v->dynamic)));
}
PStatic VisitExpr_(const RefReadNode* op, LetList* ll) final {
......@@ -754,7 +754,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return ret;
}
}
return NoStatic(ll->Push(RefReadNode::make(r->dynamic)));
return NoStatic(ll->Push(RefRead(r->dynamic)));
}
PStatic VisitExpr_(const CallNode* op, LetList* ll) final {
......@@ -774,7 +774,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return Downcast<SFunc>(f->pstatic)->func(f, x, op->attrs, op->type_args, ll);
} else {
store_.Invalidate();
return NoStatic(ll->Push(CallNode::make(f->dynamic, x_dyn, op->attrs, op->type_args)));
return NoStatic(ll->Push(Call(f->dynamic, x_dyn, op->attrs, op->type_args)));
}
}
......@@ -872,7 +872,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (const auto& v : pv) {
dyn.push_back(v->dynamic);
}
return NoStatic(ll->Push(CallNode::make(var, dyn, attrs, type_args)));
return NoStatic(ll->Push(Call(var, dyn, attrs, type_args)));
}
});
};
......@@ -898,7 +898,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic VisitFunc(const Function& func,
LetList* ll,
const Var& name = VarNode::make("x", Type())) {
const Var& name = Var("x", Type())) {
Func f = VisitFuncStatic(func, name);
Function u_func = AsFunc(RegisterFuncId(DeDup(AnnotateFuncId(func))));
// TODO(@M.K.): we seems to reduce landin knot into letrec.
......@@ -919,13 +919,13 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
if (!st->pstatic.defined()) {
throw ReflectError();
} else if (const STensorNode* op = st->pstatic.as<STensorNode>()) {
return ConstantNode::make(op->data);
return Constant(op->data);
} else if (const STupleNode* op = st->pstatic.as<STupleNode>()) {
tvm::Array<Expr> fields;
for (const PStatic& field : op->fields) {
fields.push_back(Reflect(field));
}
return TupleNode::make(fields);
return Tuple(fields);
} else {
LOG(FATAL) << "Unknown case: " << st->dynamic;
throw;
......@@ -935,7 +935,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
PStatic Reify(const ObjectRef& v, LetList* ll) const {
if (v->IsInstance<runtime::NDArray::ContainerType>()) {
auto nd_array = Downcast<runtime::NDArray>(v);
return HasStatic(MkSTensor(nd_array), ll->Push(ConstantNode::make(nd_array)));
return HasStatic(MkSTensor(nd_array), ll->Push(Constant(nd_array)));
} else if (const runtime::ADTObj* op = v.as<runtime::ADTObj>()) {
std::vector<PStatic> fields;
tvm::Array<Expr> fields_dyn;
......@@ -945,7 +945,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
fields.push_back(ps);
fields_dyn.push_back(ps->dynamic);
}
return HasStatic(MkSTuple(fields), ll->Push(TupleNode::make(fields_dyn)));
return HasStatic(MkSTuple(fields), ll->Push(Tuple(fields_dyn)));
} else {
LOG(FATAL) << "Unknown case";
throw;
......@@ -977,7 +977,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
ns_args.push_back(ps->dynamic);
}
auto ns = [&]() {
return NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
return NoStatic(ll->Push(Call(expr, ns_args, attrs, type_args)));
};
if (StatefulOp(expr)) {
return ns();
......@@ -987,7 +987,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (const PStatic& ps : pv) {
args.push_back(Reflect(ps));
}
return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll);
return ConstEvaluate(Call(expr, args, attrs, type_args), ll);
}
catch (const ReflectError&) {
return ns();
......@@ -1010,7 +1010,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (const PStatic& ps : pv) {
dyn.push_back(ps->dynamic);
}
return HasStatic(MkSConstructor(c, pv), ll->Push(CallNode::make(c, dyn)));
return HasStatic(MkSConstructor(c, pv), ll->Push(Call(c, dyn)));
};
return HasStatic(MkSFunc(f), GetRef<Expr>(op));
}
......@@ -1036,10 +1036,10 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return VisitExpr(c->rhs, ll)->dynamic;
});
});
clauses.push_back(ClauseNode::make(c->lhs, expr));
clauses.push_back(Clause(c->lhs, expr));
}
store_.Invalidate();
return NoStatic(ll->Push(MatchNode::make(ps->dynamic, clauses, op->complete)));
return NoStatic(ll->Push(Match(ps->dynamic, clauses, op->complete)));
}();
default:
LOG(FATAL) << "Unknown MatchStatus";
......
......@@ -169,7 +169,7 @@ class Partitioner : public ExprMutator {
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
// The type of the created variable is the same as the compiler_begin
// node.
auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
auto var = Var(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
call->checked_type_);
// Find the corresponding subgraph and add the argument.
......@@ -246,7 +246,7 @@ class Partitioner : public ExprMutator {
module_->Add(glob_func, subgraph_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = CallNode::make(glob_func, args);
auto ret = Call(glob_func, args);
ret->checked_type_ = call->checked_type_;
return std::move(ret);
}
......@@ -264,7 +264,7 @@ class Partitioner : public ExprMutator {
for (auto field : op->fields) {
fields.push_back(VisitExpr(field));
}
return TupleNode::make(fields);
return Tuple(fields);
}
}
......@@ -275,7 +275,7 @@ class Partitioner : public ExprMutator {
} else {
AddToSubgraph(subgraph, g->tuple);
auto t = VisitExpr(g->tuple);
return TupleGetItemNode::make(t, g->index);
return TupleGetItem(t, g->index);
}
}
......@@ -309,7 +309,7 @@ class Partitioner : public ExprMutator {
auto value = VisitExpr(op->value);
auto body = VisitExpr(op->body);
return LetNode::make(var, value, body);
return Let(var, value, body);
}
}
......@@ -324,7 +324,7 @@ class Partitioner : public ExprMutator {
auto guard = VisitExpr(op->cond);
auto true_b = VisitExpr(op->true_branch);
auto false_b = VisitExpr(op->false_branch);
return IfNode::make(guard, true_b, false_b);
return If(guard, true_b, false_b);
}
}
......@@ -335,7 +335,7 @@ class Partitioner : public ExprMutator {
} else {
AddToSubgraph(subgraph, op->value);
Expr value = VisitExpr(op->value);
return RefCreateNode::make(value);
return RefCreate(value);
}
}
......@@ -346,7 +346,7 @@ class Partitioner : public ExprMutator {
} else {
AddToSubgraph(subgraph, op->ref);
Expr ref = VisitExpr(op->ref);
return RefReadNode::make(ref);
return RefRead(ref);
}
}
......@@ -358,7 +358,7 @@ class Partitioner : public ExprMutator {
AddToSubgraph(subgraph, op->ref);
Expr ref = VisitExpr(op->ref);
Expr value = VisitExpr(op->value);
return RefWriteNode::make(ref, value);
return RefWrite(ref, value);
}
}
......
......@@ -155,7 +155,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr Compound(const Expr& orig, const Expr& now, const Var& v) {
Var var = v.defined() ?
v :
VarNode::make(std::string("x"), Type());
Var(std::string("x"), Type());
return GetScope(orig)->ll->Push(var, now);
}
......@@ -165,7 +165,7 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
for (const auto& a : c->args) {
args.push_back(VisitExpr(a));
}
return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), v);
return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v);
}
Expr VisitExpr_(const TupleNode* t, const Var& v) final {
......@@ -174,32 +174,32 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
for (const auto& a : t->fields) {
fields.push_back(VisitExpr(a));
}
return Compound(e, TupleNode::make(fields), v);
return Compound(e, Tuple(fields), v);
}
Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final {
Expr e = GetRef<Expr>(t);
return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v);
return Compound(e, TupleGetItem(VisitExpr(t->tuple), t->index), v);
}
Expr VisitExpr_(const RefCreateNode* r, const Var& v) final {
Expr e = GetRef<Expr>(r);
return Compound(e, RefCreateNode::make(VisitExpr(r->value)), v);
return Compound(e, RefCreate(VisitExpr(r->value)), v);
}
Expr VisitExpr_(const RefReadNode* r, const Var& v) final {
Expr e = GetRef<Expr>(r);
return Compound(e, RefReadNode::make(VisitExpr(r->ref)), v);
return Compound(e, RefRead(VisitExpr(r->ref)), v);
}
Expr VisitExpr_(const RefWriteNode* r, const Var& v) final {
Expr e = GetRef<Expr>(r);
return Compound(e, RefWriteNode::make(VisitExpr(r->ref), VisitExpr(r->value)), v);
return Compound(e, RefWrite(VisitExpr(r->ref), VisitExpr(r->value)), v);
}
Expr VisitExpr_(const IfNode* i, const Var& v) final {
Expr e = GetRef<Expr>(i);
Expr ret = IfNode::make(VisitExpr(i->cond),
Expr ret = If(VisitExpr(i->cond),
GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)),
GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch)));
return Compound(e, ret, v);
......@@ -257,11 +257,11 @@ class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
Expr data = VisitExpr(m->data);
std::vector<Clause> clauses;
for (const Clause& c : m->clauses) {
clauses.push_back(ClauseNode::make(
clauses.push_back(Clause(
c->lhs,
GetSubScope(e, 1 + clauses.size())->ll->Get(VisitExpr(c->rhs))));
}
return Compound(e, MatchNode::make(data, clauses, m->complete), v);
return Compound(e, Match(data, clauses, m->complete), v);
}
};
......
......@@ -137,7 +137,7 @@ Function ToCPS(const Function& f,
Expr VisitExpr_(const LetNode* op, const MCont& k) final {
return VisitExpr(op->value, [&](const Expr& v) {
return LetNode::make(remap(op->var), v, VisitExpr(op->body, k));
return Let(remap(op->var), v, VisitExpr(op->body, k));
});
}
......@@ -155,7 +155,7 @@ Function ToCPS(const Function& f,
}
Pattern VisitPattern_(const PatternVarNode* op) final {
return PatternVarNode::make(remap(op->var));
return PatternVar(remap(op->var));
}
Expr VisitExpr_(const GlobalVarNode* op, const MCont& k) final {
......@@ -177,18 +177,18 @@ Function ToCPS(const Function& f,
}
Expr VisitExpr_(const RefCreateNode* op, const MCont& k) final {
return VisitExpr(op->value, [&](const Expr& v) { return k(RefCreateNode::make(v)); });
return VisitExpr(op->value, [&](const Expr& v) { return k(RefCreate(v)); });
}
Expr reify(const MCont& k) {
Var arg = VarNode::make("arg", Type());
Var arg = Var("arg", Type());
return Function({arg}, k(arg), Type(), {}, {});
}
Expr reify(const MCont& k, const std::function<Expr(MCont)>& cont) {
return LetList::Let(reify(k),
return LetList::LetBind(reify(k),
[&](const Var& f) {
return cont([&](const Expr& e) { return CallNode::make(f, {e}); });
return cont([&](const Expr& e) { return Call(f, {e}); });
});
}
......@@ -196,7 +196,7 @@ Function ToCPS(const Function& f,
return reify(k, [&](const MCont& kf) {
return VisitExpr(op->cond,
[&](const Expr& v) {
return IfNode::make(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf));
return If(v, VisitExpr(op->true_branch, kf), VisitExpr(op->false_branch, kf));
});
});
}
......@@ -206,9 +206,9 @@ Function ToCPS(const Function& f,
return VisitExpr(op->data, [&](const Expr& v) {
tvm::Array<Clause> clauses;
for (const auto& c : op->clauses) {
clauses.push_back(ClauseNode::make(VisitPattern(c->lhs), VisitExpr(c->rhs, kf)));
clauses.push_back(Clause(VisitPattern(c->lhs), VisitExpr(c->rhs, kf)));
}
return MatchNode::make(v, clauses, op->complete);
return Match(v, clauses, op->complete);
});
});
}
......@@ -216,7 +216,7 @@ Function ToCPS(const Function& f,
Expr VisitExpr_(const RefReadNode* op, const MCont& k) final {
return VisitExpr(op->ref,
[&](const Expr& r) {
return LetList::Let(RefReadNode::make(r), k);
return LetList::LetBind(RefRead(r), k);
});
}
......@@ -225,7 +225,7 @@ Function ToCPS(const Function& f,
[&](const Expr& r) {
return VisitExpr(op->value,
[&](const Expr& v) {
return LetList::Let(RefWriteNode::make(r, v), k);
return LetList::LetBind(RefWrite(r, v), k);
});
});
}
......@@ -235,7 +235,7 @@ Function ToCPS(const Function& f,
std::function<Expr()> next;
next = [&]() {
return (fields.size() == op->fields.size()) ?
k(TupleNode::make(fields)) :
k(Tuple(fields)) :
VisitExpr(op->fields[fields.size()], [&](const Expr& v) {
fields.push_back(v);
return next();
......@@ -246,7 +246,7 @@ Function ToCPS(const Function& f,
Expr VisitExpr_(const TupleGetItemNode* op, const MCont& k) final {
return VisitExpr(op->tuple, [&](const Expr& v) {
return k(TupleGetItemNode::make(v, op->index));
return k(TupleGetItem(v, op->index));
});
}
......@@ -256,7 +256,7 @@ Function ToCPS(const Function& f,
std::function<Expr()> next;
next = [&]() {
if (args.size() == op->args.size()) {
return LetList::Let(CallNode::make(op->op, args, op->attrs, op->type_args), k);
return LetList::LetBind(Call(op->op, args, op->attrs, op->type_args), k);
} else {
return VisitExpr(op->args[args.size()], [&](const Expr& v) {
args.push_back(v);
......@@ -272,7 +272,7 @@ Function ToCPS(const Function& f,
next = [&]() {
if (args.size() == op->args.size()) {
args.push_back(reify(k));
return Expr(CallNode::make(f, args, op->attrs, op->type_args));
return Expr(Call(f, args, op->attrs, op->type_args));
} else {
return VisitExpr(op->args[args.size()], [&](const Expr& v) {
args.push_back(v);
......@@ -287,7 +287,7 @@ Function ToCPS(const Function& f,
}
}
} mut(remap, answer, m, vm, cm);
Var k = VarNode::make("k", Arrow(CPSType(function_type->ret_type, answer), answer));
Var k = Var("k", Arrow(CPSType(function_type->ret_type, answer), answer));
tvm::Array<Var> new_params;
for (const Var& v : f->params) {
new_params.push_back(remap(v));
......@@ -295,7 +295,7 @@ Function ToCPS(const Function& f,
new_params.push_back(k);
return Function(new_params,
mut.VisitExpr(f->body,
[&](const Expr& e) { return CallNode::make(k, {e}); }),
[&](const Expr& e) { return Call(k, {e}); }),
answer,
f->type_params,
f->attrs);
......@@ -311,7 +311,7 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
void VisitExpr_(const VarNode* vn) final {
Var v = GetRef<Var>(vn);
if (vm->count(v) == 0) {
auto ret = VarNode::make(v->name_hint(), CPSType(v->checked_type(), answer));
auto ret = Var(v->name_hint(), CPSType(v->checked_type(), answer));
vm->insert({v, ret});
}
}
......@@ -340,7 +340,7 @@ Function UnCPS(const Function& f) {
CHECK_GT(f->params.size(), 0);
std::vector<Var> new_params;
for (const auto& p : f->params) {
new_params.push_back(VarNode::make(p->name_hint(), p->checked_type()));
new_params.push_back(Var(p->name_hint(), p->checked_type()));
}
auto cont_type = Downcast<FuncType>(new_params.back()->type_annotation);
new_params.pop_back();
......@@ -354,7 +354,7 @@ Function UnCPS(const Function& f) {
new_type_params.pop_back();
// TODO(@M.K.): make alphaequal work on free term
// CHECK(AlphaEqual(cont_type, Arrow(new_ret_type, answer_type)));
auto x = VarNode::make("x", new_ret_type);
auto x = Var("x", new_ret_type);
auto cont = Function({x}, x, new_ret_type, {}, {});
tvm::Array<Expr> args;
for (const auto& p : new_params) {
......@@ -367,7 +367,7 @@ Function UnCPS(const Function& f) {
}
type_args.push_back(new_ret_type);
return Function(new_params,
CallNode::make(f, args, {}, type_args),
Call(f, args, {}, type_args),
new_ret_type,
new_type_params,
f->attrs);
......
......@@ -63,7 +63,7 @@ class GNF : public ExprMutator {
}
static Expr WrapRec(const Var& var, const Expr& val) {
return UseVar(var, val) ? LetNode::make(var, val, var) : val;
return UseVar(var, val) ? Let(var, val, var) : val;
}
Expr VisitExpr_(const LetNode* ln) override {
......
......@@ -138,7 +138,7 @@ class TransformMemorizer : public ObjectRef {
// 2) Insert layout transform on the transformed src.
CHECK(new_src_layout.defined() && dst_layout.defined())
<< "Cannot insert layout transform because there are undefined layouts";
CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined())
CHECK(tir::BijectiveLayout(new_src_layout, dst_layout).defined())
<< "Cannot insert layout transform because there are inconvertible layouts: "
<< new_src_layout << " v.s. " << dst_layout;
return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name());
......@@ -258,7 +258,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
Expr tmp = push_back_one_arg(x);
fields.push_back(tmp);
}
normal_new_args.push_back(TupleNode::make(fields));
normal_new_args.push_back(Tuple(fields));
} else {
Expr tmp = push_back_one_arg(new_arg);
normal_new_args.push_back(tmp);
......@@ -325,7 +325,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt]));
pt++;
}
transformed_args.push_back(TupleNode::make(transformed_tuple_arg));
transformed_args.push_back(Tuple(transformed_tuple_arg));
} else {
transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt]));
pt++;
......@@ -336,21 +336,21 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
// state[node] = (old_out, new_out)
// (handle tuple output)
if (ref_call->checked_type()->IsInstance<TupleTypeNode>()) {
Expr tuple_output = CallNode::make(new_call->op, transformed_args, new_call->attrs);
Expr tuple_output = Call(new_call->op, transformed_args, new_call->attrs);
Array<Expr> fields;
for (size_t i = 0; i < new_out.size(); ++i) {
auto rnode = make_object<LayoutAlternatedExprNode<TransformMemorizerT>>();
rnode->value = TupleGetItemNode::make(tuple_output, i);
rnode->value = TupleGetItem(tuple_output, i);
rnode->old_layout = old_out[i];
rnode->new_layout = new_out[i];
rnode->memorizer = memorizer;
fields.push_back(Expr(rnode));
}
return TupleNode::make(fields);
return Tuple(fields);
} else {
auto rnode = make_object<LayoutAlternatedExprNode<TransformMemorizerT>>();
CHECK_EQ(new_out.size(), 1);
rnode->value = CallNode::make(new_call->op, transformed_args, new_call->attrs);
rnode->value = Call(new_call->op, transformed_args, new_call->attrs);
rnode->old_layout = old_out[0];
rnode->new_layout = new_out[0];
rnode->memorizer = memorizer;
......
......@@ -350,20 +350,18 @@ Array<PrimExpr> BijectiveLayout::BackwardShape(const Array<PrimExpr>& shape) con
self->src_layout->axes, self->backward_rule);
}
BijectiveLayout BijectiveLayoutNode::make(const Layout& src_layout,
const Layout& dst_layout) {
BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) {
auto n = make_object<BijectiveLayoutNode>();
n->src_layout = src_layout;
n->dst_layout = dst_layout;
n->src_layout = std::move(src_layout);
n->dst_layout = std::move(dst_layout);
if (!GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) {
// not convertible
return BijectiveLayout();
// To be consistent with previous behavior, a nullptr layout is created
// when argument is invalid.
if (GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) {
CHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout));
data_ = std::move(n);
}
CHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout));
return BijectiveLayout(n);
}
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......@@ -398,7 +396,9 @@ TVM_REGISTER_GLOBAL("tir.LayoutGetItem")
});
TVM_REGISTER_GLOBAL("tir.BijectiveLayout")
.set_body_typed(BijectiveLayoutNode::make);
.set_body_typed([](Layout src_layout, Layout dst_layout) -> BijectiveLayout {
return BijectiveLayout(src_layout, dst_layout);
});
TVM_REGISTER_GLOBAL("tir.BijectiveLayoutForwardIndex")
.set_body_method(&BijectiveLayout::ForwardIndex);
......
......@@ -76,12 +76,12 @@ TVM_REGISTER_GLOBAL("relay.backend.lower_call")
TEST(Relay, BuildModule) {
auto tensor_type = relay::TensorType({2, 3}, DataType::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
auto a = relay::Var("a", tensor_type);
auto b = relay::Var("b", tensor_type);
auto add_op = relay::Op::Get("add");
auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::VarNode::make("c", tensor_type);
auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
auto x = relay::Call(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::Var("c", tensor_type);
auto y = relay::Call(add_op, {x, c}, tvm::Attrs(), {});
auto func = relay::Function(relay::FreeVars(y), y, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
......
......@@ -27,11 +27,11 @@
TEST(Relay, SelfReference) {
using namespace tvm;
auto tensor_type = relay::TensorType({}, DataType::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto x = relay::Var("x", relay::Type());
auto f = relay::Function(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
CHECK(f->IsInstance<BaseFuncNode>());
auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, Array<relay::Expr>{ y });
auto y = relay::Var("y", tensor_type);
auto call = relay::Call(f, Array<relay::Expr>{ y });
auto fx = relay::Function(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto mod = IRModule::FromExpr(fx);
mod = relay::transform::InferType()(mod);
......
......@@ -41,17 +41,17 @@ TEST(Relay, Sequential) {
tvm::runtime::NDArray::Empty({1, 2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
// Create a function for optimization.
auto c = relay::ConstantNode::make(c_data);
auto a = relay::VarNode::make("a", tensor_type);
auto x = relay::VarNode::make("x", tensor_type);
auto c = relay::Constant(c_data);
auto a = relay::Var("a", tensor_type);
auto x = relay::Var("x", tensor_type);
auto add_op = relay::Op::Get("add");
auto y = relay::CallNode::make(add_op, {c, c});
y = relay::CallNode::make(add_op, {x, y});
auto z = relay::CallNode::make(add_op, {y, c});
auto z1 = relay::CallNode::make(add_op, {y, c});
auto z2 = relay::CallNode::make(add_op, {z, z1});
auto y = relay::Call(add_op, {c, c});
y = relay::Call(add_op, {x, y});
auto z = relay::Call(add_op, {y, c});
auto z1 = relay::Call(add_op, {y, c});
auto z2 = relay::Call(add_op, {z, z1});
// Let expression and varaible a should be dead-code eliminated.
auto z3 = relay::LetNode::make(a, c, z2);
auto z3 = relay::Let(a, c, z2);
relay::Function func =
relay::Function(relay::FreeVars(z3), z3, relay::Type(), {});
......@@ -89,12 +89,12 @@ TEST(Relay, Sequential) {
CHECK(f.defined());
// Expected function
auto c1 = relay::ConstantNode::make(c_data);
auto x1 = relay::VarNode::make("x", tensor_type);
auto y1 = relay::CallNode::make(add_op, {c1, c1});
y1 = relay::CallNode::make(add_op, {x1, y1});
auto zz = relay::CallNode::make(add_op, {y1, c1});
zz = relay::CallNode::make(add_op, {zz, zz});
auto c1 = relay::Constant(c_data);
auto x1 = relay::Var("x", tensor_type);
auto y1 = relay::Call(add_op, {c1, c1});
y1 = relay::Call(add_op, {x1, y1});
auto zz = relay::Call(add_op, {y1, c1});
zz = relay::Call(add_op, {zz, zz});
relay::Function expected_func =
relay::Function(relay::FreeVars(zz), zz, relay::Type(), {});
......
......@@ -1258,7 +1258,7 @@ inline Tensor layout_transform(const Tensor& src,
CHECK(src_layout_struct.defined() && dst_layout_struct.defined())
<< "cannot convert from/to undefined layout";
auto layout_converter = BijectiveLayoutNode::make(src_layout_struct, dst_layout_struct);
auto layout_converter = tir::BijectiveLayout(src_layout_struct, dst_layout_struct);
CHECK(layout_converter.defined())
<< "cannot convert from " << src_layout << " to " << dst_layout;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment