Commit 672203f2 by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] [Error] Fix error in partial evaluator (#3693)

* fix

* lint
parent 8ad36a17
......@@ -131,7 +131,7 @@ Expr PostProcess(const Expr&);
/*! \brief The base container type of Relay values. */
class StaticNode : public RelayNode {
public:
static constexpr const char* _type_key = "relay.Value";
static constexpr const char* _type_key = "relay.Static";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
};
......@@ -161,6 +161,7 @@ struct PStaticNode : Node {
PStaticNode(const Static& pstatic, const Expr& dynamic) :
pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
static constexpr const char* _type_key = "relay.PStatic";
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
};
......@@ -169,6 +170,7 @@ RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef);
struct STupleNode : StaticNode {
std::vector<PStatic> fields;
explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) { }
static constexpr const char* _type_key = "relay.STuple";
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
};
......@@ -181,7 +183,8 @@ Static MkSTuple(const std::vector<PStatic>& fields) {
struct STensorNode : StaticNode {
runtime::NDArray data;
explicit STensorNode(const NDArray& data) : data(data) { }
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
static constexpr const char* _type_key = "relay.STensor";
TVM_DECLARE_NODE_TYPE_INFO(STensorNode, StaticNode);
};
RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value);
......@@ -195,6 +198,7 @@ struct SConstructorNode : StaticNode {
std::vector<PStatic> fields;
SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields) :
constructor(constructor), fields(fields) { }
static constexpr const char* _type_key = "relay.SConstructor";
TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode);
};
......@@ -205,6 +209,7 @@ Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>
}
struct SRefNode : StaticNode {
static constexpr const char* _type_key = "relay.SRef";
// we will use the address as the guid for hashing
TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode);
};
......@@ -223,6 +228,7 @@ using Func = std::function<PStatic(const std::vector<PStatic>&,
struct SFuncNode : StaticNode {
Func func;
explicit SFuncNode(const Func& func) : func(func) { }
static constexpr const char* _type_key = "relay.SFunc";
TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode);
};
......@@ -711,8 +717,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return VisitFunc(GetRef<Function>(op), ll);
}
struct ReflectError : dmlc::Error {
ReflectError() : dmlc::Error("static value not found") { }
};
Expr Reflect(const PStatic& st) {
if (const STensorNode* op = st->pstatic.as<STensorNode>()) {
if (!st->pstatic.defined()) {
throw ReflectError();
} else if (const STensorNode* op = st->pstatic.as<STensorNode>()) {
return ConstantNode::make(op->data);
} else if (const STupleNode* op = st->pstatic.as<STupleNode>()) {
tvm::Array<Expr> fields;
......@@ -721,7 +733,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
}
return TupleNode::make(fields);
} else {
LOG(FATAL) << "Unknown case";
LOG(FATAL) << "Unknown case: " << st->dynamic;
throw;
}
}
......@@ -767,19 +779,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (const PStatic& ps : pv) {
ns_args.push_back(ps->dynamic);
}
PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
auto ns = [&]() {
return NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
};
if (StatefulOp(expr)) {
return ns;
return ns();
}
tvm::Array<Expr> args;
for (const PStatic& ps : pv) {
if (ps->pstatic.defined()) {
try {
tvm::Array<Expr> args;
for (const PStatic& ps : pv) {
args.push_back(Reflect(ps));
} else {
return ns;
}
return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll);
}
catch (const ReflectError&) {
return ns();
}
return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll);
};
}
......
......@@ -18,7 +18,7 @@
import numpy as np
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from tvm.relay.prelude import Prelude
from tvm.relay import op, create_executor, transform
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
......@@ -306,6 +306,14 @@ def test_double():
assert alpha_equal(res.body, make_nat_expr(p, 6))
def test_concat():
t = relay.TensorType([10], "float32")
x = Var("x", t)
y = Var("x", t)
orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
assert_alpha_equal(orig, dcpe(orig))
if __name__ == '__main__':
test_ref()
test_tuple()
......@@ -323,3 +331,4 @@ if __name__ == '__main__':
test_nat_id()
test_global_match_nat_id()
test_match_nat_id()
test_concat()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment