attrs.cc 10.8 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22
/*!
 * \file attrs.cc
 */
23
#include <tvm/ir/attrs.h>
24
#include <tvm/runtime/registry.h>
25
#include "attr_functor.h"
26 27 28 29 30 31 32

namespace tvm {

void DictAttrsNode::VisitAttrs(AttrVisitor* v)  {
  v->Visit("__dict__", &dict);
}

33 34 35 36
void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
  v->Visit("__dict__", &dict);
}

37 38 39 40 41
void DictAttrsNode::InitByPackedArgs(
    const runtime::TVMArgs& args, bool allow_unknown) {
  for (int i = 0; i < args.size(); i += 2) {
    std::string key = args[i];
    runtime::TVMArgValue val = args[i + 1];
42 43
    if (val.IsObjectRef<ObjectRef>()) {
      dict.Set(key, val.operator ObjectRef());
44
    } else if (val.type_code() == kTVMStr) {
45
      dict.Set(key, PrimExpr(val.operator std::string()));
46
    } else {
47
      dict.Set(key, val.operator PrimExpr());
48 49 50 51
    }
  }
}

52
Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
53 54 55
  return {};
}

56 57
Attrs DictAttrsNode::make(Map<std::string, ObjectRef> dict) {
  ObjectPtr<DictAttrsNode> n = make_object<DictAttrsNode>();
58 59 60 61
  n->dict = std::move(dict);
  return Attrs(n);
}

62 63
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const ObjectRef& node, ReprPrinter* p) {
64
    auto* op = static_cast<const DictAttrsNode*>(node.get());
65 66 67 68 69
    p->stream << op->dict;
});

TVM_REGISTER_NODE_TYPE(DictAttrsNode);

70 71
TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);

72

73
using namespace tir;
74
// Equal handler.
75
bool AttrsEqualHandler::Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
76 77 78 79
  if (lhs.same_as(rhs)) return true;
  if (!lhs.defined() || !rhs.defined()) return false;
  return this->VisitAttr(lhs, rhs);
}
80

81 82
bool AttrsEqualHandler::VisitAttrDefault_(const Object* lhs, const ObjectRef& other) {
  if (lhs->IsInstance<BaseAttrsNode>()) {
83 84 85 86
    AttrsEqual equal;
    equal.handler_ = this;
    return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
        other.get(), equal);
87
  }
88 89
  return lhs == other.get();
}
90

91 92
bool AttrsEqualHandler::VisitAttr_(const IntImmNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<IntImmNode>()) {
93
    return lhs->value == rhs->value;
94
  }
95 96
  return false;
}
97

98 99
bool AttrsEqualHandler::VisitAttr_(const FloatImmNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<FloatImmNode>()) {
100
    return lhs->value == rhs->value;
101
  }
102 103
  return false;
}
104

105 106
bool AttrsEqualHandler::VisitAttr_(const StringImmNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<StringImmNode>()) {
107
    return lhs->value == rhs->value;
108
  }
109 110
  return false;
}
111

112
bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const ObjectRef& other) {
113 114
  if (const auto* rhs = other.as<ArrayNode>()) {
    if (rhs->data.size() != lhs->data.size()) return false;
115
    for (size_t i = 0; i < lhs->data.size(); ++i) {
116
      if (!Equal(lhs->data[i], rhs->data[i])) return false;
117 118
    }
  }
119 120
  return true;
}
121

122
bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const ObjectRef& other) {
123 124 125 126 127
  if (const auto* rhs = other.as<StrMapNode>()) {
    if (rhs->data.size() != lhs->data.size()) return false;
    for (const auto& kv : lhs->data) {
      auto it = rhs->data.find(kv.first);
      if (it == rhs->data.end()) return false;
128
      if (!Equal(kv.second, it->second)) return false;
129 130
    }
  }
131 132
  return true;
}
133

134
#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName)                          \
135
  bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const ObjectRef& other) { \
136 137 138 139 140 141 142 143 144
    if (const auto* rhs = other.as<NodeName>()) {                       \
      if (!Equal(lhs->a, rhs->a)) return false;                         \
      if (!Equal(lhs->b, rhs->b)) return false;                         \
      return true;                                                      \
    } else {                                                            \
      return false;                                                     \
    }                                                                   \
  }                                                                     \

145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
TVM_DEFINE_ATTRS_BINOP_EQUAL(AddNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(SubNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MulNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(DivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(ModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MaxNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(MinNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LTNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NENode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(AndNode);
TVM_DEFINE_ATTRS_BINOP_EQUAL(OrNode);

bool AttrsEqualHandler::VisitAttr_(const NotNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<NotNode>()) {
165 166 167
    return Equal(lhs->a, rhs->a);
  } else {
    return false;
168
  }
169
}
170

171 172
bool AttrsEqualHandler::VisitAttr_(const CastNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<CastNode>()) {
173
    if (lhs->dtype != rhs->dtype) return false;
174 175 176
    return Equal(lhs->value, rhs->value);
  } else {
    return false;
177
  }
178
}
179

180 181
bool AttrsEqualHandler::VisitAttr_(const CallNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<CallNode>()) {
182 183
    return
        lhs->name == rhs->name &&
184
        lhs->dtype == rhs->dtype &&
185 186 187 188
        lhs->call_type == rhs->call_type &&
        Equal(lhs->args, rhs->args);
  } else {
    return false;
189
  }
190
}
191

192 193
bool AttrsEqualHandler::VisitAttr_(const SelectNode* lhs, const ObjectRef& other) {
  if (const auto* rhs = other.as<SelectNode>()) {
194 195 196 197 198 199
    return
        Equal(lhs->condition, rhs->condition) &&
        Equal(lhs->true_value, rhs->true_value) &&
        Equal(lhs->false_value, rhs->false_value);
  } else {
    return false;
200
  }
201
}
202

203
// Hash Handler.
204 205
size_t AttrsHashHandler::VisitAttrDefault_(const Object* value) {
  if (value->IsInstance<BaseAttrsNode>()) {
206 207 208 209
    AttrsHash hasher;
    hasher.handler_ = this;
    return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
  } else {
210
    return ObjectHash()(GetRef<ObjectRef>(value));
211
  }
212
}
213

214
size_t AttrsHashHandler::VisitAttr_(const IntImmNode* op) {
215 216
  return std::hash<int64_t>()(op->value);
}
217

218
size_t AttrsHashHandler::VisitAttr_(const FloatImmNode* op) {
219 220 221
  return std::hash<double>()(op->value);
}

222
size_t AttrsHashHandler::VisitAttr_(const StringImmNode* op) {
223 224 225 226 227 228
  return std::hash<std::string>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
  size_t result = op->data.size();
  for (size_t  i = 0; i < op->data.size(); ++i) {
229
    result = Combine(result, this->Hash(op->data[i]));
230
  }
231 232
  return result;
}
233

234
size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
235
    using Entry = std::pair<std::string, ObjectRef>;
236 237 238 239
    std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
    std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
        return a.first < b.first;
      });
240
    size_t result = 0;
241
    for (const Entry& kv : data) {
242
      result = Combine(result, std::hash<std::string>()(kv.first));
243
      result = Combine(result, this->Hash(kv.second));
244
    }
245 246
    return result;
}
247 248


249 250 251 252 253 254
#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName)                           \
  size_t AttrsHashHandler::VisitAttr_(const NodeName* op) {             \
    static size_t key = std::hash<std::string>()(NodeName::_type_key);  \
    return Combine(key, Combine(Hash(op->a), Hash(op->b)));             \
  }                                                                     \

255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
TVM_DEFINE_ATTRS_BINOP_HASH(AddNode);
TVM_DEFINE_ATTRS_BINOP_HASH(SubNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MulNode);
TVM_DEFINE_ATTRS_BINOP_HASH(DivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(ModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorDivNode);
TVM_DEFINE_ATTRS_BINOP_HASH(FloorModNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MaxNode);
TVM_DEFINE_ATTRS_BINOP_HASH(MinNode);
TVM_DEFINE_ATTRS_BINOP_HASH(GENode);
TVM_DEFINE_ATTRS_BINOP_HASH(GTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(LENode);
TVM_DEFINE_ATTRS_BINOP_HASH(LTNode);
TVM_DEFINE_ATTRS_BINOP_HASH(EQNode);
TVM_DEFINE_ATTRS_BINOP_HASH(NENode);
TVM_DEFINE_ATTRS_BINOP_HASH(AndNode);
TVM_DEFINE_ATTRS_BINOP_HASH(OrNode);

size_t AttrsHashHandler::VisitAttr_(const NotNode* op) {
  static size_t key = std::hash<std::string>()(NotNode::_type_key);
275 276 277
  return Combine(key, Hash(op->a));
}

278 279
size_t AttrsHashHandler::VisitAttr_(const CastNode* op) {
  static size_t key = std::hash<std::string>()(CastNode::_type_key);
280 281
  AttrsHash hasher;
  size_t res = key;
282
  res = Combine(res, hasher(op->dtype));
283 284 285 286
  res = Combine(res, Hash(op->value));
  return res;
}

287 288
size_t AttrsHashHandler::VisitAttr_(const CallNode* op) {
  static size_t key = std::hash<std::string>()(CallNode::_type_key);
289 290 291
  AttrsHash hasher;
  size_t res = key;
  res = Combine(res, hasher(op->name));
292
  res = Combine(res, hasher(op->dtype));
293 294 295 296
  res = Combine(res, Hash(op->args));
  return res;
}

297 298
size_t AttrsHashHandler::VisitAttr_(const SelectNode* op) {
  static size_t key = std::hash<std::string>()(SelectNode::_type_key);
299 300 301 302 303 304 305 306 307
  size_t res = key;
  res = Combine(res, Hash(op->condition));
  res = Combine(res, Hash(op->true_value));
  res = Combine(res, Hash(op->false_value));
  return res;
}


// Default case
308
bool AttrsEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
309
  if (lhs.same_as(rhs)) return true;
310 311 312 313 314
  if (handler_ == nullptr) {
    return AttrsEqualHandler().Equal(lhs, rhs);
  } else {
    return handler_->Equal(lhs, rhs);
  }
315 316
}

317
size_t AttrsHash::operator()(const ObjectRef& node) const {
318
  if (!node.defined()) return 0;
319 320 321 322 323
  if (handler_ == nullptr) {
    return AttrsHashHandler().Hash(node);
  } else {
    return handler_->Hash(node);
  }
324 325
}

326 327
size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
  return hasher(this->dict);
328 329
}

330
bool DictAttrsNode::ContentEqual(const Object* other, AttrsEqual equal) const {
331 332 333
  if (this == other) return true;
  if (other == nullptr) return false;
  if (this->type_index() != other->type_index()) return false;
334
  return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
335 336
}

337
TVM_REGISTER_GLOBAL("ir.AttrsListFieldInfo")
338 339 340 341
.set_body([](TVMArgs args, TVMRetValue* ret) {
  *ret = args[0].operator Attrs()->ListFieldInfo();
});

342
}  // namespace tvm