Commit 70d93028 by tqchen

Keep up with changes of NodeRef

parent 2fc12dcd
Subproject commit bf96f8af0dfd1f79d258c7c1506f9ded932b94a9 Subproject commit eb2f7d604a611318fc685172847bcf5ba2fcf835
...@@ -95,13 +95,13 @@ class RDomainNode : public Node { ...@@ -95,13 +95,13 @@ class RDomainNode : public Node {
RDomainNode(Array<Var> index, Domain domain) RDomainNode(Array<Var> index, Domain domain)
: index(index), domain(domain) { : index(index), domain(domain) {
} }
const char* type_key() const override {
return "RDomain";
}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("index", &index); v->Visit("index", &index);
v->Visit("domain", &domain); v->Visit("domain", &domain);
} }
static constexpr const char* _type_key = "RDomain";
TVM_DECLARE_NODE_TYPE_INFO(RDomainNode);
}; };
inline const RDomainNode* RDomain::operator->() const { inline const RDomainNode* RDomain::operator->() const {
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#ifndef TVM_IR_MUTATOR_H_ #ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_ #define TVM_IR_MUTATOR_H_
#include <tvm/ir_node.h> #include <tvm/ir_functor.h>
#include <unordered_map> #include <unordered_map>
#include "./expr.h" #include "./expr.h"
...@@ -16,7 +16,7 @@ namespace ir { ...@@ -16,7 +16,7 @@ namespace ir {
* \brief a base class for mutator to iterative mutate the IR * \brief a base class for mutator to iterative mutate the IR
* *
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern. * This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This enables easy extensions of possible new IRNode. * This enables easy extensions of possible new Node.
* It also makes changing return types easier. * It also makes changing return types easier.
* *
* \note If you want to return a different type other than Expr and Stmt, * \note If you want to return a different type other than Expr and Stmt,
...@@ -44,9 +44,9 @@ class IRMutator { ...@@ -44,9 +44,9 @@ class IRMutator {
/*! \brief destructor */ /*! \brief destructor */
virtual ~IRMutator() {} virtual ~IRMutator() {}
/*! \brief functor type of expr mutation */ /*! \brief functor type of expr mutation */
using FMutateExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRMutator*)>; using FMutateExpr = IRFunctor<Expr(const NodeRef&, const Expr&, IRMutator*)>;
/*! \brief functor type of stmt mutation */ /*! \brief functor type of stmt mutation */
using FMutateStmt = IRFunctor<Stmt(const IRNodeRef&, const Stmt&, IRMutator*)>; using FMutateStmt = IRFunctor<Stmt(const NodeRef&, const Stmt&, IRMutator*)>;
/*! \return internal vtable of expr */ /*! \return internal vtable of expr */
static FMutateExpr& vtable_expr(); // NOLINT(*) static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */ /*! \return internal stmt of expr */
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#ifndef TVM_IR_PASS_H_ #ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_ #define TVM_IR_PASS_H_
#include <tvm/ir_node.h> #include <tvm/ir_functor.h>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "./expr.h" #include "./expr.h"
......
...@@ -15,7 +15,7 @@ namespace ir { ...@@ -15,7 +15,7 @@ namespace ir {
* \brief a base class for visitor to iterative traverse the IR * \brief a base class for visitor to iterative traverse the IR
* *
* This IRVisitor is implemented via IRFunctor * This IRVisitor is implemented via IRFunctor
* This enables extensions of possible new IRNode. * This enables extensions of possible new Node.
* *
* \sa IRFunctor, PostOrderVisit * \sa IRFunctor, PostOrderVisit
*/ */
...@@ -24,14 +24,14 @@ class IRVisitor { ...@@ -24,14 +24,14 @@ class IRVisitor {
/*! /*!
* \brief recursively visit an IR node * \brief recursively visit an IR node
*/ */
virtual void Visit(const IRNodeRef& node) { virtual void Visit(const NodeRef& node) {
static const FVisit& f = vtable(); static const FVisit& f = vtable();
if (node.defined()) f(node, this); if (node.defined()) f(node, this);
} }
/*! \brief destructor */ /*! \brief destructor */
virtual ~IRVisitor() {} virtual ~IRVisitor() {}
/*! \brief functor type of visitor */ /*! \brief functor type of visitor */
using FVisit = IRFunctor<void(const IRNodeRef&, IRVisitor*)>; using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
/*! \return internal vtable*/ /*! \return internal vtable*/
static FVisit& vtable(); static FVisit& vtable();
}; };
...@@ -42,7 +42,7 @@ class IRVisitor { ...@@ -42,7 +42,7 @@ class IRVisitor {
* \param node The ir to be visited. * \param node The ir to be visited.
* \param fvisit The visitor function to be applied. * \param fvisit The visitor function to be applied.
*/ */
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit); void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
} // namespace ir } // namespace ir
} // namespace tvm } // namespace tvm
......
...@@ -23,9 +23,6 @@ class ComputeOpNode : public OperationNode { ...@@ -23,9 +23,6 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */ /*! \brief constructor */
ComputeOpNode() {} ComputeOpNode() {}
const char* type_key() const final {
return "ComputeOp";
}
size_t num_outputs() const final { size_t num_outputs() const final {
return 1; return 1;
} }
...@@ -43,6 +40,9 @@ class ComputeOpNode : public OperationNode { ...@@ -43,6 +40,9 @@ class ComputeOpNode : public OperationNode {
std::string name, std::string name,
Array<Var> dim_var, Array<Var> dim_var,
Expr body); Expr body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
}; };
......
...@@ -62,6 +62,10 @@ class ScheduleNode : public Node { ...@@ -62,6 +62,10 @@ class ScheduleNode : public Node {
const char* type_key() const final { const char* type_key() const final {
return "Schedule"; return "Schedule";
} }
const uint32_t type_index() const final {
static uint32_t tidx = TypeKey2Index(type_key());
return tidx;
}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope); v->Visit("scope", &scope);
v->Visit("op", &op); v->Visit("op", &op);
......
...@@ -46,14 +46,15 @@ class DimSplitNode : public SplitNode { ...@@ -46,14 +46,15 @@ class DimSplitNode : public SplitNode {
Expr factor; Expr factor;
/*! \brief constructor */ /*! \brief constructor */
DimSplitNode() {} DimSplitNode() {}
const char* type_key() const final {
return "DimSplit";
}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var); v->Visit("var", &var);
v->Visit("factor", &factor); v->Visit("factor", &factor);
} }
static Split make(Var var, Expr factor); static Split make(Var var, Expr factor);
static constexpr const char* _type_key = "DimSplit";
TVM_DECLARE_NODE_TYPE_INFO(DimSplitNode);
}; };
// Implementations of inline functions // Implementations of inline functions
......
...@@ -104,9 +104,7 @@ class TensorNode : public FunctionBaseNode { ...@@ -104,9 +104,7 @@ class TensorNode : public FunctionBaseNode {
int value_index{0}; int value_index{0};
/*! \brief constructor */ /*! \brief constructor */
TensorNode() {} TensorNode() {}
const char* type_key() const final {
return "Tensor";
}
void VisitAttrs(AttrVisitor* v) final { void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape); v->Visit("shape", &shape);
v->Visit("name", &name); v->Visit("name", &name);
...@@ -125,6 +123,9 @@ class TensorNode : public FunctionBaseNode { ...@@ -125,6 +123,9 @@ class TensorNode : public FunctionBaseNode {
Type dtype, Type dtype,
Operation op, Operation op,
int value_index); int value_index);
static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_NODE_TYPE_INFO(TensorNode);
}; };
/*! /*!
......
...@@ -9,5 +9,6 @@ ...@@ -9,5 +9,6 @@
#include "./base.h" #include "./base.h"
#include "./expr.h" #include "./expr.h"
#include "./tensor.h" #include "./tensor.h"
#include "./operation.h"
#endif // TVM_TVM_H_ #endif // TVM_TVM_H_
...@@ -26,9 +26,9 @@ TVM_REGISTER_API(_format_str) ...@@ -26,9 +26,9 @@ TVM_REGISTER_API(_format_str)
CHECK(args.at(0).type_id == kNodeHandle); CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os; std::ostringstream os;
auto& sptr = args.at(0).sptr; auto& sptr = args.at(0).sptr;
if (sptr->is_type<TensorNode>()) { if (dynamic_cast<const TensorNode*>(sptr.get())) {
os << args.at(0).operator Tensor(); os << args.at(0).operator Tensor();
} else if (sptr->is_type<RDomainNode>()) { } else if (dynamic_cast<const RDomainNode*>(sptr.get())) {
os << args.at(0).operator RDomain(); os << args.at(0).operator RDomain();
} else if (dynamic_cast<const BaseExprNode*>(sptr.get())) { } else if (dynamic_cast<const BaseExprNode*>(sptr.get())) {
os << args.at(0).operator Expr(); os << args.at(0).operator Expr();
......
...@@ -22,7 +22,7 @@ namespace { ...@@ -22,7 +22,7 @@ namespace {
using namespace Halide::Internal; using namespace Halide::Internal;
// const expr // const expr
inline Expr ReturnSelfExpr(const IRNodeRef&, const Expr& e, IRMutator*) { inline Expr ReturnSelfExpr(const NodeRef&, const Expr& e, IRMutator*) {
return e; return e;
} }
......
...@@ -12,9 +12,9 @@ namespace { ...@@ -12,9 +12,9 @@ namespace {
// visitor to implement apply // visitor to implement apply
class IRApplyVisit : public IRVisitor { class IRApplyVisit : public IRVisitor {
public: public:
explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {} explicit IRApplyVisit(std::function<void(const NodeRef&)> f) : f_(f) {}
void Visit(const IRNodeRef& node) final { void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return; if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get()); visited_.insert(node.get());
IRVisitor::Visit(node); IRVisitor::Visit(node);
...@@ -22,13 +22,13 @@ class IRApplyVisit : public IRVisitor { ...@@ -22,13 +22,13 @@ class IRApplyVisit : public IRVisitor {
} }
private: private:
std::function<void(const IRNodeRef&)> f_; std::function<void(const NodeRef&)> f_;
std::unordered_set<const Node*> visited_; std::unordered_set<const Node*> visited_;
}; };
} // namespace } // namespace
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) { void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node); IRApplyVisit(fvisit).Visit(node);
} }
...@@ -42,7 +42,7 @@ namespace { ...@@ -42,7 +42,7 @@ namespace {
using namespace Halide::Internal; using namespace Halide::Internal;
void NoOp(const IRNodeRef& n, IRVisitor* v) { void NoOp(const NodeRef& n, IRVisitor* v) {
} }
inline void VisitArray(Array<Expr> arr, IRVisitor* v) { inline void VisitArray(Array<Expr> arr, IRVisitor* v) {
......
...@@ -5,21 +5,37 @@ ...@@ -5,21 +5,37 @@
#include <tvm/ir.h> #include <tvm/ir.h>
#include <tvm/ir_mutator.h> #include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include "./scope.h"
namespace tvm { namespace tvm {
namespace ir { namespace ir {
namespace { namespace {
Stmt MakeCompute(const ComputeOpNode* op, const Array<Split>& splits) { /*!
Tensor output; * \brief make nest loops given list of stmt, whose body is not defined.
std::vector<Expr> args(op->dim_var.size()); * \param nest A list of For and LetStmt, whose body is not defined.
for (size_t i = 0; i < args.size(); ++i) { * \param body The inner-most body of the loop
args[i] = op->dim_var[i]; */
Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
while (!nest.empty()) {
Stmt s = std::move(nest.back()); nest.pop_back();
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
n->body = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
} }
Array<Expr> values{op->body}; }
Stmt stmt = Provide::make(output, values, args); return body;
// add splits from ousside most to outsidemost to innermost
return stmt;
} }
......
/*!
* Copyright (c) 2016 by Contributors
* \file scope.h
* \brief attribute scope data structure,
* defines attributes on current domain
*/
#ifndef TVM_PASS_SCOPE_H_
#define TVM_PASS_SCOPE_H_
#include <tvm/ir.h>
#include <unordered_map>
#include <vector>
#include <string>
namespace tvm {
namespace ir {
/*!
* \brief Attribute scope of Nodes in the IR.
* \tparam ValueType The value of of the scope.
*/
template<typename K, typename V>
class Scope {
public:
/*!
* \brief Push value to scope
* \param key the key to be pushed.
* \param v The value to be pushed.
*/
inline void Push(const K& key, V v) {
data_[key].emplace_back(v);
}
/*!
* \brief Pop value from scope.
* \param key the key to be poped
*/
inline void Pop(const K& key) {
auto& v = data_[key];
CHECK_NE(v.size(), 0);
v.pop_back();
}
/*!
* \brief Get value from the scope
* \param key the key to fetch.
* \return The value to be fetched.
*/
inline V operator[](const K& key) const {
const auto it = data_.find(key);
CHECK(it != data_.end() && it->second.size() != 0)
<< "cannot find value in scope";
return it->second.back();
}
private:
std::unordered_map<K, std::vector<V> > data_;
};
/*! \brief Attribute key for specific attribute */
struct AttrKey {
/*! \brief The node of the attribute */
NodeRef node;
/*! \brief The type key of the attribute. */
std::string type_key;
// overload operator ==
inline bool operator==(const AttrKey& other) const {
return node == other.node && type_key == other.type_key;
}
};
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::AttrKey> {
std::size_t operator()(const ::tvm::ir::AttrKey& k) const {
size_t lhs = k.node.hash();
size_t rhs = std::hash<std::string>()(k.type_key);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
#endif // TVM_PASS_SCOPE_H_
...@@ -17,7 +17,7 @@ namespace { ...@@ -17,7 +17,7 @@ namespace {
// global functor to get var definition from // global functor to get var definition from
struct FGetVarDef { struct FGetVarDef {
using FType = IRFunctor<VarExpr (const IRNodeRef&)>; using FType = IRFunctor<VarExpr (const NodeRef&)>;
static FType& vtable() { // NOLINT(*) static FType& vtable() { // NOLINT(*)
static FType inst; return inst; static FType inst; return inst;
} }
...@@ -37,8 +37,8 @@ TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable) ...@@ -37,8 +37,8 @@ TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable)
}); });
struct FSetVarDef { struct FSetVarDef {
using FTypeExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>; using FTypeExpr = IRFunctor<Expr (const NodeRef&, VarExpr)>;
using FTypeStmt = IRFunctor<Stmt (const IRNodeRef&, VarExpr)>; using FTypeStmt = IRFunctor<Stmt (const NodeRef&, VarExpr)>;
static FTypeExpr& vtable_expr() { // NOLINT(*) static FTypeExpr& vtable_expr() { // NOLINT(*)
static FTypeExpr inst; return inst; static FTypeExpr inst; return inst;
} }
...@@ -69,7 +69,7 @@ class IRVerifySSA : public IRVisitor { ...@@ -69,7 +69,7 @@ class IRVerifySSA : public IRVisitor {
public: public:
bool is_ssa{true}; bool is_ssa{true};
void Visit(const IRNodeRef& n) final { void Visit(const NodeRef& n) final {
if (!is_ssa) return; if (!is_ssa) return;
static auto& fget_var_def = FGetVarDef::vtable(); static auto& fget_var_def = FGetVarDef::vtable();
if (fget_var_def.can_dispatch(n)) { if (fget_var_def.can_dispatch(n)) {
......
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <tvm/tvm.h> #include <tvm/tvm.h>
#include <tvm/ir_node.h> #include <tvm/ir_functor.h>
TEST(IRF, Basic) { TEST(IRF, Basic) {
using namespace Halide::Internal; using namespace Halide::Internal;
...@@ -9,7 +9,7 @@ TEST(IRF, Basic) { ...@@ -9,7 +9,7 @@ TEST(IRF, Basic) {
Var x("x"); Var x("x");
auto z = x + 1; auto z = x + 1;
IRFunctor<int(const IRNodeRef& n, int b)> f; IRFunctor<int(const NodeRef& n, int b)> f;
LOG(INFO) << "x"; LOG(INFO) << "x";
f.set_dispatch<Variable>([](const Variable* n, int b) { f.set_dispatch<Variable>([](const Variable* n, int b) {
return b; return b;
......
...@@ -11,7 +11,7 @@ TEST(IRVisitor, CountVar) { ...@@ -11,7 +11,7 @@ TEST(IRVisitor, CountVar) {
Var x("x"), y; Var x("x"), y;
auto z = x + 1 + y + y; auto z = x + 1 + y + y;
ir::PostOrderVisit(z, [&n_var](const IRNodeRef& n) { ir::PostOrderVisit(z, [&n_var](const NodeRef& n) {
if (n.as<Variable>()) ++n_var; if (n.as<Variable>()) ++n_var;
}); });
CHECK_EQ(n_var, 2); CHECK_EQ(n_var, 2);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment