Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
151707e0
Commit
151707e0
authored
Oct 23, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
check stmt in
parent
dac6b528
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
225 additions
and
45 deletions
+225
-45
include/tvm/base.h
+24
-2
include/tvm/expr.h
+31
-29
include/tvm/expr_node.h
+31
-12
include/tvm/stmt.h
+57
-0
include/tvm/stmt_node.h
+29
-1
src/README.md
+7
-0
src/expr/expr.cc
+8
-1
src/expr/stmt.cc
+38
-0
No files found.
include/tvm/base.h
View file @
151707e0
...
...
@@ -22,14 +22,29 @@ class NodeRef;
class
UnaryOp
;
class
BinaryOp
;
/*! \brief pointer type mask */
const
int
kPtrTypeMask
=
16
;
/*! \brief list of all supported data types */
enum
DataType
:
int
{
kUnknown
=
0
,
kInt32
=
1
,
kFloat32
=
2
kFloat32
=
2
,
kInt32Buffer
=
kInt32
|
kPtrTypeMask
,
kFloat32Buffer
=
kFloat32
|
kPtrTypeMask
};
/*!
* \brief convert pointer type to data type
* \param ptr_type The pointer type.
* \return The corresponding data type.
*/
inline
DataType
Ptr2DataType
(
DataType
ptr_type
)
{
CHECK_GE
(
ptr_type
,
kPtrTypeMask
);
return
static_cast
<
DataType
>
(
ptr_type
&
(
kPtrTypeMask
-
1
));
}
/*!
* \brief List of subset node types used for quick runtime switch.
*
* \note The value of NodeType is not used for serialization type_key is used instead.
...
...
@@ -45,6 +60,7 @@ enum NodeType {
kBinaryOpNode
,
kReduceNode
,
kTensorReadNode
,
kBufferReadNode
,
// stmt nodes
kStoreNode
,
kForRangeNode
,
...
...
@@ -157,6 +173,8 @@ class NodeRef {
inline
bool
operator
!=
(
const
NodeRef
&
other
)
const
;
/*! \return the hash function for NodeRef */
inline
size_t
hash
()
const
;
/*! \return the raw internal pointer of the node */
inline
Node
*
node_ptr
()
const
;
protected
:
template
<
typename
T
,
typename
>
...
...
@@ -217,7 +235,11 @@ inline bool NodeRef::operator!=(const NodeRef& other) const {
}
inline
size_t
NodeRef
::
hash
()
const
{
return
std
::
hash
<
Node
*>
()(
node_
.
get
());
return
std
::
hash
<
Node
*>
()(
node_ptr
());
}
inline
Node
*
NodeRef
::
node_ptr
()
const
{
return
node_
.
get
();
}
}
// namespace tvm
...
...
include/tvm/expr.h
View file @
151707e0
...
...
@@ -10,8 +10,9 @@
#include "./base.h"
namespace
tvm
{
//
f
orward declare Expr
//
F
orward declare Expr
class
Expr
;
class
Var
;
/*!
* \brief create a constant expression
...
...
@@ -24,34 +25,33 @@ template<typename T,
inline
Expr
constant
(
T
value
);
/*!
* \brief create a integer expression
* \param value The value to the expression
* \return the expression.
*/
Expr
IntConstant
(
int64_t
value
);
/*!
* \brief create a float expression.
* \param value The value to the expression
* \return the expression.
*/
Expr
FloatConstant
(
double
value
);
/*!
* \brief create a float expression.
* \param value The value to the expression
* \return the expression.
*/
Expr
BufferRead
(
Var
buffer
,
Expr
offset
);
/*!
* \brief a expression type, holds a ref to root of an AST
*/
class
Expr
:
public
NodeRef
{
public
:
/*! \brief default constructor */
Expr
()
=
default
;
/*!
* \brief copy constructor
* \param other the input
*/
Expr
(
const
Expr
&
other
)
=
default
;
/*!
* \brief move constructor
* \param other the input
*/
Expr
(
Expr
&&
other
)
=
default
;
/*!
* \brief assign operator.
* \param other the input.
* \return reference to self
*/
Expr
&
operator
=
(
const
Expr
&
other
)
=
default
;
/*!
* \brief assign move operator.
* \param other the input.
* \return reference to self
*/
Expr
&
operator
=
(
Expr
&&
other
)
=
default
;
Expr
()
{}
/*!
* \brief constructor from constant value
* \param value the constant value
...
...
@@ -82,15 +82,17 @@ class Expr : public NodeRef {
void
Print
(
std
::
ostream
&
os
)
const
;
// NOLINT(*)
};
/*! \brief Variable class */
/*!
* \brief Variable class to represent the symbolic placeholder
* in the DSL, internally it is a VarNode.
*
* The Variable is uniquely identified by the address of VarNode.
*/
class
Var
:
public
Expr
{
public
:
Var
(
std
::
string
name
=
""
,
DataType
dtype
=
kInt32
);
// NOLINT(*)
};
Expr
IntConstant
(
int64_t
value
);
Expr
FloatConstant
(
double
value
);
/*! \brief base of expression node */
class
ExprNode
:
public
Node
{
public
:
...
...
@@ -98,7 +100,7 @@ class ExprNode : public Node {
DataType
dtype_
{
kUnknown
};
};
// i
nline i
mplementations
// implementations
inline
DataType
Expr
::
dtype
()
const
{
return
static_cast
<
const
ExprNode
*>
(
node_
.
get
())
->
dtype_
;
}
...
...
include/tvm/expr_node.h
View file @
151707e0
...
...
@@ -12,10 +12,8 @@
#include "./expr.h"
namespace
tvm
{
/*! \brief variable node for symbolic variables */
class
VarNode
:
public
ExprNode
{
public
:
struct
VarNode
:
public
ExprNode
{
/*! \brief hint name of the variable */
std
::
string
name
;
/*! \brief constructor */
...
...
@@ -32,7 +30,7 @@ class VarNode : public ExprNode {
};
/*! \brief integer constant node */
class
IntNode
:
public
ExprNode
{
struct
IntNode
:
public
ExprNode
{
public
:
/*! \brief the value field */
int64_t
value
;
...
...
@@ -51,8 +49,7 @@ class IntNode : public ExprNode {
};
/*! \brief float constant node */
class
FloatNode
:
public
ExprNode
{
public
:
struct
FloatNode
:
public
ExprNode
{
/*! \brief the value field */
double
value
;
/*! \brief constructor */
...
...
@@ -61,7 +58,7 @@ class FloatNode : public ExprNode {
dtype_
=
kFloat32
;
}
const
char
*
type_key
()
const
override
{
return
"
In
tNode"
;
return
"
Floa
tNode"
;
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"value"
,
&
value
);
...
...
@@ -70,8 +67,7 @@ class FloatNode : public ExprNode {
};
/*! \brief Unary mapping operator */
class
UnaryOpNode
:
public
ExprNode
{
public
:
struct
UnaryOpNode
:
public
ExprNode
{
/*! \brief The operator */
const
UnaryOp
*
op
;
/*! \brief The source expression */
...
...
@@ -105,7 +101,6 @@ class UnaryOpNode : public ExprNode {
/*! \brief Binary mapping operator */
struct
BinaryOpNode
:
public
ExprNode
{
public
:
/*! \brief The operator */
const
BinaryOp
*
op
;
/*! \brief The left operand */
...
...
@@ -143,7 +138,6 @@ struct BinaryOpNode : public ExprNode {
/*! \brief Reduction operator operator */
struct
ReduceNode
:
public
ExprNode
{
public
:
/*! \brief The operator */
const
BinaryOp
*
op
;
/*! \brief The source operand */
...
...
@@ -180,7 +174,6 @@ struct ReduceNode : public ExprNode {
/*! \brief Tensor read operator */
struct
TensorReadNode
:
public
ExprNode
{
public
:
/*! \brief The tensor to be read from */
Tensor
tensor
;
/*! \brief The indices of read */
...
...
@@ -215,6 +208,32 @@ struct TensorReadNode : public ExprNode {
}
};
/*! \brief Buffer read node */
struct
BufferReadNode
:
public
ExprNode
{
/*! \brief The buffer variable to be read from */
Var
buffer
;
/*! \brief The offset to be read from */
Expr
offset
;
/*! \brief constructor, do not use constructor */
BufferReadNode
()
{
node_type_
=
kBufferReadNode
;
}
const
char
*
type_key
()
const
override
{
return
"BufferReadNode"
;
}
void
Verify
()
const
override
{
CHECK_EQ
(
dtype_
,
Ptr2DataType
(
buffer
.
dtype
()));
CHECK_EQ
(
offset
.
dtype
(),
kInt32
);
}
void
VisitAttrs
(
AttrVisitor
*
visitor
)
override
{
visitor
->
Visit
(
"dtype"
,
&
dtype_
);
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"buffer"
,
&
buffer
);
fvisit
(
"offset"
,
&
offset
);
}
};
}
// namespace tvm
#endif // TVM_EXPR_NODE_H_
include/tvm/stmt.h
0 → 100644
View file @
151707e0
/*!
* Copyright (c) 2016 by Contributors
* \file stmt.h
* \brief The statement creation functions.
* The underlying container are defined in stmt_node.h
*/
#ifndef TVM_STMT_H_
#define TVM_STMT_H_
#include <type_traits>
#include "./base.h"
#include "./domain.h"
namespace
tvm
{
/*!
* \brief a expression type, holds a ref to root of an AST
*/
class
Stmt
:
public
NodeRef
{
public
:
/*! \brief default constructor */
Stmt
()
{}
/*!
* \brief constructor from node pointer
* \param nptr Another node shared pointer
*/
explicit
Stmt
(
std
::
shared_ptr
<
Node
>&&
nptr
)
:
NodeRef
(
std
::
move
(
nptr
))
{
CHECK
(
node_
.
get
()
!=
nullptr
);
}
};
/*!
* \brief construct Store Stmt.
* \param buffer The variable representing the buffer.
* \param offset The offset in the buffer
* \param src The source expression.
*/
Stmt
Store
(
Var
buffer
,
Expr
offset
,
Expr
src
);
/*!
* \brief construct ForRange Stmt
* \param loop_var The loop variable
* \param range The loop range
* \param body The loop body
*/
Stmt
ForRange
(
Var
loop_var
,
Range
range
,
Stmt
body
);
/*!
* \brief construct a IfThenElse
* \param cond The condition.
* \param then_body The body to go to in then condition.
* \param else_body The body to go to in else condition.
*/
Stmt
IfThenElse
(
Expr
cond
,
Stmt
then_body
,
Stmt
else_body
);
}
// namespace tvm
#endif // TVM_STMT_H_
include/tvm/stmt_node.h
View file @
151707e0
...
...
@@ -6,8 +6,15 @@
#ifndef TVM_STMT_NODE_H_
#define TVM_STMT_NODE_H_
#include "./base.h"
#include "./domain.h"
namespace
tvm
{
/*!
* \brief The internal base class of StmtNode
* So far no extra stuffs in here.
*/
struct
StmtNode
:
public
Node
{
};
...
...
@@ -23,11 +30,18 @@ struct StoreNode : public StmtNode {
StoreNode
()
{
node_type_
=
kStoreNode
;
}
const
char
*
type_key
()
const
override
{
return
"StoreNode"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"buffer"
,
&
buffer
);
fvisit
(
"offset"
,
&
offset
);
fvisit
(
"src"
,
&
src
);
}
void
Verify
()
const
override
{
CHECK_EQ
(
Ptr2DataType
(
buffer
.
dtype
()),
src
.
dtype
());
CHECK_EQ
(
offset
.
dtype
(),
kInt32
);
}
};
/*! \brief for loop in range */
...
...
@@ -42,11 +56,19 @@ struct ForRangeNode : public StmtNode {
ForRangeNode
()
{
node_type_
=
kForRangeNode
;
}
const
char
*
type_key
()
const
override
{
return
"ForRangeNode"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"loop_var"
,
&
loop_var
);
fvisit
(
"range"
,
&
range
);
fvisit
(
"body"
,
&
body
);
}
void
Verify
()
const
override
{
CHECK_EQ
(
loop_var
.
dtype
(),
kInt32
);
CHECK_EQ
(
this
->
range
->
begin
.
dtype
(),
loop_var
.
dtype
());
CHECK_EQ
(
this
->
range
->
end
.
dtype
(),
loop_var
.
dtype
());
}
};
/*! \brief conditional expression */
...
...
@@ -61,13 +83,19 @@ struct IfThenElseNode : public StmtNode {
IfThenElseNode
()
{
node_type_
=
kIfThenElseNode
;
}
const
char
*
type_key
()
const
override
{
return
"IfThenElseNode"
;
}
void
VisitNodeRefFields
(
FNodeRefVisit
fvisit
)
override
{
fvisit
(
"cond"
,
&
cond
);
fvisit
(
"then_body"
,
&
then_body
);
fvisit
(
"else_body"
,
&
else_body
);
}
void
Verify
()
const
override
{
CHECK_EQ
(
cond
.
dtype
(),
kInt32
);
}
};
}
// namespace tvm
#endif // TVM_
CODEGEN
_H_
#endif // TVM_
STMT_NODE
_H_
src/README.md
0 → 100644
View file @
151707e0
# Code organization
-
c_api C API related functions
-
lang The definition of DSL related data structure
-
schedule The Schedule->Stmt generation logic
-
codegen Backend code generation related
\ No newline at end of file
src/expr/expr.cc
View file @
151707e0
...
...
@@ -5,7 +5,6 @@
#include <tvm/expr.h>
#include <tvm/op.h>
#include <tvm/expr_node.h>
#include <cctype>
namespace
tvm
{
...
...
@@ -28,4 +27,12 @@ Expr FloatConstant(double value) {
return
Expr
(
std
::
move
(
nptr
));
}
Expr
BufferRead
(
Var
buffer
,
Expr
offset
)
{
auto
nptr
=
std
::
make_shared
<
BufferReadNode
>
();
nptr
->
buffer
=
std
::
move
(
buffer
);
nptr
->
offset
=
std
::
move
(
offset
);
nptr
->
Verify
();
return
Expr
(
std
::
move
(
nptr
));
}
}
// namespace tvm
src/expr/stmt.cc
0 → 100644
View file @
151707e0
/*!
* Copyright (c) 2016 by Contributors
* \file stmt.cc
*/
#include <tvm/expr.h>
#include <tvm/stmt.h>
#include <tvm/stmt_node.h>
namespace
tvm
{
Stmt
Store
(
Var
buffer
,
Expr
offset
,
Expr
src
)
{
auto
nptr
=
std
::
make_shared
<
StoreNode
>
();
nptr
->
buffer
=
std
::
move
(
buffer
);
nptr
->
offset
=
std
::
move
(
offset
);
nptr
->
src
=
std
::
move
(
src
);
nptr
->
Verify
();
return
Stmt
(
std
::
move
(
nptr
));
}
Stmt
ForRange
(
Var
loop_var
,
Range
range
,
Stmt
body
)
{
auto
nptr
=
std
::
make_shared
<
ForRangeNode
>
();
nptr
->
loop_var
=
std
::
move
(
loop_var
);
nptr
->
range
=
std
::
move
(
range
);
nptr
->
body
=
std
::
move
(
body
);
nptr
->
Verify
();
return
Stmt
(
std
::
move
(
nptr
));
}
Stmt
IfThenElse
(
Expr
cond
,
Stmt
then_body
,
Stmt
else_body
)
{
auto
nptr
=
std
::
make_shared
<
IfThenElseNode
>
();
nptr
->
cond
=
std
::
move
(
cond
);
nptr
->
then_body
=
std
::
move
(
then_body
);
nptr
->
else_body
=
std
::
move
(
else_body
);
nptr
->
Verify
();
return
Stmt
(
std
::
move
(
nptr
));
}
}
// namespace tvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment