Commit 51785062 by James Gilles Committed by Tianqi Chen

[REFACTOR] Use more TypedPackedFuncs (#2981)

* Add `set_body_simple` to Registry, refactor a lot of code to use it

* Add more types to Relay PackedFuncs

* Add Registry::set_body_method to easily make Node methods into
PackedFuncs

* Add set_body_method, set_body_node_method; start typing api_lang

* Add some docs, remove unused script

* Fix mysterious linter problem

* Touch up api_ir.cc

* Fix some issues with TOPI argument counts

* Revert changes to topi.cc to avoid problems with optional arguments

* A little more cleanup

* Type more of the api _ functions

* Whitespace

* Finalize names and docs for new registry helpers

* Update docs
parent 57f47a17
......@@ -83,6 +83,169 @@ class Registry {
Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed());
}
/*!
* \brief set the body of the function to the given function pointer.
* Note that this doesn't work with lambdas, you need to
* explicitly give a type for those.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* int multiply(int x, int y) {
* return x * y;
* }
*
* TVM_REGISTER_API("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* \endcode
*
* \param f The function to forward to.
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename R, typename ...Args>
Registry& set_body_typed(R (*f)(Args...)) {
return set_body(TypedPackedFunc<R(Args...)>(f));
}
/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam T the type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...)) {
return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
}
/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam T the type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...) const) {
return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
}
/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct ExampleNode: BaseNode {
* int doThing(int x);
* }
*
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
* // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
}
/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct ExampleNode: BaseNode {
* int doThing(int x);
* }
*
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
* // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
const TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
}
/*!
* \brief Register a function with given name
* \param name The name of the function.
......
......@@ -360,9 +360,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph")
});
TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = GraphKeyNode::make(args[0], args[1], args[2]);
});
.set_body_typed(GraphKeyNode::make);
// This can be used to extract workloads from nnvm compiler
TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
......
......@@ -235,8 +235,6 @@ std::string GraphDeepCompare(const Graph& a,
}
TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = GraphDeepCompare(args[0], args[1], args[2]);
});
.set_body_typed(GraphDeepCompare);
} // namespace compiler
} // namespace nnvm
......@@ -31,73 +31,51 @@ namespace tvm {
namespace arith {
TVM_REGISTER_API("arith.intset_single_point")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::single_point(args[0]);
});
.set_body_typed(IntSet::single_point);
TVM_REGISTER_API("arith.intset_vector")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::vector(args[0]);
});
.set_body_typed(IntSet::vector);
TVM_REGISTER_API("arith.intset_interval")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]);
});
.set_body_typed(IntSet::interval);
TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectLinearEquation(args[0], args[1]);
});
.set_body_typed(DetectLinearEquation);
TVM_REGISTER_API("arith.DetectClipBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectClipBound(args[0], args[1]);
});
.set_body_typed(DetectClipBound);
TVM_REGISTER_API("arith.DeduceBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1],
args[2].operator Map<Var, IntSet>(),
args[3].operator Map<Var, IntSet>());
});
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
Expr v, Expr cond,
const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map
) {
return DeduceBound(v, cond, hint_map, relax_map);
});
TVM_REGISTER_API("arith.DomainTouched")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DomainTouched(args[0], args[1], args[2], args[3]);
});
.set_body_typed(DomainTouched);
TVM_REGISTER_API("_IntervalSetGetMin")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().min();
});
.set_body_method(&IntSet::min);
TVM_REGISTER_API("_IntervalSetGetMax")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().max();
});
.set_body_method(&IntSet::max);
TVM_REGISTER_API("_IntSetIsNothing")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_nothing();
});
.set_body_method(&IntSet::is_nothing);
TVM_REGISTER_API("_IntSetIsEverything")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_everything();
});
.set_body_method(&IntSet::is_everything);
TVM_REGISTER_API("arith._make_ConstIntBound")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstIntBoundNode::make(args[0], args[1]);
});
.set_body_typed(ConstIntBoundNode::make);
TVM_REGISTER_API("arith._make_ModularSet")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ModularSetNode::make(args[0], args[1]);
});
.set_body_typed(ModularSetNode::make);
TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......
......@@ -50,9 +50,8 @@ TVM_REGISTER_API("_load_json")
.set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);
TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]);
});
.set_body_typed(TVMSetStream);
TVM_REGISTER_API("_save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
......
......@@ -41,8 +41,6 @@ TVM_REGISTER_API("codegen._Build")
});
TVM_REGISTER_API("module._PackImportsToC")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = PackImportsToC(args[0], args[1]);
});
.set_body_typed(PackImportsToC);
} // namespace codegen
} // namespace tvm
......@@ -31,54 +31,43 @@ namespace tvm {
namespace ir {
TVM_REGISTER_API("_Var")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Variable::make(args[1], args[0]);
.set_body_typed<VarExpr(std::string, Type)>([](std::string s, Type t) {
return Variable::make(t, s);
});
TVM_REGISTER_API("make.abs")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::abs(args[0]);
});
.set_body_typed(tvm::abs);
TVM_REGISTER_API("make.floor")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::floor(args[0]);
});
.set_body_typed(tvm::floor);
TVM_REGISTER_API("make.ceil")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::ceil(args[0]);
});
.set_body_typed(tvm::ceil);
TVM_REGISTER_API("make.round")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::round(args[0]);
});
.set_body_typed(tvm::round);
TVM_REGISTER_API("make.trunc")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::trunc(args[0]);
});
.set_body_typed(tvm::trunc);
TVM_REGISTER_API("make._cast")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = tvm::cast(args[0], args[1]);
});
.set_body_typed(tvm::cast);
TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Range::make_by_min_extent(args[0], args[1]);
});
.set_body_typed(Range::make_by_min_extent);
TVM_REGISTER_API("make.For")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = For::make(args[0],
args[1],
args[2],
static_cast<ForType>(args[3].operator int()),
static_cast<HalideIR::DeviceAPI>(args[4].operator int()),
args[5]);
});
.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
VarExpr loop_var, Expr min, Expr extent,
int for_type, int device_api, Stmt body
) {
return For::make(loop_var,
min,
extent,
static_cast<ForType>(for_type),
static_cast<HalideIR::DeviceAPI>(device_api),
body);
});
TVM_REGISTER_API("make.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) {
......@@ -101,114 +90,87 @@ TVM_REGISTER_API("make.Store")
});
TVM_REGISTER_API("make.Realize")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Realize::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5]);
});
.set_body_typed(Realize::make);
TVM_REGISTER_API("make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Call::make(args[0],
args[1],
args[2],
static_cast<Call::CallType>(args[3].operator int()),
args[4],
args[5]);
});
.set_body_typed<Expr(Type, std::string, Array<Expr>, int, FunctionRef, int)>([](
Type type, std::string name,
Array<Expr> args, int call_type,
FunctionRef func, int value_index
) {
return Call::make(type,
name,
args,
static_cast<Call::CallType>(call_type),
func,
value_index);
});
TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CommReducerNode::make(args[0],
args[1],
args[2],
args[3]);
});
.set_body_typed(CommReducerNode::make);
// make from two arguments
#define REGISTER_MAKE1(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0]); \
}) \
#define REGISTER_MAKE2(Node) \
#define REGISTER_MAKE(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1]); \
}) \
#define REGISTER_MAKE3(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2]); \
}) \
#define REGISTER_MAKE4(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \
#define REGISTER_MAKE5(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \
REGISTER_MAKE5(Reduce);
REGISTER_MAKE4(AttrStmt);
REGISTER_MAKE2(IntImm);
REGISTER_MAKE2(UIntImm);
REGISTER_MAKE2(FloatImm);
REGISTER_MAKE1(StringImm);
REGISTER_MAKE2(Add);
REGISTER_MAKE2(Sub);
REGISTER_MAKE2(Mul);
REGISTER_MAKE2(Div);
REGISTER_MAKE2(Mod);
REGISTER_MAKE2(Min);
REGISTER_MAKE2(Max);
REGISTER_MAKE2(EQ);
REGISTER_MAKE2(NE);
REGISTER_MAKE2(LT);
REGISTER_MAKE2(LE);
REGISTER_MAKE2(GT);
REGISTER_MAKE2(GE);
REGISTER_MAKE2(And);
REGISTER_MAKE2(Or);
REGISTER_MAKE1(Not);
REGISTER_MAKE3(Select);
REGISTER_MAKE3(Ramp);
REGISTER_MAKE2(Cast);
REGISTER_MAKE2(Broadcast);
REGISTER_MAKE2(Shuffle);
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
.set_body_typed(Node::make); \
REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);
REGISTER_MAKE(IntImm);
REGISTER_MAKE(UIntImm);
REGISTER_MAKE(FloatImm);
REGISTER_MAKE(StringImm);
REGISTER_MAKE(Add);
REGISTER_MAKE(Sub);
REGISTER_MAKE(Mul);
REGISTER_MAKE(Div);
REGISTER_MAKE(Mod);
REGISTER_MAKE(Min);
REGISTER_MAKE(Max);
REGISTER_MAKE(EQ);
REGISTER_MAKE(NE);
REGISTER_MAKE(LT);
REGISTER_MAKE(LE);
REGISTER_MAKE(GT);
REGISTER_MAKE(GE);
REGISTER_MAKE(And);
REGISTER_MAKE(Or);
REGISTER_MAKE(Not);
REGISTER_MAKE(Select);
REGISTER_MAKE(Ramp);
REGISTER_MAKE(Cast);
REGISTER_MAKE(Broadcast);
REGISTER_MAKE(Shuffle);
REGISTER_MAKE(Let);
REGISTER_MAKE(LetStmt);
REGISTER_MAKE(AssertStmt);
REGISTER_MAKE(ProducerConsumer);
REGISTER_MAKE(Provide);
REGISTER_MAKE(Prefetch);
REGISTER_MAKE(Free);
REGISTER_MAKE(IfThenElse);
REGISTER_MAKE(Evaluate);
// overloaded, needs special handling
TVM_REGISTER_API("make.Block")
.set_body_typed(static_cast<Stmt (*)(Stmt, Stmt)>(Block::make));
// has default args
TVM_REGISTER_API("make.Allocate")
.set_body_typed<Stmt(VarExpr, Type, Array<Expr>, Expr, Stmt)>([](
VarExpr buffer_var, Type type, Array<Expr> extents, Expr condition, Stmt body
){
return Allocate::make(buffer_var, type, extents, condition, body);
});
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
Expr a = args[0], b = args[1]; \
*ret = (Func(a, b)); \
.set_body_typed<Expr(Expr, Expr)>([](Expr a, Expr b) { \
return (Func(a, b)); \
})
#define REGISTER_MAKE_BIT_OP(Node, Func) \
......
......@@ -119,68 +119,43 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
});
// make from two arguments
#define REGISTER_PASS1(PassName) \
#define REGISTER_PASS(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \
}) \
#define REGISTER_PASS2(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \
}) \
#define REGISTER_PASS3(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2]); \
}) \
#define REGISTER_PASS4(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2], args[3]); \
}) \
#define REGISTER_PASS5(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1], args[2], args[3], args[4]); \
}) \
REGISTER_PASS1(ConvertSSA);
REGISTER_PASS1(VerifySSA);
REGISTER_PASS1(RewriteUnsafeSelect);
REGISTER_PASS4(Inline);
REGISTER_PASS4(IRTransform);
REGISTER_PASS1(VectorizeLoop);
REGISTER_PASS5(UnrollLoop);
REGISTER_PASS3(InjectCopyIntrin);
REGISTER_PASS2(ThreadSync);
REGISTER_PASS5(MakeAPI);
REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(CoProcSync);
REGISTER_PASS1(LowerStorageAccessInfo);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS2(InjectDoubleBuffer);
REGISTER_PASS2(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LiftAttrScope);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerWarpMemory);
REGISTER_PASS2(RemapThreadAxis);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode);
REGISTER_PASS1(DecorateDeviceScope);
REGISTER_PASS1(InstrumentBoundCheckers);
.set_body_typed(PassName); \
REGISTER_PASS(ConvertSSA);
REGISTER_PASS(VerifySSA);
REGISTER_PASS(RewriteUnsafeSelect);
REGISTER_PASS(Inline);
REGISTER_PASS(IRTransform);
REGISTER_PASS(VectorizeLoop);
REGISTER_PASS(UnrollLoop);
REGISTER_PASS(InjectCopyIntrin);
REGISTER_PASS(ThreadSync);
REGISTER_PASS(MakeAPI);
REGISTER_PASS(BindDeviceType);
REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite);
REGISTER_PASS(CoProcSync);
REGISTER_PASS(LowerStorageAccessInfo);
REGISTER_PASS(InjectVirtualThread);
REGISTER_PASS(InjectPrefetch);
REGISTER_PASS(InjectDoubleBuffer);
REGISTER_PASS(LoopPartition);
REGISTER_PASS(RemoveNoOp);
REGISTER_PASS(SplitPipeline);
REGISTER_PASS(LiftAttrScope);
REGISTER_PASS(NarrowChannelAccess);
REGISTER_PASS(LowerThreadAllreduce);
REGISTER_PASS(LowerWarpMemory);
REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS(LowerIntrin);
REGISTER_PASS(LowerTVMBuiltin);
REGISTER_PASS(CombineContextCall);
REGISTER_PASS(VerifyMemory);
REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS(InstrumentBoundCheckers);
} // namespace ir
} // namespace tvm
......@@ -33,15 +33,11 @@ namespace tvm {
namespace schedule {
TVM_REGISTER_API("schedule.AutoInlineElemWise")
.set_body([](TVMArgs args, TVMRetValue* ret) {
AutoInlineElemWise(args[0]);
});
.set_body_typed(AutoInlineElemWise);
TVM_REGISTER_API("schedule.AutoInlineInjective")
.set_body([](TVMArgs args, TVMRetValue* ret) {
AutoInlineInjective(args[0]);
});
.set_body_typed(AutoInlineInjective);
TVM_REGISTER_API("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......@@ -51,25 +47,17 @@ TVM_REGISTER_API("schedule.ScheduleOps")
*ret = ScheduleOps(args[0], args[1], args[2]);
});
#define REGISTER_SCHEDULE_PASS1(PassName) \
TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \
}) \
#define REGISTER_SCHEDULE_PASS2(PassName) \
#define REGISTER_SCHEDULE_PASS(PassName) \
TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0], args[1]); \
}) \
.set_body_typed(PassName); \
REGISTER_SCHEDULE_PASS1(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(CreateAttachPath);
REGISTER_SCHEDULE_PASS1(ScanGetBody);
REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis);
REGISTER_SCHEDULE_PASS(InferBound);
REGISTER_SCHEDULE_PASS(CreateReadGraph);
REGISTER_SCHEDULE_PASS(PostDFSOrder);
REGISTER_SCHEDULE_PASS(CreateAttachPath);
REGISTER_SCHEDULE_PASS(ScanGetBody);
REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
} // namespace schedule
} // namespace tvm
......@@ -263,8 +263,6 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
}
TVM_REGISTER_API("codegen.build_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildOpenCL(args[0]);
});
.set_body_typed(BuildOpenCL);
} // namespace codegen
} // namespace tvm
......@@ -302,9 +302,7 @@ runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) {
}
TVM_REGISTER_API("codegen.build_opengl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildOpenGL(args[0]);
});
.set_body_typed(BuildOpenGL);
} // namespace codegen
} // namespace tvm
......@@ -164,9 +164,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
}
TVM_REGISTER_API("codegen.build_sdaccel")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildSDAccel(args[0], args[1]);
});
.set_body_typed(BuildSDAccel);
} // namespace codegen
} // namespace tvm
......@@ -265,9 +265,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
}
TVM_REGISTER_API("codegen.build_rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildAMDGPU(args[0], args[1]);
});
.set_body_typed(BuildAMDGPU);
} // namespace codegen
} // namespace tvm
......
......@@ -243,9 +243,7 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
}
TVM_REGISTER_API("codegen.build_nvptx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildNVPTX(args[0], args[1]);
});
.set_body_typed(BuildNVPTX);
} // namespace codegen
} // namespace tvm
......
......@@ -155,8 +155,6 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
}
TVM_REGISTER_API("codegen.build_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildCUDA(args[0]);
});
.set_body_typed(BuildCUDA);
} // namespace codegen
} // namespace tvm
......@@ -188,8 +188,6 @@ runtime::Module DeviceSourceModuleCreate(
}
TVM_REGISTER_GLOBAL("module.source_module_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = SourceModuleCreate(args[0], args[1]);
});
.set_body_typed(SourceModuleCreate);
} // namespace codegen
} // namespace tvm
......@@ -103,9 +103,7 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
}
TVM_REGISTER_API("codegen.build_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildSPIRV(args[0]);
});
.set_body_typed(BuildSPIRV);
} // namespace codegen
} // namespace tvm
......@@ -522,8 +522,6 @@ runtime::Module BuildStackVM(const Array<LoweredFunc>& funcs) {
}
TVM_REGISTER_API("codegen.build_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildStackVM(args[0]);
});
.set_body_typed(BuildStackVM);
} // namespace codegen
} // namespace tvm
......@@ -51,9 +51,7 @@ Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
}
TVM_REGISTER_API("relay._make.Closure")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ClosureNode::make(args[0], args[1]);
});
.set_body_typed(ClosureNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
......@@ -67,9 +65,7 @@ TupleValue TupleValueNode::make(tvm::Array<Value> value) {
}
TVM_REGISTER_API("relay._make.TupleValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleValueNode::make(args[0]);
});
.set_body_typed(TupleValueNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleValueNode>([](const TupleValueNode* node, tvm::IRPrinter* p) {
......@@ -90,10 +86,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
});
TVM_REGISTER_API("relay._make.TensorValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
runtime::NDArray data = args[0];
*ret = TensorValueNode::make(data);
});
.set_body_typed(TensorValueNode::make);
RefValue RefValueNode::make(Value value) {
NodePtr<RefValueNode> n = make_node<RefValueNode>();
......@@ -102,9 +95,7 @@ RefValue RefValueNode::make(Value value) {
}
TVM_REGISTER_API("relay._make.RefValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefValueNode::make(args[0]);
});
.set_body_typed(RefValueNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const RefValueNode* node,
......@@ -121,9 +112,7 @@ ConstructorValue ConstructorValueNode::make(Constructor constructor,
}
TVM_REGISTER_API("relay._make.ConstructorValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstructorValueNode::make(args[0], args[1]);
});
.set_body_typed(ConstructorValueNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
......@@ -614,9 +603,7 @@ CreateInterpreter(
}
TVM_REGISTER_API("relay.backend.CreateInterpreter")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CreateInterpreter(args[0], args[1], args[2]);
});
.set_body_typed(CreateInterpreter);
TVM_REGISTER_NODE_TYPE(ClosureNode);
TVM_REGISTER_NODE_TYPE(TupleValueNode);
......
......@@ -36,9 +36,7 @@ PatternWildcard PatternWildcardNode::make() {
TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_API("relay._make.PatternWildcard")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PatternWildcardNode::make();
});
.set_body_typed(PatternWildcardNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node,
......@@ -55,9 +53,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) {
TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_API("relay._make.PatternVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PatternVarNode::make(args[0]);
});
.set_body_typed(PatternVarNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternVarNode>([](const PatternVarNode* node,
......@@ -76,9 +72,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor,
TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_API("relay._make.PatternConstructor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PatternConstructorNode::make(args[0], args[1]);
});
.set_body_typed(PatternConstructorNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node,
......@@ -100,9 +94,7 @@ Constructor ConstructorNode::make(std::string name_hint,
TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_API("relay._make.Constructor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstructorNode::make(args[0], args[1], args[2]);
});
.set_body_typed(ConstructorNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ConstructorNode* node,
......@@ -124,9 +116,7 @@ TypeData TypeDataNode::make(GlobalTypeVar header,
TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_API("relay._make.TypeData")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeDataNode::make(args[0], args[1], args[2]);
});
.set_body_typed(TypeDataNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeDataNode>([](const TypeDataNode* node,
......@@ -145,9 +135,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) {
TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_API("relay._make.Clause")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ClauseNode::make(args[0], args[1]);
});
.set_body_typed(ClauseNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClauseNode>([](const ClauseNode* node,
......@@ -166,9 +154,7 @@ Match MatchNode::make(Expr data, tvm::Array<Clause> clauses) {
TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_API("relay._make.Match")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = MatchNode::make(args[0], args[1]);
});
.set_body_typed(MatchNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const MatchNode* node,
......
......@@ -505,18 +505,18 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AlphaEqualHandler(false).Equal(args[0], args[1]);
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(false).Equal(a, b);
});
TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]);
.set_body_typed<bool(Type, Type)>([](Type a, Type b) {
return AlphaEqualHandler(false).TypeEqual(a, b);
});
TVM_REGISTER_API("relay._make._graph_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AlphaEqualHandler(true).Equal(args[0], args[1]);
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(true).Equal(a, b);
});
} // namespace relay
} // namespace tvm
......@@ -52,9 +52,7 @@ SourceName SourceName::Get(const std::string& name) {
}
TVM_REGISTER_API("relay._make.SourceName")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SourceName::Get(args[0]);
});
.set_body_typed(SourceName::Get);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
......@@ -78,9 +76,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SpanNode::make(args[0], args[1], args[2]);
});
.set_body_typed(SpanNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
......@@ -91,11 +87,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef node_ref = args[0];
.set_body_typed<void(NodeRef, Span)>([](NodeRef node_ref, Span sp) {
auto rn = node_ref.as_derived<RelayNode>();
CHECK(rn);
Span sp = args[1];
rn->span = sp;
});
......
......@@ -39,9 +39,7 @@ Constant ConstantNode::make(runtime::NDArray data) {
TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstantNode::make(args[0]);
});
.set_body_typed(ConstantNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
......@@ -73,9 +71,7 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleNode::make(args[0]);
});
.set_body_typed(TupleNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
......@@ -99,9 +95,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = VarNode::make(args[0].operator std::string(), args[1]);
});
.set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
......@@ -122,9 +116,7 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GlobalVarNode::make(args[0]);
});
.set_body_typed(GlobalVarNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
......@@ -201,9 +193,7 @@ Function FunctionSetAttr(const Function& func, const std::string& key, const Nod
TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]);
});
.set_body_typed(FunctionNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode* node,
......@@ -226,9 +216,7 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CallNode::make(args[0], args[1], args[2], args[3]);
});
.set_body_typed(CallNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
......@@ -247,9 +235,7 @@ Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = LetNode::make(args[0], args[1], args[2]);
});
.set_body_typed(LetNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
......@@ -267,9 +253,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = IfNode::make(args[0], args[1], args[2]);
});
TVM_REGISTER_API("relay._make.If")
.set_body_typed(IfNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
......@@ -286,9 +271,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleGetItemNode::make(args[0], args[1]);
});
TVM_REGISTER_API("relay._make.TupleGetItem")
.set_body_typed(TupleGetItemNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
......@@ -301,9 +285,8 @@ RefCreate RefCreateNode::make(Expr value) {
return RefCreate(n);
}
TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefCreateNode::make(args[0]);
});
TVM_REGISTER_API("relay._make.RefCreate")
.set_body_typed(RefCreateNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
......@@ -317,9 +300,7 @@ RefRead RefReadNode::make(Expr ref) {
}
TVM_REGISTER_API("relay._make.RefRead")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefReadNode::make(args[0]);
});
.set_body_typed(RefReadNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
......@@ -334,9 +315,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
}
TVM_REGISTER_API("relay._make.RefWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefWriteNode::make(args[0], args[1]);
});
.set_body_typed(RefWriteNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
......@@ -344,9 +323,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
});
TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body([](TVMArgs args, TVMRetValue* ret) {
TempExpr temp = args[0];
*ret = temp->Realize();
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize();
});
} // namespace relay
......
......@@ -346,9 +346,8 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
}
TVM_REGISTER_API("relay._ir_pass.post_order_visit")
.set_body([](TVMArgs args, TVMRetValue *ret) {
PackedFunc f = args[1];
PostOrderVisit(args[0], [f](const Expr& n) {
.set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) {
PostOrderVisit(expr, [f](const Expr& n) {
f(n);
});
});
......
......@@ -410,14 +410,14 @@ size_t StructuralHash::operator()(const Expr& expr) const {
}
TVM_REGISTER_API("relay._ir_pass._expr_hash")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(RelayHashHandler().Hash(args[0]));
});
.set_body_typed<int64_t(NodeRef)>([](NodeRef ref) {
return static_cast<int64_t>(RelayHashHandler().Hash(ref));
});
TVM_REGISTER_API("relay._ir_pass._type_hash")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = static_cast<int64_t>(RelayHashHandler().TypeHash(args[0]));
});
.set_body_typed<int64_t(Type)>([](Type type) {
return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
});
} // namespace relay
} // namespace tvm
......@@ -181,66 +181,43 @@ Module ModuleNode::FromExpr(
TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = ModuleNode::make(args[0], args[1]);
});
.set_body_typed(ModuleNode::make);
TVM_REGISTER_API("relay._make.Module_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
mod->Add(args[1], args[2], args[3]);
});
.set_body_method<Module>(&ModuleNode::Add);
TVM_REGISTER_API("relay._module.Module_AddDef")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
mod->AddDef(args[1], args[2]);
});
.set_body_method<Module>(&ModuleNode::AddDef);
TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
*ret = mod->GetGlobalVar(args[1]);
});
.set_body_method<Module>(&ModuleNode::GetGlobalVar);
TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
*ret = mod->GetGlobalTypeVar(args[1]);
});
.set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
TVM_REGISTER_API("relay._module.Module_Lookup")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
GlobalVar var = args[1];
*ret = mod->Lookup(var);
.set_body_typed<Function(Module, GlobalVar)>([](Module mod, GlobalVar var) {
return mod->Lookup(var);
});
TVM_REGISTER_API("relay._module.Module_Lookup_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
std::string var_name = args[1];
*ret = mod->Lookup(var_name);
.set_body_typed<Function(Module, std::string)>([](Module mod, std::string var) {
return mod->Lookup(var);
});
TVM_REGISTER_API("relay._module.Module_LookupDef")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
GlobalTypeVar var = args[1];
*ret = mod->LookupDef(var);
.set_body_typed<TypeData(Module, GlobalTypeVar)>([](Module mod, GlobalTypeVar var) {
return mod->LookupDef(var);
});
TVM_REGISTER_API("relay._module.Module_LookupDef_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
std::string var_name = args[1];
*ret = mod->LookupDef(var_name);
.set_body_typed<TypeData(Module, std::string)>([](Module mod, std::string var) {
return mod->LookupDef(var);
});
TVM_REGISTER_API("relay._module.Module_Update")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Module mod = args[0];
mod->Update(args[1]);
.set_body_typed<void(Module, Module)>([](Module mod, Module from) {
mod->Update(from);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......
......@@ -56,10 +56,7 @@ IndexExpr TensorTypeNode::Size() const {
TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Array<IndexExpr> shape = args[0];
*ret = TensorTypeNode::make(shape, args[1]);
});
.set_body_typed(TensorTypeNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
......@@ -77,10 +74,8 @@ TypeVar TypeVarNode::make(std::string name, Kind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1];
*ret =
TypeVarNode::make(args[0], static_cast<Kind>(kind));
.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......@@ -100,10 +95,9 @@ GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_API("relay._make.GlobalTypeVar")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[1];
*ret = GlobalTypeVarNode::make(args[0], static_cast<Kind>(kind));
});
.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
return GlobalTypeVarNode::make(name, static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node,
......@@ -122,9 +116,7 @@ TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_API("relay._make.TypeCall")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeCallNode::make(args[0], args[1]);
});
.set_body_typed(TypeCallNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeCallNode>([](const TypeCallNode* node,
......@@ -142,9 +134,8 @@ IncompleteType IncompleteTypeNode::make(Kind kind) {
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int kind = args[0];
*ret = IncompleteTypeNode::make(static_cast<Kind>(kind));
.set_body_typed<IncompleteType(int)>([](int kind) {
return IncompleteTypeNode::make(static_cast<Kind>(kind));
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......@@ -169,9 +160,7 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
});
.set_body_typed(FuncTypeNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
......@@ -196,9 +185,7 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
});
.set_body_typed(TypeRelationNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
......@@ -216,9 +203,7 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleTypeNode::make(args[0]);
});
.set_body_typed(TupleTypeNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
......@@ -233,9 +218,7 @@ RefType RefTypeNode::make(Type value) {
}
TVM_REGISTER_API("relay._make.RefType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefTypeNode::make(args[0]);
});
.set_body_typed(RefTypeNode::make);
TVM_REGISTER_NODE_TYPE(RefTypeNode);
......
......@@ -64,9 +64,7 @@ Expr MakeDebug(Expr expr, std::string name) {
}
TVM_REGISTER_API("relay.op._make.debug")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeDebug, args, rv);
});
.set_body_typed(MakeDebug);
} // namespace relay
} // namespace tvm
......
......@@ -105,9 +105,7 @@ Expr MakeResize(Expr data,
TVM_REGISTER_API("relay.op.image._make.resize")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeResize, args, rv);
});
.set_body_typed(MakeResize);
RELAY_REGISTER_OP("image.resize")
......
......@@ -170,9 +170,7 @@ Expr MakeConv2D(Expr data,
TVM_REGISTER_API("relay.op.nn._make.conv2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2D, args, rv);
});
.set_body_typed(MakeConv2D);
RELAY_REGISTER_OP("nn.conv2d")
......@@ -324,9 +322,7 @@ Expr MakeConv2DTranspose(Expr data,
TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2DTranspose, args, rv);
});
.set_body_typed(MakeConv2DTranspose);
RELAY_REGISTER_OP("nn.conv2d_transpose")
.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
......@@ -465,9 +461,7 @@ Expr MakeConv2DWinograd(Expr data,
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 13>(MakeConv2DWinograd, args, rv);
});
.set_body_typed(MakeConv2DWinograd);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
......@@ -530,9 +524,7 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight,
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
});
.set_body_typed(MakeConv2DWinogradWeightTransform);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
......@@ -580,9 +572,7 @@ Expr MakeConv2DWinogradNNPACK(Expr data,
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2DWinogradNNPACK, args, rv);
});
.set_body_typed(MakeConv2DWinogradNNPACK);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
......@@ -649,9 +639,7 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
});
.set_body_typed(MakeConv2DWinogradNNPACKWeightTransform);
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
......@@ -698,9 +686,7 @@ Expr MakeConv2DNCHWc(Expr data,
}
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeConv2DNCHWc, args, rv);
});
.set_body_typed(MakeConv2DNCHWc);
RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
......@@ -750,9 +736,7 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data,
}
TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeDepthwiseConv2DNCHWc, args, rv);
});
.set_body_typed(MakeDepthwiseConv2DNCHWc);
RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
......@@ -910,9 +894,7 @@ Expr MakeDeformableConv2D(Expr data,
}
TVM_REGISTER_API("relay.op.nn._make.deformable_conv2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 14>(MakeDeformableConv2D, args, rv);
});
.set_body_typed(MakeDeformableConv2D);
} // namespace relay
......
......@@ -78,9 +78,7 @@ Expr MakeBiasAdd(Expr data,
TVM_REGISTER_API("relay.op.nn._make.bias_add")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeBiasAdd, args, rv);
});
.set_body_typed(MakeBiasAdd);
RELAY_REGISTER_OP("nn.bias_add")
......@@ -145,9 +143,7 @@ Expr MakeDense(Expr data,
TVM_REGISTER_API("relay.op.nn._make.dense")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeDense, args, rv);
});
.set_body_typed(MakeDense);
RELAY_REGISTER_OP("nn.dense")
......@@ -179,9 +175,7 @@ Expr MakeLeakyRelu(Expr data,
TVM_REGISTER_API("relay.op.nn._make.leaky_relu")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeLeakyRelu, args, rv);
});
.set_body_typed(MakeLeakyRelu);
RELAY_REGISTER_OP("nn.leaky_relu")
......@@ -244,9 +238,7 @@ Expr MakePRelu(Expr data,
TVM_REGISTER_API("relay.op.nn._make.prelu")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakePRelu, args, rv);
});
.set_body_typed(MakePRelu);
RELAY_REGISTER_OP("nn.prelu")
......@@ -276,17 +268,14 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
TVM_REGISTER_API("relay.op.nn._make.softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) {
auto attrs = make_node<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.softmax");
return CallNode::make(op, {data}, Attrs(attrs), {});
};
runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
auto attrs = make_node<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.softmax");
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("nn.softmax")
.describe(R"code(Softmax layer.
......@@ -314,15 +303,11 @@ RELAY_REGISTER_OP("nn.softmax")
// relay.nn.log_softmax
TVM_REGISTER_API("relay.op.nn._make.log_softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
auto make_func = [](Expr data, int axis) {
auto attrs = make_node<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.log_softmax");
return CallNode::make(op, {data}, Attrs(attrs), {});
};
runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
auto attrs = make_node<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.log_softmax");
return CallNode::make(op, {data}, Attrs(attrs), {});
});
RELAY_REGISTER_OP("nn.log_softmax")
......@@ -382,9 +367,7 @@ Expr MakeBatchFlatten(Expr data) {
TVM_REGISTER_API("relay.op.nn._make.batch_flatten")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 1>(MakeBatchFlatten, args, rv);
});
.set_body_typed(MakeBatchFlatten);
RELAY_REGISTER_OP("nn.batch_flatten")
......@@ -424,7 +407,7 @@ Example::
// relu
TVM_REGISTER_API("relay.op.nn._make.relu")
.set_body_typed<Expr(Expr)>([](Expr data) {
.set_body_typed<Call(Expr)>([](Expr data) {
static const Op& op = Op::Get("nn.relu");
return CallNode::make(op, {data}, Attrs(), {});
});
......@@ -469,9 +452,7 @@ Expr MakeLRN(Expr data,
}
TVM_REGISTER_API("relay.op.nn._make.lrn")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeLRN, args, rv);
});
.set_body_typed(MakeLRN);
RELAY_REGISTER_OP("nn.lrn")
.describe(R"code(LRN layer.
......@@ -509,9 +490,7 @@ Expr MakeL2Normalize(Expr data,
}
TVM_REGISTER_API("relay.op.nn._make.l2_normalize")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
});
.set_body_typed(MakeL2Normalize);
RELAY_REGISTER_OP("nn.l2_normalize")
.describe(R"code(L2 Normalization layer.
......@@ -556,9 +535,7 @@ Expr MakeDropout(Expr data, double rate) {
}
TVM_REGISTER_API("relay.op.nn._make.dropout")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeDropout, args, rv);
});
.set_body_typed(MakeDropout);
RELAY_REGISTER_OP("nn.dropout")
.describe(R"code(Applies the dropout operation to the input array.
......@@ -622,9 +599,7 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi
}
TVM_REGISTER_API("relay.op.nn._make.batch_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 9>(MakeBatchNorm, args, rv);
});
.set_body_typed(MakeBatchNorm);
RELAY_REGISTER_OP("nn.batch_norm")
.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
......@@ -711,9 +686,7 @@ Expr MakeBatchMatmul(Expr x,
TVM_REGISTER_API("relay.op.nn._make.batch_matmul")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBatchMatmul, args, rv);
});
.set_body_typed(MakeBatchMatmul);
RELAY_REGISTER_OP("nn.batch_matmul")
......
......@@ -115,9 +115,7 @@ Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
}
TVM_REGISTER_API("relay.op.nn._make.pad")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv);
});
.set_body_typed(MakePad);
RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor.
......
......@@ -186,9 +186,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
}
TVM_REGISTER_API("relay.op.nn._make.max_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeMaxPool2D, args, rv);
});
.set_body_typed(MakeMaxPool2D);
RELAY_REGISTER_OP("nn.max_pool2d")
......@@ -242,9 +240,7 @@ Expr MakeAvgPool2D(Expr data,
TVM_REGISTER_API("relay.op.nn._make.avg_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 7>(MakeAvgPool2D, args, rv);
});
.set_body_typed(MakeAvgPool2D);
RELAY_REGISTER_OP("nn.avg_pool2d")
......@@ -345,9 +341,7 @@ Expr MakeGlobalAvgPool2D(Expr data,
TVM_REGISTER_API("relay.op.nn._make.global_avg_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGlobalAvgPool2D, args, rv);
});
.set_body_typed(MakeGlobalAvgPool2D);
// GlobalAvgPool
RELAY_REGISTER_OP("nn.global_avg_pool2d")
......@@ -378,9 +372,7 @@ Expr MakeGlobalMaxPool2D(Expr data,
}
TVM_REGISTER_API("relay.op.nn._make.global_max_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGlobalMaxPool2D, args, rv);
});
.set_body_typed(MakeGlobalMaxPool2D);
RELAY_REGISTER_OP("nn.global_max_pool2d")
......
......@@ -110,9 +110,7 @@ Expr MakeUpSampling(Expr data,
TVM_REGISTER_API("relay.op.nn._make.upsampling")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeUpSampling, args, rv);
});
.set_body_typed(MakeUpSampling);
RELAY_REGISTER_OP("nn.upsampling")
......
......@@ -265,8 +265,8 @@ bool ReduceRel(const Array<Type>& types,
#define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \
.set_body([](const TVMArgs& args, TVMRetValue* rv) { \
auto make_func = [](Expr data, \
.set_body_typed<Call(Expr, Array<Integer>, bool, bool)>([]( \
Expr data, \
Array<Integer> axis, \
bool keepdims, \
bool exclude) { \
......@@ -276,8 +276,6 @@ bool ReduceRel(const Array<Type>& types,
attrs->exclude = exclude; \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(attrs), {}); \
}; \
runtime::detail::unpack_call<Expr, 4>(make_func, args, rv); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
......
......@@ -81,9 +81,7 @@ Expr MakeCast(Expr data,
}
TVM_REGISTER_API("relay._make.cast")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv);
});
.set_body_typed(MakeCast);
RELAY_REGISTER_OP("cast")
.describe(R"code(Cast the data into a new data type.
......@@ -161,9 +159,7 @@ Expr MakeExpandDims(Expr data,
}
TVM_REGISTER_API("relay.op._make.expand_dims")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeExpandDims, args, rv);
});
.set_body_typed(MakeExpandDims);
RELAY_REGISTER_OP("expand_dims")
.describe(R"code(Insert `num_newaxis` axises at the position given by `axis`
......@@ -279,9 +275,7 @@ Expr MakeConcatenate(Expr data,
}
TVM_REGISTER_API("relay.op._make.concatenate")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeConcatenate, args, rv);
});
.set_body_typed(MakeConcatenate);
RELAY_REGISTER_OP("concatenate")
.describe(R"code(Concatenate the input tensors along the given axis.
......@@ -367,9 +361,7 @@ Expr MakeStack(Expr data,
}
TVM_REGISTER_API("relay.op._make.stack")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeStack, args, rv);
});
.set_body_typed(MakeStack);
RELAY_REGISTER_OP("stack")
.describe(R"code(Stack the input tensors along the given axis.
......@@ -461,9 +453,7 @@ Expr MakeTranspose(Expr data,
}
TVM_REGISTER_API("relay.op._make.transpose")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeTranspose, args, rv);
});
.set_body_typed(MakeTranspose);
RELAY_REGISTER_OP("transpose")
.describe(R"code(Permutes the dimensions of an array.
......@@ -598,9 +588,7 @@ Expr MakeReshape(Expr data,
}
TVM_REGISTER_API("relay.op._make.reshape")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReshape, args, rv);
});
.set_body_typed(MakeReshape);
RELAY_REGISTER_OP("reshape")
.describe(R"code(Reshapes the input array.
......@@ -698,9 +686,7 @@ Expr MakeReshapeLike(Expr data,
TVM_REGISTER_API("relay.op._make.reshape_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReshapeLike, args, rv);
});
.set_body_typed(MakeReshapeLike);
RELAY_REGISTER_OP("reshape_like")
......@@ -790,9 +776,7 @@ Expr MakeTake(Expr data,
}
TVM_REGISTER_API("relay.op._make.take")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
});
.set_body_typed(MakeTake);
RELAY_REGISTER_OP("take")
.describe(R"code(Take elements from an array along an axis.
......@@ -873,9 +857,7 @@ Expr MakeFull(Expr fill_value,
}
TVM_REGISTER_API("relay.op._make.full")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeFull, args, rv);
});
.set_body_typed(MakeFull);
RELAY_REGISTER_OP("full")
.describe(R"code(Fill array with scalar value.
......@@ -910,9 +892,7 @@ Expr MakeZeros(Array<IndexExpr> shape,
}
TVM_REGISTER_API("relay.op._make.zeros")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeZeros, args, rv);
});
.set_body_typed(MakeZeros);
RELAY_REGISTER_OP("zeros")
.describe(R"code(Fill array with zeros.
......@@ -933,9 +913,7 @@ Expr MakeOnes(Array<IndexExpr> shape,
}
TVM_REGISTER_API("relay.op._make.ones")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeOnes, args, rv);
});
.set_body_typed(MakeOnes);
RELAY_REGISTER_OP("ones")
.describe(R"code(Fill array with ones.
......@@ -982,9 +960,7 @@ Expr MakeFullLike(Expr data,
}
TVM_REGISTER_API("relay.op._make.full_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeFullLike, args, rv);
});
.set_body_typed(MakeFullLike);
RELAY_REGISTER_OP("full_like")
.describe(R"code(Return an scalar value array with the same shape
......@@ -1041,9 +1017,7 @@ Expr MakeArange(tvm::Expr start,
}
TVM_REGISTER_API("relay.op._make.arange")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeArange, args, rv);
});
.set_body_typed(MakeArange);
RELAY_REGISTER_OP("arange")
.describe(R"code(Returns evenly spaced values within a given interval.
......@@ -1117,9 +1091,7 @@ Expr MakeRepeat(Expr data,
}
TVM_REGISTER_API("relay.op._make.repeat")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeRepeat, args, rv);
});
.set_body_typed(MakeRepeat);
RELAY_REGISTER_OP("repeat")
.describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
......@@ -1217,9 +1189,7 @@ Expr MakeTile(Expr data,
}
TVM_REGISTER_API("relay.op._make.tile")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeTile, args, rv);
});
.set_body_typed(MakeTile);
RELAY_REGISTER_OP("tile")
.describe(R"code(Repeat the whole array multiple times.
......@@ -1280,9 +1250,7 @@ Expr MakeReverse(Expr data,
}
TVM_REGISTER_API("relay.op._make.reverse")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReverse, args, rv);
});
.set_body_typed(MakeReverse);
RELAY_REGISTER_OP("reverse")
.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
......@@ -1345,9 +1313,7 @@ Array<Tensor> WhereCompute(const Attrs& attrs,
}
TVM_REGISTER_API("relay.op._make.where")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeWhere, args, rv);
});
.set_body_typed(MakeWhere);
RELAY_REGISTER_OP("where")
.describe(R"code(
......@@ -1400,9 +1366,7 @@ Expr MakeSqueeze(Expr data,
}
TVM_REGISTER_API("relay.op._make.squeeze")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeSqueeze, args, rv);
});
.set_body_typed(MakeSqueeze);
bool SqueezeRel(const Array<Type>& types,
......@@ -1507,9 +1471,7 @@ Array<Tensor> CollapseSumLikeCompute(const Attrs& attrs,
}
TVM_REGISTER_API("relay.op._make.collapse_sum_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeCollapseSumLike, args, rv);
});
.set_body_typed(MakeCollapseSumLike);
RELAY_REGISTER_OP("collapse_sum_like")
.describe(R"code(Collapse the first input to match the shape of the second input.
......@@ -1554,9 +1516,7 @@ Array<Tensor> BroadCastToCompute(const Attrs& attrs,
}
TVM_REGISTER_API("relay.op._make.broadcast_to")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastTo, args, rv);
});
.set_body_typed(MakeBroadCastTo);
RELAY_REGISTER_OP("broadcast_to")
.describe(R"code(Broadcast the first input to match the shape argument.
......@@ -1594,9 +1554,7 @@ Array<Tensor> BroadCastToLikeCompute(const Attrs& attrs,
}
TVM_REGISTER_API("relay.op._make.broadcast_to_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastToLike, args, rv);
});
.set_body_typed(MakeBroadCastToLike);
RELAY_REGISTER_OP("broadcast_to_like")
.describe(R"code(Broadcast the first input to match the shape of the second input.
......@@ -1806,9 +1764,7 @@ Array<Tensor> StridedSliceCompute(const Attrs& attrs,
TVM_REGISTER_API("relay.op._make.strided_slice")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeStridedSlice, args, rv);
});
.set_body_typed(MakeStridedSlice);
RELAY_REGISTER_OP("strided_slice")
......@@ -2081,9 +2037,7 @@ Array<Tensor> SliceLikeCompute(const Attrs& attrs,
TVM_REGISTER_API("relay.op._make.slice_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeSliceLike, args, rv);
});
.set_body_typed(MakeSliceLike);
RELAY_REGISTER_OP("slice_like")
......@@ -2144,9 +2098,7 @@ Expr MakeLayoutTransform(Expr data,
}
TVM_REGISTER_API("relay.op._make.layout_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 3>(MakeLayoutTransform, args, rv);
});
.set_body_typed(MakeLayoutTransform);
RELAY_REGISTER_OP("layout_transform")
.describe(R"code(Transform the input data layout.
......@@ -2174,9 +2126,7 @@ Expr MakeReverseReshape(Expr data,
}
TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReverseReshape, args, rv);
});
.set_body_typed(MakeReverseReshape);
RELAY_REGISTER_OP("_contrib_reverse_reshape")
.describe(R"code(Reshapes the input array where the special values are inferred from
......@@ -2250,9 +2200,7 @@ Expr MakeGatherND(Expr data,
}
TVM_REGISTER_API("relay.op._make.gather_nd")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGatherND, args, rv);
});
.set_body_typed(MakeGatherND);
RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to
......
......@@ -73,9 +73,7 @@ Expr MakeMultiBoxPrior(Expr data,
TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxPrior, args, rv);
});
.set_body_typed(MakeMultiBoxPrior);
RELAY_REGISTER_OP("vision.multibox_prior")
......@@ -147,9 +145,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob,
}
TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxTransformLoc, args, rv);
});
.set_body_typed(MakeMultiBoxTransformLoc);
RELAY_REGISTER_OP("vision.multibox_transform_loc")
.describe(R"doc("Location transformation for multibox detection."
......
......@@ -59,9 +59,7 @@ Expr MakeGetValidCounts(Expr data,
TVM_REGISTER_API("relay.op.vision._make.get_valid_counts")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeGetValidCounts, args, rv);
});
.set_body_typed(MakeGetValidCounts);
RELAY_REGISTER_OP("vision.get_valid_counts")
......@@ -125,9 +123,7 @@ Expr MakeNMS(Expr data,
TVM_REGISTER_API("relay.op.vision._make.non_max_suppression")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 9>(MakeNMS, args, rv);
});
.set_body_typed(MakeNMS);
RELAY_REGISTER_OP("vision.non_max_suppression")
......
......@@ -62,9 +62,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spa
}
TVM_REGISTER_API("relay.op.vision._make.roi_align")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 6>(MakeROIAlign, args, rv);
});
.set_body_typed(MakeROIAlign);
RELAY_REGISTER_OP("vision.roi_align")
.describe(R"doc(ROI Align operator.
......@@ -114,9 +112,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spat
}
TVM_REGISTER_API("relay.op.vision._make.roi_pool")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 5>(MakeROIPool, args, rv);
});
.set_body_typed(MakeROIPool);
RELAY_REGISTER_OP("vision.roi_pool")
.describe(R"doc(ROI Pool operator.
......@@ -182,9 +178,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array<IndexExpr>
}
TVM_REGISTER_API("relay.op.vision._make.proposal")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 11>(MakeProposal, args, rv);
});
.set_body_typed(MakeProposal);
RELAY_REGISTER_OP("vision.proposal")
.describe(R"code(Generate region proposals via RPN.
......
......@@ -71,9 +71,7 @@ Expr MakeYoloReorg(Expr data,
TVM_REGISTER_API("relay.op.vision._make.yolo_reorg")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeYoloReorg, args, rv);
});
.set_body_typed(MakeYoloReorg);
RELAY_REGISTER_OP("vision.yolo_reorg")
......
......@@ -61,9 +61,7 @@ Expr CanonicalizeOps(const Expr& e) {
}
TVM_REGISTER_API("relay._ir_pass.canonicalize_ops")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CanonicalizeOps(args[0]);
});
.set_body_typed(CanonicalizeOps);
} // namespace relay
} // namespace tvm
......@@ -355,9 +355,7 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
}
TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = CombineParallelConv2D(args[0], args[1]);
});
.set_body_typed(CombineParallelConv2D);
} // namespace relay
} // namespace tvm
......@@ -148,9 +148,7 @@ Expr DeadCodeElimination(const Expr& e) {
}
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = DeadCodeElimination(args[0]);
});
.set_body_typed(DeadCodeElimination);
} // namespace relay
} // namespace tvm
......@@ -493,19 +493,13 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
}
TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CollectDeviceInfo(args[0]);
});
.set_body_typed(CollectDeviceInfo);
TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = RewriteAnnotatedOps(args[0], args[1]);
});
.set_body_typed(RewriteAnnotatedOps);
TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = CollectDeviceAnnotationOps(args[0]);
});
.set_body_typed(CollectDeviceAnnotationOps);
} // namespace relay
} // namespace tvm
......@@ -210,9 +210,7 @@ Expr FoldConstant(const Expr& expr) {
}
TVM_REGISTER_API("relay._ir_pass.FoldConstant")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FoldConstant(args[0]);
});
.set_body_typed(FoldConstant);
} // namespace relay
} // namespace tvm
......@@ -912,8 +912,6 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level) {
}
TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = FuseOps(args[0], args[1]);
});
.set_body_typed(FuseOps);
} // namespace relay
} // namespace tvm
......@@ -247,10 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
}
TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = FirstOrderGradient(args[0], args[1]);
});
.set_body_typed(FirstOrderGradient);
struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
......@@ -263,7 +260,7 @@ struct ReverseAD : ExprMutator {
Var bp;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
ReverseAD(const Var& bp) : bp(bp) { }
ReverseAD(const Var& bp) : bp(bp) { } /// NOLINT(*)
Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
......@@ -349,10 +346,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
}
TVM_REGISTER_API("relay._ir_pass.gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) {
CHECK_EQ(args.size(), 2);
*ret = Gradient(args[0], args[1]);
});
.set_body_typed(Gradient);
} // namespace relay
} // namespace tvm
......@@ -147,9 +147,7 @@ int64_t GetTotalMacNumber(const Expr& expr) {
}
TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = GetTotalMacNumber(args[0]);
});
.set_body_typed(GetTotalMacNumber);
} // namespace mac_count
} // namespace relay
......
......@@ -426,12 +426,7 @@ Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_API("relay._ir_pass.PassInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
int opt_level = args[0];
std::string name = args[1];
tvm::Array<tvm::Expr> required = args[2];
*ret = PassInfoNode::make(opt_level, name, required);
});
.set_body_typed(PassInfoNode::make);
TVM_REGISTER_API("relay._ir_pass.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......@@ -456,13 +451,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_API("relay._ir_pass.CreateModulePass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
PackedFunc pass_func = args[0];
int opt_level = args[1];
std::string name = args[2];
tvm::Array<tvm::Expr> required = args[3];
*ret = CreateModulePass(pass_func, opt_level, name, required);
});
.set_body_typed(CreateModulePass);
TVM_REGISTER_API("relay._ir_pass.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......@@ -487,13 +476,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
PackedFunc pass_func = args[0];
int opt_level = args[1];
std::string name = args[2];
tvm::Array<tvm::Expr> required = args[3];
*ret = CreateFunctionPass(pass_func, opt_level, name, required);
});
.set_body_typed(CreateFunctionPass);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
......@@ -541,9 +524,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext")
TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API("relay._ir_pass.PassContext")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = PassContextNode::make();
});
.set_body_typed(PassContextNode::make);
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PassContextNode>([](const PassContextNode* node,
......
......@@ -571,20 +571,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
TVM_REGISTER_API("relay._quantize._GetCurrentQConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = QConfig::Current();
});
.set_body_typed(QConfig::Current);
TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
QConfig target = args[0];
QConfig::EnterQConfigScope(target);
});
.set_body_typed(QConfig::EnterQConfigScope);
TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) {
QConfig::ExitQConfigScope();
});
.set_body_typed(QConfig::ExitQConfigScope);
} // namespace quantize
} // namespace relay
......
......@@ -103,9 +103,7 @@ Expr SimplifyInference(const Expr& e) {
}
TVM_REGISTER_API("relay._ir_pass.simplify_inference")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SimplifyInference(args[0]);
});
.set_body_typed(SimplifyInference);
} // namespace relay
} // namespace tvm
......@@ -491,9 +491,7 @@ Expr ToANormalForm(const Expr& e, const Module& m) {
}
TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ToANormalForm(args[0], args[1]);
});
.set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
} // namespace relay
} // namespace tvm
......@@ -77,9 +77,7 @@ Expr ToGraphNormalForm(const Expr& e) {
}
TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ToGraphNormalForm(args[0]);
});
.set_body_typed(ToGraphNormalForm);
} // namespace relay
} // namespace tvm
......@@ -801,8 +801,8 @@ Function InferType(const Function& func,
}
TVM_REGISTER_API("relay._ir_pass.infer_type")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = InferType(args[0], args[1]);
.set_body_typed<Expr(const Expr&, const Module&)>([](const Expr& expr, const Module& mod_ref) {
return InferType(expr, mod_ref);
});
} // namespace relay
} // namespace tvm
......@@ -275,9 +275,7 @@ tvm::Array<Var> AllVars(const Expr& expr) {
}
TVM_REGISTER_API("relay._ir_pass.free_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = FreeVars(args[0]);
});
.set_body_typed(FreeVars);
TVM_REGISTER_API("relay._ir_pass.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......@@ -290,9 +288,7 @@ TVM_REGISTER_API("relay._ir_pass.bound_vars")
});
TVM_REGISTER_API("relay._ir_pass.all_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = AllVars(args[0]);
});
.set_body_typed(AllVars);
TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) {
......
......@@ -79,10 +79,7 @@ bool WellFormed(const Expr& e) {
}
TVM_REGISTER_API("relay._ir_pass.well_formed")
.set_body([](TVMArgs args, TVMRetValue *ret) {
Expr e = args[0];
*ret = WellFormed(e);
});
.set_body_typed(WellFormed);
} // namespace relay
} // namespace tvm
......@@ -308,18 +308,12 @@ Module CUDAModuleLoadBinary(void* strm) {
}
TVM_REGISTER_GLOBAL("module.loadfile_cubin")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoadFile(args[0], args[1]);
});
.set_body_typed(CUDAModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_ptx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoadFile(args[0], args[1]);
});
.set_body_typed(CUDAModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoadBinary(args[0]);
});
.set_body_typed(CUDAModuleLoadBinary);
} // namespace runtime
} // namespace tvm
......@@ -310,13 +310,9 @@ Module MetalModuleLoadBinary(void* strm) {
}
TVM_REGISTER_GLOBAL("module.loadfile_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = MetalModuleLoadFile(args[0], args[1]);
});
.set_body_typed(MetalModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = MetalModuleLoadBinary(args[0]);
});
.set_body_typed(MetalModuleLoadBinary);
} // namespace runtime
} // namespace tvm
......@@ -69,9 +69,7 @@ Module AOCLModuleLoadFile(const std::string& file_name,
}
TVM_REGISTER_GLOBAL("module.loadfile_aocx")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = AOCLModuleLoadFile(args[0], args[1]);
});
.set_body_typed(AOCLModuleLoadFile);
} // namespace runtime
} // namespace tvm
......@@ -281,18 +281,12 @@ Module OpenCLModuleLoadBinary(void* strm) {
}
TVM_REGISTER_GLOBAL("module.loadfile_cl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoadFile(args[0], args[1]);
});
.set_body_typed(OpenCLModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_clbin")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoadFile(args[0], args[1]);
});
.set_body_typed(OpenCLModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoadBinary(args[0]);
});
.set_body_typed(OpenCLModuleLoadBinary);
} // namespace runtime
} // namespace tvm
......@@ -80,13 +80,9 @@ Module SDAccelModuleLoadBinary(void* strm) {
}
TVM_REGISTER_GLOBAL("module.loadfile_xclbin")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = SDAccelModuleLoadFile(args[0], args[1]);
});
.set_body_typed(SDAccelModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_awsxclbin")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = SDAccelModuleLoadFile(args[0], args[1]);
});
.set_body_typed(SDAccelModuleLoadFile);
} // namespace runtime
} // namespace tvm
......@@ -243,14 +243,10 @@ Module ROCMModuleLoadBinary(void* strm) {
TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = ROCMModuleLoadBinary(args[0]);
});
.set_body_typed(ROCMModuleLoadBinary);
TVM_REGISTER_GLOBAL("module.loadbinary_hip")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = ROCMModuleLoadBinary(args[0]);
});
.set_body_typed(ROCMModuleLoadBinary);
} // namespace runtime
} // namespace tvm
......@@ -64,8 +64,6 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend,
}
TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEventDrivenServer(args[0], args[1], args[2]);
});
.set_body_typed(CreateEventDrivenServer);
} // namespace runtime
} // namespace tvm
......@@ -110,9 +110,7 @@ void RPCServerLoop(int sockfd) {
}
TVM_REGISTER_GLOBAL("rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = RPCClientConnect(args[0], args[1], args[2]);
});
.set_body_typed(RPCClientConnect);
TVM_REGISTER_GLOBAL("rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) {
......
......@@ -142,9 +142,7 @@ Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap,
}
TVM_REGISTER_GLOBAL("module.loadfile_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = StackVMModuleNode::LoadFromFile(args[0], args[1]);
});
.set_body_typed(StackVMModuleNode::LoadFromFile);
} // namespace runtime
} // namespace tvm
......@@ -427,13 +427,9 @@ Module VulkanModuleLoadBinary(void* strm) {
}
TVM_REGISTER_GLOBAL("module.loadfile_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = VulkanModuleLoadFile(args[0], args[1]);
});
.set_body_typed(VulkanModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadbinary_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = VulkanModuleLoadBinary(args[0]);
});
.set_body_typed(VulkanModuleLoadBinary);
} // namespace runtime
} // namespace tvm
......@@ -60,16 +60,16 @@ struct RPCEnv {
};
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
.set_body_typed<std::string(std::string)>([](std::string path) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
return env.GetPath(path);
});
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string file_name = "/rpc/" + args[0].operator std::string();
*rv = Module::LoadFromFile(file_name, "");
.set_body_typed<Module(std::string)>([](std::string path) {
std::string file_name = "/rpc/" + path;
LOG(INFO) << "Load module from " << file_name << " ...";
return Module::LoadFromFile(file_name, "");
});
} // namespace contrib
} // namespace tvm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment