Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2c1ca60e
Unverified
Commit
2c1ca60e
authored
Apr 14, 2020
by
masahi
Committed by
GitHub
Apr 13, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add memoized expr translator for use by backend codegen (#5325)
parent
0ab18036
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
63 additions
and
115 deletions
+63
-115
src/relay/backend/compile_engine.cc
+23
-41
src/relay/backend/contrib/codegen_c/codegen.cc
+1
-11
src/relay/backend/contrib/dnnl/codegen.cc
+1
-11
src/relay/backend/graph_runtime_codegen.cc
+3
-47
src/relay/backend/interpreter.cc
+0
-5
src/relay/backend/utils.h
+35
-0
No files found.
src/relay/backend/compile_engine.cc
View file @
2c1ca60e
...
@@ -21,29 +21,31 @@
...
@@ -21,29 +21,31 @@
* \file relay/backend/compile_engine.cc
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
* \brief Internal compialtion engine.
*/
*/
#include "compile_engine.h"
#include <topi/tags.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
#include <tvm/ir/type_functor.h>
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/driver/driver_api.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <topi/tags.h>
#include <functional>
#include <utility>
#include <limits>
#include <limits>
#include <mutex>
#include <mutex>
#include <functional>
#include <vector>
#include <unordered_map>
#include <unordered_map>
#include <utility>
#include <vector>
#include "
compile_engine
.h"
#include "
utils
.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
...
@@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
...
@@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// The getter to get schedule from compile engine.
// The getter to get schedule from compile engine.
// Get schedule from functor.
// Get schedule from functor.
class
ScheduleGetter
:
class
ScheduleGetter
:
public
backend
::
MemoizedExprTranslator
<
Array
<
te
::
Tensor
>>
{
public
ExprFunctor
<
Array
<
te
::
Tensor
>
(
const
Expr
&
)
>
{
public
:
public
:
explicit
ScheduleGetter
(
Target
target
)
explicit
ScheduleGetter
(
Target
target
)
:
target_
(
target
),
device_copy_op_
(
Op
::
Get
(
"device_copy"
))
{}
:
target_
(
target
),
device_copy_op_
(
Op
::
Get
(
"device_copy"
))
{}
...
@@ -179,17 +180,6 @@ class ScheduleGetter :
...
@@ -179,17 +180,6 @@ class ScheduleGetter :
return
CachedFunc
(
cache_node
);
return
CachedFunc
(
cache_node
);
}
}
Array
<
te
::
Tensor
>
VisitExpr
(
const
Expr
&
expr
)
{
auto
it
=
memo_
.
find
(
expr
);
if
(
it
!=
memo_
.
end
())
{
return
it
->
second
;
}
else
{
Array
<
te
::
Tensor
>
res
=
ExprFunctor
::
VisitExpr
(
expr
);
memo_
[
expr
]
=
res
;
return
res
;
}
}
Array
<
te
::
Tensor
>
VisitExpr_
(
const
VarNode
*
op
)
final
{
Array
<
te
::
Tensor
>
VisitExpr_
(
const
VarNode
*
op
)
final
{
LOG
(
FATAL
)
<<
"Free variable "
<<
op
->
name_hint
();
LOG
(
FATAL
)
<<
"Free variable "
<<
op
->
name_hint
();
return
{};
return
{};
...
@@ -327,7 +317,6 @@ class ScheduleGetter :
...
@@ -327,7 +317,6 @@ class ScheduleGetter :
int
master_op_pattern_
{
0
};
int
master_op_pattern_
{
0
};
OpImplementation
master_implementation_
;
OpImplementation
master_implementation_
;
std
::
ostringstream
readable_name_stream_
;
std
::
ostringstream
readable_name_stream_
;
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
memo_
;
Array
<
te
::
Operation
>
scalars_
;
Array
<
te
::
Operation
>
scalars_
;
// Cache device copy op for equivalence checking to reduce registry lookup
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
// overhead for each invocation of call node when retrieving schedules.
...
@@ -335,7 +324,7 @@ class ScheduleGetter :
...
@@ -335,7 +324,7 @@ class ScheduleGetter :
};
};
// Creates shape function from functor.
// Creates shape function from functor.
class
MakeShapeFunc
:
public
ExprFunctor
<
Array
<
te
::
Tensor
>
(
const
Expr
&
)
>
{
class
MakeShapeFunc
:
public
backend
::
MemoizedExprTranslator
<
Array
<
te
::
Tensor
>
>
{
public
:
public
:
MakeShapeFunc
()
{}
MakeShapeFunc
()
{}
...
@@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
...
@@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
return
std
::
make_pair
(
schedule
,
cfunc
);
return
std
::
make_pair
(
schedule
,
cfunc
);
}
}
Array
<
te
::
Tensor
>
VisitExpr
(
const
Expr
&
expr
)
{
Array
<
te
::
Tensor
>
VisitExpr
(
const
Expr
&
expr
)
final
{
auto
it
=
memo_
.
find
(
expr
);
if
(
expr
.
as
<
VarNode
>
())
{
if
(
it
!=
memo_
.
end
())
{
// Do not memoize vars because shape functions could use either the data
return
it
->
second
;
// or the shape of a var each time.
}
else
{
return
ExprFunctor
::
VisitExpr
(
expr
);
Array
<
te
::
Tensor
>
res
=
ExprFunctor
::
VisitExpr
(
expr
);
if
(
expr
.
as
<
VarNode
>
()
==
nullptr
)
{
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
memo_
[
expr
]
=
res
;
}
return
res
;
}
}
// For other case, do memoized visit
return
backend
::
MemoizedExprTranslator
<
Array
<
te
::
Tensor
>>::
VisitExpr
(
expr
);
}
}
Array
<
te
::
Tensor
>
VisitExpr_
(
const
VarNode
*
var_node
)
final
{
Array
<
te
::
Tensor
>
VisitExpr_
(
const
VarNode
*
var_node
)
final
{
...
@@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
...
@@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
param_data_
;
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
param_data_
;
/*! \brief Map from parameter to list of shape placeholder */
/*! \brief Map from parameter to list of shape placeholder */
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
param_shapes_
;
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
param_shapes_
;
/*! \brief Memoized visit result */
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
memo_
;
/*! \brief Stack of data dependencies for shape function */
/*! \brief Stack of data dependencies for shape function */
std
::
vector
<
bool
>
data_dependants_
;
std
::
vector
<
bool
>
data_dependants_
;
/*! \brief Scalars used in the shape function */
/*! \brief Scalars used in the shape function */
...
...
src/relay/backend/contrib/codegen_c/codegen.cc
View file @
2c1ca60e
...
@@ -40,18 +40,10 @@ using namespace backend;
...
@@ -40,18 +40,10 @@ using namespace backend;
* purpose. Only several binary options are covered. Users
* purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators.
* may need to extend them to cover more operators.
*/
*/
class
CodegenC
:
public
ExprFunctor
<
std
::
vector
<
Output
>
(
const
Expr
&
)
>
,
class
CodegenC
:
public
MemoizedExprTranslator
<
std
::
vector
<
Output
>>
,
public
CodegenCBase
{
public
CodegenCBase
{
public
:
public
:
explicit
CodegenC
(
const
std
::
string
&
id
)
{
this
->
ext_func_id_
=
id
;
}
explicit
CodegenC
(
const
std
::
string
&
id
)
{
this
->
ext_func_id_
=
id
;
}
std
::
vector
<
Output
>
VisitExpr
(
const
Expr
&
expr
)
final
{
if
(
visited_
.
count
(
expr
))
return
visited_
.
at
(
expr
);
std
::
vector
<
Output
>
output
=
ExprFunctor
::
VisitExpr
(
expr
);
visited_
[
expr
]
=
output
;
return
output
;
}
std
::
vector
<
Output
>
VisitExprDefault_
(
const
Object
*
op
)
final
{
std
::
vector
<
Output
>
VisitExprDefault_
(
const
Object
*
op
)
final
{
LOG
(
FATAL
)
<<
"C codegen doesn't support: "
<<
op
->
GetTypeKey
();
LOG
(
FATAL
)
<<
"C codegen doesn't support: "
<<
op
->
GetTypeKey
();
return
{};
return
{};
...
@@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
...
@@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
std
::
vector
<
std
::
string
>
func_decl_
;
std
::
vector
<
std
::
string
>
func_decl_
;
/*! \brief The declaration statements of buffers. */
/*! \brief The declaration statements of buffers. */
std
::
vector
<
std
::
string
>
buf_decl_
;
std
::
vector
<
std
::
string
>
buf_decl_
;
/*! \brief The name and index pairs for output. */
std
::
unordered_map
<
Expr
,
std
::
vector
<
Output
>
,
ObjectHash
,
ObjectEqual
>
visited_
;
};
};
class
CSourceCodegen
:
public
CSourceModuleCodegenBase
{
class
CSourceCodegen
:
public
CSourceModuleCodegenBase
{
...
...
src/relay/backend/contrib/dnnl/codegen.cc
View file @
2c1ca60e
...
@@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) {
...
@@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) {
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
// all utilities and make a base class for users to implement.
class
CodegenDNNL
:
public
ExprFunctor
<
std
::
vector
<
Output
>
(
const
Expr
&
)
>
,
class
CodegenDNNL
:
public
MemoizedExprTranslator
<
std
::
vector
<
Output
>>
,
public
CodegenCBase
{
public
CodegenCBase
{
public
:
public
:
explicit
CodegenDNNL
(
const
std
::
string
&
id
)
{
this
->
ext_func_id_
=
id
;
}
explicit
CodegenDNNL
(
const
std
::
string
&
id
)
{
this
->
ext_func_id_
=
id
;
}
std
::
vector
<
Output
>
VisitExpr
(
const
Expr
&
expr
)
final
{
if
(
visited_
.
count
(
expr
))
return
visited_
.
at
(
expr
);
std
::
vector
<
Output
>
output
=
ExprFunctor
::
VisitExpr
(
expr
);
visited_
[
expr
]
=
output
;
return
output
;
}
std
::
vector
<
Output
>
VisitExprDefault_
(
const
Object
*
op
)
final
{
std
::
vector
<
Output
>
VisitExprDefault_
(
const
Object
*
op
)
final
{
LOG
(
FATAL
)
<<
"DNNL codegen doesn't support: "
<<
op
->
GetTypeKey
();
LOG
(
FATAL
)
<<
"DNNL codegen doesn't support: "
<<
op
->
GetTypeKey
();
return
{};
return
{};
...
@@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
...
@@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
std
::
vector
<
std
::
string
>
ext_func_body
;
std
::
vector
<
std
::
string
>
ext_func_body
;
/*! \brief The declaration of intermeidate buffers. */
/*! \brief The declaration of intermeidate buffers. */
std
::
vector
<
std
::
string
>
buf_decl_
;
std
::
vector
<
std
::
string
>
buf_decl_
;
/*! \brief The cached expressions. */
std
::
unordered_map
<
Expr
,
std
::
vector
<
Output
>
,
ObjectHash
,
ObjectEqual
>
visited_
;
};
};
/*!
/*!
...
...
src/relay/backend/graph_runtime_codegen.cc
View file @
2c1ca60e
...
@@ -28,13 +28,12 @@
...
@@ -28,13 +28,12 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/device_api.h>
#include <list>
#include <list>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "utils.h"
#include "compile_engine.h"
#include "compile_engine.h"
#include "utils.h"
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
...
@@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode {
...
@@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode {
};
};
/*! \brief Code generator for graph runtime */
/*! \brief Code generator for graph runtime */
class
GraphRuntimeCodegen
class
GraphRuntimeCodegen
:
public
backend
::
MemoizedExprTranslator
<
std
::
vector
<
GraphNodeRef
>>
{
:
public
::
tvm
::
relay
::
ExprFunctor
<
std
::
vector
<
GraphNodeRef
>
(
const
Expr
&
)
>
{
public
:
public
:
GraphRuntimeCodegen
(
runtime
::
Module
*
mod
,
const
TargetsMap
&
targets
)
GraphRuntimeCodegen
(
runtime
::
Module
*
mod
,
const
TargetsMap
&
targets
)
:
mod_
(
mod
)
{
:
mod_
(
mod
)
{
compile_engine_
=
CompileEngine
::
Global
();
compile_engine_
=
CompileEngine
::
Global
();
targets_
=
targets
;
targets_
=
targets
;
}
}
...
@@ -313,47 +310,6 @@ class GraphRuntimeCodegen
...
@@ -313,47 +310,6 @@ class GraphRuntimeCodegen
return
{
GraphNodeRef
(
node_id
,
0
)};
return
{
GraphNodeRef
(
node_id
,
0
)};
}
}
/*! \brief Visitors */
std
::
unordered_map
<
Expr
,
std
::
vector
<
GraphNodeRef
>
,
ObjectHash
,
ObjectEqual
>
visitor_cache_
;
std
::
vector
<
GraphNodeRef
>
VisitExpr
(
const
Expr
&
expr
)
override
{
if
(
visitor_cache_
.
count
(
expr
))
return
visitor_cache_
.
at
(
expr
);
std
::
vector
<
GraphNodeRef
>
res
;
if
(
expr
.
as
<
ConstantNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
ConstantNode
>
());
}
else
if
(
expr
.
as
<
TupleNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
TupleNode
>
());
}
else
if
(
expr
.
as
<
VarNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
VarNode
>
());
}
else
if
(
expr
.
as
<
GlobalVarNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
GlobalVarNode
>
());
}
else
if
(
expr
.
as
<
FunctionNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
FunctionNode
>
());
}
else
if
(
expr
.
as
<
CallNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
CallNode
>
());
}
else
if
(
expr
.
as
<
LetNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
LetNode
>
());
}
else
if
(
expr
.
as
<
IfNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
IfNode
>
());
}
else
if
(
expr
.
as
<
OpNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
OpNode
>
());
}
else
if
(
expr
.
as
<
TupleGetItemNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
TupleGetItemNode
>
());
}
else
if
(
expr
.
as
<
RefCreateNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
RefCreateNode
>
());
}
else
if
(
expr
.
as
<
RefReadNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
RefReadNode
>
());
}
else
if
(
expr
.
as
<
RefWriteNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
RefWriteNode
>
());
}
else
if
(
expr
.
as
<
ConstructorNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
ConstructorNode
>
());
}
else
if
(
expr
.
as
<
MatchNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
MatchNode
>
());
}
visitor_cache_
[
expr
]
=
res
;
return
res
;
}
std
::
vector
<
GraphNodeRef
>
VisitExpr_
(
const
VarNode
*
op
)
override
{
std
::
vector
<
GraphNodeRef
>
VisitExpr_
(
const
VarNode
*
op
)
override
{
Expr
expr
=
GetRef
<
Expr
>
(
op
);
Expr
expr
=
GetRef
<
Expr
>
(
op
);
return
var_map_
[
expr
.
get
()];
return
var_map_
[
expr
.
get
()];
...
...
src/relay/backend/interpreter.cc
View file @
2c1ca60e
...
@@ -244,11 +244,6 @@ class Interpreter :
...
@@ -244,11 +244,6 @@ class Interpreter :
return
VisitExpr
(
expr
);
return
VisitExpr
(
expr
);
}
}
ObjectRef
VisitExpr
(
const
Expr
&
expr
)
final
{
auto
ret
=
ExprFunctor
<
ObjectRef
(
const
Expr
&
n
)
>::
VisitExpr
(
expr
);
return
ret
;
}
ObjectRef
VisitExpr_
(
const
VarNode
*
var_node
)
final
{
ObjectRef
VisitExpr_
(
const
VarNode
*
var_node
)
final
{
return
Lookup
(
GetRef
<
Var
>
(
var_node
));
return
Lookup
(
GetRef
<
Var
>
(
var_node
));
}
}
...
...
src/relay/backend/utils.h
View file @
2c1ca60e
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include <dmlc/json.h>
#include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/relay/type.h>
#include <tvm/target/codegen.h>
#include <tvm/target/codegen.h>
...
@@ -42,6 +43,40 @@
...
@@ -42,6 +43,40 @@
namespace
tvm
{
namespace
tvm
{
namespace
relay
{
namespace
relay
{
namespace
backend
{
namespace
backend
{
/*!
* \brief A simple wrapper around ExprFunctor for a single argument case.
* The result of visit is memoized.
*/
template
<
typename
OutputType
>
class
MemoizedExprTranslator
:
public
::
tvm
::
relay
::
ExprFunctor
<
OutputType
(
const
Expr
&
)
>
{
using
BaseFunctor
=
::
tvm
::
relay
::
ExprFunctor
<
OutputType
(
const
Expr
&
)
>
;
public
:
/*! \brief virtual destructor */
virtual
~
MemoizedExprTranslator
()
{}
/*!
* \brief The memoized call.
* \param n The expression node.
* \return The result of the call
*/
virtual
OutputType
VisitExpr
(
const
Expr
&
n
)
{
CHECK
(
n
.
defined
());
auto
it
=
memo_
.
find
(
n
);
if
(
it
!=
memo_
.
end
())
{
return
it
->
second
;
}
auto
res
=
BaseFunctor
::
VisitExpr
(
n
);
memo_
[
n
]
=
res
;
return
res
;
}
protected
:
/*! \brief Internal map used for memoization. */
std
::
unordered_map
<
Expr
,
OutputType
,
ObjectHash
,
ObjectEqual
>
memo_
;
};
/*!
/*!
* \brief Get the Packed Func
* \brief Get the Packed Func
*
*
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment