expr.cc 6.42 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

/*!
21
 * \file src/ir/expr.cc
22 23 24 25
 * \brief The expression AST nodes for the common IR infra.
 */
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
26
#include <tvm/ir/function.h>
27 28 29 30 31
// NOTE: reverse dependency on top/tir.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
//
// Rationale: convert from IterVar and top::Tensor
32
#include <tvm/te/tensor.h>
33
#include <tvm/tir/expr.h>
34 35 36

namespace tvm {

37 38 39 40 41 42 43 44 45
PrimExpr::PrimExpr(int32_t value)
    : PrimExpr(IntImm(DataType::Int(32), value)) {}

PrimExpr::PrimExpr(float value)
    : PrimExpr(FloatImm(DataType::Float(32), value)) {}

PrimExpr::PrimExpr(std::string str)
    : PrimExpr(tir::StringImmNode::make(str)) {}

46 47
PrimExpr PrimExpr::FromObject_(ObjectPtr<Object> ptr) {
  using runtime::ObjectTypeChecker;
48 49
  if (ptr->IsInstance<tir::IterVarNode>()) {
    return tir::IterVar(ptr)->var;
50
  }
51 52
  if (ptr->IsInstance<te::TensorNode>()) {
    return te::Tensor(ptr)();
53 54 55 56 57 58 59
  }
  CHECK(ObjectTypeChecker<PrimExpr>::Check(ptr.get()))
      << "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
      << " but get " << ptr->GetTypeKey();
  return PrimExpr(ptr);
}

60

61 62 63 64 65 66 67 68 69 70 71 72 73 74
IntImm::IntImm(DataType dtype, int64_t value) {
  CHECK(dtype.is_scalar())
      << "ValueError: IntImm can only take scalar.";
  CHECK(dtype.is_int() || dtype.is_uint())
      << "ValueError: IntImm can only take scalar.";
  if (dtype.is_uint()) {
    CHECK_GE(value, 0U);
  }
  ObjectPtr<IntImmNode> node = make_object<IntImmNode>();
  node->dtype = dtype;
  node->value = value;
  data_ = std::move(node);
}

75
TVM_REGISTER_GLOBAL("ir.IntImm")
76 77 78 79
.set_body_typed([](DataType dtype, int64_t value) {
  return IntImm(dtype, value);
});

80 81
TVM_REGISTER_NODE_TYPE(IntImmNode);

82 83
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<IntImmNode>([](const ObjectRef& node, ReprPrinter* p) {
84 85 86 87 88 89 90
    auto* op = static_cast<const IntImmNode*>(node.get());
    if (op->dtype == DataType::Int(32)) {
      p->stream << op->value;
    } else {
      p->stream << "(" << op->dtype << ")" << op->value;
    }
  });
91 92 93 94 95 96 97 98 99 100

FloatImm::FloatImm(DataType dtype, double value) {
  CHECK_EQ(dtype.lanes(), 1)
      << "ValueError: FloatImm can only take scalar.";
  ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
  node->dtype = dtype;
  node->value = value;
  data_ = std::move(node);
}

101
TVM_REGISTER_GLOBAL("ir.FloatImm")
102 103 104 105
.set_body_typed([](DataType dtype, double value) {
  return FloatImm(dtype, value);
});

106 107
TVM_REGISTER_NODE_TYPE(FloatImmNode);

108 109
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<FloatImmNode>([](const ObjectRef& node, ReprPrinter* p) {
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
    auto* op = static_cast<const FloatImmNode*>(node.get());
    auto& stream = p->stream;
    switch (op->dtype.bits()) {
      case 64:
        stream << op->value;
        break;
      case 32:
        stream << op->value << 'f';
        break;
      case 16:
        stream << op->value << 'h';
        break;
      default:
        LOG(FATAL) << "Unknown float type bits=" << op->dtype.bits();
    }
  });


Range::Range(PrimExpr begin, PrimExpr end)
    : Range(make_object<RangeNode>(
          begin,
          tir::is_zero(begin) ? end : (end - begin))) {
}

Range Range::make_by_min_extent(PrimExpr min, PrimExpr extent) {
  return Range(make_object<RangeNode>(min, extent));
}

138 139 140 141 142 143 144 145
TVM_REGISTER_GLOBAL("ir.range_by_min_extent")
.set_body_typed(Range::make_by_min_extent);

TVM_REGISTER_GLOBAL("ir.Range")
.set_body([](TVMArgs args,  TVMRetValue* ret) {
  *ret = Range(args[0], args[1]);
  });

146 147
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<RangeNode>([](const ObjectRef& node, ReprPrinter* p) {
148 149 150 151 152 153 154 155 156
    auto* op = static_cast<const RangeNode*>(node.get());
    p->stream << "range(min=" << op->min << ", ext=" << op->extent << ')';
  });

TVM_REGISTER_NODE_TYPE(ArrayNode);
TVM_REGISTER_NODE_TYPE(MapNode);
TVM_REGISTER_NODE_TYPE(StrMapNode);
TVM_REGISTER_NODE_TYPE(RangeNode);

157

158 159 160 161 162 163 164 165
GlobalVar::GlobalVar(std::string name_hint) {
  ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
  n->name_hint = std::move(name_hint);
  data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(GlobalVarNode);

166
TVM_REGISTER_GLOBAL("ir.GlobalVar")
167 168 169 170
.set_body_typed([](std::string name){
  return GlobalVar(name);
});

171 172
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<GlobalVarNode>([](const ObjectRef& ref, ReprPrinter* p) {
173 174 175 176
    auto* node = static_cast<const GlobalVarNode*>(ref.get());
    p->stream << "GlobalVar(" << node->name_hint << ")";
  });

177
// Container printer
178 179
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ArrayNode>([](const ObjectRef& node, ReprPrinter* p) {
180 181 182 183 184 185 186 187 188 189 190
    auto* op = static_cast<const ArrayNode*>(node.get());
    p->stream << '[';
    for (size_t i = 0 ; i < op->data.size(); ++i) {
      if (i != 0) {
        p->stream << ", ";
      }
      p->Print(op->data[i]);
    }
    p->stream << ']';
});

191 192
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<MapNode>([](const ObjectRef& node, ReprPrinter* p) {
193 194 195 196 197 198 199 200 201 202 203 204 205
    auto* op = static_cast<const MapNode*>(node.get());
    p->stream << '{';
    for (auto it = op->data.begin(); it != op->data.end(); ++it) {
      if (it != op->data.begin()) {
        p->stream << ", ";
      }
      p->Print(it->first);
      p->stream << ": ";
      p->Print(it->second);
    }
    p->stream << '}';
  });

206 207
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StrMapNode>([](const ObjectRef& node, ReprPrinter* p) {
208 209 210 211 212 213 214 215 216 217 218
    auto* op = static_cast<const StrMapNode*>(node.get());
    p->stream << '{';
    for (auto it = op->data.begin(); it != op->data.end(); ++it) {
      if (it != op->data.begin()) {
        p->stream << ", ";
      }
      p->stream << '\"' << it->first << "\": ";
      p->Print(it->second);
    }
    p->stream << '}';
  });
219
}  // namespace tvm