type.cc 5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/*!
 *  Copyright (c) 2018 by Contributors
 * \file src/tvm/ir/type.cc
 * \brief The type system AST nodes of Relay.
 */
#include <tvm/relay/type.h>

namespace tvm {
namespace relay {

using tvm::IRPrinter;
using namespace tvm::runtime;

14
TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
15
  NodePtr<TensorTypeNode> n = make_node<TensorTypeNode>();
16 17 18 19 20 21 22 23 24
  n->shape = std::move(shape);
  n->dtype = std::move(dtype);
  return TensorType(n);
}

TensorType TensorTypeNode::Scalar(DataType dtype) {
  return TensorTypeNode::make({}, dtype);
}

Siju committed
25 26 27 28 29 30 31 32 33 34 35 36
IndexExpr TensorTypeNode::Size() const {
  if (shape.size() == 0) {
    return make_const(Int(64), 1);
  }

  IndexExpr size = shape[0];
  for (size_t i = 1; i < shape.size(); ++i) {
    size *= shape[i];
  }
  return size;
}

37 38
TVM_REGISTER_NODE_TYPE(TensorTypeNode);

39
TVM_REGISTER_API("relay._make.TensorType")
40
.set_body([](TVMArgs args, TVMRetValue* ret) {
41
  Array<IndexExpr> shape = args[0];
42 43 44 45
  *ret = TensorTypeNode::make(shape, args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
46 47
.set_dispatch<TensorTypeNode>([](const TensorTypeNode* node,
                                 tvm::IRPrinter* p) {
48
  p->stream << "TensorType(" << node->shape << ", " << node->dtype << ")";
49 50
});

51 52
TypeVar TypeVarNode::make(std::string name, TypeVarNode::Kind kind) {
  NodePtr<TypeVarNode> n = make_node<TypeVarNode>();
53 54
  n->var = tvm::Var(name);
  n->kind = std::move(kind);
55
  return TypeVar(n);
56 57
}

58
TVM_REGISTER_NODE_TYPE(TypeVarNode);
59

60
TVM_REGISTER_API("relay._make.TypeVar")
61
.set_body([](TVMArgs args, TVMRetValue* ret) {
62 63
  int kind = args[1];
  *ret =
64
    TypeVarNode::make(args[0], static_cast<TypeVarNode::Kind>(kind));
65 66 67
    });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
68 69
.set_dispatch<TypeVarNode>([](const TypeVarNode* node,
                                    tvm::IRPrinter* p) {
70
  p->stream << "TypeVarNode(" << node->var->name_hint << ", "
71 72 73
    << node->kind << ")";
});

74
IncompleteType IncompleteTypeNode::make(TypeVarNode::Kind kind) {
75 76 77 78 79
  auto n = make_node<IncompleteTypeNode>();
  n->kind = std::move(kind);
  return IncompleteType(n);
}

80 81
TVM_REGISTER_NODE_TYPE(IncompleteTypeNode);

82 83 84
TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    int kind = args[0];
85
    *ret = IncompleteTypeNode::make(static_cast<TypeVarNode::Kind>(kind));
86 87 88 89 90 91 92 93 94 95 96
  });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IncompleteTypeNode>(
    [](const IncompleteTypeNode* node,
       tvm::IRPrinter* p) {
      p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
    });

FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
                            Type ret_type,
97
                            tvm::Array<TypeVar> type_params,
98
                            tvm::Array<TypeConstraint> type_constraints) {
99
  NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>();
100 101 102 103 104 105 106
  n->arg_types = std::move(arg_types);
  n->ret_type = std::move(ret_type);
  n->type_params = std::move(type_params);
  n->type_constraints = std::move(type_constraints);
  return FuncType(n);
}

107 108
TVM_REGISTER_NODE_TYPE(FuncTypeNode);

109
TVM_REGISTER_API("relay._make.FuncType")
110
.set_body([](TVMArgs args, TVMRetValue* ret) {
111 112 113 114
  *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
115 116
.set_dispatch<FuncTypeNode>([](const FuncTypeNode* node,
                                   tvm::IRPrinter* p) {
117 118 119 120 121
  p->stream << "FuncTypeNode(" << node->type_params << ", "
            << node->arg_types << ", " << node->ret_type << ", "
            << node->type_constraints << ")";
});

122 123 124 125
TypeRelation TypeRelationNode::make(TypeRelationFn func,
                                    Array<Type> args,
                                    int num_inputs,
                                    Attrs attrs) {
126
  NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>();
127
  n->func = std::move(func);
128
  n->args = std::move(args);
129 130
  n->num_inputs = num_inputs;
  n->attrs = std::move(attrs);
131 132 133
  return TypeRelation(n);
}

134 135
TVM_REGISTER_NODE_TYPE(TypeRelationNode);

136
TVM_REGISTER_API("relay._make.TypeRelation")
137
.set_body([](TVMArgs args, TVMRetValue* ret) {
138
    *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
139 140 141
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
142
.set_dispatch<TypeRelationNode>([](const TypeRelationNode* node, tvm::IRPrinter* p) {
143 144 145
    p->stream << "TypeRelationNode("
              << node->func->name
              << ", " << node->args << ")";
146 147 148
});

TupleType TupleTypeNode::make(Array<Type> fields) {
149
  NodePtr<TupleTypeNode> n = make_node<TupleTypeNode>();
150 151 152 153
  n->fields = std::move(fields);
  return TupleType(n);
}

154 155
TVM_REGISTER_NODE_TYPE(TupleTypeNode);

156
TVM_REGISTER_API("relay._make.TupleType")
157
.set_body([](TVMArgs args, TVMRetValue* ret) {
158 159 160 161
    *ret = TupleTypeNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
162 163
.set_dispatch<TupleTypeNode>([](const TupleTypeNode* node,
                                tvm::IRPrinter* p) {
164 165 166 167 168
  p->stream << "TupleTypeNode(" << node->fields << ")";
});

}  // namespace relay
}  // namespace tvm