expr.cc 2.05 KB
Newer Older
tqchen committed
1 2 3 4 5 6 7
/*!
 *  Copyright (c) 2016 by Contributors
 * \file expr.cc
 */
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
8
#include <tvm/ir_operator.h>
tqchen committed
9 10 11 12
#include <ir/IRPrinter.h>
#include <memory>

namespace tvm {
13

14
using HalideIR::IR::RangeNode;
15

tqchen committed
16
Range::Range(Expr begin, Expr end)
17
    : Range(make_node<RangeNode>(
18 19
          begin,
          is_zero(begin) ? end : (end - begin))) {
tqchen committed
20 21
}

22
Range Range::make_by_min_extent(Expr min, Expr extent) {
23
  return Range(make_node<HalideIR::IR::RangeNode>(min, extent));
tqchen committed
24 25
}

26 27
IterVar IterVarNode::make(Range dom, Var var,
                          IterVarType t, std::string thread_tag) {
28
  NodePtr<IterVarNode> n = make_node<IterVarNode>();
tqchen committed
29
  n->dom = dom;
30
  n->var = var;
31
  n->iter_type = t;
tqchen committed
32 33 34 35
  n->thread_tag = thread_tag;
  return IterVar(n);
}

36 37 38 39 40 41 42 43 44 45
IterVar thread_axis(Range dom, std::string tag) {
  return IterVarNode::make(
      dom, Var(tag), kThreadIndex, tag);
}

IterVar reduce_axis(Range dom, std::string name) {
  return IterVarNode::make(
      dom, Var(name), kCommReduce);
}

tqchen committed
46 47 48 49 50
std::ostream& operator<<(std::ostream& os, const NodeRef& n) {  // NOLINT(*)
  IRPrinter(os).print(n);
  return os;
}

51 52 53 54
void Dump(const NodeRef& n) {
  std::cerr << n << "\n";
}

55 56 57 58
Var var(const std::string& name_hint, Type t) {
  return Var(name_hint, t);
}

tqchen committed
59 60 61 62 63 64
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<IterVarNode>([](const IterVarNode *op, IRPrinter *p) {
    p->stream << "iter_var(";
    if (op->var->name_hint.length() != 0) {
      p->stream  << op->var->name_hint << ", ";
    }
tqchen committed
65 66 67
    if (op->dom.defined()) {
      p->stream << op->dom;
    }
tqchen committed
68 69 70 71 72 73 74
    if (op->thread_tag.length() != 0) {
      p->stream << ", " << op->thread_tag;
    }
    p->stream << ")";
  });

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
75
.set_dispatch<RangeNode>([](const HalideIR::IR::RangeNode *op, IRPrinter *p) {
tqchen committed
76 77 78
    p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
  });

79 80 81

TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
82
TVM_REGISTER_NODE_TYPE(StrMapNode);
83
TVM_REGISTER_NODE_TYPE(RangeNode);
tqchen committed
84 85 86
TVM_REGISTER_NODE_TYPE(IterVarNode);

}  // namespace tvm