Unverified Commit d5d63a44 by Tianqi Chen Committed by GitHub

[REFACTOR] Automatically deduce function type signature in Registry.set_body_typed (#4623)

Previously we support a limited case of function type deduction and in many places
we have to supply the type twice during set_body_typed (one in the template parameter, another in the lambda signature).

This PR improves the deduce function by enablng automatic function signature deduction.

```
TVM_REGISTER_GLOBAL("sub")
.set_body_typed([](int x, int y) -> int { return x - y; });
```

Unfortunately, because of template conflict, we can not support the original case
where both type signature and lambda are supplied through set_body_typed.

This PR refactors the existing regsitration to the new style.
parent f12c4fe2
......@@ -1031,6 +1031,42 @@ inline void for_each(const F& f, Args&&... args) { // NOLINT(*)
for_each_dispatcher<sizeof...(Args) == 0, 0, F>
::run(f, std::forward<Args>(args)...);
}
template<typename T>
struct func_signature_helper {
using FType = void;
};
template<typename T, typename R, typename ...Args>
struct func_signature_helper<R (T::*)(Args...)> {
using FType = R(Args...);
};
template<typename T, typename R, typename ...Args>
struct func_signature_helper<R (T::*)(Args...) const> {
using FType = R(Args...);
};
/*!
* \brief template class to get function signature of a function or functor.
* \tparam T The funtion/functor type.
*/
template<typename T>
struct function_signature {
using FType = typename func_signature_helper<decltype(&T::operator())>::FType;
};
// handle case of function.
template<typename R, typename ...Args>
struct function_signature<R(Args...)> {
using FType = R(Args...);
};
// handle case of function ptr.
template<typename R, typename ...Args>
struct function_signature<R (*)(Args...)> {
using FType = R(Args...);
};
} // namespace detail
/* \brief argument settter to PackedFunc */
......
......@@ -46,6 +46,7 @@
#include <tvm/runtime/packed_func.h>
#include <string>
#include <vector>
#include <utility>
namespace tvm {
namespace runtime {
......@@ -66,28 +67,7 @@ class Registry {
return set_body(PackedFunc(f));
}
/*!
* \brief set the body of the function to be TypedPackedFunc.
*
* \code
*
* TVM_REGISTER_GLOBAL("addone")
* .set_body_typed<int(int)>([](int x) { return x + 1; });
*
* \endcode
*
* \param f The body of the function.
* \tparam FType the signature of the function.
* \tparam FLambda The type of f.
*/
template<typename FType, typename FLambda>
Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed());
}
/*!
* \brief set the body of the function to the given function pointer.
* Note that this doesn't work with lambdas, you need to
* explicitly give a type for those.
* \brief set the body of the function to the given function.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
......@@ -99,17 +79,20 @@ class Registry {
* TVM_REGISTER_GLOBAL("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* // will have type int(int, int)
* TVM_REGISTER_GLOBAL("sub")
* .set_body_typed([](int a, int b) -> int { return a - b; });
*
* \endcode
*
* \param f The function to forward to.
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
* \tparam FLambda The signature of the function.
*/
template<typename R, typename ...Args>
Registry& set_body_typed(R (*f)(Args...)) {
return set_body(TypedPackedFunc<R(Args...)>(f));
template<typename FLambda>
Registry& set_body_typed(FLambda f) {
using FType = typename detail::function_signature<FLambda>::FType;
return set_body(TypedPackedFunc<FType>(std::move(f)).packed());
}
/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
......@@ -132,10 +115,11 @@ class Registry {
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...)) {
return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
auto fwrap =[f](T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
};
return set_body(TypedPackedFunc<R(T, Args...)>(fwrap));
}
/*!
......@@ -160,10 +144,11 @@ class Registry {
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...) const) {
return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
auto fwrap = [f](const T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
};
return set_body(TypedPackedFunc<R(const T, Args...)>(fwrap));
}
/*!
......@@ -199,11 +184,12 @@ class Registry {
template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) {
return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
auto fwrap = [f](TObjectRef ref, Args... params) {
TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
}
/*!
......@@ -239,11 +225,12 @@ class Registry {
template<typename TObjectRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<ObjectRef, TObjectRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) {
return set_body_typed<R(TObjectRef, Args...)>([f](TObjectRef ref, Args... params) {
auto fwrap = [f](TObjectRef ref, Args... params) {
const TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
};
return set_body(TypedPackedFunc<R(TObjectRef, Args...)>(fwrap));
}
/*!
......
......@@ -48,7 +48,7 @@ TVM_REGISTER_GLOBAL("arith.DetectClipBound")
.set_body_typed(DetectClipBound);
TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
.set_body_typed([](
Expr v, Expr cond,
const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map
......
......@@ -45,10 +45,10 @@ TVM_REGISTER_GLOBAL("_raw_ptr")
});
TVM_REGISTER_GLOBAL("_save_json")
.set_body_typed<std::string(ObjectRef)>(SaveJSON);
.set_body_typed(SaveJSON);
TVM_REGISTER_GLOBAL("_load_json")
.set_body_typed<ObjectRef(std::string)>(LoadJSON);
.set_body_typed(LoadJSON);
TVM_REGISTER_GLOBAL("_TVMSetStream")
.set_body_typed(TVMSetStream);
......
......@@ -32,7 +32,7 @@ namespace tvm {
namespace ir {
TVM_REGISTER_GLOBAL("_Var")
.set_body_typed<VarExpr(std::string, DataType)>([](std::string s, DataType t) {
.set_body_typed([](std::string s, DataType t) {
return Variable::make(t, s);
});
......@@ -64,7 +64,7 @@ TVM_REGISTER_GLOBAL("make._range_by_min_extent")
.set_body_typed(Range::make_by_min_extent);
TVM_REGISTER_GLOBAL("make.For")
.set_body_typed<Stmt(VarExpr, Expr, Expr, int, int, Stmt)>([](
.set_body_typed([](
VarExpr loop_var, Expr min, Expr extent,
int for_type, int device_api, Stmt body) {
return For::make(loop_var,
......@@ -99,7 +99,7 @@ TVM_REGISTER_GLOBAL("make.Realize")
.set_body_typed(Realize::make);
TVM_REGISTER_GLOBAL("make.Call")
.set_body_typed<Expr(DataType, std::string, Array<Expr>, int, FunctionRef, int)>([](
.set_body_typed([](
DataType type, std::string name,
Array<Expr> args, int call_type,
FunctionRef func, int value_index
......@@ -116,9 +116,9 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
.set_body_typed(CommReducerNode::make);
// make from two arguments
#define REGISTER_MAKE(Node) \
#define REGISTER_MAKE(Node) \
TVM_REGISTER_GLOBAL("make."#Node) \
.set_body_typed(Node::make); \
.set_body_typed(Node::make); \
REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);
......@@ -168,32 +168,32 @@ TVM_REGISTER_GLOBAL("make.Block")
// has default args
TVM_REGISTER_GLOBAL("make.Allocate")
.set_body_typed<Stmt(VarExpr, DataType, Array<Expr>, Expr, Stmt)>([](
.set_body_typed([](
VarExpr buffer_var, DataType type, Array<Expr> extents, Expr condition, Stmt body
){
return Allocate::make(buffer_var, type, extents, condition, body);
});
// operator overloading, smarter than make
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
#define REGISTER_MAKE_BINARY_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \
.set_body_typed<Expr(Expr, Expr)>([](Expr a, Expr b) { \
return (Func(a, b)); \
})
.set_body_typed([](Expr a, Expr b) { \
return (Func(a, b)); \
})
#define REGISTER_MAKE_BIT_OP(Node, Func) \
TVM_REGISTER_GLOBAL("make."#Node) \
TVM_REGISTER_GLOBAL("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator Expr())); \
} else if (rhs_is_int) { \
*ret = (Func(args[0].operator Expr(), args[1].operator int())); \
} else { \
*ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
} \
})
bool lhs_is_int = args[0].type_code() == kDLInt; \
bool rhs_is_int = args[1].type_code() == kDLInt; \
if (lhs_is_int) { \
*ret = (Func(args[0].operator int(), args[1].operator Expr())); \
} else if (rhs_is_int) { \
*ret = (Func(args[0].operator Expr(), args[1].operator int())); \
} else { \
*ret = (Func(args[0].operator Expr(), args[1].operator Expr())); \
} \
})
REGISTER_MAKE_BINARY_OP(_OpAdd, operator+);
......@@ -224,7 +224,7 @@ REGISTER_MAKE_BIT_OP(bitwise_xor, operator^);
REGISTER_MAKE_BIT_OP(left_shift, operator<<); // NOLINT(*)
REGISTER_MAKE_BIT_OP(right_shift, operator>>);
TVM_REGISTER_GLOBAL("make._OpIfThenElse")
.set_body_typed<Expr(Expr, Expr, Expr)>([] (Expr cond, Expr true_value, Expr false_value) {
.set_body_typed([] (Expr cond, Expr true_value, Expr false_value) {
return if_then_else(cond, true_value, false_value);
});
......
......@@ -236,22 +236,22 @@ TVM_REGISTER_GLOBAL("_Layout")
.set_body_typed(LayoutNode::make);
TVM_REGISTER_GLOBAL("_LayoutIndexOf")
.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.IndexOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("_LayoutFactorOf")
.set_body_typed<int(Layout, std::string)>([](Layout layout, std::string axis) {
.set_body_typed([](Layout layout, std::string axis) -> int {
return layout.FactorOf(LayoutAxis::make(axis));
});
TVM_REGISTER_GLOBAL("_LayoutNdim")
.set_body_typed<int(Layout)>([](Layout layout) {
.set_body_typed([](Layout layout) -> int {
return layout.ndim();
});
TVM_REGISTER_GLOBAL("_LayoutGetItem")
.set_body_typed<std::string(Layout, int)>([](Layout layout, int idx) {
.set_body_typed([](Layout layout, int idx) -> std::string {
const LayoutAxis& axis = layout[idx];
return axis.name();
});
......@@ -284,14 +284,12 @@ TVM_REGISTER_GLOBAL("_TensorEqual")
.set_body_method(&Tensor::operator==);
TVM_REGISTER_GLOBAL("_TensorHash")
.set_body_typed<int64_t(Tensor)>([](Tensor tensor) {
.set_body_typed([](Tensor tensor) -> int64_t {
return static_cast<int64_t>(std::hash<Tensor>()(tensor));
});
TVM_REGISTER_GLOBAL("_Placeholder")
.set_body_typed<Tensor(Array<Expr>, DataType, std::string)>([](
Array<Expr> shape, DataType dtype, std::string name
) {
.set_body_typed([](Array<Expr> shape, DataType dtype, std::string name) {
return placeholder(shape, dtype, name);
});
......@@ -311,7 +309,7 @@ TVM_REGISTER_GLOBAL("_HybridOp")
.set_body_typed(HybridOpNode::make);
TVM_REGISTER_GLOBAL("_OpGetOutput")
.set_body_typed<Tensor(Operation, int64_t)>([](Operation op, int64_t output) {
.set_body_typed([](Operation op, int64_t output) {
return op.output(static_cast<size_t>(output));
});
......@@ -322,9 +320,7 @@ TVM_REGISTER_GLOBAL("_OpInputTensors")
.set_body_method<Operation>(&OperationNode::InputTensors);
TVM_REGISTER_GLOBAL("_IterVar")
.set_body_typed<IterVar(Range, Var, int, std::string)>([](
Range dom, Var var, int iter_type, std::string thread_tag
) {
.set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) {
return IterVarNode::make(
dom, var,
static_cast<IterVarType>(iter_type),
......@@ -341,25 +337,21 @@ TVM_REGISTER_GLOBAL("_StageBind")
.set_body_method(&Stage::bind);
TVM_REGISTER_GLOBAL("_StageSplitByFactor")
.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
Stage stage, IterVar parent, Expr factor
) {
.set_body_typed([](Stage stage, IterVar parent, Expr factor) {
IterVar outer, inner;
stage.split(parent, factor, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("_StageSplitByNParts")
.set_body_typed<Array<IterVar>(Stage, IterVar, Expr)>([](
Stage stage, IterVar parent, Expr nparts
) {
.set_body_typed([](Stage stage, IterVar parent, Expr nparts) {
IterVar outer, inner;
stage.split_by_nparts(parent, nparts, &outer, &inner);
return Array<IterVar>({outer, inner});
});
TVM_REGISTER_GLOBAL("_StageFuse")
.set_body_typed<IterVar(Stage, Array<IterVar>)>([](Stage stage, Array<IterVar> axes) {
.set_body_typed([](Stage stage, Array<IterVar> axes) {
IterVar fused;
stage.fuse(axes, &fused);
return fused;
......@@ -378,7 +370,7 @@ TVM_REGISTER_GLOBAL("_StageReorder")
.set_body_method(&Stage::reorder);
TVM_REGISTER_GLOBAL("_StageTile")
.set_body_typed<Array<IterVar>(Stage, IterVar, IterVar, Expr, Expr)>([](
.set_body_typed([](
Stage stage,
IterVar x_parent, IterVar y_parent,
Expr x_factor, Expr y_factor
......
......@@ -95,21 +95,21 @@ TVM_REGISTER_GLOBAL("ir_pass.StorageFlatten")
});
TVM_REGISTER_GLOBAL("ir_pass.RewriteForTensorCore")
.set_body_typed<Stmt(const Stmt&, const Schedule&, const Map<Tensor, Buffer>&)>
.set_body_typed
([](const Stmt& stmt, const Schedule& schedule, const Map<Tensor, Buffer>& extern_buffer) {
return RewriteForTensorCore(stmt, schedule, extern_buffer);
});
TVM_REGISTER_GLOBAL("ir_pass.AttrsEqual")
.set_body_typed<bool(const ObjectRef&, const ObjectRef&)>(
.set_body_typed(
[](const ObjectRef& lhs, const ObjectRef& rhs) {
return AttrsEqual()(lhs, rhs);
});
TVM_REGISTER_GLOBAL("ir_pass.AttrsHash")
.set_body_typed<int64_t(const ObjectRef&)>([](const ObjectRef &node) {
.set_body_typed([](const ObjectRef &node) -> int64_t {
return AttrsHash()(node);
});
});
TVM_REGISTER_GLOBAL("ir_pass.ExprUseVar")
......
......@@ -106,7 +106,7 @@ void ErrorTest(int x, int y) {
}
TVM_REGISTER_GLOBAL("_ErrorTest")
.set_body_typed<void(int, int)>(ErrorTest);
.set_body_typed(ErrorTest);
// internal function used for debug and testing purposes
TVM_REGISTER_GLOBAL("_ndarray_use_count")
......
......@@ -37,7 +37,7 @@ TypeVar TypeVarNode::make(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(TypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.TypeVar")
.set_body_typed<TypeVar(std::string, int)>([](std::string name, int kind) {
.set_body_typed([](std::string name, int kind) {
return TypeVarNode::make(name, static_cast<TypeKind>(kind));
});
......@@ -58,7 +58,7 @@ GlobalTypeVar GlobalTypeVarNode::make(std::string name, TypeKind kind) {
TVM_REGISTER_NODE_TYPE(GlobalTypeVarNode);
TVM_REGISTER_GLOBAL("relay._make.GlobalTypeVar")
.set_body_typed<GlobalTypeVar(std::string, int)>([](std::string name, int kind) {
.set_body_typed([](std::string name, int kind) {
return GlobalTypeVarNode::make(name, static_cast<TypeKind>(kind));
});
......
......@@ -51,7 +51,7 @@ EnvFunc EnvFunc::Get(const std::string& name) {
}
TVM_REGISTER_GLOBAL("_EnvFuncGet")
.set_body_typed<EnvFunc(const std::string& name)>(EnvFunc::Get);
.set_body_typed(EnvFunc::Get);
TVM_REGISTER_GLOBAL("_EnvFuncCall")
.set_body([](TVMArgs args, TVMRetValue* rv) {
......@@ -63,7 +63,7 @@ TVM_REGISTER_GLOBAL("_EnvFuncCall")
});
TVM_REGISTER_GLOBAL("_EnvFuncGetPackedFunc")
.set_body_typed<PackedFunc(const EnvFunc& n)>([](const EnvFunc&n) {
.set_body_typed([](const EnvFunc&n) {
return n->func;
});
......
......@@ -816,43 +816,43 @@ const CompileEngine& CompileEngine::Global() {
}
TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed<CCacheKey(Function, Target)>(CCacheKeyNode::make);
.set_body_typed(CCacheKeyNode::make);
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal")
.set_body_typed<CompileEngine()>([]() {
.set_body_typed([]() {
return CompileEngine::Global();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
.set_body_typed([](CompileEngine self) {
self->Clear();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
.set_body_typed(
[](CompileEngine self, CCacheKey key) {
return self->Lower(key);
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
.set_body_typed(
[](CompileEngine self, CCacheKey key) {
return self->LowerShapeFunc(key);
});
TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
.set_body_typed([](CompileEngine self) {
return self->LowerExternalFunctions();
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
.set_body_typed(
[](CompileEngine self, CCacheKey key) {
return self->JIT(key);
});
TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems")
.set_body_typed<Array<ObjectRef>(CompileEngine)>(
.set_body_typed(
[](CompileEngine self){
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -390,7 +390,7 @@ Map<Expr, Array<IntegerArray> > GraphPlanMemory(const Function& func) {
}
TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory")
.set_body_typed<Map<Expr, Array<IntegerArray> >(const Function&)>(GraphPlanMemory);
.set_body_typed(GraphPlanMemory);
} // namespace relay
} // namespace tvm
......@@ -595,23 +595,23 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_GLOBAL("relay._make._alpha_equal")
.set_body_typed<bool(ObjectRef, ObjectRef)>([](ObjectRef a, ObjectRef b) {
.set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(false, false).Equal(a, b);
});
TVM_REGISTER_GLOBAL("relay._make._assert_alpha_equal")
.set_body_typed<void(ObjectRef, ObjectRef)>([](ObjectRef a, ObjectRef b) {
.set_body_typed([](ObjectRef a, ObjectRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not alpha equal";
});
TVM_REGISTER_GLOBAL("relay._make._graph_equal")
.set_body_typed<bool(ObjectRef, ObjectRef)>([](ObjectRef a, ObjectRef b) {
.set_body_typed([](ObjectRef a, ObjectRef b) {
return AlphaEqualHandler(true, false).Equal(a, b);
});
TVM_REGISTER_GLOBAL("relay._make._assert_graph_equal")
.set_body_typed<void(ObjectRef, ObjectRef)>([](ObjectRef a, ObjectRef b) {
.set_body_typed([](ObjectRef a, ObjectRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " are not graph equal";
});
......
......@@ -35,7 +35,7 @@ using namespace tvm::runtime;
TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_GLOBAL("relay._base.set_span")
.set_body_typed<void(ObjectRef, Span)>([](ObjectRef node_ref, Span sp) {
.set_body_typed([](ObjectRef node_ref, Span sp) {
if (auto* rn = node_ref.as<RelayNode>()) {
CHECK(rn);
rn->span = sp;
......
......@@ -167,7 +167,7 @@ Function FunctionNode::SetParams(const tvm::Map<Var, Constant>& parameters) cons
}
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetParams")
.set_body_typed<Function(const Function&, const tvm::Map<Var, Constant>&)>(
.set_body_typed(
[](const Function& func, const tvm::Map<Var, Constant>& parameters) {
return func->SetParams(parameters);
});
......@@ -178,7 +178,7 @@ tvm::Map<Var, Constant> FunctionNode::GetParams() const {
}
TVM_REGISTER_GLOBAL("relay._expr.FunctionGetParams")
.set_body_typed<tvm::Map<Var, Constant>(const Function&)>([](const Function& func) {
.set_body_typed([](const Function& func) {
return func->GetParams();
});
......@@ -367,12 +367,12 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
});
TVM_REGISTER_GLOBAL("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
.set_body_typed([](TempExpr temp) {
return temp->Realize();
});
TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr")
.set_body_typed<Function(Function, std::string, ObjectRef)>(
.set_body_typed(
[](Function func, std::string name, ObjectRef ref) {
return FunctionSetAttr(func, name, ref);
});
......
......@@ -348,7 +348,7 @@ void PostOrderVisit(const Expr& e, std::function<void(const Expr&)> fvisit) {
}
TVM_REGISTER_GLOBAL("relay._analysis.post_order_visit")
.set_body_typed<void(Expr, PackedFunc)>([](Expr expr, PackedFunc f) {
.set_body_typed([](Expr expr, PackedFunc f) {
PostOrderVisit(expr, [f](const Expr& n) {
f(n);
});
......
......@@ -424,12 +424,12 @@ size_t StructuralHash::operator()(const Expr& expr) const {
}
TVM_REGISTER_GLOBAL("relay._analysis._expr_hash")
.set_body_typed<int64_t(ObjectRef)>([](ObjectRef ref) {
.set_body_typed([](ObjectRef ref) {
return static_cast<int64_t>(RelayHashHandler().Hash(ref));
});
TVM_REGISTER_GLOBAL("relay._analysis._type_hash")
.set_body_typed<int64_t(Type)>([](Type type) {
.set_body_typed([](Type type) {
return static_cast<int64_t>(RelayHashHandler().TypeHash(type));
});
......
......@@ -318,7 +318,7 @@ Module FromText(const std::string& source, const std::string& source_name) {
TVM_REGISTER_NODE_TYPE(ModuleNode);
TVM_REGISTER_GLOBAL("relay._make.Module")
.set_body_typed<Module(tvm::Map<GlobalVar, Function>, tvm::Map<GlobalTypeVar, TypeData>)>(
.set_body_typed(
[](tvm::Map<GlobalVar, Function> funcs, tvm::Map<GlobalTypeVar, TypeData> types) {
return ModuleNode::make(funcs, types, {});
});
......@@ -365,52 +365,49 @@ TVM_REGISTER_GLOBAL("relay._module.Module_GetGlobalTypeVar")
.set_body_method<Module>(&ModuleNode::GetGlobalTypeVar);
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup")
.set_body_typed<Function(Module, GlobalVar)>([](Module mod, GlobalVar var) {
.set_body_typed([](Module mod, GlobalVar var) {
return mod->Lookup(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Lookup_str")
.set_body_typed<Function(Module, std::string)>([](Module mod, std::string var) {
.set_body_typed([](Module mod, std::string var) {
return mod->Lookup(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef")
.set_body_typed<TypeData(Module, GlobalTypeVar)>([](Module mod, GlobalTypeVar var) {
.set_body_typed([](Module mod, GlobalTypeVar var) {
return mod->LookupDef(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupDef_str")
.set_body_typed<TypeData(Module, std::string)>([](Module mod, std::string var) {
.set_body_typed([](Module mod, std::string var) {
return mod->LookupDef(var);
});
TVM_REGISTER_GLOBAL("relay._module.Module_LookupTag")
.set_body_typed<Constructor(Module, int32_t)>([](Module mod, int32_t tag) {
.set_body_typed([](Module mod, int32_t tag) {
return mod->LookupTag(tag);
});
TVM_REGISTER_GLOBAL("relay._module.Module_FromExpr")
.set_body_typed<
Module(Expr,
tvm::Map<GlobalVar, Function>,
tvm::Map<GlobalTypeVar, TypeData>)>([](Expr e,
tvm::Map<GlobalVar, Function> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return ModuleNode::FromExpr(e, funcs, type_defs);
});
.set_body_typed([](Expr e,
tvm::Map<GlobalVar, Function> funcs,
tvm::Map<GlobalTypeVar, TypeData> type_defs) {
return ModuleNode::FromExpr(e, funcs, type_defs);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Update")
.set_body_typed<void(Module, Module)>([](Module mod, Module from) {
.set_body_typed([](Module mod, Module from) {
mod->Update(from);
});
TVM_REGISTER_GLOBAL("relay._module.Module_Import")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
.set_body_typed([](Module mod, std::string path) {
mod->Import(path);
});
TVM_REGISTER_GLOBAL("relay._module.Module_ImportFromStd")
.set_body_typed<void(Module, std::string)>([](Module mod, std::string path) {
.set_body_typed([](Module mod, std::string path) {
mod->ImportFromStd(path);
});;
......
......@@ -136,7 +136,7 @@ void OpRegistry::UpdateAttr(const std::string& key,
// Frontend APIs
TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
.set_body_typed<Array<tvm::Expr>()>([]() {
.set_body_typed([]() {
Array<tvm::Expr> ret;
for (const std::string& name :
dmlc::Registry<OpRegistry>::ListAllNames()) {
......@@ -145,7 +145,7 @@ TVM_REGISTER_GLOBAL("relay.op._ListOpNames")
return ret;
});
TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed<Op(std::string)>(Op::Get);
TVM_REGISTER_GLOBAL("relay.op._GetOp").set_body_typed(Op::Get);
TVM_REGISTER_GLOBAL("relay.op._OpGetAttr")
.set_body([](TVMArgs args, TVMRetValue* rv) {
......
......@@ -991,9 +991,7 @@ std::string AsText(const ObjectRef& node,
}
TVM_REGISTER_GLOBAL("relay._expr.AsText")
.set_body_typed<std::string(const ObjectRef&,
bool,
runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);
.set_body_typed(AsText);
} // namespace relay
} // namespace tvm
......@@ -91,7 +91,7 @@ IncompleteType IncompleteTypeNode::make(Kind kind) {
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);
TVM_REGISTER_GLOBAL("relay._make.IncompleteType")
.set_body_typed<IncompleteType(int)>([](int kind) {
.set_body_typed([](int kind) {
return IncompleteTypeNode::make(static_cast<Kind>(kind));
});
......@@ -161,8 +161,7 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
});
TVM_REGISTER_GLOBAL("relay._make.Any")
.set_body_typed<IndexExpr()>([]() { return Any::make(); });
.set_body_typed([]() { return Any::make(); });
} // namespace relay
} // namespace tvm
......@@ -40,7 +40,7 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(OnDeviceAttrs);
TVM_REGISTER_GLOBAL("relay.op.annotation._make.on_device")
.set_body_typed<Expr(Expr, int)>([](Expr data, int device_type) {
.set_body_typed([](Expr data, int device_type) {
auto attrs = make_object<OnDeviceAttrs>();
attrs->device_type = device_type;
static const Op& op = Op::Get("on_device");
......@@ -63,7 +63,7 @@ Expr StopFusion(Expr data) {
}
TVM_REGISTER_GLOBAL("relay.op.annotation._make.stop_fusion")
.set_body_typed<Expr(Expr)>([](Expr data) {
.set_body_typed([](Expr data) {
return StopFusion(data);
});
......@@ -145,7 +145,7 @@ Mark the end of bitpacking.
});
TVM_REGISTER_GLOBAL("relay.op.annotation._make.checkpoint")
.set_body_typed<Expr(Expr)>([](Expr data) {
.set_body_typed([](Expr data) {
static const Op& op = Op::Get("annotation.checkpoint");
return CallNode::make(op, {data}, Attrs{}, {});
});
......
......@@ -42,7 +42,7 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(DeviceCopyAttrs);
TVM_REGISTER_GLOBAL("relay.op._make.device_copy")
.set_body_typed<Expr(Expr, int, int)>([](Expr data, int src_dev_type,
.set_body_typed([](Expr data, int src_dev_type,
int dst_dev_type) {
auto attrs = make_object<DeviceCopyAttrs>();
attrs->src_dev_type = src_dev_type;
......
......@@ -42,7 +42,7 @@ TVM_REGISTER_NODE_TYPE(ShapeFuncAttrs);
// We should consider a better solution, i.e the type relation
// being able to see the arguments as well?
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_storage")
.set_body_typed<Expr(Expr, Expr, DataType)>([](Expr size, Expr alignment, DataType dtype) {
.set_body_typed([](Expr size, Expr alignment, DataType dtype) {
auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("memory.alloc_storage");
......@@ -88,7 +88,7 @@ RELAY_REGISTER_OP("memory.alloc_storage")
});
TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor")
.set_body_typed<Expr(Expr, Expr, DataType, Array<IndexExpr> assert_shape)>(
.set_body_typed(
[](Expr storage, tvm::relay::Expr shape, DataType dtype, Array<IndexExpr> assert_shape) {
auto attrs = make_object<AllocTensorAttrs>();
attrs->dtype = dtype;
......@@ -209,7 +209,7 @@ bool InvokeTVMOPRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
}
TVM_REGISTER_GLOBAL("relay.op.memory._make.invoke_tvm_op")
.set_body_typed<Expr(Expr, Expr, Expr)>(
.set_body_typed(
[](Expr func, Expr inputs, Expr outputs) {
return CallNode::make(Op::Get("memory.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
});
......@@ -257,7 +257,7 @@ RELAY_REGISTER_OP("memory.kill")
});
TVM_REGISTER_GLOBAL("relay.op.memory._make.shape_func")
.set_body_typed<Expr(Expr, Expr, Expr, Array<tvm::Integer>)>(
.set_body_typed(
[](Expr func, Expr inputs, Expr outputs, Array<tvm::Integer> is_input) {
static const Op& op = Op::Get("memory.shape_func");
auto attrs = make_object<ShapeFuncAttrs>();
......
......@@ -326,7 +326,7 @@ where :math:`*` is an channelwise multiplication for each sample in the batch.
TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
TVM_REGISTER_GLOBAL("relay.op.nn._make.softmax")
.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
.set_body_typed([](Expr data, int axis) {
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.softmax");
......@@ -361,7 +361,7 @@ RELAY_REGISTER_OP("nn.softmax")
// relay.nn.log_softmax
TVM_REGISTER_GLOBAL("relay.op.nn._make.log_softmax")
.set_body_typed<Call(Expr, int)>([](Expr data, int axis) {
.set_body_typed([](Expr data, int axis) {
auto attrs = make_object<SoftmaxAttrs>();
attrs->axis = axis;
static const Op& op = Op::Get("nn.log_softmax");
......@@ -470,7 +470,7 @@ Example::
// relu
TVM_REGISTER_GLOBAL("relay.op.nn._make.relu")
.set_body_typed<Call(Expr)>([](Expr data) {
.set_body_typed([](Expr data) {
static const Op& op = Op::Get("nn.relu");
return CallNode::make(op, {data}, Attrs(), {});
});
......
......@@ -214,13 +214,12 @@ Array<Tensor> Pool2DCompute(const Attrs& attrs,
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d")
.set_body_typed<Expr(Expr, Array<IndexExpr>, Array<IndexExpr>, Array<IndexExpr>,
std::string, bool)>([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode) {
.set_body_typed([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode) {
return MakeMaxPool<MaxPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
"nn.max_pool2d");
});
......@@ -258,14 +257,13 @@ RELAY_REGISTER_OP("nn.max_pool2d")
// AvgPool2D
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d")
.set_body_typed<Expr(Expr, Array<IndexExpr>, Array<IndexExpr>, Array<IndexExpr>,
std::string, bool, bool)>([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode,
bool count_include_pad) {
.set_body_typed([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode,
bool count_include_pad) {
return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
count_include_pad, "nn.avg_pool2d");
});
......@@ -868,13 +866,12 @@ Array<Tensor> Pool3DCompute(const Attrs& attrs,
}
TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d")
.set_body_typed<Expr(Expr, Array<IndexExpr>, Array<IndexExpr>, Array<IndexExpr>,
std::string, bool)>([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode) {
.set_body_typed([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode) {
return MakeMaxPool<MaxPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
"nn.max_pool3d");
});
......@@ -912,14 +909,13 @@ RELAY_REGISTER_OP("nn.max_pool3d")
// AvgPool3D
TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d")
.set_body_typed<Expr(Expr, Array<IndexExpr>, Array<IndexExpr>, Array<IndexExpr>,
std::string, bool, bool)>([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode,
bool count_include_pad) {
.set_body_typed([](Expr data,
Array<IndexExpr> pool_size,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
std::string layout,
bool ceil_mode,
bool count_include_pad) {
return MakeAvgPool<AvgPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode,
count_include_pad, "nn.avg_pool3d");
});
......
......@@ -48,19 +48,19 @@ namespace relay {
* \param OpName the name of registry.
*/
#define RELAY_REGISTER_UNARY_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed<Expr(Expr)>([](Expr data) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(), {}); \
}); \
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed([](Expr data) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {data}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.") \
.add_type_rel("Identity", IdentityRel) \
.set_attr<TOpPattern>("TOpPattern", kElemWise) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
ElemwiseArbitraryLayout) \
.set_num_inputs(1) \
.add_argument("data", "Tensor", "The input tensor.") \
.add_type_rel("Identity", IdentityRel) \
.set_attr<TOpPattern>("TOpPattern", kElemWise) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
ElemwiseArbitraryLayout) \
/*! Quick helper macro
......@@ -73,38 +73,38 @@ namespace relay {
*
* \param OpName the name of registry.
*/
#define RELAY_REGISTER_BINARY_OP(OpName) \
#define RELAY_REGISTER_BINARY_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout)
.set_body_typed([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("Broadcast", BroadcastRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout)
// Comparisons
#define RELAY_REGISTER_CMP_OP(OpName) \
#define RELAY_REGISTER_CMP_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("BroadcastComp", BroadcastCompRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout)
.set_body_typed([](Expr lhs, Expr rhs) { \
static const Op& op = Op::Get(OpName); \
return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP(OpName) \
.set_num_inputs(2) \
.add_argument("lhs", "Tensor", "The left hand side tensor.") \
.add_argument("rhs", "Tensor", "The right hand side tensor.") \
.add_type_rel("BroadcastComp", BroadcastCompRel) \
.set_attr<TOpPattern>("TOpPattern", kBroadcast) \
.set_attr<TOpIsStateful>("TOpIsStateful", false) \
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", \
BinaryBroadcastLayout)
/*! \brief A helper class for matching and rewriting operators. */
......
......@@ -303,7 +303,7 @@ bool ReduceRel(const Array<Type>& types,
#define RELAY_REGISTER_REDUCE_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.op._make." OpName) \
.set_body_typed<Call(Expr, Array<Integer>, bool, bool)>([]( \
.set_body_typed([]( \
Expr data, \
Array<Integer> axis, \
bool keepdims, \
......
......@@ -858,7 +858,7 @@ bool ArgWhereRel(const Array<Type>& types,
}
TVM_REGISTER_GLOBAL("relay.op._make.argwhere")
.set_body_typed<Expr(Expr)>([](Expr data) {
.set_body_typed([](Expr data) {
static const Op& op = Op::Get("argwhere");
auto attrs = make_object<ArgWhereAttrs>();
return CallNode::make(op, {data}, Attrs(attrs), {});
......
......@@ -158,7 +158,7 @@ RELAY_REGISTER_UNARY_OP("copy")
TVM_REGISTER_NODE_TYPE(ClipAttrs);
TVM_REGISTER_GLOBAL("relay.op._make.clip")
.set_body_typed<Expr(Expr, double, double)>([](Expr a, double a_min, double a_max) {
.set_body_typed([](Expr a, double a_min, double a_max) {
auto attrs = make_object<ClipAttrs>();
attrs->a_min = a_min;
attrs->a_max = a_max;
......@@ -301,7 +301,7 @@ Array<Tensor> ShapeOfCompute(const Attrs& attrs,
}
TVM_REGISTER_GLOBAL("relay.op._make.shape_of")
.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) {
.set_body_typed([](Expr data, DataType dtype) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("shape_of");
......@@ -352,7 +352,7 @@ Array<Tensor> NdarraySizeCompute(const Attrs& attrs,
}
TVM_REGISTER_GLOBAL("relay.op.contrib._make.ndarray_size")
.set_body_typed<Expr(Expr, DataType)>([](Expr data, DataType dtype) {
.set_body_typed([](Expr data, DataType dtype) {
auto attrs = make_object<NdarraySizeAttrs>();
attrs->dtype = dtype;
static const Op& op = Op::Get("contrib.ndarray_size");
......
......@@ -327,7 +327,7 @@ Array<Pattern> UnmatchedCases(const Match& match, const Module& mod) {
// expose for testing only
TVM_REGISTER_GLOBAL("relay._analysis.unmatched_cases")
.set_body_typed<Array<Pattern>(const Match&, const Module&)>(
.set_body_typed(
[](const Match& match, const Module& mod_ref) {
Module call_mod = mod_ref;
if (!call_mod.defined()) {
......
......@@ -67,7 +67,7 @@ RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize")
.set_body_typed<Expr(Expr, Expr, Expr, Expr, int, bool, std::string)>(
.set_body_typed(
[](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max,
int kind, bool sign, std::string rounding) {
auto attrs = make_object<SimulatedQuantizeAttrs>();
......
......@@ -79,7 +79,7 @@ bool TupleGetItemRel(const Array<Type>& types,
TVM_REGISTER_NODE_TYPE(TupleGetItemAttrs);
TVM_REGISTER_GLOBAL("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
.set_body_typed(
TupleGetItemRel);
struct ResolvedTypeInfo {
......@@ -840,7 +840,7 @@ Pass InferType() {
}
TVM_REGISTER_GLOBAL("relay._transform.InferType")
.set_body_typed<Pass()>([]() {
.set_body_typed([]() {
return InferType();
});
......
......@@ -63,19 +63,18 @@ static inline bool QnnBroadcastRel(const Array<Type>& types, int num_inputs, con
*
* \param OpName the name of registry.
*/
#define QNN_REGISTER_BINARY_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \
.set_body_typed<Expr(Expr, Expr, Expr, Expr, Expr, Expr, Expr, Expr)>( \
[](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \
Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \
static const Op& op = Op::Get("qnn." OpName); \
return CallNode::make(op, {lhs, rhs, \
lhs_scale, lhs_zero_point, \
rhs_scale, rhs_zero_point, \
output_scale, output_zero_point}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP("qnn." OpName) \
.set_num_inputs(8) \
#define QNN_REGISTER_BINARY_OP(OpName) \
TVM_REGISTER_GLOBAL("relay.qnn.op._make." OpName) \
.set_body_typed([](Expr lhs, Expr rhs, Expr lhs_scale, Expr lhs_zero_point, Expr rhs_scale, \
Expr rhs_zero_point, Expr output_scale, Expr output_zero_point) { \
static const Op& op = Op::Get("qnn." OpName); \
return CallNode::make(op, {lhs, rhs, \
lhs_scale, lhs_zero_point, \
rhs_scale, rhs_zero_point, \
output_scale, output_zero_point}, Attrs(), {}); \
}); \
RELAY_REGISTER_OP("qnn." OpName) \
.set_num_inputs(8) \
.add_argument("lhs", "Tensor", "The left hand side quantized tensor.") \
.add_argument("rhs", "Tensor", "The right hand side quantized tensor.") \
.add_argument("lhs_scale", "Tensor", "The scale of the lhs tensor.") \
......
......@@ -795,7 +795,7 @@ TVM_REGISTER_GLOBAL("relay._vm.GetPrimitiveFields")
});
TVM_REGISTER_GLOBAL("relay._vm.Load_Executable")
.set_body_typed<runtime::Module(std::string, runtime::Module)>([](
.set_body_typed([](
std::string code,
runtime::Module lib) {
return Executable::Load(code, lib);
......
......@@ -179,6 +179,24 @@ TEST(TypedPackedFunc, HighOrder) {
CHECK_EQ(f1(3), 4);
}
TEST(TypedPackedFunc, Deduce) {
using namespace tvm::runtime;
using tvm::runtime::detail::function_signature;
TypedPackedFunc<int(float)> x;
auto f = [](int x) -> int {
return x + 1;
};
std::function<void(float)> y;
static_assert(std::is_same<function_signature<decltype(x)>::FType,
int(float)>::value, "invariant1");
static_assert(std::is_same<function_signature<decltype(f)>::FType,
int(int)>::value, "invariant2");
static_assert(std::is_same<function_signature<decltype(y)>::FType,
void(float)>::value, "invariant3");
}
TEST(PackedFunc, ObjectConversion) {
using namespace tvm;
......
......@@ -60,13 +60,13 @@ struct RPCEnv {
};
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body_typed<std::string(std::string)>([](std::string path) {
.set_body_typed([](std::string path) {
static RPCEnv env;
return env.GetPath(path);
});
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body_typed<Module(std::string)>([](std::string path) {
.set_body_typed([](std::string path) {
std::string file_name = "/rpc/" + path;
LOG(INFO) << "Load module from " << file_name << " ...";
return Module::LoadFromFile(file_name, "");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment