Commit 51785062 by James Gilles Committed by Tianqi Chen

[REFACTOR] Use more TypedPackedFuncs (#2981)

* Add `set_body_simple` to Registry, refactor a lot of code to use it

* Add more types to Relay PackedFuncs

* Add Registry::set_body_method to easily make Node methods into
PackedFuncs

* Add set_body_method, set_body_node_method; start typing api_lang

* Add some docs, remove unused script

* Fix mysterious linter problem

* Touch up api_ir.cc

* Fix some issues with TOPI argument counts

* Revert changes to topi.cc to avoid problems with optional arguments

* A little more cleanup

* Type more of the api _ functions

* Whitespace

* Finalize names and docs for new registry helpers

* Update docs
parent 57f47a17
...@@ -83,6 +83,169 @@ class Registry { ...@@ -83,6 +83,169 @@ class Registry {
Registry& set_body_typed(FLambda f) { Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed()); return set_body(TypedPackedFunc<FType>(f).packed());
} }
/*!
* \brief set the body of the function to the given function pointer.
* Note that this doesn't work with lambdas, you need to
* explicitly give a type for those.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* int multiply(int x, int y) {
* return x * y;
* }
*
* TVM_REGISTER_API("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* \endcode
*
* \param f The function to forward to.
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename R, typename ...Args>
Registry& set_body_typed(R (*f)(Args...)) {
return set_body(TypedPackedFunc<R(Args...)>(f));
}
/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam T the type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...)) {
return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
}
/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam T the type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...) const) {
return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
}
/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct ExampleNode: BaseNode {
* int doThing(int x);
* }
*
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
* // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
}
/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct ExampleNode: BaseNode {
* int doThing(int x);
* }
*
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
* // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
const TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
}
/*! /*!
* \brief Register a function with given name * \brief Register a function with given name
* \param name The name of the function. * \param name The name of the function.
......
...@@ -360,9 +360,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph") ...@@ -360,9 +360,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph")
}); });
TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey") TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { .set_body_typed(GraphKeyNode::make);
*rv = GraphKeyNode::make(args[0], args[1], args[2]);
});
// This can be used to extract workloads from nnvm compiler // This can be used to extract workloads from nnvm compiler
TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs") TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
......
...@@ -235,8 +235,6 @@ std::string GraphDeepCompare(const Graph& a, ...@@ -235,8 +235,6 @@ std::string GraphDeepCompare(const Graph& a,
} }
TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare") TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) { .set_body_typed(GraphDeepCompare);
*rv = GraphDeepCompare(args[0], args[1], args[2]);
});
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -31,73 +31,51 @@ namespace tvm { ...@@ -31,73 +31,51 @@ namespace tvm {
namespace arith { namespace arith {
TVM_REGISTER_API("arith.intset_single_point") TVM_REGISTER_API("arith.intset_single_point")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(IntSet::single_point);
*ret = IntSet::single_point(args[0]);
});
TVM_REGISTER_API("arith.intset_vector") TVM_REGISTER_API("arith.intset_vector")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(IntSet::vector);
*ret = IntSet::vector(args[0]);
});
TVM_REGISTER_API("arith.intset_interval") TVM_REGISTER_API("arith.intset_interval")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(IntSet::interval);
*ret = IntSet::interval(args[0], args[1]);
});
TVM_REGISTER_API("arith.DetectLinearEquation") TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(DetectLinearEquation);
*ret = DetectLinearEquation(args[0], args[1]);
});
TVM_REGISTER_API("arith.DetectClipBound") TVM_REGISTER_API("arith.DetectClipBound")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(DetectClipBound);
*ret = DetectClipBound(args[0], args[1]);
});
TVM_REGISTER_API("arith.DeduceBound") TVM_REGISTER_API("arith.DeduceBound")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
*ret = DeduceBound(args[0], args[1], Expr v, Expr cond,
args[2].operator Map<Var, IntSet>(), const Map<Var, IntSet> hint_map,
args[3].operator Map<Var, IntSet>()); const Map<Var, IntSet> relax_map
}); ) {
return DeduceBound(v, cond, hint_map, relax_map);
});
TVM_REGISTER_API("arith.DomainTouched") TVM_REGISTER_API("arith.DomainTouched")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(DomainTouched);
*ret = DomainTouched(args[0], args[1], args[2], args[3]);
});
TVM_REGISTER_API("_IntervalSetGetMin") TVM_REGISTER_API("_IntervalSetGetMin")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method(&IntSet::min);
*ret = args[0].operator IntSet().min();
});
TVM_REGISTER_API("_IntervalSetGetMax") TVM_REGISTER_API("_IntervalSetGetMax")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method(&IntSet::max);
*ret = args[0].operator IntSet().max();
});
TVM_REGISTER_API("_IntSetIsNothing") TVM_REGISTER_API("_IntSetIsNothing")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method(&IntSet::is_nothing);
*ret = args[0].operator IntSet().is_nothing();
});
TVM_REGISTER_API("_IntSetIsEverything") TVM_REGISTER_API("_IntSetIsEverything")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method(&IntSet::is_everything);
*ret = args[0].operator IntSet().is_everything();
});
TVM_REGISTER_API("arith._make_ConstIntBound") TVM_REGISTER_API("arith._make_ConstIntBound")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ConstIntBoundNode::make);
*ret = ConstIntBoundNode::make(args[0], args[1]);
});
TVM_REGISTER_API("arith._make_ModularSet") TVM_REGISTER_API("arith._make_ModularSet")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ModularSetNode::make);
*ret = ModularSetNode::make(args[0], args[1]);
});
TVM_REGISTER_API("arith._CreateAnalyzer") TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
......
...@@ -50,9 +50,8 @@ TVM_REGISTER_API("_load_json") ...@@ -50,9 +50,8 @@ TVM_REGISTER_API("_load_json")
.set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>); .set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);
TVM_REGISTER_API("_TVMSetStream") TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(TVMSetStream);
TVMSetStream(args[0], args[1], args[2]);
});
TVM_REGISTER_API("_save_param_dict") TVM_REGISTER_API("_save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u); CHECK_EQ(args.size() % 2, 0u);
......
...@@ -41,8 +41,6 @@ TVM_REGISTER_API("codegen._Build") ...@@ -41,8 +41,6 @@ TVM_REGISTER_API("codegen._Build")
}); });
TVM_REGISTER_API("module._PackImportsToC") TVM_REGISTER_API("module._PackImportsToC")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(PackImportsToC);
*ret = PackImportsToC(args[0], args[1]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -31,54 +31,43 @@ namespace tvm { ...@@ -31,54 +31,43 @@ namespace tvm {
namespace ir { namespace ir {
TVM_REGISTER_API("_Var") TVM_REGISTER_API("_Var")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<VarExpr(std::string, Type)>([](std::string s, Type t) {
*ret = Variable::make(args[1], args[0]); return Variable::make(t, s);
}); });
TVM_REGISTER_API("make.abs") TVM_REGISTER_API("make.abs")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(tvm::abs);
*ret = tvm::abs(args[0]);
});
TVM_REGISTER_API("make.floor") TVM_REGISTER_API("make.floor")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(tvm::floor);
*ret = tvm::floor(args[0]);
});
TVM_REGISTER_API("make.ceil") TVM_REGISTER_API("make.ceil")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(tvm::ceil);
*ret = tvm::ceil(args[0]);
});
TVM_REGISTER_API("make.round") TVM_REGISTER_API("make.round")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(tvm::round);
*ret = tvm::round(args[0]);
});
TVM_REGISTER_API("make.trunc") TVM_REGISTER_API("make.trunc")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(tvm::trunc);
*ret = tvm::trunc(args[0]);
});
TVM_REGISTER_API("make._cast") TVM_REGISTER_API("make._cast")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(tvm::cast);
*ret = tvm::cast(args[0], args[1]);
});
TVM_REGISTER_API("make._range_by_min_extent") TVM_REGISTER_API("make._range_by_min_extent")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(Range::make_by_min_extent);
*ret = Range::make_by_min_extent(args[0], args[1]);
});
TVM_REGISTER_API("make.For") TVM_REGISTER_API("make.For")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
*ret = For::make(args[0], VarExpr loop_var, Expr min, Expr extent,
args[1], int for_type, int device_api, Stmt body
args[2], ) {
static_cast<ForType>(args[3].operator int()), return For::make(loop_var,
static_cast<HalideIR::DeviceAPI>(args[4].operator int()), min,
args[5]); extent,
}); static_cast<ForType>(for_type),
static_cast<HalideIR::DeviceAPI>(device_api),
body);
});
TVM_REGISTER_API("make.Load") TVM_REGISTER_API("make.Load")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
...@@ -101,114 +90,87 @@ TVM_REGISTER_API("make.Store") ...@@ -101,114 +90,87 @@ TVM_REGISTER_API("make.Store")
}); });
TVM_REGISTER_API("make.Realize") TVM_REGISTER_API("make.Realize")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(Realize::make);
*ret = Realize::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5]);
});
TVM_REGISTER_API("make.Call") TVM_REGISTER_API("make.Call")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<Expr(Type, std::string, Array<Expr>, int, FunctionRef, int)>([](
*ret = Call::make(args[0], Type type, std::string name,
args[1], Array<Expr> args, int call_type,
args[2], FunctionRef func, int value_index
static_cast<Call::CallType>(args[3].operator int()), ) {
args[4], return Call::make(type,
args[5]); name,
}); args,
static_cast<Call::CallType>(call_type),
func,
value_index);
});
TVM_REGISTER_API("make.CommReducer") TVM_REGISTER_API("make.CommReducer")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(CommReducerNode::make);
*ret = CommReducerNode::make(args[0],
args[1],
args[2],
args[3]);
});
// make from two arguments // make from two arguments
#define REGISTER_MAKE1(Node) \ #define REGISTER_MAKE(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0]); \
}) \
#define REGISTER_MAKE2(Node) \
TVM_REGISTER_API("make."#Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body_typed(Node::make); \
*ret = Node::make(args[0], args[1]); \
}) \ REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);
#define REGISTER_MAKE3(Node) \
TVM_REGISTER_API("make."#Node) \ REGISTER_MAKE(IntImm);
.set_body([](TVMArgs args, TVMRetValue *ret) { \ REGISTER_MAKE(UIntImm);
*ret = Node::make(args[0], args[1], args[2]); \ REGISTER_MAKE(FloatImm);
}) \ REGISTER_MAKE(StringImm);
#define REGISTER_MAKE4(Node) \ REGISTER_MAKE(Add);
TVM_REGISTER_API("make."#Node) \ REGISTER_MAKE(Sub);
.set_body([](TVMArgs args, TVMRetValue *ret) { \ REGISTER_MAKE(Mul);
*ret = Node::make(args[0], args[1], args[2], args[3]); \ REGISTER_MAKE(Div);
}) \ REGISTER_MAKE(Mod);
REGISTER_MAKE(Min);
#define REGISTER_MAKE5(Node) \ REGISTER_MAKE(Max);
TVM_REGISTER_API("make."#Node) \ REGISTER_MAKE(EQ);
.set_body([](TVMArgs args, TVMRetValue *ret) { \ REGISTER_MAKE(NE);
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \ REGISTER_MAKE(LT);
}) \ REGISTER_MAKE(LE);
REGISTER_MAKE(GT);
REGISTER_MAKE(GE);
REGISTER_MAKE5(Reduce); REGISTER_MAKE(And);
REGISTER_MAKE4(AttrStmt); REGISTER_MAKE(Or);
REGISTER_MAKE2(IntImm); REGISTER_MAKE(Not);
REGISTER_MAKE2(UIntImm); REGISTER_MAKE(Select);
REGISTER_MAKE2(FloatImm); REGISTER_MAKE(Ramp);
REGISTER_MAKE1(StringImm); REGISTER_MAKE(Cast);
REGISTER_MAKE(Broadcast);
REGISTER_MAKE2(Add); REGISTER_MAKE(Shuffle);
REGISTER_MAKE2(Sub); REGISTER_MAKE(Let);
REGISTER_MAKE2(Mul); REGISTER_MAKE(LetStmt);
REGISTER_MAKE2(Div); REGISTER_MAKE(AssertStmt);
REGISTER_MAKE2(Mod); REGISTER_MAKE(ProducerConsumer);
REGISTER_MAKE2(Min); REGISTER_MAKE(Provide);
REGISTER_MAKE2(Max); REGISTER_MAKE(Prefetch);
REGISTER_MAKE2(EQ); REGISTER_MAKE(Free);
REGISTER_MAKE2(NE); REGISTER_MAKE(IfThenElse);
REGISTER_MAKE2(LT); REGISTER_MAKE(Evaluate);
REGISTER_MAKE2(LE);
REGISTER_MAKE2(GT); // overloaded, needs special handling
REGISTER_MAKE2(GE); TVM_REGISTER_API("make.Block")
REGISTER_MAKE2(And); .set_body_typed(static_cast<Stmt (*)(Stmt, Stmt)>(Block::make));
REGISTER_MAKE2(Or);
// has default args
REGISTER_MAKE1(Not); TVM_REGISTER_API("make.Allocate")
REGISTER_MAKE3(Select); .set_body_typed<Stmt(VarExpr, Type, Array<Expr>, Expr, Stmt)>([](
REGISTER_MAKE3(Ramp); VarExpr buffer_var, Type type, Array<Expr> extents, Expr condition, Stmt body
REGISTER_MAKE2(Cast); ){
REGISTER_MAKE2(Broadcast); return Allocate::make(buffer_var, type, extents, condition, body);
REGISTER_MAKE2(Shuffle); });
REGISTER_MAKE3(Let);
REGISTER_MAKE3(LetStmt);
REGISTER_MAKE3(AssertStmt);
REGISTER_MAKE3(ProducerConsumer);
REGISTER_MAKE5(Allocate);
REGISTER_MAKE4(Provide);
REGISTER_MAKE4(Prefetch);
REGISTER_MAKE1(Free);
REGISTER_MAKE2(Block);
REGISTER_MAKE3(IfThenElse);
REGISTER_MAKE1(Evaluate);
// operator overloading, smarter than make // operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \ #define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_API("make."#Node) \ TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body_typed<Expr(Expr, Expr)>([](Expr a, Expr b) { \
Expr a = args[0], b = args[1]; \ return (Func(a, b)); \
*ret = (Func(a, b)); \
}) })
#define REGISTER_MAKE_BIT_OP(Node, Func) \ #define REGISTER_MAKE_BIT_OP(Node, Func) \
......
...@@ -32,19 +32,14 @@ ...@@ -32,19 +32,14 @@
#include <tvm/build_module.h> #include <tvm/build_module.h>
#include <tvm/data_layout.h> #include <tvm/data_layout.h>
namespace tvm { namespace tvm {
TVM_REGISTER_API("_min_value") TVM_REGISTER_API("_min_value")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Type::min);
Type t = args[0].operator Type();
*ret = t.min();
});
TVM_REGISTER_API("_max_value") TVM_REGISTER_API("_max_value")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Type::max);
Type t = args[0].operator Type();
*ret = t.max();
});
TVM_REGISTER_API("_const") TVM_REGISTER_API("_const")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
...@@ -58,9 +53,7 @@ TVM_REGISTER_API("_const") ...@@ -58,9 +53,7 @@ TVM_REGISTER_API("_const")
}); });
TVM_REGISTER_API("_str") TVM_REGISTER_API("_str")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ir::StringImm::make);
*ret = ir::StringImm::make(args[0]);
});
TVM_REGISTER_API("_Array") TVM_REGISTER_API("_Array")
...@@ -214,373 +207,217 @@ TVM_REGISTER_API("Range") ...@@ -214,373 +207,217 @@ TVM_REGISTER_API("Range")
}); });
TVM_REGISTER_API("_Buffer") TVM_REGISTER_API("_Buffer")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(BufferNode::make);
*ret = BufferNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7],
args[8]);
});
TVM_REGISTER_API("_BufferAccessPtr") TVM_REGISTER_API("_BufferAccessPtr")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Buffer::access_ptr);
*ret = args[0].operator Buffer()
.access_ptr(args[1], args[2], args[3], args[4]);
});
TVM_REGISTER_API("_BufferVLoad") TVM_REGISTER_API("_BufferVLoad")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Buffer::vload);
*ret = args[0].operator Buffer()
.vload(args[1], args[2]);
});
TVM_REGISTER_API("_BufferVStore") TVM_REGISTER_API("_BufferVStore")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Buffer::vstore);
*ret = args[0].operator Buffer()
.vstore(args[1], args[2]);
});
TVM_REGISTER_API("_Layout") TVM_REGISTER_API("_Layout")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(LayoutNode::make);
*ret = LayoutNode::make(args[0]);
});
TVM_REGISTER_API("_LayoutIndexOf") TVM_REGISTER_API("_LayoutIndexOf")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
*ret = args[0].operator Layout() return layout.IndexOf(LayoutAxis::make(axis));
.IndexOf(LayoutAxis::make(args[1]));
}); });
TVM_REGISTER_API("_LayoutFactorOf") TVM_REGISTER_API("_LayoutFactorOf")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
*ret = args[0].operator Layout() return layout.FactorOf(LayoutAxis::make(axis));
.FactorOf(LayoutAxis::make(args[1]));
}); });
TVM_REGISTER_API("_LayoutNdim") TVM_REGISTER_API("_LayoutNdim")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<int(Layout)>([](Layout layout) {
*ret = static_cast<int64_t>(args[0].operator Layout().ndim()); return layout.ndim();
}); });
TVM_REGISTER_API("_LayoutGetItem") TVM_REGISTER_API("_LayoutGetItem")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<std::string(Layout, int)>([](Layout layout, int idx) {
const LayoutAxis& axis = args[0].operator Layout()[args[1]]; const LayoutAxis& axis = layout[idx];
*ret = axis.name(); return axis.name();
}); });
TVM_REGISTER_API("_BijectiveLayout") TVM_REGISTER_API("_BijectiveLayout")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(BijectiveLayoutNode::make);
*ret = BijectiveLayoutNode::make(args[0], args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardIndex") TVM_REGISTER_API("_BijectiveLayoutForwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&BijectiveLayout::ForwardIndex);
*ret = args[0].operator BijectiveLayout()
.ForwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardIndex") TVM_REGISTER_API("_BijectiveLayoutBackwardIndex")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&BijectiveLayout::BackwardIndex);
*ret = args[0].operator BijectiveLayout()
.BackwardIndex(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutForwardShape") TVM_REGISTER_API("_BijectiveLayoutForwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&BijectiveLayout::ForwardShape);
*ret = args[0].operator BijectiveLayout()
.ForwardShape(args[1]);
});
TVM_REGISTER_API("_BijectiveLayoutBackwardShape") TVM_REGISTER_API("_BijectiveLayoutBackwardShape")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&BijectiveLayout::BackwardShape);
*ret = args[0].operator BijectiveLayout()
.BackwardShape(args[1]);
});
TVM_REGISTER_API("_Tensor") TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TensorNode::make);
*ret = TensorNode::make(args[0],
args[1],
args[2],
args[3]);
});
TVM_REGISTER_API("_TensorIntrin") TVM_REGISTER_API("_TensorIntrin")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TensorIntrinNode::make);
*ret = TensorIntrinNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6]);
});
TVM_REGISTER_API("_TensorIntrinCall") TVM_REGISTER_API("_TensorIntrinCall")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TensorIntrinCallNode::make);
*ret = TensorIntrinCallNode::make(args[0],
args[1],
args[2],
args[3]);
});
TVM_REGISTER_API("_TensorEqual") TVM_REGISTER_API("_TensorEqual")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Tensor::operator==);
*ret = args[0].operator Tensor() == args[1].operator Tensor();
});
TVM_REGISTER_API("_TensorHash") TVM_REGISTER_API("_TensorHash")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<int64_t(Tensor)>([](Tensor tensor) {
*ret = static_cast<int64_t>( return static_cast<int64_t>(std::hash<Tensor>()(tensor));
std::hash<Tensor>()(args[0].operator Tensor()));
}); });
TVM_REGISTER_API("_Placeholder") TVM_REGISTER_API("_Placeholder")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<Tensor(Array<Expr>, Type, std::string)>([](
*ret = placeholder(args[0], Array<Expr> shape, Type dtype, std::string name
args[1], ) {
args[2]); return placeholder(shape, dtype, name);
}); });
TVM_REGISTER_API("_ComputeOp") TVM_REGISTER_API("_ComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ComputeOpNode::make);
*ret = ComputeOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4]);
});
TVM_REGISTER_API("_ScanOp") TVM_REGISTER_API("_ScanOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ScanOpNode::make);
*ret = ScanOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7]);
});
TVM_REGISTER_API("_TensorComputeOp") TVM_REGISTER_API("_TensorComputeOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TensorComputeOpNode::make);
*ret = TensorComputeOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6],
args[7]);
});
TVM_REGISTER_API("_ExternOp") TVM_REGISTER_API("_ExternOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ExternOpNode::make);
*ret = ExternOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5],
args[6]);
});
TVM_REGISTER_API("_HybridOp") TVM_REGISTER_API("_HybridOp")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(HybridOpNode::make);
*ret = HybridOpNode::make(args[0],
args[1],
args[2],
args[3],
args[4],
args[5]);
});
TVM_REGISTER_API("_OpGetOutput") TVM_REGISTER_API("_OpGetOutput")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<Tensor(Operation, int64_t)>([](Operation op, int64_t output) {
*ret = args[0].operator Operation().output( return op.output(static_cast<size_t>(output));
static_cast<size_t>(args[1].operator int64_t())); });
});
TVM_REGISTER_API("_OpNumOutputs") TVM_REGISTER_API("_OpNumOutputs")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method<Operation>(&OperationNode::num_outputs);
*ret = args[0].operator Operation()->num_outputs();
});
TVM_REGISTER_API("_OpInputTensors") TVM_REGISTER_API("_OpInputTensors")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method<Operation>(&OperationNode::InputTensors);
*ret = args[0].operator Operation()->InputTensors();
});
TVM_REGISTER_API("_IterVar") TVM_REGISTER_API("_IterVar")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<IterVar(Range, Var, int, std::string)>([](
*ret = IterVarNode::make( Range dom, Var var, int iter_type, std::string thread_tag
args[0], args[1], ) {
static_cast<IterVarType>(args[2].operator int()), return IterVarNode::make(
args[3]); dom, var,
}); static_cast<IterVarType>(iter_type),
thread_tag);
});
TVM_REGISTER_API("_CreateSchedule") TVM_REGISTER_API("_CreateSchedule")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(create_schedule);
*ret = create_schedule(args[0].operator Array<Operation>());
});
TVM_REGISTER_API("_StageSetScope") TVM_REGISTER_API("_StageSetScope")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::set_scope);
args[0].operator Stage()
.set_scope(args[1]);
});
TVM_REGISTER_API("_StageBind") TVM_REGISTER_API("_StageBind")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::bind);
args[0].operator Stage()
.bind(args[1], args[2]);
});
TVM_REGISTER_API("_StageSplitByFactor") TVM_REGISTER_API("_StageSplitByFactor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
IterVar outer, inner; Stage stage, IterVar parent, Expr factor
args[0].operator Stage() ) {
.split(args[1], args[2], &outer, &inner); IterVar outer, inner;
*ret = Array<IterVar>({outer, inner}); stage.split(parent, factor, &outer, &inner);
}); return Array<IterVar>({outer, inner});
});
TVM_REGISTER_API("_StageSplitByNParts") TVM_REGISTER_API("_StageSplitByNParts")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
IterVar outer, inner; Stage stage, IterVar parent, Expr nparts
args[0].operator Stage() ) {
.split_by_nparts(args[1], args[2], &outer, &inner); IterVar outer, inner;
*ret = Array<IterVar>({outer, inner}); stage.split_by_nparts(parent, nparts, &outer, &inner);
}); return Array<IterVar>({outer, inner});
});
TVM_REGISTER_API("_StageFuse") TVM_REGISTER_API("_StageFuse")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<IterVar(Stage, Array<IterVar>)>([](Stage stage, Array<IterVar> axes) {
IterVar fused; IterVar fused;
args[0].operator Stage() stage.fuse(axes, &fused);
.fuse(args[1], &fused); return fused;
*ret = fused;
}); });
TVM_REGISTER_API("_StageComputeAt") TVM_REGISTER_API("_StageComputeAt")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::compute_at);
args[0].operator Stage()
.compute_at(args[1], args[2]);
});
TVM_REGISTER_API("_StageComputeInline") TVM_REGISTER_API("_StageComputeInline")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::compute_inline);
args[0].operator Stage()
.compute_inline();
});
TVM_REGISTER_API("_StageComputeRoot") TVM_REGISTER_API("_StageComputeRoot")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::compute_root);
args[0].operator Stage()
.compute_root();
});
TVM_REGISTER_API("_StageReorder") TVM_REGISTER_API("_StageReorder")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::reorder);
args[0].operator Stage()
.reorder(args[1]);
});
TVM_REGISTER_API("_StageTile") TVM_REGISTER_API("_StageTile")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<Array<IterVar>(Stage, IterVar, IterVar, Expr, Expr)>([](
Stage stage,
IterVar x_parent, IterVar y_parent,
Expr x_factor, Expr y_factor
) {
IterVar x_outer, y_outer, x_inner, y_inner; IterVar x_outer, y_outer, x_inner, y_inner;
args[0].operator Stage() stage.tile(x_parent, y_parent,
.tile(args[1], args[2], x_factor, y_factor,
args[3], args[4], &x_outer, &y_outer,
&x_outer, &y_outer, &x_inner, &y_inner);
&x_inner, &y_inner); return Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
}); });
TVM_REGISTER_API("_StageEnvThreads") TVM_REGISTER_API("_StageEnvThreads")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::env_threads);
args[0].operator Stage()
.env_threads(args[1]);
});
TVM_REGISTER_API("_StageSetStorePredicate") TVM_REGISTER_API("_StageSetStorePredicate")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::set_store_predicate);
args[0].operator Stage()
.set_store_predicate(args[1]);
});
TVM_REGISTER_API("_StageUnroll") TVM_REGISTER_API("_StageUnroll")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::unroll);
args[0].operator Stage()
.unroll(args[1]);
});
TVM_REGISTER_API("_StageVectorize") TVM_REGISTER_API("_StageVectorize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::vectorize);
args[0].operator Stage()
.vectorize(args[1]);
});
TVM_REGISTER_API("_StageTensorize") TVM_REGISTER_API("_StageTensorize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::tensorize);
args[0].operator Stage()
.tensorize(args[1], args[2]);
});
TVM_REGISTER_API("_StageParallel") TVM_REGISTER_API("_StageParallel")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::parallel);
args[0].operator Stage()
.parallel(args[1]);
});
TVM_REGISTER_API("_StagePragma") TVM_REGISTER_API("_StagePragma")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Stage::pragma);
args[0].operator Stage()
.pragma(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_StagePrefetch") TVM_REGISTER_API("_StagePrefetch")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method(&Stage::prefetch);
args[0].operator Stage()
.prefetch(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_StageStorageAlign") TVM_REGISTER_API("_StageStorageAlign")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method(&Stage::storage_align);
args[0].operator Stage()
.storage_align(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_StageDoubleBuffer") TVM_REGISTER_API("_StageDoubleBuffer")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method(&Stage::double_buffer);
args[0].operator Stage().double_buffer();
});
TVM_REGISTER_API("_StageOpenGL") TVM_REGISTER_API("_StageOpenGL")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method(&Stage::opengl);
args[0].operator Stage().opengl();
});
TVM_REGISTER_API("_ScheduleNormalize") TVM_REGISTER_API("_ScheduleNormalize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Schedule::normalize);
*ret = args[0].operator Schedule()
.normalize();
});
TVM_REGISTER_API("_ScheduleCreateGroup") TVM_REGISTER_API("_ScheduleCreateGroup")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Schedule::create_group);
*ret = args[0].operator Schedule()
.create_group(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_ScheduleCacheRead") TVM_REGISTER_API("_ScheduleCacheRead")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Schedule::cache_read);
*ret = args[0].operator Schedule()
.cache_read(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_ScheduleCacheWrite") TVM_REGISTER_API("_ScheduleCacheWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
...@@ -594,16 +431,9 @@ TVM_REGISTER_API("_ScheduleCacheWrite") ...@@ -594,16 +431,9 @@ TVM_REGISTER_API("_ScheduleCacheWrite")
}); });
TVM_REGISTER_API("_ScheduleRFactor") TVM_REGISTER_API("_ScheduleRFactor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method(&Schedule::rfactor);
*ret = args[0].operator Schedule()
.rfactor(args[1], args[2], args[3]);
});
TVM_REGISTER_API("_CommReducerCombine") TVM_REGISTER_API("_CommReducerCombine")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_method<ir::CommReducer>(&ir::CommReducerNode::operator());
const ir::CommReducerNode* combiner =
args[0].operator ir::CommReducer().as<ir::CommReducerNode>();
*ret = (*combiner)(args[1], args[2]);
});
} // namespace tvm } // namespace tvm
...@@ -119,68 +119,43 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit") ...@@ -119,68 +119,43 @@ TVM_REGISTER_API("ir_pass.PostOrderVisit")
}); });
// make from two arguments // make from two arguments
#define REGISTER_PASS1(PassName) \ #define REGISTER_PASS(PassName) \
TVM_REGISTER_API("ir_pass."#PassName) \ TVM_REGISTER_API("ir_pass."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body_typed(PassName); \
*ret = PassName(args[0]); \
}) \
REGISTER_PASS(ConvertSSA);
#define REGISTER_PASS2(PassName) \ REGISTER_PASS(VerifySSA);
TVM_REGISTER_API("ir_pass."#PassName) \ REGISTER_PASS(RewriteUnsafeSelect);
.set_body([](TVMArgs args, TVMRetValue *ret) { \ REGISTER_PASS(Inline);
*ret = PassName(args[0], args[1]); \ REGISTER_PASS(IRTransform);
}) \ REGISTER_PASS(VectorizeLoop);
REGISTER_PASS(UnrollLoop);
#define REGISTER_PASS3(PassName) \ REGISTER_PASS(InjectCopyIntrin);
TVM_REGISTER_API("ir_pass."#PassName) \ REGISTER_PASS(ThreadSync);
.set_body([](TVMArgs args, TVMRetValue *ret) { \ REGISTER_PASS(MakeAPI);
*ret = PassName(args[0], args[1], args[2]); \ REGISTER_PASS(BindDeviceType);
}) \ REGISTER_PASS(SplitHostDevice);
REGISTER_PASS(StorageRewrite);
#define REGISTER_PASS4(PassName) \ REGISTER_PASS(CoProcSync);
TVM_REGISTER_API("ir_pass."#PassName) \ REGISTER_PASS(LowerStorageAccessInfo);
.set_body([](TVMArgs args, TVMRetValue *ret) { \ REGISTER_PASS(InjectVirtualThread);
*ret = PassName(args[0], args[1], args[2], args[3]); \ REGISTER_PASS(InjectPrefetch);
}) \ REGISTER_PASS(InjectDoubleBuffer);
REGISTER_PASS(LoopPartition);
#define REGISTER_PASS5(PassName) \ REGISTER_PASS(RemoveNoOp);
TVM_REGISTER_API("ir_pass."#PassName) \ REGISTER_PASS(SplitPipeline);
.set_body([](TVMArgs args, TVMRetValue *ret) { \ REGISTER_PASS(LiftAttrScope);
*ret = PassName(args[0], args[1], args[2], args[3], args[4]); \ REGISTER_PASS(NarrowChannelAccess);
}) \ REGISTER_PASS(LowerThreadAllreduce);
REGISTER_PASS(LowerWarpMemory);
REGISTER_PASS1(ConvertSSA); REGISTER_PASS(RemapThreadAxis);
REGISTER_PASS1(VerifySSA); REGISTER_PASS(LowerIntrin);
REGISTER_PASS1(RewriteUnsafeSelect); REGISTER_PASS(LowerTVMBuiltin);
REGISTER_PASS4(Inline); REGISTER_PASS(CombineContextCall);
REGISTER_PASS4(IRTransform); REGISTER_PASS(VerifyMemory);
REGISTER_PASS1(VectorizeLoop); REGISTER_PASS(VerifyGPUCode);
REGISTER_PASS5(UnrollLoop); REGISTER_PASS(DecorateDeviceScope);
REGISTER_PASS3(InjectCopyIntrin); REGISTER_PASS(InstrumentBoundCheckers);
REGISTER_PASS2(ThreadSync);
REGISTER_PASS5(MakeAPI);
REGISTER_PASS2(BindDeviceType);
REGISTER_PASS1(SplitHostDevice);
REGISTER_PASS1(StorageRewrite);
REGISTER_PASS1(CoProcSync);
REGISTER_PASS1(LowerStorageAccessInfo);
REGISTER_PASS1(InjectVirtualThread);
REGISTER_PASS1(InjectPrefetch);
REGISTER_PASS2(InjectDoubleBuffer);
REGISTER_PASS2(LoopPartition);
REGISTER_PASS1(RemoveNoOp);
REGISTER_PASS2(SplitPipeline);
REGISTER_PASS2(LiftAttrScope);
REGISTER_PASS1(NarrowChannelAccess);
REGISTER_PASS2(LowerThreadAllreduce);
REGISTER_PASS2(LowerWarpMemory);
REGISTER_PASS2(RemapThreadAxis);
REGISTER_PASS2(LowerIntrin);
REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode);
REGISTER_PASS1(DecorateDeviceScope);
REGISTER_PASS1(InstrumentBoundCheckers);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
...@@ -33,15 +33,11 @@ namespace tvm { ...@@ -33,15 +33,11 @@ namespace tvm {
namespace schedule { namespace schedule {
TVM_REGISTER_API("schedule.AutoInlineElemWise") TVM_REGISTER_API("schedule.AutoInlineElemWise")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(AutoInlineElemWise);
AutoInlineElemWise(args[0]);
});
TVM_REGISTER_API("schedule.AutoInlineInjective") TVM_REGISTER_API("schedule.AutoInlineInjective")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(AutoInlineInjective);
AutoInlineInjective(args[0]);
});
TVM_REGISTER_API("schedule.ScheduleOps") TVM_REGISTER_API("schedule.ScheduleOps")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
...@@ -51,25 +47,17 @@ TVM_REGISTER_API("schedule.ScheduleOps") ...@@ -51,25 +47,17 @@ TVM_REGISTER_API("schedule.ScheduleOps")
*ret = ScheduleOps(args[0], args[1], args[2]); *ret = ScheduleOps(args[0], args[1], args[2]);
}); });
#define REGISTER_SCHEDULE_PASS1(PassName) \ #define REGISTER_SCHEDULE_PASS(PassName) \
TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = PassName(args[0]); \
}) \
#define REGISTER_SCHEDULE_PASS2(PassName) \
TVM_REGISTER_API("schedule."#PassName) \ TVM_REGISTER_API("schedule."#PassName) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body_typed(PassName); \
*ret = PassName(args[0], args[1]); \
}) \
REGISTER_SCHEDULE_PASS1(InferBound); REGISTER_SCHEDULE_PASS(InferBound);
REGISTER_SCHEDULE_PASS1(CreateReadGraph); REGISTER_SCHEDULE_PASS(CreateReadGraph);
REGISTER_SCHEDULE_PASS2(PostDFSOrder); REGISTER_SCHEDULE_PASS(PostDFSOrder);
REGISTER_SCHEDULE_PASS1(CreateAttachPath); REGISTER_SCHEDULE_PASS(CreateAttachPath);
REGISTER_SCHEDULE_PASS1(ScanGetBody); REGISTER_SCHEDULE_PASS(ScanGetBody);
REGISTER_SCHEDULE_PASS1(ScanFixPointAnalysis); REGISTER_SCHEDULE_PASS(ScanFixPointAnalysis);
} // namespace schedule } // namespace schedule
} // namespace tvm } // namespace tvm
...@@ -263,8 +263,6 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) { ...@@ -263,8 +263,6 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
} }
TVM_REGISTER_API("codegen.build_opencl") TVM_REGISTER_API("codegen.build_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(BuildOpenCL);
*rv = BuildOpenCL(args[0]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -302,9 +302,7 @@ runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) { ...@@ -302,9 +302,7 @@ runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) {
} }
TVM_REGISTER_API("codegen.build_opengl") TVM_REGISTER_API("codegen.build_opengl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(BuildOpenGL);
*rv = BuildOpenGL(args[0]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -164,9 +164,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) { ...@@ -164,9 +164,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
} }
TVM_REGISTER_API("codegen.build_sdaccel") TVM_REGISTER_API("codegen.build_sdaccel")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(BuildSDAccel);
*rv = BuildSDAccel(args[0], args[1]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -265,9 +265,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) { ...@@ -265,9 +265,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
} }
TVM_REGISTER_API("codegen.build_rocm") TVM_REGISTER_API("codegen.build_rocm")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(BuildAMDGPU);
*rv = BuildAMDGPU(args[0], args[1]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -243,9 +243,7 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) { ...@@ -243,9 +243,7 @@ runtime::Module BuildNVPTX(Array<LoweredFunc> funcs, std::string target) {
} }
TVM_REGISTER_API("codegen.build_nvptx") TVM_REGISTER_API("codegen.build_nvptx")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(BuildNVPTX);
*rv = BuildNVPTX(args[0], args[1]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -155,8 +155,6 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) { ...@@ -155,8 +155,6 @@ runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
} }
TVM_REGISTER_API("codegen.build_cuda") TVM_REGISTER_API("codegen.build_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(BuildCUDA);
*rv = BuildCUDA(args[0]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -188,8 +188,6 @@ runtime::Module DeviceSourceModuleCreate( ...@@ -188,8 +188,6 @@ runtime::Module DeviceSourceModuleCreate(
} }
TVM_REGISTER_GLOBAL("module.source_module_create") TVM_REGISTER_GLOBAL("module.source_module_create")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(SourceModuleCreate);
*rv = SourceModuleCreate(args[0], args[1]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -103,9 +103,7 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) { ...@@ -103,9 +103,7 @@ runtime::Module BuildSPIRV(Array<LoweredFunc> funcs) {
} }
TVM_REGISTER_API("codegen.build_vulkan") TVM_REGISTER_API("codegen.build_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(BuildSPIRV);
*rv = BuildSPIRV(args[0]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -522,8 +522,6 @@ runtime::Module BuildStackVM(const Array<LoweredFunc>& funcs) { ...@@ -522,8 +522,6 @@ runtime::Module BuildStackVM(const Array<LoweredFunc>& funcs) {
} }
TVM_REGISTER_API("codegen.build_stackvm") TVM_REGISTER_API("codegen.build_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(BuildStackVM);
*rv = BuildStackVM(args[0]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -51,9 +51,7 @@ Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) { ...@@ -51,9 +51,7 @@ Closure ClosureNode::make(tvm::Map<Var, Value> env, Function func) {
} }
TVM_REGISTER_API("relay._make.Closure") TVM_REGISTER_API("relay._make.Closure")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ClosureNode::make);
*ret = ClosureNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) { .set_dispatch<ClosureNode>([](const ClosureNode* node, tvm::IRPrinter* p) {
...@@ -67,9 +65,7 @@ TupleValue TupleValueNode::make(tvm::Array<Value> value) { ...@@ -67,9 +65,7 @@ TupleValue TupleValueNode::make(tvm::Array<Value> value) {
} }
TVM_REGISTER_API("relay._make.TupleValue") TVM_REGISTER_API("relay._make.TupleValue")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TupleValueNode::make);
*ret = TupleValueNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleValueNode>([](const TupleValueNode* node, tvm::IRPrinter* p) { .set_dispatch<TupleValueNode>([](const TupleValueNode* node, tvm::IRPrinter* p) {
...@@ -90,10 +86,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -90,10 +86,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
}); });
TVM_REGISTER_API("relay._make.TensorValue") TVM_REGISTER_API("relay._make.TensorValue")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TensorValueNode::make);
runtime::NDArray data = args[0];
*ret = TensorValueNode::make(data);
});
RefValue RefValueNode::make(Value value) { RefValue RefValueNode::make(Value value) {
NodePtr<RefValueNode> n = make_node<RefValueNode>(); NodePtr<RefValueNode> n = make_node<RefValueNode>();
...@@ -102,9 +95,7 @@ RefValue RefValueNode::make(Value value) { ...@@ -102,9 +95,7 @@ RefValue RefValueNode::make(Value value) {
} }
TVM_REGISTER_API("relay._make.RefValue") TVM_REGISTER_API("relay._make.RefValue")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(RefValueNode::make);
*ret = RefValueNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const RefValueNode* node, .set_dispatch<RefValueNode>([](const RefValueNode* node,
...@@ -121,9 +112,7 @@ ConstructorValue ConstructorValueNode::make(Constructor constructor, ...@@ -121,9 +112,7 @@ ConstructorValue ConstructorValueNode::make(Constructor constructor,
} }
TVM_REGISTER_API("relay._make.ConstructorValue") TVM_REGISTER_API("relay._make.ConstructorValue")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ConstructorValueNode::make);
*ret = ConstructorValueNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node, .set_dispatch<ConstructorValueNode>([](const ConstructorValueNode* node,
...@@ -614,9 +603,7 @@ CreateInterpreter( ...@@ -614,9 +603,7 @@ CreateInterpreter(
} }
TVM_REGISTER_API("relay.backend.CreateInterpreter") TVM_REGISTER_API("relay.backend.CreateInterpreter")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(CreateInterpreter);
*ret = CreateInterpreter(args[0], args[1], args[2]);
});
TVM_REGISTER_NODE_TYPE(ClosureNode); TVM_REGISTER_NODE_TYPE(ClosureNode);
TVM_REGISTER_NODE_TYPE(TupleValueNode); TVM_REGISTER_NODE_TYPE(TupleValueNode);
......
...@@ -36,9 +36,7 @@ PatternWildcard PatternWildcardNode::make() { ...@@ -36,9 +36,7 @@ PatternWildcard PatternWildcardNode::make() {
TVM_REGISTER_NODE_TYPE(PatternWildcardNode); TVM_REGISTER_NODE_TYPE(PatternWildcardNode);
TVM_REGISTER_API("relay._make.PatternWildcard") TVM_REGISTER_API("relay._make.PatternWildcard")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(PatternWildcardNode::make);
*ret = PatternWildcardNode::make();
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node, .set_dispatch<PatternWildcardNode>([](const PatternWildcardNode* node,
...@@ -55,9 +53,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) { ...@@ -55,9 +53,7 @@ PatternVar PatternVarNode::make(tvm::relay::Var var) {
TVM_REGISTER_NODE_TYPE(PatternVarNode); TVM_REGISTER_NODE_TYPE(PatternVarNode);
TVM_REGISTER_API("relay._make.PatternVar") TVM_REGISTER_API("relay._make.PatternVar")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(PatternVarNode::make);
*ret = PatternVarNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternVarNode>([](const PatternVarNode* node, .set_dispatch<PatternVarNode>([](const PatternVarNode* node,
...@@ -76,9 +72,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor, ...@@ -76,9 +72,7 @@ PatternConstructor PatternConstructorNode::make(Constructor constructor,
TVM_REGISTER_NODE_TYPE(PatternConstructorNode); TVM_REGISTER_NODE_TYPE(PatternConstructorNode);
TVM_REGISTER_API("relay._make.PatternConstructor") TVM_REGISTER_API("relay._make.PatternConstructor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(PatternConstructorNode::make);
*ret = PatternConstructorNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node, .set_dispatch<PatternConstructorNode>([](const PatternConstructorNode* node,
...@@ -100,9 +94,7 @@ Constructor ConstructorNode::make(std::string name_hint, ...@@ -100,9 +94,7 @@ Constructor ConstructorNode::make(std::string name_hint,
TVM_REGISTER_NODE_TYPE(ConstructorNode); TVM_REGISTER_NODE_TYPE(ConstructorNode);
TVM_REGISTER_API("relay._make.Constructor") TVM_REGISTER_API("relay._make.Constructor")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ConstructorNode::make);
*ret = ConstructorNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstructorNode>([](const ConstructorNode* node, .set_dispatch<ConstructorNode>([](const ConstructorNode* node,
...@@ -124,9 +116,7 @@ TypeData TypeDataNode::make(GlobalTypeVar header, ...@@ -124,9 +116,7 @@ TypeData TypeDataNode::make(GlobalTypeVar header,
TVM_REGISTER_NODE_TYPE(TypeDataNode); TVM_REGISTER_NODE_TYPE(TypeDataNode);
TVM_REGISTER_API("relay._make.TypeData") TVM_REGISTER_API("relay._make.TypeData")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TypeDataNode::make);
*ret = TypeDataNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeDataNode>([](const TypeDataNode* node, .set_dispatch<TypeDataNode>([](const TypeDataNode* node,
...@@ -145,9 +135,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) { ...@@ -145,9 +135,7 @@ Clause ClauseNode::make(Pattern lhs, Expr rhs) {
TVM_REGISTER_NODE_TYPE(ClauseNode); TVM_REGISTER_NODE_TYPE(ClauseNode);
TVM_REGISTER_API("relay._make.Clause") TVM_REGISTER_API("relay._make.Clause")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ClauseNode::make);
*ret = ClauseNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ClauseNode>([](const ClauseNode* node, .set_dispatch<ClauseNode>([](const ClauseNode* node,
...@@ -166,9 +154,7 @@ Match MatchNode::make(Expr data, tvm::Array<Clause> clauses) { ...@@ -166,9 +154,7 @@ Match MatchNode::make(Expr data, tvm::Array<Clause> clauses) {
TVM_REGISTER_NODE_TYPE(MatchNode); TVM_REGISTER_NODE_TYPE(MatchNode);
TVM_REGISTER_API("relay._make.Match") TVM_REGISTER_API("relay._make.Match")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(MatchNode::make);
*ret = MatchNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<MatchNode>([](const MatchNode* node, .set_dispatch<MatchNode>([](const MatchNode* node,
......
...@@ -505,18 +505,18 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { ...@@ -505,18 +505,18 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
// TODO(@jroesch): move to correct namespace? // TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal") TVM_REGISTER_API("relay._make._alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
*ret = AlphaEqualHandler(false).Equal(args[0], args[1]); return AlphaEqualHandler(false).Equal(a, b);
}); });
TVM_REGISTER_API("relay._make._type_alpha_equal") TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<bool(Type, Type)>([](Type a, Type b) {
*ret = AlphaEqualHandler(false).TypeEqual(args[0], args[1]); return AlphaEqualHandler(false).TypeEqual(a, b);
}); });
TVM_REGISTER_API("relay._make._graph_equal") TVM_REGISTER_API("relay._make._graph_equal")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
*ret = AlphaEqualHandler(true).Equal(args[0], args[1]); return AlphaEqualHandler(true).Equal(a, b);
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -52,9 +52,7 @@ SourceName SourceName::Get(const std::string& name) { ...@@ -52,9 +52,7 @@ SourceName SourceName::Get(const std::string& name) {
} }
TVM_REGISTER_API("relay._make.SourceName") TVM_REGISTER_API("relay._make.SourceName")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(SourceName::Get);
*ret = SourceName::Get(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) { .set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
...@@ -78,9 +76,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) { ...@@ -78,9 +76,7 @@ Span SpanNode::make(SourceName source, int lineno, int col_offset) {
TVM_REGISTER_NODE_TYPE(SpanNode); TVM_REGISTER_NODE_TYPE(SpanNode);
TVM_REGISTER_API("relay._make.Span") TVM_REGISTER_API("relay._make.Span")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(SpanNode::make);
*ret = SpanNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) { .set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
...@@ -91,11 +87,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -91,11 +87,9 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span") TVM_REGISTER_API("relay._base.set_span")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<void(NodeRef, Span)>([](NodeRef node_ref, Span sp) {
NodeRef node_ref = args[0];
auto rn = node_ref.as_derived<RelayNode>(); auto rn = node_ref.as_derived<RelayNode>();
CHECK(rn); CHECK(rn);
Span sp = args[1];
rn->span = sp; rn->span = sp;
}); });
......
...@@ -39,9 +39,7 @@ Constant ConstantNode::make(runtime::NDArray data) { ...@@ -39,9 +39,7 @@ Constant ConstantNode::make(runtime::NDArray data) {
TVM_REGISTER_NODE_TYPE(ConstantNode); TVM_REGISTER_NODE_TYPE(ConstantNode);
TVM_REGISTER_API("relay._make.Constant") TVM_REGISTER_API("relay._make.Constant")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ConstantNode::make);
*ret = ConstantNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) { .set_dispatch<ConstantNode>([](const ConstantNode* node, tvm::IRPrinter* p) {
...@@ -73,9 +71,7 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) { ...@@ -73,9 +71,7 @@ Tuple TupleNode::make(tvm::Array<relay::Expr> fields) {
TVM_REGISTER_NODE_TYPE(TupleNode); TVM_REGISTER_NODE_TYPE(TupleNode);
TVM_REGISTER_API("relay._make.Tuple") TVM_REGISTER_API("relay._make.Tuple")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TupleNode::make);
*ret = TupleNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) { .set_dispatch<TupleNode>([](const TupleNode* node, tvm::IRPrinter* p) {
...@@ -99,9 +95,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) { ...@@ -99,9 +95,7 @@ Var VarNode::make(std::string name_hint, Type type_annotation) {
TVM_REGISTER_NODE_TYPE(VarNode); TVM_REGISTER_NODE_TYPE(VarNode);
TVM_REGISTER_API("relay._make.Var") TVM_REGISTER_API("relay._make.Var")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(static_cast<Var (*)(std::string, Type)>(VarNode::make));
*ret = VarNode::make(args[0].operator std::string(), args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) { .set_dispatch<VarNode>([](const VarNode* node, tvm::IRPrinter* p) {
...@@ -122,9 +116,7 @@ GlobalVar GlobalVarNode::make(std::string name_hint) { ...@@ -122,9 +116,7 @@ GlobalVar GlobalVarNode::make(std::string name_hint) {
TVM_REGISTER_NODE_TYPE(GlobalVarNode); TVM_REGISTER_NODE_TYPE(GlobalVarNode);
TVM_REGISTER_API("relay._make.GlobalVar") TVM_REGISTER_API("relay._make.GlobalVar")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(GlobalVarNode::make);
*ret = GlobalVarNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) { .set_dispatch<GlobalVarNode>([](const GlobalVarNode* node, tvm::IRPrinter* p) {
...@@ -201,9 +193,7 @@ Function FunctionSetAttr(const Function& func, const std::string& key, const Nod ...@@ -201,9 +193,7 @@ Function FunctionSetAttr(const Function& func, const std::string& key, const Nod
TVM_REGISTER_NODE_TYPE(FunctionNode); TVM_REGISTER_NODE_TYPE(FunctionNode);
TVM_REGISTER_API("relay._make.Function") TVM_REGISTER_API("relay._make.Function")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(FunctionNode::make);
*ret = FunctionNode::make(args[0], args[1], args[2], args[3], args[4]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionNode>([](const FunctionNode* node, .set_dispatch<FunctionNode>([](const FunctionNode* node,
...@@ -226,9 +216,7 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs, ...@@ -226,9 +216,7 @@ Call CallNode::make(Expr op, Array<Expr> args, Attrs attrs,
TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_NODE_TYPE(CallNode);
TVM_REGISTER_API("relay._make.Call") TVM_REGISTER_API("relay._make.Call")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(CallNode::make);
*ret = CallNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) { .set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
...@@ -247,9 +235,7 @@ Let LetNode::make(Var var, Expr value, Expr body) { ...@@ -247,9 +235,7 @@ Let LetNode::make(Var var, Expr value, Expr body) {
TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_NODE_TYPE(LetNode);
TVM_REGISTER_API("relay._make.Let") TVM_REGISTER_API("relay._make.Let")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(LetNode::make);
*ret = LetNode::make(args[0], args[1], args[2]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) { .set_dispatch<LetNode>([](const LetNode* node, tvm::IRPrinter* p) {
...@@ -267,9 +253,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) { ...@@ -267,9 +253,8 @@ If IfNode::make(Expr cond, Expr true_branch, Expr false_branch) {
TVM_REGISTER_NODE_TYPE(IfNode); TVM_REGISTER_NODE_TYPE(IfNode);
TVM_REGISTER_API("relay._make.If").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_API("relay._make.If")
*ret = IfNode::make(args[0], args[1], args[2]); .set_body_typed(IfNode::make);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) { .set_dispatch<IfNode>([](const IfNode* node, tvm::IRPrinter* p) {
...@@ -286,9 +271,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { ...@@ -286,9 +271,8 @@ TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
TVM_REGISTER_NODE_TYPE(TupleGetItemNode); TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_API("relay._make.TupleGetItem")
*ret = TupleGetItemNode::make(args[0], args[1]); .set_body_typed(TupleGetItemNode::make);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) { .set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
...@@ -301,9 +285,8 @@ RefCreate RefCreateNode::make(Expr value) { ...@@ -301,9 +285,8 @@ RefCreate RefCreateNode::make(Expr value) {
return RefCreate(n); return RefCreate(n);
} }
TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) { TVM_REGISTER_API("relay._make.RefCreate")
*ret = RefCreateNode::make(args[0]); .set_body_typed(RefCreateNode::make);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) { .set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
...@@ -317,9 +300,7 @@ RefRead RefReadNode::make(Expr ref) { ...@@ -317,9 +300,7 @@ RefRead RefReadNode::make(Expr ref) {
} }
TVM_REGISTER_API("relay._make.RefRead") TVM_REGISTER_API("relay._make.RefRead")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(RefReadNode::make);
*ret = RefReadNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) { .set_dispatch<RefReadNode>([](const RefReadNode* node, tvm::IRPrinter* p) {
...@@ -334,9 +315,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) { ...@@ -334,9 +315,7 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
} }
TVM_REGISTER_API("relay._make.RefWrite") TVM_REGISTER_API("relay._make.RefWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(RefWriteNode::make);
*ret = RefWriteNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) { .set_dispatch<RefWriteNode>([](const RefWriteNode* node, tvm::IRPrinter* p) {
...@@ -344,9 +323,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -344,9 +323,8 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
}); });
TVM_REGISTER_API("relay._expr.TempExprRealize") TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
TempExpr temp = args[0]; return temp->Realize();
*ret = temp->Realize();
}); });
} // namespace relay } // namespace relay
......
...@@ -346,9 +346,8 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) { ...@@ -346,9 +346,8 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
} }
TVM_REGISTER_API("relay._ir_pass.post_order_visit") TVM_REGISTER_API("relay._ir_pass.post_order_visit")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) {
PackedFunc f = args[1]; PostOrderVisit(expr, [f](const Expr& n) {
PostOrderVisit(args[0], [f](const Expr& n) {
f(n); f(n);
}); });
}); });
......
...@@ -410,14 +410,14 @@ size_t StructuralHash::operator()(const Expr& expr) const { ...@@ -410,14 +410,14 @@ size_t StructuralHash::operator()(const Expr& expr) const {
} }
TVM_REGISTER_API("relay._ir_pass._expr_hash") TVM_REGISTER_API("relay._ir_pass._expr_hash")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<int64_t(NodeRef)>([](NodeRef ref) {
*ret = static_cast<int64_t>(RelayHashHandler().Hash(args[0])); return static_cast<int64_t>(RelayHashHandler().Hash(ref));
}); });
TVM_REGISTER_API("relay._ir_pass._type_hash") TVM_REGISTER_API("relay._ir_pass._type_hash")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<int64_t(Type)>([](Type type) {
*ret = static_cast<int64_t>(RelayHashHandler().TypeHash(args[0])); return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -181,66 +181,43 @@ Module ModuleNode::FromExpr( ...@@ -181,66 +181,43 @@ Module ModuleNode::FromExpr(
TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_API("relay._make.Module") TVM_REGISTER_API("relay._make.Module")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(ModuleNode::make);
*ret = ModuleNode::make(args[0], args[1]);
});
TVM_REGISTER_API("relay._make.Module_Add") TVM_REGISTER_API("relay._make.Module_Add")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method<Module>(&ModuleNode::Add);
Module mod = args[0];
mod->Add(args[1], args[2], args[3]);
});
TVM_REGISTER_API("relay._module.Module_AddDef") TVM_REGISTER_API("relay._module.Module_AddDef")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method<Module>(&ModuleNode::AddDef);
Module mod = args[0];
mod->AddDef(args[1], args[2]);
});
TVM_REGISTER_API("relay._module.Module_GetGlobalVar") TVM_REGISTER_API("relay._module.Module_GetGlobalVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method<Module>(&ModuleNode::GetGlobalVar);
Module mod = args[0];
*ret = mod->GetGlobalVar(args[1]);
});
TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar") TVM_REGISTER_API("relay._module.Module_GetGlobalTypeVar")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
Module mod = args[0];
*ret = mod->GetGlobalTypeVar(args[1]);
});
TVM_REGISTER_API("relay._module.Module_Lookup") TVM_REGISTER_API("relay._module.Module_Lookup")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<Function(Module, GlobalVar)>([](Module mod, GlobalVar var) {
Module mod = args[0]; return mod->Lookup(var);
GlobalVar var = args[1];
*ret = mod->Lookup(var);
}); });
TVM_REGISTER_API("relay._module.Module_Lookup_str") TVM_REGISTER_API("relay._module.Module_Lookup_str")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<Function(Module, std::string)>([](Module mod, std::string var) {
Module mod = args[0]; return mod->Lookup(var);
std::string var_name = args[1];
*ret = mod->Lookup(var_name);
}); });
TVM_REGISTER_API("relay._module.Module_LookupDef") TVM_REGISTER_API("relay._module.Module_LookupDef")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<TypeData(Module, GlobalTypeVar)>([](Module mod, GlobalTypeVar var) {
Module mod = args[0]; return mod->LookupDef(var);
GlobalTypeVar var = args[1];
*ret = mod->LookupDef(var);
}); });
TVM_REGISTER_API("relay._module.Module_LookupDef_str") TVM_REGISTER_API("relay._module.Module_LookupDef_str")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<TypeData(Module, std::string)>([](Module mod, std::string var) {
Module mod = args[0]; return mod->LookupDef(var);
std::string var_name = args[1];
*ret = mod->LookupDef(var_name);
}); });
TVM_REGISTER_API("relay._module.Module_Update") TVM_REGISTER_API("relay._module.Module_Update")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed<void(Module, Module)>([](Module mod, Module from) {
Module mod = args[0]; mod->Update(from);
mod->Update(args[1]);
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
......
...@@ -56,10 +56,7 @@ IndexExpr TensorTypeNode::Size() const { ...@@ -56,10 +56,7 @@ IndexExpr TensorTypeNode::Size() const {
TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_NODE_TYPE(TensorTypeNode);
TVM_REGISTER_API("relay._make.TensorType") TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TensorTypeNode::make);
Array<IndexExpr> shape = args[0];
*ret = TensorTypeNode::make(shape, args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node, .set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
...@@ -77,10 +74,8 @@ TypeVar TypeVarNode::make(std::string name, Kind kind) { ...@@ -77,10 +74,8 @@ TypeVar TypeVarNode::make(std::string name, Kind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_API("relay._make.TypeVar") TVM_REGISTER_API("relay._make.TypeVar")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
int kind = args[1]; return TypeVarNode::make(name, static_cast<Kind>(kind));
*ret =
TypeVarNode::make(args[0], static_cast<Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...@@ -100,10 +95,9 @@ GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) { ...@@ -100,10 +95,9 @@ GlobalTypeVar GlobalTypeVarNode::make(std::string name, Kind kind) {
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_API("relay._make.GlobalTypeVar") TVM_REGISTER_API("relay._make.GlobalTypeVar")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
int kind = args[1]; return GlobalTypeVarNode::make(name, static_cast<Kind>(kind));
*ret = GlobalTypeVarNode::make(args[0], static_cast<Kind>(kind)); });
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node, .set_dispatch<GlobalTypeVarNode>([](const GlobalTypeVarNode *node,
...@@ -122,9 +116,7 @@ TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) { ...@@ -122,9 +116,7 @@ TypeCall TypeCallNode::make(Type func, tvm::Array<Type> args) {
TVM_REGISTER_NODE_TYPE(TypeCallNode); TVM_REGISTER_NODE_TYPE(TypeCallNode);
TVM_REGISTER_API("relay._make.TypeCall") TVM_REGISTER_API("relay._make.TypeCall")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TypeCallNode::make);
*ret = TypeCallNode::make(args[0], args[1]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeCallNode>([](const TypeCallNode* node, .set_dispatch<TypeCallNode>([](const TypeCallNode* node,
...@@ -142,9 +134,8 @@ IncompleteType IncompleteTypeNode::make(Kind kind) { ...@@ -142,9 +134,8 @@ IncompleteType IncompleteTypeNode::make(Kind kind) {
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_API("relay._make.IncompleteType") TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<IncompleteType(int)>([](int kind) {
int kind = args[0]; return IncompleteTypeNode::make(static_cast<Kind>(kind));
*ret = IncompleteTypeNode::make(static_cast<Kind>(kind));
}); });
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
...@@ -169,9 +160,7 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types, ...@@ -169,9 +160,7 @@ FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
TVM_REGISTER_NODE_TYPE(FuncTypeNode); TVM_REGISTER_NODE_TYPE(FuncTypeNode);
TVM_REGISTER_API("relay._make.FuncType") TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(FuncTypeNode::make);
*ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node, .set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
...@@ -196,9 +185,7 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func, ...@@ -196,9 +185,7 @@ TypeRelation TypeRelationNode::make(TypeRelationFn func,
TVM_REGISTER_NODE_TYPE(TypeRelationNode); TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_API("relay._make.TypeRelation") TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TypeRelationNode::make);
*ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) { .set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
...@@ -216,9 +203,7 @@ TupleType TupleTypeNode::make(Array<Type> fields) { ...@@ -216,9 +203,7 @@ TupleType TupleTypeNode::make(Array<Type> fields) {
TVM_REGISTER_NODE_TYPE(TupleTypeNode); TVM_REGISTER_NODE_TYPE(TupleTypeNode);
TVM_REGISTER_API("relay._make.TupleType") TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(TupleTypeNode::make);
*ret = TupleTypeNode::make(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node, .set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
...@@ -233,9 +218,7 @@ RefType RefTypeNode::make(Type value) { ...@@ -233,9 +218,7 @@ RefType RefTypeNode::make(Type value) {
} }
TVM_REGISTER_API("relay._make.RefType") TVM_REGISTER_API("relay._make.RefType")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(RefTypeNode::make);
*ret = RefTypeNode::make(args[0]);
});
TVM_REGISTER_NODE_TYPE(RefTypeNode); TVM_REGISTER_NODE_TYPE(RefTypeNode);
......
...@@ -64,9 +64,7 @@ Expr MakeDebug(Expr expr, std::string name) { ...@@ -64,9 +64,7 @@ Expr MakeDebug(Expr expr, std::string name) {
} }
TVM_REGISTER_API("relay.op._make.debug") TVM_REGISTER_API("relay.op._make.debug")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeDebug);
runtime::detail::unpack_call<Expr, 2>(MakeDebug, args, rv);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
......
...@@ -105,9 +105,7 @@ Expr MakeResize(Expr data, ...@@ -105,9 +105,7 @@ Expr MakeResize(Expr data,
TVM_REGISTER_API("relay.op.image._make.resize") TVM_REGISTER_API("relay.op.image._make.resize")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeResize);
runtime::detail::unpack_call<Expr, 5>(MakeResize, args, rv);
});
RELAY_REGISTER_OP("image.resize") RELAY_REGISTER_OP("image.resize")
......
...@@ -170,9 +170,7 @@ Expr MakeConv2D(Expr data, ...@@ -170,9 +170,7 @@ Expr MakeConv2D(Expr data,
TVM_REGISTER_API("relay.op.nn._make.conv2d") TVM_REGISTER_API("relay.op.nn._make.conv2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeConv2D);
runtime::detail::unpack_call<Expr, 12>(MakeConv2D, args, rv);
});
RELAY_REGISTER_OP("nn.conv2d") RELAY_REGISTER_OP("nn.conv2d")
...@@ -324,9 +322,7 @@ Expr MakeConv2DTranspose(Expr data, ...@@ -324,9 +322,7 @@ Expr MakeConv2DTranspose(Expr data,
TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose") TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeConv2DTranspose);
runtime::detail::unpack_call<Expr, 12>(MakeConv2DTranspose, args, rv);
});
RELAY_REGISTER_OP("nn.conv2d_transpose") RELAY_REGISTER_OP("nn.conv2d_transpose")
.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). .describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution).
...@@ -465,9 +461,7 @@ Expr MakeConv2DWinograd(Expr data, ...@@ -465,9 +461,7 @@ Expr MakeConv2DWinograd(Expr data,
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform") TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_without_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeConv2DWinograd);
runtime::detail::unpack_call<Expr, 13>(MakeConv2DWinograd, args, rv);
});
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform") RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_without_weight_transform")
...@@ -530,9 +524,7 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight, ...@@ -530,9 +524,7 @@ Expr MakeConv2DWinogradWeightTransform(Expr weight,
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform") TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeConv2DWinogradWeightTransform);
runtime::detail::unpack_call<Expr, 2>(MakeConv2DWinogradWeightTransform, args, rv);
});
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform") RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_weight_transform")
...@@ -580,9 +572,7 @@ Expr MakeConv2DWinogradNNPACK(Expr data, ...@@ -580,9 +572,7 @@ Expr MakeConv2DWinogradNNPACK(Expr data,
} }
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform") TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_without_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeConv2DWinogradNNPACK);
runtime::detail::unpack_call<Expr, 12>(MakeConv2DWinogradNNPACK, args, rv);
});
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform") RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_without_weight_transform")
.describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout. .describe(R"code(Compute conv2d with winograd nnpack. Only supports NCHW layout.
...@@ -649,9 +639,7 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight, ...@@ -649,9 +639,7 @@ Expr MakeConv2DWinogradNNPACKWeightTransform(Expr weight,
} }
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform") TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_winograd_nnpack_weight_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeConv2DWinogradNNPACKWeightTransform);
runtime::detail::unpack_call<Expr, 3>(MakeConv2DWinogradNNPACKWeightTransform, args, rv);
});
RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform") RELAY_REGISTER_OP("nn.contrib_conv2d_winograd_nnpack_weight_transform")
.describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK. .describe(R"code(Weight transformation of winograd fast convolution algorithm with NNPACK.
...@@ -698,9 +686,7 @@ Expr MakeConv2DNCHWc(Expr data, ...@@ -698,9 +686,7 @@ Expr MakeConv2DNCHWc(Expr data,
} }
TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc") TVM_REGISTER_API("relay.op.nn._make.contrib_conv2d_NCHWc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeConv2DNCHWc);
runtime::detail::unpack_call<Expr, 12>(MakeConv2DNCHWc, args, rv);
});
RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc") RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
...@@ -750,9 +736,7 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data, ...@@ -750,9 +736,7 @@ Expr MakeDepthwiseConv2DNCHWc(Expr data,
} }
TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc") TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeDepthwiseConv2DNCHWc);
runtime::detail::unpack_call<Expr, 12>(MakeDepthwiseConv2DNCHWc, args, rv);
});
RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc") RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
...@@ -910,9 +894,7 @@ Expr MakeDeformableConv2D(Expr data, ...@@ -910,9 +894,7 @@ Expr MakeDeformableConv2D(Expr data,
} }
TVM_REGISTER_API("relay.op.nn._make.deformable_conv2d") TVM_REGISTER_API("relay.op.nn._make.deformable_conv2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeDeformableConv2D);
runtime::detail::unpack_call<Expr, 14>(MakeDeformableConv2D, args, rv);
});
} // namespace relay } // namespace relay
......
...@@ -78,9 +78,7 @@ Expr MakeBiasAdd(Expr data, ...@@ -78,9 +78,7 @@ Expr MakeBiasAdd(Expr data,
TVM_REGISTER_API("relay.op.nn._make.bias_add") TVM_REGISTER_API("relay.op.nn._make.bias_add")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeBiasAdd);
runtime::detail::unpack_call<Expr, 3>(MakeBiasAdd, args, rv);
});
RELAY_REGISTER_OP("nn.bias_add") RELAY_REGISTER_OP("nn.bias_add")
...@@ -145,9 +143,7 @@ Expr MakeDense(Expr data, ...@@ -145,9 +143,7 @@ Expr MakeDense(Expr data,
TVM_REGISTER_API("relay.op.nn._make.dense") TVM_REGISTER_API("relay.op.nn._make.dense")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeDense);
runtime::detail::unpack_call<Expr, 3>(MakeDense, args, rv);
});
RELAY_REGISTER_OP("nn.dense") RELAY_REGISTER_OP("nn.dense")
...@@ -179,9 +175,7 @@ Expr MakeLeakyRelu(Expr data, ...@@ -179,9 +175,7 @@ Expr MakeLeakyRelu(Expr data,
TVM_REGISTER_API("relay.op.nn._make.leaky_relu") TVM_REGISTER_API("relay.op.nn._make.leaky_relu")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeLeakyRelu);
runtime::detail::unpack_call<Expr, 2>(MakeLeakyRelu, args, rv);
});
RELAY_REGISTER_OP("nn.leaky_relu") RELAY_REGISTER_OP("nn.leaky_relu")
...@@ -244,9 +238,7 @@ Expr MakePRelu(Expr data, ...@@ -244,9 +238,7 @@ Expr MakePRelu(Expr data,
TVM_REGISTER_API("relay.op.nn._make.prelu") TVM_REGISTER_API("relay.op.nn._make.prelu")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakePRelu);
runtime::detail::unpack_call<Expr, 3>(MakePRelu, args, rv);
});
RELAY_REGISTER_OP("nn.prelu") RELAY_REGISTER_OP("nn.prelu")
...@@ -276,17 +268,14 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. ...@@ -276,17 +268,14 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
TVM_REGISTER_API("relay.op.nn._make.softmax") TVM_REGISTER_API("relay.op.nn._make.softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
auto make_func = [](Expr data, int axis) { auto attrs = make_node<SoftmaxAttrs>();
auto attrs = make_node<SoftmaxAttrs>(); attrs->axis = axis;
attrs->axis = axis; static const Op& op = Op::Get("nn.softmax");
static const Op& op = Op::Get("nn.softmax"); return CallNode::make(op, {data}, Attrs(attrs), {});
return CallNode::make(op, {data}, Attrs(attrs), {});
};
runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
}); });
RELAY_REGISTER_OP("nn.softmax") RELAY_REGISTER_OP("nn.softmax")
.describe(R"code(Softmax layer. .describe(R"code(Softmax layer.
...@@ -314,15 +303,11 @@ RELAY_REGISTER_OP("nn.softmax") ...@@ -314,15 +303,11 @@ RELAY_REGISTER_OP("nn.softmax")
// relay.nn.log_softmax // relay.nn.log_softmax
TVM_REGISTER_API("relay.op.nn._make.log_softmax") TVM_REGISTER_API("relay.op.nn._make.log_softmax")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
auto make_func = [](Expr data, int axis) { auto attrs = make_node<SoftmaxAttrs>();
auto attrs = make_node<SoftmaxAttrs>(); attrs->axis = axis;
attrs->axis = axis; static const Op& op = Op::Get("nn.log_softmax");
static const Op& op = Op::Get("nn.log_softmax"); return CallNode::make(op, {data}, Attrs(attrs), {});
return CallNode::make(op, {data}, Attrs(attrs), {});
};
runtime::detail::unpack_call<Expr, 2>(make_func, args, rv);
}); });
RELAY_REGISTER_OP("nn.log_softmax") RELAY_REGISTER_OP("nn.log_softmax")
...@@ -382,9 +367,7 @@ Expr MakeBatchFlatten(Expr data) { ...@@ -382,9 +367,7 @@ Expr MakeBatchFlatten(Expr data) {
TVM_REGISTER_API("relay.op.nn._make.batch_flatten") TVM_REGISTER_API("relay.op.nn._make.batch_flatten")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeBatchFlatten);
runtime::detail::unpack_call<Expr, 1>(MakeBatchFlatten, args, rv);
});
RELAY_REGISTER_OP("nn.batch_flatten") RELAY_REGISTER_OP("nn.batch_flatten")
...@@ -424,7 +407,7 @@ Example:: ...@@ -424,7 +407,7 @@ Example::
// relu // relu
TVM_REGISTER_API("relay.op.nn._make.relu") TVM_REGISTER_API("relay.op.nn._make.relu")
.set_body_typed<Expr(Expr)>([](Expr data) { .set_body_typed<Call(Expr)>([](Expr data) {
static const Op& op = Op::Get("nn.relu"); static const Op& op = Op::Get("nn.relu");
return CallNode::make(op, {data}, Attrs(), {}); return CallNode::make(op, {data}, Attrs(), {});
}); });
...@@ -469,9 +452,7 @@ Expr MakeLRN(Expr data, ...@@ -469,9 +452,7 @@ Expr MakeLRN(Expr data,
} }
TVM_REGISTER_API("relay.op.nn._make.lrn") TVM_REGISTER_API("relay.op.nn._make.lrn")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeLRN);
runtime::detail::unpack_call<Expr, 6>(MakeLRN, args, rv);
});
RELAY_REGISTER_OP("nn.lrn") RELAY_REGISTER_OP("nn.lrn")
.describe(R"code(LRN layer. .describe(R"code(LRN layer.
...@@ -509,9 +490,7 @@ Expr MakeL2Normalize(Expr data, ...@@ -509,9 +490,7 @@ Expr MakeL2Normalize(Expr data,
} }
TVM_REGISTER_API("relay.op.nn._make.l2_normalize") TVM_REGISTER_API("relay.op.nn._make.l2_normalize")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeL2Normalize);
runtime::detail::unpack_call<Expr, 3>(MakeL2Normalize, args, rv);
});
RELAY_REGISTER_OP("nn.l2_normalize") RELAY_REGISTER_OP("nn.l2_normalize")
.describe(R"code(L2 Normalization layer. .describe(R"code(L2 Normalization layer.
...@@ -556,9 +535,7 @@ Expr MakeDropout(Expr data, double rate) { ...@@ -556,9 +535,7 @@ Expr MakeDropout(Expr data, double rate) {
} }
TVM_REGISTER_API("relay.op.nn._make.dropout") TVM_REGISTER_API("relay.op.nn._make.dropout")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeDropout);
runtime::detail::unpack_call<Expr, 2>(MakeDropout, args, rv);
});
RELAY_REGISTER_OP("nn.dropout") RELAY_REGISTER_OP("nn.dropout")
.describe(R"code(Applies the dropout operation to the input array. .describe(R"code(Applies the dropout operation to the input array.
...@@ -622,9 +599,7 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi ...@@ -622,9 +599,7 @@ Expr MakeBatchNorm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr movi
} }
TVM_REGISTER_API("relay.op.nn._make.batch_norm") TVM_REGISTER_API("relay.op.nn._make.batch_norm")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeBatchNorm);
runtime::detail::unpack_call<Expr, 9>(MakeBatchNorm, args, rv);
});
RELAY_REGISTER_OP("nn.batch_norm") RELAY_REGISTER_OP("nn.batch_norm")
.describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014). .describe(R"code(Batch normalization layer (Ioffe and Szegedy, 2014).
...@@ -711,9 +686,7 @@ Expr MakeBatchMatmul(Expr x, ...@@ -711,9 +686,7 @@ Expr MakeBatchMatmul(Expr x,
TVM_REGISTER_API("relay.op.nn._make.batch_matmul") TVM_REGISTER_API("relay.op.nn._make.batch_matmul")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeBatchMatmul);
runtime::detail::unpack_call<Expr, 2>(MakeBatchMatmul, args, rv);
});
RELAY_REGISTER_OP("nn.batch_matmul") RELAY_REGISTER_OP("nn.batch_matmul")
......
...@@ -115,9 +115,7 @@ Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) { ...@@ -115,9 +115,7 @@ Expr MakePad(Expr data, Array<Array<IndexExpr> > pad_width, double pad_value) {
} }
TVM_REGISTER_API("relay.op.nn._make.pad") TVM_REGISTER_API("relay.op.nn._make.pad")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakePad);
runtime::detail::unpack_call<Expr, 3>(MakePad, args, rv);
});
RELAY_REGISTER_OP("nn.pad") RELAY_REGISTER_OP("nn.pad")
.describe(R"code(Pad for n-D tensor. .describe(R"code(Pad for n-D tensor.
......
...@@ -186,9 +186,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs, ...@@ -186,9 +186,7 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
} }
TVM_REGISTER_API("relay.op.nn._make.max_pool2d") TVM_REGISTER_API("relay.op.nn._make.max_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeMaxPool2D);
runtime::detail::unpack_call<Expr, 6>(MakeMaxPool2D, args, rv);
});
RELAY_REGISTER_OP("nn.max_pool2d") RELAY_REGISTER_OP("nn.max_pool2d")
...@@ -242,9 +240,7 @@ Expr MakeAvgPool2D(Expr data, ...@@ -242,9 +240,7 @@ Expr MakeAvgPool2D(Expr data,
TVM_REGISTER_API("relay.op.nn._make.avg_pool2d") TVM_REGISTER_API("relay.op.nn._make.avg_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeAvgPool2D);
runtime::detail::unpack_call<Expr, 7>(MakeAvgPool2D, args, rv);
});
RELAY_REGISTER_OP("nn.avg_pool2d") RELAY_REGISTER_OP("nn.avg_pool2d")
...@@ -345,9 +341,7 @@ Expr MakeGlobalAvgPool2D(Expr data, ...@@ -345,9 +341,7 @@ Expr MakeGlobalAvgPool2D(Expr data,
TVM_REGISTER_API("relay.op.nn._make.global_avg_pool2d") TVM_REGISTER_API("relay.op.nn._make.global_avg_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeGlobalAvgPool2D);
runtime::detail::unpack_call<Expr, 2>(MakeGlobalAvgPool2D, args, rv);
});
// GlobalAvgPool // GlobalAvgPool
RELAY_REGISTER_OP("nn.global_avg_pool2d") RELAY_REGISTER_OP("nn.global_avg_pool2d")
...@@ -378,9 +372,7 @@ Expr MakeGlobalMaxPool2D(Expr data, ...@@ -378,9 +372,7 @@ Expr MakeGlobalMaxPool2D(Expr data,
} }
TVM_REGISTER_API("relay.op.nn._make.global_max_pool2d") TVM_REGISTER_API("relay.op.nn._make.global_max_pool2d")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeGlobalMaxPool2D);
runtime::detail::unpack_call<Expr, 2>(MakeGlobalMaxPool2D, args, rv);
});
RELAY_REGISTER_OP("nn.global_max_pool2d") RELAY_REGISTER_OP("nn.global_max_pool2d")
......
...@@ -110,9 +110,7 @@ Expr MakeUpSampling(Expr data, ...@@ -110,9 +110,7 @@ Expr MakeUpSampling(Expr data,
TVM_REGISTER_API("relay.op.nn._make.upsampling") TVM_REGISTER_API("relay.op.nn._make.upsampling")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeUpSampling);
runtime::detail::unpack_call<Expr, 4>(MakeUpSampling, args, rv);
});
RELAY_REGISTER_OP("nn.upsampling") RELAY_REGISTER_OP("nn.upsampling")
......
...@@ -265,8 +265,8 @@ bool ReduceRel(const Array<Type>& types, ...@@ -265,8 +265,8 @@ bool ReduceRel(const Array<Type>& types,
#define RELAY_REGISTER_REDUCE_OP(OpName) \ #define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_API("relay.op._make." OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \
.set_body([](const TVMArgs& args, TVMRetValue* rv) { \ .set_body_typed<Call(Expr, Array<Integer>, bool, bool)>([]( \
auto make_func = [](Expr data, \ Expr data, \
Array<Integer> axis, \ Array<Integer> axis, \
bool keepdims, \ bool keepdims, \
bool exclude) { \ bool exclude) { \
...@@ -276,8 +276,6 @@ bool ReduceRel(const Array<Type>& types, ...@@ -276,8 +276,6 @@ bool ReduceRel(const Array<Type>& types,
attrs->exclude = exclude; \ attrs->exclude = exclude; \
static const Op& op = Op::Get(OpName); \ static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(attrs), {}); \ return CallNode::make(op, {data}, Attrs(attrs), {}); \
}; \
runtime::detail::unpack_call<Expr, 4>(make_func, args, rv); \
}); \ }); \
RELAY_REGISTER_OP(OpName) \ RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \ .set_num_inputs(1) \
......
...@@ -81,9 +81,7 @@ Expr MakeCast(Expr data, ...@@ -81,9 +81,7 @@ Expr MakeCast(Expr data,
} }
TVM_REGISTER_API("relay._make.cast") TVM_REGISTER_API("relay._make.cast")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeCast);
runtime::detail::unpack_call<Expr, 2>(MakeCast, args, rv);
});
RELAY_REGISTER_OP("cast") RELAY_REGISTER_OP("cast")
.describe(R"code(Cast the data into a new data type. .describe(R"code(Cast the data into a new data type.
...@@ -161,9 +159,7 @@ Expr MakeExpandDims(Expr data, ...@@ -161,9 +159,7 @@ Expr MakeExpandDims(Expr data,
} }
TVM_REGISTER_API("relay.op._make.expand_dims") TVM_REGISTER_API("relay.op._make.expand_dims")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeExpandDims);
runtime::detail::unpack_call<Expr, 3>(MakeExpandDims, args, rv);
});
RELAY_REGISTER_OP("expand_dims") RELAY_REGISTER_OP("expand_dims")
.describe(R"code(Insert `num_newaxis` axises at the position given by `axis` .describe(R"code(Insert `num_newaxis` axises at the position given by `axis`
...@@ -279,9 +275,7 @@ Expr MakeConcatenate(Expr data, ...@@ -279,9 +275,7 @@ Expr MakeConcatenate(Expr data,
} }
TVM_REGISTER_API("relay.op._make.concatenate") TVM_REGISTER_API("relay.op._make.concatenate")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeConcatenate);
runtime::detail::unpack_call<Expr, 2>(MakeConcatenate, args, rv);
});
RELAY_REGISTER_OP("concatenate") RELAY_REGISTER_OP("concatenate")
.describe(R"code(Concatenate the input tensors along the given axis. .describe(R"code(Concatenate the input tensors along the given axis.
...@@ -367,9 +361,7 @@ Expr MakeStack(Expr data, ...@@ -367,9 +361,7 @@ Expr MakeStack(Expr data,
} }
TVM_REGISTER_API("relay.op._make.stack") TVM_REGISTER_API("relay.op._make.stack")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeStack);
runtime::detail::unpack_call<Expr, 2>(MakeStack, args, rv);
});
RELAY_REGISTER_OP("stack") RELAY_REGISTER_OP("stack")
.describe(R"code(Stack the input tensors along the given axis. .describe(R"code(Stack the input tensors along the given axis.
...@@ -461,9 +453,7 @@ Expr MakeTranspose(Expr data, ...@@ -461,9 +453,7 @@ Expr MakeTranspose(Expr data,
} }
TVM_REGISTER_API("relay.op._make.transpose") TVM_REGISTER_API("relay.op._make.transpose")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeTranspose);
runtime::detail::unpack_call<Expr, 2>(MakeTranspose, args, rv);
});
RELAY_REGISTER_OP("transpose") RELAY_REGISTER_OP("transpose")
.describe(R"code(Permutes the dimensions of an array. .describe(R"code(Permutes the dimensions of an array.
...@@ -598,9 +588,7 @@ Expr MakeReshape(Expr data, ...@@ -598,9 +588,7 @@ Expr MakeReshape(Expr data,
} }
TVM_REGISTER_API("relay.op._make.reshape") TVM_REGISTER_API("relay.op._make.reshape")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeReshape);
runtime::detail::unpack_call<Expr, 2>(MakeReshape, args, rv);
});
RELAY_REGISTER_OP("reshape") RELAY_REGISTER_OP("reshape")
.describe(R"code(Reshapes the input array. .describe(R"code(Reshapes the input array.
...@@ -698,9 +686,7 @@ Expr MakeReshapeLike(Expr data, ...@@ -698,9 +686,7 @@ Expr MakeReshapeLike(Expr data,
TVM_REGISTER_API("relay.op._make.reshape_like") TVM_REGISTER_API("relay.op._make.reshape_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeReshapeLike);
runtime::detail::unpack_call<Expr, 2>(MakeReshapeLike, args, rv);
});
RELAY_REGISTER_OP("reshape_like") RELAY_REGISTER_OP("reshape_like")
...@@ -790,9 +776,7 @@ Expr MakeTake(Expr data, ...@@ -790,9 +776,7 @@ Expr MakeTake(Expr data,
} }
TVM_REGISTER_API("relay.op._make.take") TVM_REGISTER_API("relay.op._make.take")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeTake);
runtime::detail::unpack_call<Expr, 4>(MakeTake, args, rv);
});
RELAY_REGISTER_OP("take") RELAY_REGISTER_OP("take")
.describe(R"code(Take elements from an array along an axis. .describe(R"code(Take elements from an array along an axis.
...@@ -873,9 +857,7 @@ Expr MakeFull(Expr fill_value, ...@@ -873,9 +857,7 @@ Expr MakeFull(Expr fill_value,
} }
TVM_REGISTER_API("relay.op._make.full") TVM_REGISTER_API("relay.op._make.full")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeFull);
runtime::detail::unpack_call<Expr, 3>(MakeFull, args, rv);
});
RELAY_REGISTER_OP("full") RELAY_REGISTER_OP("full")
.describe(R"code(Fill array with scalar value. .describe(R"code(Fill array with scalar value.
...@@ -910,9 +892,7 @@ Expr MakeZeros(Array<IndexExpr> shape, ...@@ -910,9 +892,7 @@ Expr MakeZeros(Array<IndexExpr> shape,
} }
TVM_REGISTER_API("relay.op._make.zeros") TVM_REGISTER_API("relay.op._make.zeros")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeZeros);
runtime::detail::unpack_call<Expr, 2>(MakeZeros, args, rv);
});
RELAY_REGISTER_OP("zeros") RELAY_REGISTER_OP("zeros")
.describe(R"code(Fill array with zeros. .describe(R"code(Fill array with zeros.
...@@ -933,9 +913,7 @@ Expr MakeOnes(Array<IndexExpr> shape, ...@@ -933,9 +913,7 @@ Expr MakeOnes(Array<IndexExpr> shape,
} }
TVM_REGISTER_API("relay.op._make.ones") TVM_REGISTER_API("relay.op._make.ones")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeOnes);
runtime::detail::unpack_call<Expr, 2>(MakeOnes, args, rv);
});
RELAY_REGISTER_OP("ones") RELAY_REGISTER_OP("ones")
.describe(R"code(Fill array with ones. .describe(R"code(Fill array with ones.
...@@ -982,9 +960,7 @@ Expr MakeFullLike(Expr data, ...@@ -982,9 +960,7 @@ Expr MakeFullLike(Expr data,
} }
TVM_REGISTER_API("relay.op._make.full_like") TVM_REGISTER_API("relay.op._make.full_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeFullLike);
runtime::detail::unpack_call<Expr, 2>(MakeFullLike, args, rv);
});
RELAY_REGISTER_OP("full_like") RELAY_REGISTER_OP("full_like")
.describe(R"code(Return an scalar value array with the same shape .describe(R"code(Return an scalar value array with the same shape
...@@ -1041,9 +1017,7 @@ Expr MakeArange(tvm::Expr start, ...@@ -1041,9 +1017,7 @@ Expr MakeArange(tvm::Expr start,
} }
TVM_REGISTER_API("relay.op._make.arange") TVM_REGISTER_API("relay.op._make.arange")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeArange);
runtime::detail::unpack_call<Expr, 4>(MakeArange, args, rv);
});
RELAY_REGISTER_OP("arange") RELAY_REGISTER_OP("arange")
.describe(R"code(Returns evenly spaced values within a given interval. .describe(R"code(Returns evenly spaced values within a given interval.
...@@ -1117,9 +1091,7 @@ Expr MakeRepeat(Expr data, ...@@ -1117,9 +1091,7 @@ Expr MakeRepeat(Expr data,
} }
TVM_REGISTER_API("relay.op._make.repeat") TVM_REGISTER_API("relay.op._make.repeat")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeRepeat);
runtime::detail::unpack_call<Expr, 3>(MakeRepeat, args, rv);
});
RELAY_REGISTER_OP("repeat") RELAY_REGISTER_OP("repeat")
.describe(R"code(Repeat elements of an array `repeats` times along axis `axis` .describe(R"code(Repeat elements of an array `repeats` times along axis `axis`
...@@ -1217,9 +1189,7 @@ Expr MakeTile(Expr data, ...@@ -1217,9 +1189,7 @@ Expr MakeTile(Expr data,
} }
TVM_REGISTER_API("relay.op._make.tile") TVM_REGISTER_API("relay.op._make.tile")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeTile);
runtime::detail::unpack_call<Expr, 2>(MakeTile, args, rv);
});
RELAY_REGISTER_OP("tile") RELAY_REGISTER_OP("tile")
.describe(R"code(Repeat the whole array multiple times. .describe(R"code(Repeat the whole array multiple times.
...@@ -1280,9 +1250,7 @@ Expr MakeReverse(Expr data, ...@@ -1280,9 +1250,7 @@ Expr MakeReverse(Expr data,
} }
TVM_REGISTER_API("relay.op._make.reverse") TVM_REGISTER_API("relay.op._make.reverse")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeReverse);
runtime::detail::unpack_call<Expr, 2>(MakeReverse, args, rv);
});
RELAY_REGISTER_OP("reverse") RELAY_REGISTER_OP("reverse")
.describe(R"code(Reverses the order of elements along given `axis` while preserving array shape. .describe(R"code(Reverses the order of elements along given `axis` while preserving array shape.
...@@ -1345,9 +1313,7 @@ Array<Tensor> WhereCompute(const Attrs& attrs, ...@@ -1345,9 +1313,7 @@ Array<Tensor> WhereCompute(const Attrs& attrs,
} }
TVM_REGISTER_API("relay.op._make.where") TVM_REGISTER_API("relay.op._make.where")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeWhere);
runtime::detail::unpack_call<Expr, 3>(MakeWhere, args, rv);
});
RELAY_REGISTER_OP("where") RELAY_REGISTER_OP("where")
.describe(R"code( .describe(R"code(
...@@ -1400,9 +1366,7 @@ Expr MakeSqueeze(Expr data, ...@@ -1400,9 +1366,7 @@ Expr MakeSqueeze(Expr data,
} }
TVM_REGISTER_API("relay.op._make.squeeze") TVM_REGISTER_API("relay.op._make.squeeze")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeSqueeze);
runtime::detail::unpack_call<Expr, 2>(MakeSqueeze, args, rv);
});
bool SqueezeRel(const Array<Type>& types, bool SqueezeRel(const Array<Type>& types,
...@@ -1507,9 +1471,7 @@ Array<Tensor> CollapseSumLikeCompute(const Attrs& attrs, ...@@ -1507,9 +1471,7 @@ Array<Tensor> CollapseSumLikeCompute(const Attrs& attrs,
} }
TVM_REGISTER_API("relay.op._make.collapse_sum_like") TVM_REGISTER_API("relay.op._make.collapse_sum_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeCollapseSumLike);
runtime::detail::unpack_call<Expr, 2>(MakeCollapseSumLike, args, rv);
});
RELAY_REGISTER_OP("collapse_sum_like") RELAY_REGISTER_OP("collapse_sum_like")
.describe(R"code(Collapse the first input to match the shape of the second input. .describe(R"code(Collapse the first input to match the shape of the second input.
...@@ -1554,9 +1516,7 @@ Array<Tensor> BroadCastToCompute(const Attrs& attrs, ...@@ -1554,9 +1516,7 @@ Array<Tensor> BroadCastToCompute(const Attrs& attrs,
} }
TVM_REGISTER_API("relay.op._make.broadcast_to") TVM_REGISTER_API("relay.op._make.broadcast_to")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeBroadCastTo);
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastTo, args, rv);
});
RELAY_REGISTER_OP("broadcast_to") RELAY_REGISTER_OP("broadcast_to")
.describe(R"code(Broadcast the first input to match the shape argument. .describe(R"code(Broadcast the first input to match the shape argument.
...@@ -1594,9 +1554,7 @@ Array<Tensor> BroadCastToLikeCompute(const Attrs& attrs, ...@@ -1594,9 +1554,7 @@ Array<Tensor> BroadCastToLikeCompute(const Attrs& attrs,
} }
TVM_REGISTER_API("relay.op._make.broadcast_to_like") TVM_REGISTER_API("relay.op._make.broadcast_to_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeBroadCastToLike);
runtime::detail::unpack_call<Expr, 2>(MakeBroadCastToLike, args, rv);
});
RELAY_REGISTER_OP("broadcast_to_like") RELAY_REGISTER_OP("broadcast_to_like")
.describe(R"code(Broadcast the first input to match the shape of the second input. .describe(R"code(Broadcast the first input to match the shape of the second input.
...@@ -1806,9 +1764,7 @@ Array<Tensor> StridedSliceCompute(const Attrs& attrs, ...@@ -1806,9 +1764,7 @@ Array<Tensor> StridedSliceCompute(const Attrs& attrs,
TVM_REGISTER_API("relay.op._make.strided_slice") TVM_REGISTER_API("relay.op._make.strided_slice")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeStridedSlice);
runtime::detail::unpack_call<Expr, 4>(MakeStridedSlice, args, rv);
});
RELAY_REGISTER_OP("strided_slice") RELAY_REGISTER_OP("strided_slice")
...@@ -2081,9 +2037,7 @@ Array<Tensor> SliceLikeCompute(const Attrs& attrs, ...@@ -2081,9 +2037,7 @@ Array<Tensor> SliceLikeCompute(const Attrs& attrs,
TVM_REGISTER_API("relay.op._make.slice_like") TVM_REGISTER_API("relay.op._make.slice_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeSliceLike);
runtime::detail::unpack_call<Expr, 3>(MakeSliceLike, args, rv);
});
RELAY_REGISTER_OP("slice_like") RELAY_REGISTER_OP("slice_like")
...@@ -2144,9 +2098,7 @@ Expr MakeLayoutTransform(Expr data, ...@@ -2144,9 +2098,7 @@ Expr MakeLayoutTransform(Expr data,
} }
TVM_REGISTER_API("relay.op._make.layout_transform") TVM_REGISTER_API("relay.op._make.layout_transform")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeLayoutTransform);
runtime::detail::unpack_call<Expr, 3>(MakeLayoutTransform, args, rv);
});
RELAY_REGISTER_OP("layout_transform") RELAY_REGISTER_OP("layout_transform")
.describe(R"code(Transform the input data layout. .describe(R"code(Transform the input data layout.
...@@ -2174,9 +2126,7 @@ Expr MakeReverseReshape(Expr data, ...@@ -2174,9 +2126,7 @@ Expr MakeReverseReshape(Expr data,
} }
TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape") TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeReverseReshape);
runtime::detail::unpack_call<Expr, 2>(MakeReverseReshape, args, rv);
});
RELAY_REGISTER_OP("_contrib_reverse_reshape") RELAY_REGISTER_OP("_contrib_reverse_reshape")
.describe(R"code(Reshapes the input array where the special values are inferred from .describe(R"code(Reshapes the input array where the special values are inferred from
...@@ -2250,9 +2200,7 @@ Expr MakeGatherND(Expr data, ...@@ -2250,9 +2200,7 @@ Expr MakeGatherND(Expr data,
} }
TVM_REGISTER_API("relay.op._make.gather_nd") TVM_REGISTER_API("relay.op._make.gather_nd")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeGatherND);
runtime::detail::unpack_call<Expr, 2>(MakeGatherND, args, rv);
});
RELAY_REGISTER_OP("gather_nd") RELAY_REGISTER_OP("gather_nd")
.describe(R"code(Gather elements or slices from data and store to .describe(R"code(Gather elements or slices from data and store to
......
...@@ -73,9 +73,7 @@ Expr MakeMultiBoxPrior(Expr data, ...@@ -73,9 +73,7 @@ Expr MakeMultiBoxPrior(Expr data,
TVM_REGISTER_API("relay.op.vision._make.multibox_prior") TVM_REGISTER_API("relay.op.vision._make.multibox_prior")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeMultiBoxPrior);
runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxPrior, args, rv);
});
RELAY_REGISTER_OP("vision.multibox_prior") RELAY_REGISTER_OP("vision.multibox_prior")
...@@ -147,9 +145,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob, ...@@ -147,9 +145,7 @@ Expr MakeMultiBoxTransformLoc(Expr cls_prob,
} }
TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc") TVM_REGISTER_API("relay.op.vision._make.multibox_transform_loc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeMultiBoxTransformLoc);
runtime::detail::unpack_call<Expr, 6>(MakeMultiBoxTransformLoc, args, rv);
});
RELAY_REGISTER_OP("vision.multibox_transform_loc") RELAY_REGISTER_OP("vision.multibox_transform_loc")
.describe(R"doc("Location transformation for multibox detection." .describe(R"doc("Location transformation for multibox detection."
......
...@@ -59,9 +59,7 @@ Expr MakeGetValidCounts(Expr data, ...@@ -59,9 +59,7 @@ Expr MakeGetValidCounts(Expr data,
TVM_REGISTER_API("relay.op.vision._make.get_valid_counts") TVM_REGISTER_API("relay.op.vision._make.get_valid_counts")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeGetValidCounts);
runtime::detail::unpack_call<Expr, 2>(MakeGetValidCounts, args, rv);
});
RELAY_REGISTER_OP("vision.get_valid_counts") RELAY_REGISTER_OP("vision.get_valid_counts")
...@@ -125,9 +123,7 @@ Expr MakeNMS(Expr data, ...@@ -125,9 +123,7 @@ Expr MakeNMS(Expr data,
TVM_REGISTER_API("relay.op.vision._make.non_max_suppression") TVM_REGISTER_API("relay.op.vision._make.non_max_suppression")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeNMS);
runtime::detail::unpack_call<Expr, 9>(MakeNMS, args, rv);
});
RELAY_REGISTER_OP("vision.non_max_suppression") RELAY_REGISTER_OP("vision.non_max_suppression")
......
...@@ -62,9 +62,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spa ...@@ -62,9 +62,7 @@ Expr MakeROIAlign(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spa
} }
TVM_REGISTER_API("relay.op.vision._make.roi_align") TVM_REGISTER_API("relay.op.vision._make.roi_align")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeROIAlign);
runtime::detail::unpack_call<Expr, 6>(MakeROIAlign, args, rv);
});
RELAY_REGISTER_OP("vision.roi_align") RELAY_REGISTER_OP("vision.roi_align")
.describe(R"doc(ROI Align operator. .describe(R"doc(ROI Align operator.
...@@ -114,9 +112,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spat ...@@ -114,9 +112,7 @@ Expr MakeROIPool(Expr data, Expr rois, Array<IndexExpr> pooled_size, double spat
} }
TVM_REGISTER_API("relay.op.vision._make.roi_pool") TVM_REGISTER_API("relay.op.vision._make.roi_pool")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeROIPool);
runtime::detail::unpack_call<Expr, 5>(MakeROIPool, args, rv);
});
RELAY_REGISTER_OP("vision.roi_pool") RELAY_REGISTER_OP("vision.roi_pool")
.describe(R"doc(ROI Pool operator. .describe(R"doc(ROI Pool operator.
...@@ -182,9 +178,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array<IndexExpr> ...@@ -182,9 +178,7 @@ Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array<IndexExpr>
} }
TVM_REGISTER_API("relay.op.vision._make.proposal") TVM_REGISTER_API("relay.op.vision._make.proposal")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeProposal);
runtime::detail::unpack_call<Expr, 11>(MakeProposal, args, rv);
});
RELAY_REGISTER_OP("vision.proposal") RELAY_REGISTER_OP("vision.proposal")
.describe(R"code(Generate region proposals via RPN. .describe(R"code(Generate region proposals via RPN.
......
...@@ -71,9 +71,7 @@ Expr MakeYoloReorg(Expr data, ...@@ -71,9 +71,7 @@ Expr MakeYoloReorg(Expr data,
TVM_REGISTER_API("relay.op.vision._make.yolo_reorg") TVM_REGISTER_API("relay.op.vision._make.yolo_reorg")
.set_body([](const TVMArgs& args, TVMRetValue* rv) { .set_body_typed(MakeYoloReorg);
runtime::detail::unpack_call<Expr, 2>(MakeYoloReorg, args, rv);
});
RELAY_REGISTER_OP("vision.yolo_reorg") RELAY_REGISTER_OP("vision.yolo_reorg")
......
...@@ -61,9 +61,7 @@ Expr CanonicalizeOps(const Expr& e) { ...@@ -61,9 +61,7 @@ Expr CanonicalizeOps(const Expr& e) {
} }
TVM_REGISTER_API("relay._ir_pass.canonicalize_ops") TVM_REGISTER_API("relay._ir_pass.canonicalize_ops")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(CanonicalizeOps);
*ret = CanonicalizeOps(args[0]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -355,9 +355,7 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) { ...@@ -355,9 +355,7 @@ Expr CombineParallelConv2D(const Expr& expr, uint64_t min_num_branches) {
} }
TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D") TVM_REGISTER_API("relay._ir_pass.CombineParallelConv2D")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(CombineParallelConv2D);
*ret = CombineParallelConv2D(args[0], args[1]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -148,9 +148,7 @@ Expr DeadCodeElimination(const Expr& e) { ...@@ -148,9 +148,7 @@ Expr DeadCodeElimination(const Expr& e) {
} }
TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") TVM_REGISTER_API("relay._ir_pass.dead_code_elimination")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(DeadCodeElimination);
*ret = DeadCodeElimination(args[0]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -493,19 +493,13 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) { ...@@ -493,19 +493,13 @@ Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
} }
TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo") TVM_REGISTER_API("relay._ir_pass.CollectDeviceInfo")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(CollectDeviceInfo);
*ret = CollectDeviceInfo(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation") TVM_REGISTER_API("relay._ir_pass.RewriteDeviceAnnotation")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(RewriteAnnotatedOps);
*ret = RewriteAnnotatedOps(args[0], args[1]);
});
TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps") TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(CollectDeviceAnnotationOps);
*ret = CollectDeviceAnnotationOps(args[0]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -210,9 +210,7 @@ Expr FoldConstant(const Expr& expr) { ...@@ -210,9 +210,7 @@ Expr FoldConstant(const Expr& expr) {
} }
TVM_REGISTER_API("relay._ir_pass.FoldConstant") TVM_REGISTER_API("relay._ir_pass.FoldConstant")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(FoldConstant);
*ret = FoldConstant(args[0]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -912,8 +912,6 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level) { ...@@ -912,8 +912,6 @@ Expr FuseOps(const Expr& expr, int fuse_opt_level) {
} }
TVM_REGISTER_API("relay._ir_pass.FuseOps") TVM_REGISTER_API("relay._ir_pass.FuseOps")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(FuseOps);
*ret = FuseOps(args[0], args[1]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -247,10 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { ...@@ -247,10 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
} }
TVM_REGISTER_API("relay._ir_pass.first_order_gradient") TVM_REGISTER_API("relay._ir_pass.first_order_gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(FirstOrderGradient);
CHECK_EQ(args.size(), 2);
*ret = FirstOrderGradient(args[0], args[1]);
});
struct ReverseADType : TypeMutator { struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final { Type VisitType_(const TensorTypeNode* ttn) final {
...@@ -263,7 +260,7 @@ struct ReverseAD : ExprMutator { ...@@ -263,7 +260,7 @@ struct ReverseAD : ExprMutator {
Var bp; Var bp;
const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient"); const OpMap<FPrimalGradient> rev_map = Op::GetAttr<FPrimalGradient>("FPrimalGradient");
ReverseAD(const Var& bp) : bp(bp) { } ReverseAD(const Var& bp) : bp(bp) { } /// NOLINT(*)
Expr VisitExpr_(const OpNode* op) final { Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call"; LOG(FATAL) << "op should only be inside call";
...@@ -349,10 +346,7 @@ Expr Gradient(const Expr& re, const Module& mod) { ...@@ -349,10 +346,7 @@ Expr Gradient(const Expr& re, const Module& mod) {
} }
TVM_REGISTER_API("relay._ir_pass.gradient") TVM_REGISTER_API("relay._ir_pass.gradient")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(Gradient);
CHECK_EQ(args.size(), 2);
*ret = Gradient(args[0], args[1]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -147,9 +147,7 @@ int64_t GetTotalMacNumber(const Expr& expr) { ...@@ -147,9 +147,7 @@ int64_t GetTotalMacNumber(const Expr& expr) {
} }
TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber") TVM_REGISTER_API("relay._ir_pass.GetTotalMacNumber")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(GetTotalMacNumber);
*ret = GetTotalMacNumber(args[0]);
});
} // namespace mac_count } // namespace mac_count
} // namespace relay } // namespace relay
......
...@@ -426,12 +426,7 @@ Pass CreateSequentialPass(const tvm::Array<Pass>& passes, ...@@ -426,12 +426,7 @@ Pass CreateSequentialPass(const tvm::Array<Pass>& passes,
TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_NODE_TYPE(PassInfoNode);
TVM_REGISTER_API("relay._ir_pass.PassInfo") TVM_REGISTER_API("relay._ir_pass.PassInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(PassInfoNode::make);
int opt_level = args[0];
std::string name = args[1];
tvm::Array<tvm::Expr> required = args[2];
*ret = PassInfoNode::make(opt_level, name, required);
});
TVM_REGISTER_API("relay._ir_pass.Info") TVM_REGISTER_API("relay._ir_pass.Info")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
...@@ -456,13 +451,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -456,13 +451,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_NODE_TYPE(ModulePassNode);
TVM_REGISTER_API("relay._ir_pass.CreateModulePass") TVM_REGISTER_API("relay._ir_pass.CreateModulePass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(CreateModulePass);
PackedFunc pass_func = args[0];
int opt_level = args[1];
std::string name = args[2];
tvm::Array<tvm::Expr> required = args[3];
*ret = CreateModulePass(pass_func, opt_level, name, required);
});
TVM_REGISTER_API("relay._ir_pass.RunPass") TVM_REGISTER_API("relay._ir_pass.RunPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
...@@ -487,13 +476,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -487,13 +476,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(FunctionPassNode); TVM_REGISTER_NODE_TYPE(FunctionPassNode);
TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass") TVM_REGISTER_API("relay._ir_pass.CreateFunctionPass")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(CreateFunctionPass);
PackedFunc pass_func = args[0];
int opt_level = args[1];
std::string name = args[2];
tvm::Array<tvm::Expr> required = args[3];
*ret = CreateFunctionPass(pass_func, opt_level, name, required);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FunctionPassNode>([](const FunctionPassNode* node, .set_dispatch<FunctionPassNode>([](const FunctionPassNode* node,
...@@ -541,9 +524,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext") ...@@ -541,9 +524,7 @@ TVM_REGISTER_API("relay._ir_pass.SetContext")
TVM_REGISTER_NODE_TYPE(PassContextNode); TVM_REGISTER_NODE_TYPE(PassContextNode);
TVM_REGISTER_API("relay._ir_pass.PassContext") TVM_REGISTER_API("relay._ir_pass.PassContext")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(PassContextNode::make);
*ret = PassContextNode::make();
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<PassContextNode>([](const PassContextNode* node, .set_dispatch<PassContextNode>([](const PassContextNode* node,
......
...@@ -571,20 +571,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) ...@@ -571,20 +571,13 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
}); });
TVM_REGISTER_API("relay._quantize._GetCurrentQConfig") TVM_REGISTER_API("relay._quantize._GetCurrentQConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(QConfig::Current);
*ret = QConfig::Current();
});
TVM_REGISTER_API("relay._quantize._EnterQConfigScope") TVM_REGISTER_API("relay._quantize._EnterQConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(QConfig::EnterQConfigScope);
QConfig target = args[0];
QConfig::EnterQConfigScope(target);
});
TVM_REGISTER_API("relay._quantize._ExitQConfigScope") TVM_REGISTER_API("relay._quantize._ExitQConfigScope")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(QConfig::ExitQConfigScope);
QConfig::ExitQConfigScope();
});
} // namespace quantize } // namespace quantize
} // namespace relay } // namespace relay
......
...@@ -103,9 +103,7 @@ Expr SimplifyInference(const Expr& e) { ...@@ -103,9 +103,7 @@ Expr SimplifyInference(const Expr& e) {
} }
TVM_REGISTER_API("relay._ir_pass.simplify_inference") TVM_REGISTER_API("relay._ir_pass.simplify_inference")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(SimplifyInference);
*ret = SimplifyInference(args[0]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -491,9 +491,7 @@ Expr ToANormalForm(const Expr& e, const Module& m) { ...@@ -491,9 +491,7 @@ Expr ToANormalForm(const Expr& e, const Module& m) {
} }
TVM_REGISTER_API("relay._ir_pass.to_a_normal_form") TVM_REGISTER_API("relay._ir_pass.to_a_normal_form")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(static_cast<Expr (*)(const Expr&, const Module&)>(ToANormalForm));
*ret = ToANormalForm(args[0], args[1]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -77,9 +77,7 @@ Expr ToGraphNormalForm(const Expr& e) { ...@@ -77,9 +77,7 @@ Expr ToGraphNormalForm(const Expr& e) {
} }
TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form") TVM_REGISTER_API("relay._ir_pass.to_graph_normal_form")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(ToGraphNormalForm);
*ret = ToGraphNormalForm(args[0]);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -801,8 +801,8 @@ Function InferType(const Function& func, ...@@ -801,8 +801,8 @@ Function InferType(const Function& func,
} }
TVM_REGISTER_API("relay._ir_pass.infer_type") TVM_REGISTER_API("relay._ir_pass.infer_type")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed<Expr(const Expr&, const Module&)>([](const Expr& expr, const Module& mod_ref) {
*ret = InferType(args[0], args[1]); return InferType(expr, mod_ref);
}); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -275,9 +275,7 @@ tvm::Array<Var> AllVars(const Expr& expr) { ...@@ -275,9 +275,7 @@ tvm::Array<Var> AllVars(const Expr& expr) {
} }
TVM_REGISTER_API("relay._ir_pass.free_vars") TVM_REGISTER_API("relay._ir_pass.free_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(FreeVars);
*ret = FreeVars(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.bound_vars") TVM_REGISTER_API("relay._ir_pass.bound_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
...@@ -290,9 +288,7 @@ TVM_REGISTER_API("relay._ir_pass.bound_vars") ...@@ -290,9 +288,7 @@ TVM_REGISTER_API("relay._ir_pass.bound_vars")
}); });
TVM_REGISTER_API("relay._ir_pass.all_vars") TVM_REGISTER_API("relay._ir_pass.all_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body_typed(AllVars);
*ret = AllVars(args[0]);
});
TVM_REGISTER_API("relay._ir_pass.free_type_vars") TVM_REGISTER_API("relay._ir_pass.free_type_vars")
.set_body([](TVMArgs args, TVMRetValue* ret) { .set_body([](TVMArgs args, TVMRetValue* ret) {
......
...@@ -79,10 +79,7 @@ bool WellFormed(const Expr& e) { ...@@ -79,10 +79,7 @@ bool WellFormed(const Expr& e) {
} }
TVM_REGISTER_API("relay._ir_pass.well_formed") TVM_REGISTER_API("relay._ir_pass.well_formed")
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body_typed(WellFormed);
Expr e = args[0];
*ret = WellFormed(e);
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -308,18 +308,12 @@ Module CUDAModuleLoadBinary(void* strm) { ...@@ -308,18 +308,12 @@ Module CUDAModuleLoadBinary(void* strm) {
} }
TVM_REGISTER_GLOBAL("module.loadfile_cubin") TVM_REGISTER_GLOBAL("module.loadfile_cubin")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(CUDAModuleLoadFile);
*rv = CUDAModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadfile_ptx") TVM_REGISTER_GLOBAL("module.loadfile_ptx")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(CUDAModuleLoadFile);
*rv = CUDAModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_cuda") TVM_REGISTER_GLOBAL("module.loadbinary_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(CUDAModuleLoadBinary);
*rv = CUDAModuleLoadBinary(args[0]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -310,13 +310,9 @@ Module MetalModuleLoadBinary(void* strm) { ...@@ -310,13 +310,9 @@ Module MetalModuleLoadBinary(void* strm) {
} }
TVM_REGISTER_GLOBAL("module.loadfile_metal") TVM_REGISTER_GLOBAL("module.loadfile_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(MetalModuleLoadFile);
*rv = MetalModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_metal") TVM_REGISTER_GLOBAL("module.loadbinary_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(MetalModuleLoadBinary);
*rv = MetalModuleLoadBinary(args[0]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -69,9 +69,7 @@ Module AOCLModuleLoadFile(const std::string& file_name, ...@@ -69,9 +69,7 @@ Module AOCLModuleLoadFile(const std::string& file_name,
} }
TVM_REGISTER_GLOBAL("module.loadfile_aocx") TVM_REGISTER_GLOBAL("module.loadfile_aocx")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(AOCLModuleLoadFile);
*rv = AOCLModuleLoadFile(args[0], args[1]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -281,18 +281,12 @@ Module OpenCLModuleLoadBinary(void* strm) { ...@@ -281,18 +281,12 @@ Module OpenCLModuleLoadBinary(void* strm) {
} }
TVM_REGISTER_GLOBAL("module.loadfile_cl") TVM_REGISTER_GLOBAL("module.loadfile_cl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(OpenCLModuleLoadFile);
*rv = OpenCLModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadfile_clbin") TVM_REGISTER_GLOBAL("module.loadfile_clbin")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(OpenCLModuleLoadFile);
*rv = OpenCLModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_opencl") TVM_REGISTER_GLOBAL("module.loadbinary_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(OpenCLModuleLoadBinary);
*rv = OpenCLModuleLoadBinary(args[0]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -80,13 +80,9 @@ Module SDAccelModuleLoadBinary(void* strm) { ...@@ -80,13 +80,9 @@ Module SDAccelModuleLoadBinary(void* strm) {
} }
TVM_REGISTER_GLOBAL("module.loadfile_xclbin") TVM_REGISTER_GLOBAL("module.loadfile_xclbin")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(SDAccelModuleLoadFile);
*rv = SDAccelModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadfile_awsxclbin") TVM_REGISTER_GLOBAL("module.loadfile_awsxclbin")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(SDAccelModuleLoadFile);
*rv = SDAccelModuleLoadFile(args[0], args[1]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -243,14 +243,10 @@ Module ROCMModuleLoadBinary(void* strm) { ...@@ -243,14 +243,10 @@ Module ROCMModuleLoadBinary(void* strm) {
TVM_REGISTER_GLOBAL("module.loadbinary_hsaco") TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(ROCMModuleLoadBinary);
*rv = ROCMModuleLoadBinary(args[0]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_hip") TVM_REGISTER_GLOBAL("module.loadbinary_hip")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(ROCMModuleLoadBinary);
*rv = ROCMModuleLoadBinary(args[0]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -64,8 +64,6 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend, ...@@ -64,8 +64,6 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend,
} }
TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer") TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(CreateEventDrivenServer);
*rv = CreateEventDrivenServer(args[0], args[1], args[2]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -110,9 +110,7 @@ void RPCServerLoop(int sockfd) { ...@@ -110,9 +110,7 @@ void RPCServerLoop(int sockfd) {
} }
TVM_REGISTER_GLOBAL("rpc._Connect") TVM_REGISTER_GLOBAL("rpc._Connect")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(RPCClientConnect);
*rv = RPCClientConnect(args[0], args[1], args[2]);
});
TVM_REGISTER_GLOBAL("rpc._ServerLoop") TVM_REGISTER_GLOBAL("rpc._ServerLoop")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
......
...@@ -142,9 +142,7 @@ Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap, ...@@ -142,9 +142,7 @@ Module StackVMModuleCreate(std::unordered_map<std::string, StackVM> fmap,
} }
TVM_REGISTER_GLOBAL("module.loadfile_stackvm") TVM_REGISTER_GLOBAL("module.loadfile_stackvm")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(StackVMModuleNode::LoadFromFile);
*rv = StackVMModuleNode::LoadFromFile(args[0], args[1]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -427,13 +427,9 @@ Module VulkanModuleLoadBinary(void* strm) { ...@@ -427,13 +427,9 @@ Module VulkanModuleLoadBinary(void* strm) {
} }
TVM_REGISTER_GLOBAL("module.loadfile_vulkan") TVM_REGISTER_GLOBAL("module.loadfile_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(VulkanModuleLoadFile);
*rv = VulkanModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_vulkan") TVM_REGISTER_GLOBAL("module.loadbinary_vulkan")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed(VulkanModuleLoadBinary);
*rv = VulkanModuleLoadBinary(args[0]);
});
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
...@@ -60,16 +60,16 @@ struct RPCEnv { ...@@ -60,16 +60,16 @@ struct RPCEnv {
}; };
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body_typed<std::string(std::string)>([](std::string path) {
static RPCEnv env; static RPCEnv env;
*rv = env.GetPath(args[0]); return env.GetPath(path);
}); });
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) { .set_body_typed<Module(std::string)>([](std::string path) {
std::string file_name = "/rpc/" + args[0].operator std::string(); std::string file_name = "/rpc/" + path;
*rv = Module::LoadFromFile(file_name, "");
LOG(INFO) << "Load module from " << file_name << " ..."; LOG(INFO) << "Load module from " << file_name << " ...";
return Module::LoadFromFile(file_name, "");
}); });
} // namespace contrib } // namespace contrib
} // namespace tvm } // namespace tvm
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment