Commit 6d702ea8 by Haichen Shen Committed by Tianqi Chen

fix (#3550)

parent c855882a
...@@ -247,7 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) { ...@@ -247,7 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {}); return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
} }
TVM_REGISTER_API("relay._analysis.first_order_gradient") TVM_REGISTER_API("relay._transform.first_order_gradient")
.set_body_typed(FirstOrderGradient); .set_body_typed(FirstOrderGradient);
struct ReverseADType : TypeMutator { struct ReverseADType : TypeMutator {
......
...@@ -35,7 +35,7 @@ def test_id(): ...@@ -35,7 +35,7 @@ def test_id():
t = relay.TensorType(shape, dtype) t = relay.TensorType(shape, dtype)
x = relay.var("x", t) x = relay.var("x", t)
func = relay.Function([x], x) func = relay.Function([x], x)
back_func = run_infer_type(gradient(func)) back_func = run_infer_type(gradient(func, mode="first_order"))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])])) assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor() ex = create_executor()
x = rand(dtype, *shape) x = rand(dtype, *shape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment