Commit 3bfa5fc0 by Jared Roesch Committed by Tianqi Chen

[RELAY][TypeSystem] Add support for populating type args (#1962)

parent 3a1bb8c7
...@@ -485,6 +485,36 @@ inline ValueType OpMap<ValueType>::get(const Op& op, ...@@ -485,6 +485,36 @@ inline ValueType OpMap<ValueType>::get(const Op& op,
return map_.get<ValueType>(op, def_value); return map_.get<ValueType>(op, def_value);
} }
/*!
* \brief Check that an expression is a "primtive operator".
*
* Will return true if the expression is an operator which
* matches the form of primtive operators registered directly
* by the Relay codebase.
*
* That is the arguments are all type variables, and there is a single
* type relation applied to the input and output types.
*/
inline bool IsPrimitiveOp(const Expr& expr) {
const auto* op = expr.as<OpNode>();
if (!op) {
return false;
}
const auto& fn_ty = op->op_type;
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
if (rel == nullptr) return false;
// validate if the type parameter matches up
for (size_t i = 0; i < fn_ty->type_params.size(); ++i) {
if (!fn_ty->type_params[i].same_as(rel->args[i])) return false;
}
return true;
}
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
#endif // TVM_RELAY_OP_H_ #endif // TVM_RELAY_OP_H_
...@@ -278,10 +278,7 @@ class TextPrinter : ...@@ -278,10 +278,7 @@ class TextPrinter :
} }
TextValue VisitExpr_(const CallNode* op) final { TextValue VisitExpr_(const CallNode* op) final {
// TODO(tqchen, M.K.): support generic call
// possibly through meta-data // possibly through meta-data
CHECK_EQ(op->type_args.size(), 0U)
<< "generic call not yet supported";
TextValue call_op = GetValue(op->op); TextValue call_op = GetValue(op->op);
std::vector<TextValue> args; std::vector<TextValue> args;
for (Expr arg : op->args) { for (Expr arg : op->args) {
...@@ -289,7 +286,23 @@ class TextPrinter : ...@@ -289,7 +286,23 @@ class TextPrinter :
} }
TextValue id = this->AllocTempVar(); TextValue id = this->AllocTempVar();
this->PrintIndent(); this->PrintIndent();
stream_ << id << " = " << call_op << "(";
stream_ << id << " = " << call_op;
auto type_args = op->type_args;
if (!IsPrimitiveOp(op->op) && type_args.size() > 0U) {
stream_ << "<";
for (size_t i = 0; i < op->type_args.size(); ++i) {
this->PrintType(type_args[i], stream_);
if (i + 1 != type_args.size()) {
stream_ << ", ";
}
}
stream_ << ">";
}
stream_ << "(";
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
stream_ << args[i]; stream_ << args[i];
if (i + 1 != args.size()) { if (i + 1 != args.size()) {
......
...@@ -61,6 +61,17 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem") ...@@ -61,6 +61,17 @@ TVM_REGISTER_API("tvm.relay.type_relation.TupleGetItem")
.set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>( .set_body_typed<bool(const Array<Type>&, int, const Attrs&, const TypeReporter&)>(
TupleGetItemRel); TupleGetItemRel);
struct ResolvedTypeInfo {
explicit ResolvedTypeInfo(Type checked_type, Array<Type> type_args)
: checked_type(checked_type), type_args(type_args) {}
ResolvedTypeInfo() {}
Type checked_type;
// Only allocated when the expression is a call.
Array<Type> type_args = Array<Type>(NodePtr<Node>(nullptr));
};
// //
// The inference algorithm can roughly be devided into three stages: // The inference algorithm can roughly be devided into three stages:
// - Populate the constraints by visiting the expression (TypeInferencer.GetType) // - Populate the constraints by visiting the expression (TypeInferencer.GetType)
...@@ -87,7 +98,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -87,7 +98,8 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
Environment env_; Environment env_;
// map from expression to checked type // map from expression to checked type
// type inferencer will populate it up // type inferencer will populate it up
std::unordered_map<Expr, Type, NodeHash, NodeEqual> type_map_; std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual> type_map_;
// The solver used by the inferencer. // The solver used by the inferencer.
TypeSolver solver_; TypeSolver solver_;
// relation function // relation function
...@@ -111,11 +123,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -111,11 +123,12 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
// will call visit to deduce it if it is not in the type_map_ // will call visit to deduce it if it is not in the type_map_
Type GetType(const Expr &expr) { Type GetType(const Expr &expr) {
auto it = type_map_.find(expr); auto it = type_map_.find(expr);
if (it != type_map_.end()) { if (it != type_map_.end() && it->second.checked_type.defined()) {
return it->second; return it->second.checked_type;
} }
Type ret = this->VisitExpr(expr); Type ret = this->VisitExpr(expr);
type_map_[expr] = ret; ResolvedTypeInfo& rti = type_map_[expr];
rti.checked_type = ret;
return ret; return ret;
} }
...@@ -176,7 +189,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -176,7 +189,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
} }
CHECK(!type_map_.count(op->var)); CHECK(!type_map_.count(op->var));
// NOTE: no scoping is necessary because var are unique in program // NOTE: no scoping is necessary because var are unique in program
type_map_[op->var] = vtype; type_map_[op->var].checked_type = vtype;
return GetType(op->body); return GetType(op->body);
} }
...@@ -224,6 +237,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -224,6 +237,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
subst_map.Set(ty_param, fresh); subst_map.Set(ty_param, fresh);
ty_args->push_back(fresh); ty_args->push_back(fresh);
} }
Type ret_type = fn_ty->ret_type; Type ret_type = fn_ty->ret_type;
// If the function type is incomplete, place a new IncompleteType // If the function type is incomplete, place a new IncompleteType
...@@ -234,6 +248,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -234,6 +248,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
if (!ret_type.defined()) { if (!ret_type.defined()) {
ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType); ret_type = IncompleteTypeNode::make(TypeVarNode::Kind::kType);
} }
Type inst_ty = FuncTypeNode::make(fn_ty->arg_types, Type inst_ty = FuncTypeNode::make(fn_ty->arg_types,
ret_type, {}, ret_type, {},
fn_ty->type_constraints); fn_ty->type_constraints);
...@@ -241,49 +256,74 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -241,49 +256,74 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return Downcast<FuncType>(inst_ty); return Downcast<FuncType>(inst_ty);
} }
void AddTypeArgs(const Expr& expr, Array<Type> type_args) {
auto type_info = type_map_.find(expr);
if (type_info == type_map_.end()) {
type_map_.insert({expr, ResolvedTypeInfo(Type(), type_args)});
} else {
CHECK(!type_info->second.type_args.defined());
type_info->second.type_args = type_args;
}
}
// Handle general call node. // Handle general call node.
Type GeneralCall(const CallNode* op, Array<Type> arg_types) { Type GeneralCall(const CallNode* call, Array<Type> arg_types) {
Type ftype = GetType(op->op); Type ftype = GetType(call->op);
auto* fn_ty_node = ftype.as<FuncTypeNode>(); auto* fn_ty_node = ftype.as<FuncTypeNode>();
CHECK(fn_ty_node != nullptr) CHECK(fn_ty_node != nullptr)
<< "only expressions with function types can be called, at " << "only expressions with function types can be called, at "
<< op->span; << call->span;
Array<Type> type_args; Array<Type> type_args;
FuncType fn_ty = Instantiate(fn_ty_node, &type_args); FuncType fn_ty = Instantiate(fn_ty_node, &type_args);
AddTypeArgs(GetRef<Call>(call), type_args);
size_t type_arity = fn_ty->arg_types.size(); size_t type_arity = fn_ty->arg_types.size();
size_t number_of_args = arg_types.size(); size_t number_of_args = arg_types.size();
if (type_arity != number_of_args) { if (type_arity != number_of_args) {
if (type_arity < number_of_args) { if (type_arity < number_of_args) {
LOG(FATAL) << "the function is provided too many arguments " << op->span; LOG(FATAL) << "the function is provided too many arguments " << call->span;
} else { } else {
LOG(FATAL) << "the function is provided too few arguments" << op->span; LOG(FATAL) << "the function is provided too few arguments" << call->span;
} }
} }
for (size_t i = 0; i < fn_ty->arg_types.size(); i++) { for (size_t i = 0; i < fn_ty->arg_types.size(); i++) {
this->Unify(fn_ty->arg_types[i], arg_types[i], op->args[i]->span); this->Unify(fn_ty->arg_types[i], arg_types[i], call->args[i]->span);
} }
for (auto cs : fn_ty->type_constraints) { for (auto cs : fn_ty->type_constraints) {
if (auto tr = cs.as<TypeRelationNode>()) {
solver_.AddConstraint(
TypeRelationNode::make(tr->func, tr->args, tr->num_inputs, call->attrs));
} else {
solver_.AddConstraint(cs); solver_.AddConstraint(cs);
} }
}
return fn_ty->ret_type; return fn_ty->ret_type;
} }
Type VisitExpr_(const CallNode* op) final { Type VisitExpr_(const CallNode* call) final {
// Fast path: well-formed primitive op
Array<Type> arg_types; Array<Type> arg_types;
for (Expr arg : op->args) { for (Expr arg : call->args) {
arg_types.push_back(GetType(arg)); arg_types.push_back(GetType(arg));
} }
if (const OpNode* opnode = op->op.as<OpNode>()) {
if (const OpNode* opnode = call->op.as<OpNode>()) {
Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(), Type rtype = PrimitiveCall(opnode->op_type.as<FuncTypeNode>(),
arg_types, arg_types,
op->attrs); call->attrs);
if (rtype.defined()) return rtype; if (rtype.defined()) {
AddTypeArgs(GetRef<Call>(call), arg_types);
return rtype;
} }
return GeneralCall(op, arg_types); }
return GeneralCall(call, arg_types);
} }
Type VisitExpr_(const FunctionNode* f) final { Type VisitExpr_(const FunctionNode* f) final {
...@@ -312,7 +352,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> { ...@@ -312,7 +352,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
class TypeInferencer::Resolver : public ExprMutator { class TypeInferencer::Resolver : public ExprMutator {
public: public:
Resolver(const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap, Resolver(const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap,
TypeSolver* solver) TypeSolver* solver)
: tmap_(tmap), solver_(solver) { : tmap_(tmap), solver_(solver) {
} }
...@@ -362,7 +402,7 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -362,7 +402,7 @@ class TypeInferencer::Resolver : public ExprMutator {
Expr AttachCheckedType(const T* op) { Expr AttachCheckedType(const T* op) {
auto it = tmap_.find(GetRef<Expr>(op)); auto it = tmap_.find(GetRef<Expr>(op));
CHECK(it != tmap_.end()); CHECK(it != tmap_.end());
Type checked_type = solver_->Resolve(it->second); Type checked_type = solver_->Resolve(it->second.checked_type);
CHECK(checked_type.as<IncompleteTypeNode>() == nullptr) CHECK(checked_type.as<IncompleteTypeNode>() == nullptr)
<< "Cannot resolve type of " << GetRef<Expr>(op) << "Cannot resolve type of " << GetRef<Expr>(op)
<< " at " << op->span; << " at " << op->span;
...@@ -376,25 +416,37 @@ class TypeInferencer::Resolver : public ExprMutator { ...@@ -376,25 +416,37 @@ class TypeInferencer::Resolver : public ExprMutator {
} }
new_e->checked_type_ = checked_type; new_e->checked_type_ = checked_type;
} }
if (it->second.type_args.defined()) {
Call call = Downcast<Call>(new_e);
const CallNode* const_call_ref = call.operator->();
CallNode* call_ref = const_cast<CallNode*>(const_call_ref);
call_ref->type_args = it->second.type_args;
for (size_t i = 0; i < call->type_args.size(); i++) {
call_ref->type_args.Set(i, solver_->Resolve(call->type_args[i]));
}
}
return new_e; return new_e;
} }
Type VisitType(const Type& t) final { Type VisitType(const Type &t) final {
return solver_->Resolve(t); return solver_->Resolve(t);
} }
private: private:
const std::unordered_map<Expr, Type, NodeHash, NodeEqual>& tmap_; const std::unordered_map<Expr, ResolvedTypeInfo, NodeHash, NodeEqual>& tmap_;
TypeSolver* solver_; TypeSolver* solver_;
}; };
Expr TypeInferencer::Infer(Expr expr) { Expr TypeInferencer::Infer(Expr expr) {
// step 0: populate the constraints // Step 0: Populate the constraints.
GetType(expr); GetType(expr);
// step 1: solve the constraints // Step 1: Solve the constraints.
solver_.Solve(); solver_.Solve();
// step 2: attach resolved types to checked_type field // Step 2: Attach resolved types to checked_type field.
return Resolver(type_map_, &solver_).VisitExpr(expr); return Resolver(type_map_, &solver_).VisitExpr(expr);
} }
......
...@@ -91,6 +91,21 @@ def test_free_expr(): ...@@ -91,6 +91,21 @@ def test_free_expr():
yy = relay.ir_pass.infer_type(y) yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == relay.scalar_type("float32") assert yy.checked_type == relay.scalar_type("float32")
def test_type_args():
x = relay.var("x", shape=(10, 10))
y = relay.var("y", shape=(1, 10))
z = relay.add(x, y)
ty_z = relay.ir_pass.infer_type(z)
ty_args = ty_z.type_args
assert len(ty_args) == 2
assert ty_args[0].dtype == "float32"
assert ty_args[1].dtype == "float32"
sh1 = ty_args[0].shape
sh2 = ty_args[1].shape
assert sh1[0].value == 10
assert sh1[1].value == 10
assert sh2[0].value == 1
assert sh2[1].value == 10
if __name__ == "__main__": if __name__ == "__main__":
test_free_expr() test_free_expr()
...@@ -100,3 +115,5 @@ if __name__ == "__main__": ...@@ -100,3 +115,5 @@ if __name__ == "__main__":
test_decl() test_decl()
test_recursion() test_recursion()
test_tuple() test_tuple()
test_free_expr()
test_type_args()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment