ir.cc 4.15 KB
Newer Older
1 2
/*!
 *  Copyright (c) 2016 by Contributors
tqchen committed
3
 * \file ir.cc
4 5 6
 */
#include <tvm/base.h>
#include <tvm/expr.h>
tqchen committed
7
#include <tvm/ir.h>
ziheng committed
8
#include <tvm/ir_pass.h>
9
#include <ir/IR.h>
tqchen committed
10
#include <ir/IRPrinter.h>
11
#include <memory>
12
#include "../pass/ir_util.h"
13

14
namespace HalideIR {
tqchen committed
15 16
namespace Internal {

ziheng committed
17
using tvm::ir::CommReducerNode;
tqchen committed
18
using tvm::ir::Reduce;
tqchen committed
19
using tvm::ir::AttrStmt;
tqchen committed
20

tqchen committed
21
template<>
22
void ExprNode<Reduce>::accept(IRVisitor *v, const Expr&) const {
tqchen committed
23 24 25
  LOG(FATAL) << "Reduce do not work with old Visitor, use IRFunctor style visitor";
}

tqchen committed
26 27
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<Reduce>([](const Reduce *op, IRPrinter *p) {
ziheng committed
28
  p->stream << "reduce(combiner="
29 30
            << op->combiner;
  p->stream << ", source=" << op->source;
31
  p->stream << ", axis=" << op->axis;
32 33
  p->stream << ", where=" << op->condition;
  p->stream << ", value_index=" << op->value_index;
34
  p->stream << ")";
tqchen committed
35 36
});

ziheng committed
37 38
TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<CommReducerNode>([](const CommReducerNode *op, IRPrinter *p) {
39 40 41 42
  p->stream << "comm_reducer(result=" << op->result
            << ", lhs=" << op->lhs
            << ", rhs=" << op->rhs
            << ", identity_element=" << op->identity_element
ziheng committed
43 44
            << ")";
});
tqchen committed
45
}  // namespace Internal
46
}  // namespace HalideIR
tqchen committed
47

48
namespace tvm {
tqchen committed
49
namespace ir {
50

51 52 53 54
CommReducer CommReducerNode::make(Array<Var> lhs,
                                  Array<Var> rhs,
                                  Array<Expr> result,
                                  Array<Expr> identity_element) {
ziheng committed
55
  auto node = std::make_shared<CommReducerNode>();
56 57
  node->lhs = lhs;
  node->rhs = rhs;
ziheng committed
58 59 60 61 62
  node->result = result;
  node->identity_element = identity_element;
  return CommReducer(node);
}

63 64 65 66
Array<Expr> CommReducerNode::operator()(Array<Expr> a, Array<Expr> b) const {
  CHECK_EQ(a.size(), b.size());
  CHECK_EQ(lhs.size(), a.size());
  CHECK_EQ(rhs.size(), b.size());
ziheng committed
67
  Map<Var, Expr> value_map;
68 69 70 71 72 73 74
  for (size_t i = 0; i < a.size(); ++i) {
    value_map.Set(lhs[i], a[i]);
    value_map.Set(rhs[i], b[i]);
  }
  return UpdateArray(result, [&value_map] (const Expr& e) {
      return Substitute(e, value_map);
    });
ziheng committed
75 76
}

77 78
Expr Reduce::make(CommReducer combiner, Array<Expr> source,
                  Array<IterVar> axis, Expr condition, int value_index) {
79 80 81 82
  for (size_t i = 0; i < axis.size(); ++i) {
    CHECK_EQ(axis[i]->iter_type, kCommReduce)
        << "Can only take axis created by reduce_axis";
  }
83 84 85
  if (!condition.defined()) {
    condition = const_true();
  }
tqchen committed
86 87
  auto n = std::make_shared<Reduce>();
  CHECK(source.defined());
88 89
  for (size_t i = 0; i < axis.size(); ++i) {
    CHECK(axis[i].defined());
tqchen committed
90
  }
91 92 93 94
  n->type = source[value_index].type();
  n->combiner = std::move(combiner);
  n->source = std::move(source);
  n->axis = std::move(axis);
95
  n->condition = condition;
96
  n->value_index = value_index;
tqchen committed
97 98 99
  return Expr(n);
}

100
TVM_REGISTER_NODE_TYPE(CommReducerNode);
tqchen committed
101 102
TVM_REGISTER_NODE_TYPE(Reduce);
TVM_REGISTER_NODE_TYPE(AttrStmt);
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143

TVM_REGISTER_NODE_TYPE(FloatImm);
TVM_REGISTER_NODE_TYPE(IntImm);
TVM_REGISTER_NODE_TYPE(UIntImm);
TVM_REGISTER_NODE_TYPE(StringImm);
TVM_REGISTER_NODE_TYPE(Cast);
TVM_REGISTER_NODE_TYPE(Variable);
TVM_REGISTER_NODE_TYPE(Add);
TVM_REGISTER_NODE_TYPE(Sub);
TVM_REGISTER_NODE_TYPE(Mul);
TVM_REGISTER_NODE_TYPE(Div);
TVM_REGISTER_NODE_TYPE(Mod);
TVM_REGISTER_NODE_TYPE(Min);
TVM_REGISTER_NODE_TYPE(Max);
TVM_REGISTER_NODE_TYPE(EQ);
TVM_REGISTER_NODE_TYPE(NE);
TVM_REGISTER_NODE_TYPE(LT);
TVM_REGISTER_NODE_TYPE(LE);
TVM_REGISTER_NODE_TYPE(GT);
TVM_REGISTER_NODE_TYPE(GE);
TVM_REGISTER_NODE_TYPE(And);
TVM_REGISTER_NODE_TYPE(Or);
TVM_REGISTER_NODE_TYPE(Not);
TVM_REGISTER_NODE_TYPE(Select);
TVM_REGISTER_NODE_TYPE(Load);
TVM_REGISTER_NODE_TYPE(Ramp);
TVM_REGISTER_NODE_TYPE(Broadcast);
TVM_REGISTER_NODE_TYPE(Call);
TVM_REGISTER_NODE_TYPE(Let);
TVM_REGISTER_NODE_TYPE(LetStmt);
TVM_REGISTER_NODE_TYPE(AssertStmt);
TVM_REGISTER_NODE_TYPE(ProducerConsumer);
TVM_REGISTER_NODE_TYPE(For);
TVM_REGISTER_NODE_TYPE(Store);
TVM_REGISTER_NODE_TYPE(Provide);
TVM_REGISTER_NODE_TYPE(Allocate);
TVM_REGISTER_NODE_TYPE(Free);
TVM_REGISTER_NODE_TYPE(Realize);
TVM_REGISTER_NODE_TYPE(Block);
TVM_REGISTER_NODE_TYPE(IfThenElse);
TVM_REGISTER_NODE_TYPE(Evaluate);
144

tqchen committed
145
}  // namespace ir
146
}  // namespace tvm