Unverified Commit ec0d497c by Tianqi Chen Committed by GitHub

[NODE][RELAY] Move most of the reference related code to node (#1747)

parent 1c2b0b65
......@@ -102,10 +102,10 @@ class TVM_DLL Node : public NodeBase {
template<typename T>
inline bool is_type() const;
/*!
* \brief Get a NodeRef that holds reference to this Node.
* \return the NodeRef
* \brief Get a NodePtr that holds reference to this Node.
* \return the NodePtr
*/
inline NodeRef GetNodeRef() const;
inline NodePtr<Node> GetNodePtr() const;
// node ref can see this
friend class NodeRef;
static constexpr const char* _type_key = "Node";
......@@ -177,6 +177,32 @@ class NodeRef {
};
/*!
* \brief Get a reference type from a Node ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
*
* \param ptr The node pointer
* \tparam RefType The reference type
* \tparam NodeType The node type
* \return The corresponding RefType
*/
template <typename RefType, typename NodeType>
inline RefType GetRef(const NodeType* ptr);
/*!
* \brief Downcast a base reference type to a more specific type.
*
* \param ref The inptut reference
* \return The corresponding SubRef.
* \tparam SubRef The target specific reference type.
* \tparam BaseRef the current reference type.
*/
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref);
/*!
* \brief helper macro to declare type information in a base node.
*/
#define TVM_DECLARE_BASE_NODE_INFO(TypeName, Parent) \
......@@ -218,8 +244,24 @@ inline bool Node::derived_from() const {
return this->_DerivedFrom(type_id);
}
inline NodeRef Node::GetNodeRef() const {
return NodeRef(NodePtr<Node>(const_cast<Node*>(this)));
inline NodePtr<Node> Node::GetNodePtr() const {
return NodePtr<Node>(const_cast<Node*>(this));
}
template <typename RefType, typename NodeType>
inline RefType GetRef(const NodeType* ptr) {
static_assert(std::is_base_of<typename RefType::ContainerType, NodeType>::value,
"Can only cast to the ref of same container type");
return RefType(ptr->GetNodePtr());
}
template <typename SubRef, typename BaseRef>
inline SubRef Downcast(BaseRef ref) {
CHECK(ref->template is_type<typename SubRef::ContainerType>() ||
ref->template derived_from<typename SubRef::ContainerType>())
<< "Downcast from " << ref->type_key() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(std::move(ref.node_));
}
inline const Node* NodeRef::get() const {
......
......@@ -158,43 +158,6 @@ class RelayNode : public Node {
TVM_DECLARE_BASE_NODE_INFO(RelayNode, Node);
};
/*!
* \brief Get a reference type from a Node ptr type
*
* It is always important to get a reference type
* if we want to return a value as reference or keep
* the node alive beyond the scope of the function.
*
* \param ptr The node pointer
* \tparam RefType The reference type
* \tparam NodeType The node type
* \return The corresponding RefType
*/
template <typename RefType, typename NodeType>
RefType GetRef(const NodeType* ptr) {
static_assert(std::is_same<typename RefType::ContainerType, NodeType>::value,
"Can only cast to the ref of same container type");
return RefType(std::move(ptr->GetNodeRef().node_));
}
// TODO(@tqchen, @jroesch): can we move these semantics to HalideIR
template <typename T>
inline const T* As(const NodeRef& node) {
const Node* ptr = static_cast<const Node*>(node.get());
if (ptr && (ptr->is_type<T>() || ptr->derived_from<T>())) {
return static_cast<const T*>(ptr);
}
return nullptr;
}
template <typename SubRef, typename BaseRef>
SubRef Downcast(BaseRef ref) {
CHECK(ref->template is_type<typename SubRef::ContainerType>())
<< "Downcast from " << ref->type_key() << " to "
<< SubRef::ContainerType::_type_key << " failed.";
return SubRef(ref.node_);
}
} // namespace relay
} // namespace tvm
......
......@@ -65,7 +65,9 @@ class ConstantNode : public ExprNode {
TensorType tensor_type() const;
/*! \return Whether it is scalar(rank-0 tensor) */
bool is_scalar() const { return data->ndim == 0; }
bool is_scalar() const {
return data->ndim == 0;
}
void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("data", &data);
......@@ -341,7 +343,7 @@ RELAY_DEFINE_NODE_REF(Let, LetNode, Expr);
*
* let x = if (true) { 1 } else { 0 }; // x is 1
* let y = if (false) { 1 } else { 0 }; // y is 0
*
*
* \note This is similar to C's ternary operator.
*/
class If;
......
......@@ -139,19 +139,19 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
* the cost of using functional updates.
*/
class ExprMutator
: public ::tvm::relay::ExprFunctor<Expr(const Expr&, const Expr&)> {
: public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
public:
Expr Mutate(const Expr& expr);
Expr VisitExpr_(const VarNode* op, const Expr& e) override;
Expr VisitExpr_(const ConstantNode* op, const Expr& e) override;
Expr VisitExpr_(const GlobalVarNode* op, const Expr& e) override;
Expr VisitExpr_(const OpNode* op, const Expr& expr) override;
Expr VisitExpr_(const TupleNode* op, const Expr& e) override;
Expr VisitExpr_(const ParamNode* op, const Expr& e) override;
Expr VisitExpr_(const FunctionNode* op, const Expr& e) override;
Expr VisitExpr_(const CallNode* call_node, const Expr& e) override;
Expr VisitExpr_(const LetNode* op, const Expr& e) override;
Expr VisitExpr_(const IfNode* op, const Expr& e) override;
Expr VisitExpr_(const VarNode* op) override;
Expr VisitExpr_(const ConstantNode* op) override;
Expr VisitExpr_(const GlobalVarNode* op) override;
Expr VisitExpr_(const OpNode* op) override;
Expr VisitExpr_(const TupleNode* op) override;
Expr VisitExpr_(const ParamNode* op) override;
Expr VisitExpr_(const FunctionNode* op) override;
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
......@@ -162,7 +162,7 @@ class ExprMutator
private:
/*! \brief Internal map used for memoization. */
tvm::Map<Expr, Expr> memo_;
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};
} // namespace relay
......
......@@ -41,12 +41,12 @@ void EnvironmentNode::Add(const GlobalVar &var,
const Function &func,
bool update) {
// Type check the item before we add it to the environment.
auto env = relay::GetRef<Environment>(this);
auto env = GetRef<Environment>(this);
Expr checked_expr = InferType(env, var, func);
if (const FunctionNode *func_node = checked_expr.as<FunctionNode>()) {
auto checked_func = relay::GetRef<Function>(func_node);
auto checked_func = GetRef<Function>(func_node);
auto type = checked_func->checked_type();
CHECK(IsFullyResolved(type));
......
......@@ -13,33 +13,33 @@ namespace tvm {
namespace relay {
Expr ExprMutator::Mutate(const Expr& expr) {
auto cached_expr = this->memo_.find(expr);
if (cached_expr != this->memo_.end()) {
return (*cached_expr).second;
auto it = this->memo_.find(expr);
if (it != this->memo_.end()) {
return it->second;
} else {
auto new_expr = this->ExprMutator::VisitExpr(expr, expr);
this->memo_.Set(expr, new_expr);
Expr new_expr = ExprMutator::VisitExpr(expr);
memo_[expr] = new_expr;
return new_expr;
}
}
Expr ExprMutator::VisitExpr_(const VarNode* op, const Expr& expr) {
return expr;
Expr ExprMutator::VisitExpr_(const VarNode* op) {
return GetRef<Expr>(op);
}
Expr ExprMutator::VisitExpr_(const ConstantNode* op, const Expr& expr) {
return expr;
Expr ExprMutator::VisitExpr_(const ConstantNode* op) {
return GetRef<Expr>(op);
}
Expr ExprMutator::VisitExpr_(const GlobalVarNode* op, const Expr& expr) {
return expr;
Expr ExprMutator::VisitExpr_(const GlobalVarNode* op) {
return GetRef<Expr>(op);
}
Expr ExprMutator::VisitExpr_(const OpNode* op, const Expr& expr) {
return expr;
Expr ExprMutator::VisitExpr_(const OpNode* op) {
return GetRef<Expr>(op);
}
Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) {
Expr ExprMutator::VisitExpr_(const TupleNode* op) {
tvm::Array<Expr> fields;
bool all_fields_unchanged = true;
for (auto field : op->fields) {
......@@ -49,23 +49,23 @@ Expr ExprMutator::VisitExpr_(const TupleNode* op, const Expr& e) {
}
if (all_fields_unchanged) {
return e;
return GetRef<Expr>(op);
} else {
return TupleNode::make(fields);
}
}
Expr ExprMutator::VisitExpr_(const ParamNode* op, const Expr& e) {
Expr ExprMutator::VisitExpr_(const ParamNode* op) {
Var var = Downcast<Var>(this->Mutate(op->var));
auto type = this->VisitType(op->type);
if (var == op->var && type == op->type) {
return e;
if (op->var.same_as(var) && op->type.same_as(type)) {
return GetRef<Expr>(op);
} else {
return ParamNode::make(var, type);
}
}
Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) {
Expr ExprMutator::VisitExpr_(const FunctionNode* op) {
tvm::Array<TypeParam> ty_params;
bool all_ty_params_changed = true;
......@@ -86,74 +86,82 @@ Expr ExprMutator::VisitExpr_(const FunctionNode* op, const Expr& e) {
auto ret_type = this->VisitType(op->ret_type);
auto body = this->Mutate(op->body);
if (ty_params.same_as(op->type_params) && params.same_as(op->params) &&
ret_type.same_as(op->ret_type) && body.same_as(op->body)) {
return e;
if (ty_params.same_as(op->type_params) &&
params.same_as(op->params) &&
ret_type.same_as(op->ret_type) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return FunctionNode::make(params, ret_type, body, ty_params);
}
}
Expr ExprMutator::VisitExpr_(const CallNode* call_node, const Expr& e) {
auto op = this->Mutate(call_node->op);
Expr ExprMutator::VisitExpr_(const CallNode* call_node) {
auto new_op = this->Mutate(call_node->op);
bool unchanged = call_node->op.same_as(new_op);
tvm::Array<Type> ty_args;
bool all_ty_args_unchanged = true;
for (auto ty_arg : call_node->type_args) {
auto new_ty_arg = this->VisitType(ty_arg);
ty_args.push_back(new_ty_arg);
all_ty_args_unchanged &= new_ty_arg.same_as(ty_arg);
unchanged &= new_ty_arg.same_as(ty_arg);
}
tvm::Array<Expr> call_args;
bool all_args_unchanged = true;
for (auto arg : call_node->args) {
auto new_arg = this->Mutate(arg);
call_args.push_back(new_arg);
all_args_unchanged &= new_arg.same_as(arg);
unchanged &= new_arg.same_as(arg);
}
if (all_ty_args_unchanged && all_args_unchanged &&
call_node->op.same_as(op)) {
return e;
if (unchanged) {
return GetRef<Expr>(call_node);
} else {
return CallNode::make(op, call_args, call_node->attrs, ty_args);
return CallNode::make(new_op, call_args, call_node->attrs, ty_args);
}
}
Expr ExprMutator::VisitExpr_(const LetNode* op, const Expr& e) {
Expr ExprMutator::VisitExpr_(const LetNode* op) {
Var var = Downcast<Var>(this->Mutate(op->var));
auto type = this->VisitType(op->value_type);
auto value = this->Mutate(op->value);
auto body = this->Mutate(op->body);
if (var.same_as(op->var) && type.same_as(op->value_type) &&
value.same_as(op->value) && body.same_as(op->body)) {
return e;
if (var.same_as(op->var) &&
type.same_as(op->value_type) &&
value.same_as(op->value) &&
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return LetNode::make(var, value, body, type);
}
}
Expr ExprMutator::VisitExpr_(const IfNode* op, const Expr& e) {
Expr ExprMutator::VisitExpr_(const IfNode* op) {
auto guard = this->Mutate(op->cond);
auto true_b = this->Mutate(op->true_branch);
auto false_b = this->Mutate(op->false_branch);
if (op->cond == guard && true_b == op->true_branch &&
false_b == op->false_branch) {
return e;
if (op->cond.same_as(guard) &&
op->true_branch.same_as(true_b) &&
op->false_branch.same_as(false_b)) {
return GetRef<Expr>(op);;
} else {
return IfNode::make(guard, true_b, false_b);
}
}
Type ExprMutator::VisitType(const Type& t) { return t; }
Type ExprMutator::VisitType(const Type& t) {
return t;
}
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { return; }
void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
}
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) { return; }
void ExprVisitor::ExprVisitor::VisitExpr_(const GlobalVarNode* op) {
}
void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) { return; }
void ExprVisitor::ExprVisitor::VisitExpr_(const ConstantNode* op) {
}
void ExprVisitor::ExprVisitor::VisitExpr_(const TupleNode* op) {
for (auto field : op->fields) {
......@@ -202,4 +210,3 @@ void ExprVisitor::VisitType(const Type& t) { return; }
} // namespace relay
} // namespace tvm
......@@ -78,7 +78,8 @@ struct TypeMutator : TypeFunctor<Type(const Type& n)> {
Array<TypeConstraint> type_constraints;
for (auto type_cs : op->type_constraints) {
auto new_type_cs = VisitType(type_cs);
if (const TypeConstraintNode* tin = As<TypeConstraintNode>(new_type_cs)) {
if (const TypeConstraintNode* tin =
new_type_cs.as_derived<TypeConstraintNode>()) {
type_constraints.push_back(GetRef<TypeConstraint>(tin));
} else {
CHECK(false) << new_type_cs << std::endl;
......
......@@ -20,7 +20,7 @@ TEST(ExprNodeRef, Basic) {
Var x("x");
Expr z = max(x + 1 + 2, 100);
const ir::Max* op = z.as<ir::Max>();
CHECK(op->GetNodeRef().same_as(z));
CHECK(NodeRef(op->GetNodePtr()).same_as(z));
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment