Commit 14597ce4 by guoyuxuan

merge xujianxing fix.

parent ee64bda4
...@@ -7,13 +7,12 @@ ...@@ -7,13 +7,12 @@
#define TVM_TSL_TE_OPERATION_H_ #define TVM_TSL_TE_OPERATION_H_
#include <tvm/arith/analyzer.h> #include <tvm/arith/analyzer.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h> #include <tvm/te/schedule.h>
#include <tvm/te/tensor.h> #include <tvm/te/tensor.h>
#include <tvm/tir/buffer.h> #include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h> #include <tvm/tir/expr.h>
#include <tvm/tir/op.h> #include <tvm/tir/op.h>
#include <tvm/te/operation.h>
#include <tvm/tsl/te/tensor.h> #include <tvm/tsl/te/tensor.h>
#include <string> #include <string>
...@@ -24,7 +23,7 @@ namespace tvm { ...@@ -24,7 +23,7 @@ namespace tvm {
/*! \brief Tensor expression language DSL. */ /*! \brief Tensor expression language DSL. */
namespace te { namespace te {
class TVM_DLL TslOperationNode: public Object{ class TVM_DLL TslOperationNode : public Object {
public: public:
/*! \brief optional name of the operation */ /*! \brief optional name of the operation */
std::string name; std::string name;
...@@ -54,25 +53,24 @@ class TVM_DLL TslOperationNode: public Object{ ...@@ -54,25 +53,24 @@ class TVM_DLL TslOperationNode: public Object{
* \brief List all the input Tensors. * \brief List all the input Tensors.
* \return List of input tensors. * \return List of input tensors.
*/ */
virtual Array<Tensor> InputTensorUnions() const = 0; virtual Array<TensorUnion> InputTensorUnions() const = 0;
static constexpr const char* _type_key = "TslOperation"; static constexpr const char* _type_key = "TslOperation";
TVM_DECLARE_BASE_OBJECT_INFO(TslOperationNode, Object); TVM_DECLARE_BASE_OBJECT_INFO(TslOperationNode, Object);
}; };
class TslPlaceholderOpNode:public TslOperationNode{ class TslPlaceholderOpNode : public TslOperationNode {
public: public:
Array<PrimExpr> union_shape; Array<PrimExpr> union_shape;
Array<PrimExpr> elem_shape; Array<PrimExpr> elem_shape;
DataType dtype; DataType dtype;
int num_outputs() const final; int num_outputs() const final;
Array<IterVar> root_iter_vars() const final{return {};} Array<IterVar> root_iter_vars() const final { return {}; }
DataType output_dtype(size_t i) const final; DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_unionshape(size_t i) const final; Array<PrimExpr> output_unionshape(size_t i) const final;
Array<PrimExpr> output_elemshape(size_t i) const final; Array<PrimExpr> output_elemshape(size_t i) const final;
Array<Tensor> InputTensorUnions() const final {return {};} Array<TensorUnion> InputTensorUnions() const final { return {}; }
void VisitAttrs(AttrVisitor* v) { void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name); v->Visit("name", &name);
v->Visit("tag", &tag); v->Visit("tag", &tag);
...@@ -87,18 +85,59 @@ class TslPlaceholderOpNode:public TslOperationNode{ ...@@ -87,18 +85,59 @@ class TslPlaceholderOpNode:public TslOperationNode{
class TslPlaceholderOp : public TslOperation { class TslPlaceholderOp : public TslOperation {
public: public:
TVM_DLL TslPlaceholderOp(std::string name, Array<PrimExpr> union_shape,Array<PrimExpr> elem_shape, DataType dtype); TVM_DLL TslPlaceholderOp(std::string name, Array<PrimExpr> union_shape,
Array<PrimExpr> elem_shape, DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(TslPlaceholderOp, TslOperation, TslPlaceholderOpNode); TVM_DEFINE_OBJECT_REF_METHODS(TslPlaceholderOp, TslOperation, TslPlaceholderOpNode);
}; };
TVM_DLL TensorUnion tsl_placeholder(Array<PrimExpr> union_shape,Array<PrimExpr> elem_shape, DataType dtype, std::string name); TVM_DLL TensorUnion tsl_placeholder(Array<PrimExpr> union_shape, Array<PrimExpr> elem_shape,
DataType dtype, std::string name);
inline const TslOperationNode* TslOperation::operator->() const { inline const TslOperationNode* TslOperation::operator->() const {
return static_cast<const TslOperationNode*>(get()); return static_cast<const TslOperationNode*>(get());
} }
class TVM_DLL TslBaseComputeOpNode : public TslOperationNode {
public:
Array<IterVar> union_axis;
Array<IterVar> union_reduce_axis;
Array<IterVar> root_iter_vars() const final; // root union itervars
Array<PrimExpr> output_unionshape(size_t idx) const final;
Array<PrimExpr> output_elemshape(size_t idx) const final { return {}; } // TODO:deal with this
static constexpr const char* _type_key = "TslBaseComputeOp";
TVM_DECLARE_BASE_OBJECT_INFO(TslBaseComputeOpNode, TslOperationNode);
};
class TVM_DLL TslComputeOpNode : public TslBaseComputeOpNode {
public:
Array<PrimExpr> body;
TslComputeOpNode() {}
int num_outputs() const final;
DataType output_dtype(size_t i) const final;
Array<TensorUnion> InputTensorUnions() const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("union_axis", &union_axis);
v->Visit("union_reduce_axis", &union_reduce_axis);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "TslComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(TslComputeOpNode, TslBaseComputeOpNode);
};
class TslComputeOp : public TslOperation {
public:
TVM_DLL TslComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> union_axis, Array<PrimExpr> body);
TVM_DEFINE_OBJECT_REF_METHODS(TslComputeOp, TslOperation, ComputeOpNode);
};
TVM_DLL TensorUnion Tslcompute(Array<PrimExpr> union_shape, FCompute fcompute,
std::string name = "TU", std::string tag = "",
Map<String, ObjectRef> attrs = {});
} // namespace te } // namespace te
} // namespace tvm } // namespace tvm
......
...@@ -36,7 +36,7 @@ class TslOperation : public ObjectRef { ...@@ -36,7 +36,7 @@ class TslOperation : public ObjectRef {
using ContainerType = TslOperationNode; using ContainerType = TslOperationNode;
}; };
class TensorUnionNode : public TslDataProducerNode { class TensorUnionNode : public tvm::tir::TslDataProducerNode {
public: public:
Array<PrimExpr> union_shape; Array<PrimExpr> union_shape;
Array<PrimExpr> elem_shape; Array<PrimExpr> elem_shape;
...@@ -59,7 +59,7 @@ class TensorUnionNode : public TslDataProducerNode { ...@@ -59,7 +59,7 @@ class TensorUnionNode : public TslDataProducerNode {
TVM_DECLARE_FINAL_OBJECT_INFO(TensorUnionNode, TslDataProducerNode); TVM_DECLARE_FINAL_OBJECT_INFO(TensorUnionNode, TslDataProducerNode);
}; };
class TensorUnion : public TslDataProducer { class TensorUnion : public tvm::tir::TslDataProducer {
public: public:
TVM_DLL TensorUnion(Array<PrimExpr> union_shape, Array<PrimExpr> elem_shape, DataType dtype, TVM_DLL TensorUnion(Array<PrimExpr> union_shape, Array<PrimExpr> elem_shape, DataType dtype,
TslOperation op, int value_index); TslOperation op, int value_index);
...@@ -106,4 +106,22 @@ inline bool TensorUnion::operator!=(const TensorUnion& other) const { return !(* ...@@ -106,4 +106,22 @@ inline bool TensorUnion::operator!=(const TensorUnion& other) const { return !(*
} // namespace te } // namespace te
} // namespace tvm } // namespace tvm
namespace std {
template <>
struct hash<::tvm::te::TslOperation> : public ::tvm::ObjectPtrHash {};
template <>
struct hash<::tvm::te::TensorUnion> {
std::size_t operator()(const ::tvm::te::TensorUnion& k) const {
::tvm::ObjectPtrHash hasher;
if (k.defined() && k->op.defined()) {
return hasher(k->op);
} else {
return hasher(k);
}
}
};
}; // namespace std
#endif // TVM_TSL_TE_TENSOR_H_ #endif // TVM_TSL_TE_TENSOR_H_
//
// Created by bb on 2021/1/3.
//
#include <tvm/runtime/registry.h>
#include <tvm/tsl/te/operation.h>
#include <tvm/te/operation.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tsl/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <unordered_set>
namespace tvm {
namespace te {
Array<IterVar> TslBaseComputeOpNode::root_iter_vars() const {
if (union_reduce_axis.size() == 0) return union_axis;
Array<IterVar> ret = union_axis;
for (IterVar iv : union_reduce_axis) {
ret.push_back(iv);
}
return ret;
}
Array<PrimExpr> TslBaseComputeOpNode::output_unionshape(size_t idx) const {
CHECK_LT(idx, num_outputs());
// for now, all outputs of a BaseComputeOp have the same shape
Array<PrimExpr> shape;
for (const auto& ivar : this->union_axis) {
const Range& r = ivar->dom;
shape.push_back(r->extent);
}
return shape;
}
Array<TensorUnion> TslComputeOpNode::InputTensorUnions() const {
Array<TensorUnion> ret;
std::unordered_set<TensorUnion> visited;
for (auto& e : body) {
tir::PostOrderVisit(e, [&ret, &visited](const ObjectRef& n) {
if (auto* pload = n.as<tir::TULoadNode>()) {
TensorUnion t = Downcast<TensorUnion>(pload->producer);
if (!visited.count(t)) {
ret.push_back(t);
visited.insert(t);
}
}
});
}
return ret;
}
DataType TslComputeOpNode::output_dtype(size_t i) const {
CHECK_LT(i, num_outputs());
return body[i].dtype();
}
int TslComputeOpNode::num_outputs() const { return body.size(); }
TslComputeOp::TslComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> union_axis, Array<PrimExpr> body) {
if (!attrs.defined()) {
attrs = Map<String, ObjectRef>();
}
auto n = make_object<TslComputeOpNode>();
n->name = std::move(name);
n->tag = std::move(tag);
n->attrs = std::move(attrs);
n->union_axis = std::move(union_axis);
n->body = std::move(body);
if (n->body[0]->IsInstance<tir::ReduceNode>()) {
const tir::ReduceNode* reduce = n->body[0].as<tir::ReduceNode>();
n->union_reduce_axis = reduce->axis;
}
// VerifyComputeOp(n.get()); //TODO: uncomment this after finishing verifier.
data_ = std::move(n);
}
TensorUnion Tslcompute(Array<PrimExpr> union_shape, FCompute fcompute, std::string name,
std::string tag, Map<String, ObjectRef> attrs) {
size_t ndim = union_shape.size();
std::vector<IterVar> axis;
std::vector<Var> args;
for (size_t i = 0; i < ndim; ++i) {
std::ostringstream os;
os << "ax" << i;
axis.emplace_back(
IterVar(Range(0, union_shape[i]), Var(os.str(), union_shape[i].dtype()), kDataPar));
args.push_back(axis.back()->var);
}
return TslComputeOp(name, tag, attrs, axis, {fcompute(args)}).output(0);
}
TVM_REGISTER_GLOBAL("te.TslComputeOp")
.set_body_typed([](std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> union_axis, Array<PrimExpr> body) {
return TslComputeOp(name, tag, attrs, union_axis, body);
});
/*class TslComputeVerifier final : protected tir::ExprVisitor {
public:
/// Special member functions
//@{
explicit ComputeVerifier(const ComputeOpNode* compute)
: compute_(compute), reduce_(compute->body[0].as<tir::ReduceNode>()) {}
virtual ~ComputeVerifier() = default;
ComputeVerifier(const ComputeVerifier&) = delete;
ComputeVerifier(ComputeVerifier&&) = delete;
ComputeVerifier& operator=(const ComputeVerifier&) = delete;
ComputeVerifier& operator=(ComputeVerifier&&) = delete;
//@}
/// Interface to perform compute verification
void Run() {
for (const PrimExpr e : compute_->body) {
// Check for consistency of top level reductions
const tir::ReduceNode* reduce = e.as<tir::ReduceNode>();
CHECK((reduce && reduce_) || (!reduce && !reduce_)) << "All ComputeOp should be consistent "
<< "with being Reduce operation or not.";
if (reduce && reduce_) {
CHECK(ReduceEqual(reduce, reduce_)) << "The Reduce inputs of ComputeOp should "
<< "have the same attribute except value_index";
}
level_ = 0;
ExprVisitor::VisitExpr(e);
}
}
protected:
/// Visitor implementation
//@{
void VisitExpr(const PrimExpr& n) final {
++level_;
ExprVisitor::VisitExpr(n);
--level_;
}
void VisitExpr_(const tir::ReduceNode* op) final {
// Check for non top level reductions
CHECK(0 == level_) << "Reductions are only allowed at the top level of compute. "
<< "Please create another tensor for further composition.";
}
//@}
private:
const ComputeOpNode* compute_{nullptr}; ///< ComputeOpNode to verify
const tir::ReduceNode* reduce_{nullptr}; ///< Top level Reduce operation
int level_{0}; ///< Level of op being processed
};
} // namespace*/ //TODO: complete this reduction verifier. it's not very significant in current setup though.
} // namespace te
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment