Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
ac3f5bd9
Unverified
Commit
ac3f5bd9
authored
May 16, 2019
by
Tianqi Chen
Committed by
GitHub
May 16, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RELAY] Hotfix build_module creation (#3198)
parent
493f90ff
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
38 deletions
+23
-38
src/relay/backend/build_module.cc
+23
-38
No files found.
src/relay/backend/build_module.cc
View file @
ac3f5bd9
...
...
@@ -18,12 +18,11 @@
*/
/*!
* Copyright (c) 2019 by Contributors
* \file relay/backend/build_module.cc
* \brief Code generation for TVM's graph runtime.
*/
#include <tvm/build_module.h>
#include <tvm/runtime/device_api.h>
#include <tvm/relay/op.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/attrs/nn.h>
...
...
@@ -41,31 +40,6 @@ namespace backend {
using
TargetsMap
=
Map
<
tvm
::
Integer
,
tvm
::
Target
>
;
/*!
* \brief Context index to Target
*/
struct
ContextTargetMap
{
static
const
std
::
unordered_map
<
int
,
tvm
::
Target
>
mask2str
;
static
tvm
::
Target
Mask2Str
(
int
mask
)
{
CHECK_GT
(
mask2str
.
count
(
mask
),
0
)
<<
"Unknown mask."
;
return
mask2str
.
at
(
mask
);
}
};
const
std
::
unordered_map
<
int
,
tvm
::
Target
>
ContextTargetMap
::
mask2str
=
{
{
1
,
tvm
::
Target
::
create
(
"llvm"
)},
{
2
,
tvm
::
Target
::
create
(
"cuda"
)},
{
4
,
tvm
::
Target
::
create
(
"opencl"
)},
{
5
,
tvm
::
Target
::
create
(
"aocl"
)},
{
6
,
tvm
::
Target
::
create
(
"sdaccel"
)},
{
7
,
tvm
::
Target
::
create
(
"vulkan"
)},
{
8
,
tvm
::
Target
::
create
(
"metal"
)},
{
9
,
tvm
::
Target
::
create
(
"vpi"
)},
{
10
,
tvm
::
Target
::
create
(
"rocm"
)},
{
11
,
tvm
::
Target
::
create
(
"opengl"
)},
{
12
,
tvm
::
Target
::
create
(
"ext_dev"
)}
};
/*!
* \brief A data structure to map the names of specific optimizations to
* numeric optimization levels
*
...
...
@@ -310,8 +284,8 @@ class RelayBuildModule : public runtime::ModuleNode {
*
* \return Array<StringImm> names of params
*/
Array
<
HalideIR
::
Expr
>
ListParamNames
()
{
Array
<
HalideIR
::
Expr
>
ret
;
Array
<
tvm
::
Expr
>
ListParamNames
()
{
Array
<
tvm
::
Expr
>
ret
;
for
(
const
auto
&
kv
:
params_
)
{
ret
.
push_back
(
ir
::
StringImm
::
make
(
kv
.
first
));
}
...
...
@@ -470,12 +444,9 @@ class RelayBuildModule : public runtime::ModuleNode {
if
(
cfg
.
pass_enabled
(
"AlterOpLayout"
))
{
if
(
targets
.
size
()
==
1
)
{
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
auto
enter_pf
=
GetPackedFunc
(
"_EnterTargetScope"
);
auto
exit_pf
=
GetPackedFunc
(
"_ExitTargetScope"
);
for
(
const
auto
&
kv
:
targets
)
{
(
*
enter_pf
)
(
kv
.
second
);
TargetContext
tctx
(
kv
.
second
);
func
=
CallPackedFunc
(
"relay._ir_pass.AlterOpLayout"
,
func
);
(
*
exit_pf
)();
}
}
else
{
LOG
(
WARNING
)
<<
"AlterOpLayout pass is not enabled for heterogeneous"
...
...
@@ -487,6 +458,18 @@ class RelayBuildModule : public runtime::ModuleNode {
}
return
func
;
}
/*!
* \brief Create a default type.
* \param device_type The device type index.
* \return the default target for the device.
*/
Target
CreateDefaultTarget
(
int
device_type
)
{
std
::
string
name
=
runtime
::
DeviceName
(
device_type
);
if
(
name
==
"cpu"
)
return
Target
::
create
(
"llvm"
);
if
(
name
==
"gpu"
)
return
Target
::
create
(
"cuda"
);
return
Target
::
create
(
name
);
}
/*!
* \brief Update the target and fallback device required for heterogeneous
* compilation. CPU is used as the fallback device if it wasn't provided.
...
...
@@ -507,7 +490,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if
(
tmp_map
.
count
(
cfg
.
fallback_device
)
==
0
)
{
device_target
.
Set
(
cfg
.
fallback_device
,
C
ontextTargetMap
::
Mask2Str
(
cfg
.
fallback_device
));
C
reateDefaultTarget
(
cfg
.
fallback_device
));
}
return
device_target
;
}
...
...
@@ -520,7 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode {
* \param targets_map_ptr
* \return Function
*/
Function
RunDeviceAnnotationPass
(
Function
func
,
const
RelayBuildConfig
&
cfg
,
Function
RunDeviceAnnotationPass
(
Function
func
,
const
RelayBuildConfig
&
cfg
,
TargetsMap
*
targets_map_ptr
)
{
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.RewriteDeviceAnnotation"
,
func
,
...
...
@@ -532,7 +516,7 @@ class RelayBuildModule : public runtime::ModuleNode {
"relay._ir_pass.CollectDeviceAnnotationOps"
,
func
,
nullptr
);
if
(
annotation_map
.
size
()
==
0
)
{
targets_map_ptr
->
Set
(
0
,
C
ontextTargetMap
::
Mask2Str
(
cfg
.
fallback_device
));
0
,
C
reateDefaultTarget
(
cfg
.
fallback_device
));
}
else
{
int64_t
dev_type
=
-
1
;
for
(
auto
kv
:
annotation_map
)
{
...
...
@@ -547,7 +531,7 @@ class RelayBuildModule : public runtime::ModuleNode {
<<
"found. Please check the "
<<
"RewriteAnnotation pass."
;
}
targets_map_ptr
->
Set
(
0
,
C
ontextTargetMap
::
Mask2Str
(
dev_type
));
targets_map_ptr
->
Set
(
0
,
C
reateDefaultTarget
(
dev_type
));
}
}
return
func
;
...
...
@@ -611,7 +595,8 @@ runtime::Module RelayBuildCreate() {
return
runtime
::
Module
(
exec
);
}
TVM_REGISTER_GLOBAL
(
"relay.build_module._BuildModule"
).
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
TVM_REGISTER_GLOBAL
(
"relay.build_module._BuildModule"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
RelayBuildCreate
();
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment