Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
330d49f8
Commit
330d49f8
authored
May 04, 2017
by
Tianqi Chen
Committed by
GitHub
May 04, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] Update new version of HalideIR (#116)
parent
d3c8256b
Show whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
149 additions
and
89 deletions
+149
-89
HalideIR
+1
-1
include/tvm/ir.h
+9
-0
include/tvm/ir_mutator.h
+1
-0
include/tvm/ir_pass.h
+1
-1
src/api/api_ir.cc
+27
-11
src/arithmetic/compute_expr.h
+1
-1
src/arithmetic/int_set.cc
+1
-1
src/codegen/codegen_c.cc
+5
-1
src/codegen/llvm/codegen_llvm.cc
+12
-5
src/codegen/stack_vm/codegen_stack_vm.cc
+3
-3
src/codegen/verilog/verilog_ir.cc
+5
-3
src/lang/buffer.cc
+5
-2
src/op/compute_op.cc
+4
-2
src/pass/inject_virtual_thread.cc
+7
-8
src/pass/ir_mutator.cc
+17
-5
src/pass/ir_util.h
+4
-2
src/pass/ir_visitor.cc
+2
-0
src/pass/lower_packed_call.cc
+2
-2
src/pass/lower_thread_allreduce.cc
+9
-7
src/pass/make_api.cc
+4
-3
src/pass/narrow_channel_access.cc
+4
-2
src/pass/split_pipeline.cc
+3
-2
src/pass/storage_flatten.cc
+2
-2
src/pass/storage_rewrite.cc
+2
-2
src/pass/storage_sync.cc
+1
-1
src/pass/vectorize_loop.cc
+16
-6
tests/cpp/ir_cse_pass_test.cc
+0
-15
tests/cpp/ir_simplify_test.cc
+1
-1
No files found.
HalideIR
@
4fffc62c
Subproject commit
398edacd956c6de82185821ffd9f482598182e5
1
Subproject commit
4fffc62c124651c1cde18f31957db413b677d60
1
include/tvm/ir.h
View file @
330d49f8
...
...
@@ -174,6 +174,14 @@ namespace intrinsic {
/*!
* \brief See pesudo code
*
* Handle tvm_address_of(Load *op) {
* return &op->buffer_var[index];
* }
*/
constexpr
const
char
*
tvm_address_of
=
"tvm_address_of"
;
/*!
* \brief See pesudo code
*
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
* return arr[index]->field;
* }
...
...
@@ -355,6 +363,7 @@ using Halide::Internal::Realize;
using
Halide
::
Internal
::
Block
;
using
Halide
::
Internal
::
IfThenElse
;
using
Halide
::
Internal
::
Evaluate
;
using
Halide
::
Internal
::
Shuffle
;
// ir functions
using
Halide
::
Internal
::
is_const_power_of_two_integer
;
...
...
include/tvm/ir_mutator.h
View file @
330d49f8
...
...
@@ -98,6 +98,7 @@ class IRMutator {
virtual
Expr
Mutate_
(
const
UIntImm
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
FloatImm
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
StringImm
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
Shuffle
*
op
,
const
Expr
&
e
);
};
}
// namespace ir
...
...
include/tvm/ir_pass.h
View file @
330d49f8
...
...
@@ -10,7 +10,7 @@
#define TVM_IR_PASS_H_
#include <ir/IREquality.h>
#include <
pass
/Simplify.h>
#include <
arithmetic
/Simplify.h>
#include <tvm/ir_functor.h>
#include <unordered_map>
#include <vector>
...
...
src/api/api_ir.cc
View file @
330d49f8
...
...
@@ -26,6 +26,26 @@ TVM_REGISTER_API("make.For")
args
[
5
]);
});
TVM_REGISTER_API
(
"make.Load"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Type
t
=
args
[
0
];
if
(
args
.
size
()
==
3
)
{
*
ret
=
Load
::
make
(
t
,
args
[
1
],
args
[
2
],
const_true
(
t
.
lanes
()));
}
else
{
*
ret
=
Load
::
make
(
t
,
args
[
1
],
args
[
2
],
args
[
3
]);
}
});
TVM_REGISTER_API
(
"make.Store"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Expr
value
=
args
[
1
];
if
(
args
.
size
()
==
3
)
{
*
ret
=
Store
::
make
(
args
[
0
],
value
,
args
[
2
],
const_true
(
value
.
type
().
lanes
()));
}
else
{
*
ret
=
Store
::
make
(
args
[
0
],
value
,
args
[
2
],
args
[
3
]);
}
});
TVM_REGISTER_API
(
"make.Realize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Realize
::
make
(
args
[
0
],
...
...
@@ -47,15 +67,6 @@ TVM_REGISTER_API("make.Call")
args
[
5
]);
});
TVM_REGISTER_API
(
"make.Allocate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Allocate
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
]);
});
TVM_REGISTER_API
(
"make.CommReducer"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
CommReducerNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
...
...
@@ -87,6 +98,12 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \
#define REGISTER_MAKE5(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
...
...
@@ -125,8 +142,7 @@ REGISTER_MAKE3(Let);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE2
(
AssertStmt
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE3
(
Load
);
REGISTER_MAKE3
(
Store
);
REGISTER_MAKE5
(
Allocate
);
REGISTER_MAKE4
(
Provide
);
REGISTER_MAKE1
(
Free
);
REGISTER_MAKE2
(
Block
);
...
...
src/arithmetic/compute_expr.h
View file @
330d49f8
...
...
@@ -8,7 +8,7 @@
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <
pass
/Interval.h>
#include <
arithmetic
/Interval.h>
#include <limits>
namespace
tvm
{
...
...
src/arithmetic/int_set.cc
View file @
330d49f8
...
...
@@ -6,7 +6,7 @@
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <
pass
/Interval.h>
#include <
arithmetic
/Interval.h>
#include <unordered_map>
#include "./compute_expr.h"
#include "./int_set_internal.h"
...
...
src/codegen/codegen_c.cc
View file @
330d49f8
...
...
@@ -471,7 +471,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
PrintBinaryIntrinsitc
(
op
,
" << "
,
os
,
this
);
}
else
if
(
op
->
is_intrinsic
(
Call
::
shift_right
))
{
PrintBinaryIntrinsitc
(
op
,
" >> "
,
os
,
this
);
}
else
if
(
op
->
is_intrinsic
(
Call
::
address_of
))
{
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_
address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
os
<<
"(("
;
...
...
@@ -535,6 +535,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
std
::
string
ref
=
GetBufferRef
(
op
->
type
,
op
->
buffer_var
.
get
(),
op
->
index
);
os
<<
ref
;
}
else
{
CHECK
(
is_one
(
op
->
predicate
))
<<
"predicated load is not supported"
;
Expr
base
;
if
(
TryGetRamp1Base
(
op
->
index
,
op
->
type
.
lanes
(),
&
base
))
{
std
::
string
ref
=
GetVecLoad
(
op
->
type
,
op
->
buffer_var
.
get
(),
base
);
...
...
@@ -575,6 +577,8 @@ void CodeGenC::VisitStmt_(const Store* op) {
this
->
PrintIndent
();
stream
<<
ref
<<
" = "
<<
value
<<
";
\n
"
;
}
else
{
CHECK
(
is_one
(
op
->
predicate
))
<<
"Predicated store is not supported"
;
Expr
base
;
if
(
TryGetRamp1Base
(
op
->
index
,
t
.
lanes
(),
&
base
))
{
std
::
string
value
=
this
->
PrintExpr
(
op
->
value
);
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
330d49f8
...
...
@@ -702,7 +702,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return
builder_
->
CreateLShr
(
MakeValue
(
op
->
args
[
0
]),
MakeValue
(
op
->
args
[
1
]));
}
}
else
if
(
op
->
is_intrinsic
(
Call
::
address_of
))
{
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_
address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
return
CreateBufferPtr
(
...
...
@@ -752,7 +752,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
}
else
{
LOG
(
FATAL
)
<<
"Unknown stack alloca type "
<<
type
;
}
}
else
if
(
op
->
is_intrinsic
(
Call
::
null_handle
))
{
}
else
if
(
op
->
is_intrinsic
(
Call
::
reinterpret
)
&&
is_zero
(
op
->
args
[
0
]
))
{
return
llvm
::
Constant
::
getNullValue
(
t_void_p_
);
}
else
{
LOG
(
FATAL
)
<<
"Unknown intrinstic "
<<
op
->
name
;
...
...
@@ -1077,6 +1077,8 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(
}
llvm
::
Value
*
CodeGenLLVM
::
VisitExpr_
(
const
Load
*
op
)
{
CHECK
(
is_one
(
op
->
predicate
))
<<
"Predicated Load is not supported"
;
Type
t
=
op
->
type
;
const
Ramp
*
ramp
=
op
->
index
.
as
<
Ramp
>
();
llvm
::
Value
*
buf
=
GetVarValue
(
op
->
buffer_var
.
get
());
...
...
@@ -1135,12 +1137,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
t
,
op
->
buffer_var
,
Ramp
::
make
(
arith
::
ComputeExpr
<
Add
>
(
ramp
->
base
,
make_const
(
bt
,
first_shift
)),
make_const
(
bt
,
1
),
ramp
->
lanes
)));
make_const
(
bt
,
1
),
ramp
->
lanes
),
const_true
(
t
.
lanes
())));
llvm
::
Value
*
next
=
MakeValue
(
Load
::
make
(
t
,
op
->
buffer_var
,
Ramp
::
make
(
arith
::
ComputeExpr
<
Add
>
(
ramp
->
base
,
make_const
(
bt
,
ramp
->
lanes
+
next_shift
)),
make_const
(
bt
,
1
),
ramp
->
lanes
)));
make_const
(
bt
,
1
),
ramp
->
lanes
),
const_true
(
t
.
lanes
())));
// shuffle
std
::
vector
<
llvm
::
Constant
*>
indices
;
int
target_index
=
0
;
...
...
@@ -1170,7 +1174,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
make_const
(
ramp
->
base
.
type
(),
1
),
lanes
);
// load value then flip
llvm
::
Value
*
v
=
MakeValue
(
Load
::
make
(
t
,
op
->
buffer_var
,
neg_ramp
));
llvm
::
Value
*
v
=
MakeValue
(
Load
::
make
(
t
,
op
->
buffer_var
,
neg_ramp
,
const_true
(
t
.
lanes
())));
return
CreateVecFlip
(
v
);
}
else
{
llvm
::
Value
*
ret
=
llvm
::
UndefValue
::
get
(
LLVMType
(
t
));
...
...
@@ -1187,6 +1192,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
// stmts
void
CodeGenLLVM
::
VisitStmt_
(
const
Store
*
op
)
{
CHECK
(
is_one
(
op
->
predicate
))
<<
"Predicated Load is not supported"
;
llvm
::
Value
*
value
=
MakeValue
(
op
->
value
);
Type
t
=
op
->
value
.
type
();
const
Ramp
*
ramp
=
op
->
index
.
as
<
Ramp
>
();
...
...
src/codegen/stack_vm/codegen_stack_vm.cc
View file @
330d49f8
...
...
@@ -121,7 +121,7 @@ void CodeGenStackVM::VisitStmt_(const Allocate* op) {
}
void
CodeGenStackVM
::
VisitExpr_
(
const
Call
*
op
)
{
if
(
op
->
is_intrinsic
(
Call
::
address_of
))
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_
address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
this
->
PushOp
(
StackVM
::
LOAD_HEAP
,
GetVarID
(
l
->
buffer_var
.
get
()));
...
...
@@ -129,8 +129,8 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
this
->
PushOp
(
StackVM
::
PUSH_I64
,
l
->
type
.
element_of
().
bytes
());
this
->
PushOp
(
StackVM
::
MUL_I64
);
this
->
PushOp
(
StackVM
::
ADDR_ADD
);
}
else
if
(
op
->
is_intrinsic
(
Call
::
null_handle
))
{
this
->
Push
Op
(
StackVM
::
PUSH_I64
,
0
);
}
else
if
(
op
->
is_intrinsic
(
Call
::
reinterpret
))
{
this
->
Push
(
op
->
args
[
0
]
);
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_struct_get
))
{
CHECK_EQ
(
op
->
args
.
size
(),
3U
);
int
kind
=
op
->
args
[
2
].
as
<
IntImm
>
()
->
value
;
...
...
src/codegen/verilog/verilog_ir.cc
View file @
330d49f8
...
...
@@ -217,11 +217,13 @@ class PipelineExtractor: public IRVisitor {
if
(
is_zero
(
op
->
index
)
&&
load
)
{
compute
->
body
=
Store
::
make
(
op
->
buffer_var
,
Load
::
make
(
load
->
type
,
load
->
buffer_var
,
repl
.
Mutate
(
load
->
index
)),
op
->
index
);
Load
::
make
(
load
->
type
,
load
->
buffer_var
,
repl
.
Mutate
(
load
->
index
),
op
->
predicate
),
op
->
index
,
op
->
predicate
);
}
else
{
compute
->
body
=
Store
::
make
(
op
->
buffer_var
,
repl
.
Mutate
(
op
->
value
),
repl
.
Mutate
(
op
->
index
));
op
->
buffer_var
,
repl
.
Mutate
(
op
->
value
),
repl
.
Mutate
(
op
->
index
),
op
->
predicate
);
}
compute
->
inputs
=
repl
.
inputs_
;
pipeline_
->
stages
.
push_back
(
ComputeBlock
(
compute
));
...
...
src/lang/buffer.cc
View file @
330d49f8
...
...
@@ -49,13 +49,16 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr
Buffer
::
MakeLoad
(
Array
<
Expr
>
index
)
const
{
const
BufferNode
*
n
=
operator
->
();
return
ir
::
Load
::
make
(
n
->
dtype
,
n
->
data
,
BufferOffset
(
n
,
index
));
return
ir
::
Load
::
make
(
n
->
dtype
,
n
->
data
,
BufferOffset
(
n
,
index
),
const_true
(
n
->
dtype
.
lanes
()));
}
Stmt
Buffer
::
MakeStore
(
Array
<
Expr
>
index
,
Expr
value
)
const
{
const
BufferNode
*
n
=
operator
->
();
CHECK_EQ
(
value
.
type
(),
n
->
dtype
);
return
ir
::
Store
::
make
(
n
->
data
,
value
,
BufferOffset
(
n
,
index
));
return
ir
::
Store
::
make
(
n
->
data
,
value
,
BufferOffset
(
n
,
index
),
const_true
(
n
->
dtype
.
lanes
()));
}
Buffer
BufferNode
::
make
(
std
::
string
name
,
...
...
src/op/compute_op.cc
View file @
330d49f8
...
...
@@ -254,19 +254,21 @@ Stmt MakeCrossThreadReduction(
}
}
}
Type
t
=
reduce
->
type
;
Expr
pred
=
const_true
(
t
.
lanes
());
Stmt
reduce_body
=
Store
::
make
(
res_handle
,
Call
::
make
(
reduce
->
type
,
ir
::
intrinsic
::
tvm_thread_allreduce
,
freduce_args
,
Call
::
Intrinsic
),
0
);
0
,
pred
);
reduce_body
=
AttrStmt
::
make
(
reduce
->
combiner
,
attr
::
reduce_scope
,
make_zero
(
reduce
->
type
),
reduce_body
);
Stmt
assign_body
=
Provide
::
make
(
stage
->
op
,
0
,
Load
::
make
(
reduce
->
type
,
res_handle
,
0
),
args
);
stage
->
op
,
0
,
Load
::
make
(
reduce
->
type
,
res_handle
,
0
,
pred
),
args
);
assign_body
=
MergeNest
(
op
::
MakeIfNest
(
thread_head_check
),
assign_body
);
assign_body
=
MergeNest
(
op
::
MakeIfNest
(
conds
),
assign_body
);
Stmt
body
=
Allocate
::
make
(
...
...
src/pass/inject_virtual_thread.cc
View file @
330d49f8
...
...
@@ -152,12 +152,8 @@ class VTInjector : public IRMutator {
return
e
;
}
Expr
RewriteIndex
(
Expr
index
,
Expr
alloc_extent
)
const
{
if
(
index_rewrite_strategy_
==
0
)
{
return
index
*
num_threads_
+
var_
;
}
else
{
return
index
+
var_
*
alloc_extent
;
}
}
// Load
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
...
...
@@ -168,7 +164,8 @@ class VTInjector : public IRMutator {
auto
it
=
touched_alloc_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
touched_alloc_
.
end
())
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
RewriteIndex
(
op
->
index
,
it
->
second
));
RewriteIndex
(
op
->
index
,
it
->
second
),
op
->
predicate
);
}
else
{
return
expr
;
}
...
...
@@ -184,7 +181,8 @@ class VTInjector : public IRMutator {
if
(
it
!=
touched_alloc_
.
end
())
{
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
RewriteIndex
(
op
->
index
,
it
->
second
));
RewriteIndex
(
op
->
index
,
it
->
second
),
op
->
predicate
);
}
else
{
return
stmt
;
}
...
...
@@ -307,6 +305,9 @@ class VTInjector : public IRMutator {
for
(
size_t
i
=
1
;
i
<
extents
.
size
();
++
i
)
{
stride
=
arith
::
ComputeExpr
<
Mul
>
(
stride
,
extents
[
i
]);
}
if
(
op
->
type
.
lanes
()
!=
0
)
{
stride
=
stride
*
op
->
type
.
lanes
();
}
Array
<
Expr
>
other
;
other
.
push_back
(
num_threads_
);
for
(
Expr
e
:
extents
)
{
...
...
@@ -368,8 +369,6 @@ class VTInjector : public IRMutator {
Var
var_
;
// the threads/lanes
int
num_threads_
;
// Index rewriting strategy
int
index_rewrite_strategy_
{
1
};
// whethe the loop is already injected.
bool
vt_loop_injected_
{
false
};
// whether current expression get touched.
...
...
src/pass/ir_mutator.cc
View file @
330d49f8
...
...
@@ -143,10 +143,11 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Stmt
IRMutator
::
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
Expr
index
=
this
->
Mutate
(
op
->
index
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
Expr
pred
=
this
->
Mutate
(
op
->
predicate
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
)
&&
pred
.
same_as
(
op
->
predicate
))
{
return
s
;
}
else
{
return
Store
::
make
(
op
->
buffer_var
,
value
,
index
);
return
Store
::
make
(
op
->
buffer_var
,
value
,
index
,
pred
);
}
}
...
...
@@ -263,10 +264,11 @@ Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
Expr
IRMutator
::
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
{
Expr
index
=
this
->
Mutate
(
op
->
index
);
if
(
index
.
same_as
(
op
->
index
))
{
Expr
pred
=
this
->
Mutate
(
op
->
predicate
);
if
(
index
.
same_as
(
op
->
index
)
&&
pred
.
same_as
(
op
->
predicate
))
{
return
e
;
}
else
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
);
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
,
pred
);
}
}
...
...
@@ -383,6 +385,15 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
}
}
Expr
IRMutator
::
Mutate_
(
const
Shuffle
*
op
,
const
Expr
&
e
)
{
auto
new_vec
=
MutateArray
(
op
->
vectors
,
this
);
if
(
new_vec
.
same_as
(
op
->
vectors
))
{
return
e
;
}
else
{
return
Shuffle
::
make
(
new_vec
,
op
->
indices
);
}
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \
return e; \
...
...
@@ -422,7 +433,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.
DISPATCH_TO_MUTATE_EXPR
(
IntImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
UIntImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
FloatImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
StringImm
);
.
DISPATCH_TO_MUTATE_EXPR
(
StringImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
Shuffle
);
}
// namespace ir
}
// namespace tvm
src/pass/ir_util.h
View file @
330d49f8
...
...
@@ -111,8 +111,10 @@ inline Expr TVMStructGet(
*/
inline
Expr
AddressOffset
(
Var
handle
,
Type
dtype
,
int
offset
)
{
return
Call
::
make
(
Handle
(),
Call
::
address_of
,
{
Load
::
make
(
dtype
,
handle
,
make_const
(
Int
(
32
),
offset
))},
Call
::
PureIntrinsic
);
Handle
(),
intrinsic
::
tvm_address_of
,
{
Load
::
make
(
dtype
,
handle
,
make_const
(
Int
(
32
),
offset
*
dtype
.
lanes
()),
const_true
(
dtype
.
lanes
()))},
Call
::
PureIntrinsic
);
}
/*!
...
...
src/pass/ir_visitor.cc
View file @
330d49f8
...
...
@@ -81,11 +81,13 @@ void IRVisitor::Visit_(const Allocate *op) {
void
IRVisitor
::
Visit_
(
const
Load
*
op
)
{
this
->
Visit
(
op
->
index
);
this
->
Visit
(
op
->
predicate
);
}
void
IRVisitor
::
Visit_
(
const
Store
*
op
)
{
this
->
Visit
(
op
->
value
);
this
->
Visit
(
op
->
index
);
this
->
Visit
(
op
->
predicate
);
}
void
IRVisitor
::
Visit_
(
const
IfThenElse
*
op
)
{
...
...
src/pass/lower_packed_call.cc
View file @
330d49f8
...
...
@@ -99,7 +99,7 @@ class PackedCallBuilder : public IRMutator {
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
prep_seq_
.
emplace_back
(
Store
::
make
(
stack_shape_
,
Convert
(
Int
(
64
),
op
->
args
[
i
]),
ConstInt32
(
stack_begin
+
i
)));
ConstInt32
(
stack_begin
+
i
)
,
const_true
(
1
)
));
}
return
AddressOffset
(
stack_shape_
,
Int
(
64
),
stack_begin
);
}
...
...
@@ -169,7 +169,7 @@ class PackedCallBuilder : public IRMutator {
prep_seq_
.
emplace_back
(
Store
::
make
(
stack_tcode_
,
ConstInt32
(
arg_tcode
),
stack_index
));
stack_index
,
const_true
(
1
)
));
}
// UPDATE stack value
max_arg_stack_
=
std
::
max
(
run_arg_stack_
,
max_arg_stack_
);
...
...
src/pass/lower_thread_allreduce.cc
View file @
330d49f8
...
...
@@ -143,9 +143,10 @@ class ThreadAllreduceBuilder : public IRMutator {
int
threadx_extent
=
1
;
Expr
reduce_index
=
FlattenThread
(
vred
,
&
reduce_extent
);
Expr
group_index
=
FlattenThread
(
vpar
,
&
group_extent
);
Expr
pred
=
const_true
(
value
.
type
().
lanes
());
if
(
reduce_extent
==
1
)
{
// special case, no reduction is needed.
return
Store
::
make
(
op
->
buffer_var
,
value
,
0
);
return
Store
::
make
(
op
->
buffer_var
,
value
,
0
,
pred
);
}
// Whether the threadIdx.x is involved in reduction.
if
(
vred
[
0
].
scope
.
dim_index
==
0
)
{
...
...
@@ -155,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator {
std
::
vector
<
Stmt
>
seq
;
seq
.
emplace_back
(
Store
::
make
(
shared_buf
,
value
,
BufIndex
(
reduce_index
,
group_index
,
reduce_extent
)));
BufIndex
(
reduce_index
,
group_index
,
reduce_extent
)
,
pred
));
seq
.
emplace_back
(
SyncThread
(
"shared"
));
seq
.
emplace_back
(
MakeBufAllreduce
(
combiner
,
value
.
type
(),
shared_buf
,
...
...
@@ -164,11 +165,12 @@ class ThreadAllreduceBuilder : public IRMutator {
load_remap_
[
op
->
buffer_var
.
get
()]
=
Load
::
make
(
value
.
type
(),
shared_buf
,
BufIndex
(
make_zero
(
reduce_index
.
type
()),
group_index
,
reduce_extent
));
BufIndex
(
make_zero
(
reduce_index
.
type
()),
group_index
,
reduce_extent
),
pred
);
alloc_remap_
[
op
->
buffer_var
.
get
()]
=
Allocate
::
make
(
shared_buf
,
value
.
type
(),
{
Expr
(
group_extent
),
Expr
(
reduce_extent
)},
const_true
()
,
Evaluate
::
make
(
0
));
pred
,
Evaluate
::
make
(
0
));
return
MergeSeq
(
seq
);
}
// make allreduce.
...
...
@@ -192,9 +194,9 @@ class ThreadAllreduceBuilder : public IRMutator {
auto
freduce
=
[
&
](
int
offset
)
{
Expr
b
=
Load
::
make
(
type
,
shared_buf
,
BufIndex
(
reduce_index
+
offset
,
group_index
,
reduce_extent
));
Expr
a
=
Load
::
make
(
type
,
shared_buf
,
buf_index
);
return
Store
::
make
(
shared_buf
,
(
*
combiner
)(
a
,
b
),
buf_index
);
BufIndex
(
reduce_index
+
offset
,
group_index
,
reduce_extent
)
,
const_true
()
);
Expr
a
=
Load
::
make
(
type
,
shared_buf
,
buf_index
,
const_true
()
);
return
Store
::
make
(
shared_buf
,
(
*
combiner
)(
a
,
b
),
buf_index
,
const_true
()
);
};
// Step one, check for
if
(
reduce_align
>
reduce_extent
)
{
...
...
src/pass/make_api.cc
View file @
330d49f8
...
...
@@ -122,7 +122,8 @@ LoweredFunc MakeAPI(Stmt body,
Var
tcode
(
v_arg
->
name_hint
+
".code"
,
Int
(
32
));
seq_init
.
emplace_back
(
LetStmt
::
make
(
tcode
,
Load
::
make
(
Int
(
32
),
v_packed_arg_type_ids
,
IntImm
::
make
(
Int
(
32
),
i
)),
nop
));
Int
(
32
),
v_packed_arg_type_ids
,
IntImm
::
make
(
Int
(
32
),
i
),
const_true
(
1
)),
nop
));
Type
t
=
v_arg
.
type
();
if
(
t
.
is_handle
())
{
std
::
ostringstream
msg
;
...
...
@@ -191,7 +192,7 @@ LoweredFunc MakeAPI(Stmt body,
f_push
(
buf
->
shape
[
k
],
cast
(
buf
->
shape
[
k
].
type
(),
Load
::
make
(
tvm_shape_type
,
v_shape
,
IntImm
::
make
(
Int
(
32
),
k
))),
IntImm
::
make
(
Int
(
32
),
k
)
,
const_true
(
1
)
)),
field_name
.
str
());
}
// strides field
...
...
@@ -212,7 +213,7 @@ LoweredFunc MakeAPI(Stmt body,
f_push
(
buf
->
strides
[
k
],
cast
(
buf
->
shape
[
k
].
type
(),
Load
::
make
(
tvm_shape_type
,
v_strides
,
IntImm
::
make
(
Int
(
32
),
k
))),
IntImm
::
make
(
Int
(
32
),
k
)
,
const_true
(
1
)
)),
field_name
.
str
());
}
}
...
...
src/pass/narrow_channel_access.cc
View file @
330d49f8
...
...
@@ -75,7 +75,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
op
=
expr
.
as
<
Load
>
();
if
(
read_access_
&&
buf_var_
==
op
->
buffer_var
.
get
())
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
ir
::
Simplify
(
op
->
index
-
min_
));
op
->
type
,
op
->
buffer_var
,
ir
::
Simplify
(
op
->
index
-
min_
),
op
->
predicate
);
}
else
{
return
expr
;
}
...
...
@@ -85,7 +86,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
op
=
stmt
.
as
<
Store
>
();
if
(
!
read_access_
&&
buf_var_
==
op
->
buffer_var
.
get
())
{
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
ir
::
Simplify
(
op
->
index
-
min_
));
op
->
buffer_var
,
op
->
value
,
ir
::
Simplify
(
op
->
index
-
min_
),
op
->
predicate
);
}
else
{
return
stmt
;
}
...
...
src/pass/split_pipeline.cc
View file @
330d49f8
...
...
@@ -170,12 +170,13 @@ class StageSplitter : public IRMutator {
Expr
index
=
Mutate
(
op
->
index
);
Stmt
provide
=
Store
::
make
(
ch
->
handle_var
,
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
),
0
);
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
,
op
->
predicate
),
0
,
op
->
predicate
);
Stmt
temp
=
nest_
.
back
();
nest_
.
pop_back
();
stages_
.
emplace_back
(
BuildStage
(
provide
,
ch
));
nest_
.
push_back
(
temp
);
fifo_map_
[
ch
->
handle_var
.
get
()]
=
ch
;
return
Load
::
make
(
op
->
type
,
ch
->
handle_var
,
0
);
return
Load
::
make
(
op
->
type
,
ch
->
handle_var
,
0
,
op
->
predicate
);
}
Stmt
Split
(
Stmt
stmt
,
const
ProducerConsumer
*
env
)
{
...
...
src/pass/storage_flatten.cc
View file @
330d49f8
...
...
@@ -33,7 +33,7 @@ class StorageFlattener : public IRMutator {
op
=
stmt
.
as
<
Store
>
();
auto
it
=
extern_buf_remap_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
extern_buf_remap_
.
end
())
{
return
Store
::
make
(
it
->
second
,
op
->
value
,
op
->
index
);
return
Store
::
make
(
it
->
second
,
op
->
value
,
op
->
index
,
op
->
predicate
);
}
else
{
return
stmt
;
}
...
...
@@ -115,7 +115,7 @@ class StorageFlattener : public IRMutator {
op
=
expr
.
as
<
Load
>
();
auto
it
=
extern_buf_remap_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
extern_buf_remap_
.
end
())
{
return
Load
::
make
(
op
->
type
,
it
->
second
,
op
->
index
);
return
Load
::
make
(
op
->
type
,
it
->
second
,
op
->
index
,
op
->
predicate
);
}
else
{
return
expr
;
}
...
...
src/pass/storage_rewrite.cc
View file @
330d49f8
...
...
@@ -194,14 +194,14 @@ class StoragePlanRewriter : public IRMutator {
op
=
stmt
.
as
<
Store
>
();
auto
it
=
alloc_map_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
==
alloc_map_
.
end
())
return
stmt
;
return
Store
::
make
(
it
->
second
->
alloc_var
,
op
->
value
,
op
->
index
);
return
Store
::
make
(
it
->
second
->
alloc_var
,
op
->
value
,
op
->
index
,
op
->
predicate
);
}
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Load
>
();
auto
it
=
alloc_map_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
==
alloc_map_
.
end
())
return
expr
;
return
Load
::
make
(
op
->
type
,
it
->
second
->
alloc_var
,
op
->
index
);
return
Load
::
make
(
op
->
type
,
it
->
second
->
alloc_var
,
op
->
index
,
op
->
predicate
);
}
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
auto
it
=
alloc_map_
.
find
(
op
);
...
...
src/pass/storage_sync.cc
View file @
330d49f8
...
...
@@ -100,7 +100,7 @@ class StorageSyncPlanner : public IRVisitor {
}
}
void
Visit_
(
const
Call
*
op
)
final
{
if
(
op
->
is_intrinsic
(
Call
::
address_of
))
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_
address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
IRVisitor
::
Visit_
(
l
);
}
else
{
...
...
src/pass/vectorize_loop.cc
View file @
330d49f8
...
...
@@ -34,7 +34,8 @@ class VecAllocAccess : public IRMutator {
op
=
expr
.
as
<
Load
>
();
if
(
op
->
buffer_var
.
get
()
==
buf_
)
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
op
->
index
*
var_lanes_
+
var_
);
op
->
index
*
var_lanes_
+
var_
,
op
->
predicate
);
}
else
{
return
expr
;
}
...
...
@@ -46,7 +47,8 @@ class VecAllocAccess : public IRMutator {
if
(
op
->
buffer_var
.
get
()
==
buf_
)
{
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
op
->
index
*
var_lanes_
+
var_
);
op
->
index
*
var_lanes_
+
var_
,
op
->
predicate
);
}
else
{
return
stmt
;
}
...
...
@@ -160,11 +162,16 @@ class Vectorizer : public IRMutator {
// Load
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
index
=
this
->
Mutate
(
op
->
index
);
if
(
index
.
same_as
(
op
->
index
))
{
Expr
pred
=
this
->
Mutate
(
op
->
predicate
);
if
(
index
.
same_as
(
op
->
index
)
&&
pred
.
same_as
(
op
->
predicate
))
{
return
e
;
}
else
{
return
Load
::
make
(
op
->
type
.
with_lanes
(
index
.
type
().
lanes
()),
op
->
buffer_var
,
index
);
int
lanes
=
std
::
max
(
index
.
type
().
lanes
(),
pred
.
type
().
lanes
());
return
Load
::
make
(
op
->
type
.
with_lanes
(
lanes
),
op
->
buffer_var
,
BroadcastTo
(
index
,
lanes
),
BroadcastTo
(
pred
,
lanes
));
}
}
// Let
...
...
@@ -201,13 +208,16 @@ class Vectorizer : public IRMutator {
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
Expr
index
=
this
->
Mutate
(
op
->
index
);
Expr
pred
=
this
->
Mutate
(
op
->
predicate
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
return
s
;
}
else
{
int
lanes
=
std
::
max
(
value
.
type
().
lanes
(),
index
.
type
().
lanes
());
lanes
=
std
::
max
(
lanes
,
pred
.
type
().
lanes
());
return
Store
::
make
(
op
->
buffer_var
,
BroadcastTo
(
value
,
lanes
),
BroadcastTo
(
index
,
lanes
));
BroadcastTo
(
index
,
lanes
),
BroadcastTo
(
pred
,
lanes
));
}
}
// For
...
...
tests/cpp/ir_cse_pass_test.cc
deleted
100644 → 0
View file @
d3c8256b
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <pass/CSE.h>
TEST
(
IR_PASS
,
CSE
)
{
using
namespace
Halide
::
Internal
;
cse_test
();
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
return
RUN_ALL_TESTS
();
}
tests/cpp/ir_simplify_test.cc
View file @
330d49f8
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <
pass
/Simplify.h>
#include <
arithmetic
/Simplify.h>
TEST
(
IRSIMPLIFY
,
Basic
)
{
using
namespace
Halide
::
Internal
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment