Commit 70d93028 by tqchen

Keep up with changes of NodeRef

parent 2fc12dcd
Subproject commit bf96f8af0dfd1f79d258c7c1506f9ded932b94a9
Subproject commit eb2f7d604a611318fc685172847bcf5ba2fcf835
......@@ -95,13 +95,13 @@ class RDomainNode : public Node {
RDomainNode(Array<Var> index, Domain domain)
: index(index), domain(domain) {
}
const char* type_key() const override {
return "RDomain";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("index", &index);
v->Visit("domain", &domain);
}
static constexpr const char* _type_key = "RDomain";
TVM_DECLARE_NODE_TYPE_INFO(RDomainNode);
};
inline const RDomainNode* RDomain::operator->() const {
......
......@@ -6,7 +6,7 @@
#ifndef TVM_IR_MUTATOR_H_
#define TVM_IR_MUTATOR_H_
#include <tvm/ir_node.h>
#include <tvm/ir_functor.h>
#include <unordered_map>
#include "./expr.h"
......@@ -16,7 +16,7 @@ namespace ir {
* \brief a base class for mutator to iterative mutate the IR
*
* This IRMutator is implemented via IRFunctor instead of Visitor Pattern.
* This enables easy extensions of possible new IRNode.
* This enables easy extensions of possible new Node.
* It also makes changing return types easier.
*
* \note If you want to return a different type other than Expr and Stmt,
......@@ -44,9 +44,9 @@ class IRMutator {
/*! \brief destructor */
virtual ~IRMutator() {}
/*! \brief functor type of expr mutation */
using FMutateExpr = IRFunctor<Expr(const IRNodeRef&, const Expr&, IRMutator*)>;
using FMutateExpr = IRFunctor<Expr(const NodeRef&, const Expr&, IRMutator*)>;
/*! \brief functor type of stmt mutation */
using FMutateStmt = IRFunctor<Stmt(const IRNodeRef&, const Stmt&, IRMutator*)>;
using FMutateStmt = IRFunctor<Stmt(const NodeRef&, const Stmt&, IRMutator*)>;
/*! \return internal vtable of expr */
static FMutateExpr& vtable_expr(); // NOLINT(*)
/*! \return internal stmt of expr */
......
......@@ -9,7 +9,7 @@
#ifndef TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <tvm/ir_node.h>
#include <tvm/ir_functor.h>
#include <unordered_map>
#include <vector>
#include "./expr.h"
......
......@@ -15,7 +15,7 @@ namespace ir {
* \brief a base class for visitor to iterative traverse the IR
*
* This IRVisitor is implemented via IRFunctor
* This enables extensions of possible new IRNode.
* This enables extensions of possible new Node.
*
* \sa IRFunctor, PostOrderVisit
*/
......@@ -24,14 +24,14 @@ class IRVisitor {
/*!
* \brief recursively visit an IR node
*/
virtual void Visit(const IRNodeRef& node) {
virtual void Visit(const NodeRef& node) {
static const FVisit& f = vtable();
if (node.defined()) f(node, this);
}
/*! \brief destructor */
virtual ~IRVisitor() {}
/*! \brief functor type of visitor */
using FVisit = IRFunctor<void(const IRNodeRef&, IRVisitor*)>;
using FVisit = IRFunctor<void(const NodeRef&, IRVisitor*)>;
/*! \return internal vtable*/
static FVisit& vtable();
};
......@@ -42,7 +42,7 @@ class IRVisitor {
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit);
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit);
} // namespace ir
} // namespace tvm
......
......@@ -23,9 +23,6 @@ class ComputeOpNode : public OperationNode {
/*! \brief constructor */
ComputeOpNode() {}
const char* type_key() const final {
return "ComputeOp";
}
size_t num_outputs() const final {
return 1;
}
......@@ -43,6 +40,9 @@ class ComputeOpNode : public OperationNode {
std::string name,
Array<Var> dim_var,
Expr body);
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode);
};
......
......@@ -62,6 +62,10 @@ class ScheduleNode : public Node {
const char* type_key() const final {
return "Schedule";
}
const uint32_t type_index() const final {
static uint32_t tidx = TypeKey2Index(type_key());
return tidx;
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("scope", &scope);
v->Visit("op", &op);
......
......@@ -46,14 +46,15 @@ class DimSplitNode : public SplitNode {
Expr factor;
/*! \brief constructor */
DimSplitNode() {}
const char* type_key() const final {
return "DimSplit";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("var", &var);
v->Visit("factor", &factor);
}
static Split make(Var var, Expr factor);
static constexpr const char* _type_key = "DimSplit";
TVM_DECLARE_NODE_TYPE_INFO(DimSplitNode);
};
// Implementations of inline functions
......
......@@ -104,9 +104,7 @@ class TensorNode : public FunctionBaseNode {
int value_index{0};
/*! \brief constructor */
TensorNode() {}
const char* type_key() const final {
return "Tensor";
}
void VisitAttrs(AttrVisitor* v) final {
v->Visit("shape", &shape);
v->Visit("name", &name);
......@@ -125,6 +123,9 @@ class TensorNode : public FunctionBaseNode {
Type dtype,
Operation op,
int value_index);
static constexpr const char* _type_key = "Tensor";
TVM_DECLARE_NODE_TYPE_INFO(TensorNode);
};
/*!
......
......@@ -9,5 +9,6 @@
#include "./base.h"
#include "./expr.h"
#include "./tensor.h"
#include "./operation.h"
#endif // TVM_TVM_H_
......@@ -26,9 +26,9 @@ TVM_REGISTER_API(_format_str)
CHECK(args.at(0).type_id == kNodeHandle);
std::ostringstream os;
auto& sptr = args.at(0).sptr;
if (sptr->is_type<TensorNode>()) {
if (dynamic_cast<const TensorNode*>(sptr.get())) {
os << args.at(0).operator Tensor();
} else if (sptr->is_type<RDomainNode>()) {
} else if (dynamic_cast<const RDomainNode*>(sptr.get())) {
os << args.at(0).operator RDomain();
} else if (dynamic_cast<const BaseExprNode*>(sptr.get())) {
os << args.at(0).operator Expr();
......
......@@ -22,7 +22,7 @@ namespace {
using namespace Halide::Internal;
// const expr
inline Expr ReturnSelfExpr(const IRNodeRef&, const Expr& e, IRMutator*) {
inline Expr ReturnSelfExpr(const NodeRef&, const Expr& e, IRMutator*) {
return e;
}
......
......@@ -12,9 +12,9 @@ namespace {
// visitor to implement apply
class IRApplyVisit : public IRVisitor {
public:
explicit IRApplyVisit(std::function<void(const IRNodeRef&)> f) : f_(f) {}
explicit IRApplyVisit(std::function<void(const NodeRef&)> f) : f_(f) {}
void Visit(const IRNodeRef& node) final {
void Visit(const NodeRef& node) final {
if (visited_.count(node.get()) != 0) return;
visited_.insert(node.get());
IRVisitor::Visit(node);
......@@ -22,13 +22,13 @@ class IRApplyVisit : public IRVisitor {
}
private:
std::function<void(const IRNodeRef&)> f_;
std::function<void(const NodeRef&)> f_;
std::unordered_set<const Node*> visited_;
};
} // namespace
void PostOrderVisit(const IRNodeRef& node, std::function<void(const IRNodeRef&)> fvisit) {
void PostOrderVisit(const NodeRef& node, std::function<void(const NodeRef&)> fvisit) {
IRApplyVisit(fvisit).Visit(node);
}
......@@ -42,7 +42,7 @@ namespace {
using namespace Halide::Internal;
void NoOp(const IRNodeRef& n, IRVisitor* v) {
void NoOp(const NodeRef& n, IRVisitor* v) {
}
inline void VisitArray(Array<Expr> arr, IRVisitor* v) {
......
......@@ -5,21 +5,37 @@
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_pass.h>
#include "./scope.h"
namespace tvm {
namespace ir {
namespace {
Stmt MakeCompute(const ComputeOpNode* op, const Array<Split>& splits) {
Tensor output;
std::vector<Expr> args(op->dim_var.size());
for (size_t i = 0; i < args.size(); ++i) {
args[i] = op->dim_var[i];
/*!
* \brief make nest loops given list of stmt, whose body is not defined.
* \param nest A list of For and LetStmt, whose body is not defined.
* \param body The inner-most body of the loop
*/
Stmt MakeLoop(std::vector<Stmt>&& nest, Stmt body) {
while (!nest.empty()) {
Stmt s = std::move(nest.back()); nest.pop_back();
if (s.as<For>()) {
auto n = std::make_shared<For>(*s.as<For>());
n->body = body;
body = Stmt(n);
} else if (s.as<LetStmt>()) {
auto n = std::make_shared<LetStmt>(*s.as<LetStmt>());
n->body = body;
body = Stmt(n);
} else if (s.as<AttrStmt>()) {
auto n = std::make_shared<AttrStmt>(*s.as<AttrStmt>());
n->body = body;
body = Stmt(n);
} else {
LOG(FATAL) << "not supported nest type";
}
Array<Expr> values{op->body};
Stmt stmt = Provide::make(output, values, args);
// add splits from ousside most to outsidemost to innermost
return stmt;
}
return body;
}
......
/*!
* Copyright (c) 2016 by Contributors
* \file scope.h
* \brief attribute scope data structure,
* defines attributes on current domain
*/
#ifndef TVM_PASS_SCOPE_H_
#define TVM_PASS_SCOPE_H_
#include <tvm/ir.h>
#include <unordered_map>
#include <vector>
#include <string>
namespace tvm {
namespace ir {
/*!
* \brief Attribute scope of Nodes in the IR.
* \tparam ValueType The value of of the scope.
*/
template<typename K, typename V>
class Scope {
public:
/*!
* \brief Push value to scope
* \param key the key to be pushed.
* \param v The value to be pushed.
*/
inline void Push(const K& key, V v) {
data_[key].emplace_back(v);
}
/*!
* \brief Pop value from scope.
* \param key the key to be poped
*/
inline void Pop(const K& key) {
auto& v = data_[key];
CHECK_NE(v.size(), 0);
v.pop_back();
}
/*!
* \brief Get value from the scope
* \param key the key to fetch.
* \return The value to be fetched.
*/
inline V operator[](const K& key) const {
const auto it = data_.find(key);
CHECK(it != data_.end() && it->second.size() != 0)
<< "cannot find value in scope";
return it->second.back();
}
private:
std::unordered_map<K, std::vector<V> > data_;
};
/*! \brief Attribute key for specific attribute */
struct AttrKey {
/*! \brief The node of the attribute */
NodeRef node;
/*! \brief The type key of the attribute. */
std::string type_key;
// overload operator ==
inline bool operator==(const AttrKey& other) const {
return node == other.node && type_key == other.type_key;
}
};
} // namespace ir
} // namespace tvm
namespace std {
template <>
struct hash<::tvm::ir::AttrKey> {
std::size_t operator()(const ::tvm::ir::AttrKey& k) const {
size_t lhs = k.node.hash();
size_t rhs = std::hash<std::string>()(k.type_key);
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
} // namespace std
#endif // TVM_PASS_SCOPE_H_
......@@ -17,7 +17,7 @@ namespace {
// global functor to get var definition from
struct FGetVarDef {
using FType = IRFunctor<VarExpr (const IRNodeRef&)>;
using FType = IRFunctor<VarExpr (const NodeRef&)>;
static FType& vtable() { // NOLINT(*)
static FType inst; return inst;
}
......@@ -37,8 +37,8 @@ TVM_STATIC_IR_FUNCTOR(FGetVarDef, vtable)
});
struct FSetVarDef {
using FTypeExpr = IRFunctor<Expr (const IRNodeRef&, VarExpr)>;
using FTypeStmt = IRFunctor<Stmt (const IRNodeRef&, VarExpr)>;
using FTypeExpr = IRFunctor<Expr (const NodeRef&, VarExpr)>;
using FTypeStmt = IRFunctor<Stmt (const NodeRef&, VarExpr)>;
static FTypeExpr& vtable_expr() { // NOLINT(*)
static FTypeExpr inst; return inst;
}
......@@ -69,7 +69,7 @@ class IRVerifySSA : public IRVisitor {
public:
bool is_ssa{true};
void Visit(const IRNodeRef& n) final {
void Visit(const NodeRef& n) final {
if (!is_ssa) return;
static auto& fget_var_def = FGetVarDef::vtable();
if (fget_var_def.can_dispatch(n)) {
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/ir_node.h>
#include <tvm/ir_functor.h>
TEST(IRF, Basic) {
using namespace Halide::Internal;
......@@ -9,7 +9,7 @@ TEST(IRF, Basic) {
Var x("x");
auto z = x + 1;
IRFunctor<int(const IRNodeRef& n, int b)> f;
IRFunctor<int(const NodeRef& n, int b)> f;
LOG(INFO) << "x";
f.set_dispatch<Variable>([](const Variable* n, int b) {
return b;
......
......@@ -11,7 +11,7 @@ TEST(IRVisitor, CountVar) {
Var x("x"), y;
auto z = x + 1 + y + y;
ir::PostOrderVisit(z, [&n_var](const IRNodeRef& n) {
ir::PostOrderVisit(z, [&n_var](const NodeRef& n) {
if (n.as<Variable>()) ++n_var;
});
CHECK_EQ(n_var, 2);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment