Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
330d49f8
Commit
330d49f8
authored
May 04, 2017
by
Tianqi Chen
Committed by
GitHub
May 04, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[IR] Update new version of HalideIR (#116)
parent
d3c8256b
Hide whitespace changes
Inline
Side-by-side
Showing
28 changed files
with
150 additions
and
90 deletions
+150
-90
HalideIR
+1
-1
include/tvm/ir.h
+9
-0
include/tvm/ir_mutator.h
+1
-0
include/tvm/ir_pass.h
+1
-1
src/api/api_ir.cc
+27
-11
src/arithmetic/compute_expr.h
+1
-1
src/arithmetic/int_set.cc
+1
-1
src/codegen/codegen_c.cc
+5
-1
src/codegen/llvm/codegen_llvm.cc
+12
-5
src/codegen/stack_vm/codegen_stack_vm.cc
+3
-3
src/codegen/verilog/verilog_ir.cc
+5
-3
src/lang/buffer.cc
+5
-2
src/op/compute_op.cc
+4
-2
src/pass/inject_virtual_thread.cc
+8
-9
src/pass/ir_mutator.cc
+17
-5
src/pass/ir_util.h
+4
-2
src/pass/ir_visitor.cc
+2
-0
src/pass/lower_packed_call.cc
+2
-2
src/pass/lower_thread_allreduce.cc
+9
-7
src/pass/make_api.cc
+4
-3
src/pass/narrow_channel_access.cc
+4
-2
src/pass/split_pipeline.cc
+3
-2
src/pass/storage_flatten.cc
+2
-2
src/pass/storage_rewrite.cc
+2
-2
src/pass/storage_sync.cc
+1
-1
src/pass/vectorize_loop.cc
+16
-6
tests/cpp/ir_cse_pass_test.cc
+0
-15
tests/cpp/ir_simplify_test.cc
+1
-1
No files found.
HalideIR
@
4fffc62c
Subproject commit
398edacd956c6de82185821ffd9f482598182e5
1
Subproject commit
4fffc62c124651c1cde18f31957db413b677d60
1
include/tvm/ir.h
View file @
330d49f8
...
@@ -174,6 +174,14 @@ namespace intrinsic {
...
@@ -174,6 +174,14 @@ namespace intrinsic {
/*!
/*!
* \brief See pesudo code
* \brief See pesudo code
*
*
* Handle tvm_address_of(Load *op) {
* return &op->buffer_var[index];
* }
*/
constexpr
const
char
*
tvm_address_of
=
"tvm_address_of"
;
/*!
* \brief See pesudo code
*
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
* Type tvm_struct_get(StructType* arr, int index, int field_id) {
* return arr[index]->field;
* return arr[index]->field;
* }
* }
...
@@ -355,6 +363,7 @@ using Halide::Internal::Realize;
...
@@ -355,6 +363,7 @@ using Halide::Internal::Realize;
using
Halide
::
Internal
::
Block
;
using
Halide
::
Internal
::
Block
;
using
Halide
::
Internal
::
IfThenElse
;
using
Halide
::
Internal
::
IfThenElse
;
using
Halide
::
Internal
::
Evaluate
;
using
Halide
::
Internal
::
Evaluate
;
using
Halide
::
Internal
::
Shuffle
;
// ir functions
// ir functions
using
Halide
::
Internal
::
is_const_power_of_two_integer
;
using
Halide
::
Internal
::
is_const_power_of_two_integer
;
...
...
include/tvm/ir_mutator.h
View file @
330d49f8
...
@@ -98,6 +98,7 @@ class IRMutator {
...
@@ -98,6 +98,7 @@ class IRMutator {
virtual
Expr
Mutate_
(
const
UIntImm
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
UIntImm
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
FloatImm
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
FloatImm
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
StringImm
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
StringImm
*
op
,
const
Expr
&
e
);
virtual
Expr
Mutate_
(
const
Shuffle
*
op
,
const
Expr
&
e
);
};
};
}
// namespace ir
}
// namespace ir
...
...
include/tvm/ir_pass.h
View file @
330d49f8
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#define TVM_IR_PASS_H_
#define TVM_IR_PASS_H_
#include <ir/IREquality.h>
#include <ir/IREquality.h>
#include <
pass
/Simplify.h>
#include <
arithmetic
/Simplify.h>
#include <tvm/ir_functor.h>
#include <tvm/ir_functor.h>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
...
src/api/api_ir.cc
View file @
330d49f8
...
@@ -26,6 +26,26 @@ TVM_REGISTER_API("make.For")
...
@@ -26,6 +26,26 @@ TVM_REGISTER_API("make.For")
args
[
5
]);
args
[
5
]);
});
});
TVM_REGISTER_API
(
"make.Load"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Type
t
=
args
[
0
];
if
(
args
.
size
()
==
3
)
{
*
ret
=
Load
::
make
(
t
,
args
[
1
],
args
[
2
],
const_true
(
t
.
lanes
()));
}
else
{
*
ret
=
Load
::
make
(
t
,
args
[
1
],
args
[
2
],
args
[
3
]);
}
});
TVM_REGISTER_API
(
"make.Store"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Expr
value
=
args
[
1
];
if
(
args
.
size
()
==
3
)
{
*
ret
=
Store
::
make
(
args
[
0
],
value
,
args
[
2
],
const_true
(
value
.
type
().
lanes
()));
}
else
{
*
ret
=
Store
::
make
(
args
[
0
],
value
,
args
[
2
],
args
[
3
]);
}
});
TVM_REGISTER_API
(
"make.Realize"
)
TVM_REGISTER_API
(
"make.Realize"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Realize
::
make
(
args
[
0
],
*
ret
=
Realize
::
make
(
args
[
0
],
...
@@ -47,15 +67,6 @@ TVM_REGISTER_API("make.Call")
...
@@ -47,15 +67,6 @@ TVM_REGISTER_API("make.Call")
args
[
5
]);
args
[
5
]);
});
});
TVM_REGISTER_API
(
"make.Allocate"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
Allocate
::
make
(
args
[
0
],
args
[
1
],
args
[
2
],
args
[
3
],
args
[
4
]);
});
TVM_REGISTER_API
(
"make.CommReducer"
)
TVM_REGISTER_API
(
"make.CommReducer"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
*
ret
=
CommReducerNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
*
ret
=
CommReducerNode
::
make
(
args
[
0
],
args
[
1
],
args
[
2
]);
...
@@ -87,6 +98,12 @@ TVM_REGISTER_API("make.CommReducer")
...
@@ -87,6 +98,12 @@ TVM_REGISTER_API("make.CommReducer")
*ret = Node::make(args[0], args[1], args[2], args[3]); \
*ret = Node::make(args[0], args[1], args[2], args[3]); \
}) \
}) \
#define REGISTER_MAKE5(Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
*ret = Node::make(args[0], args[1], args[2], args[3], args[4]); \
}) \
#define REGISTER_MAKE_BINARY_OP(Node) \
#define REGISTER_MAKE_BINARY_OP(Node) \
TVM_REGISTER_API("make."#Node) \
TVM_REGISTER_API("make."#Node) \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
.set_body([](TVMArgs args, TVMRetValue *ret) { \
...
@@ -125,8 +142,7 @@ REGISTER_MAKE3(Let);
...
@@ -125,8 +142,7 @@ REGISTER_MAKE3(Let);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE3
(
LetStmt
);
REGISTER_MAKE2
(
AssertStmt
);
REGISTER_MAKE2
(
AssertStmt
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE3
(
ProducerConsumer
);
REGISTER_MAKE3
(
Load
);
REGISTER_MAKE5
(
Allocate
);
REGISTER_MAKE3
(
Store
);
REGISTER_MAKE4
(
Provide
);
REGISTER_MAKE4
(
Provide
);
REGISTER_MAKE1
(
Free
);
REGISTER_MAKE1
(
Free
);
REGISTER_MAKE2
(
Block
);
REGISTER_MAKE2
(
Block
);
...
...
src/arithmetic/compute_expr.h
View file @
330d49f8
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <
pass
/Interval.h>
#include <
arithmetic
/Interval.h>
#include <limits>
#include <limits>
namespace
tvm
{
namespace
tvm
{
...
...
src/arithmetic/int_set.cc
View file @
330d49f8
...
@@ -6,7 +6,7 @@
...
@@ -6,7 +6,7 @@
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <tvm/arithmetic.h>
#include <tvm/arithmetic.h>
#include <
pass
/Interval.h>
#include <
arithmetic
/Interval.h>
#include <unordered_map>
#include <unordered_map>
#include "./compute_expr.h"
#include "./compute_expr.h"
#include "./int_set_internal.h"
#include "./int_set_internal.h"
...
...
src/codegen/codegen_c.cc
View file @
330d49f8
...
@@ -471,7 +471,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
...
@@ -471,7 +471,7 @@ void CodeGenC::VisitExpr_(const Call *op, std::ostream& os) { // NOLINT(*)
PrintBinaryIntrinsitc
(
op
,
" << "
,
os
,
this
);
PrintBinaryIntrinsitc
(
op
,
" << "
,
os
,
this
);
}
else
if
(
op
->
is_intrinsic
(
Call
::
shift_right
))
{
}
else
if
(
op
->
is_intrinsic
(
Call
::
shift_right
))
{
PrintBinaryIntrinsitc
(
op
,
" >> "
,
os
,
this
);
PrintBinaryIntrinsitc
(
op
,
" >> "
,
os
,
this
);
}
else
if
(
op
->
is_intrinsic
(
Call
::
address_of
))
{
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_
address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
os
<<
"(("
;
os
<<
"(("
;
...
@@ -535,6 +535,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
...
@@ -535,6 +535,8 @@ void CodeGenC::VisitExpr_(const Load* op, std::ostream& os) { // NOLINT(*)
std
::
string
ref
=
GetBufferRef
(
op
->
type
,
op
->
buffer_var
.
get
(),
op
->
index
);
std
::
string
ref
=
GetBufferRef
(
op
->
type
,
op
->
buffer_var
.
get
(),
op
->
index
);
os
<<
ref
;
os
<<
ref
;
}
else
{
}
else
{
CHECK
(
is_one
(
op
->
predicate
))
<<
"predicated load is not supported"
;
Expr
base
;
Expr
base
;
if
(
TryGetRamp1Base
(
op
->
index
,
op
->
type
.
lanes
(),
&
base
))
{
if
(
TryGetRamp1Base
(
op
->
index
,
op
->
type
.
lanes
(),
&
base
))
{
std
::
string
ref
=
GetVecLoad
(
op
->
type
,
op
->
buffer_var
.
get
(),
base
);
std
::
string
ref
=
GetVecLoad
(
op
->
type
,
op
->
buffer_var
.
get
(),
base
);
...
@@ -575,6 +577,8 @@ void CodeGenC::VisitStmt_(const Store* op) {
...
@@ -575,6 +577,8 @@ void CodeGenC::VisitStmt_(const Store* op) {
this
->
PrintIndent
();
this
->
PrintIndent
();
stream
<<
ref
<<
" = "
<<
value
<<
";
\n
"
;
stream
<<
ref
<<
" = "
<<
value
<<
";
\n
"
;
}
else
{
}
else
{
CHECK
(
is_one
(
op
->
predicate
))
<<
"Predicated store is not supported"
;
Expr
base
;
Expr
base
;
if
(
TryGetRamp1Base
(
op
->
index
,
t
.
lanes
(),
&
base
))
{
if
(
TryGetRamp1Base
(
op
->
index
,
t
.
lanes
(),
&
base
))
{
std
::
string
value
=
this
->
PrintExpr
(
op
->
value
);
std
::
string
value
=
this
->
PrintExpr
(
op
->
value
);
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
330d49f8
...
@@ -702,7 +702,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
...
@@ -702,7 +702,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
return
builder_
->
CreateLShr
(
return
builder_
->
CreateLShr
(
MakeValue
(
op
->
args
[
0
]),
MakeValue
(
op
->
args
[
1
]));
MakeValue
(
op
->
args
[
0
]),
MakeValue
(
op
->
args
[
1
]));
}
}
}
else
if
(
op
->
is_intrinsic
(
Call
::
address_of
))
{
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_
address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
return
CreateBufferPtr
(
return
CreateBufferPtr
(
...
@@ -752,7 +752,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
...
@@ -752,7 +752,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinstic(const Call* op) {
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unknown stack alloca type "
<<
type
;
LOG
(
FATAL
)
<<
"Unknown stack alloca type "
<<
type
;
}
}
}
else
if
(
op
->
is_intrinsic
(
Call
::
null_handle
))
{
}
else
if
(
op
->
is_intrinsic
(
Call
::
reinterpret
)
&&
is_zero
(
op
->
args
[
0
]
))
{
return
llvm
::
Constant
::
getNullValue
(
t_void_p_
);
return
llvm
::
Constant
::
getNullValue
(
t_void_p_
);
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Unknown intrinstic "
<<
op
->
name
;
LOG
(
FATAL
)
<<
"Unknown intrinstic "
<<
op
->
name
;
...
@@ -1077,6 +1077,8 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(
...
@@ -1077,6 +1077,8 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(
}
}
llvm
::
Value
*
CodeGenLLVM
::
VisitExpr_
(
const
Load
*
op
)
{
llvm
::
Value
*
CodeGenLLVM
::
VisitExpr_
(
const
Load
*
op
)
{
CHECK
(
is_one
(
op
->
predicate
))
<<
"Predicated Load is not supported"
;
Type
t
=
op
->
type
;
Type
t
=
op
->
type
;
const
Ramp
*
ramp
=
op
->
index
.
as
<
Ramp
>
();
const
Ramp
*
ramp
=
op
->
index
.
as
<
Ramp
>
();
llvm
::
Value
*
buf
=
GetVarValue
(
op
->
buffer_var
.
get
());
llvm
::
Value
*
buf
=
GetVarValue
(
op
->
buffer_var
.
get
());
...
@@ -1135,12 +1137,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
...
@@ -1135,12 +1137,14 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
t
,
op
->
buffer_var
,
t
,
op
->
buffer_var
,
Ramp
::
make
(
arith
::
ComputeExpr
<
Add
>
(
Ramp
::
make
(
arith
::
ComputeExpr
<
Add
>
(
ramp
->
base
,
make_const
(
bt
,
first_shift
)),
ramp
->
base
,
make_const
(
bt
,
first_shift
)),
make_const
(
bt
,
1
),
ramp
->
lanes
)));
make_const
(
bt
,
1
),
ramp
->
lanes
),
const_true
(
t
.
lanes
())));
llvm
::
Value
*
next
=
MakeValue
(
Load
::
make
(
llvm
::
Value
*
next
=
MakeValue
(
Load
::
make
(
t
,
op
->
buffer_var
,
t
,
op
->
buffer_var
,
Ramp
::
make
(
arith
::
ComputeExpr
<
Add
>
(
Ramp
::
make
(
arith
::
ComputeExpr
<
Add
>
(
ramp
->
base
,
make_const
(
bt
,
ramp
->
lanes
+
next_shift
)),
ramp
->
base
,
make_const
(
bt
,
ramp
->
lanes
+
next_shift
)),
make_const
(
bt
,
1
),
ramp
->
lanes
)));
make_const
(
bt
,
1
),
ramp
->
lanes
),
const_true
(
t
.
lanes
())));
// shuffle
// shuffle
std
::
vector
<
llvm
::
Constant
*>
indices
;
std
::
vector
<
llvm
::
Constant
*>
indices
;
int
target_index
=
0
;
int
target_index
=
0
;
...
@@ -1170,7 +1174,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
...
@@ -1170,7 +1174,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
make_const
(
ramp
->
base
.
type
(),
1
),
make_const
(
ramp
->
base
.
type
(),
1
),
lanes
);
lanes
);
// load value then flip
// load value then flip
llvm
::
Value
*
v
=
MakeValue
(
Load
::
make
(
t
,
op
->
buffer_var
,
neg_ramp
));
llvm
::
Value
*
v
=
MakeValue
(
Load
::
make
(
t
,
op
->
buffer_var
,
neg_ramp
,
const_true
(
t
.
lanes
())));
return
CreateVecFlip
(
v
);
return
CreateVecFlip
(
v
);
}
else
{
}
else
{
llvm
::
Value
*
ret
=
llvm
::
UndefValue
::
get
(
LLVMType
(
t
));
llvm
::
Value
*
ret
=
llvm
::
UndefValue
::
get
(
LLVMType
(
t
));
...
@@ -1187,6 +1192,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
...
@@ -1187,6 +1192,8 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Load* op) {
// stmts
// stmts
void
CodeGenLLVM
::
VisitStmt_
(
const
Store
*
op
)
{
void
CodeGenLLVM
::
VisitStmt_
(
const
Store
*
op
)
{
CHECK
(
is_one
(
op
->
predicate
))
<<
"Predicated Load is not supported"
;
llvm
::
Value
*
value
=
MakeValue
(
op
->
value
);
llvm
::
Value
*
value
=
MakeValue
(
op
->
value
);
Type
t
=
op
->
value
.
type
();
Type
t
=
op
->
value
.
type
();
const
Ramp
*
ramp
=
op
->
index
.
as
<
Ramp
>
();
const
Ramp
*
ramp
=
op
->
index
.
as
<
Ramp
>
();
...
...
src/codegen/stack_vm/codegen_stack_vm.cc
View file @
330d49f8
...
@@ -121,7 +121,7 @@ void CodeGenStackVM::VisitStmt_(const Allocate* op) {
...
@@ -121,7 +121,7 @@ void CodeGenStackVM::VisitStmt_(const Allocate* op) {
}
}
void
CodeGenStackVM
::
VisitExpr_
(
const
Call
*
op
)
{
void
CodeGenStackVM
::
VisitExpr_
(
const
Call
*
op
)
{
if
(
op
->
is_intrinsic
(
Call
::
address_of
))
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_
address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
CHECK
(
op
->
args
.
size
()
==
1
&&
l
);
this
->
PushOp
(
StackVM
::
LOAD_HEAP
,
GetVarID
(
l
->
buffer_var
.
get
()));
this
->
PushOp
(
StackVM
::
LOAD_HEAP
,
GetVarID
(
l
->
buffer_var
.
get
()));
...
@@ -129,8 +129,8 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
...
@@ -129,8 +129,8 @@ void CodeGenStackVM::VisitExpr_(const Call* op) {
this
->
PushOp
(
StackVM
::
PUSH_I64
,
l
->
type
.
element_of
().
bytes
());
this
->
PushOp
(
StackVM
::
PUSH_I64
,
l
->
type
.
element_of
().
bytes
());
this
->
PushOp
(
StackVM
::
MUL_I64
);
this
->
PushOp
(
StackVM
::
MUL_I64
);
this
->
PushOp
(
StackVM
::
ADDR_ADD
);
this
->
PushOp
(
StackVM
::
ADDR_ADD
);
}
else
if
(
op
->
is_intrinsic
(
Call
::
null_handle
))
{
}
else
if
(
op
->
is_intrinsic
(
Call
::
reinterpret
))
{
this
->
Push
Op
(
StackVM
::
PUSH_I64
,
0
);
this
->
Push
(
op
->
args
[
0
]
);
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_struct_get
))
{
}
else
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_struct_get
))
{
CHECK_EQ
(
op
->
args
.
size
(),
3U
);
CHECK_EQ
(
op
->
args
.
size
(),
3U
);
int
kind
=
op
->
args
[
2
].
as
<
IntImm
>
()
->
value
;
int
kind
=
op
->
args
[
2
].
as
<
IntImm
>
()
->
value
;
...
...
src/codegen/verilog/verilog_ir.cc
View file @
330d49f8
...
@@ -217,11 +217,13 @@ class PipelineExtractor: public IRVisitor {
...
@@ -217,11 +217,13 @@ class PipelineExtractor: public IRVisitor {
if
(
is_zero
(
op
->
index
)
&&
load
)
{
if
(
is_zero
(
op
->
index
)
&&
load
)
{
compute
->
body
=
Store
::
make
(
compute
->
body
=
Store
::
make
(
op
->
buffer_var
,
op
->
buffer_var
,
Load
::
make
(
load
->
type
,
load
->
buffer_var
,
repl
.
Mutate
(
load
->
index
)),
Load
::
make
(
load
->
type
,
load
->
buffer_var
,
op
->
index
);
repl
.
Mutate
(
load
->
index
),
op
->
predicate
),
op
->
index
,
op
->
predicate
);
}
else
{
}
else
{
compute
->
body
=
Store
::
make
(
compute
->
body
=
Store
::
make
(
op
->
buffer_var
,
repl
.
Mutate
(
op
->
value
),
repl
.
Mutate
(
op
->
index
));
op
->
buffer_var
,
repl
.
Mutate
(
op
->
value
),
repl
.
Mutate
(
op
->
index
),
op
->
predicate
);
}
}
compute
->
inputs
=
repl
.
inputs_
;
compute
->
inputs
=
repl
.
inputs_
;
pipeline_
->
stages
.
push_back
(
ComputeBlock
(
compute
));
pipeline_
->
stages
.
push_back
(
ComputeBlock
(
compute
));
...
...
src/lang/buffer.cc
View file @
330d49f8
...
@@ -49,13 +49,16 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
...
@@ -49,13 +49,16 @@ inline Expr BufferOffset(const BufferNode* n, Array<Expr> index) {
Expr
Buffer
::
MakeLoad
(
Array
<
Expr
>
index
)
const
{
Expr
Buffer
::
MakeLoad
(
Array
<
Expr
>
index
)
const
{
const
BufferNode
*
n
=
operator
->
();
const
BufferNode
*
n
=
operator
->
();
return
ir
::
Load
::
make
(
n
->
dtype
,
n
->
data
,
BufferOffset
(
n
,
index
));
return
ir
::
Load
::
make
(
n
->
dtype
,
n
->
data
,
BufferOffset
(
n
,
index
),
const_true
(
n
->
dtype
.
lanes
()));
}
}
Stmt
Buffer
::
MakeStore
(
Array
<
Expr
>
index
,
Expr
value
)
const
{
Stmt
Buffer
::
MakeStore
(
Array
<
Expr
>
index
,
Expr
value
)
const
{
const
BufferNode
*
n
=
operator
->
();
const
BufferNode
*
n
=
operator
->
();
CHECK_EQ
(
value
.
type
(),
n
->
dtype
);
CHECK_EQ
(
value
.
type
(),
n
->
dtype
);
return
ir
::
Store
::
make
(
n
->
data
,
value
,
BufferOffset
(
n
,
index
));
return
ir
::
Store
::
make
(
n
->
data
,
value
,
BufferOffset
(
n
,
index
),
const_true
(
n
->
dtype
.
lanes
()));
}
}
Buffer
BufferNode
::
make
(
std
::
string
name
,
Buffer
BufferNode
::
make
(
std
::
string
name
,
...
...
src/op/compute_op.cc
View file @
330d49f8
...
@@ -254,19 +254,21 @@ Stmt MakeCrossThreadReduction(
...
@@ -254,19 +254,21 @@ Stmt MakeCrossThreadReduction(
}
}
}
}
}
}
Type
t
=
reduce
->
type
;
Expr
pred
=
const_true
(
t
.
lanes
());
Stmt
reduce_body
=
Store
::
make
(
res_handle
,
Stmt
reduce_body
=
Store
::
make
(
res_handle
,
Call
::
make
(
Call
::
make
(
reduce
->
type
,
reduce
->
type
,
ir
::
intrinsic
::
tvm_thread_allreduce
,
ir
::
intrinsic
::
tvm_thread_allreduce
,
freduce_args
,
Call
::
Intrinsic
),
freduce_args
,
Call
::
Intrinsic
),
0
);
0
,
pred
);
reduce_body
=
AttrStmt
::
make
(
reduce_body
=
AttrStmt
::
make
(
reduce
->
combiner
,
reduce
->
combiner
,
attr
::
reduce_scope
,
attr
::
reduce_scope
,
make_zero
(
reduce
->
type
),
make_zero
(
reduce
->
type
),
reduce_body
);
reduce_body
);
Stmt
assign_body
=
Provide
::
make
(
Stmt
assign_body
=
Provide
::
make
(
stage
->
op
,
0
,
Load
::
make
(
reduce
->
type
,
res_handle
,
0
),
args
);
stage
->
op
,
0
,
Load
::
make
(
reduce
->
type
,
res_handle
,
0
,
pred
),
args
);
assign_body
=
MergeNest
(
op
::
MakeIfNest
(
thread_head_check
),
assign_body
);
assign_body
=
MergeNest
(
op
::
MakeIfNest
(
thread_head_check
),
assign_body
);
assign_body
=
MergeNest
(
op
::
MakeIfNest
(
conds
),
assign_body
);
assign_body
=
MergeNest
(
op
::
MakeIfNest
(
conds
),
assign_body
);
Stmt
body
=
Allocate
::
make
(
Stmt
body
=
Allocate
::
make
(
...
...
src/pass/inject_virtual_thread.cc
View file @
330d49f8
...
@@ -152,11 +152,7 @@ class VTInjector : public IRMutator {
...
@@ -152,11 +152,7 @@ class VTInjector : public IRMutator {
return
e
;
return
e
;
}
}
Expr
RewriteIndex
(
Expr
index
,
Expr
alloc_extent
)
const
{
Expr
RewriteIndex
(
Expr
index
,
Expr
alloc_extent
)
const
{
if
(
index_rewrite_strategy_
==
0
)
{
return
index
+
var_
*
alloc_extent
;
return
index
*
num_threads_
+
var_
;
}
else
{
return
index
+
var_
*
alloc_extent
;
}
}
}
// Load
// Load
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
...
@@ -168,7 +164,8 @@ class VTInjector : public IRMutator {
...
@@ -168,7 +164,8 @@ class VTInjector : public IRMutator {
auto
it
=
touched_alloc_
.
find
(
op
->
buffer_var
.
get
());
auto
it
=
touched_alloc_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
touched_alloc_
.
end
())
{
if
(
it
!=
touched_alloc_
.
end
())
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
RewriteIndex
(
op
->
index
,
it
->
second
));
RewriteIndex
(
op
->
index
,
it
->
second
),
op
->
predicate
);
}
else
{
}
else
{
return
expr
;
return
expr
;
}
}
...
@@ -184,7 +181,8 @@ class VTInjector : public IRMutator {
...
@@ -184,7 +181,8 @@ class VTInjector : public IRMutator {
if
(
it
!=
touched_alloc_
.
end
())
{
if
(
it
!=
touched_alloc_
.
end
())
{
return
Store
::
make
(
op
->
buffer_var
,
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
op
->
value
,
RewriteIndex
(
op
->
index
,
it
->
second
));
RewriteIndex
(
op
->
index
,
it
->
second
),
op
->
predicate
);
}
else
{
}
else
{
return
stmt
;
return
stmt
;
}
}
...
@@ -307,6 +305,9 @@ class VTInjector : public IRMutator {
...
@@ -307,6 +305,9 @@ class VTInjector : public IRMutator {
for
(
size_t
i
=
1
;
i
<
extents
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
extents
.
size
();
++
i
)
{
stride
=
arith
::
ComputeExpr
<
Mul
>
(
stride
,
extents
[
i
]);
stride
=
arith
::
ComputeExpr
<
Mul
>
(
stride
,
extents
[
i
]);
}
}
if
(
op
->
type
.
lanes
()
!=
0
)
{
stride
=
stride
*
op
->
type
.
lanes
();
}
Array
<
Expr
>
other
;
Array
<
Expr
>
other
;
other
.
push_back
(
num_threads_
);
other
.
push_back
(
num_threads_
);
for
(
Expr
e
:
extents
)
{
for
(
Expr
e
:
extents
)
{
...
@@ -368,8 +369,6 @@ class VTInjector : public IRMutator {
...
@@ -368,8 +369,6 @@ class VTInjector : public IRMutator {
Var
var_
;
Var
var_
;
// the threads/lanes
// the threads/lanes
int
num_threads_
;
int
num_threads_
;
// Index rewriting strategy
int
index_rewrite_strategy_
{
1
};
// whethe the loop is already injected.
// whethe the loop is already injected.
bool
vt_loop_injected_
{
false
};
bool
vt_loop_injected_
{
false
};
// whether current expression get touched.
// whether current expression get touched.
...
...
src/pass/ir_mutator.cc
View file @
330d49f8
...
@@ -143,10 +143,11 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
...
@@ -143,10 +143,11 @@ Stmt IRMutator::Mutate_(const IfThenElse *op, const Stmt& s) {
Stmt
IRMutator
::
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
{
Stmt
IRMutator
::
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
Expr
value
=
this
->
Mutate
(
op
->
value
);
Expr
index
=
this
->
Mutate
(
op
->
index
);
Expr
index
=
this
->
Mutate
(
op
->
index
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
Expr
pred
=
this
->
Mutate
(
op
->
predicate
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
)
&&
pred
.
same_as
(
op
->
predicate
))
{
return
s
;
return
s
;
}
else
{
}
else
{
return
Store
::
make
(
op
->
buffer_var
,
value
,
index
);
return
Store
::
make
(
op
->
buffer_var
,
value
,
index
,
pred
);
}
}
}
}
...
@@ -263,10 +264,11 @@ Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
...
@@ -263,10 +264,11 @@ Expr IRMutator::Mutate_(const Variable *op, const Expr& e) {
Expr
IRMutator
::
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
{
Expr
IRMutator
::
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
{
Expr
index
=
this
->
Mutate
(
op
->
index
);
Expr
index
=
this
->
Mutate
(
op
->
index
);
if
(
index
.
same_as
(
op
->
index
))
{
Expr
pred
=
this
->
Mutate
(
op
->
predicate
);
if
(
index
.
same_as
(
op
->
index
)
&&
pred
.
same_as
(
op
->
predicate
))
{
return
e
;
return
e
;
}
else
{
}
else
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
);
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
,
pred
);
}
}
}
}
...
@@ -383,6 +385,15 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
...
@@ -383,6 +385,15 @@ Expr IRMutator::Mutate_(const Broadcast *op, const Expr& e) {
}
}
}
}
Expr
IRMutator
::
Mutate_
(
const
Shuffle
*
op
,
const
Expr
&
e
)
{
auto
new_vec
=
MutateArray
(
op
->
vectors
,
this
);
if
(
new_vec
.
same_as
(
op
->
vectors
))
{
return
e
;
}
else
{
return
Shuffle
::
make
(
new_vec
,
op
->
indices
);
}
}
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
#define DEFINE_OP_RETURN_SELF_EXPR_MUTATE_(OP) \
Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \
Expr IRMutator::Mutate_(const OP *op, const Expr& e) { \
return e; \
return e; \
...
@@ -422,7 +433,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
...
@@ -422,7 +433,8 @@ TVM_STATIC_IR_FUNCTOR(IRMutator, vtable_expr)
.
DISPATCH_TO_MUTATE_EXPR
(
IntImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
IntImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
UIntImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
UIntImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
FloatImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
FloatImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
StringImm
);
.
DISPATCH_TO_MUTATE_EXPR
(
StringImm
)
.
DISPATCH_TO_MUTATE_EXPR
(
Shuffle
);
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
src/pass/ir_util.h
View file @
330d49f8
...
@@ -111,8 +111,10 @@ inline Expr TVMStructGet(
...
@@ -111,8 +111,10 @@ inline Expr TVMStructGet(
*/
*/
inline
Expr
AddressOffset
(
Var
handle
,
Type
dtype
,
int
offset
)
{
inline
Expr
AddressOffset
(
Var
handle
,
Type
dtype
,
int
offset
)
{
return
Call
::
make
(
return
Call
::
make
(
Handle
(),
Call
::
address_of
,
Handle
(),
intrinsic
::
tvm_address_of
,
{
Load
::
make
(
dtype
,
handle
,
make_const
(
Int
(
32
),
offset
))},
Call
::
PureIntrinsic
);
{
Load
::
make
(
dtype
,
handle
,
make_const
(
Int
(
32
),
offset
*
dtype
.
lanes
()),
const_true
(
dtype
.
lanes
()))},
Call
::
PureIntrinsic
);
}
}
/*!
/*!
...
...
src/pass/ir_visitor.cc
View file @
330d49f8
...
@@ -81,11 +81,13 @@ void IRVisitor::Visit_(const Allocate *op) {
...
@@ -81,11 +81,13 @@ void IRVisitor::Visit_(const Allocate *op) {
void
IRVisitor
::
Visit_
(
const
Load
*
op
)
{
void
IRVisitor
::
Visit_
(
const
Load
*
op
)
{
this
->
Visit
(
op
->
index
);
this
->
Visit
(
op
->
index
);
this
->
Visit
(
op
->
predicate
);
}
}
void
IRVisitor
::
Visit_
(
const
Store
*
op
)
{
void
IRVisitor
::
Visit_
(
const
Store
*
op
)
{
this
->
Visit
(
op
->
value
);
this
->
Visit
(
op
->
value
);
this
->
Visit
(
op
->
index
);
this
->
Visit
(
op
->
index
);
this
->
Visit
(
op
->
predicate
);
}
}
void
IRVisitor
::
Visit_
(
const
IfThenElse
*
op
)
{
void
IRVisitor
::
Visit_
(
const
IfThenElse
*
op
)
{
...
...
src/pass/lower_packed_call.cc
View file @
330d49f8
...
@@ -99,7 +99,7 @@ class PackedCallBuilder : public IRMutator {
...
@@ -99,7 +99,7 @@ class PackedCallBuilder : public IRMutator {
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
op
->
args
.
size
();
++
i
)
{
prep_seq_
.
emplace_back
(
prep_seq_
.
emplace_back
(
Store
::
make
(
stack_shape_
,
Convert
(
Int
(
64
),
op
->
args
[
i
]),
Store
::
make
(
stack_shape_
,
Convert
(
Int
(
64
),
op
->
args
[
i
]),
ConstInt32
(
stack_begin
+
i
)));
ConstInt32
(
stack_begin
+
i
)
,
const_true
(
1
)
));
}
}
return
AddressOffset
(
stack_shape_
,
Int
(
64
),
stack_begin
);
return
AddressOffset
(
stack_shape_
,
Int
(
64
),
stack_begin
);
}
}
...
@@ -169,7 +169,7 @@ class PackedCallBuilder : public IRMutator {
...
@@ -169,7 +169,7 @@ class PackedCallBuilder : public IRMutator {
prep_seq_
.
emplace_back
(
prep_seq_
.
emplace_back
(
Store
::
make
(
stack_tcode_
,
Store
::
make
(
stack_tcode_
,
ConstInt32
(
arg_tcode
),
ConstInt32
(
arg_tcode
),
stack_index
));
stack_index
,
const_true
(
1
)
));
}
}
// UPDATE stack value
// UPDATE stack value
max_arg_stack_
=
std
::
max
(
run_arg_stack_
,
max_arg_stack_
);
max_arg_stack_
=
std
::
max
(
run_arg_stack_
,
max_arg_stack_
);
...
...
src/pass/lower_thread_allreduce.cc
View file @
330d49f8
...
@@ -143,9 +143,10 @@ class ThreadAllreduceBuilder : public IRMutator {
...
@@ -143,9 +143,10 @@ class ThreadAllreduceBuilder : public IRMutator {
int
threadx_extent
=
1
;
int
threadx_extent
=
1
;
Expr
reduce_index
=
FlattenThread
(
vred
,
&
reduce_extent
);
Expr
reduce_index
=
FlattenThread
(
vred
,
&
reduce_extent
);
Expr
group_index
=
FlattenThread
(
vpar
,
&
group_extent
);
Expr
group_index
=
FlattenThread
(
vpar
,
&
group_extent
);
Expr
pred
=
const_true
(
value
.
type
().
lanes
());
if
(
reduce_extent
==
1
)
{
if
(
reduce_extent
==
1
)
{
// special case, no reduction is needed.
// special case, no reduction is needed.
return
Store
::
make
(
op
->
buffer_var
,
value
,
0
);
return
Store
::
make
(
op
->
buffer_var
,
value
,
0
,
pred
);
}
}
// Whether the threadIdx.x is involved in reduction.
// Whether the threadIdx.x is involved in reduction.
if
(
vred
[
0
].
scope
.
dim_index
==
0
)
{
if
(
vred
[
0
].
scope
.
dim_index
==
0
)
{
...
@@ -155,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator {
...
@@ -155,7 +156,7 @@ class ThreadAllreduceBuilder : public IRMutator {
std
::
vector
<
Stmt
>
seq
;
std
::
vector
<
Stmt
>
seq
;
seq
.
emplace_back
(
Store
::
make
(
seq
.
emplace_back
(
Store
::
make
(
shared_buf
,
value
,
shared_buf
,
value
,
BufIndex
(
reduce_index
,
group_index
,
reduce_extent
)));
BufIndex
(
reduce_index
,
group_index
,
reduce_extent
)
,
pred
));
seq
.
emplace_back
(
SyncThread
(
"shared"
));
seq
.
emplace_back
(
SyncThread
(
"shared"
));
seq
.
emplace_back
(
MakeBufAllreduce
(
seq
.
emplace_back
(
MakeBufAllreduce
(
combiner
,
value
.
type
(),
shared_buf
,
combiner
,
value
.
type
(),
shared_buf
,
...
@@ -164,11 +165,12 @@ class ThreadAllreduceBuilder : public IRMutator {
...
@@ -164,11 +165,12 @@ class ThreadAllreduceBuilder : public IRMutator {
load_remap_
[
op
->
buffer_var
.
get
()]
=
load_remap_
[
op
->
buffer_var
.
get
()]
=
Load
::
make
(
Load
::
make
(
value
.
type
(),
shared_buf
,
value
.
type
(),
shared_buf
,
BufIndex
(
make_zero
(
reduce_index
.
type
()),
group_index
,
reduce_extent
));
BufIndex
(
make_zero
(
reduce_index
.
type
()),
group_index
,
reduce_extent
),
pred
);
alloc_remap_
[
op
->
buffer_var
.
get
()]
=
alloc_remap_
[
op
->
buffer_var
.
get
()]
=
Allocate
::
make
(
shared_buf
,
value
.
type
(),
Allocate
::
make
(
shared_buf
,
value
.
type
(),
{
Expr
(
group_extent
),
Expr
(
reduce_extent
)},
{
Expr
(
group_extent
),
Expr
(
reduce_extent
)},
const_true
()
,
Evaluate
::
make
(
0
));
pred
,
Evaluate
::
make
(
0
));
return
MergeSeq
(
seq
);
return
MergeSeq
(
seq
);
}
}
// make allreduce.
// make allreduce.
...
@@ -192,9 +194,9 @@ class ThreadAllreduceBuilder : public IRMutator {
...
@@ -192,9 +194,9 @@ class ThreadAllreduceBuilder : public IRMutator {
auto
freduce
=
[
&
](
int
offset
)
{
auto
freduce
=
[
&
](
int
offset
)
{
Expr
b
=
Load
::
make
(
Expr
b
=
Load
::
make
(
type
,
shared_buf
,
type
,
shared_buf
,
BufIndex
(
reduce_index
+
offset
,
group_index
,
reduce_extent
));
BufIndex
(
reduce_index
+
offset
,
group_index
,
reduce_extent
)
,
const_true
()
);
Expr
a
=
Load
::
make
(
type
,
shared_buf
,
buf_index
);
Expr
a
=
Load
::
make
(
type
,
shared_buf
,
buf_index
,
const_true
()
);
return
Store
::
make
(
shared_buf
,
(
*
combiner
)(
a
,
b
),
buf_index
);
return
Store
::
make
(
shared_buf
,
(
*
combiner
)(
a
,
b
),
buf_index
,
const_true
()
);
};
};
// Step one, check for
// Step one, check for
if
(
reduce_align
>
reduce_extent
)
{
if
(
reduce_align
>
reduce_extent
)
{
...
...
src/pass/make_api.cc
View file @
330d49f8
...
@@ -122,7 +122,8 @@ LoweredFunc MakeAPI(Stmt body,
...
@@ -122,7 +122,8 @@ LoweredFunc MakeAPI(Stmt body,
Var
tcode
(
v_arg
->
name_hint
+
".code"
,
Int
(
32
));
Var
tcode
(
v_arg
->
name_hint
+
".code"
,
Int
(
32
));
seq_init
.
emplace_back
(
LetStmt
::
make
(
seq_init
.
emplace_back
(
LetStmt
::
make
(
tcode
,
Load
::
make
(
tcode
,
Load
::
make
(
Int
(
32
),
v_packed_arg_type_ids
,
IntImm
::
make
(
Int
(
32
),
i
)),
nop
));
Int
(
32
),
v_packed_arg_type_ids
,
IntImm
::
make
(
Int
(
32
),
i
),
const_true
(
1
)),
nop
));
Type
t
=
v_arg
.
type
();
Type
t
=
v_arg
.
type
();
if
(
t
.
is_handle
())
{
if
(
t
.
is_handle
())
{
std
::
ostringstream
msg
;
std
::
ostringstream
msg
;
...
@@ -191,7 +192,7 @@ LoweredFunc MakeAPI(Stmt body,
...
@@ -191,7 +192,7 @@ LoweredFunc MakeAPI(Stmt body,
f_push
(
buf
->
shape
[
k
],
f_push
(
buf
->
shape
[
k
],
cast
(
buf
->
shape
[
k
].
type
(),
cast
(
buf
->
shape
[
k
].
type
(),
Load
::
make
(
tvm_shape_type
,
v_shape
,
Load
::
make
(
tvm_shape_type
,
v_shape
,
IntImm
::
make
(
Int
(
32
),
k
))),
IntImm
::
make
(
Int
(
32
),
k
)
,
const_true
(
1
)
)),
field_name
.
str
());
field_name
.
str
());
}
}
// strides field
// strides field
...
@@ -212,7 +213,7 @@ LoweredFunc MakeAPI(Stmt body,
...
@@ -212,7 +213,7 @@ LoweredFunc MakeAPI(Stmt body,
f_push
(
buf
->
strides
[
k
],
f_push
(
buf
->
strides
[
k
],
cast
(
buf
->
shape
[
k
].
type
(),
cast
(
buf
->
shape
[
k
].
type
(),
Load
::
make
(
tvm_shape_type
,
v_strides
,
Load
::
make
(
tvm_shape_type
,
v_strides
,
IntImm
::
make
(
Int
(
32
),
k
))),
IntImm
::
make
(
Int
(
32
),
k
)
,
const_true
(
1
)
)),
field_name
.
str
());
field_name
.
str
());
}
}
}
}
...
...
src/pass/narrow_channel_access.cc
View file @
330d49f8
...
@@ -75,7 +75,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
...
@@ -75,7 +75,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
op
=
expr
.
as
<
Load
>
();
op
=
expr
.
as
<
Load
>
();
if
(
read_access_
&&
buf_var_
==
op
->
buffer_var
.
get
())
{
if
(
read_access_
&&
buf_var_
==
op
->
buffer_var
.
get
())
{
return
Load
::
make
(
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
ir
::
Simplify
(
op
->
index
-
min_
));
op
->
type
,
op
->
buffer_var
,
ir
::
Simplify
(
op
->
index
-
min_
),
op
->
predicate
);
}
else
{
}
else
{
return
expr
;
return
expr
;
}
}
...
@@ -85,7 +86,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
...
@@ -85,7 +86,8 @@ class ChannelAccessIndexRewriter : public IRMutator {
op
=
stmt
.
as
<
Store
>
();
op
=
stmt
.
as
<
Store
>
();
if
(
!
read_access_
&&
buf_var_
==
op
->
buffer_var
.
get
())
{
if
(
!
read_access_
&&
buf_var_
==
op
->
buffer_var
.
get
())
{
return
Store
::
make
(
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
ir
::
Simplify
(
op
->
index
-
min_
));
op
->
buffer_var
,
op
->
value
,
ir
::
Simplify
(
op
->
index
-
min_
),
op
->
predicate
);
}
else
{
}
else
{
return
stmt
;
return
stmt
;
}
}
...
...
src/pass/split_pipeline.cc
View file @
330d49f8
...
@@ -170,12 +170,13 @@ class StageSplitter : public IRMutator {
...
@@ -170,12 +170,13 @@ class StageSplitter : public IRMutator {
Expr
index
=
Mutate
(
op
->
index
);
Expr
index
=
Mutate
(
op
->
index
);
Stmt
provide
=
Store
::
make
(
Stmt
provide
=
Store
::
make
(
ch
->
handle_var
,
ch
->
handle_var
,
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
),
0
);
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
index
,
op
->
predicate
),
0
,
op
->
predicate
);
Stmt
temp
=
nest_
.
back
();
nest_
.
pop_back
();
Stmt
temp
=
nest_
.
back
();
nest_
.
pop_back
();
stages_
.
emplace_back
(
BuildStage
(
provide
,
ch
));
stages_
.
emplace_back
(
BuildStage
(
provide
,
ch
));
nest_
.
push_back
(
temp
);
nest_
.
push_back
(
temp
);
fifo_map_
[
ch
->
handle_var
.
get
()]
=
ch
;
fifo_map_
[
ch
->
handle_var
.
get
()]
=
ch
;
return
Load
::
make
(
op
->
type
,
ch
->
handle_var
,
0
);
return
Load
::
make
(
op
->
type
,
ch
->
handle_var
,
0
,
op
->
predicate
);
}
}
Stmt
Split
(
Stmt
stmt
,
const
ProducerConsumer
*
env
)
{
Stmt
Split
(
Stmt
stmt
,
const
ProducerConsumer
*
env
)
{
...
...
src/pass/storage_flatten.cc
View file @
330d49f8
...
@@ -33,7 +33,7 @@ class StorageFlattener : public IRMutator {
...
@@ -33,7 +33,7 @@ class StorageFlattener : public IRMutator {
op
=
stmt
.
as
<
Store
>
();
op
=
stmt
.
as
<
Store
>
();
auto
it
=
extern_buf_remap_
.
find
(
op
->
buffer_var
.
get
());
auto
it
=
extern_buf_remap_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
extern_buf_remap_
.
end
())
{
if
(
it
!=
extern_buf_remap_
.
end
())
{
return
Store
::
make
(
it
->
second
,
op
->
value
,
op
->
index
);
return
Store
::
make
(
it
->
second
,
op
->
value
,
op
->
index
,
op
->
predicate
);
}
else
{
}
else
{
return
stmt
;
return
stmt
;
}
}
...
@@ -115,7 +115,7 @@ class StorageFlattener : public IRMutator {
...
@@ -115,7 +115,7 @@ class StorageFlattener : public IRMutator {
op
=
expr
.
as
<
Load
>
();
op
=
expr
.
as
<
Load
>
();
auto
it
=
extern_buf_remap_
.
find
(
op
->
buffer_var
.
get
());
auto
it
=
extern_buf_remap_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
extern_buf_remap_
.
end
())
{
if
(
it
!=
extern_buf_remap_
.
end
())
{
return
Load
::
make
(
op
->
type
,
it
->
second
,
op
->
index
);
return
Load
::
make
(
op
->
type
,
it
->
second
,
op
->
index
,
op
->
predicate
);
}
else
{
}
else
{
return
expr
;
return
expr
;
}
}
...
...
src/pass/storage_rewrite.cc
View file @
330d49f8
...
@@ -194,14 +194,14 @@ class StoragePlanRewriter : public IRMutator {
...
@@ -194,14 +194,14 @@ class StoragePlanRewriter : public IRMutator {
op
=
stmt
.
as
<
Store
>
();
op
=
stmt
.
as
<
Store
>
();
auto
it
=
alloc_map_
.
find
(
op
->
buffer_var
.
get
());
auto
it
=
alloc_map_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
==
alloc_map_
.
end
())
return
stmt
;
if
(
it
==
alloc_map_
.
end
())
return
stmt
;
return
Store
::
make
(
it
->
second
->
alloc_var
,
op
->
value
,
op
->
index
);
return
Store
::
make
(
it
->
second
->
alloc_var
,
op
->
value
,
op
->
index
,
op
->
predicate
);
}
}
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
Expr
expr
=
IRMutator
::
Mutate_
(
op
,
e
);
op
=
expr
.
as
<
Load
>
();
op
=
expr
.
as
<
Load
>
();
auto
it
=
alloc_map_
.
find
(
op
->
buffer_var
.
get
());
auto
it
=
alloc_map_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
==
alloc_map_
.
end
())
return
expr
;
if
(
it
==
alloc_map_
.
end
())
return
expr
;
return
Load
::
make
(
op
->
type
,
it
->
second
->
alloc_var
,
op
->
index
);
return
Load
::
make
(
op
->
type
,
it
->
second
->
alloc_var
,
op
->
index
,
op
->
predicate
);
}
}
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
Expr
Mutate_
(
const
Variable
*
op
,
const
Expr
&
e
)
final
{
auto
it
=
alloc_map_
.
find
(
op
);
auto
it
=
alloc_map_
.
find
(
op
);
...
...
src/pass/storage_sync.cc
View file @
330d49f8
...
@@ -100,7 +100,7 @@ class StorageSyncPlanner : public IRVisitor {
...
@@ -100,7 +100,7 @@ class StorageSyncPlanner : public IRVisitor {
}
}
}
}
void
Visit_
(
const
Call
*
op
)
final
{
void
Visit_
(
const
Call
*
op
)
final
{
if
(
op
->
is_intrinsic
(
Call
::
address_of
))
{
if
(
op
->
is_intrinsic
(
intrinsic
::
tvm_
address_of
))
{
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
const
Load
*
l
=
op
->
args
[
0
].
as
<
Load
>
();
IRVisitor
::
Visit_
(
l
);
IRVisitor
::
Visit_
(
l
);
}
else
{
}
else
{
...
...
src/pass/vectorize_loop.cc
View file @
330d49f8
...
@@ -34,7 +34,8 @@ class VecAllocAccess : public IRMutator {
...
@@ -34,7 +34,8 @@ class VecAllocAccess : public IRMutator {
op
=
expr
.
as
<
Load
>
();
op
=
expr
.
as
<
Load
>
();
if
(
op
->
buffer_var
.
get
()
==
buf_
)
{
if
(
op
->
buffer_var
.
get
()
==
buf_
)
{
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
return
Load
::
make
(
op
->
type
,
op
->
buffer_var
,
op
->
index
*
var_lanes_
+
var_
);
op
->
index
*
var_lanes_
+
var_
,
op
->
predicate
);
}
else
{
}
else
{
return
expr
;
return
expr
;
}
}
...
@@ -46,7 +47,8 @@ class VecAllocAccess : public IRMutator {
...
@@ -46,7 +47,8 @@ class VecAllocAccess : public IRMutator {
if
(
op
->
buffer_var
.
get
()
==
buf_
)
{
if
(
op
->
buffer_var
.
get
()
==
buf_
)
{
return
Store
::
make
(
op
->
buffer_var
,
return
Store
::
make
(
op
->
buffer_var
,
op
->
value
,
op
->
value
,
op
->
index
*
var_lanes_
+
var_
);
op
->
index
*
var_lanes_
+
var_
,
op
->
predicate
);
}
else
{
}
else
{
return
stmt
;
return
stmt
;
}
}
...
@@ -160,11 +162,16 @@ class Vectorizer : public IRMutator {
...
@@ -160,11 +162,16 @@ class Vectorizer : public IRMutator {
// Load
// Load
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
Mutate_
(
const
Load
*
op
,
const
Expr
&
e
)
final
{
Expr
index
=
this
->
Mutate
(
op
->
index
);
Expr
index
=
this
->
Mutate
(
op
->
index
);
if
(
index
.
same_as
(
op
->
index
))
{
Expr
pred
=
this
->
Mutate
(
op
->
predicate
);
if
(
index
.
same_as
(
op
->
index
)
&&
pred
.
same_as
(
op
->
predicate
))
{
return
e
;
return
e
;
}
else
{
}
else
{
return
Load
::
make
(
op
->
type
.
with_lanes
(
index
.
type
().
lanes
()),
int
lanes
=
std
::
max
(
index
.
type
().
lanes
(),
pred
.
type
().
lanes
());
op
->
buffer_var
,
index
);
return
Load
::
make
(
op
->
type
.
with_lanes
(
lanes
),
op
->
buffer_var
,
BroadcastTo
(
index
,
lanes
),
BroadcastTo
(
pred
,
lanes
));
}
}
}
}
// Let
// Let
...
@@ -201,13 +208,16 @@ class Vectorizer : public IRMutator {
...
@@ -201,13 +208,16 @@ class Vectorizer : public IRMutator {
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
Store
*
op
,
const
Stmt
&
s
)
final
{
Expr
value
=
this
->
Mutate
(
op
->
value
);
Expr
value
=
this
->
Mutate
(
op
->
value
);
Expr
index
=
this
->
Mutate
(
op
->
index
);
Expr
index
=
this
->
Mutate
(
op
->
index
);
Expr
pred
=
this
->
Mutate
(
op
->
predicate
);
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
if
(
value
.
same_as
(
op
->
value
)
&&
index
.
same_as
(
op
->
index
))
{
return
s
;
return
s
;
}
else
{
}
else
{
int
lanes
=
std
::
max
(
value
.
type
().
lanes
(),
index
.
type
().
lanes
());
int
lanes
=
std
::
max
(
value
.
type
().
lanes
(),
index
.
type
().
lanes
());
lanes
=
std
::
max
(
lanes
,
pred
.
type
().
lanes
());
return
Store
::
make
(
op
->
buffer_var
,
return
Store
::
make
(
op
->
buffer_var
,
BroadcastTo
(
value
,
lanes
),
BroadcastTo
(
value
,
lanes
),
BroadcastTo
(
index
,
lanes
));
BroadcastTo
(
index
,
lanes
),
BroadcastTo
(
pred
,
lanes
));
}
}
}
}
// For
// For
...
...
tests/cpp/ir_cse_pass_test.cc
deleted
100644 → 0
View file @
d3c8256b
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <pass/CSE.h>
TEST
(
IR_PASS
,
CSE
)
{
using
namespace
Halide
::
Internal
;
cse_test
();
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
return
RUN_ALL_TESTS
();
}
tests/cpp/ir_simplify_test.cc
View file @
330d49f8
#include <dmlc/logging.h>
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/tvm.h>
#include <
pass
/Simplify.h>
#include <
arithmetic
/Simplify.h>
TEST
(
IRSIMPLIFY
,
Basic
)
{
TEST
(
IRSIMPLIFY
,
Basic
)
{
using
namespace
Halide
::
Internal
;
using
namespace
Halide
::
Internal
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment