attrs.cc 9.83 KB
Newer Older
1 2 3 4 5
/*!
 *  Copyright (c) 2018 by Contributors
 * \file attrs.cc
 */
#include <tvm/attrs.h>
6
#include <tvm/api_registry.h>
7
#include "attr_functor.h"
8 9 10 11 12 13 14

namespace tvm {

void DictAttrsNode::VisitAttrs(AttrVisitor* v)  {
  v->Visit("__dict__", &dict);
}

15 16 17 18
void DictAttrsNode::VisitNonDefaultAttrs(AttrVisitor* v) {
  v->Visit("__dict__", &dict);
}

19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
void DictAttrsNode::InitByPackedArgs(
    const runtime::TVMArgs& args, bool allow_unknown) {
  for (int i = 0; i < args.size(); i += 2) {
    std::string key = args[i];
    runtime::TVMArgValue val = args[i + 1];
    if (val.type_code() == kNodeHandle) {
      dict.Set(key, val.operator NodeRef());
    } else if (val.type_code() == kStr) {
      dict.Set(key, Expr(val.operator std::string()));
    } else {
      dict.Set(key, val.operator Expr());
    }
  }
}

34
Array<AttrFieldInfo> DictAttrsNode::ListFieldInfo() const {
35 36 37 38
  return {};
}

Attrs DictAttrsNode::make(Map<std::string, NodeRef> dict) {
39
  NodePtr<DictAttrsNode> n = make_node<DictAttrsNode>();
40 41 42 43 44 45 46 47 48 49 50
  n->dict = std::move(dict);
  return Attrs(n);
}

TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
.set_dispatch<DictAttrsNode>([](const DictAttrsNode *op, IRPrinter *p) {
    p->stream << op->dict;
});

TVM_REGISTER_NODE_TYPE(DictAttrsNode);

51 52
TVM_REGISTER_NODE_TYPE(AttrFieldInfoNode);

53 54

using namespace ir;
55 56 57 58 59 60
// Equal handler.
bool AttrsEqualHandler::Equal(const NodeRef& lhs, const NodeRef& rhs) {
  if (lhs.same_as(rhs)) return true;
  if (!lhs.defined() || !rhs.defined()) return false;
  return this->VisitAttr(lhs, rhs);
}
61

62 63 64 65 66 67
bool AttrsEqualHandler::VisitAttrDefault_(const Node* lhs, const NodeRef& other) {
  if (lhs->derived_from<BaseAttrsNode>()) {
    AttrsEqual equal;
    equal.handler_ = this;
    return static_cast<const BaseAttrsNode*>(lhs)->ContentEqual(
        other.get(), equal);
68
  }
69 70
  return lhs == other.get();
}
71

72 73 74
bool AttrsEqualHandler::VisitAttr_(const IntImm* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<IntImm>()) {
    return lhs->value == rhs->value;
75
  }
76 77
  return false;
}
78

79 80 81
bool AttrsEqualHandler::VisitAttr_(const UIntImm* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<UIntImm>()) {
    return lhs->value == rhs->value;
82
  }
83 84
  return false;
}
85

86 87 88
bool AttrsEqualHandler::VisitAttr_(const FloatImm* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<FloatImm>()) {
    return lhs->value == rhs->value;
89
  }
90 91
  return false;
}
92

93 94 95
bool AttrsEqualHandler::VisitAttr_(const StringImm* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<StringImm>()) {
    return lhs->value == rhs->value;
96
  }
97 98
  return false;
}
99

100 101 102 103 104
bool AttrsEqualHandler::VisitAttr_(const ArrayNode* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<ArrayNode>()) {
    if (rhs->data.size() != lhs->data.size()) return false;
    for (size_t  i = 0; i < lhs->data.size(); ++i) {
      if (!Equal(NodeRef(lhs->data[i]), NodeRef(rhs->data[i]))) return false;
105 106
    }
  }
107 108
  return true;
}
109

110 111 112 113 114 115 116
bool AttrsEqualHandler::VisitAttr_(const StrMapNode* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<StrMapNode>()) {
    if (rhs->data.size() != lhs->data.size()) return false;
    for (const auto& kv : lhs->data) {
      auto it = rhs->data.find(kv.first);
      if (it == rhs->data.end()) return false;
      if (!Equal(NodeRef(kv.second), NodeRef(it->second))) return false;
117 118
    }
  }
119 120
  return true;
}
121

122 123 124 125 126 127 128 129 130 131 132 133 134 135
#define TVM_DEFINE_ATTRS_BINOP_EQUAL(NodeName)                          \
  bool AttrsEqualHandler::VisitAttr_(const NodeName* lhs, const NodeRef& other) { \
    if (const auto* rhs = other.as<NodeName>()) {                       \
      if (!Equal(lhs->a, rhs->a)) return false;                         \
      if (!Equal(lhs->b, rhs->b)) return false;                         \
      return true;                                                      \
    } else {                                                            \
      return false;                                                     \
    }                                                                   \
  }                                                                     \

TVM_DEFINE_ATTRS_BINOP_EQUAL(Add);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Sub);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mul);
Siva committed
136
TVM_DEFINE_ATTRS_BINOP_EQUAL(Div);
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
TVM_DEFINE_ATTRS_BINOP_EQUAL(Mod);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Max);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Min);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(GT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(LT);
TVM_DEFINE_ATTRS_BINOP_EQUAL(EQ);
TVM_DEFINE_ATTRS_BINOP_EQUAL(NE);
TVM_DEFINE_ATTRS_BINOP_EQUAL(And);
TVM_DEFINE_ATTRS_BINOP_EQUAL(Or);

bool AttrsEqualHandler::VisitAttr_(const Not* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<Not>()) {
    return Equal(lhs->a, rhs->a);
  } else {
    return false;
154
  }
155
}
156

157 158 159 160 161 162
bool AttrsEqualHandler::VisitAttr_(const Cast* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<Cast>()) {
    if (lhs->type != rhs->type) return false;
    return Equal(lhs->value, rhs->value);
  } else {
    return false;
163
  }
164
}
165

166 167 168 169 170 171 172 173 174
bool AttrsEqualHandler::VisitAttr_(const Call* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<Call>()) {
    return
        lhs->name == rhs->name &&
        lhs->type == rhs->type &&
        lhs->call_type == rhs->call_type &&
        Equal(lhs->args, rhs->args);
  } else {
    return false;
175
  }
176
}
177

178 179 180 181 182 183 184 185
bool AttrsEqualHandler::VisitAttr_(const Select* lhs, const NodeRef& other) {
  if (const auto* rhs = other.as<Select>()) {
    return
        Equal(lhs->condition, rhs->condition) &&
        Equal(lhs->true_value, rhs->true_value) &&
        Equal(lhs->false_value, rhs->false_value);
  } else {
    return false;
186
  }
187
}
188

189 190 191 192 193 194 195 196
// Hash Handler.
size_t AttrsHashHandler::VisitAttrDefault_(const Node* value) {
  if (value->derived_from<BaseAttrsNode>()) {
    AttrsHash hasher;
    hasher.handler_ = this;
    return static_cast<const BaseAttrsNode*>(value)->ContentHash(hasher);
  } else {
    return NodeHash()(GetRef<NodeRef>(value));
197
  }
198
}
199

200 201 202
size_t AttrsHashHandler::VisitAttr_(const IntImm* op) {
  return std::hash<int64_t>()(op->value);
}
203

204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
size_t AttrsHashHandler::VisitAttr_(const UIntImm* op) {
  return std::hash<uint64_t>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const FloatImm* op) {
  return std::hash<double>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const StringImm* op) {
  return std::hash<std::string>()(op->value);
}

size_t AttrsHashHandler::VisitAttr_(const ArrayNode* op) {
  size_t result = op->data.size();
  for (size_t  i = 0; i < op->data.size(); ++i) {
    result = Combine(result, this->Hash(NodeRef(op->data[i])));
220
  }
221 222
  return result;
}
223

224
size_t AttrsHashHandler::VisitAttr_(const StrMapNode* lhs) {
225 226 227 228 229
    using Entry = std::pair<std::string, NodePtr<Node> >;
    std::vector<Entry> data(lhs->data.begin(), lhs->data.end());
    std::sort(data.begin(), data.end(), [](const Entry& a, const Entry& b) {
        return a.first < b.first;
      });
230
    size_t result = 0;
231
    for (const Entry& kv : data) {
232 233
      result = Combine(result, std::hash<std::string>()(kv.first));
      result = Combine(result, this->Hash(NodeRef(kv.second)));
234
    }
235 236
    return result;
}
237 238


239 240 241 242 243 244 245 246 247
#define TVM_DEFINE_ATTRS_BINOP_HASH(NodeName)                           \
  size_t AttrsHashHandler::VisitAttr_(const NodeName* op) {             \
    static size_t key = std::hash<std::string>()(NodeName::_type_key);  \
    return Combine(key, Combine(Hash(op->a), Hash(op->b)));             \
  }                                                                     \

TVM_DEFINE_ATTRS_BINOP_HASH(Add);
TVM_DEFINE_ATTRS_BINOP_HASH(Sub);
TVM_DEFINE_ATTRS_BINOP_HASH(Mul);
Siva committed
248
TVM_DEFINE_ATTRS_BINOP_HASH(Div);
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
TVM_DEFINE_ATTRS_BINOP_HASH(Mod);
TVM_DEFINE_ATTRS_BINOP_HASH(Max);
TVM_DEFINE_ATTRS_BINOP_HASH(Min);
TVM_DEFINE_ATTRS_BINOP_HASH(GE);
TVM_DEFINE_ATTRS_BINOP_HASH(GT);
TVM_DEFINE_ATTRS_BINOP_HASH(LE);
TVM_DEFINE_ATTRS_BINOP_HASH(LT);
TVM_DEFINE_ATTRS_BINOP_HASH(EQ);
TVM_DEFINE_ATTRS_BINOP_HASH(NE);
TVM_DEFINE_ATTRS_BINOP_HASH(And);
TVM_DEFINE_ATTRS_BINOP_HASH(Or);

size_t AttrsHashHandler::VisitAttr_(const Not* op) {
  static size_t key = std::hash<std::string>()(Not::_type_key);
  return Combine(key, Hash(op->a));
}

size_t AttrsHashHandler::VisitAttr_(const Cast* op) {
  static size_t key = std::hash<std::string>()(Cast::_type_key);
  AttrsHash hasher;
  size_t res = key;
  res = Combine(res, hasher(op->type));
  res = Combine(res, Hash(op->value));
  return res;
}

size_t AttrsHashHandler::VisitAttr_(const Call* op) {
  static size_t key = std::hash<std::string>()(Call::_type_key);
  AttrsHash hasher;
  size_t res = key;
  res = Combine(res, hasher(op->name));
  res = Combine(res, hasher(op->type));
  res = Combine(res, Hash(op->args));
  return res;
}

size_t AttrsHashHandler::VisitAttr_(const Select* op) {
  static size_t key = std::hash<std::string>()(Select::_type_key);
  size_t res = key;
  res = Combine(res, Hash(op->condition));
  res = Combine(res, Hash(op->true_value));
  res = Combine(res, Hash(op->false_value));
  return res;
}


// Default case
bool AttrsEqual::operator()(const NodeRef& lhs, const NodeRef& rhs) const {
297
  if (lhs.same_as(rhs)) return true;
298 299 300 301 302
  if (handler_ == nullptr) {
    return AttrsEqualHandler().Equal(lhs, rhs);
  } else {
    return handler_->Equal(lhs, rhs);
  }
303 304
}

305
size_t AttrsHash::operator()(const NodeRef& node) const {
306
  if (!node.defined()) return 0;
307 308 309 310 311
  if (handler_ == nullptr) {
    return AttrsHashHandler().Hash(node);
  } else {
    return handler_->Hash(node);
  }
312 313
}

314 315
size_t DictAttrsNode::ContentHash(AttrsHash hasher) const {
  return hasher(this->dict);
316 317
}

318
bool DictAttrsNode::ContentEqual(const Node* other, AttrsEqual equal) const {
319 320 321
  if (this == other) return true;
  if (other == nullptr) return false;
  if (this->type_index() != other->type_index()) return false;
322
  return equal(this->dict, static_cast<const DictAttrsNode*>(other)->dict);
323 324
}

325 326 327 328 329
TVM_REGISTER_API("_AttrsListFieldInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
  *ret = args[0].operator Attrs()->ListFieldInfo();
});

330
}  // namespace tvm