Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
59448fed
Unverified
Commit
59448fed
authored
Jul 06, 2019
by
Tianqi Chen
Committed by
GitHub
Jul 06, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[ARITH] Refactor: Remove un-necessary usage of ComputeExpr (#3503)
parent
54f9d20a
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
45 additions
and
65 deletions
+45
-65
src/arithmetic/compute_expr.h
+10
-12
src/arithmetic/detect_linear_equation.cc
+3
-4
src/codegen/codegen_cuda.cc
+0
-1
src/codegen/llvm/codegen_llvm.cc
+1
-3
src/codegen/spirv/codegen_spirv.cc
+3
-5
src/lang/buffer.cc
+1
-2
src/pass/arg_binder.cc
+0
-1
src/pass/inject_double_buffer.cc
+7
-6
src/pass/inject_virtual_thread.cc
+3
-3
src/pass/lower_thread_allreduce.cc
+0
-0
src/pass/lower_warp_memory.cc
+0
-2
src/pass/make_api.cc
+0
-1
src/pass/storage_flatten.cc
+3
-4
src/pass/unroll_loop.cc
+1
-5
src/pass/vectorize_loop.cc
+3
-4
src/schedule/message_passing.cc
+10
-12
No files found.
src/arithmetic/compute_expr.h
View file @
59448fed
...
@@ -18,10 +18,8 @@
...
@@ -18,10 +18,8 @@
*/
*/
/*!
/*!
* Copyright (c) 2017 by Contributors
* \file compute_expr.h
* \file compute_expr.h
* \brief Utility integer expression with quick eager simplification.
* \brief Utility to invoke certan compute operations.
* This is weaker than Simplify but can be done Eagerly.
*/
*/
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#ifndef TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
#define TVM_ARITHMETIC_COMPUTE_EXPR_H_
...
@@ -41,7 +39,7 @@ namespace arith {
...
@@ -41,7 +39,7 @@ namespace arith {
* \return The result.
* \return The result.
*/
*/
template
<
typename
OP
>
template
<
typename
OP
>
inline
Expr
Compute
Expr
(
Expr
lhs
,
Expr
rhs
)
{
inline
Expr
Compute
(
Expr
lhs
,
Expr
rhs
)
{
return
OP
::
make
(
lhs
,
rhs
);
return
OP
::
make
(
lhs
,
rhs
);
}
}
...
@@ -79,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) {
...
@@ -79,37 +77,37 @@ inline bool GetConstInt(Expr e, int* out) {
}
}
template
<>
template
<>
inline
Expr
Compute
Expr
<
ir
::
Add
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
Compute
<
ir
::
Add
>
(
Expr
a
,
Expr
b
)
{
return
a
+
b
;
return
a
+
b
;
}
}
template
<>
template
<>
inline
Expr
Compute
Expr
<
ir
::
Sub
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
Compute
<
ir
::
Sub
>
(
Expr
a
,
Expr
b
)
{
return
a
-
b
;
return
a
-
b
;
}
}
template
<>
template
<>
inline
Expr
Compute
Expr
<
ir
::
Mul
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
Compute
<
ir
::
Mul
>
(
Expr
a
,
Expr
b
)
{
return
a
*
b
;
return
a
*
b
;
}
}
template
<>
template
<>
inline
Expr
Compute
Expr
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
Compute
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
return
a
/
b
;
return
a
/
b
;
}
}
template
<>
template
<>
inline
Expr
Compute
Expr
<
ir
::
Mod
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
Compute
<
ir
::
Mod
>
(
Expr
a
,
Expr
b
)
{
return
a
%
b
;
return
a
%
b
;
}
}
template
<>
template
<>
inline
Expr
Compute
Expr
<
ir
::
Max
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
Compute
<
ir
::
Max
>
(
Expr
a
,
Expr
b
)
{
return
max
(
a
,
b
);
return
max
(
a
,
b
);
}
}
template
<>
template
<>
inline
Expr
Compute
Expr
<
ir
::
Min
>
(
Expr
a
,
Expr
b
)
{
inline
Expr
Compute
<
ir
::
Min
>
(
Expr
a
,
Expr
b
)
{
return
min
(
a
,
b
);
return
min
(
a
,
b
);
}
}
...
@@ -121,7 +119,7 @@ inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
...
@@ -121,7 +119,7 @@ inline Expr ComputeReduce(const Array<Expr>& values, Expr empty_value) {
}
}
Expr
res
=
values
[
0
];
Expr
res
=
values
[
0
];
for
(
size_t
i
=
1
;
i
<
values
.
size
();
++
i
)
{
for
(
size_t
i
=
1
;
i
<
values
.
size
();
++
i
)
{
res
=
Compute
Expr
<
Op
>
(
res
,
values
[
i
]);
res
=
Compute
<
Op
>
(
res
,
values
[
i
]);
}
}
return
res
;
return
res
;
}
}
...
...
src/arithmetic/detect_linear_equation.cc
View file @
59448fed
...
@@ -27,7 +27,6 @@
...
@@ -27,7 +27,6 @@
#include <tvm/ir_visitor.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/ir_functor_ext.h>
#include <tvm/arithmetic.h>
#include <tvm/arithmetic.h>
#include "compute_expr.h"
namespace
tvm
{
namespace
tvm
{
namespace
arith
{
namespace
arith
{
...
@@ -127,18 +126,18 @@ class LinearEqDetector
...
@@ -127,18 +126,18 @@ class LinearEqDetector
Expr
AddCombine
(
Expr
a
,
Expr
b
)
{
Expr
AddCombine
(
Expr
a
,
Expr
b
)
{
if
(
!
a
.
defined
())
return
b
;
if
(
!
a
.
defined
())
return
b
;
if
(
!
b
.
defined
())
return
a
;
if
(
!
b
.
defined
())
return
a
;
return
ComputeExpr
<
Add
>
(
a
,
b
)
;
return
a
+
b
;
}
}
Expr
SubCombine
(
Expr
a
,
Expr
b
)
{
Expr
SubCombine
(
Expr
a
,
Expr
b
)
{
// Check b first in case they are both undefined
// Check b first in case they are both undefined
if
(
!
b
.
defined
())
return
a
;
if
(
!
b
.
defined
())
return
a
;
if
(
!
a
.
defined
())
return
-
b
;
if
(
!
a
.
defined
())
return
-
b
;
return
ComputeExpr
<
Sub
>
(
a
,
b
)
;
return
a
-
b
;
}
}
Expr
MulCombine
(
Expr
a
,
Expr
b
)
{
Expr
MulCombine
(
Expr
a
,
Expr
b
)
{
if
(
!
a
.
defined
())
return
a
;
if
(
!
a
.
defined
())
return
a
;
if
(
!
b
.
defined
())
return
b
;
if
(
!
b
.
defined
())
return
b
;
return
ComputeExpr
<
Mul
>
(
a
,
b
)
;
return
a
*
b
;
}
}
};
};
...
...
src/codegen/codegen_cuda.cc
View file @
59448fed
...
@@ -27,7 +27,6 @@
...
@@ -27,7 +27,6 @@
#include <vector>
#include <vector>
#include <string>
#include <string>
#include "codegen_cuda.h"
#include "codegen_cuda.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
tvm
{
namespace
codegen
{
namespace
codegen
{
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
59448fed
...
@@ -748,9 +748,7 @@ void CodeGenLLVM::Scalarize(const Expr& e,
...
@@ -748,9 +748,7 @@ void CodeGenLLVM::Scalarize(const Expr& e,
std
::
function
<
void
(
int
i
,
llvm
::
Value
*
v
)
>
f
)
{
std
::
function
<
void
(
int
i
,
llvm
::
Value
*
v
)
>
f
)
{
if
(
const
Ramp
*
ramp
=
e
.
as
<
Ramp
>
())
{
if
(
const
Ramp
*
ramp
=
e
.
as
<
Ramp
>
())
{
for
(
int
i
=
0
;
i
<
ramp
->
type
.
lanes
();
++
i
)
{
for
(
int
i
=
0
;
i
<
ramp
->
type
.
lanes
();
++
i
)
{
Expr
offset
=
arith
::
ComputeExpr
<
Add
>
(
Expr
offset
=
ramp
->
base
+
(
ramp
->
stride
*
i
);
ramp
->
base
,
arith
::
ComputeExpr
<
Mul
>
(
ramp
->
stride
,
i
));
f
(
i
,
MakeValue
(
offset
));
f
(
i
,
MakeValue
(
offset
));
}
}
}
else
{
}
else
{
...
...
src/codegen/spirv/codegen_spirv.cc
View file @
59448fed
...
@@ -25,8 +25,8 @@
...
@@ -25,8 +25,8 @@
#include <tvm/ir.h>
#include <tvm/ir.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <string>
#include <string>
#include "../../arithmetic/compute_expr.h"
#include "codegen_spirv.h"
#include "codegen_spirv.h"
#include "../../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
tvm
{
namespace
codegen
{
namespace
codegen
{
...
@@ -339,7 +339,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
...
@@ -339,7 +339,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Ramp* op) {
spirv
::
Value
v
=
base
;
spirv
::
Value
v
=
base
;
if
(
i
!=
0
)
{
if
(
i
!=
0
)
{
spirv
::
Value
offset
=
MakeValue
(
spirv
::
Value
offset
=
MakeValue
(
arith
::
ComputeExpr
<
Mul
>
(
make_const
(
op
->
stride
.
type
(),
i
),
op
->
stride
)
);
make_const
(
op
->
stride
.
type
(),
i
)
*
op
->
stride
);
v
=
builder_
->
Add
(
v
,
offset
);
v
=
builder_
->
Add
(
v
,
offset
);
}
}
values
.
push_back
(
v
);
values
.
push_back
(
v
);
...
@@ -419,9 +419,7 @@ void CodeGenSPIRV::Scalarize(const Expr& e,
...
@@ -419,9 +419,7 @@ void CodeGenSPIRV::Scalarize(const Expr& e,
std
::
function
<
void
(
int
i
,
spirv
::
Value
v
)
>
f
)
{
std
::
function
<
void
(
int
i
,
spirv
::
Value
v
)
>
f
)
{
if
(
const
Ramp
*
ramp
=
e
.
as
<
Ramp
>
())
{
if
(
const
Ramp
*
ramp
=
e
.
as
<
Ramp
>
())
{
for
(
int
i
=
0
;
i
<
ramp
->
type
.
lanes
();
++
i
)
{
for
(
int
i
=
0
;
i
<
ramp
->
type
.
lanes
();
++
i
)
{
Expr
offset
=
arith
::
ComputeExpr
<
Add
>
(
Expr
offset
=
ramp
->
base
+
ramp
->
stride
*
i
;
ramp
->
base
,
arith
::
ComputeExpr
<
Mul
>
(
ramp
->
stride
,
i
));
f
(
i
,
MakeValue
(
offset
));
f
(
i
,
MakeValue
(
offset
));
}
}
}
else
{
}
else
{
...
...
src/lang/buffer.cc
View file @
59448fed
...
@@ -378,8 +378,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
...
@@ -378,8 +378,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, Expr
extent
=
make_const
(
self
->
DefaultIndexType
(),
1
);
extent
=
make_const
(
self
->
DefaultIndexType
(),
1
);
}
else
if
(
self
->
strides
.
size
()
==
self
->
shape
.
size
())
{
}
else
if
(
self
->
strides
.
size
()
==
self
->
shape
.
size
())
{
int
highest_dim
=
0
;
int
highest_dim
=
0
;
extent
=
arith
::
ComputeExpr
<
ir
::
Mul
>
(
extent
=
self
->
strides
[
highest_dim
]
*
self
->
shape
[
highest_dim
]
-
offset
;
self
->
strides
[
highest_dim
],
self
->
shape
[
highest_dim
])
-
offset
;
}
else
{
}
else
{
extent
=
arith
::
ComputeReduce
<
ir
::
Mul
>
(
self
->
shape
,
Expr
())
-
offset
;
extent
=
arith
::
ComputeReduce
<
ir
::
Mul
>
(
self
->
shape
,
Expr
())
-
offset
;
}
}
...
...
src/pass/arg_binder.cc
View file @
59448fed
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
*/
*/
/*!
/*!
* Copyright (c) 2017 by Contributors
* \file arg_binder.cc
* \file arg_binder.cc
* \brief Helper utility to match and bind arguments.
* \brief Helper utility to match and bind arguments.
*/
*/
...
...
src/pass/inject_double_buffer.cc
View file @
59448fed
...
@@ -26,6 +26,7 @@
...
@@ -26,6 +26,7 @@
#include <tvm/ir_pass.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_visitor.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_mutator.h>
#include <tvm/expr_operator.h>
#include "ir_util.h"
#include "ir_util.h"
#include "../arithmetic/compute_expr.h"
#include "../arithmetic/compute_expr.h"
...
@@ -100,8 +101,8 @@ class DoubleBufferInjector : public IRMutator {
...
@@ -100,8 +101,8 @@ class DoubleBufferInjector : public IRMutator {
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
Allocate
*
op
,
const
Stmt
&
s
)
final
{
auto
it
=
dbuffer_info_
.
find
(
op
->
buffer_var
.
get
());
auto
it
=
dbuffer_info_
.
find
(
op
->
buffer_var
.
get
());
if
(
it
!=
dbuffer_info_
.
end
())
{
if
(
it
!=
dbuffer_info_
.
end
())
{
it
->
second
.
stride
=
arith
::
ComputeReduce
<
Mul
>
it
->
second
.
stride
=
arith
::
ComputeReduce
<
Mul
>
(
(
op
->
extents
,
Expr
())
*
op
->
type
.
lanes
();
op
->
extents
,
Expr
())
*
op
->
type
.
lanes
();
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
Stmt
stmt
=
IRMutator
::
Mutate_
(
op
,
s
);
op
=
stmt
.
as
<
Allocate
>
();
op
=
stmt
.
as
<
Allocate
>
();
Array
<
Expr
>
new_extents
{
make_const
(
op
->
extents
[
0
].
type
(),
2
)};
Array
<
Expr
>
new_extents
{
make_const
(
op
->
extents
[
0
].
type
(),
2
)};
...
@@ -135,11 +136,11 @@ class DoubleBufferInjector : public IRMutator {
...
@@ -135,11 +136,11 @@ class DoubleBufferInjector : public IRMutator {
<<
"It is better to split with multiple of 2"
;
<<
"It is better to split with multiple of 2"
;
CHECK
(
is_zero
(
old_loop
->
min
));
CHECK
(
is_zero
(
old_loop
->
min
));
Expr
zero
=
old_loop
->
min
;
Expr
zero
=
old_loop
->
min
;
Expr
new_ext
=
arith
::
ComputeExpr
<
Sub
>
(
Expr
new_ext
=
old_loop
->
extent
,
make_const
(
old_loop
->
loop_var
.
type
(),
1
)
);
old_loop
->
extent
-
make_const
(
old_loop
->
loop_var
.
type
(),
1
);
Expr
factor
=
make_const
(
new_ext
.
type
(),
split_loop_
);
Expr
factor
=
make_const
(
new_ext
.
type
(),
split_loop_
);
Expr
outer_ext
=
arith
::
ComputeExpr
<
Div
>
(
new_ext
,
factor
)
;
Expr
outer_ext
=
new_ext
/
factor
;
Expr
tail_base
=
arith
::
ComputeExpr
<
Mul
>
(
outer_ext
,
factor
)
;
Expr
tail_base
=
outer_ext
*
factor
;
Var
outer_var
(
old_loop
->
loop_var
->
name_hint
+
".outer"
,
old_loop
->
loop_var
.
type
());
Var
outer_var
(
old_loop
->
loop_var
->
name_hint
+
".outer"
,
old_loop
->
loop_var
.
type
());
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vmap
;
std
::
unordered_map
<
const
Variable
*
,
Expr
>
vmap
;
std
::
vector
<
Stmt
>
loop_seq
;
std
::
vector
<
Stmt
>
loop_seq
;
...
...
src/pass/inject_virtual_thread.cc
View file @
59448fed
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
*/
*/
/*!
/*!
* Copyright (c) 2017 by Contributors
* \file inject_virtual_thread.cc
* \file inject_virtual_thread.cc
*/
*/
#include <tvm/ir.h>
#include <tvm/ir.h>
...
@@ -37,6 +36,7 @@ class ExprTouched final : public IRVisitor {
...
@@ -37,6 +36,7 @@ class ExprTouched final : public IRVisitor {
explicit
ExprTouched
(
const
std
::
unordered_set
<
const
Variable
*>
&
touched
,
explicit
ExprTouched
(
const
std
::
unordered_set
<
const
Variable
*>
&
touched
,
bool
check_write
)
bool
check_write
)
:
touched_var_
(
touched
),
check_write_
(
check_write
)
{}
:
touched_var_
(
touched
),
check_write_
(
check_write
)
{}
void
Visit
(
const
NodeRef
&
n
)
final
{
void
Visit
(
const
NodeRef
&
n
)
final
{
// early stopping
// early stopping
if
(
expr_touched_
&&
!
check_write_
)
return
;
if
(
expr_touched_
&&
!
check_write_
)
return
;
...
@@ -241,8 +241,8 @@ class VTInjector : public IRMutator {
...
@@ -241,8 +241,8 @@ class VTInjector : public IRMutator {
visit_touched_var_
=
true
;
visit_touched_var_
=
true
;
Expr
offset
=
Mutate
(
op
->
args
[
2
]);
Expr
offset
=
Mutate
(
op
->
args
[
2
]);
Expr
extent
=
Mutate
(
op
->
args
[
3
]);
Expr
extent
=
Mutate
(
op
->
args
[
3
]);
Expr
stride
=
arith
::
ComputeExpr
<
Div
>
(
Expr
stride
=
it
->
second
,
make_const
(
offset
.
type
(),
dtype
.
lanes
()
));
it
->
second
/
make_const
(
offset
.
type
(),
dtype
.
lanes
(
));
offset
=
stride
*
var_
+
offset
;
offset
=
stride
*
var_
+
offset
;
return
Call
::
make
(
return
Call
::
make
(
op
->
type
,
op
->
name
,
op
->
type
,
op
->
name
,
...
...
src/pass/lower_thread_allreduce.cc
View file @
59448fed
src/pass/lower_warp_memory.cc
View file @
59448fed
...
@@ -18,8 +18,6 @@
...
@@ -18,8 +18,6 @@
*/
*/
/*!
/*!
* Copyright (c) 2018 by Contributors
*
* Lower warp memory to use local memory
* Lower warp memory to use local memory
* and shuffle intrinsics.
* and shuffle intrinsics.
*
*
...
...
src/pass/make_api.cc
View file @
59448fed
...
@@ -33,7 +33,6 @@
...
@@ -33,7 +33,6 @@
#include "ir_util.h"
#include "ir_util.h"
#include "arg_binder.h"
#include "arg_binder.h"
#include "../arithmetic/compute_expr.h"
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
...
...
src/pass/storage_flatten.cc
View file @
59448fed
...
@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator {
...
@@ -211,7 +211,7 @@ class StorageFlattener : public IRMutator {
stride
=
ir
::
Simplify
(
stride
);
stride
=
ir
::
Simplify
(
stride
);
}
}
rstrides
.
push_back
(
stride
);
rstrides
.
push_back
(
stride
);
stride
=
arith
::
ComputeExpr
<
Mul
>
(
stride
,
shape
[
dim
])
;
stride
=
stride
*
shape
[
dim
]
;
}
}
strides
=
Array
<
Expr
>
(
rstrides
.
rbegin
(),
rstrides
.
rend
());
strides
=
Array
<
Expr
>
(
rstrides
.
rbegin
(),
rstrides
.
rend
());
}
}
...
@@ -237,7 +237,7 @@ class StorageFlattener : public IRMutator {
...
@@ -237,7 +237,7 @@ class StorageFlattener : public IRMutator {
int
first_dim
=
0
;
int
first_dim
=
0
;
ret
=
Allocate
::
make
(
ret
=
Allocate
::
make
(
e
.
buffer
->
data
,
storage_type
,
e
.
buffer
->
data
,
storage_type
,
{
arith
::
ComputeExpr
<
Mul
>
(
e
.
buffer
->
strides
[
first_dim
],
e
.
buffer
->
shape
[
first_dim
])
},
{
e
.
buffer
->
strides
[
first_dim
]
*
e
.
buffer
->
shape
[
first_dim
]
},
make_const
(
Bool
(
e
.
buffer
->
dtype
.
lanes
()),
true
),
body
);
make_const
(
Bool
(
e
.
buffer
->
dtype
.
lanes
()),
true
),
body
);
}
else
{
}
else
{
shape
=
e
.
buffer
->
shape
;
shape
=
e
.
buffer
->
shape
;
...
@@ -414,8 +414,7 @@ class StorageFlattener : public IRMutator {
...
@@ -414,8 +414,7 @@ class StorageFlattener : public IRMutator {
if
(
be
.
bounds
.
size
()
!=
0
)
{
if
(
be
.
bounds
.
size
()
!=
0
)
{
CHECK_EQ
(
tuple
->
args
.
size
(),
be
.
bounds
.
size
()
*
2
);
CHECK_EQ
(
tuple
->
args
.
size
(),
be
.
bounds
.
size
()
*
2
);
for
(
size_t
i
=
0
;
i
<
be
.
buffer
->
shape
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
be
.
buffer
->
shape
.
size
();
++
i
)
{
begins
.
push_back
(
begins
.
push_back
(
tuple
->
args
[
2
*
i
]
-
be
.
bounds
[
i
]
->
min
);
arith
::
ComputeExpr
<
Sub
>
(
tuple
->
args
[
2
*
i
],
be
.
bounds
[
i
]
->
min
));
extents
.
push_back
(
tuple
->
args
[
2
*
i
+
1
]);
extents
.
push_back
(
tuple
->
args
[
2
*
i
+
1
]);
}
}
}
else
{
}
else
{
...
...
src/pass/unroll_loop.cc
View file @
59448fed
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
*/
*/
/*!
/*!
* Copyright (c) 2017 by Contributors
* Loop unrolling as in Halide pipeline.
* Loop unrolling as in Halide pipeline.
* \file unroll_loop.cc
* \file unroll_loop.cc
*/
*/
...
@@ -144,7 +143,6 @@ class LoopUnroller : public IRMutator {
...
@@ -144,7 +143,6 @@ class LoopUnroller : public IRMutator {
}
}
Stmt
Unroll
(
const
For
*
op
)
{
Stmt
Unroll
(
const
For
*
op
)
{
using
arith
::
ComputeExpr
;
int
value
=
GetExtent
(
op
);
int
value
=
GetExtent
(
op
);
// For loop must have a constant integer extent
// For loop must have a constant integer extent
CHECK_NE
(
value
,
-
1
)
<<
"loop doesn't have a constant integer extent"
;
CHECK_NE
(
value
,
-
1
)
<<
"loop doesn't have a constant integer extent"
;
...
@@ -154,9 +152,7 @@ class LoopUnroller : public IRMutator {
...
@@ -154,9 +152,7 @@ class LoopUnroller : public IRMutator {
Stmt
unrolled
;
Stmt
unrolled
;
for
(
int
i
=
0
;
i
<
value
;
++
i
)
{
for
(
int
i
=
0
;
i
<
value
;
++
i
)
{
Var
lv
(
op
->
loop_var
.
node_
);
Var
lv
(
op
->
loop_var
.
node_
);
vmap
.
Set
(
lv
,
vmap
.
Set
(
lv
,
op
->
min
+
make_const
(
op
->
loop_var
.
type
(),
i
));
ComputeExpr
<
Add
>
(
op
->
min
,
make_const
(
op
->
loop_var
.
type
(),
i
)));
Stmt
step
=
Substitute
(
body
,
vmap
);
Stmt
step
=
Substitute
(
body
,
vmap
);
if
(
unrolled
.
defined
())
{
if
(
unrolled
.
defined
())
{
unrolled
=
Block
::
make
(
unrolled
,
step
);
unrolled
=
Block
::
make
(
unrolled
,
step
);
...
...
src/pass/vectorize_loop.cc
View file @
59448fed
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
*/
*/
/*!
/*!
* Copyright (c) 2017 by Contributors
* \file vectorize_loop.cc
* \file vectorize_loop.cc
*/
*/
// Loop vectorizer as in Halide pipeline.
// Loop vectorizer as in Halide pipeline.
...
@@ -486,13 +485,13 @@ class Vectorizer : public IRMutator {
...
@@ -486,13 +485,13 @@ class Vectorizer : public IRMutator {
const
Ramp
*
a_ramp
=
a
.
as
<
Ramp
>
();
const
Ramp
*
a_ramp
=
a
.
as
<
Ramp
>
();
if
(
a
.
type
().
lanes
()
==
1
&&
b_ramp
)
{
if
(
a
.
type
().
lanes
()
==
1
&&
b_ramp
)
{
return
Ramp
::
make
(
return
Ramp
::
make
(
arith
::
Compute
Expr
<
T
>
(
a
,
b_ramp
->
base
),
arith
::
Compute
<
T
>
(
a
,
b_ramp
->
base
),
arith
::
Compute
Expr
<
T
>
(
make_zero
(
b_ramp
->
stride
.
type
()),
b_ramp
->
stride
),
arith
::
Compute
<
T
>
(
make_zero
(
b_ramp
->
stride
.
type
()),
b_ramp
->
stride
),
b_ramp
->
lanes
);
b_ramp
->
lanes
);
}
}
if
(
b
.
type
().
lanes
()
==
1
&&
a_ramp
)
{
if
(
b
.
type
().
lanes
()
==
1
&&
a_ramp
)
{
return
Ramp
::
make
(
return
Ramp
::
make
(
arith
::
Compute
Expr
<
T
>
(
a_ramp
->
base
,
b
),
a_ramp
->
stride
,
a_ramp
->
lanes
);
arith
::
Compute
<
T
>
(
a_ramp
->
base
,
b
),
a_ramp
->
stride
,
a_ramp
->
lanes
);
}
}
}
}
return
T
::
make
(
BroadcastTo
(
a
,
lanes
),
BroadcastTo
(
b
,
lanes
));
return
T
::
make
(
BroadcastTo
(
a
,
lanes
),
BroadcastTo
(
b
,
lanes
));
...
...
src/schedule/message_passing.cc
View file @
59448fed
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
*/
*/
/*!
/*!
* Copyright (c) 2017 by Contributors
* \file message_passing.cc
* \file message_passing.cc
* \brief The message passing domain.
* \brief The message passing domain.
*/
*/
...
@@ -32,12 +31,11 @@ namespace tvm {
...
@@ -32,12 +31,11 @@ namespace tvm {
namespace
schedule
{
namespace
schedule
{
using
namespace
ir
;
using
namespace
ir
;
using
namespace
arith
;
void
Update
(
std
::
unordered_map
<
IterVar
,
Range
>*
p_state
,
void
Update
(
std
::
unordered_map
<
IterVar
,
Range
>*
p_state
,
const
IterVar
&
iv
,
const
IterVar
&
iv
,
Range
r
,
Range
r
,
Analyzer
*
analyzer
)
{
arith
::
Analyzer
*
analyzer
)
{
auto
it
=
p_state
->
find
(
iv
);
auto
it
=
p_state
->
find
(
iv
);
if
(
it
==
p_state
->
end
())
{
if
(
it
==
p_state
->
end
())
{
(
*
p_state
)[
iv
]
=
r
;
(
*
p_state
)[
iv
]
=
r
;
...
@@ -145,8 +143,8 @@ void PassUpIndex(const Stage& stage,
...
@@ -145,8 +143,8 @@ void PassUpIndex(const Stage& stage,
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
factor
=
dom_map
.
at
(
s
->
inner
)
->
extent
;
Expr
outer_min
=
dom_map
.
at
(
s
->
outer
)
->
min
;
Expr
outer_min
=
dom_map
.
at
(
s
->
outer
)
->
min
;
Expr
inner_min
=
dom_map
.
at
(
s
->
inner
)
->
min
;
Expr
inner_min
=
dom_map
.
at
(
s
->
inner
)
->
min
;
state
[
s
->
outer
]
=
ComputeExpr
<
Div
>
(
value
,
factor
)
;
state
[
s
->
outer
]
=
value
/
factor
;
state
[
s
->
inner
]
=
ComputeExpr
<
Mod
>
(
value
,
factor
)
;
state
[
s
->
inner
]
=
value
%
factor
;
// add min if they exist
// add min if they exist
if
(
!
is_zero
(
outer_min
))
{
if
(
!
is_zero
(
outer_min
))
{
state
[
s
->
outer
]
=
state
[
s
->
outer
]
+
outer_min
;
state
[
s
->
outer
]
=
state
[
s
->
outer
]
+
outer_min
;
...
@@ -189,8 +187,8 @@ void PassDownIndex(const Stage& stage,
...
@@ -189,8 +187,8 @@ void PassDownIndex(const Stage& stage,
CHECK
(
is_zero
(
r
->
min
));
CHECK
(
is_zero
(
r
->
min
));
Expr
parent
=
state
.
at
(
s
->
parent
);
Expr
parent
=
state
.
at
(
s
->
parent
);
Expr
factor
=
r
->
extent
;
Expr
factor
=
r
->
extent
;
state
[
s
->
outer
]
=
ComputeExpr
<
Div
>
(
parent
,
factor
)
;
state
[
s
->
outer
]
=
parent
/
factor
;
state
[
s
->
inner
]
=
ComputeExpr
<
Mod
>
(
parent
,
factor
)
;
state
[
s
->
inner
]
=
parent
%
factor
;
}
else
if
(
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
())
{
}
else
if
(
const
FuseNode
*
s
=
rel
.
as
<
FuseNode
>
())
{
if
(
!
state
.
count
(
s
->
inner
)
&&
!
state
.
count
(
s
->
outer
))
{
if
(
!
state
.
count
(
s
->
inner
)
&&
!
state
.
count
(
s
->
outer
))
{
CHECK
(
allow_missing
);
CHECK
(
allow_missing
);
...
@@ -240,7 +238,7 @@ void PassUpDomain(const SplitNode* s,
...
@@ -240,7 +238,7 @@ void PassUpDomain(const SplitNode* s,
CHECK
(
outer
.
defined
());
CHECK
(
outer
.
defined
());
CHECK
(
inner
.
defined
());
CHECK
(
inner
.
defined
());
CHECK
(
factor
.
defined
());
CHECK
(
factor
.
defined
());
*
parent
=
EvalSet
(
*
parent
=
arith
::
EvalSet
(
s
->
outer
->
var
*
factor
+
s
->
inner
->
var
+
parent_min
,
s
->
outer
->
var
*
factor
+
s
->
inner
->
var
+
parent_min
,
{{
s
->
outer
,
outer
},
{
s
->
inner
,
inner
}});
{{
s
->
outer
,
outer
},
{
s
->
inner
,
inner
}});
}
}
...
@@ -290,7 +288,7 @@ void PassUpDomain(const RebaseNode* s,
...
@@ -290,7 +288,7 @@ void PassUpDomain(const RebaseNode* s,
return
;
return
;
}
}
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
Expr
parent_min
=
dom_map
.
at
(
s
->
parent
)
->
min
;
*
parent
=
EvalSet
(
s
->
rebased
->
var
+
parent_min
,
*
parent
=
arith
::
EvalSet
(
s
->
rebased
->
var
+
parent_min
,
{{
s
->
rebased
,
rebased
}});
{{
s
->
rebased
,
rebased
}});
}
}
...
@@ -476,7 +474,7 @@ std::vector<Expr> MakeBoundCheck(
...
@@ -476,7 +474,7 @@ std::vector<Expr> MakeBoundCheck(
const
std
::
unordered_map
<
IterVar
,
Expr
>&
value_map
,
const
std
::
unordered_map
<
IterVar
,
Expr
>&
value_map
,
bool
skip_ivar_domain
,
bool
skip_ivar_domain
,
const
std
::
unordered_set
<
IterVar
>&
skip_iter
)
{
const
std
::
unordered_set
<
IterVar
>&
skip_iter
)
{
Analyzer
analyzer
;
arith
::
Analyzer
analyzer
;
std
::
unordered_map
<
IterVar
,
bool
>
bound_state
;
std
::
unordered_map
<
IterVar
,
bool
>
bound_state
;
for
(
IterVar
iv
:
stage
->
leaf_iter_vars
)
{
for
(
IterVar
iv
:
stage
->
leaf_iter_vars
)
{
...
@@ -496,7 +494,7 @@ std::vector<Expr> MakeBoundCheck(
...
@@ -496,7 +494,7 @@ std::vector<Expr> MakeBoundCheck(
if
(
skip_iter
.
count
(
iv
)
||
iv
->
iter_type
==
kOpaque
)
continue
;
if
(
skip_iter
.
count
(
iv
)
||
iv
->
iter_type
==
kOpaque
)
continue
;
if
(
bound_state
.
at
(
iv
))
{
if
(
bound_state
.
at
(
iv
))
{
Range
dom
=
dom_map
.
at
(
iv
);
Range
dom
=
dom_map
.
at
(
iv
);
Expr
value
=
ComputeExpr
<
Sub
>
(
value_map
.
at
(
iv
),
dom
->
min
)
;
Expr
value
=
value_map
.
at
(
iv
)
-
dom
->
min
;
Expr
vmax
=
EvalSet
(
value
,
iset_dmap
).
max
();
Expr
vmax
=
EvalSet
(
value
,
iset_dmap
).
max
();
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
analyzer
.
CanProve
(
vmax
<
dom
->
extent
))
{
if
(
vmax
.
type
()
!=
value
.
type
()
||
!
analyzer
.
CanProve
(
vmax
<
dom
->
extent
))
{
preds
.
emplace_back
(
value
<
dom
->
extent
);
preds
.
emplace_back
(
value
<
dom
->
extent
);
...
@@ -508,7 +506,7 @@ std::vector<Expr> MakeBoundCheck(
...
@@ -508,7 +506,7 @@ std::vector<Expr> MakeBoundCheck(
Range
dom
=
dom_map
.
at
(
iv
);
Range
dom
=
dom_map
.
at
(
iv
);
CHECK
(
iv
->
dom
.
defined
());
CHECK
(
iv
->
dom
.
defined
());
if
(
!
skip_ivar_domain
&&
!
iv
->
dom
.
same_as
(
dom
))
{
if
(
!
skip_ivar_domain
&&
!
iv
->
dom
.
same_as
(
dom
))
{
Expr
value
=
ComputeExpr
<
Sub
>
(
value_map
.
at
(
iv
),
iv
->
dom
->
min
)
;
Expr
value
=
value_map
.
at
(
iv
)
-
iv
->
dom
->
min
;
IntSet
s
=
EvalSet
(
value
,
iset_dmap
);
IntSet
s
=
EvalSet
(
value
,
iset_dmap
);
Expr
vmin
=
s
.
min
();
Expr
vmin
=
s
.
min
();
Expr
vmax
=
s
.
max
();
Expr
vmax
=
s
.
max
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment