Commit e342fc36 by guoyuxuan

add ir

parent ff5c1a8e
......@@ -15,9 +15,8 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/var.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/var.h>
#include <tvm/tsl/tir/buffer.h>
#include <algorithm>
......@@ -30,7 +29,7 @@
namespace tvm {
namespace tir {
class TULoadNode:public PrimExprNode{
class TULoadNode : public PrimExprNode {
public:
TslDataProducer producer;
Array<PrimExpr> union_indices;
......@@ -39,7 +38,7 @@ class TULoadNode:public PrimExprNode{
v->Visit("producer", &producer);
v->Visit("union_indices", &union_indices);
}
//TODO:investigate SEQUAL/SHASH
// TODO:investigate SEQUAL/SHASH
bool SEqualReduce(const TULoadNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(producer, other->producer) &&
equal(union_indices, other->union_indices);
......@@ -55,16 +54,118 @@ class TULoadNode:public PrimExprNode{
TVM_DECLARE_FINAL_OBJECT_INFO(TULoadNode, PrimExprNode);
};
class TULoad:public PrimExpr{
class TULoad : public PrimExpr {
public:
TVM_DLL explicit TULoad(TslDataProducer producer, Array<PrimExpr> union_indices);
TVM_DEFINE_OBJECT_REF_METHODS(TULoad, PrimExpr, TULoadNode);
};
/* OpNode start (yuxguo)
* PrimExprNode -> (TslUnaryOpNode, TslBinaryOpNode) -> (TslTGemmOpNode, TslTAddOpNode, ...)
*/
template <typename T>
class TslBinaryOpNode : public PrimExprNode {
public:
/*! \brief The left operand. */
PrimExpr a;
/*! \brief The right operand. */
PrimExpr b;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("a", &a);
v->Visit("b", &b);
}
bool SEqualReduce(const T* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(a, other->a) && equal(b, other->b);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(a);
hash_reduce(b);
}
TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
};
template <typename T>
class TslUnaryOpNode : public PrimExprNode {
public:
/*! \brief The operand. */
PrimExpr a;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &(this->dtype));
v->Visit("a", &a);
}
bool SEqualReduce(const T* other, SEqualReducer equal) const {
return equal(dtype, other->dtype) && equal(a, other->a);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(a);
}
TVM_DECLARE_FINAL_OBJECT_INFO(T, PrimExprNode);
};
class TslTGemmNode : public TslBinaryOpNode<TslTGemmNode> {
public:
static constexpr const char* _type_key = "tir.TslTGemm";
};
class TslTAddNode : public TslBinaryOpNode<TslTAddNode> {
public:
static constexpr const char* _type_key = "tir.TslTAdd";
};
class TslTWriteNode : public TslBinaryOpNode<TslTWriteNode> {
public:
static constexpr const char* _type_key = "tir.TslTWrite";
};
class TslTStoreNode : public TslBinaryOpNode<TslTStoreNode> {
public:
static constexpr const char* _type_key = "tir.TslTStore";
};
/* Op start (yuxguo) manage opnode
* TslOperation (TslTensorGemmOp, TslTensorAddOp, ...)
* usage: using TslComputeOp to specify compute type and expanded in scheduleOps pass
*/
class TslTGemm : public PrimExpr {
public:
TVM_DLL TslTGemm(PrimExpr a, PrimExpr b);
TVM_DEFINE_OBJECT_REF_METHODS(TslTGemm, PrimExpr, TslTGemmNode);
};
class TslTAdd : public PrimExpr {
public:
TVM_DLL TslTAdd(PrimExpr a, PrimExpr b);
TVM_DEFINE_OBJECT_REF_METHODS(TslTAdd, PrimExpr, TslTAddNode);
};
class TslTWrite : public PrimExpr {
public:
TVM_DLL TslTWrite(PrimExpr a, PrimExpr b);
TVM_DEFINE_OBJECT_REF_METHODS(TslTWrite, PrimExpr, TslTWriteNode);
};
class TslTStore : public PrimExpr {
public:
TVM_DLL TslTStore(PrimExpr a, PrimExpr b);
TVM_DEFINE_OBJECT_REF_METHODS(TslTStore, PrimExpr, TslTStoreNode);
};
} // namespace tir
} // namespace tvm
#endif // TVM_TSL_TIR_EXPR_H_
......@@ -7,7 +7,6 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tsl/tir/expr.h>
#include <limits>
......@@ -18,21 +17,34 @@
namespace tvm {
namespace tir {
#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
Name::Name(PrimExpr a, PrimExpr b) { \
using T = Name::ContainerType; \
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \
ObjectPtr<T> node = make_object<T>(); \
node->dtype = a.dtype(); \
node->a = std::move(a); \
node->b = std::move(b); \
data_ = std::move(node); \
}
TULoad::TULoad(TslDataProducer producer, Array<PrimExpr> union_indices) {
ObjectPtr<TULoadNode> node=make_object<TULoadNode>();
node->dtype=producer->GetDataType();
node->producer=std::move(producer);
node->union_indices=std::move(union_indices);
data_=std::move(node);
ObjectPtr<TULoadNode> node = make_object<TULoadNode>();
node->dtype = producer->GetDataType();
node->producer = std::move(producer);
node->union_indices = std::move(union_indices);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.TULoad")
.set_body_typed([](DataProducer producer, Array<PrimExpr> union_indices) {
.set_body_typed([](DataProducer producer, Array<PrimExpr> union_indices) {
return ProducerLoad(producer, union_indices);
});
});
TVM_REGISTER_NODE_TYPE(TULoadNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TULoadNode>([](const ObjectRef& node, ReprPrinter* p) {
.set_dispatch<TULoadNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TULoadNode*>(node.get());
p->stream << op->producer->GetNameHint() << "[";
for (size_t i = 0; i < op->union_indices.size(); ++i) {
......@@ -42,8 +54,86 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}
}
p->stream << "]";
});
// TslTGemm
TVM_DEFINE_BINOP_CONSTRUCTOR(TslTGemm);
TVM_REGISTER_GLOBAL("tir.TslTGemm").set_body_typed([](PrimExpr a, PrimExpr b) {
return TslTGemm(a, b);
});
TVM_REGISTER_NODE_TYPE(TslTGemmNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TslTGemmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TslTGemmNode*>(node.get());
p->stream << "TslTGemm(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ')';
});
// TslTAdd
TVM_DEFINE_BINOP_CONSTRUCTOR(TslTAdd);
TVM_REGISTER_GLOBAL("tir.TslTAdd").set_body_typed([](PrimExpr a, PrimExpr b) {
return TslTAdd(a, b);
});
TVM_REGISTER_NODE_TYPE(TslTAddNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TslTAddNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TslTAddNode*>(node.get());
p->stream << "TslTAdd(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ')';
});
// TslTWrite
TVM_DEFINE_BINOP_CONSTRUCTOR(TslTWrite);
TVM_REGISTER_GLOBAL("tir.TslTWrite").set_body_typed([](PrimExpr a, PrimExpr b) {
return TslTWrite(a, b);
});
TVM_REGISTER_NODE_TYPE(TslTWriteNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TslTWriteNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TslTWriteNode*>(node.get());
p->stream << "TslTWrite(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ')';
});
// TslTStore
TVM_DEFINE_BINOP_CONSTRUCTOR(TslTStore);
TVM_REGISTER_GLOBAL("tir.TslTStore").set_body_typed([](PrimExpr a, PrimExpr b) {
return TslTStore(a, b);
});
TVM_REGISTER_NODE_TYPE(TslTStoreNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TslTStoreNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const TslTStoreNode*>(node.get());
p->stream << "TslTStore(";
p->Print(op->a);
p->stream << ", ";
p->Print(op->b);
p->stream << ')';
});
} // namespace tir
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment