Commit b0b51f25 by Tianqi Chen Committed by Zhi

[REFACTOR][IR] attrs.h -> ir (#4709)

This PR moves attrs.h into the ir folder as it
can serve as a common infra for building ir dats structures.

We also moved common container(FloatImm) into ir/expr.h
parent 83da72f2
......@@ -677,13 +677,13 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
}
}
if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
if (t.is_float()) return FloatImm(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
return ir::FloatImmNode::make(t, static_cast<double>(value));
return FloatImm(t, static_cast<double>(value));
}
LOG(FATAL) << "cannot make const for type " << t;
return PrimExpr();
......
......@@ -37,25 +37,9 @@ namespace tvm {
namespace ir {
using IntImmNode = tvm::IntImmNode;
using FloatImmNode = tvm::FloatImmNode;
using VarNode = tvm::VarNode;
/*! \brief Floating point constants. */
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
TVM_DLL static PrimExpr make(DataType t, double value);
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
/*! \brief String constants, only used in asserts. */
class StringImmNode : public PrimExprNode {
public:
......
......@@ -16,10 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/attrs.h
* \brief TVM attribute module
* \file tvm/ir/attrs.h
* \brief Helpers for attribute objects.
*
* This module enables declaration of named attributes
* which support default value setup and bound checking.
......@@ -42,20 +41,19 @@
*
* \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
*/
#ifndef TVM_ATTRS_H_
#define TVM_ATTRS_H_
#ifndef TVM_IR_ATTRS_H_
#define TVM_IR_ATTRS_H_
#include <dmlc/common.h>
#include <tvm/ir/expr.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_map>
#include <vector>
#include <functional>
#include <type_traits>
#include <string>
#include <utility>
#include "ir.h"
#include "base.h"
#include "expr.h"
#include "packed_func_ext.h"
namespace tvm {
/*!
......@@ -481,45 +479,36 @@ template<typename T>
inline void SetValue(T* ptr, const TVMArgValue& val) {
*ptr = val.operator T();
}
template<typename T>
inline void SetIntValue(T* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLInt) {
*ptr = static_cast<T>(val.value().v_int64);
} else {
PrimExpr expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
}
IntImm expr = val;
*ptr = static_cast<T>(expr->value);
}
}
template<>
inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) {
*ptr = val.operator std::string();
} else {
PrimExpr expr = val;
const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
LOG(FATAL) << "Expect str";
}
}
template<>
inline void SetValue(DataType* ptr, const TVMArgValue& val) {
*ptr = val.operator DataType();
}
template<>
inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
*ptr = val.operator double();
} else {
PrimExpr expr = val;
ObjectRef expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
if (const IntImmNode* op = expr.as<IntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
} else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
......@@ -611,7 +600,7 @@ struct TypeName<uint64_t> {
template<>
struct TypeName<DataType> {
static constexpr const char* value = "Type";
static constexpr const char* value = "DataType";
};
template<>
......@@ -872,4 +861,4 @@ inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
}
} // namespace tvm
#endif // TVM_ATTRS_H_
#endif // TVM_IR_ATTRS_H_
......@@ -182,6 +182,56 @@ class IntImm : public PrimExpr {
};
/*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
*/
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}
static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};
/*!
* \brief Managed reference class to FloatImmNode.
*
* \sa FloatImmNode
*/
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
FloatImm() {}
/*!
* \brief constructor from node.
*/
explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL FloatImm(DataType dtype, double value);
/*!
* \brief Get pointer to the container.
* \return The pointer.
*/
const FloatImmNode* operator->() const {
return static_cast<const FloatImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = FloatImmNode;
};
/*!
* \brief Base node of all non-primitive expressions.
*
* RelayExpr supports tensor types, functions and ADT as
......
......@@ -26,7 +26,7 @@
#define TVM_IR_OP_H_
#include <dmlc/registry.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
......@@ -296,7 +296,8 @@ class OpRegistry {
// return internal pointer to op.
inline OpNode* get();
// update the attribute OpMap
TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value,
TVM_DLL void UpdateAttr(const std::string& key,
runtime::TVMRetValue value,
int plevel);
};
......@@ -316,7 +317,7 @@ class GenericOpMap {
* \param op The key to the map
* \return the const reference to the content value.
*/
inline const TVMRetValue& operator[](const Op& op) const;
inline const runtime::TVMRetValue& operator[](const Op& op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
......@@ -342,7 +343,7 @@ class GenericOpMap {
// the attribute field.
std::string attr_name_;
// internal data
std::vector<std::pair<TVMRetValue, int> > data_;
std::vector<std::pair<runtime::TVMRetValue, int> > data_;
// The value
GenericOpMap() = default;
};
......@@ -532,7 +533,7 @@ template <typename ValueType>
inline OpRegistry& OpRegistry::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value, int plevel) {
CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
TVMRetValue rv;
runtime::TVMRetValue rv;
rv = value;
UpdateAttr(attr_name, rv, plevel);
return *this;
......@@ -548,7 +549,8 @@ inline int GenericOpMap::count(const Op& op) const {
}
}
inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const {
inline const runtime::TVMRetValue&
GenericOpMap::operator[](const Op& op) const {
CHECK(op.defined());
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second != 0)
......
......@@ -27,7 +27,7 @@
#include <tvm/ir/type.h>
#include <tvm/ir/module.h>
#include <tvm/ir/env_func.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
namespace tvm {
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/adt.h>
#include <string>
#include <functional>
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_ALGORITHM_H_
#define TVM_RELAY_ATTRS_ALGORITHM_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
#define TVM_RELAY_ATTRS_ANNOTATION_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>
namespace tvm {
......
......@@ -25,7 +25,7 @@
#ifndef TVM_RELAY_ATTRS_BITSERIAL_H_
#define TVM_RELAY_ATTRS_BITSERIAL_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_DEBUG_H_
#define TVM_RELAY_ATTRS_DEBUG_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>
namespace tvm {
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>
namespace tvm {
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_IMAGE_H_
#define TVM_RELAY_ATTRS_IMAGE_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_MEMORY_H_
#define TVM_RELAY_ATTRS_MEMORY_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/expr.h>
#include <string>
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_NN_H_
#define TVM_RELAY_ATTRS_NN_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_REDUCE_H_
#define TVM_RELAY_ATTRS_REDUCE_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>
namespace tvm {
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
#define TVM_RELAY_ATTRS_TRANSFORM_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>
#include <string>
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_VISION_H_
#define TVM_RELAY_ATTRS_VISION_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_EXPR_H_
#define TVM_RELAY_EXPR_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/expr.h>
#include <string>
#include <functional>
......
......@@ -24,7 +24,7 @@
#ifndef TVM_RELAY_QNN_ATTRS_H_
#define TVM_RELAY_QNN_ATTRS_H_
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>
namespace tvm {
......
......@@ -26,6 +26,7 @@
#include <tvm/ir/type.h>
#include <tvm/ir/type_relation.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir/env_func.h>
......@@ -34,7 +35,7 @@
#include <string>
#include "base.h"
#include "../attrs.h"
namespace tvm {
namespace relay {
......
......@@ -130,7 +130,6 @@ TVM_REGISTER_GLOBAL("make.CommReducer")
REGISTER_MAKE(Reduce);
REGISTER_MAKE(AttrStmt);
REGISTER_MAKE(FloatImm);
REGISTER_MAKE(StringImm);
REGISTER_MAKE(Add);
......
......@@ -23,7 +23,7 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/runtime/registry.h>
......
......@@ -23,7 +23,7 @@
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/env_func.h>
#include <tvm/packed_func_ext.h>
......
......@@ -102,7 +102,7 @@ inline PrimExpr TryConstFold<ir::AddNode>(PrimExpr a, PrimExpr b) {
if (pa && pb) return IntImm(rtype, pa->value + pb->value);
if (pa && pa->value == 0) return b;
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImmNode::make(rtype, fa->value + fb->value);
if (fa && fb) return FloatImm(rtype, fa->value + fb->value);
if (fa && fa->value == 0) return b;
if (fb && fb->value == 0) return a;
});
......@@ -115,7 +115,7 @@ inline PrimExpr TryConstFold<ir::SubNode>(PrimExpr a, PrimExpr b) {
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, pa->value - pb->value);
if (pb && pb->value == 0) return a;
if (fa && fb) return FloatImmNode::make(rtype, fa->value - fb->value);
if (fa && fb) return FloatImm(rtype, fa->value - fb->value);
if (fb && fb->value == 0) return a;
});
return PrimExpr();
......@@ -134,7 +134,7 @@ inline PrimExpr TryConstFold<ir::MulNode>(PrimExpr a, PrimExpr b) {
if (pb->value == 1) return a;
if (pb->value == 0) return b;
}
if (fa && fb) return FloatImmNode::make(rtype, fa->value * fb->value);
if (fa && fb) return FloatImm(rtype, fa->value * fb->value);
if (fa) {
if (fa->value == 1) return b;
if (fa->value == 0) return a;
......@@ -165,7 +165,7 @@ inline PrimExpr TryConstFold<ir::DivNode>(PrimExpr a, PrimExpr b) {
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImmNode::make(rtype, fa->value / fb->value);
return FloatImm(rtype, fa->value / fb->value);
}
if (fa && fa->value == 0) return a;
if (fb) {
......@@ -210,7 +210,7 @@ inline PrimExpr TryConstFold<ir::FloorDivNode>(PrimExpr a, PrimExpr b) {
CHECK_NE(pb->value, 0) << "Divide by zero";
}
if (fa && fb && fb->value != 0) {
return FloatImmNode::make(rtype, std::floor(fa->value / fb->value));
return FloatImm(rtype, std::floor(fa->value / fb->value));
}
if (fa && fa->value == 0) return a;
if (fb) {
......@@ -244,7 +244,7 @@ inline PrimExpr TryConstFold<ir::MinNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::min(pa->value, pb->value));
if (fa && fb) return FloatImmNode::make(rtype, std::min(fa->value, fb->value));
if (fa && fb) return FloatImm(rtype, std::min(fa->value, fb->value));
});
if (a.same_as(b)) return a;
return PrimExpr();
......@@ -255,7 +255,7 @@ inline PrimExpr TryConstFold<ir::MaxNode>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
const DataType& rtype = a.dtype();
if (pa && pb) return IntImm(rtype, std::max(pa->value, pb->value));
if (fa && fb) return FloatImmNode::make(rtype, std::max(fa->value, fb->value));
if (fa && fb) return FloatImm(rtype, std::max(fa->value, fb->value));
});
if (a.same_as(b)) return a;
return PrimExpr();
......
......@@ -255,10 +255,10 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
feature_row.push_back(Array<PrimExpr>{std::string("_itervar_"), var});
Array<PrimExpr> attr{std::string("_attr_"),
FloatImmNode::make(DataType::Float(32), trans(fea.length)),
FloatImm(DataType::Float(32), trans(fea.length)),
IntImm(DataType::Int(32), fea.nest_level),
FloatImmNode::make(DataType::Float(32), trans(fea.topdown_product)),
FloatImmNode::make(DataType::Float(32), trans(fea.bottomup_product)),
FloatImm(DataType::Float(32), trans(fea.topdown_product)),
FloatImm(DataType::Float(32), trans(fea.bottomup_product)),
};
// one hot annotation
for (int i = 0; i < kNum; i++) {
......@@ -268,9 +268,9 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
// arithmetic
feature_row.push_back(Array<PrimExpr>{std::string("_arith_"),
FloatImmNode::make(DataType::Float(32), trans(fea.add_ct)),
FloatImmNode::make(DataType::Float(32), trans(fea.mul_ct)),
FloatImmNode::make(DataType::Float(32), trans(fea.div_ct)),
FloatImm(DataType::Float(32), trans(fea.add_ct)),
FloatImm(DataType::Float(32), trans(fea.mul_ct)),
FloatImm(DataType::Float(32), trans(fea.div_ct)),
});
// touch map
......@@ -283,12 +283,12 @@ void GetItervarFeature(Stmt stmt, bool take_log, Array<Array<Array<PrimExpr> > >
TouchPattern &v = fea.touch_feature[k];
feature_row.push_back(
Array<PrimExpr>{k,
FloatImmNode::make(DataType::Float(32), trans(v.stride)),
FloatImmNode::make(DataType::Float(32), trans(v.mod)),
FloatImmNode::make(DataType::Float(32), trans(v.count)),
FloatImmNode::make(DataType::Float(32), trans(v.reuse)),
FloatImmNode::make(DataType::Float(32), trans(v.thread_count)),
FloatImmNode::make(DataType::Float(32), trans(v.thread_reuse)),
FloatImm(DataType::Float(32), trans(v.stride)),
FloatImm(DataType::Float(32), trans(v.mod)),
FloatImm(DataType::Float(32), trans(v.count)),
FloatImm(DataType::Float(32), trans(v.reuse)),
FloatImm(DataType::Float(32), trans(v.thread_count)),
FloatImm(DataType::Float(32), trans(v.thread_reuse)),
});
}
......
......@@ -95,7 +95,7 @@ llvm::Value* CodeGenX86_64::VisitExpr_(const CastNode* op) {
ir::CallNode::PureIntrinsic)),
MakeValue(
ir::BroadcastNode::make(
ir::FloatImmNode::make(DataType::Float(32), 0), from.lanes())),
FloatImm(DataType::Float(32), 0), from.lanes())),
/*mask=*/MakeValue(IntImm(DataType::Int(16), -1)),
/*rounding-mode=*/MakeValue(IntImm(DataType::Int(32), 4)),
});
......
......@@ -27,8 +27,8 @@
* - array of attributes
* - map of attributes
*/
#ifndef TVM_LANG_ATTR_FUNCTOR_H_
#define TVM_LANG_ATTR_FUNCTOR_H_
#ifndef TVM_IR_ATTR_FUNCTOR_H_
#define TVM_IR_ATTR_FUNCTOR_H_
#include <tvm/node/functor.h>
#include <utility>
......@@ -230,4 +230,4 @@ class AttrsHashHandler :
}
};
} // namespace tvm
#endif // TVM_LANG_ATTR_FUNCTOR_H_
#endif // TVM_IR_ATTR_FUNCTOR_H_
......@@ -20,7 +20,7 @@
/*!
* \file attrs.cc
*/
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
......
......@@ -45,6 +45,22 @@ TVM_REGISTER_GLOBAL("make.IntImm")
return IntImm(dtype, value);
});
FloatImm::FloatImm(DataType dtype, double value) {
CHECK_EQ(dtype.lanes(), 1)
<< "ValueError: FloatImm can only take scalar.";
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
node->dtype = dtype;
node->value = value;
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("make.FloatImm")
.set_body_typed([](DataType dtype, double value) {
return FloatImm(dtype, value);
});
GlobalVar::GlobalVar(std::string name_hint) {
ObjectPtr<GlobalVarNode> n = make_object<GlobalVarNode>();
n->name_hint = std::move(name_hint);
......
......@@ -36,6 +36,10 @@ DMLC_REGISTRY_ENABLE(::tvm::OpRegistry);
namespace tvm {
using runtime::TVMRetValue;
using runtime::TVMArgs;
using runtime::PackedFunc;
::dmlc::Registry<OpRegistry>* OpRegistry::Registry() {
return ::dmlc::Registry<OpRegistry>::Get();
}
......
......@@ -33,7 +33,7 @@ PrimExpr::PrimExpr(int32_t value)
: PrimExpr(IntImm(DataType::Int(32), value)) {}
PrimExpr::PrimExpr(float value)
: PrimExpr(ir::FloatImmNode::make(DataType::Float(32), value)) {}
: PrimExpr(FloatImm(DataType::Float(32), value)) {}
PrimExpr::PrimExpr(std::string str)
: PrimExpr(ir::StringImmNode::make(str)) {}
......
......@@ -108,11 +108,11 @@ PrimExpr max_value(const DataType& dtype) {
}
} else if (dtype.is_float()) {
if (dtype.bits() == 64) {
return FloatImmNode::make(dtype, std::numeric_limits<double>::max());
return FloatImm(dtype, std::numeric_limits<double>::max());
} else if (dtype.bits() == 32) {
return FloatImmNode::make(dtype, std::numeric_limits<float>::max());
return FloatImm(dtype, std::numeric_limits<float>::max());
} else if (dtype.bits() == 16) {
return FloatImmNode::make(dtype, 65504.0);
return FloatImm(dtype, 65504.0);
}
}
LOG(FATAL) << "Cannot decide max_value for type" << dtype;
......@@ -134,11 +134,11 @@ PrimExpr min_value(const DataType& dtype) {
return IntImm(dtype, 0);
} else if (dtype.is_float()) {
if (dtype.bits() == 64) {
return FloatImmNode::make(dtype, std::numeric_limits<double>::lowest());
return FloatImm(dtype, std::numeric_limits<double>::lowest());
} else if (dtype.bits() == 32) {
return FloatImmNode::make(dtype, std::numeric_limits<float>::lowest());
return FloatImm(dtype, std::numeric_limits<float>::lowest());
} else if (dtype.bits() == 16) {
return FloatImmNode::make(dtype, -65504.0);
return FloatImm(dtype, -65504.0);
}
}
LOG(FATAL) << "Cannot decide min_value for type" << dtype;
......@@ -219,7 +219,7 @@ PrimExpr operator-(PrimExpr a) {
const IntImmNode* pa = a.as<IntImmNode>();
const FloatImmNode* fa = a.as<FloatImmNode>();
if (pa) return IntImm(a.dtype(), -pa->value);
if (fa) return ir::FloatImmNode::make(a.dtype(), -fa->value);
if (fa) return FloatImm(a.dtype(), -fa->value);
return make_zero(a.dtype()) - a;
}
......@@ -492,7 +492,7 @@ PrimExpr abs(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
return ir::FloatImmNode::make(x.dtype(), std::fabs(fx->value));
return FloatImm(x.dtype(), std::fabs(fx->value));
}
return ir::CallNode::make(x.dtype(), "fabs", {x}, ir::CallNode::PureIntrinsic);
} else if (x.dtype().is_uint()) {
......@@ -593,28 +593,28 @@ PrimExpr fmod(PrimExpr x, PrimExpr y) {
PrimExpr floor(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::floor(fx->value));
if (fx) return FloatImm(x.dtype(), std::floor(fx->value));
return ir::CallNode::make(x.dtype(), "floor", {x}, ir::CallNode::PureIntrinsic);
}
PrimExpr ceil(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::ceil(fx->value));
if (fx) return FloatImm(x.dtype(), std::ceil(fx->value));
return ir::CallNode::make(x.dtype(), "ceil", {x}, ir::CallNode::PureIntrinsic);
}
PrimExpr round(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
return ir::CallNode::make(x.dtype(), "round", {x}, ir::CallNode::PureIntrinsic);
}
PrimExpr nearbyint(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) return FloatImmNode::make(x.dtype(), std::nearbyint(fx->value));
if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value));
return ir::CallNode::make(x.dtype(), "nearbyint", {x}, ir::CallNode::PureIntrinsic);
}
......@@ -622,7 +622,7 @@ PrimExpr trunc(PrimExpr x) {
using ir::FloatImmNode;
const FloatImmNode* fx = x.as<FloatImmNode>();
if (fx) {
return FloatImmNode::make(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) :
std::floor(fx->value)));
}
return ir::CallNode::make(x.dtype(), "trunc", {x}, ir::CallNode::PureIntrinsic);
......
......@@ -32,7 +32,7 @@ namespace ir {
// constructors
PrimExpr FloatImmNode::make(DataType t, double value) {
PrimExpr FloatImm(DataType t, double value) {
CHECK_EQ(t.lanes(), 1)
<< "ValueError: FloatImm can only take scalar";
ObjectPtr<FloatImmNode> node = make_object<FloatImmNode>();
......
......@@ -25,10 +25,14 @@
#include <tvm/node/node.h>
#include <tvm/node/container.h>
#include <tvm/node/reflection.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
namespace tvm {
using runtime::TVMRetValue;
using runtime::TVMArgs;
using runtime::PackedFunc;
// Attr getter.
class AttrGetter : public AttrVisitor {
public:
......
......@@ -29,7 +29,7 @@
#include <tvm/node/container.h>
#include <tvm/node/reflection.h>
#include <tvm/node/serialization.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>
#include <map>
......
......@@ -25,7 +25,7 @@
#define TVM_PASS_STORAGE_ACCESS_H_
#include <tvm/ir.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_functor_ext.h>
#include <vector>
......
......@@ -29,7 +29,7 @@
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
#include "../../ir/attr_functor.h"
namespace tvm {
namespace relay {
......
......@@ -26,9 +26,9 @@
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
#include "../../ir/attr_functor.h"
namespace tvm {
namespace relay {
......
......@@ -38,7 +38,7 @@
#include "doc.h"
#include "type_functor.h"
#include "../pass/dependency_graph.h"
#include "../../lang/attr_functor.h"
#include "../../ir/attr_functor.h"
namespace tvm {
namespace relay {
......
......@@ -19,7 +19,9 @@
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/expr_operator.h>
#include <tvm/packed_func_ext.h>
#include <tvm/ir.h>
namespace tvm {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment