base.cc 1.89 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/*!
 *  Copyright (c) 2018 by Contributors
 * \file base.cc
 * \brief The core base types for Relay.
 */
#include <tvm/api_registry.h>
#include <tvm/relay/base.h>

namespace tvm {
namespace relay {

using tvm::IRPrinter;
using namespace tvm::runtime;

15 16 17 18
NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
  // always return pointer as the reference can change as map re-allocate.
  // or use another level of indirection by creating a unique_ptr
  static std::unordered_map<std::string, NodePtr<SourceNameNode> > source_map;
19 20 21

  auto sn = source_map.find(name);
  if (sn == source_map.end()) {
22 23 24 25
    NodePtr<SourceNameNode> n = make_node<SourceNameNode>();
    n->name = std::move(name);
    source_map[name] = n;
    return n;
26 27 28 29 30
  } else {
    return sn->second;
  }
}

31 32 33
SourceName SourceName::Get(const std::string& name) {
  return SourceName(GetSourceNameNode(name));
}
34 35

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
36
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
37 38
    p->stream << "SourceName(" << node->name << ", " << node << ")";
  });
39 40

TVM_REGISTER_NODE_TYPE(SourceNameNode)
41
.set_creator(GetSourceNameNode)
42 43 44 45 46
.set_global_key([](const Node* n) {
    return static_cast<const SourceNameNode*>(n)->name;
  });

Span SpanNode::make(SourceName source, int lineno, int col_offset) {
47
  auto n = make_node<SpanNode>();
48 49 50 51 52 53
  n->source = std::move(source);
  n->lineno = lineno;
  n->col_offset = col_offset;
  return Span(n);
}

54 55
TVM_REGISTER_NODE_TYPE(SpanNode);

56
TVM_REGISTER_API("relay._make.Span")
57
.set_body([](TVMArgs args, TVMRetValue* ret) {
58 59 60 61
    *ret = SpanNode::make(args[0], args[1], args[2]);
  });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
62
.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
63 64 65 66 67 68
    p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
              << node->col_offset << ")";
  });

}  // namespace relay
}  // namespace tvm