Commit 151707e0 by tqchen

check stmt in

parent dac6b528
......@@ -22,14 +22,29 @@ class NodeRef;
class UnaryOp;
class BinaryOp;
/*! \brief pointer type mask */
const int kPtrTypeMask = 16;
/*! \brief list of all supported data types */
enum DataType : int {
kUnknown = 0,
kInt32 = 1,
kFloat32 = 2
kFloat32 = 2,
kInt32Buffer = kInt32 | kPtrTypeMask,
kFloat32Buffer = kFloat32 | kPtrTypeMask
};
/*!
* \brief convert pointer type to data type
* \param ptr_type The pointer type.
* \return The corresponding data type.
*/
inline DataType Ptr2DataType(DataType ptr_type) {
CHECK_GE(ptr_type, kPtrTypeMask);
return static_cast<DataType>(ptr_type & (kPtrTypeMask -1));
}
/*!
* \brief List of subset node types used for quick runtime switch.
*
* \note The value of NodeType is not used for serialization type_key is used instead.
......@@ -45,6 +60,7 @@ enum NodeType {
kBinaryOpNode,
kReduceNode,
kTensorReadNode,
kBufferReadNode,
// stmt nodes
kStoreNode,
kForRangeNode,
......@@ -157,6 +173,8 @@ class NodeRef {
inline bool operator!=(const NodeRef& other) const;
/*! \return the hash function for NodeRef */
inline size_t hash() const;
/*! \return the raw internal pointer of the node */
inline Node* node_ptr() const;
protected:
template<typename T, typename>
......@@ -217,7 +235,11 @@ inline bool NodeRef::operator!=(const NodeRef& other) const {
}
inline size_t NodeRef::hash() const {
return std::hash<Node*>()(node_.get());
return std::hash<Node*>()(node_ptr());
}
inline Node* NodeRef::node_ptr() const {
return node_.get();
}
} // namespace tvm
......
......@@ -10,8 +10,9 @@
#include "./base.h"
namespace tvm {
// forward declare Expr
// Forward declare Expr
class Expr;
class Var;
/*!
* \brief create a constant expression
......@@ -24,34 +25,33 @@ template<typename T,
inline Expr constant(T value);
/*!
* \brief create a integer expression
* \param value The value to the expression
* \return the expression.
*/
Expr IntConstant(int64_t value);
/*!
* \brief create a float expression.
* \param value The value to the expression
* \return the expression.
*/
Expr FloatConstant(double value);
/*!
* \brief create a float expression.
* \param value The value to the expression
* \return the expression.
*/
Expr BufferRead(Var buffer, Expr offset);
/*!
* \brief a expression type, holds a ref to root of an AST
*/
class Expr : public NodeRef {
public:
/*! \brief default constructor */
Expr() = default;
/*!
* \brief copy constructor
* \param other the input
*/
Expr(const Expr& other) = default;
/*!
* \brief move constructor
* \param other the input
*/
Expr(Expr&& other) = default;
/*!
* \brief assign operator.
* \param other the input.
* \return reference to self
*/
Expr& operator=(const Expr& other) = default;
/*!
* \brief assign move operator.
* \param other the input.
* \return reference to self
*/
Expr& operator=(Expr&& other) = default;
Expr() {}
/*!
* \brief constructor from constant value
* \param value the constant value
......@@ -82,15 +82,17 @@ class Expr : public NodeRef {
void Print(std::ostream& os) const; // NOLINT(*)
};
/*! \brief Variable class */
/*!
* \brief Variable class to represent the symbolic placeholder
* in the DSL, internally it is a VarNode.
*
* The Variable is uniquely identified by the address of VarNode.
*/
class Var : public Expr {
public:
Var(std::string name="", DataType dtype=kInt32); // NOLINT(*)
};
Expr IntConstant(int64_t value);
Expr FloatConstant(double value);
/*! \brief base of expression node */
class ExprNode : public Node {
public:
......@@ -98,7 +100,7 @@ class ExprNode : public Node {
DataType dtype_{kUnknown};
};
// inline implementations
// implementations
inline DataType Expr::dtype() const {
return static_cast<const ExprNode*>(node_.get())->dtype_;
}
......
......@@ -12,10 +12,8 @@
#include "./expr.h"
namespace tvm {
/*! \brief variable node for symbolic variables */
class VarNode : public ExprNode {
public:
struct VarNode : public ExprNode {
/*! \brief hint name of the variable */
std::string name;
/*! \brief constructor */
......@@ -32,7 +30,7 @@ class VarNode : public ExprNode {
};
/*! \brief integer constant node */
class IntNode : public ExprNode {
struct IntNode : public ExprNode {
public:
/*! \brief the value field */
int64_t value;
......@@ -51,8 +49,7 @@ class IntNode : public ExprNode {
};
/*! \brief float constant node */
class FloatNode : public ExprNode {
public:
struct FloatNode : public ExprNode {
/*! \brief the value field */
double value;
/*! \brief constructor */
......@@ -61,7 +58,7 @@ class FloatNode : public ExprNode {
dtype_ = kFloat32;
}
const char* type_key() const override {
return "IntNode";
return "FloatNode";
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("value", &value);
......@@ -70,8 +67,7 @@ class FloatNode : public ExprNode {
};
/*! \brief Unary mapping operator */
class UnaryOpNode : public ExprNode {
public:
struct UnaryOpNode : public ExprNode {
/*! \brief The operator */
const UnaryOp* op;
/*! \brief The source expression */
......@@ -105,7 +101,6 @@ class UnaryOpNode : public ExprNode {
/*! \brief Binary mapping operator */
struct BinaryOpNode : public ExprNode {
public:
/*! \brief The operator */
const BinaryOp* op;
/*! \brief The left operand */
......@@ -143,7 +138,6 @@ struct BinaryOpNode : public ExprNode {
/*! \brief Reduction operator operator */
struct ReduceNode : public ExprNode {
public:
/*! \brief The operator */
const BinaryOp* op;
/*! \brief The source operand */
......@@ -180,7 +174,6 @@ struct ReduceNode : public ExprNode {
/*! \brief Tensor read operator */
struct TensorReadNode : public ExprNode {
public:
/*! \brief The tensor to be read from */
Tensor tensor;
/*! \brief The indices of read */
......@@ -215,6 +208,32 @@ struct TensorReadNode : public ExprNode {
}
};
/*! \brief Buffer read node */
struct BufferReadNode : public ExprNode {
/*! \brief The buffer variable to be read from */
Var buffer;
/*! \brief The offset to be read from */
Expr offset;
/*! \brief constructor, do not use constructor */
BufferReadNode() {
node_type_ = kBufferReadNode;
}
const char* type_key() const override {
return "BufferReadNode";
}
void Verify() const override {
CHECK_EQ(dtype_, Ptr2DataType(buffer.dtype()));
CHECK_EQ(offset.dtype(), kInt32);
}
void VisitAttrs(AttrVisitor* visitor) override {
visitor->Visit("dtype", &dtype_);
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("buffer", &buffer);
fvisit("offset", &offset);
}
};
} // namespace tvm
#endif // TVM_EXPR_NODE_H_
/*!
* Copyright (c) 2016 by Contributors
* \file stmt.h
* \brief The statement creation functions.
* The underlying container are defined in stmt_node.h
*/
#ifndef TVM_STMT_H_
#define TVM_STMT_H_
#include <type_traits>
#include "./base.h"
#include "./domain.h"
namespace tvm {
/*!
* \brief a expression type, holds a ref to root of an AST
*/
class Stmt : public NodeRef {
public:
/*! \brief default constructor */
Stmt() {}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit Stmt(std::shared_ptr<Node>&& nptr) : NodeRef(std::move(nptr)) {
CHECK(node_.get() != nullptr);
}
};
/*!
* \brief construct Store Stmt.
* \param buffer The variable representing the buffer.
* \param offset The offset in the buffer
* \param src The source expression.
*/
Stmt Store(Var buffer, Expr offset, Expr src);
/*!
* \brief construct ForRange Stmt
* \param loop_var The loop variable
* \param range The loop range
* \param body The loop body
*/
Stmt ForRange(Var loop_var, Range range, Stmt body);
/*!
* \brief construct a IfThenElse
* \param cond The condition.
* \param then_body The body to go to in then condition.
* \param else_body The body to go to in else condition.
*/
Stmt IfThenElse(Expr cond, Stmt then_body, Stmt else_body);
} // namespace tvm
#endif // TVM_STMT_H_
......@@ -6,8 +6,15 @@
#ifndef TVM_STMT_NODE_H_
#define TVM_STMT_NODE_H_
#include "./base.h"
#include "./domain.h"
namespace tvm {
/*!
* \brief The internal base class of StmtNode
* So far no extra stuffs in here.
*/
struct StmtNode : public Node {
};
......@@ -23,11 +30,18 @@ struct StoreNode : public StmtNode {
StoreNode() {
node_type_ = kStoreNode;
}
const char* type_key() const override {
return "StoreNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("buffer", &buffer);
fvisit("offset", &offset);
fvisit("src", &src);
}
void Verify() const override {
CHECK_EQ(Ptr2DataType(buffer.dtype()), src.dtype());
CHECK_EQ(offset.dtype(), kInt32);
}
};
/*! \brief for loop in range */
......@@ -42,11 +56,19 @@ struct ForRangeNode : public StmtNode {
ForRangeNode() {
node_type_ = kForRangeNode;
}
const char* type_key() const override {
return "ForRangeNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("loop_var", &loop_var);
fvisit("range", &range);
fvisit("body", &body);
}
void Verify() const override {
CHECK_EQ(loop_var.dtype(), kInt32);
CHECK_EQ(this->range->begin.dtype(), loop_var.dtype());
CHECK_EQ(this->range->end.dtype(), loop_var.dtype());
}
};
/*! \brief conditional expression */
......@@ -61,13 +83,19 @@ struct IfThenElseNode : public StmtNode {
IfThenElseNode() {
node_type_ = kIfThenElseNode;
}
const char* type_key() const override {
return "IfThenElseNode";
}
void VisitNodeRefFields(FNodeRefVisit fvisit) override {
fvisit("cond", &cond);
fvisit("then_body", &then_body);
fvisit("else_body", &else_body);
}
void Verify() const override {
CHECK_EQ(cond.dtype(), kInt32);
}
};
} // namespace tvm
#endif // TVM_CODEGEN_H_
#endif // TVM_STMT_NODE_H_
# Code organization
- c_api C API related functions
- lang The definition of DSL related data structure
- schedule The Schedule->Stmt generation logic
- codegen Backend code generation related
\ No newline at end of file
......@@ -5,7 +5,6 @@
#include <tvm/expr.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
#include <cctype>
namespace tvm {
......@@ -28,4 +27,12 @@ Expr FloatConstant(double value) {
return Expr(std::move(nptr));
}
Expr BufferRead(Var buffer, Expr offset) {
auto nptr = std::make_shared<BufferReadNode>();
nptr->buffer = std::move(buffer);
nptr->offset = std::move(offset);
nptr->Verify();
return Expr(std::move(nptr));
}
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* \file stmt.cc
*/
#include <tvm/expr.h>
#include <tvm/stmt.h>
#include <tvm/stmt_node.h>
namespace tvm {
Stmt Store(Var buffer, Expr offset, Expr src) {
auto nptr = std::make_shared<StoreNode>();
nptr->buffer = std::move(buffer);
nptr->offset = std::move(offset);
nptr->src = std::move(src);
nptr->Verify();
return Stmt(std::move(nptr));
}
Stmt ForRange(Var loop_var, Range range, Stmt body) {
auto nptr = std::make_shared<ForRangeNode>();
nptr->loop_var = std::move(loop_var);
nptr->range = std::move(range);
nptr->body = std::move(body);
nptr->Verify();
return Stmt(std::move(nptr));
}
Stmt IfThenElse(Expr cond, Stmt then_body, Stmt else_body) {
auto nptr = std::make_shared<IfThenElseNode>();
nptr->cond = std::move(cond);
nptr->then_body = std::move(then_body);
nptr->else_body = std::move(else_body);
nptr->Verify();
return Stmt(std::move(nptr));
}
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment