Commit 19f8c123 by 雾雨魔理沙 Committed by Haichen Shen

[Relay][Op] Make Type Relation catch more errors (#3899)

* save

* init

* move type_relations
parent ca0292d8
...@@ -30,6 +30,7 @@ ...@@ -30,6 +30,7 @@
#include <tvm/relay/op.h> #include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h> #include <tvm/relay/op_attr_types.h>
#include <vector> #include <vector>
#include "type_relations.h"
#include "../pass/alter_op_layout.h" #include "../pass/alter_op_layout.h"
namespace tvm { namespace tvm {
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* \file transform.cc * \file transform.cc
* \brief Transform operators. * \brief Transform operators.
*/ */
...@@ -1541,7 +1541,6 @@ RELAY_REGISTER_OP("squeeze") ...@@ -1541,7 +1541,6 @@ RELAY_REGISTER_OP("squeeze")
.set_attr<TOpPattern>("TOpPattern", kInjective); .set_attr<TOpPattern>("TOpPattern", kInjective);
// Have no idea how to assert the constraint.
// CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A // CollapseSumLike: <A, B> -> B where BroadCast(A, B) = A
bool CollapseSumLikeRel(const Array<Type>& types, bool CollapseSumLikeRel(const Array<Type>& types,
int num_inputs, int num_inputs,
...@@ -1549,7 +1548,7 @@ bool CollapseSumLikeRel(const Array<Type>& types, ...@@ -1549,7 +1548,7 @@ bool CollapseSumLikeRel(const Array<Type>& types,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]); reporter->Assign(types[2], types[1]);
return true; return BroadcastRel({types[0], types[1], types[0]}, 2, Attrs(), reporter);
} }
Expr MakeCollapseSumLike(Expr data, Expr MakeCollapseSumLike(Expr data,
...@@ -1593,7 +1592,7 @@ bool BroadCastToRel(const Array<Type>& types, ...@@ -1593,7 +1592,7 @@ bool BroadCastToRel(const Array<Type>& types,
if (intt == nullptr) { return false; } if (intt == nullptr) { return false; }
auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype); auto type = TensorTypeNode::make(ioattrs->shape, intt->dtype);
reporter->Assign(types[1], type); reporter->Assign(types[1], type);
return true; return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
} }
Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) { Expr MakeBroadCastTo(Expr data, Array<IndexExpr> shape) {
...@@ -1632,7 +1631,7 @@ bool BroadCastToLikeRel(const Array<Type>& types, ...@@ -1632,7 +1631,7 @@ bool BroadCastToLikeRel(const Array<Type>& types,
const TypeReporter& reporter) { const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3); CHECK_EQ(types.size(), 3);
reporter->Assign(types[2], types[1]); reporter->Assign(types[2], types[1]);
return true; return BroadcastRel({types[0], types[1], types[1]}, 2, Attrs(), reporter);
} }
Expr MakeBroadCastToLike(Expr data, Expr MakeBroadCastToLike(Expr data,
...@@ -2493,9 +2492,9 @@ RELAY_REGISTER_OP("one_hot") ...@@ -2493,9 +2492,9 @@ RELAY_REGISTER_OP("one_hot")
**off_value** Value to fill at all other positions besides indices. **off_value** Value to fill at all other positions besides indices.
**depth** Depth of the one-hot dimension. **depth** Depth of the one-hot dimension.
**axis** Axis to fill. **axis** Axis to fill.
**dtype**)code" TVM_ADD_FILELINE) **dtype**)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.OneHotAttrs") .set_attrs_type_key("relay.attrs.OneHotAttrs")
.set_num_inputs(3) .set_num_inputs(3)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment