Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
6d460606
Unverified
Commit
6d460606
authored
Mar 02, 2019
by
Tianqi Chen
Committed by
GitHub
Mar 02, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[EXPR] ir_operator.h->expr_operator.h Centralize const folder logic (#2719)
parent
1eb1dac4
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
375 additions
and
213 deletions
+375
-213
include/tvm/buffer.h
+1
-1
include/tvm/data_layout.h
+1
-1
include/tvm/expr_operator.h
+4
-4
include/tvm/operation.h
+1
-1
include/tvm/tensor.h
+1
-1
include/tvm/tvm.h
+1
-1
src/api/api_ir.cc
+1
-2
src/arithmetic/const_fold.h
+289
-0
src/arithmetic/modular_set.cc
+1
-1
src/lang/expr.cc
+1
-1
src/lang/expr_operator.cc
+67
-193
src/op/hybrid_op.cc
+1
-1
src/pass/ir_util.h
+1
-1
src/pass/storage_flatten.cc
+1
-1
src/relay/op/nn/pad.cc
+1
-1
src/relay/op/tensor/transform.cc
+1
-1
src/relay/pass/fuse_ops.cc
+1
-1
tests/cpp/ir_mutator_test.cc
+1
-1
No files found.
include/tvm/buffer.h
View file @
6d460606
...
...
@@ -10,7 +10,7 @@
#include "base.h"
#include "expr.h"
#include "
i
r_operator.h"
#include "
exp
r_operator.h"
#include "tvm/node/container.h"
namespace
tvm
{
...
...
include/tvm/data_layout.h
View file @
6d460606
...
...
@@ -16,7 +16,7 @@
#include <utility>
#include <algorithm>
#include "
i
r_operator.h"
#include "
exp
r_operator.h"
namespace
tvm
{
...
...
include/tvm/
i
r_operator.h
→
include/tvm/
exp
r_operator.h
View file @
6d460606
/*!
* Copyright (c) 2018 by Contributors
* \file tvm/
i
r_operator.h
* \file tvm/
exp
r_operator.h
* \brief Common operators defined for Expr.
*
* \note Most of the operator defined here perform simple constant folding
* when the type is int32 or int64 for simplifying the index expressions.
*/
#ifndef TVM_
I
R_OPERATOR_H_
#define TVM_
I
R_OPERATOR_H_
#ifndef TVM_
EXP
R_OPERATOR_H_
#define TVM_
EXP
R_OPERATOR_H_
#include <algorithm>
#include <type_traits>
...
...
@@ -617,4 +617,4 @@ TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(operator&&);
TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD
(
operator
||
);
}
// namespace tvm
#endif // TVM_
I
R_OPERATOR_H_
#endif // TVM_
EXP
R_OPERATOR_H_
include/tvm/operation.h
View file @
6d460606
...
...
@@ -10,7 +10,7 @@
#include <vector>
#include <unordered_map>
#include "expr.h"
#include "
i
r_operator.h"
#include "
exp
r_operator.h"
#include "tensor.h"
#include "schedule.h"
#include "arithmetic.h"
...
...
include/tvm/tensor.h
View file @
6d460606
...
...
@@ -14,7 +14,7 @@
#include "base.h"
#include "expr.h"
#include "
i
r_operator.h"
#include "
exp
r_operator.h"
#include "arithmetic.h"
namespace
tvm
{
...
...
include/tvm/tvm.h
View file @
6d460606
...
...
@@ -8,7 +8,7 @@
#include "base.h"
#include "expr.h"
#include "
i
r_operator.h"
#include "
exp
r_operator.h"
#include "tensor.h"
#include "operation.h"
#include "packed_func_ext.h"
...
...
src/api/api_ir.cc
View file @
6d460606
...
...
@@ -5,9 +5,8 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/ir_operator.h>
#include <tvm/api_registry.h>
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
namespace
tvm
{
namespace
ir
{
...
...
src/arithmetic/const_fold.h
0 → 100644
View file @
6d460606
/*!
* Copyright (c) 2019 by Contributors
* \file const_fold.h
* \brief Centralized location for constant folding.
*/
#ifndef TVM_ARITHMETIC_CONST_FOLD_H_
#define TVM_ARITHMETIC_CONST_FOLD_H_
#include <tvm/ir.h>
#include <algorithm>
namespace
tvm
{
namespace
arith
{
/*!
* \brief Try to run binary compute with constant folding.
*
* \param a The left operand.
* \param b The right operand.
* \tparam Op The operator type.
*
* \note a and b Must already matched data types with each other.
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template
<
typename
Op
>
inline
Expr
TryConstFold
(
Expr
a
,
Expr
b
);
/*!
* \brief Try to run unary compute with constant folding.
*
* \param a The left operand.
* \tparam Op The operator type.
*
* \note a and b Must already matched data types with each other.
* \return nullptr if constant fold fails, otherwise return folded result.
*/
template
<
typename
Op
>
inline
Expr
TryConstFold
(
Expr
a
);
/*!
* \brief Check whether type is used to represent index.
*
* Index types are frequently used in shape computation
* and need to be aggressively constant-folded.
*
* \param type The type to represent index.
* \return the checked result.
*/
inline
bool
IsIndexType
(
const
Type
&
type
)
{
return
type
.
is_int
()
&&
type
.
lanes
()
==
1
&&
(
type
.
bits
()
==
32
||
type
.
bits
()
==
64
);
}
#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
using ir::FloatImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const FloatImm* fa = a.as<FloatImm>(); \
const FloatImm* fb = b.as<FloatImm>(); \
BODY;
#define TVM_INDEX_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const Type& ta = a.type(); \
const Type& tb = b.type(); \
if (arith::IsIndexType(ta) && arith::IsIndexType(tb)) { \
BODY; \
} \
// specialization of constant folders.
template
<>
inline
Expr
TryConstFold
<
ir
::
Add
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
rtype
=
a
.
type
();
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
pa
->
value
+
pb
->
value
);
if
(
pa
&&
pa
->
value
==
0
)
return
b
;
if
(
pb
&&
pb
->
value
==
0
)
return
a
;
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
fa
->
value
+
fb
->
value
);
if
(
fa
&&
fa
->
value
==
0
)
return
b
;
if
(
fb
&&
fb
->
value
==
0
)
return
a
;
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
Sub
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
rtype
=
a
.
type
();
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
pa
->
value
-
pb
->
value
);
if
(
pb
&&
pb
->
value
==
0
)
return
a
;
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
fa
->
value
-
fb
->
value
);
if
(
fb
&&
fb
->
value
==
0
)
return
a
;
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
Mul
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
rtype
=
a
.
type
();
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
pa
->
value
*
pb
->
value
);
if
(
pa
)
{
if
(
pa
->
value
==
1
)
return
b
;
if
(
pa
->
value
==
0
)
return
a
;
}
if
(
pb
)
{
if
(
pb
->
value
==
1
)
return
a
;
if
(
pb
->
value
==
0
)
return
b
;
}
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
fa
->
value
*
fb
->
value
);
if
(
fa
)
{
if
(
fa
->
value
==
1
)
return
b
;
if
(
fa
->
value
==
0
)
return
a
;
}
if
(
fb
)
{
if
(
fb
->
value
==
1
)
return
a
;
if
(
fb
->
value
==
0
)
return
b
;
}
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
Div
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
rtype
=
a
.
type
();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if
(
pa
&&
pb
&&
pa
->
value
>=
0
&&
pb
->
value
>
0
)
{
return
IntImm
::
make
(
rtype
,
pa
->
value
/
pb
->
value
);
}
if
(
pa
)
{
if
(
pa
->
value
==
0
)
return
a
;
}
if
(
pb
)
{
if
(
pb
->
value
==
1
)
return
a
;
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
}
if
(
fa
&&
fb
&&
fb
->
value
!=
0
)
{
return
FloatImm
::
make
(
rtype
,
fa
->
value
/
fb
->
value
);
}
if
(
fa
&&
fa
->
value
==
0
)
return
a
;
if
(
fb
)
{
if
(
fb
->
value
==
1
)
return
a
;
CHECK_NE
(
fb
->
value
,
0
)
<<
"Divide by zero"
;
}
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
Mod
>
(
Expr
a
,
Expr
b
)
{
TVM_INDEX_CONST_PROPAGATION
({
const
Type
&
rtype
=
a
.
type
();
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if
(
pa
&&
pb
&&
pa
->
value
>=
0
&&
pb
->
value
>
0
)
{
return
IntImm
::
make
(
rtype
,
pa
->
value
%
pb
->
value
);
}
if
(
pa
)
{
if
(
pa
->
value
==
0
)
return
a
;
}
if
(
pb
)
{
if
(
pb
->
value
==
1
)
return
make_zero
(
rtype
);
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
}
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
Min
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
rtype
=
a
.
type
();
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
min
(
pa
->
value
,
pb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
min
(
fa
->
value
,
fb
->
value
));
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
Max
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
rtype
=
a
.
type
();
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
max
(
pa
->
value
,
pb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
max
(
fa
->
value
,
fb
->
value
));
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
GT
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
>
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
>
fb
->
value
);
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
GE
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
>=
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
>=
fb
->
value
);
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
LT
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
<
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
<
fb
->
value
);
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
LE
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
<=
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
<=
fb
->
value
);
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
EQ
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
==
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
==
fb
->
value
);
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
NE
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
!=
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
!=
fb
->
value
);
});
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
And
>
(
Expr
a
,
Expr
b
)
{
using
ir
::
UIntImm
;
const
UIntImm
*
pa
=
a
.
as
<
UIntImm
>
();
const
UIntImm
*
pb
=
b
.
as
<
UIntImm
>
();
if
(
pa
&&
pa
->
value
)
return
b
;
if
(
pa
&&
!
pa
->
value
)
return
a
;
if
(
pb
&&
pb
->
value
)
return
a
;
if
(
pb
&&
!
pb
->
value
)
return
b
;
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
Or
>
(
Expr
a
,
Expr
b
)
{
using
ir
::
UIntImm
;
const
UIntImm
*
pa
=
a
.
as
<
UIntImm
>
();
const
UIntImm
*
pb
=
b
.
as
<
UIntImm
>
();
if
(
pa
&&
pa
->
value
)
return
a
;
if
(
pa
&&
!
pa
->
value
)
return
b
;
if
(
pb
&&
pb
->
value
)
return
b
;
if
(
pb
&&
!
pb
->
value
)
return
a
;
return
Expr
();
}
template
<>
inline
Expr
TryConstFold
<
ir
::
Not
>
(
Expr
a
)
{
using
ir
::
UIntImm
;
const
UIntImm
*
pa
=
a
.
as
<
UIntImm
>
();
if
(
pa
)
{
return
UIntImm
::
make
(
UInt
(
1
),
!
(
pa
->
value
));
}
return
Expr
();
}
}
// namespace arith
}
// namespace tvm
#endif // TVM_ARITHMETIC_CONST_FOLD_H_
src/arithmetic/modular_set.cc
View file @
6d460606
...
...
@@ -4,7 +4,7 @@
* \brief Modular set analysis
*/
#include <tvm/arithmetic.h>
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
#include <tvm/ir_functor_ext.h>
#include <limits>
#include "pattern_match.h"
...
...
src/lang/expr.cc
View file @
6d460606
...
...
@@ -5,7 +5,7 @@
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
#include <ir/IRPrinter.h>
#include <memory>
...
...
src/lang/
i
r_operator.cc
→
src/lang/
exp
r_operator.cc
View file @
6d460606
/*!
* Copyright (c) 2017 by Contributors
* \file
i
r_operator.cc
* \file
exp
r_operator.cc
*/
#include <tvm/base.h>
#include <tvm/ir.h>
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
#include <cmath>
// Centralized header for constant folders.
#include "../arithmetic/const_fold.h"
namespace
tvm
{
/*!
* \brief Check whether type is used to represent index.
*
* Index types are frequently used in shape computation
* and need to be aggressively constant-folded.
*
* \param type The type to represent index.
* \return the checked result.
*/
inline
bool
IsIndexType
(
const
Type
&
type
)
{
return
type
.
is_int
()
&&
type
.
lanes
()
==
1
&&
(
type
.
bits
()
==
32
||
type
.
bits
()
==
64
);
}
// simple cast that only checks if type matches and cast
inline
Expr
SimpleCast
(
const
Type
&
t
,
Expr
value
)
{
if
(
value
.
type
()
==
t
)
return
value
;
...
...
@@ -135,45 +123,14 @@ Expr reinterpret(const Type& t, Expr value) {
return
ir
::
Call
::
make
(
t
,
ir
::
Call
::
reinterpret
,
{
value
},
ir
::
Call
::
PureIntrinsic
);
}
#define TVM_INDEX_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const Type& ta = a.type(); \
const Type& tb = b.type(); \
if (IsIndexType(ta) && IsIndexType(tb)) { \
BODY; \
} \
BinaryOpMatchTypes(a, b);
#define TVM_ARITH_CONST_PROPAGATION(BODY) \
using ir::IntImm; \
using ir::UIntImm; \
using ir::FloatImm; \
BinaryOpMatchTypes(a, b); \
const IntImm* pa = a.as<IntImm>(); \
const IntImm* pb = b.as<IntImm>(); \
const FloatImm* fa = a.as<FloatImm>(); \
const FloatImm* fb = b.as<FloatImm>(); \
BODY;
Expr
operator
+
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
ta
=
a
.
type
();
const
Type
&
tb
=
b
.
type
();
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
pa
->
value
+
pb
->
value
);
if
(
pa
&&
pa
->
value
==
0
)
return
SimpleCast
(
rtype
,
b
);
if
(
pb
&&
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
fa
->
value
+
fb
->
value
);
if
(
fa
&&
fa
->
value
==
0
)
return
SimpleCast
(
rtype
,
b
);
if
(
fb
&&
fb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Add
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
Add
::
make
(
a
,
b
);
}
// negation
Expr
operator
-
(
Expr
a
)
{
using
ir
::
IntImm
;
using
ir
::
FloatImm
;
...
...
@@ -185,114 +142,44 @@ Expr operator-(Expr a) {
}
Expr
operator
-
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
ta
=
a
.
type
();
const
Type
&
tb
=
b
.
type
();
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
pa
->
value
-
pb
->
value
);
if
(
pb
&&
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
fa
->
value
-
fb
->
value
);
if
(
fb
&&
fb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Sub
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
Sub
::
make
(
a
,
b
);
}
Expr
operator
*
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
ta
=
a
.
type
();
const
Type
&
tb
=
b
.
type
();
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
pa
->
value
*
pb
->
value
);
if
(
pa
)
{
if
(
pa
->
value
==
1
)
return
SimpleCast
(
rtype
,
b
);
if
(
pa
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
}
if
(
pb
)
{
if
(
pb
->
value
==
1
)
return
SimpleCast
(
rtype
,
a
);
if
(
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
b
);
}
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
fa
->
value
*
fb
->
value
);
if
(
fa
)
{
if
(
fa
->
value
==
1
)
return
SimpleCast
(
rtype
,
b
);
if
(
fa
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
}
if
(
fb
)
{
if
(
fb
->
value
==
1
)
return
SimpleCast
(
rtype
,
a
);
if
(
fb
->
value
==
0
)
return
SimpleCast
(
rtype
,
b
);
}
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Mul
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
Mul
::
make
(
a
,
b
);
}
Expr
operator
/
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
ta
=
a
.
type
();
const
Type
&
tb
=
b
.
type
();
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if
(
pa
&&
pb
&&
pa
->
value
>=
0
&&
pb
->
value
>
0
)
{
return
IntImm
::
make
(
rtype
,
pa
->
value
/
pb
->
value
);
}
if
(
pa
)
{
if
(
pa
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
}
if
(
pb
)
{
if
(
pb
->
value
==
1
)
return
SimpleCast
(
rtype
,
a
);
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
}
if
(
fa
&&
fb
&&
fb
->
value
!=
0
)
{
return
FloatImm
::
make
(
rtype
,
fa
->
value
/
fb
->
value
);
}
if
(
fa
&&
fa
->
value
==
0
)
{
return
SimpleCast
(
rtype
,
a
);
}
if
(
fb
)
{
if
(
fb
->
value
==
1
)
return
SimpleCast
(
rtype
,
a
);
CHECK_NE
(
fb
->
value
,
0
)
<<
"Divide by zero"
;
}
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Div
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
Div
::
make
(
a
,
b
);
}
Expr
operator
%
(
Expr
a
,
Expr
b
)
{
TVM_INDEX_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
// due to division and mod can have different modes
// only constant fold positive number where rule is fixed.
if
(
pa
&&
pb
&&
pa
->
value
>=
0
&&
pb
->
value
>
0
)
{
return
IntImm
::
make
(
rtype
,
pa
->
value
%
pb
->
value
);
}
if
(
pa
)
{
if
(
pa
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
);
}
if
(
pb
)
{
if
(
pb
->
value
==
1
)
return
make_zero
(
rtype
);
CHECK_NE
(
pb
->
value
,
0
)
<<
"Divide by zero"
;
}
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Mod
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
Mod
::
make
(
a
,
b
);
}
Expr
min
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
ta
=
a
.
type
();
const
Type
&
tb
=
b
.
type
();
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
min
(
pa
->
value
,
pb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
min
(
fa
->
value
,
fb
->
value
));
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Min
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
Min
::
make
(
a
,
b
);
}
Expr
max
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
const
Type
&
ta
=
a
.
type
();
const
Type
&
tb
=
b
.
type
();
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
std
::
max
(
pa
->
value
,
pb
->
value
));
if
(
fa
&&
fb
)
return
FloatImm
::
make
(
rtype
,
std
::
max
(
fa
->
value
,
fb
->
value
));
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Max
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
Max
::
make
(
a
,
b
);
}
...
...
@@ -328,129 +215,116 @@ Expr likely(Expr cond) {
}
Expr
operator
>
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
>
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
>
fb
->
value
);
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
GT
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
GT
::
make
(
a
,
b
);
}
Expr
operator
>=
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
>=
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
>=
fb
->
value
);
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
GE
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
GE
::
make
(
a
,
b
);
}
Expr
operator
<
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
<
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
<
fb
->
value
);
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
LT
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
LT
::
make
(
a
,
b
);
}
Expr
operator
<=
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
<=
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
<=
fb
->
value
);
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
LE
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
LE
::
make
(
a
,
b
);
}
Expr
operator
==
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
==
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
==
fb
->
value
);
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
EQ
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
EQ
::
make
(
a
,
b
);
}
Expr
operator
!=
(
Expr
a
,
Expr
b
)
{
TVM_ARITH_CONST_PROPAGATION
({
if
(
pa
&&
pb
)
return
UIntImm
::
make
(
UInt
(
1
),
pa
->
value
!=
pb
->
value
);
if
(
fa
&&
fb
)
return
UIntImm
::
make
(
UInt
(
1
),
fa
->
value
!=
fb
->
value
);
});
BinaryOpMatchTypes
(
a
,
b
);
Expr
ret
=
arith
::
TryConstFold
<
ir
::
NE
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
NE
::
make
(
a
,
b
);
}
Expr
operator
&&
(
Expr
a
,
Expr
b
)
{
using
ir
::
UIntImm
;
if
(
a
.
type
().
is_bool
()
&&
b
.
type
().
is_bool
())
{
const
UIntImm
*
pa
=
a
.
as
<
UIntImm
>
();
const
UIntImm
*
pb
=
b
.
as
<
UIntImm
>
();
if
(
pa
&&
pa
->
value
)
return
b
;
if
(
pa
&&
!
pa
->
value
)
return
a
;
if
(
pb
&&
pb
->
value
)
return
a
;
if
(
pb
&&
!
pb
->
value
)
return
b
;
}
CHECK
(
a
.
type
().
is_bool
());
CHECK
(
b
.
type
().
is_bool
());
Expr
ret
=
arith
::
TryConstFold
<
ir
::
And
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
And
::
make
(
a
,
b
);
}
Expr
operator
||
(
Expr
a
,
Expr
b
)
{
using
ir
::
UIntImm
;
if
(
a
.
type
().
is_bool
()
&&
b
.
type
().
is_bool
())
{
const
UIntImm
*
pa
=
a
.
as
<
UIntImm
>
();
const
UIntImm
*
pb
=
b
.
as
<
UIntImm
>
();
if
(
pa
&&
pa
->
value
)
return
a
;
if
(
pa
&&
!
pa
->
value
)
return
b
;
if
(
pb
&&
pb
->
value
)
return
b
;
if
(
pb
&&
!
pb
->
value
)
return
a
;
}
CHECK
(
a
.
type
().
is_bool
());
CHECK
(
b
.
type
().
is_bool
());
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Or
>
(
a
,
b
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
Or
::
make
(
a
,
b
);
}
Expr
operator
!
(
Expr
a
)
{
using
ir
::
UIntImm
;
const
UIntImm
*
pa
=
a
.
as
<
UIntImm
>
();
if
(
pa
)
{
return
UIntImm
::
make
(
UInt
(
1
),
!
(
pa
->
value
));
}
CHECK
(
a
.
type
().
is_bool
());
Expr
ret
=
arith
::
TryConstFold
<
ir
::
Not
>
(
a
);
if
(
ret
.
defined
())
return
ret
;
return
ir
::
Not
::
make
(
a
);
}
Expr
operator
>>
(
Expr
a
,
Expr
b
)
{
BinaryOpMatchTypes
(
a
,
b
);
TVM_INDEX_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
const
Type
&
rtype
=
a
.
type
()
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
>>
pb
->
value
));
if
(
pb
)
{
if
(
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
)
;
if
(
pb
->
value
==
0
)
return
a
;
}
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
shift_right
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
<<
(
Expr
a
,
Expr
b
)
{
BinaryOpMatchTypes
(
a
,
b
);
TVM_INDEX_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
const
Type
&
rtype
=
a
.
type
()
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
<<
pb
->
value
));
if
(
pb
)
{
if
(
pb
->
value
==
0
)
return
SimpleCast
(
rtype
,
a
)
;
if
(
pb
->
value
==
0
)
return
a
;
}
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
shift_left
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
&
(
Expr
a
,
Expr
b
)
{
BinaryOpMatchTypes
(
a
,
b
);
TVM_INDEX_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
const
Type
&
rtype
=
a
.
type
()
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
&
pb
->
value
));
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
bitwise_and
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
|
(
Expr
a
,
Expr
b
)
{
BinaryOpMatchTypes
(
a
,
b
);
TVM_INDEX_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
const
Type
&
rtype
=
a
.
type
()
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
|
pb
->
value
));
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
bitwise_or
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
}
Expr
operator
^
(
Expr
a
,
Expr
b
)
{
BinaryOpMatchTypes
(
a
,
b
);
TVM_INDEX_CONST_PROPAGATION
({
Type
rtype
=
ta
.
bits
()
>=
tb
.
bits
()
?
ta
:
tb
;
const
Type
&
rtype
=
a
.
type
()
;
if
(
pa
&&
pb
)
return
IntImm
::
make
(
rtype
,
(
pa
->
value
^
pb
->
value
));
});
return
ir
::
Call
::
make
(
a
.
type
(),
ir
::
Call
::
bitwise_xor
,
{
a
,
b
},
ir
::
Call
::
PureIntrinsic
);
...
...
src/op/hybrid_op.cc
View file @
6d460606
...
...
@@ -7,8 +7,8 @@
#include <tvm/arithmetic.h>
#include <tvm/ir.h>
#include <tvm/ir_mutator.h>
#include <tvm/ir_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/expr_operator.h>
#include <ir/Expr.h>
#include <unordered_set>
#include <string>
...
...
src/pass/ir_util.h
View file @
6d460606
...
...
@@ -7,7 +7,7 @@
#define TVM_PASS_IR_UTIL_H_
#include <tvm/ir.h>
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
#include <tvm/runtime/device_api.h>
#include <vector>
...
...
src/pass/storage_flatten.cc
View file @
6d460606
...
...
@@ -8,7 +8,7 @@
#include <tvm/expr.h>
#include <tvm/operation.h>
#include <tvm/ir_mutator.h>
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
#include <tvm/ir_pass.h>
#include <tvm/buffer.h>
#include <tvm/target_info.h>
...
...
src/relay/op/nn/pad.cc
View file @
6d460606
...
...
@@ -4,7 +4,7 @@
* \brief Implementation of operator pad
*/
#include <tvm/data_layout.h>
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <topi/nn.h>
...
...
src/relay/op/tensor/transform.cc
View file @
6d460606
...
...
@@ -5,7 +5,7 @@
*/
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
#include <tvm/ir.h>
#include <tvm/data_layout.h>
#include <topi/transform.h>
...
...
src/relay/pass/fuse_ops.cc
View file @
6d460606
...
...
@@ -6,7 +6,7 @@
* \brief This is a backend-aware optimization pass.
* Fuse necessary ops into a single one.
*/
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
#include <tvm/relay/pass.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
...
...
tests/cpp/ir_mutator_test.cc
View file @
6d460606
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/ir_mutator.h>
#include <tvm/
i
r_operator.h>
#include <tvm/
exp
r_operator.h>
namespace
{
using
namespace
tvm
::
ir
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment