Unverified Commit d8f06020 by Tianqi Chen Committed by GitHub

[REFACTOR][IR] Unified IR Primitive Op and Registry (#4687)

This PR migrates relay's Op into the ir folder.
Op and its registry provides an useful mechanism to
store any attribute meta-data of an operator include
function signatures, lowering rules, side effect etc.

These features are not only useful for Relay, but also needed in the low-level IR.
At the current moment, intrinsic functions in the low-level IR are simply
represented by a string. This means we cannot type-check the low-level IR
when the type does not meet the constraint, nor can we obtain further
information such as side-effect and read write relation of these intrinsics
wrt to arguments.

Op will be used as the way to handle primitive ops(in DL terminology)
(builtin intrinsics or in compiler terminology).
We will perform follow-up refactors to make low-level CallNode
take Op as the function argument.
parent a2fe7a3e
......@@ -51,6 +51,7 @@
#include <tvm/runtime/object.h>
#include <tvm/node/node.h>
#include <tvm/node/env_func.h>
#include <tvm/node/container.h>
#include <tvm/ir/span.h>
#include <string>
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/ir/type_relation.h
* \brief Type relation function for type checking.
*/
#ifndef TVM_IR_TYPE_RELATION_H_
#define TVM_IR_TYPE_RELATION_H_
#include <tvm/ir/type.h>
#include <tvm/attrs.h>
namespace tvm {
// TODO(tqchen): remove after migrate Module to ir.
namespace relay {
struct Module;
}
/*!
* \brief reporter that reports back to the
* type resolution information.
*/
class TypeReporterNode : public Object {
public:
/*!
* \brief Create a type equality constraint.
*
* The "assign direction" acts as a hint to the solver
* showing that it is more likely to resolve dst by src.
* But it is possible for the solver to resolve src by dst as well.
*/
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
* \param cond The condition of operation.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL virtual bool Assert(const PrimExpr& cond)= 0;
/*!
* \brief assert shape expression equals each other.
* \param lhs The left operand.
* \param rhs The right operand.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL virtual bool AssertEQ(const PrimExpr& lhs, const PrimExpr& rhs) = 0;
/*!
* \brief Set the location at which to report unification errors.
* \param ref The program node to report the error.
*/
TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0;
/*!
* \brief Retrieve the current global module.
* \return The global module.
*/
TVM_DLL virtual relay::Module GetModule() = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
};
/*!
* \brief Container class of TypeReporter.
* \sa TypeReporterNode
*/
class TypeReporter : public ObjectRef {
public:
TypeReporter() {}
explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {
}
TypeReporterNode* operator->() const {
return const_cast<TypeReporterNode*>(
static_cast<const TypeReporterNode*>(get()));
}
using ContainerType = TypeReporterNode;
};
/*!
* \brief User defined type constraint function.
*
* If the input type information can be used to fully decide
* the IncompleteTypes, then the function should call
* reporter.Assign to report the new types, and return true.
* Otherwise, the function should return false.
*
* \param args The arguments to the relation.
* The types are stored in the form of
* [input_type_0, input_type_1, ... input_type_n,
* output_type_0, output_type_1, ... output_type_m]
*
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved.
* true if this relation has been resolved.
*/
using TypeRelationFn =
TypedEnvFunc<bool(const Array<Type>& args,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter)>;
/*!
* \brief User defined type relation, is an input-output relation on types.
*/
class TypeRelation;
/*!
* \brief TypeRelation container.
* \note This node is not directly serializable.
* The type function need to be lookedup in the module.
*/
class TypeRelationNode : public TypeConstraintNode {
public:
/*!
* \brief The function on input and output variables which
* this is not directly serializable,
* need to be looked-up in the module.
*/
TypeRelationFn func;
/*! \brief The type arguments to the type function. */
tvm::Array<Type> args;
/*! \brief Number of inputs arguments */
int num_inputs;
/*! \brief Attributes to the relation function */
Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
}
TVM_DLL static TypeRelation make(TypeRelationFn func,
Array<Type> args,
int num_args,
Attrs attrs);
static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
class TypeRelation : public TypeConstraint {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
};
} // namespace tvm
#endif // TVM_IR_TYPE_RELATION_H_
......@@ -24,8 +24,8 @@
#ifndef TVM_RELAY_TYPE_H_
#define TVM_RELAY_TYPE_H_
#include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/env_func.h>
......@@ -51,6 +51,11 @@ using TypeConstraint = tvm::TypeConstraint;
using TypeConstraintNode = tvm::TypeConstraintNode;
using FuncType = tvm::FuncType;
using FuncTypeNode = tvm::FuncTypeNode;
using TypeRelation = tvm::TypeRelation;
using TypeRelationNode = tvm::TypeRelationNode;
using TypeRelationFn = tvm::TypeRelationFn;
using TypeReporter = tvm::TypeReporter;
using TypeReporterNode = tvm::TypeReporterNode;
/*!
* \brief Base of all Tensor types
......@@ -235,146 +240,6 @@ class RefType : public Type {
TVM_DEFINE_OBJECT_REF_METHODS(RefType, Type, RefTypeNode);
};
class TypeReporter;
/*!
* \brief reporter that reports back to the
* type resolution information.
*/
class TypeReporterNode : public Object {
public:
/*!
* \brief Create a type equality constraint.
*
* The "assign direction" acts as a hint to the solver
* showing that it is more likely to resolve dst by src.
* But it is possible for the solver to resolve src by dst as well.
*/
TVM_DLL virtual void Assign(const Type& dst, const Type& src) = 0;
/*!
* \brief assert shape expression comparison.
* \note Use assert only if any of the condition input is symbolic.
* \param cond The condition of operation.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL virtual bool Assert(const IndexExpr& cond)= 0;
/*!
* \brief assert shape expression equals each other.
* \param lhs The left operand.
* \param rhs The right operand.
* \return false if assertation can be proven to have failed
* true if solver can still proceed.
*/
TVM_DLL virtual bool AssertEQ(const IndexExpr& lhs, const IndexExpr& rhs) = 0;
/*!
* \brief Set the location at which to report unification errors.
* \param ref The program node to report the error.
*/
TVM_DLL virtual void SetLocation(const ObjectRef& ref) = 0;
/*!
* \brief Retrieve the current global module.
* \return The global module.
*/
TVM_DLL virtual Module GetModule() = 0;
// solver is not serializable.
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "relay.TypeReporter";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeReporterNode, Object);
};
/*!
* \brief Container class of TypeReporter.
* \sa TypeReporterNode
*/
class TypeReporter : public ObjectRef {
public:
TypeReporter() {}
explicit TypeReporter(::tvm::ObjectPtr<::tvm::Object> n) : ObjectRef(n) {
}
TypeReporterNode* operator->() const {
return const_cast<TypeReporterNode*>(
static_cast<const TypeReporterNode*>(get()));
}
using ContainerType = TypeReporterNode;
};
/*!
* \brief User defined type constraint function.
*
* If the input type information can be used to fully decide
* the IncompleteTypes, then the function should call
* reporter.Assign to report the new types, and return true.
* Otherwise, the function should return false.
*
* \param args The arguments to the relation.
* The types are stored in the form of
* [input_type_0, input_type_1, ... input_type_n,
* output_type_0, output_type_1, ... output_type_m]
*
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved.
* true if this relation has been resolved.
*/
using TypeRelationFn =
TypedEnvFunc<bool(const Array<Type>& args,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter)>;
/*!
* \brief User defined type relation, is an input-output relation on types.
*/
class TypeRelation;
/*!
* \brief TypeRelation container.
* \note This node is not directly serializable.
* The type function need to be lookedup in the module.
*/
class TypeRelationNode : public TypeConstraintNode {
public:
/*!
* \brief The function on input and output variables which
* this is not directly serializable,
* need to be looked-up in the module.
*/
TypeRelationFn func;
/*! \brief The type arguments to the type function. */
tvm::Array<Type> args;
/*! \brief Number of inputs arguments */
int num_inputs;
/*! \brief Attributes to the relation function */
Attrs attrs;
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("func", &func);
v->Visit("args", &args);
v->Visit("num_inputs", &num_inputs);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
}
TVM_DLL static TypeRelation make(TypeRelationFn func,
Array<Type> args,
int num_args,
Attrs attrs);
static constexpr const char* _type_key = "relay.TypeRelation";
TVM_DECLARE_FINAL_OBJECT_INFO(TypeRelationNode, TypeConstraintNode);
};
class TypeRelation : public TypeConstraint {
public:
TVM_DEFINE_OBJECT_REF_METHODS(TypeRelation, TypeConstraint, TypeRelationNode);
};
// The following fields contains advanced typing
// Only keep the class name and reserved for future usage.
class GenericTensorType;
......
......@@ -18,11 +18,11 @@
*/
/*!
* \file src/tvm/relay/op.cc
* \brief Resolve incomplete types to complete types.
* \file src/tvm/ir/op.cc
* \brief Primitive operators and intrinsics.
*/
#include <tvm/relay/op.h>
#include <tvm/relay/type.h>
#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
......@@ -31,11 +31,10 @@
namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
DMLC_REGISTRY_ENABLE(::tvm::OpRegistry);
} // namespace dmlc
namespace tvm {
namespace relay {
::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
return ::dmlc::Registry<OpRegistry>::Get();
......@@ -230,5 +229,4 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "Op(" << node->name << ")";
});
} // namespace relay
} // namespace tvm
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/tvm/ir/type_relation.cc
* \brief Type relation
*/
#include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
namespace tvm {
TypeRelation TypeRelationNode::make(TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs) {
ObjectPtr<TypeRelationNode> n = make_object<TypeRelationNode>();
n->func = std::move(func);
n->args = std::move(args);
n->num_inputs = num_inputs;
n->attrs = std::move(attrs);
return TypeRelation(n);
}
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
.set_body_typed(TypeRelationNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeRelationNode*>(ref.get());
p->stream << "TypeRelationNode("
<< node->func->name
<< ", " << node->args << ")";
});
} // namespace tvm
......@@ -101,30 +101,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
p->stream << "IncompleteTypeNode(" << node->kind << ", " << node << ")";
});
TypeRelation TypeRelationNode::make(TypeRelationFn func,
Array<Type> args,
int num_inputs,
Attrs attrs) {
ObjectPtr<TypeRelationNode> n = make_object<TypeRelationNode>();
n->func = std::move(func);
n->args = std::move(args);
n->num_inputs = num_inputs;
n->attrs = std::move(attrs);
return TypeRelation(n);
}
TVM_REGISTER_NODE_TYPE(TypeRelationNode);
TVM_REGISTER_GLOBAL("relay._make.TypeRelation")
.set_body_typed(TypeRelationNode::make);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<TypeRelationNode>([](const ObjectRef& ref, NodePrinter* p) {
auto* node = static_cast<const TypeRelationNode*>(ref.get());
p->stream << "TypeRelationNode("
<< node->func->name
<< ", " << node->args << ")";
});
TupleType TupleTypeNode::make(Array<Type> fields) {
ObjectPtr<TupleTypeNode> n = make_object<TupleTypeNode>();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment