base.cc 2.28 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/*!
 *  Copyright (c) 2018 by Contributors
 * \file base.cc
 * \brief The core base types for Relay.
 */
#include <tvm/api_registry.h>
#include <tvm/relay/base.h>

namespace tvm {
namespace relay {

using tvm::IRPrinter;
using namespace tvm::runtime;

15 16 17 18
NodePtr<SourceNameNode> GetSourceNameNode(const std::string& name) {
  // always return pointer as the reference can change as map re-allocate.
  // or use another level of indirection by creating a unique_ptr
  static std::unordered_map<std::string, NodePtr<SourceNameNode> > source_map;
19 20 21

  auto sn = source_map.find(name);
  if (sn == source_map.end()) {
22 23
    NodePtr<SourceNameNode> n = make_node<SourceNameNode>();
    source_map[name] = n;
24
    n->name = std::move(name);
25
    return n;
26 27 28 29 30
  } else {
    return sn->second;
  }
}

31 32 33
SourceName SourceName::Get(const std::string& name) {
  return SourceName(GetSourceNameNode(name));
}
34

35 36 37 38 39
TVM_REGISTER_API("relay._make.SourceName")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    *ret = SourceName::Get(args[0]);
  });

40
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
41
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
42 43
    p->stream << "SourceName(" << node->name << ", " << node << ")";
  });
44 45

TVM_REGISTER_NODE_TYPE(SourceNameNode)
46
.set_creator(GetSourceNameNode)
47 48 49 50 51
.set_global_key([](const Node* n) {
    return static_cast<const SourceNameNode*>(n)->name;
  });

Span SpanNode::make(SourceName source, int lineno, int col_offset) {
52
  auto n = make_node<SpanNode>();
53 54 55 56 57 58
  n->source = std::move(source);
  n->lineno = lineno;
  n->col_offset = col_offset;
  return Span(n);
}

59 60
TVM_REGISTER_NODE_TYPE(SpanNode);

61
TVM_REGISTER_API("relay._make.Span")
62
.set_body([](TVMArgs args, TVMRetValue* ret) {
63 64 65 66
    *ret = SpanNode::make(args[0], args[1], args[2]);
  });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
67
.set_dispatch<SpanNode>([](const SpanNode* node, tvm::IRPrinter* p) {
68 69 70 71
    p->stream << "SpanNode(" << node->source << ", " << node->lineno << ", "
              << node->col_offset << ")";
  });

72 73
TVM_REGISTER_NODE_TYPE(IdNode);

74 75 76 77 78 79 80 81 82
TVM_REGISTER_API("relay._base.set_span")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    NodeRef node_ref = args[0];
    auto rn = node_ref.as_derived<RelayNode>();
    CHECK(rn);
    Span sp = args[1];
    rn->span = sp;
});

83 84
}  // namespace relay
}  // namespace tvm