Unverified Commit d5d63a44 by Tianqi Chen Committed by GitHub

[REFACTOR] Automatically deduce function type signature in Registry.set_body_typed (#4623)

Previously we support a limited case of function type deduction and in many places
we have to supply the type twice during set_body_typed (one in the template parameter, another in the lambda signature).

This PR improves the deduce function by enablng automatic function signature deduction.

```
TVM_REGISTER_GLOBAL("sub")
.set_body_typed([](int x, int y) -> int { return x - y; });
```

Unfortunately, because of template conflict, we can not support the original case
where both type signature and lambda are supplied through set_body_typed.

This PR refactors the existing regsitration to the new style.
parent f12c4fe2
...@@ -1031,6 +1031,42 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*) ...@@ -1031,6 +1031,42 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
for_each_dispatcher<sizeof...(Args) == 0, 0, F> for_each_dispatcher<sizeof...(Args) == 0, 0, F>
::run(f, std::forward<Args>(args)...); ::run(f, std::forward<Args>(args)...);
} }
template<typename T>
struct func_signature_helper {
using FType = void;
};
template<typename T, typename R, typename ...Args>
struct func_signature_helper<R (T::*)(Args...)> {
using FType = R(Args...);
};
template<typename T, typename R, typename ...Args>
struct func_signature_helper<R (T::*)(Args...) const> {
using FType = R(Args...);
};
/*!
* \brief template class to get function signature of a function or functor.
* \tparam T The funtion/functor type.
*/
template<typename T>
struct function_signature {
using FType = typename func_signature_helper<decltype(&T::operator())>::FType;
};
// handle case of function.
template<typename R, typename ...Args>
struct function_signature<R(Args...)> {
using FType = R(Args...);
};
// handle case of function ptr.
template<typename R, typename ...Args>
struct function_signature<R (*)(Args...)> {
using FType = R(Args...);
};
} // namespace detail } // namespace detail
/* \brief argument settter to PackedFunc */ /* \brief argument settter to PackedFunc */
......
...@@ -46,6 +46,7 @@ ...@@ -46,6 +46,7 @@
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
...@@ -66,28 +67,7 @@ class Registry { ...@@ -66,28 +67,7 @@ class Registry {
return set_body(PackedFunc(f)); return set_body(PackedFunc(f));
} }
/*! /*!
* \brief set the body of the function to be TypedPackedFunc. * \brief set the body of the function to the given function.
*
* \code
*
* TVM_REGISTER_GLOBAL("addone")
* .set_body_typed<int(int)>([](int x) { return x + 1; });
*
* \endcode
*
* \param f The body of the function.
* \tparam FType the signature of the function.
* \tparam FLambda The type of f.
*/
template<typename FType, typename FLambda>
Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed());
}
/*!
* \brief set the body of the function to the given function pointer.
* Note that this doesn't work with lambdas, you need to
* explicitly give a type for those.
* Note that this will ignore default arg values and always require all arguments to be provided. * Note that this will ignore default arg values and always require all arguments to be provided.
* *
* \code * \code
...@@ -99,17 +79,20 @@ class Registry { ...@@ -99,17 +79,20 @@ class Registry {
* TVM_REGISTER_GLOBAL("multiply") * TVM_REGISTER_GLOBAL("multiply")
* .set_body_typed(multiply); // will have type int(int, int) * .set_body_typed(multiply); // will have type int(int, int)
* *
* // will have type int(int, int)
* TVM_REGISTER_GLOBAL("sub")
* .set_body_typed([](int a, int b) -> int { return a - b; });
*
* \endcode * \endcode
* *
* \param f The function to forward to. * \param f The function to forward to.
* \tparam R the return type of the function (inferred). * \tparam FLambda The signature of the function.
* \tparam Args the argument types of the function (inferred).
*/ */
template<typename R, typename ...Args> template<typename FLambda>
Registry& set_body_typed(R (*f)(Args...)) { Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<R(Args...)>(f)); using FType = typename detail::function_signature<FLambda>::FType;
return set_body(TypedPackedFunc<FType>(std::move(f)).packed());
} }
/*! /*!
* \brief set the body of the function to be the passed method pointer. * \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided. * Note that this will ignore default arg values and always require all arguments to be provided.
...@@ -132,10 +115,11 @@ class Registry { ...@@ -132,10 +115,11 @@ class Registry {
*/ */
template<typename T, typename R, typename ...Args> template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...)) { Registry& set_body_method(R (T::*f)(Args...)) {
return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R { auto fwrap =[f](T target, Args... params) -> R {
// call method pointer // call method pointer
return (target.*f)(params...); return (target.*f)(params...);
}); };
return set_body(TypedPackedFunc<R(T, Args...)>(fwrap));
} }
/*! /*!
...@@ -160,10 +144,11 @@ class Registry { ...@@ -160,10 +144,11 @@ class Registry {
*/ */
template<typename T, typename R, typename ...Args> template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...) const) { Registry& set_body_method(R (T::*f)(Args...) const) {
return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R { auto fwrap = [f](const T target, Args... params) -> R {
// call method pointer // call method pointer
return (target.*f)(params...); return (target.*f)(params...);
}); };
return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap));
} }
/*! /*!
...@@ -199,11 +184,12 @@ class Registry { ...@@ -199,11 +184,12 @@ class Registry {
template<typename TObjectRef, typename TNode, typename R, typename ...Args, template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) { Registry& set_body_method(R (TNode::*f)(Args...)) {
return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) { auto fwrap = [f](TObjectRef ref, Args... params) {
TNode* target = ref.operator->(); TNode* target = ref.operator->();
// call method pointer // call method pointer
return (target->*f)(params...); return (target->*f)(params...);
}); };
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
} }
/*! /*!
...@@ -239,11 +225,12 @@ class Registry { ...@@ -239,11 +225,12 @@ class Registry {
template<typename TObjectRef, typename TNode, typename R, typename ...Args, template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type> typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) { Registry& set_body_method(R (TNode::*f)(Args...) const) {
return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) { auto fwrap = [f](TObjectRef ref, Args... params) {
const TNode* target = ref.operator->(); const TNode* target = ref.operator->();
// call method pointer // call method pointer
return (target->*f)(params...); return (target->*f)(params...);
}); };
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
} }
/*! /*!
......
...@@ -48,7 +48,7 @@ TVM_REGISTER_GLOBAL("arith.DetectClipBound") ...@@ -48,7 +48,7 @@ TVM_REGISTER_GLOBAL("arith.DetectClipBound")
.set_body_typed(DetectClipBound); .set_body_typed(DetectClipBound);
TVM_REGISTER_GLOBAL("arith.DeduceBound") TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([]( .set_body_typed([](
Expr v, Expr cond, Expr v, Expr cond,
const Map<Var, IntSet> hint_map, const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map const Map<Var, IntSet> relax_map
......
...@@ -45,10 +45,10 @@ TVM_REGISTER_GLOBAL("_raw_ptr") ...@@ -45,10 +45,10 @@ TVM_REGISTER_GLOBAL("_raw_ptr")
}); });
TVM_REGISTER_GLOBAL("_save_json") TVM_REGISTER_GLOBAL("_save_json")
.set_body_typed<std::string(ObjectRef)>(SaveJSON); .set_body_typed(SaveJSON);
TVM_REGISTER_GLOBAL("_load_json") TVM_REGISTER_GLOBAL("_load_json")
.set_body_typed<ObjectRef(std::string)>(LoadJSON); .set_body_typed(LoadJSON);
TVM_REGISTER_GLOBAL("_TVMSetStream") TVM_REGISTER_GLOBAL("_TVMSetStream")
.set_body_typed(TVMSetStream); .set_body_typed(TVMSetStream);
......
...@@ -32,7 +32,7 @@ namespace tvm { ...@@ -32,7 +32,7 @@ namespace tvm {
namespace ir { namespace ir {
TVM_REGISTER_GLOBAL("_Var") TVM_REGISTER_GLOBAL("_Var")
.set_body_typed<VarExpr(std::string, DataType)>([](std::string s, DataType t) { .set_body_typed([](std::string s, DataType t) {
return Variable::make(t, s); return Variable::make(t, s);
}); });
...@@ -64,7 +64,7 @@ TVM_REGISTER_GLOBAL("make._range_by_min_extent") ...@@ -64,7 +64,7 @@ TVM_REGISTER_GLOBAL("make._range_by_min_extent")
.set_body_typed(Range::make_by_min_extent); .set_body_typed(Range::make_by_min_extent);
TVM_REGISTER_GLOBAL("make.For") TVM_REGISTER_GLOBAL("make.For")
.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([]( .set_body_typed([](
VarExpr loop_var, Expr min, Expr extent, VarExpr loop_var, Expr min, Expr extent,
int for_type, int device_api, Stmt body) { int for_type, int device_api, Stmt body) {
return For::make(loop_var, return For::make(loop_var,
...@@ -99,7 +99,7 @@ TVM_REGISTER_GLOBAL("make.Realize") ...@@ -99,7 +99,7 @@ TVM_REGISTER_GLOBAL("make.Realize")
.set_body_typed(Realize::make); .set_body_typed(Realize::make);
TVM_REGISTER_GLOBAL("make.Call") TVM_REGISTER_GLOBAL("make.Call")
.set_body_typed<Expr(DataType, std::string, Array<Expr>, int, FunctionRef, int)>([]( .set_body_typed([](
DataType type, std::string name, DataType type, std::string name,
Array<Expr> args, int call_type, Array<Expr> args, int call_type,
FunctionRef func, int value_index FunctionRef func, int value_index
...@@ -116,9 +116,9 @@ TVM_REGISTER_GLOBAL("make.CommReducer") ...@@ -116,9 +116,9 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
.set_body_typed(CommReducerNode::make); .set_body_typed(CommReducerNode::make);
// make from two arguments // make from two arguments
#define REGISTER_MAKE(Node) \ #define REGISTER_MAKE(Node) \
TVM_REGISTER_GLOBAL("make."#Node) \ TVM_REGISTER_GLOBAL("make."#Node) \
.set_body_typed(Node::make); \ .set_body_typed(Node::make); \
REGISTER_MAKE(Reduce); REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt); REGISTER_MAKE(AttrStmt);
...@@ -168,32 +168,32 @@ TVM_REGISTER_GLOBAL("make.Block") ...@@ -168,32 +168,32 @@ TVM_REGISTER_GLOBAL("make.Block")
// has default args // has default args
TVM_REGISTER_GLOBAL("make.Allocate") TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed<Stmt(VarExpr, DataType, Array<Expr>, Expr, Stmt)>([]( .set_body_typed([](
VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
){ ){
return Allocate::make(buffer_var, type, extents, condition, body); return Allocate::make(buffer_var, type, extents, condition, body);
}); });
// operator overloading, smarter than make // operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \ #define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \ TVM_REGISTER_GLOBAL("make."#Node) \
.set_body_typed<Expr(Expr, Expr)>([](Expr a, Expr b) { \ .set_body_typed([](Expr a, Expr b) { \
return (Func(a, b)); \ return (Func(a, b)); \
}) })
#define REGISTER_MAKE_BIT_OP(Node, Func) \ #define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \ TVM_REGISTER_GLOBAL("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \ .set_body([](TVMArgs args, TVMRetValue *ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \ bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \ bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \ if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator Expr())); \ *ret = (Func(args[0].operator int(), args[1].operator Expr())); \
} else if (rhs_is_int) { \ } else if (rhs_is_int) { \
*ret = (Func(args[0].operator Expr(), args[1].operator int())); \ *ret = (Func(args[0].operator Expr(), args[1].operator int())); \
} else { \ } else { \
*ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \ *ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
} \ } \
}) })
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+); REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
...@@ -224,7 +224,7 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, operator^); ...@@ -224,7 +224,7 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*) REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>); REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("make._OpIfThenElse") TVM_REGISTER_GLOBAL("make._OpIfThenElse")
.set_body_typed<Expr(Expr, Expr, Expr)>([] (Expr cond, Expr true_value, Expr false_value) { .set_body_typed([] (Expr cond, Expr true_value, Expr false_value) {
return if_then_else(cond, true_value, false_value); return if_then_else(cond, true_value, false_value);
}); });
......
...@@ -236,22 +236,22 @@ TVM_REGISTER_GLOBAL("_Layout") ...@@ -236,22 +236,22 @@ TVM_REGISTER_GLOBAL("_Layout")
.set_body_typed(LayoutNode::make); .set_body_typed(LayoutNode::make);
TVM_REGISTER_GLOBAL("_LayoutIndexOf") TVM_REGISTER_GLOBAL("_LayoutIndexOf")
.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) { .set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::make(axis)); return layout.IndexOf(LayoutAxis::make(axis));
}); });
TVM_REGISTER_GLOBAL("_LayoutFactorOf") TVM_REGISTER_GLOBAL("_LayoutFactorOf")
.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) { .set_body_typed([](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::make(axis)); return layout.FactorOf(LayoutAxis::make(axis));
}); });
TVM_REGISTER_GLOBAL("_LayoutNdim") TVM_REGISTER_GLOBAL("_LayoutNdim")
.set_body_typed<int(Layout)>([](Layout layout) { .set_body_typed([](Layout layout) -> int {
return layout.ndim(); return layout.ndim();
}); });
TVM_REGISTER_GLOBAL("_LayoutGetItem") TVM_REGISTER_GLOBAL("_LayoutGetItem")
.set_body_typed<std::string(Layout, int)>([](Layout layout, int idx) { .set_body_typed([](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx]; const LayoutAxis& axis = layout[idx];
return axis.name(); return axis.name();
}); });
...@@ -284,14 +284,12 @@ TVM_REGISTER_GLOBAL("_TensorEqual") ...@@ -284,14 +284,12 @@ TVM_REGISTER_GLOBAL("_TensorEqual")
.set_body_method(&Tensor::operator==); .set_body_method(&Tensor::operator==);
TVM_REGISTER_GLOBAL("_TensorHash") TVM_REGISTER_GLOBAL("_TensorHash")
.set_body_typed<int64_t(Tensor)>([](Tensor tensor) { .set_body_typed([](Tensor tensor) -> int64_t {
return static_cast<int64_t>(std::hash<Tensor>()(tensor)); return static_cast<int64_t>(std::hash<Tensor>()(tensor));
}); });
TVM_REGISTER_GLOBAL("_Placeholder") TVM_REGISTER_GLOBAL("_Placeholder")
.set_body_typed<Tensor(Array<Expr>, DataType, std::string)>([]( .set_body_typed([](Array<Expr> shape, DataType dtype, std::string name) {
Array<Expr> shape, DataType dtype, std::string name
) {
return placeholder(shape, dtype, name); return placeholder(shape, dtype, name);
}); });
...@@ -311,7 +309,7 @@ TVM_REGISTER_GLOBAL("_HybridOp") ...@@ -311,7 +309,7 @@ TVM_REGISTER_GLOBAL("_HybridOp")
.set_body_typed(HybridOpNode::make); .set_body_typed(HybridOpNode::make);
TVM_REGISTER_GLOBAL("_OpGetOutput") TVM_REGISTER_GLOBAL("_OpGetOutput")
.set_body_typed<Tensor(Operation, int64_t)>([](Operation op, int64_t output) { .set_body_typed([](Operation op, int64_t output) {
return op.output(static_cast<size_t>(output)); return op.output(static_cast<size_t>(output));
}); });
...@@ -322,9 +320,7 @@ TVM_REGISTER_GLOBAL("_OpInputTensors") ...@@ -322,9 +320,7 @@ TVM_REGISTER_GLOBAL("_OpInputTensors")
.set_body_method<Operation>(&OperationNode::InputTensors); .set_body_method<Operation>(&OperationNode::InputTensors);
TVM_REGISTER_GLOBAL("_IterVar") TVM_REGISTER_GLOBAL("_IterVar")
.set_body_typed<IterVar(Range, Var, int, std::string)>([]( .set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
Range dom, Var var, int iter_type, std::string thread_tag
) {
return IterVarNode::make( return IterVarNode::make(
dom, var, dom, var,
static_cast<IterVarType>(iter_type), static_cast<IterVarType>(iter_type),
...@@ -341,25 +337,21 @@ TVM_REGISTER_GLOBAL("_StageBind") ...@@ -341,25 +337,21 @@ TVM_REGISTER_GLOBAL("_StageBind")
.set_body_method(&Stage::bind); .set_body_method(&Stage::bind);
TVM_REGISTER_GLOBAL("_StageSplitByFactor") TVM_REGISTER_GLOBAL("_StageSplitByFactor")
.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([]( .set_body_typed([](Stage stage, IterVar parent, Expr factor) {
Stage stage, IterVar parent, Expr factor
) {
IterVar outer, inner; IterVar outer, inner;
stage.split(parent, factor, &outer, &inner); stage.split(parent, factor, &outer, &inner);
return Array<IterVar>({outer, inner}); return Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_GLOBAL("_StageSplitByNParts") TVM_REGISTER_GLOBAL("_StageSplitByNParts")
.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([]( .set_body_typed([](Stage stage, IterVar parent, Expr nparts) {
Stage stage, IterVar parent, Expr nparts
) {
IterVar outer, inner; IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner); stage.split_by_nparts(parent, nparts, &outer, &inner);
return Array<IterVar>({outer, inner}); return Array<IterVar>({outer, inner});
}); });
TVM_REGISTER_GLOBAL("_StageFuse") TVM_REGISTER_GLOBAL("_StageFuse")
.set_body_typed<IterVar(Stage, Array<IterVar>)>([](Stage stage, Array<IterVar> axes) { .set_body_typed([](Stage stage, Array<IterVar> axes) {
IterVar fused; IterVar fused;
stage.fuse(axes, &fused); stage.fuse(axes, &fused);
return fused; return fused;
...@@ -378,7 +370,7 @@ TVM_REGISTER_GLOBAL("_StageReorder") ...@@ -378,7 +370,7 @@ TVM_REGISTER_GLOBAL("_StageReorder")
.set_body_method(&Stage::reorder); .set_body_method(&Stage::reorder);
TVM_REGISTER_GLOBAL("_StageTile") TVM_REGISTER_GLOBAL("_StageTile")
.set_body_typed<Array<IterVar>(Stage, IterVar, IterVar, Expr, Expr)>([]( .set_body_typed([](
Stage stage, Stage stage,
IterVar x_parent, IterVar y_parent, IterVar x_parent, IterVar y_parent,
Expr x_factor, Expr y_factor Expr x_factor, Expr y_factor
......
...@@ -95,21 +95,21 @@ TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten") ...@@ -95,21 +95,21 @@ TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten")
}); });
TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore") TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed<Stmt(const Stmt&, const Schedule&, const Map<Tensor, Buffer>&)> .set_body_typed
([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) { ([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
return RewriteForTensorCore(stmt, schedule, extern_buffer); return RewriteForTensorCore(stmt, schedule, extern_buffer);
}); });
TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual") TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual")
.set_body_typed<bool(const ObjectRef&, const ObjectRef&)>( .set_body_typed(
[](const ObjectRef& lhs, const ObjectRef& rhs) { [](const ObjectRef& lhs, const ObjectRef& rhs) {
return AttrsEqual()(lhs, rhs); return AttrsEqual()(lhs, rhs);
}); });
TVM_REGISTER_GLOBAL("ir_pass.AttrsHash") TVM_REGISTER_GLOBAL("ir_pass.AttrsHash")
.set_body_typed<int64_t(const ObjectRef&)>([](const ObjectRef &node) { .set_body_typed([](const ObjectRef &node) -> int64_t {
return AttrsHash()(node); return AttrsHash()(node);
}); });
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar") TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
......
...@@ -106,7 +106,7 @@ void ErrorTest(int x, int y) { ...@@ -106,7 +106,7 @@ void ErrorTest(int x, int y) {
} }
TVM_REGISTER_GLOBAL("_ErrorTest") TVM_REGISTER_GLOBAL("_ErrorTest")
.set_body_typed<void(int, int)>(ErrorTest); .set_body_typed(ErrorTest);
// internal function used for debug and testing purposes // internal function used for debug and testing purposes
TVM_REGISTER_GLOBAL("_ndarray_use_count") TVM_REGISTER_GLOBAL("_ndarray_use_count")
......
...@@ -37,7 +37,7 @@ TypeVar TypeVarNode::make(std::string name, TypeKind kind) { ...@@ -37,7 +37,7 @@ TypeVar TypeVarNode::make(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode); TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.TypeVar") TVM_REGISTER_GLOBAL("relay._make.TypeVar")
.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) { .set_body_typed([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<TypeKind>(kind)); return TypeVarNode::make(name, static_cast<TypeKind>(kind));
}); });
...@@ -58,7 +58,7 @@ GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) { ...@@ -58,7 +58,7 @@ GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode); TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar") TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) { .set_body_typed([](std::string name, int kind) {
return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind)); return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind));
}); });
......
...@@ -51,7 +51,7 @@ EnvFunc EnvFunc::Get(const std::string& name) { ...@@ -51,7 +51,7 @@ EnvFunc EnvFunc::Get(const std::string& name) {
} }
TVM_REGISTER_GLOBAL("_EnvFuncGet") TVM_REGISTER_GLOBAL("_EnvFuncGet")
.set_body_typed<EnvFunc(const std::string& name)>(EnvFunc::Get); .set_body_typed(EnvFunc::Get);
TVM_REGISTER_GLOBAL("_EnvFuncCall") TVM_REGISTER_GLOBAL("_EnvFuncCall")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
...@@ -63,7 +63,7 @@ TVM_REGISTER_GLOBAL("_EnvFuncCall") ...@@ -63,7 +63,7 @@ TVM_REGISTER_GLOBAL("_EnvFuncCall")
}); });
TVM_REGISTER_GLOBAL("_EnvFuncGetPackedFunc") TVM_REGISTER_GLOBAL("_EnvFuncGetPackedFunc")
.set_body_typed<PackedFunc(const EnvFunc& n)>([](const EnvFunc&n) { .set_body_typed([](const EnvFunc&n) {
return n->func; return n->func;
}); });
......
...@@ -816,43 +816,43 @@ const CompileEngine& CompileEngine::Global() { ...@@ -816,43 +816,43 @@ const CompileEngine& CompileEngine::Global() {
} }
TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed<CCacheKey(Function, Target)>(CCacheKeyNode::make); .set_body_typed(CCacheKeyNode::make);
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal") TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal")
.set_body_typed<CompileEngine()>([]() { .set_body_typed([]() {
return CompileEngine::Global(); return CompileEngine::Global();
}); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear") TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) { .set_body_typed([](CompileEngine self) {
self->Clear(); self->Clear();
}); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>( .set_body_typed(
[](CompileEngine self, CCacheKey key) { [](CompileEngine self, CCacheKey key) {
return self->Lower(key); return self->Lower(key);
}); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>( .set_body_typed(
[](CompileEngine self, CCacheKey key) { [](CompileEngine self, CCacheKey key) {
return self->LowerShapeFunc(key); return self->LowerShapeFunc(key);
}); });
TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions") TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) { .set_body_typed([](CompileEngine self) {
return self->LowerExternalFunctions(); return self->LowerExternalFunctions();
}); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT") TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>( .set_body_typed(
[](CompileEngine self, CCacheKey key) { [](CompileEngine self, CCacheKey key) {
return self->JIT(key); return self->JIT(key);
}); });
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems") TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems")
.set_body_typed<Array<ObjectRef>(CompileEngine)>( .set_body_typed(
[](CompileEngine self){ [](CompileEngine self){
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems(); return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
}); });
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -390,7 +390,7 @@ Map<Expr, Array<IntegerArray> > GraphPlanMemory(const Function& func) { ...@@ -390,7 +390,7 @@ Map<Expr, Array<IntegerArray> > GraphPlanMemory(const Function& func) {
} }
TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory") TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory")
.set_body_typed<Map<Expr, Array<IntegerArray> >(const Function&)>(GraphPlanMemory); .set_body_typed(GraphPlanMemory);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -595,23 +595,23 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) { ...@@ -595,23 +595,23 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
// TODO(@jroesch): move to correct namespace? // TODO(@jroesch): move to correct namespace?
TVM_REGISTER_GLOBAL("relay._make._alpha_equal") TVM_REGISTER_GLOBAL("relay._make._alpha_equal")
.set_body_typed<bool(ObjectRef, ObjectRef)>([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(false, false).Equal(a, b); return AlphaEqualHandler(false, false).Equal(a, b);
}); });
TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal") TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal")
.set_body_typed<void(ObjectRef, ObjectRef)>([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b); bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal"; CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
}); });
TVM_REGISTER_GLOBAL("relay._make._graph_equal") TVM_REGISTER_GLOBAL("relay._make._graph_equal")
.set_body_typed<bool(ObjectRef, ObjectRef)>([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(true, false).Equal(a, b); return AlphaEqualHandler(true, false).Equal(a, b);
}); });
TVM_REGISTER_GLOBAL("relay._make._assert_graph_equal") TVM_REGISTER_GLOBAL("relay._make._assert_graph_equal")
.set_body_typed<void(ObjectRef, ObjectRef)>([](ObjectRef a, ObjectRef b) { .set_body_typed([](ObjectRef a, ObjectRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b); bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal"; CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
}); });
......
...@@ -35,7 +35,7 @@ using namespace tvm::runtime; ...@@ -35,7 +35,7 @@ using namespace tvm::runtime;
TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_GLOBAL("relay._base.set_span") TVM_REGISTER_GLOBAL("relay._base.set_span")
.set_body_typed<void(ObjectRef, Span)>([](ObjectRef node_ref, Span sp) { .set_body_typed([](ObjectRef node_ref, Span sp) {
if (auto* rn = node_ref.as<RelayNode>()) { if (auto* rn = node_ref.as<RelayNode>()) {
CHECK(rn); CHECK(rn);
rn->span = sp; rn->span = sp;
......
...@@ -167,7 +167,7 @@ Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) cons ...@@ -167,7 +167,7 @@ Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) cons
} }
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetParams") TVM_REGISTER_GLOBAL("relay._expr.FunctionSetParams")
.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>( .set_body_typed(
[](const Function& func, const tvm::Map<Var, Constant>& parameters) { [](const Function& func, const tvm::Map<Var, Constant>& parameters) {
return func->SetParams(parameters); return func->SetParams(parameters);
}); });
...@@ -178,7 +178,7 @@ tvm::Map<Var, Constant> FunctionNode::GetParams() const { ...@@ -178,7 +178,7 @@ tvm::Map<Var, Constant> FunctionNode::GetParams() const {
} }
TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams") TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams")
.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) { .set_body_typed([](const Function& func) {
return func->GetParams(); return func->GetParams();
}); });
...@@ -367,12 +367,12 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -367,12 +367,12 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
}); });
TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize") TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) { .set_body_typed([](TempExpr temp) {
return temp->Realize(); return temp->Realize();
}); });
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr") TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
.set_body_typed<Function(Function, std::string, ObjectRef)>( .set_body_typed(
[](Function func, std::string name, ObjectRef ref) { [](Function func, std::string name, ObjectRef ref) {
return FunctionSetAttr(func, name, ref); return FunctionSetAttr(func, name, ref);
}); });
......
...@@ -348,7 +348,7 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) { ...@@ -348,7 +348,7 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
} }
TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit") TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit")
.set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) { .set_body_typed([](Expr expr, PackedFunc f) {
PostOrderVisit(expr, [f](const Expr& n) { PostOrderVisit(expr, [f](const Expr& n) {
f(n); f(n);
}); });
......
...@@ -424,12 +424,12 @@ size_t StructuralHash::operator()(const Expr& expr) const { ...@@ -424,12 +424,12 @@ size_t StructuralHash::operator()(const Expr& expr) const {
} }
TVM_REGISTER_GLOBAL("relay._analysis._expr_hash") TVM_REGISTER_GLOBAL("relay._analysis._expr_hash")
.set_body_typed<int64_t(ObjectRef)>([](ObjectRef ref) { .set_body_typed([](ObjectRef ref) {
return static_cast<int64_t>(RelayHashHandler().Hash(ref)); return static_cast<int64_t>(RelayHashHandler().Hash(ref));
}); });
TVM_REGISTER_GLOBAL("relay._analysis._type_hash") TVM_REGISTER_GLOBAL("relay._analysis._type_hash")
.set_body_typed<int64_t(Type)>([](Type type) { .set_body_typed([](Type type) {
return static_cast<int64_t>(RelayHashHandler().TypeHash(type)); return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
}); });
......
...@@ -318,7 +318,7 @@ Module FromText(const std::string& source, const std::string& source_name) { ...@@ -318,7 +318,7 @@ Module FromText(const std::string& source, const std::string& source_name) {
TVM_REGISTER_NODE_TYPE(ModuleNode); TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_GLOBAL("relay._make.Module") TVM_REGISTER_GLOBAL("relay._make.Module")
.set_body_typed<Module(tvm::Map<GlobalVar, Function>, tvm::Map<GlobalTypeVar, TypeData>)>( .set_body_typed(
[](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) { [](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) {
return ModuleNode::make(funcs, types, {}); return ModuleNode::make(funcs, types, {});
}); });
...@@ -365,52 +365,49 @@ TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar") ...@@ -365,52 +365,49 @@ TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar")
.set_body_method<Module>(&ModuleNode::GetGlobalTypeVar); .set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup") TVM_REGISTER_GLOBAL("relay._module.Module_Lookup")
.set_body_typed<Function(Module, GlobalVar)>([](Module mod, GlobalVar var) { .set_body_typed([](Module mod, GlobalVar var) {
return mod->Lookup(var); return mod->Lookup(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str") TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str")
.set_body_typed<Function(Module, std::string)>([](Module mod, std::string var) { .set_body_typed([](Module mod, std::string var) {
return mod->Lookup(var); return mod->Lookup(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef") TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef")
.set_body_typed<TypeData(Module, GlobalTypeVar)>([](Module mod, GlobalTypeVar var) { .set_body_typed([](Module mod, GlobalTypeVar var) {
return mod->LookupDef(var); return mod->LookupDef(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str") TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str")
.set_body_typed<TypeData(Module, std::string)>([](Module mod, std::string var) { .set_body_typed([](Module mod, std::string var) {
return mod->LookupDef(var); return mod->LookupDef(var);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag") TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
.set_body_typed<Constructor(Module, int32_t)>([](Module mod, int32_t tag) { .set_body_typed([](Module mod, int32_t tag) {
return mod->LookupTag(tag); return mod->LookupTag(tag);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr") TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr")
.set_body_typed< .set_body_typed([](Expr e,
Module(Expr, tvm::Map<GlobalVar, Function> funcs,
tvm::Map<GlobalVar, Function>, tvm::Map<GlobalTypeVar, TypeData> type_defs) {
tvm::Map<GlobalTypeVar, TypeData>)>([](Expr e, return ModuleNode::FromExpr(e, funcs, type_defs);
tvm::Map<GlobalVar, Function> funcs, });
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return ModuleNode::FromExpr(e, funcs, type_defs);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Update") TVM_REGISTER_GLOBAL("relay._module.Module_Update")
.set_body_typed<void(Module, Module)>([](Module mod, Module from) { .set_body_typed([](Module mod, Module from) {
mod->Update(from); mod->Update(from);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_Import") TVM_REGISTER_GLOBAL("relay._module.Module_Import")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) { .set_body_typed([](Module mod, std::string path) {
mod->Import(path); mod->Import(path);
}); });
TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd") TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) { .set_body_typed([](Module mod, std::string path) {
mod->ImportFromStd(path); mod->ImportFromStd(path);
});; });;
......
...@@ -136,7 +136,7 @@ void OpRegistry::UpdateAttr(const std::string& key, ...@@ -136,7 +136,7 @@ void OpRegistry::UpdateAttr(const std::string& key,
// Frontend APIs // Frontend APIs
TVM_REGISTER_GLOBAL("relay.op._ListOpNames") TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
.set_body_typed<Array<tvm::Expr>()>([]() { .set_body_typed([]() {
Array<tvm::Expr> ret; Array<tvm::Expr> ret;
for (const std::string& name : for (const std::string& name :
dmlc::Registry<OpRegistry>::ListAllNames()) { dmlc::Registry<OpRegistry>::ListAllNames()) {
...@@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames") ...@@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
return ret; return ret;
}); });
TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get); TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get);
TVM_REGISTER_GLOBAL("relay.op._OpGetAttr") TVM_REGISTER_GLOBAL("relay.op._OpGetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
......
...@@ -991,9 +991,7 @@ std::string AsText(const ObjectRef& node, ...@@ -991,9 +991,7 @@ std::string AsText(const ObjectRef& node,
} }
TVM_REGISTER_GLOBAL("relay._expr.AsText") TVM_REGISTER_GLOBAL("relay._expr.AsText")
.set_body_typed<std::string(const ObjectRef&, .set_body_typed(AsText);
bool,
runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -91,7 +91,7 @@ IncompleteType IncompleteTypeNode::make(Kind kind) { ...@@ -91,7 +91,7 @@ IncompleteType IncompleteTypeNode::make(Kind kind) {
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode); TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_GLOBAL("relay._make.IncompleteType") TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
.set_body_typed<IncompleteType(int)>([](int kind) { .set_body_typed([](int kind) {
return IncompleteTypeNode::make(static_cast<Kind>(kind)); return IncompleteTypeNode::make(static_cast<Kind>(kind));
}); });
...@@ -161,8 +161,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) ...@@ -161,8 +161,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
}); });
TVM_REGISTER_GLOBAL("relay._make.Any") TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed<IndexExpr()>([]() { return Any::make(); }); .set_body_typed([]() { return Any::make(); });
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
...@@ -40,7 +40,7 @@ namespace relay { ...@@ -40,7 +40,7 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(OnDeviceAttrs); TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);
TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device") TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device")
.set_body_typed<Expr(Expr, int)>([](Expr data, int device_type) { .set_body_typed([](Expr data, int device_type) {
auto attrs = make_object<OnDeviceAttrs>(); auto attrs = make_object<OnDeviceAttrs>();
attrs->device_type = device_type; attrs->device_type = device_type;
static const Op& op = Op::Get("on_device"); static const Op& op = Op::Get("on_device");
...@@ -63,7 +63,7 @@ Expr StopFusion(Expr data) { ...@@ -63,7 +63,7 @@ Expr StopFusion(Expr data) {
} }
TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion") TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion")
.set_body_typed<Expr(Expr)>([](Expr data) { .set_body_typed([](Expr data) {
return StopFusion(data); return StopFusion(data);
}); });
...@@ -145,7 +145,7 @@ Mark the end of bitpacking. ...@@ -145,7 +145,7 @@ Mark the end of bitpacking.
}); });
TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint") TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint")
.set_body_typed<Expr(Expr)>([](Expr data) { .set_body_typed([](Expr data) {
static const Op& op = Op::Get("annotation.checkpoint"); static const Op& op = Op::Get("annotation.checkpoint");
return CallNode::make(op, {data}, Attrs{}, {}); return CallNode::make(op, {data}, Attrs{}, {});
}); });
......
...@@ -42,7 +42,7 @@ namespace relay { ...@@ -42,7 +42,7 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs); TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);
TVM_REGISTER_GLOBAL("relay.op._make.device_copy") TVM_REGISTER_GLOBAL("relay.op._make.device_copy")
.set_body_typed<Expr(Expr, int, int)>([](Expr data, int src_dev_type, .set_body_typed([](Expr data, int src_dev_type,
int dst_dev_type) { int dst_dev_type) {
auto attrs = make_object<DeviceCopyAttrs>(); auto attrs = make_object<DeviceCopyAttrs>();
attrs->src_dev_type = src_dev_type; attrs->src_dev_type = src_dev_type;
......
...@@ -42,7 +42,7 @@ TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs); ...@@ -42,7 +42,7 @@ TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
// We should consider a better solution, i.e the type relation // We should consider a better solution, i.e the type relation
// being able to see the arguments as well? // being able to see the arguments as well?
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage") TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage")
.set_body_typed<Expr(Expr, Expr, DataType)>([](Expr size, Expr alignment, DataType dtype) { .set_body_typed([](Expr size, Expr alignment, DataType dtype) {
auto attrs = make_object<AllocTensorAttrs>(); auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype; attrs->dtype = dtype;
static const Op& op = Op::Get("memory.alloc_storage"); static const Op& op = Op::Get("memory.alloc_storage");
...@@ -88,7 +88,7 @@ RELAY_REGISTER_OP("memory.alloc_storage") ...@@ -88,7 +88,7 @@ RELAY_REGISTER_OP("memory.alloc_storage")
}); });
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor") TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor")
.set_body_typed<Expr(Expr, Expr, DataType, Array<IndexExpr> assert_shape)>( .set_body_typed(
[](Expr storage, tvm::relay::Expr shape, DataType dtype, Array<IndexExpr> assert_shape) { [](Expr storage, tvm::relay::Expr shape, DataType dtype, Array<IndexExpr> assert_shape) {
auto attrs = make_object<AllocTensorAttrs>(); auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype; attrs->dtype = dtype;
...@@ -209,7 +209,7 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs ...@@ -209,7 +209,7 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
} }
TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op") TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op")
.set_body_typed<Expr(Expr, Expr, Expr)>( .set_body_typed(
[](Expr func, Expr inputs, Expr outputs) { [](Expr func, Expr inputs, Expr outputs) {
return CallNode::make(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs()); return CallNode::make(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
}); });
...@@ -257,7 +257,7 @@ RELAY_REGISTER_OP("memory.kill") ...@@ -257,7 +257,7 @@ RELAY_REGISTER_OP("memory.kill")
}); });
TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func") TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func")
.set_body_typed<Expr(Expr, Expr, Expr, Array<tvm::Integer>)>( .set_body_typed(
[](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) { [](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
static const Op& op = Op::Get("memory.shape_func"); static const Op& op = Op::Get("memory.shape_func");
auto attrs = make_object<ShapeFuncAttrs>(); auto attrs = make_object<ShapeFuncAttrs>();
......
...@@ -326,7 +326,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch. ...@@ -326,7 +326,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax") TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax")
.set_body_typed<Call(Expr, int)>([](Expr data, int axis) { .set_body_typed([](Expr data, int axis) {
auto attrs = make_object<SoftmaxAttrs>(); auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis; attrs->axis = axis;
static const Op& op = Op::Get("nn.softmax"); static const Op& op = Op::Get("nn.softmax");
...@@ -361,7 +361,7 @@ RELAY_REGISTER_OP("nn.softmax") ...@@ -361,7 +361,7 @@ RELAY_REGISTER_OP("nn.softmax")
// relay.nn.log_softmax // relay.nn.log_softmax
TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax") TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax")
.set_body_typed<Call(Expr, int)>([](Expr data, int axis) { .set_body_typed([](Expr data, int axis) {
auto attrs = make_object<SoftmaxAttrs>(); auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis; attrs->axis = axis;
static const Op& op = Op::Get("nn.log_softmax"); static const Op& op = Op::Get("nn.log_softmax");
...@@ -470,7 +470,7 @@ Example:: ...@@ -470,7 +470,7 @@ Example::
// relu // relu
TVM_REGISTER_GLOBAL("relay.op.nn._make.relu") TVM_REGISTER_GLOBAL("relay.op.nn._make.relu")
.set_body_typed<Call(Expr)>([](Expr data) { .set_body_typed([](Expr data) {
static const Op& op = Op::Get("nn.relu"); static const Op& op = Op::Get("nn.relu");
return CallNode::make(op, {data}, Attrs(), {}); return CallNode::make(op, {data}, Attrs(), {});
}); });
......
...@@ -214,13 +214,12 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs, ...@@ -214,13 +214,12 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
} }
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d") TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d")
.set_body_typed<Expr(Expr, Array<IndexExpr>, Array<IndexExpr>, Array<IndexExpr>, .set_body_typed([](Expr data,
std::string, bool)>([](Expr data, Array<IndexExpr> pool_size,
Array<IndexExpr> pool_size, Array<IndexExpr> strides,
Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> padding, std::string layout,
std::string layout, bool ceil_mode) {
bool ceil_mode) {
return MakeMaxPool<MaxPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode, return MakeMaxPool<MaxPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
"nn.max_pool2d"); "nn.max_pool2d");
}); });
...@@ -258,14 +257,13 @@ RELAY_REGISTER_OP("nn.max_pool2d") ...@@ -258,14 +257,13 @@ RELAY_REGISTER_OP("nn.max_pool2d")
// AvgPool2D // AvgPool2D
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d") TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d")
.set_body_typed<Expr(Expr, Array<IndexExpr>, Array<IndexExpr>, Array<IndexExpr>, .set_body_typed([](Expr data,
std::string, bool, bool)>([](Expr data, Array<IndexExpr> pool_size,
Array<IndexExpr> pool_size, Array<IndexExpr> strides,
Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> padding, std::string layout,
std::string layout, bool ceil_mode,
bool ceil_mode, bool count_include_pad) {
bool count_include_pad) {
return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode, return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
count_include_pad, "nn.avg_pool2d"); count_include_pad, "nn.avg_pool2d");
}); });
...@@ -868,13 +866,12 @@ Array<Tensor> Pool3DCompute(const Attrs& attrs, ...@@ -868,13 +866,12 @@ Array<Tensor> Pool3DCompute(const Attrs& attrs,
} }
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d") TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d")
.set_body_typed<Expr(Expr, Array<IndexExpr>, Array<IndexExpr>, Array<IndexExpr>, .set_body_typed([](Expr data,
std::string, bool)>([](Expr data, Array<IndexExpr> pool_size,
Array<IndexExpr> pool_size, Array<IndexExpr> strides,
Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> padding, std::string layout,
std::string layout, bool ceil_mode) {
bool ceil_mode) {
return MakeMaxPool<MaxPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode, return MakeMaxPool<MaxPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
"nn.max_pool3d"); "nn.max_pool3d");
}); });
...@@ -912,14 +909,13 @@ RELAY_REGISTER_OP("nn.max_pool3d") ...@@ -912,14 +909,13 @@ RELAY_REGISTER_OP("nn.max_pool3d")
// AvgPool3D // AvgPool3D
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d") TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d")
.set_body_typed<Expr(Expr, Array<IndexExpr>, Array<IndexExpr>, Array<IndexExpr>, .set_body_typed([](Expr data,
std::string, bool, bool)>([](Expr data, Array<IndexExpr> pool_size,
Array<IndexExpr> pool_size, Array<IndexExpr> strides,
Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> padding, std::string layout,
std::string layout, bool ceil_mode,
bool ceil_mode, bool count_include_pad) {
bool count_include_pad) {
return MakeAvgPool<AvgPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode, return MakeAvgPool<AvgPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
count_include_pad, "nn.avg_pool3d"); count_include_pad, "nn.avg_pool3d");
}); });
......
...@@ -48,19 +48,19 @@ namespace relay { ...@@ -48,19 +48,19 @@ namespace relay {
* \param OpName the name of registry. * \param OpName the name of registry.
*/ */
#define RELAY_REGISTER_UNARY_OP(OpName) \ #define RELAY_REGISTER_UNARY_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed<Expr(Expr)>([](Expr data) { \ .set_body_typed([](Expr data) { \
static const Op& op = Op::Get(OpName); \ static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(), {}); \ return CallNode::make(op, {data}, Attrs(), {}); \
}); \ }); \
RELAY_REGISTER_OP(OpName) \ RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \ .set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.") \ .add_argument("data", "Tensor", "The input tensor.") \
.add_type_rel("Identity", IdentityRel) \ .add_type_rel("Identity", IdentityRel) \
.set_attr<TOpPattern>("TOpPattern", kElemWise) \ .set_attr<TOpPattern>("TOpPattern", kElemWise) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \ .set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
ElemwiseArbitraryLayout) \ ElemwiseArbitraryLayout) \
/*! Quick helper macro /*! Quick helper macro
...@@ -73,38 +73,38 @@ namespace relay { ...@@ -73,38 +73,38 @@ namespace relay {
* *
* \param OpName the name of registry. * \param OpName the name of registry.
*/ */
#define RELAY_REGISTER_BINARY_OP(OpName) \ #define RELAY_REGISTER_BINARY_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \ .set_body_typed([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \ static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \ }); \
RELAY_REGISTER_OP(OpName) \ RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \ .set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel) \ .add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \ .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \ .set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout) BinaryBroadcastLayout)
// Comparisons // Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \ #define RELAY_REGISTER_CMP_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \ .set_body_typed([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \ static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \ }); \
RELAY_REGISTER_OP(OpName) \ RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \ .set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("BroadcastComp", BroadcastCompRel) \ .add_type_rel("BroadcastComp", BroadcastCompRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \ .set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \ .set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout) BinaryBroadcastLayout)
/*! \brief A helper class for matching and rewriting operators. */ /*! \brief A helper class for matching and rewriting operators. */
......
...@@ -303,7 +303,7 @@ bool ReduceRel(const Array<Type>& types, ...@@ -303,7 +303,7 @@ bool ReduceRel(const Array<Type>& types,
#define RELAY_REGISTER_REDUCE_OP(OpName) \ #define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \ TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed<Call(Expr, Array<Integer>, bool, bool)>([]( \ .set_body_typed([]( \
Expr data, \ Expr data, \
Array<Integer> axis, \ Array<Integer> axis, \
bool keepdims, \ bool keepdims, \
......
...@@ -858,7 +858,7 @@ bool ArgWhereRel(const Array<Type>& types, ...@@ -858,7 +858,7 @@ bool ArgWhereRel(const Array<Type>& types,
} }
TVM_REGISTER_GLOBAL("relay.op._make.argwhere") TVM_REGISTER_GLOBAL("relay.op._make.argwhere")
.set_body_typed<Expr(Expr)>([](Expr data) { .set_body_typed([](Expr data) {
static const Op& op = Op::Get("argwhere"); static const Op& op = Op::Get("argwhere");
auto attrs = make_object<ArgWhereAttrs>(); auto attrs = make_object<ArgWhereAttrs>();
return CallNode::make(op, {data}, Attrs(attrs), {}); return CallNode::make(op, {data}, Attrs(attrs), {});
......
...@@ -158,7 +158,7 @@ RELAY_REGISTER_UNARY_OP("copy") ...@@ -158,7 +158,7 @@ RELAY_REGISTER_UNARY_OP("copy")
TVM_REGISTER_NODE_TYPE(ClipAttrs); TVM_REGISTER_NODE_TYPE(ClipAttrs);
TVM_REGISTER_GLOBAL("relay.op._make.clip") TVM_REGISTER_GLOBAL("relay.op._make.clip")
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) { .set_body_typed([](Expr a, double a_min, double a_max) {
auto attrs = make_object<ClipAttrs>(); auto attrs = make_object<ClipAttrs>();
attrs->a_min = a_min; attrs->a_min = a_min;
attrs->a_max = a_max; attrs->a_max = a_max;
...@@ -301,7 +301,7 @@ Array<Tensor> ShapeOfCompute(const Attrs& attrs, ...@@ -301,7 +301,7 @@ Array<Tensor> ShapeOfCompute(const Attrs& attrs,
} }
TVM_REGISTER_GLOBAL("relay.op._make.shape_of") TVM_REGISTER_GLOBAL("relay.op._make.shape_of")
.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) { .set_body_typed([](Expr data, DataType dtype) {
auto attrs = make_object<ShapeOfAttrs>(); auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = dtype; attrs->dtype = dtype;
static const Op& op = Op::Get("shape_of"); static const Op& op = Op::Get("shape_of");
...@@ -352,7 +352,7 @@ Array<Tensor> NdarraySizeCompute(const Attrs& attrs, ...@@ -352,7 +352,7 @@ Array<Tensor> NdarraySizeCompute(const Attrs& attrs,
} }
TVM_REGISTER_GLOBAL("relay.op.contrib._make.ndarray_size") TVM_REGISTER_GLOBAL("relay.op.contrib._make.ndarray_size")
.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) { .set_body_typed([](Expr data, DataType dtype) {
auto attrs = make_object<NdarraySizeAttrs>(); auto attrs = make_object<NdarraySizeAttrs>();
attrs->dtype = dtype; attrs->dtype = dtype;
static const Op& op = Op::Get("contrib.ndarray_size"); static const Op& op = Op::Get("contrib.ndarray_size");
......
...@@ -327,7 +327,7 @@ Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) { ...@@ -327,7 +327,7 @@ Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
// expose for testing only // expose for testing only
TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases") TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases")
.set_body_typed<Array<Pattern>(const Match&, const Module&)>( .set_body_typed(
[](const Match& match, const Module& mod_ref) { [](const Match& match, const Module& mod_ref) {
Module call_mod = mod_ref; Module call_mod = mod_ref;
if (!call_mod.defined()) { if (!call_mod.defined()) {
......
...@@ -67,7 +67,7 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") ...@@ -67,7 +67,7 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize") TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize")
.set_body_typed<Expr(Expr, Expr, Expr, Expr, int, bool, std::string)>( .set_body_typed(
[](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max,
int kind, bool sign, std::string rounding) { int kind, bool sign, std::string rounding) {
auto attrs = make_object<SimulatedQuantizeAttrs>(); auto attrs = make_object<SimulatedQuantizeAttrs>();
......
...@@ -79,7 +79,7 @@ bool TupleGetItemRel(const Array<Type>& types, ...@@ -79,7 +79,7 @@ bool TupleGetItemRel(const Array<Type>& types,
TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs); TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem") TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>( .set_body_typed(
TupleGetItemRel); TupleGetItemRel);
struct ResolvedTypeInfo { struct ResolvedTypeInfo {
...@@ -840,7 +840,7 @@ Pass InferType() { ...@@ -840,7 +840,7 @@ Pass InferType() {
} }
TVM_REGISTER_GLOBAL("relay._transform.InferType") TVM_REGISTER_GLOBAL("relay._transform.InferType")
.set_body_typed<Pass()>([]() { .set_body_typed([]() {
return InferType(); return InferType();
}); });
......
...@@ -63,19 +63,18 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con ...@@ -63,19 +63,18 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
* *
* \param OpName the name of registry. * \param OpName the name of registry.
*/ */
#define QNN_REGISTER_BINARY_OP(OpName) \ #define QNN_REGISTER_BINARY_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \ TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr, Expr, Expr, Expr, Expr, Expr, Expr)>( \ .set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \
[](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \ Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \
Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \ static const Op& op = Op::Get("qnn." OpName); \
static const Op& op = Op::Get("qnn." OpName); \ return CallNode::make(op, {lhs, rhs, \
return CallNode::make(op, {lhs, rhs, \ lhs_scale, lhs_zero_point, \
lhs_scale, lhs_zero_point, \ rhs_scale, rhs_zero_point, \
rhs_scale, rhs_zero_point, \ output_scale, output_zero_point}, Attrs(), {}); \
output_scale, output_zero_point}, Attrs(), {}); \ }); \
}); \ RELAY_REGISTER_OP("qnn." OpName) \
RELAY_REGISTER_OP("qnn." OpName) \ .set_num_inputs(8) \
.set_num_inputs(8) \
.add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \ .add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \
.add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \ .add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \
.add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \ .add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \
......
...@@ -795,7 +795,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields") ...@@ -795,7 +795,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields")
}); });
TVM_REGISTER_GLOBAL("relay._vm.Load_Executable") TVM_REGISTER_GLOBAL("relay._vm.Load_Executable")
.set_body_typed<runtime::Module(std::string, runtime::Module)>([]( .set_body_typed([](
std::string code, std::string code,
runtime::Module lib) { runtime::Module lib) {
return Executable::Load(code, lib); return Executable::Load(code, lib);
......
...@@ -179,6 +179,24 @@ TEST(TypedPackedFunc, HighOrder) { ...@@ -179,6 +179,24 @@ TEST(TypedPackedFunc, HighOrder) {
CHECK_EQ(f1(3), 4); CHECK_EQ(f1(3), 4);
} }
TEST(TypedPackedFunc, Deduce) {
using namespace tvm::runtime;
using tvm::runtime::detail::function_signature;
TypedPackedFunc<int(float)> x;
auto f = [](int x) -> int {
return x + 1;
};
std::function<void(float)> y;
static_assert(std::is_same<function_signature<decltype(x)>::FType,
int(float)>::value, "invariant1");
static_assert(std::is_same<function_signature<decltype(f)>::FType,
int(int)>::value, "invariant2");
static_assert(std::is_same<function_signature<decltype(y)>::FType,
void(float)>::value, "invariant3");
}
TEST(PackedFunc, ObjectConversion) { TEST(PackedFunc, ObjectConversion) {
using namespace tvm; using namespace tvm;
......
...@@ -60,13 +60,13 @@ struct RPCEnv { ...@@ -60,13 +60,13 @@ struct RPCEnv {
}; };
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body_typed<std::string(std::string)>([](std::string path) { .set_body_typed([](std::string path) {
static RPCEnv env; static RPCEnv env;
return env.GetPath(path); return env.GetPath(path);
}); });
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body_typed<Module(std::string)>([](std::string path) { .set_body_typed([](std::string path) {
std::string file_name = "/rpc/" + path; std::string file_name = "/rpc/" + path;
LOG(INFO) << "Load module from " << file_name << " ..."; LOG(INFO) << "Load module from " << file_name << " ...";
return Module::LoadFromFile(file_name, ""); return Module::LoadFromFile(file_name, "");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment