Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tvm08dev
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
YuxuanGuo
tvm08dev
Commits
e342fc36
Commit
e342fc36
authored
Jan 05, 2021
by
guoyuxuan
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add ir
parent
ff5c1a8e
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
206 additions
and
15 deletions
+206
-15
tvm/include/tvm/tsl/tir/expr.h
+107
-6
tvm/src/tsl/te/operation/compute_op.cc
+0
-0
tvm/src/tsl/te/operation/placeholder_op.cc
+0
-0
tvm/src/tsl/te/tensor.cc
+0
-0
tvm/src/tsl/tir/ir/expr.cc
+99
-9
No files found.
tvm/include/tvm/tsl/tir/expr.h
View file @
e342fc36
...
...
@@ -15,9 +15,8 @@
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/data_type.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/var.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/var.h>
#include <tvm/tsl/tir/buffer.h>
#include <algorithm>
...
...
@@ -30,7 +29,7 @@
namespace
tvm
{
namespace
tir
{
class
TULoadNode
:
public
PrimExprNode
{
class
TULoadNode
:
public
PrimExprNode
{
public
:
TslDataProducer
producer
;
Array
<
PrimExpr
>
union_indices
;
...
...
@@ -39,7 +38,7 @@ class TULoadNode:public PrimExprNode{
v
->
Visit
(
"producer"
,
&
producer
);
v
->
Visit
(
"union_indices"
,
&
union_indices
);
}
//TODO:investigate SEQUAL/SHASH
//
TODO:investigate SEQUAL/SHASH
bool
SEqualReduce
(
const
TULoadNode
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
dtype
,
other
->
dtype
)
&&
equal
(
producer
,
other
->
producer
)
&&
equal
(
union_indices
,
other
->
union_indices
);
...
...
@@ -55,16 +54,118 @@ class TULoadNode:public PrimExprNode{
TVM_DECLARE_FINAL_OBJECT_INFO
(
TULoadNode
,
PrimExprNode
);
};
class
TULoad
:
public
PrimExpr
{
class
TULoad
:
public
PrimExpr
{
public
:
TVM_DLL
explicit
TULoad
(
TslDataProducer
producer
,
Array
<
PrimExpr
>
union_indices
);
TVM_DEFINE_OBJECT_REF_METHODS
(
TULoad
,
PrimExpr
,
TULoadNode
);
};
/* OpNode start (yuxguo)
* PrimExprNode -> (TslUnaryOpNode, TslBinaryOpNode) -> (TslTGemmOpNode, TslTAddOpNode, ...)
*/
template
<
typename
T
>
class
TslBinaryOpNode
:
public
PrimExprNode
{
public
:
/*! \brief The left operand. */
PrimExpr
a
;
/*! \brief The right operand. */
PrimExpr
b
;
void
VisitAttrs
(
AttrVisitor
*
v
)
{
v
->
Visit
(
"dtype"
,
&
(
this
->
dtype
));
v
->
Visit
(
"a"
,
&
a
);
v
->
Visit
(
"b"
,
&
b
);
}
bool
SEqualReduce
(
const
T
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
dtype
,
other
->
dtype
)
&&
equal
(
a
,
other
->
a
)
&&
equal
(
b
,
other
->
b
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
dtype
);
hash_reduce
(
a
);
hash_reduce
(
b
);
}
TVM_DECLARE_FINAL_OBJECT_INFO
(
T
,
PrimExprNode
);
};
template
<
typename
T
>
class
TslUnaryOpNode
:
public
PrimExprNode
{
public
:
/*! \brief The operand. */
PrimExpr
a
;
void
VisitAttrs
(
AttrVisitor
*
v
)
{
v
->
Visit
(
"dtype"
,
&
(
this
->
dtype
));
v
->
Visit
(
"a"
,
&
a
);
}
bool
SEqualReduce
(
const
T
*
other
,
SEqualReducer
equal
)
const
{
return
equal
(
dtype
,
other
->
dtype
)
&&
equal
(
a
,
other
->
a
);
}
void
SHashReduce
(
SHashReducer
hash_reduce
)
const
{
hash_reduce
(
dtype
);
hash_reduce
(
a
);
}
TVM_DECLARE_FINAL_OBJECT_INFO
(
T
,
PrimExprNode
);
};
class
TslTGemmNode
:
public
TslBinaryOpNode
<
TslTGemmNode
>
{
public
:
static
constexpr
const
char
*
_type_key
=
"tir.TslTGemm"
;
};
class
TslTAddNode
:
public
TslBinaryOpNode
<
TslTAddNode
>
{
public
:
static
constexpr
const
char
*
_type_key
=
"tir.TslTAdd"
;
};
class
TslTWriteNode
:
public
TslBinaryOpNode
<
TslTWriteNode
>
{
public
:
static
constexpr
const
char
*
_type_key
=
"tir.TslTWrite"
;
};
class
TslTStoreNode
:
public
TslBinaryOpNode
<
TslTStoreNode
>
{
public
:
static
constexpr
const
char
*
_type_key
=
"tir.TslTStore"
;
};
/* Op start (yuxguo) manage opnode
* TslOperation (TslTensorGemmOp, TslTensorAddOp, ...)
* usage: using TslComputeOp to specify compute type and expanded in scheduleOps pass
*/
class
TslTGemm
:
public
PrimExpr
{
public
:
TVM_DLL
TslTGemm
(
PrimExpr
a
,
PrimExpr
b
);
TVM_DEFINE_OBJECT_REF_METHODS
(
TslTGemm
,
PrimExpr
,
TslTGemmNode
);
};
class
TslTAdd
:
public
PrimExpr
{
public
:
TVM_DLL
TslTAdd
(
PrimExpr
a
,
PrimExpr
b
);
TVM_DEFINE_OBJECT_REF_METHODS
(
TslTAdd
,
PrimExpr
,
TslTAddNode
);
};
class
TslTWrite
:
public
PrimExpr
{
public
:
TVM_DLL
TslTWrite
(
PrimExpr
a
,
PrimExpr
b
);
TVM_DEFINE_OBJECT_REF_METHODS
(
TslTWrite
,
PrimExpr
,
TslTWriteNode
);
};
class
TslTStore
:
public
PrimExpr
{
public
:
TVM_DLL
TslTStore
(
PrimExpr
a
,
PrimExpr
b
);
TVM_DEFINE_OBJECT_REF_METHODS
(
TslTStore
,
PrimExpr
,
TslTStoreNode
);
};
}
// namespace tir
}
// namespace tvm
#endif // TVM_TSL_TIR_EXPR_H_
tvm/src/tsl/t
sl
/operation/compute_op.cc
→
tvm/src/tsl/t
e
/operation/compute_op.cc
View file @
e342fc36
File moved
tvm/src/tsl/t
sl
/operation/placeholder_op.cc
→
tvm/src/tsl/t
e
/operation/placeholder_op.cc
View file @
e342fc36
File moved
tvm/src/tsl/t
sl
/tensor.cc
→
tvm/src/tsl/t
e
/tensor.cc
View file @
e342fc36
File moved
tvm/src/tsl/tir/ir/expr.cc
View file @
e342fc36
...
...
@@ -7,7 +7,6 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tsl/tir/expr.h>
#include <limits>
...
...
@@ -18,21 +17,34 @@
namespace
tvm
{
namespace
tir
{
#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \
Name::Name(PrimExpr a, PrimExpr b) { \
using T = Name::ContainerType; \
ICHECK(a.defined()) << "ValueError: a is undefined\n"; \
ICHECK(b.defined()) << "ValueError: b is undefined\n"; \
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types\n"; \
ObjectPtr<T> node = make_object<T>(); \
node->dtype = a.dtype(); \
node->a = std::move(a); \
node->b = std::move(b); \
data_ = std::move(node); \
}
TULoad
::
TULoad
(
TslDataProducer
producer
,
Array
<
PrimExpr
>
union_indices
)
{
ObjectPtr
<
TULoadNode
>
node
=
make_object
<
TULoadNode
>
();
node
->
dtype
=
producer
->
GetDataType
();
node
->
producer
=
std
::
move
(
producer
);
node
->
union_indices
=
std
::
move
(
union_indices
);
data_
=
std
::
move
(
node
);
ObjectPtr
<
TULoadNode
>
node
=
make_object
<
TULoadNode
>
();
node
->
dtype
=
producer
->
GetDataType
();
node
->
producer
=
std
::
move
(
producer
);
node
->
union_indices
=
std
::
move
(
union_indices
);
data_
=
std
::
move
(
node
);
}
TVM_REGISTER_GLOBAL
(
"tir.TULoad"
)
.
set_body_typed
([](
DataProducer
producer
,
Array
<
PrimExpr
>
union_indices
)
{
.
set_body_typed
([](
DataProducer
producer
,
Array
<
PrimExpr
>
union_indices
)
{
return
ProducerLoad
(
producer
,
union_indices
);
});
});
TVM_REGISTER_NODE_TYPE
(
TULoadNode
);
TVM_STATIC_IR_FUNCTOR
(
ReprPrinter
,
vtable
)
.
set_dispatch
<
TULoadNode
>
([](
const
ObjectRef
&
node
,
ReprPrinter
*
p
)
{
.
set_dispatch
<
TULoadNode
>
([](
const
ObjectRef
&
node
,
ReprPrinter
*
p
)
{
auto
*
op
=
static_cast
<
const
TULoadNode
*>
(
node
.
get
());
p
->
stream
<<
op
->
producer
->
GetNameHint
()
<<
"["
;
for
(
size_t
i
=
0
;
i
<
op
->
union_indices
.
size
();
++
i
)
{
...
...
@@ -42,8 +54,86 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
}
}
p
->
stream
<<
"]"
;
});
// TslTGemm
TVM_DEFINE_BINOP_CONSTRUCTOR
(
TslTGemm
);
TVM_REGISTER_GLOBAL
(
"tir.TslTGemm"
).
set_body_typed
([](
PrimExpr
a
,
PrimExpr
b
)
{
return
TslTGemm
(
a
,
b
);
});
TVM_REGISTER_NODE_TYPE
(
TslTGemmNode
);
TVM_STATIC_IR_FUNCTOR
(
ReprPrinter
,
vtable
)
.
set_dispatch
<
TslTGemmNode
>
([](
const
ObjectRef
&
node
,
ReprPrinter
*
p
)
{
auto
*
op
=
static_cast
<
const
TslTGemmNode
*>
(
node
.
get
());
p
->
stream
<<
"TslTGemm("
;
p
->
Print
(
op
->
a
);
p
->
stream
<<
", "
;
p
->
Print
(
op
->
b
);
p
->
stream
<<
')'
;
});
// TslTAdd
TVM_DEFINE_BINOP_CONSTRUCTOR
(
TslTAdd
);
TVM_REGISTER_GLOBAL
(
"tir.TslTAdd"
).
set_body_typed
([](
PrimExpr
a
,
PrimExpr
b
)
{
return
TslTAdd
(
a
,
b
);
});
TVM_REGISTER_NODE_TYPE
(
TslTAddNode
);
TVM_STATIC_IR_FUNCTOR
(
ReprPrinter
,
vtable
)
.
set_dispatch
<
TslTAddNode
>
([](
const
ObjectRef
&
node
,
ReprPrinter
*
p
)
{
auto
*
op
=
static_cast
<
const
TslTAddNode
*>
(
node
.
get
());
p
->
stream
<<
"TslTAdd("
;
p
->
Print
(
op
->
a
);
p
->
stream
<<
", "
;
p
->
Print
(
op
->
b
);
p
->
stream
<<
')'
;
});
// TslTWrite
TVM_DEFINE_BINOP_CONSTRUCTOR
(
TslTWrite
);
TVM_REGISTER_GLOBAL
(
"tir.TslTWrite"
).
set_body_typed
([](
PrimExpr
a
,
PrimExpr
b
)
{
return
TslTWrite
(
a
,
b
);
});
TVM_REGISTER_NODE_TYPE
(
TslTWriteNode
);
TVM_STATIC_IR_FUNCTOR
(
ReprPrinter
,
vtable
)
.
set_dispatch
<
TslTWriteNode
>
([](
const
ObjectRef
&
node
,
ReprPrinter
*
p
)
{
auto
*
op
=
static_cast
<
const
TslTWriteNode
*>
(
node
.
get
());
p
->
stream
<<
"TslTWrite("
;
p
->
Print
(
op
->
a
);
p
->
stream
<<
", "
;
p
->
Print
(
op
->
b
);
p
->
stream
<<
')'
;
});
// TslTStore
TVM_DEFINE_BINOP_CONSTRUCTOR
(
TslTStore
);
TVM_REGISTER_GLOBAL
(
"tir.TslTStore"
).
set_body_typed
([](
PrimExpr
a
,
PrimExpr
b
)
{
return
TslTStore
(
a
,
b
);
});
TVM_REGISTER_NODE_TYPE
(
TslTStoreNode
);
TVM_STATIC_IR_FUNCTOR
(
ReprPrinter
,
vtable
)
.
set_dispatch
<
TslTStoreNode
>
([](
const
ObjectRef
&
node
,
ReprPrinter
*
p
)
{
auto
*
op
=
static_cast
<
const
TslTStoreNode
*>
(
node
.
get
());
p
->
stream
<<
"TslTStore("
;
p
->
Print
(
op
->
a
);
p
->
stream
<<
", "
;
p
->
Print
(
op
->
b
);
p
->
stream
<<
')'
;
});
}
// namespace tir
}
// namespace tvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment