Unverified Commit fc0149d5 by Tianqi Chen Committed by GitHub

[ATTR] Introduce Integer container (#1994)

parent 5742995d
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include <string> #include <string>
#include "ir.h" #include "ir.h"
#include "base.h" #include "base.h"
#include "expr.h"
#include "packed_func_ext.h" #include "packed_func_ext.h"
namespace tvm { namespace tvm {
...@@ -73,7 +74,6 @@ inline Type NullValue<Type>() { ...@@ -73,7 +74,6 @@ inline Type NullValue<Type>() {
return Type(Type::Handle, 0, 0); return Type(Type::Handle, 0, 0);
} }
/*! \brief Error thrown during attribute checking. */ /*! \brief Error thrown during attribute checking. */
struct AttrError : public dmlc::Error { struct AttrError : public dmlc::Error {
/*! /*!
......
...@@ -29,6 +29,7 @@ using HalideIR::VarExpr; ...@@ -29,6 +29,7 @@ using HalideIR::VarExpr;
using HalideIR::IR::RangeNode; using HalideIR::IR::RangeNode;
using HalideIR::IR::FunctionRef; using HalideIR::IR::FunctionRef;
using HalideIR::IR::FunctionBaseNode; using HalideIR::IR::FunctionBaseNode;
using HalideIR::Internal::IntImm;
using HalideIR::Internal::Stmt; using HalideIR::Internal::Stmt;
using HalideIR::Internal::IRPrinter; using HalideIR::Internal::IRPrinter;
using HalideIR::Internal::Variable; using HalideIR::Internal::Variable;
...@@ -83,6 +84,51 @@ class Var : public HalideIR::VarExpr { ...@@ -83,6 +84,51 @@ class Var : public HalideIR::VarExpr {
}; };
/*!
* \brief Container of constant ineteger (IntImm).
*
* This is used to store and automate type check
* attributes that must be constant integer.
*/
class Integer : public Expr {
public:
Integer() : Expr() {}
/*!
* \brief constructor from node.
*/
explicit Integer(NodePtr<Node> node) : Expr(node) {}
/*!
* \brief Construct integer from int value.
*/
Integer(int value) : Expr(value) {} // NOLINT(*)
/*!
* \brief Assign an expression to integer.
* \param other another expression.
*/
Integer& operator=(const Integer& other) {
node_ = other.node_;
return *this;
}
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImm* operator->() const {
return static_cast<const IntImm*>(node_.get());
}
/*!
* \brief convert to int64_t
*/
operator int64_t() const {
CHECK(node_ != nullptr)
<< " Trying get reference a null Integer";
return (*this)->value;
}
/*! \brief type indicate the container type */
using ContainerType = IntImm;
};
/*! \brief container class of iteration variable. */ /*! \brief container class of iteration variable. */
class IterVarNode; class IterVarNode;
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <memory> #include <memory>
#include <limits>
#include <type_traits> #include <type_traits>
#include "base.h" #include "base.h"
...@@ -126,6 +127,8 @@ inline TNodeRef TVMArgValue::AsNodeRef() const { ...@@ -126,6 +127,8 @@ inline TNodeRef TVMArgValue::AsNodeRef() const {
inline TVMArgValue::operator HalideIR::Expr() const { inline TVMArgValue::operator HalideIR::Expr() const {
if (type_code_ == kNull) return Expr(); if (type_code_ == kNull) return Expr();
if (type_code_ == kDLInt) { if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Expr(static_cast<int>(value_.v_int64)); return Expr(static_cast<int>(value_.v_int64));
} }
if (type_code_ == kDLFloat) { if (type_code_ == kDLFloat) {
...@@ -145,6 +148,20 @@ inline TVMArgValue::operator HalideIR::Expr() const { ...@@ -145,6 +148,20 @@ inline TVMArgValue::operator HalideIR::Expr() const {
return Expr(sptr); return Expr(sptr);
} }
inline TVMArgValue::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
NodePtr<Node>& sptr = *ptr<NodePtr<Node> >();
CHECK(NodeTypeChecker<Integer>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key();
return Integer(sptr);
}
inline NodePtr<Node>& TVMArgValue::node_sptr() { inline NodePtr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<NodePtr<Node> >(); return *ptr<NodePtr<Node> >();
......
...@@ -317,7 +317,7 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> { ...@@ -317,7 +317,7 @@ struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
/*! \brief Attributes for LRN operator */ /*! \brief Attributes for LRN operator */
struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> { struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
IndexExpr size; IndexExpr size;
IndexExpr axis; int axis;
double bias; double bias;
double alpha; double alpha;
double beta; double beta;
...@@ -340,7 +340,7 @@ struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> { ...@@ -340,7 +340,7 @@ struct LRNAttrs : public tvm::AttrsNode<LRNAttrs> {
/*! \brief Attributes for L2Normalize operator */ /*! \brief Attributes for L2Normalize operator */
struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> { struct L2NormalizeAttrs : public tvm::AttrsNode<L2NormalizeAttrs> {
double eps; double eps;
Array<IndexExpr> axis; Array<Integer> axis;
TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") { TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") {
TVM_ATTR_FIELD(eps) TVM_ATTR_FIELD(eps)
......
...@@ -53,7 +53,7 @@ struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> { ...@@ -53,7 +53,7 @@ struct ConcatenateAttrs : public tvm::AttrsNode<ConcatenateAttrs> {
/*! \brief Attributes used in transpose operators */ /*! \brief Attributes used in transpose operators */
struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> { struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
Array<IndexExpr> axes; Array<Integer> axes;
TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") { TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
TVM_ATTR_FIELD(axes) TVM_ATTR_FIELD(axes)
.describe("The target axes order, reverse order if not specified."); .describe("The target axes order, reverse order if not specified.");
...@@ -70,10 +70,10 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> { ...@@ -70,10 +70,10 @@ struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
}; // struct ReshapeAttrs }; // struct ReshapeAttrs
struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> { struct TakeAttrs : public tvm::AttrsNode<TakeAttrs> {
IndexExpr axis; Integer axis;
TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") {
TVM_ATTR_FIELD(axis).set_default(NullValue<IndexExpr>()) TVM_ATTR_FIELD(axis).set_default(NullValue<Integer>())
.describe("The axis over which to select values."); .describe("The axis over which to select values.");
} }
}; };
......
...@@ -32,6 +32,9 @@ struct Expr; ...@@ -32,6 +32,9 @@ struct Expr;
#endif #endif
namespace tvm { namespace tvm {
// forward declarations
class Integer;
namespace runtime { namespace runtime {
// forward declarations // forward declarations
class TVMArgs; class TVMArgs;
...@@ -559,6 +562,7 @@ class TVMArgValue : public TVMPODValue_ { ...@@ -559,6 +562,7 @@ class TVMArgValue : public TVMPODValue_ {
inline bool IsNodeType() const; inline bool IsNodeType() const;
inline operator HalideIR::Type() const; inline operator HalideIR::Type() const;
inline operator HalideIR::Expr() const; inline operator HalideIR::Expr() const;
inline operator tvm::Integer() const;
// get internal node ptr, if it is node // get internal node ptr, if it is node
inline NodePtr<Node>& node_sptr(); inline NodePtr<Node>& node_sptr();
}; };
......
...@@ -317,7 +317,7 @@ TVM_REGISTER_NODE_TYPE(LRNAttrs); ...@@ -317,7 +317,7 @@ TVM_REGISTER_NODE_TYPE(LRNAttrs);
Expr MakeLRN(Expr data, Expr MakeLRN(Expr data,
IndexExpr size, IndexExpr size,
IndexExpr axis, int axis,
double alpha, double alpha,
double beta, double beta,
double bias) { double bias) {
...@@ -337,7 +337,7 @@ TVM_REGISTER_API("relay.op.nn._make.lrn") ...@@ -337,7 +337,7 @@ TVM_REGISTER_API("relay.op.nn._make.lrn")
}); });
RELAY_REGISTER_OP("nn.lrn") RELAY_REGISTER_OP("nn.lrn")
.describe(R"code(LRN layer. .describe(R"code(LRN layer.
Normalize the input in a local region across or within feature maps. Normalize the input in a local region across or within feature maps.
Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta, Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta,
...@@ -362,7 +362,7 @@ TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs); ...@@ -362,7 +362,7 @@ TVM_REGISTER_NODE_TYPE(L2NormalizeAttrs);
Expr MakeL2Normalize(Expr data, Expr MakeL2Normalize(Expr data,
double eps, double eps,
Array<IndexExpr> axis) { Array<Integer> axis) {
auto attrs = make_node<L2NormalizeAttrs>(); auto attrs = make_node<L2NormalizeAttrs>();
attrs->eps = eps; attrs->eps = eps;
attrs->axis = std::move(axis); attrs->axis = std::move(axis);
......
...@@ -218,24 +218,23 @@ bool TransposeRel(const Array<Type>& types, ...@@ -218,24 +218,23 @@ bool TransposeRel(const Array<Type>& types,
} }
const auto* param = attrs.as<TransposeAttrs>(); const auto* param = attrs.as<TransposeAttrs>();
const int ndim = data->shape.size(); const int ndim = data->shape.size();
const Array<IndexExpr>& axes = param->axes; const Array<Integer>& axes = param->axes;
// check dimension match // check dimension match
CHECK(axes.empty() || static_cast<int>(axes.size()) == ndim) CHECK(!axes.defined() || static_cast<int>(axes.size()) == ndim)
<< "Dimension mismatch: axes has " << axes.size() << " elements" << "Dimension mismatch: axes has " << axes.size() << " elements"
<< ", but data.ndim = " << ndim; << ", but data.ndim = " << ndim;
// construct int_axes // construct int_axes
std::vector<int> int_axes; std::vector<int> int_axes;
int_axes.reserve(ndim); int_axes.reserve(ndim);
if (axes.empty()) { // used not defined to check if it is None.
if (!axes.defined()) {
for (int i = ndim - 1; i >= 0; --i) { for (int i = ndim - 1; i >= 0; --i) {
int_axes.push_back(i); int_axes.push_back(i);
} }
} else { } else {
std::vector<int> axis_used(ndim, 0); std::vector<int> axis_used(ndim, 0);
for (const IndexExpr& e : axes) { for (const Integer& e : axes) {
const int64_t *axis_ptr = as_const_int(e); int64_t axis = e;
CHECK(axis_ptr != nullptr);
int axis = *axis_ptr;
// sanity check for axis and ndim // sanity check for axis and ndim
CHECK(-ndim <= axis && axis < ndim) CHECK(-ndim <= axis && axis < ndim)
<< "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)"
...@@ -245,7 +244,7 @@ bool TransposeRel(const Array<Type>& types, ...@@ -245,7 +244,7 @@ bool TransposeRel(const Array<Type>& types,
// sanity check for duplication // sanity check for duplication
CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis;
axis_used[axis] = 1; axis_used[axis] = 1;
int_axes.push_back(axis); int_axes.push_back(static_cast<int>(axis));
} }
} }
std::vector<IndexExpr> oshape; std::vector<IndexExpr> oshape;
...@@ -258,7 +257,7 @@ bool TransposeRel(const Array<Type>& types, ...@@ -258,7 +257,7 @@ bool TransposeRel(const Array<Type>& types,
} }
Expr MakeTranspose(Expr data, Expr MakeTranspose(Expr data,
Array<IndexExpr> axes) { Array<Integer> axes) {
auto attrs = make_node<TransposeAttrs>(); auto attrs = make_node<TransposeAttrs>();
attrs->axes = std::move(axes); attrs->axes = std::move(axes);
static const Op& op = Op::Get("transpose"); static const Op& op = Op::Get("transpose");
...@@ -401,7 +400,7 @@ bool TakeRel(const Array<Type>& types, ...@@ -401,7 +400,7 @@ bool TakeRel(const Array<Type>& types,
std::vector<IndexExpr> oshape; std::vector<IndexExpr> oshape;
const auto ndim_data = static_cast<int>(data->shape.size()); const auto ndim_data = static_cast<int>(data->shape.size());
const auto ndim_indices = static_cast<int>(indices->shape.size()); const auto ndim_indices = static_cast<int>(indices->shape.size());
auto axis = (*as_const_int(param->axis)); int axis = static_cast<int>(param->axis->value);
if (axis < 0) axis += ndim_data; if (axis < 0) axis += ndim_data;
CHECK_LE(axis, ndim_data) CHECK_LE(axis, ndim_data)
<< "axis should be with in data shape" << "axis should be with in data shape"
...@@ -424,9 +423,9 @@ bool TakeRel(const Array<Type>& types, ...@@ -424,9 +423,9 @@ bool TakeRel(const Array<Type>& types,
Expr MakeTake(Expr data, Expr MakeTake(Expr data,
Expr indices, Expr indices,
IndexExpr axis) { Integer axis) {
auto attrs = make_node<TakeAttrs>(); auto attrs = make_node<TakeAttrs>();
attrs->axis = axis; attrs->axis = std::move(axis);
static const Op& op = Op::Get("take"); static const Op& op = Op::Get("take");
return CallNode::make(op, {data, indices}, Attrs(attrs), {}); return CallNode::make(op, {data, indices}, Attrs(attrs), {});
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment