Unverified Commit 88d2f34b by Tianqi Chen Committed by GitHub

[TIR] Introduce BufferLoad/Store (#5205)

Co-authored-by: Siyuan Feng <hzfengsy@sjtu.edu.cn>

This PR introduces BufferLoad/Store to TIR. The new nodes will replace
Provide and Call with Tensor arguments in the subsequent refactors.
parent 4e5c5843
...@@ -25,9 +25,8 @@ ...@@ -25,9 +25,8 @@
#define TVM_TIR_BUFFER_H_ #define TVM_TIR_BUFFER_H_
#include <tvm/node/container.h> #include <tvm/node/container.h>
#include <tvm/tir/expr.h> #include <tvm/ir/expr.h>
#include <tvm/tir/op.h> #include <tvm/tir/var.h>
#include <string> #include <string>
...@@ -36,6 +35,9 @@ namespace tir { ...@@ -36,6 +35,9 @@ namespace tir {
// Internal node container Buffer // Internal node container Buffer
class BufferNode; class BufferNode;
// forward declare Stmt
class Stmt;
/*! \brief buffer type */ /*! \brief buffer type */
enum BufferType : int { enum BufferType : int {
kDefault = 1, kDefault = 1,
...@@ -75,9 +77,9 @@ class Buffer : public ObjectRef { ...@@ -75,9 +77,9 @@ class Buffer : public ObjectRef {
* \param offset The offset of ptr. * \param offset The offset of ptr.
*/ */
TVM_DLL PrimExpr access_ptr(int access_mask, TVM_DLL PrimExpr access_ptr(int access_mask,
DataType ptr_type = DataType::Handle(), DataType ptr_type = DataType::Handle(),
int content_lanes = 1, int content_lanes = 1,
PrimExpr offset = make_const(DataType::Int(32), 0)) const; PrimExpr offset = IntImm(DataType::Int(32), 0)) const;
/*! /*!
* \brief Create an Expr that does a vector load at begin index. * \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index * \param begin The beginning index
......
...@@ -121,6 +121,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> { ...@@ -121,6 +121,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
virtual R VisitExpr_(const SizeVarNode* op, Args... args) { virtual R VisitExpr_(const SizeVarNode* op, Args... args) {
return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...); return VisitExpr_(static_cast<const VarNode*>(op), std::forward<Args>(args)...);
} }
virtual R VisitExpr_(const BufferLoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LoadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const LetNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const CallNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
...@@ -164,6 +165,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> { ...@@ -164,6 +165,7 @@ class ExprFunctor<R(const PrimExpr& n, Args...)> {
IR_EXPR_FUNCTOR_DISPATCH(VarNode); IR_EXPR_FUNCTOR_DISPATCH(VarNode);
IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode); IR_EXPR_FUNCTOR_DISPATCH(SizeVarNode);
IR_EXPR_FUNCTOR_DISPATCH(LoadNode); IR_EXPR_FUNCTOR_DISPATCH(LoadNode);
IR_EXPR_FUNCTOR_DISPATCH(BufferLoadNode);
IR_EXPR_FUNCTOR_DISPATCH(LetNode); IR_EXPR_FUNCTOR_DISPATCH(LetNode);
IR_EXPR_FUNCTOR_DISPATCH(CallNode); IR_EXPR_FUNCTOR_DISPATCH(CallNode);
IR_EXPR_FUNCTOR_DISPATCH(AddNode); IR_EXPR_FUNCTOR_DISPATCH(AddNode);
...@@ -214,6 +216,7 @@ class TVM_DLL ExprVisitor : ...@@ -214,6 +216,7 @@ class TVM_DLL ExprVisitor :
void VisitExpr_(const VarNode* op) override; void VisitExpr_(const VarNode* op) override;
void VisitExpr_(const SizeVarNode* op) override; void VisitExpr_(const SizeVarNode* op) override;
void VisitExpr_(const LoadNode* op) override; void VisitExpr_(const LoadNode* op) override;
void VisitExpr_(const BufferLoadNode* op) override;
void VisitExpr_(const LetNode* op) override; void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const CallNode* op) override; void VisitExpr_(const CallNode* op) override;
void VisitExpr_(const AddNode* op) override; void VisitExpr_(const AddNode* op) override;
...@@ -259,6 +262,7 @@ class TVM_DLL ExprMutator : ...@@ -259,6 +262,7 @@ class TVM_DLL ExprMutator :
PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const VarNode* op) override;
PrimExpr VisitExpr_(const SizeVarNode* op) override; PrimExpr VisitExpr_(const SizeVarNode* op) override;
PrimExpr VisitExpr_(const LoadNode* op) override; PrimExpr VisitExpr_(const LoadNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
PrimExpr VisitExpr_(const LetNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override;
PrimExpr VisitExpr_(const CallNode* op) override; PrimExpr VisitExpr_(const CallNode* op) override;
PrimExpr VisitExpr_(const AddNode* op) override; PrimExpr VisitExpr_(const AddNode* op) override;
......
...@@ -275,6 +275,57 @@ class StoreNode : public StmtNode { ...@@ -275,6 +275,57 @@ class StoreNode : public StmtNode {
}; };
/*! /*!
* \brief Store value to the high dimension buffer.
*
* \code
*
* buffer[i, j] = value;
*
* \endcode
* \sa BufferLoad
*/
class BufferStore;
class BufferStoreNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Buffer buffer;
/*! \brief The value to be stored. */
PrimExpr value;
/*! \brief The indices location to be stored. */
Array<PrimExpr> indices;
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("value", &value);
v->Visit("indices", &indices);
}
bool SEqualReduce(const BufferStoreNode* other, SEqualReducer equal) const {
return
equal(buffer, other->buffer) &&
equal(value, other->value) &&
equal(indices, other->indices);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(value);
hash_reduce(indices);
}
static constexpr const char* _type_key = "BufferStore";
TVM_DECLARE_FINAL_OBJECT_INFO(BufferStoreNode, StmtNode);
};
class BufferStore : public Stmt {
public:
TVM_DLL explicit BufferStore(Buffer buffer,
PrimExpr value,
Array<PrimExpr> indices);
TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode);
};
/*!
* \brief Store value into mult-dimensional array defined by func. * \brief Store value into mult-dimensional array defined by func.
*/ */
class ProvideNode : public StmtNode { class ProvideNode : public StmtNode {
......
...@@ -91,6 +91,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> { ...@@ -91,6 +91,7 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ForNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AllocateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const StoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
...@@ -154,6 +155,7 @@ class TVM_DLL StmtVisitor : ...@@ -154,6 +155,7 @@ class TVM_DLL StmtVisitor :
void VisitStmt_(const ForNode* op) override; void VisitStmt_(const ForNode* op) override;
void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AllocateNode* op) override;
void VisitStmt_(const StoreNode* op) override; void VisitStmt_(const StoreNode* op) override;
void VisitStmt_(const BufferStoreNode* op) override;
void VisitStmt_(const FreeNode* op) override; void VisitStmt_(const FreeNode* op) override;
void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override;
void VisitStmt_(const ProducerConsumerNode* op) override; void VisitStmt_(const ProducerConsumerNode* op) override;
...@@ -248,6 +250,7 @@ class TVM_DLL StmtMutator : ...@@ -248,6 +250,7 @@ class TVM_DLL StmtMutator :
Stmt VisitStmt_(const ForNode* op) override; Stmt VisitStmt_(const ForNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const StoreNode* op) override; Stmt VisitStmt_(const StoreNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const FreeNode* op) override; Stmt VisitStmt_(const FreeNode* op) override;
Stmt VisitStmt_(const AssertStmtNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override;
Stmt VisitStmt_(const ProducerConsumerNode* op) override; Stmt VisitStmt_(const ProducerConsumerNode* op) override;
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/var.h
* \brief Variables in the TIR.
*/
#ifndef TVM_TIR_VAR_H_
#define TVM_TIR_VAR_H_
#include <tvm/node/node.h>
#include <tvm/runtime/data_type.h>
#include <tvm/ir/expr.h>
#include <string>
namespace tvm {
namespace tir {
/*!
* \brief A variable node in the IR.
*
* A variable is uniquely identified by its address.
*
* Each variable is only binded once in the following nodes:
* - Allocate
* - For
* - Let
* - LetStmt
*/
class VarNode : public PrimExprNode {
public:
/*!
* \brief The hint to the variable name.
* \note Each variable is uniquely identified by its address.
*/
std::string name_hint;
/*!
* \brief type annotaion of the variable.
*
* It is an optional field that provides a refined type of the variable than dtype.
*
* \sa tvm/ir/type.h for discussion of relations between runtime::DataType and Type.
*/
Type type_annotation;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("name", &name_hint);
v->Visit("type_annotation", &type_annotation);
}
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
if (!equal(dtype, other->dtype)) return false;
if (!equal(type_annotation, other->type_annotation)) return false;
return equal.FreeVarEqualImpl(this, other);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dtype);
hash_reduce(type_annotation);
hash_reduce.FreeVarHashImpl(this);
}
static constexpr const char* _type_key = "tir.Var";
TVM_DECLARE_BASE_OBJECT_INFO(VarNode, PrimExprNode);
};
/*! \brief a named variable in TVM */
class Var : public PrimExpr {
public:
explicit Var(ObjectPtr<Object> n) : PrimExpr(n) {}
/*!
* \brief Constructor
* \param name_hint variable name
* \param dtype data type
*/
TVM_DLL explicit Var(std::string name_hint = "v",
DataType dtype = DataType::Int(32));
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \param type_annotation The type annotation.
*/
TVM_DLL explicit Var(std::string name_hint, Type type_annotation);
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
TVM_DLL Var copy_with_suffix(const std::string& suffix) const;
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const VarNode* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const VarNode* get() const {
return static_cast<const VarNode*>(data_.get());
}
/*! \brief type indicate the container type */
using ContainerType = VarNode;
};
/*!
* \brief A variable node represent a tensor index size,
* whose value must be non-negative.
*/
class SizeVarNode : public VarNode {
public:
static constexpr const char* _type_key = "tir.SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
};
/*! \brief a named variable represents a tensor index size */
class SizeVar : public Var {
public:
explicit SizeVar(ObjectPtr<Object> n) : Var(n) {}
/*!
* \brief constructor
* \param name_hint variable name
* \param t data type
*/
TVM_DLL explicit SizeVar(std::string name_hint = "s",
DataType t = DataType::Int(32));
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const SizeVarNode* operator->() const {
return get();
}
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
*/
const SizeVarNode* get() const {
return static_cast<const SizeVarNode*>(data_.get());
}
/*! \brief type indicate the container type */
using ContainerType = SizeVarNode;
};
/*! \brief container class of iteration variable. */
class IterVarNode;
using Region = Array<Range>;
/*!
* \brief Type of iteration variable.
* Each IterVar have a specific type.
*
* The type of iter var can be overriden via
* stage.iter_var_attrs given they are compatible.
*/
enum IterVarType : int {
/*!
* \brief Data parallel iteration.
* This normally corresponds to axis of Tensor.
* Allow all IterVar manipulations.
*
* \note This does not mean the loop
* have to be executed in parallel fashion.
*/
kDataPar = 0,
/*!
* \brief The IterVar itself is a thread-index
* of a fixed thread launching group.
* Note that this is already assumed to be paralellized.
*
* Disallow: split/fuse/vectorize/parallel
*/
kThreadIndex = 1,
/*!
* \brief Communicative reduction.
* Cannot be directly parallelized.
*
* Disallow: parallel/vectorize
*/
kCommReduce = 2,
/*!
* \brief Serial loops with loop carry dependency,
* the iteration must execute in order.
* Cannot be re-ordered.
*
* Disallow: reorder/parallel/vectorize
*/
kOrdered = 3,
/*!
* \brief IterVar is opaque,
*
* May not corresponds to any generated loop
* Disallow all IterVar manipulations and compute_at
*
* \note This is usually used to implement composite op
* or external op, where the
*/
kOpaque = 4,
// The following are possible additional
// types that are provided during schedule
/*!
* \brief The execution is unrolled.
*/
kUnrolled = 5,
/*!
* \brief The loop is vectorized.
*/
kVectorized = 6,
/*!
* \brief The loop is parallelized.
*/
kParallelized = 7,
/*!
* \brief Marks boundary of tensorization intrinsic.
*/
kTensorized = 8
};
/*!
* \brief Iteration Variable,
* represents an iteration over an integer interval.
*/
class IterVar : public ObjectRef {
public:
// construct a new iter var without a domain
IterVar() {}
// construct from shared ptr.
explicit IterVar(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const IterVarNode* operator->() const;
/*!
* \return the corresponding var in the IterVar.
*/
inline operator PrimExpr() const;
/*! \brief specify container node */
using ContainerType = IterVarNode;
};
using Domain = Array<Range>;
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
*/
class IterVarNode : public Object {
public:
/*!
* \brief the domain of iteration, if known, can be None
* For the intermediate schedule node, before schedule.
*/
Range dom;
/*! \brief The looping variable */
Var var;
/*! \brief The type of the IterVar */
IterVarType iter_type;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
*/
std::string thread_tag;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dom", &dom);
v->Visit("var", &var);
v->Visit("iter_type", &iter_type);
v->Visit("thread_tag", &thread_tag);
}
bool SEqualReduce(const IterVarNode* other, SEqualReducer equal) const {
return
equal(dom, other->dom) &&
equal.DefEqual(var, other->var) &&
equal(iter_type, other->iter_type) &&
equal(thread_tag, other->thread_tag);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(dom);
hash_reduce.DefHash(var);
hash_reduce(iter_type);
hash_reduce(thread_tag);
}
TVM_DLL static IterVar make(Range dom, Var var,
IterVarType iter_type,
std::string thread_tag = "");
static constexpr const char* _type_key = "IterVar";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(IterVarNode, Object);
};
// inline implementations
inline const IterVarNode* IterVar::operator->() const {
return static_cast<const IterVarNode*>(data_.get());
}
inline IterVar::operator PrimExpr() const {
return (*this)->var;
}
inline const char* IterVarType2String(IterVarType t) {
switch (t) {
case kDataPar: return "DataPar";
case kThreadIndex: return "ThreadIndex";
case kCommReduce: return "CommReduce";
case kOrdered: return "Ordered";
case kOpaque: return "Opaque";
case kUnrolled: return "Unrolled";
case kVectorized: return "Vectorized";
case kParallelized: return "Parallelized";
case kTensorized: return "Tensorized";
}
return "Unknown";
}
} // namespace tir
} // namespace tvm
#endif // TVM_TIR_VAR_H_
...@@ -24,11 +24,11 @@ from .data_layout import Layout, BijectiveLayout, bijective_layout, layout ...@@ -24,11 +24,11 @@ from .data_layout import Layout, BijectiveLayout, bijective_layout, layout
from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast from .expr import Var, SizeVar, Reduce, FloatImm, IntImm, StringImm, Cast
from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod from .expr import Add, Sub, Mul, Div, Mod, FloorDiv, FloorMod
from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not from .expr import Min, Max, EQ, NE, LT, LE, GT, GE, And, Or, Not
from .expr import Select, Load, Ramp, Broadcast, Shuffle, Call, Let from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let
from .expr import IterVar, Any from .expr import IterVar, Any
from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For
from .stmt import Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list from .stmt import IfThenElse, Evaluate, Prefetch, LoweredFunc, stmt_seq, stmt_list
from .function import PrimFunc from .function import PrimFunc
......
...@@ -14,18 +14,15 @@ ...@@ -14,18 +14,15 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
"""Expression AST Node in TVM. # pylint: disable=redefined-builtin
"""TIR expression nodes.
User do not need to deal with expression AST node directly.
But they can be helpful for developer to do quick proptyping.
While not displayed in the document and python file.
Each expression node have subfields that can be visited from python side. Each expression node have subfields that can be visited from python side.
For example, you can use addexp.a to get the left operand of an Add node. For example, you can use addexp.a to get the left operand of an Add node.
.. code-block:: python .. code-block:: python
x = te.var("n") x = tvm.tir.Var("n", "int32")
y = x + 2 y = x + 2
assert(isinstance(y, tvm.tir.Add)) assert(isinstance(y, tvm.tir.Add))
assert(y.a == x) assert(y.a == x)
...@@ -859,6 +856,23 @@ class Load(PrimExprWithOp): ...@@ -859,6 +856,23 @@ class Load(PrimExprWithOp):
@tvm._ffi.register_object @tvm._ffi.register_object
class BufferLoad(PrimExprWithOp):
"""Buffer load node.
Parameters
----------
buffer : Buffer
The buffer to be loaded.
indices : List[PrimExpr]
The buffer indices.
"""
def __init__(self, buffer, indices):
self.__init_handle_by_constructor__(
_ffi_api.BufferLoad, buffer, indices)
@tvm._ffi.register_object
class Ramp(PrimExprWithOp): class Ramp(PrimExprWithOp):
"""Ramp node. """Ramp node.
......
...@@ -16,15 +16,12 @@ ...@@ -16,15 +16,12 @@
# under the License. # under the License.
"""Statement AST Node in TVM. """Statement AST Node in TVM.
User do not need to deal with AST node directly.
But they can be helpful for developer to do quick proptyping.
While not displayed in the document and python file.
Each statement node have subfields that can be visited from python side. Each statement node have subfields that can be visited from python side.
.. code-block:: python .. code-block:: python
x = te.var("n") x = tvm.tir.Var("n", "int32")
a = te.var("array", "handle") a = tvm.tir.Var("array", "handle")
st = tvm.tir.stmt.Store(a, x + 1, 1) st = tvm.tir.stmt.Store(a, x + 1, 1)
assert isinstance(st, tvm.tir.stmt.Store) assert isinstance(st, tvm.tir.stmt.Store)
assert(st.buffer_var == a) assert(st.buffer_var == a)
...@@ -164,6 +161,26 @@ class Store(Stmt): ...@@ -164,6 +161,26 @@ class Store(Stmt):
@tvm._ffi.register_object @tvm._ffi.register_object
class BufferStore(Stmt):
"""Buffer store node.
Parameters
----------
buffer : Buffer
The buffer.
value : PrimExpr
The value we to be stored.
indices : List[PrimExpr]
The indices location to be stored.
"""
def __init__(self, buffer, value, indices):
self.__init_handle_by_constructor__(
_ffi_api.BufferStore, buffer, value, indices)
@tvm._ffi.register_object
class Provide(Stmt): class Provide(Stmt):
"""Provide node. """Provide node.
......
...@@ -407,6 +407,22 @@ PrimExpr AnyNode::make() { ...@@ -407,6 +407,22 @@ PrimExpr AnyNode::make() {
return PrimExpr(n); return PrimExpr(n);
} }
BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices) {
ObjectPtr<BufferLoadNode> node = make_object<BufferLoadNode>();
node->dtype = buffer->dtype;
node->buffer = std::move(buffer);
node->indices = std::move(indices);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.BufferLoad")
.set_body_typed([](Buffer buffer, Array<PrimExpr> indices) {
return BufferLoad(buffer, indices);
});
TVM_REGISTER_NODE_TYPE(BufferLoadNode);
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) { .set_dispatch<StringImmNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const StringImmNode*>(node.get()); auto* op = static_cast<const StringImmNode*>(node.get());
......
...@@ -36,6 +36,10 @@ void ExprVisitor::VisitExpr_(const LoadNode* op) { ...@@ -36,6 +36,10 @@ void ExprVisitor::VisitExpr_(const LoadNode* op) {
this->VisitExpr(op->predicate); this->VisitExpr(op->predicate);
} }
void ExprVisitor::VisitExpr_(const BufferLoadNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
void ExprVisitor::VisitExpr_(const LetNode* op) { void ExprVisitor::VisitExpr_(const LetNode* op) {
this->VisitExpr(op->value); this->VisitExpr(op->value);
this->VisitExpr(op->body); this->VisitExpr(op->body);
...@@ -128,6 +132,16 @@ PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) { ...@@ -128,6 +132,16 @@ PrimExpr ExprMutator::VisitExpr_(const LoadNode* op) {
} }
} }
PrimExpr ExprMutator::VisitExpr_(const BufferLoadNode* op) {
auto fmutate = [this](const PrimExpr& e) { return this->VisitExpr(e); };
Array<PrimExpr> indices = MutateArray(op->indices, fmutate);
if (indices.same_as(op->indices)) {
return GetRef<PrimExpr>(op);
} else {
return BufferLoad(op->buffer, indices);
}
}
PrimExpr ExprMutator::VisitExpr_(const LetNode* op) { PrimExpr ExprMutator::VisitExpr_(const LetNode* op) {
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
PrimExpr body = this->VisitExpr(op->body); PrimExpr body = this->VisitExpr(op->body);
......
...@@ -324,6 +324,21 @@ Stmt EvaluateNode::make(PrimExpr value) { ...@@ -324,6 +324,21 @@ Stmt EvaluateNode::make(PrimExpr value) {
TVM_REGISTER_GLOBAL("tir.Evaluate") TVM_REGISTER_GLOBAL("tir.Evaluate")
.set_body_typed(EvaluateNode::make); .set_body_typed(EvaluateNode::make);
BufferStore::BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
ObjectPtr<BufferStoreNode> node = make_object<BufferStoreNode>();
node->buffer = std::move(buffer);
node->value = std::move(value);
node->indices = std::move(indices);
data_ = std::move(node);
}
TVM_REGISTER_GLOBAL("tir.BufferStore")
.set_body_typed([](Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
return BufferStore(buffer, value, indices);
});
TVM_REGISTER_NODE_TYPE(BufferStoreNode);
// Printers // Printers
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
......
...@@ -160,6 +160,10 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) { ...@@ -160,6 +160,10 @@ void StmtVisitor::VisitStmt_(const StoreNode* op) {
this->VisitExpr(op->predicate); this->VisitExpr(op->predicate);
} }
void StmtVisitor::VisitStmt_(const BufferStoreNode* op) {
VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); });
}
void StmtVisitor::VisitStmt_(const IfThenElseNode* op) { void StmtVisitor::VisitStmt_(const IfThenElseNode* op) {
this->VisitExpr(op->condition); this->VisitExpr(op->condition);
this->VisitStmt(op->then_case); this->VisitStmt(op->then_case);
...@@ -343,6 +347,17 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) { ...@@ -343,6 +347,17 @@ Stmt StmtMutator::VisitStmt_(const StoreNode* op) {
} }
} }
Stmt StmtMutator::VisitStmt_(const BufferStoreNode* op) {
Array<PrimExpr> indices = Internal::Mutate(this, op->indices);
if (indices.same_as(op->indices)) {
return GetRef<Stmt>(op);
} else {
auto n = CopyOnWrite(op);
n->indices = std::move(indices);
return Stmt(n);
}
}
Stmt StmtMutator::VisitStmt_(const ProvideNode* op) { Stmt StmtMutator::VisitStmt_(const ProvideNode* op) {
Array<PrimExpr> args = Internal::Mutate(this, op->args); Array<PrimExpr> args = Internal::Mutate(this, op->args);
PrimExpr value = this->VisitExpr(op->value); PrimExpr value = this->VisitExpr(op->value);
......
...@@ -292,7 +292,18 @@ def test_vars(): ...@@ -292,7 +292,18 @@ def test_vars():
assert isinstance(ptype.element_type, tvm.ir.PrimType) assert isinstance(ptype.element_type, tvm.ir.PrimType)
def test_buffer_load_store():
b = tvm.tir.decl_buffer((10,), "float32")
x = tvm.tir.BufferLoad(b, [0])
assert isinstance(x, tvm.tir.BufferLoad)
assert x.dtype == "float32"
assert x.buffer == b
s = tvm.tir.BufferStore(b, 0.1, [0])
assert isinstance(s, tvm.tir.BufferStore)
if __name__ == "__main__": if __name__ == "__main__":
test_buffer_load_store()
test_vars() test_vars()
test_prim_func() test_prim_func()
test_cast() test_cast()
......
...@@ -166,6 +166,22 @@ def test_stmt(): ...@@ -166,6 +166,22 @@ def test_stmt():
assert consistent_equal(func2(), func2()) assert consistent_equal(func2(), func2())
def test_buffer_load_store():
b = tvm.tir.decl_buffer((10, 10), "float32")
x = tvm.tir.BufferLoad(b, [0, 1])
y = tvm.tir.BufferLoad(b, [0, 1])
z = tvm.tir.BufferLoad(b, [1, 2])
assert consistent_equal(y, x)
assert not consistent_equal(y, z)
i = tvm.tir.Var("x", "int32")
sx = tvm.tir.BufferStore(b, 0.1, [0, i])
sy = tvm.tir.BufferStore(b, 0.1, [0, i])
sz = tvm.tir.BufferStore(b, 0.1, [1, i])
assert consistent_equal(sy, sx)
assert not consistent_equal(sy, sz)
if __name__ == "__main__": if __name__ == "__main__":
test_exprs() test_exprs()
test_prim_func() test_prim_func()
...@@ -173,3 +189,4 @@ if __name__ == "__main__": ...@@ -173,3 +189,4 @@ if __name__ == "__main__":
test_array() test_array()
test_env_func() test_env_func()
test_stmt() test_stmt()
test_buffer_load_store()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment