Commit 6d702ea8 by Haichen Shen Committed by Tianqi Chen

fix (#3550)

parent c855882a
......@@ -247,7 +247,7 @@ Expr FirstOrderGradient(const Expr& re, const Module& mod) {
return FunctionNode::make(f->params, body, GradRetType(GetRef<Function>(f)), {});
}
TVM_REGISTER_API("relay._analysis.first_order_gradient")
TVM_REGISTER_API("relay._transform.first_order_gradient")
.set_body_typed(FirstOrderGradient);
struct ReverseADType : TypeMutator {
......
......@@ -35,7 +35,7 @@ def test_id():
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x)
back_func = run_infer_type(gradient(func))
back_func = run_infer_type(gradient(func, mode="first_order"))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment