Commit 672203f2 by 雾雨魔理沙 Committed by Thierry Moreau

[Relay] [Error] Fix error in partial evaluator (#3693)

* fix

* lint
parent 8ad36a17
...@@ -131,7 +131,7 @@ Expr PostProcess(const Expr&); ...@@ -131,7 +131,7 @@ Expr PostProcess(const Expr&);
/*! \brief The base container type of Relay values. */ /*! \brief The base container type of Relay values. */
class StaticNode : public RelayNode { class StaticNode : public RelayNode {
public: public:
static constexpr const char* _type_key = "relay.Value"; static constexpr const char* _type_key = "relay.Static";
TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode); TVM_DECLARE_BASE_NODE_INFO(ValueNode, RelayNode);
}; };
...@@ -161,6 +161,7 @@ struct PStaticNode : Node { ...@@ -161,6 +161,7 @@ struct PStaticNode : Node {
PStaticNode(const Static& pstatic, const Expr& dynamic) : PStaticNode(const Static& pstatic, const Expr& dynamic) :
pstatic(pstatic), dynamic(dynamic), created_time(time()) { } pstatic(pstatic), dynamic(dynamic), created_time(time()) { }
explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { } explicit PStaticNode(const Expr& dynamic) : PStaticNode(Static(), dynamic) { }
static constexpr const char* _type_key = "relay.PStatic";
TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node); TVM_DECLARE_NODE_TYPE_INFO(PStaticNode, Node);
}; };
...@@ -169,6 +170,7 @@ RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef); ...@@ -169,6 +170,7 @@ RELAY_DEFINE_NODE_REF(PStatic, PStaticNode, NodeRef);
struct STupleNode : StaticNode { struct STupleNode : StaticNode {
std::vector<PStatic> fields; std::vector<PStatic> fields;
explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) { } explicit STupleNode(const std::vector<PStatic>& fields) : fields(fields) { }
static constexpr const char* _type_key = "relay.STuple";
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode);
}; };
...@@ -181,7 +183,8 @@ Static MkSTuple(const std::vector<PStatic>& fields) { ...@@ -181,7 +183,8 @@ Static MkSTuple(const std::vector<PStatic>& fields) {
struct STensorNode : StaticNode { struct STensorNode : StaticNode {
runtime::NDArray data; runtime::NDArray data;
explicit STensorNode(const NDArray& data) : data(data) { } explicit STensorNode(const NDArray& data) : data(data) { }
TVM_DECLARE_NODE_TYPE_INFO(STupleNode, StaticNode); static constexpr const char* _type_key = "relay.STensor";
TVM_DECLARE_NODE_TYPE_INFO(STensorNode, StaticNode);
}; };
RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value); RELAY_DEFINE_NODE_REF(STensor, STensorNode, Value);
...@@ -195,6 +198,7 @@ struct SConstructorNode : StaticNode { ...@@ -195,6 +198,7 @@ struct SConstructorNode : StaticNode {
std::vector<PStatic> fields; std::vector<PStatic> fields;
SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields) : SConstructorNode(const Constructor& constructor, const std::vector<PStatic>& fields) :
constructor(constructor), fields(fields) { } constructor(constructor), fields(fields) { }
static constexpr const char* _type_key = "relay.SConstructor";
TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode); TVM_DECLARE_NODE_TYPE_INFO(SConstructorNode, StaticNode);
}; };
...@@ -205,6 +209,7 @@ Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic> ...@@ -205,6 +209,7 @@ Static MkSConstructor(const Constructor& constructor, const std::vector<PStatic>
} }
struct SRefNode : StaticNode { struct SRefNode : StaticNode {
static constexpr const char* _type_key = "relay.SRef";
// we will use the address as the guid for hashing // we will use the address as the guid for hashing
TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode); TVM_DECLARE_NODE_TYPE_INFO(SRefNode, StaticNode);
}; };
...@@ -223,6 +228,7 @@ using Func = std::function<PStatic(const std::vector<PStatic>&, ...@@ -223,6 +228,7 @@ using Func = std::function<PStatic(const std::vector<PStatic>&,
struct SFuncNode : StaticNode { struct SFuncNode : StaticNode {
Func func; Func func;
explicit SFuncNode(const Func& func) : func(func) { } explicit SFuncNode(const Func& func) : func(func) { }
static constexpr const char* _type_key = "relay.SFunc";
TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode); TVM_DECLARE_NODE_TYPE_INFO(SFuncNode, StaticNode);
}; };
...@@ -711,8 +717,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -711,8 +717,14 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
return VisitFunc(GetRef<Function>(op), ll); return VisitFunc(GetRef<Function>(op), ll);
} }
struct ReflectError : dmlc::Error {
ReflectError() : dmlc::Error("static value not found") { }
};
Expr Reflect(const PStatic& st) { Expr Reflect(const PStatic& st) {
if (const STensorNode* op = st->pstatic.as<STensorNode>()) { if (!st->pstatic.defined()) {
throw ReflectError();
} else if (const STensorNode* op = st->pstatic.as<STensorNode>()) {
return ConstantNode::make(op->data); return ConstantNode::make(op->data);
} else if (const STupleNode* op = st->pstatic.as<STupleNode>()) { } else if (const STupleNode* op = st->pstatic.as<STupleNode>()) {
tvm::Array<Expr> fields; tvm::Array<Expr> fields;
...@@ -721,7 +733,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -721,7 +733,7 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
} }
return TupleNode::make(fields); return TupleNode::make(fields);
} else { } else {
LOG(FATAL) << "Unknown case"; LOG(FATAL) << "Unknown case: " << st->dynamic;
throw; throw;
} }
} }
...@@ -767,19 +779,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)> ...@@ -767,19 +779,22 @@ class PartialEvaluator : public ExprFunctor<PStatic(const Expr& e, LetList* ll)>
for (const PStatic& ps : pv) { for (const PStatic& ps : pv) {
ns_args.push_back(ps->dynamic); ns_args.push_back(ps->dynamic);
} }
PStatic ns = NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args))); auto ns = [&]() {
return NoStatic(ll->Push(CallNode::make(expr, ns_args, attrs, type_args)));
};
if (StatefulOp(expr)) { if (StatefulOp(expr)) {
return ns; return ns();
} }
tvm::Array<Expr> args; try {
for (const PStatic& ps : pv) { tvm::Array<Expr> args;
if (ps->pstatic.defined()) { for (const PStatic& ps : pv) {
args.push_back(Reflect(ps)); args.push_back(Reflect(ps));
} else {
return ns;
} }
return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll);
}
catch (const ReflectError&) {
return ns();
} }
return ConstEvaluate(CallNode::make(expr, args, attrs, type_args), ll);
}; };
} }
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
import numpy as np import numpy as np
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.analysis import alpha_equal from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from tvm.relay.prelude import Prelude from tvm.relay.prelude import Prelude
from tvm.relay import op, create_executor, transform from tvm.relay import op, create_executor, transform
from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate from tvm.relay import Var, TypeVar, TupleGetItem, Let, Function, const, RefRead, RefWrite, RefCreate
...@@ -306,6 +306,14 @@ def test_double(): ...@@ -306,6 +306,14 @@ def test_double():
assert alpha_equal(res.body, make_nat_expr(p, 6)) assert alpha_equal(res.body, make_nat_expr(p, 6))
def test_concat():
t = relay.TensorType([10], "float32")
x = Var("x", t)
y = Var("x", t)
orig = run_infer_type(Function([x, y], op.concatenate([x, y], axis=0)))
assert_alpha_equal(orig, dcpe(orig))
if __name__ == '__main__': if __name__ == '__main__':
test_ref() test_ref()
test_tuple() test_tuple()
...@@ -323,3 +331,4 @@ if __name__ == '__main__': ...@@ -323,3 +331,4 @@ if __name__ == '__main__':
test_nat_id() test_nat_id()
test_global_match_nat_id() test_global_match_nat_id()
test_match_nat_id() test_match_nat_id()
test_concat()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment