Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
2c1ca60e
Unverified
Commit
2c1ca60e
authored
Apr 14, 2020
by
masahi
Committed by
GitHub
Apr 13, 2020
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add memoized expr translator for use by backend codegen (#5325)
parent
0ab18036
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
61 additions
and
113 deletions
+61
-113
src/relay/backend/compile_engine.cc
+21
-39
src/relay/backend/contrib/codegen_c/codegen.cc
+1
-11
src/relay/backend/contrib/dnnl/codegen.cc
+1
-11
src/relay/backend/graph_runtime_codegen.cc
+3
-47
src/relay/backend/interpreter.cc
+0
-5
src/relay/backend/utils.h
+35
-0
No files found.
src/relay/backend/compile_engine.cc
View file @
2c1ca60e
...
...
@@ -21,29 +21,31 @@
* \file relay/backend/compile_engine.cc
* \brief Internal compialtion engine.
*/
#include "compile_engine.h"
#include <topi/tags.h>
#include <tvm/driver/driver_api.h>
#include <tvm/ir/type_functor.h>
#include <tvm/te/schedule.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/container.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/driver/driver_api.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <topi/tags.h>
#include <utility>
#include <functional>
#include <limits>
#include <mutex>
#include <functional>
#include <vector>
#include <unordered_map>
#include <utility>
#include <vector>
#include "
compile_engine
.h"
#include "
utils
.h"
namespace
tvm
{
namespace
relay
{
...
...
@@ -111,8 +113,7 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
// The getter to get schedule from compile engine.
// Get schedule from functor.
class
ScheduleGetter
:
public
ExprFunctor
<
Array
<
te
::
Tensor
>
(
const
Expr
&
)
>
{
class
ScheduleGetter
:
public
backend
::
MemoizedExprTranslator
<
Array
<
te
::
Tensor
>>
{
public
:
explicit
ScheduleGetter
(
Target
target
)
:
target_
(
target
),
device_copy_op_
(
Op
::
Get
(
"device_copy"
))
{}
...
...
@@ -179,17 +180,6 @@ class ScheduleGetter :
return
CachedFunc
(
cache_node
);
}
Array
<
te
::
Tensor
>
VisitExpr
(
const
Expr
&
expr
)
{
auto
it
=
memo_
.
find
(
expr
);
if
(
it
!=
memo_
.
end
())
{
return
it
->
second
;
}
else
{
Array
<
te
::
Tensor
>
res
=
ExprFunctor
::
VisitExpr
(
expr
);
memo_
[
expr
]
=
res
;
return
res
;
}
}
Array
<
te
::
Tensor
>
VisitExpr_
(
const
VarNode
*
op
)
final
{
LOG
(
FATAL
)
<<
"Free variable "
<<
op
->
name_hint
();
return
{};
...
...
@@ -327,7 +317,6 @@ class ScheduleGetter :
int
master_op_pattern_
{
0
};
OpImplementation
master_implementation_
;
std
::
ostringstream
readable_name_stream_
;
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
memo_
;
Array
<
te
::
Operation
>
scalars_
;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
...
...
@@ -335,7 +324,7 @@ class ScheduleGetter :
};
// Creates shape function from functor.
class
MakeShapeFunc
:
public
ExprFunctor
<
Array
<
te
::
Tensor
>
(
const
Expr
&
)
>
{
class
MakeShapeFunc
:
public
backend
::
MemoizedExprTranslator
<
Array
<
te
::
Tensor
>
>
{
public
:
MakeShapeFunc
()
{}
...
...
@@ -422,19 +411,14 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
return
std
::
make_pair
(
schedule
,
cfunc
);
}
Array
<
te
::
Tensor
>
VisitExpr
(
const
Expr
&
expr
)
{
auto
it
=
memo_
.
find
(
expr
);
if
(
it
!=
memo_
.
end
())
{
return
it
->
second
;
}
else
{
Array
<
te
::
Tensor
>
res
=
ExprFunctor
::
VisitExpr
(
expr
);
if
(
expr
.
as
<
VarNode
>
()
==
nullptr
)
{
Array
<
te
::
Tensor
>
VisitExpr
(
const
Expr
&
expr
)
final
{
if
(
expr
.
as
<
VarNode
>
())
{
// Do not memoize vars because shape functions could use either the data
// or the shape of a var each time.
memo_
[
expr
]
=
res
;
}
return
res
;
return
ExprFunctor
::
VisitExpr
(
expr
);
}
// For other case, do memoized visit
return
backend
::
MemoizedExprTranslator
<
Array
<
te
::
Tensor
>>::
VisitExpr
(
expr
);
}
Array
<
te
::
Tensor
>
VisitExpr_
(
const
VarNode
*
var_node
)
final
{
...
...
@@ -577,8 +561,6 @@ class MakeShapeFunc : public ExprFunctor<Array<te::Tensor>(const Expr&)> {
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
param_data_
;
/*! \brief Map from parameter to list of shape placeholder */
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
param_shapes_
;
/*! \brief Memoized visit result */
std
::
unordered_map
<
Expr
,
Array
<
te
::
Tensor
>
,
ObjectHash
,
ObjectEqual
>
memo_
;
/*! \brief Stack of data dependencies for shape function */
std
::
vector
<
bool
>
data_dependants_
;
/*! \brief Scalars used in the shape function */
...
...
src/relay/backend/contrib/codegen_c/codegen.cc
View file @
2c1ca60e
...
...
@@ -40,18 +40,10 @@ using namespace backend;
* purpose. Only several binary options are covered. Users
* may need to extend them to cover more operators.
*/
class
CodegenC
:
public
ExprFunctor
<
std
::
vector
<
Output
>
(
const
Expr
&
)
>
,
public
CodegenCBase
{
class
CodegenC
:
public
MemoizedExprTranslator
<
std
::
vector
<
Output
>>
,
public
CodegenCBase
{
public
:
explicit
CodegenC
(
const
std
::
string
&
id
)
{
this
->
ext_func_id_
=
id
;
}
std
::
vector
<
Output
>
VisitExpr
(
const
Expr
&
expr
)
final
{
if
(
visited_
.
count
(
expr
))
return
visited_
.
at
(
expr
);
std
::
vector
<
Output
>
output
=
ExprFunctor
::
VisitExpr
(
expr
);
visited_
[
expr
]
=
output
;
return
output
;
}
std
::
vector
<
Output
>
VisitExprDefault_
(
const
Object
*
op
)
final
{
LOG
(
FATAL
)
<<
"C codegen doesn't support: "
<<
op
->
GetTypeKey
();
return
{};
...
...
@@ -208,8 +200,6 @@ class CodegenC : public ExprFunctor<std::vector<Output>(const Expr&)>,
std
::
vector
<
std
::
string
>
func_decl_
;
/*! \brief The declaration statements of buffers. */
std
::
vector
<
std
::
string
>
buf_decl_
;
/*! \brief The name and index pairs for output. */
std
::
unordered_map
<
Expr
,
std
::
vector
<
Output
>
,
ObjectHash
,
ObjectEqual
>
visited_
;
};
class
CSourceCodegen
:
public
CSourceModuleCodegenBase
{
...
...
src/relay/backend/contrib/dnnl/codegen.cc
View file @
2c1ca60e
...
...
@@ -128,18 +128,10 @@ std::vector<std::string> Add(const CallNode* call) {
// TODO(@zhiics, @comaniac): This is a basic implementation. We should implement
// all utilities and make a base class for users to implement.
class
CodegenDNNL
:
public
ExprFunctor
<
std
::
vector
<
Output
>
(
const
Expr
&
)
>
,
public
CodegenCBase
{
class
CodegenDNNL
:
public
MemoizedExprTranslator
<
std
::
vector
<
Output
>>
,
public
CodegenCBase
{
public
:
explicit
CodegenDNNL
(
const
std
::
string
&
id
)
{
this
->
ext_func_id_
=
id
;
}
std
::
vector
<
Output
>
VisitExpr
(
const
Expr
&
expr
)
final
{
if
(
visited_
.
count
(
expr
))
return
visited_
.
at
(
expr
);
std
::
vector
<
Output
>
output
=
ExprFunctor
::
VisitExpr
(
expr
);
visited_
[
expr
]
=
output
;
return
output
;
}
std
::
vector
<
Output
>
VisitExprDefault_
(
const
Object
*
op
)
final
{
LOG
(
FATAL
)
<<
"DNNL codegen doesn't support: "
<<
op
->
GetTypeKey
();
return
{};
...
...
@@ -343,8 +335,6 @@ class CodegenDNNL : public ExprFunctor<std::vector<Output>(const Expr&)>,
std
::
vector
<
std
::
string
>
ext_func_body
;
/*! \brief The declaration of intermeidate buffers. */
std
::
vector
<
std
::
string
>
buf_decl_
;
/*! \brief The cached expressions. */
std
::
unordered_map
<
Expr
,
std
::
vector
<
Output
>
,
ObjectHash
,
ObjectEqual
>
visited_
;
};
/*!
...
...
src/relay/backend/graph_runtime_codegen.cc
View file @
2c1ca60e
...
...
@@ -28,13 +28,12 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/runtime/device_api.h>
#include <list>
#include <string>
#include <vector>
#include "utils.h"
#include "compile_engine.h"
#include "utils.h"
namespace
tvm
{
namespace
relay
{
...
...
@@ -190,11 +189,9 @@ class GraphOpNode : public GraphNode {
};
/*! \brief Code generator for graph runtime */
class
GraphRuntimeCodegen
:
public
::
tvm
::
relay
::
ExprFunctor
<
std
::
vector
<
GraphNodeRef
>
(
const
Expr
&
)
>
{
class
GraphRuntimeCodegen
:
public
backend
::
MemoizedExprTranslator
<
std
::
vector
<
GraphNodeRef
>>
{
public
:
GraphRuntimeCodegen
(
runtime
::
Module
*
mod
,
const
TargetsMap
&
targets
)
:
mod_
(
mod
)
{
GraphRuntimeCodegen
(
runtime
::
Module
*
mod
,
const
TargetsMap
&
targets
)
:
mod_
(
mod
)
{
compile_engine_
=
CompileEngine
::
Global
();
targets_
=
targets
;
}
...
...
@@ -313,47 +310,6 @@ class GraphRuntimeCodegen
return
{
GraphNodeRef
(
node_id
,
0
)};
}
/*! \brief Visitors */
std
::
unordered_map
<
Expr
,
std
::
vector
<
GraphNodeRef
>
,
ObjectHash
,
ObjectEqual
>
visitor_cache_
;
std
::
vector
<
GraphNodeRef
>
VisitExpr
(
const
Expr
&
expr
)
override
{
if
(
visitor_cache_
.
count
(
expr
))
return
visitor_cache_
.
at
(
expr
);
std
::
vector
<
GraphNodeRef
>
res
;
if
(
expr
.
as
<
ConstantNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
ConstantNode
>
());
}
else
if
(
expr
.
as
<
TupleNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
TupleNode
>
());
}
else
if
(
expr
.
as
<
VarNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
VarNode
>
());
}
else
if
(
expr
.
as
<
GlobalVarNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
GlobalVarNode
>
());
}
else
if
(
expr
.
as
<
FunctionNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
FunctionNode
>
());
}
else
if
(
expr
.
as
<
CallNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
CallNode
>
());
}
else
if
(
expr
.
as
<
LetNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
LetNode
>
());
}
else
if
(
expr
.
as
<
IfNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
IfNode
>
());
}
else
if
(
expr
.
as
<
OpNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
OpNode
>
());
}
else
if
(
expr
.
as
<
TupleGetItemNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
TupleGetItemNode
>
());
}
else
if
(
expr
.
as
<
RefCreateNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
RefCreateNode
>
());
}
else
if
(
expr
.
as
<
RefReadNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
RefReadNode
>
());
}
else
if
(
expr
.
as
<
RefWriteNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
RefWriteNode
>
());
}
else
if
(
expr
.
as
<
ConstructorNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
ConstructorNode
>
());
}
else
if
(
expr
.
as
<
MatchNode
>
())
{
res
=
VisitExpr_
(
expr
.
as
<
MatchNode
>
());
}
visitor_cache_
[
expr
]
=
res
;
return
res
;
}
std
::
vector
<
GraphNodeRef
>
VisitExpr_
(
const
VarNode
*
op
)
override
{
Expr
expr
=
GetRef
<
Expr
>
(
op
);
return
var_map_
[
expr
.
get
()];
...
...
src/relay/backend/interpreter.cc
View file @
2c1ca60e
...
...
@@ -244,11 +244,6 @@ class Interpreter :
return
VisitExpr
(
expr
);
}
ObjectRef
VisitExpr
(
const
Expr
&
expr
)
final
{
auto
ret
=
ExprFunctor
<
ObjectRef
(
const
Expr
&
n
)
>::
VisitExpr
(
expr
);
return
ret
;
}
ObjectRef
VisitExpr_
(
const
VarNode
*
var_node
)
final
{
return
Lookup
(
GetRef
<
Var
>
(
var_node
));
}
...
...
src/relay/backend/utils.h
View file @
2c1ca60e
...
...
@@ -27,6 +27,7 @@
#include <dmlc/json.h>
#include <tvm/driver/driver_api.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/type.h>
#include <tvm/target/codegen.h>
...
...
@@ -42,6 +43,40 @@
namespace
tvm
{
namespace
relay
{
namespace
backend
{
/*!
* \brief A simple wrapper around ExprFunctor for a single argument case.
* The result of visit is memoized.
*/
template
<
typename
OutputType
>
class
MemoizedExprTranslator
:
public
::
tvm
::
relay
::
ExprFunctor
<
OutputType
(
const
Expr
&
)
>
{
using
BaseFunctor
=
::
tvm
::
relay
::
ExprFunctor
<
OutputType
(
const
Expr
&
)
>
;
public
:
/*! \brief virtual destructor */
virtual
~
MemoizedExprTranslator
()
{}
/*!
* \brief The memoized call.
* \param n The expression node.
* \return The result of the call
*/
virtual
OutputType
VisitExpr
(
const
Expr
&
n
)
{
CHECK
(
n
.
defined
());
auto
it
=
memo_
.
find
(
n
);
if
(
it
!=
memo_
.
end
())
{
return
it
->
second
;
}
auto
res
=
BaseFunctor
::
VisitExpr
(
n
);
memo_
[
n
]
=
res
;
return
res
;
}
protected
:
/*! \brief Internal map used for memoization. */
std
::
unordered_map
<
Expr
,
OutputType
,
ObjectHash
,
ObjectEqual
>
memo_
;
};
/*!
* \brief Get the Packed Func
*
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment