type.cc 4.58 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
/*!
 *  Copyright (c) 2018 by Contributors
 * \file src/tvm/ir/type.cc
 * \brief The type system AST nodes of Relay.
 */
#include <tvm/relay/type.h>

namespace tvm {
namespace relay {

using tvm::IRPrinter;
using namespace tvm::runtime;

14
TensorType TensorTypeNode::make(Array<IndexExpr> shape, DataType dtype) {
15
  NodePtr<TensorTypeNode> n = make_node<TensorTypeNode>();
16 17 18 19 20 21 22 23 24 25 26
  n->shape = std::move(shape);
  n->dtype = std::move(dtype);
  return TensorType(n);
}

TensorType TensorTypeNode::Scalar(DataType dtype) {
  return TensorTypeNode::make({}, dtype);
}

TVM_REGISTER_API("relay._make.TensorType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
27
  Array<IndexExpr> shape = args[0];
28 29 30 31 32 33 34 35 36 37
  *ret = TensorTypeNode::make(shape, args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TensorTypeNode>([](const TensorTypeNode *node,
                                     tvm::IRPrinter *p) {
  p->stream << "TensorTypeNode(" << node->dtype << ", " << node->shape << ")";
});

TypeParam TypeParamNode::make(std::string name, TypeParamNode::Kind kind) {
38
  NodePtr<TypeParamNode> n = make_node<TypeParamNode>();
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
  n->var = tvm::Var(name);
  n->kind = std::move(kind);
  return TypeParam(n);
}

TVM_REGISTER_API("relay._make.TypeParam")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  int kind = args[1];
  *ret =
    TypeParamNode::make(args[0], static_cast<TypeParamNode::Kind>(kind));
    });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TypeParamNode>([](const TypeParamNode *node,
                                    tvm::IRPrinter *p) {
  p->stream << "TypeParamNode(" << node->var->name_hint << ", "
    << node->kind << ")";
});

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
IncompleteType IncompleteTypeNode::make(TypeParamNode::Kind kind) {
  auto n = make_node<IncompleteTypeNode>();
  n->kind = std::move(kind);
  return IncompleteType(n);
}

TVM_REGISTER_API("relay._make.IncompleteType")
.set_body([](TVMArgs args, TVMRetValue* ret) {
    int kind = args[0];
    *ret = IncompleteTypeNode::make(static_cast<TypeParamNode::Kind>(kind));
  });

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<IncompleteTypeNode>(
    [](const IncompleteTypeNode* node,
       tvm::IRPrinter* p) {
      p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
    });

FuncType FuncTypeNode::make(tvm::Array<Type> arg_types,
                            Type ret_type,
79 80
                            tvm::Array<TypeParam> type_params,
                            tvm::Array<TypeConstraint> type_constraints) {
81
  NodePtr<FuncTypeNode> n = make_node<FuncTypeNode>();
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
  n->arg_types = std::move(arg_types);
  n->ret_type = std::move(ret_type);
  n->type_params = std::move(type_params);
  n->type_constraints = std::move(type_constraints);
  return FuncType(n);
}

TVM_REGISTER_API("relay._make.FuncType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
  *ret = FuncTypeNode::make(args[0], args[1], args[2], args[3]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<FuncTypeNode>([](const FuncTypeNode *node,
                                   tvm::IRPrinter *p) {
  p->stream << "FuncTypeNode(" << node->type_params << ", "
            << node->arg_types << ", " << node->ret_type << ", "
            << node->type_constraints << ")";
});

102 103 104 105
TypeRelation TypeRelationNode::make(TypeRelationFn func,
                                    Array<Type> args,
                                    int num_inputs,
                                    Attrs attrs) {
106
  NodePtr<TypeRelationNode> n = make_node<TypeRelationNode>();
107
  n->func = std::move(func);
108
  n->args = std::move(args);
109 110
  n->num_inputs = num_inputs;
  n->attrs = std::move(attrs);
111 112 113 114 115
  return TypeRelation(n);
}

TVM_REGISTER_API("relay._make.TypeRelation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
116
    *ret = TypeRelationNode::make(args[0], args[1], args[2], args[3]);
117 118 119
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
120 121 122 123
.set_dispatch<TypeRelationNode>([](const TypeRelationNode *node, tvm::IRPrinter *p) {
    p->stream << "TypeRelationNode("
              << node->func->name
              << ", " << node->args << ")";
124 125 126
});

TupleType TupleTypeNode::make(Array<Type> fields) {
127
  NodePtr<TupleTypeNode> n = make_node<TupleTypeNode>();
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
  n->fields = std::move(fields);
  return TupleType(n);
}

TVM_REGISTER_API("relay._make.TupleType")
.set_body([](TVMArgs args, TVMRetValue *ret) {
    *ret = TupleTypeNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleTypeNode>([](const TupleTypeNode *node,
                                    tvm::IRPrinter *p) {
  p->stream << "TupleTypeNode(" << node->fields << ")";
});

}  // namespace relay
}  // namespace tvm