Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
415a270d
Unverified
Commit
415a270d
authored
May 24, 2019
by
Tianqi Chen
Committed by
GitHub
May 24, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[C++][API] Consistent RAII scoping API. (#3231)
parent
b2f8b96a
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
214 additions
and
190 deletions
+214
-190
include/tvm/arithmetic.h
+15
-10
include/tvm/base.h
+44
-0
include/tvm/build_module.h
+45
-86
python/tvm/build_module.py
+1
-1
python/tvm/target.py
+1
-1
src/api/api_arith.cc
+2
-2
src/arithmetic/analyzer.cc
+12
-5
src/arithmetic/rewrite_simplify.cc
+4
-4
src/arithmetic/stmt_simplify.cc
+5
-5
src/codegen/build_module.cc
+44
-35
src/codegen/codegen_aocl.cc
+3
-3
src/codegen/codegen_vhls.cc
+3
-3
src/codegen/llvm/codegen_llvm.cc
+3
-3
src/codegen/spirv/codegen_spirv.cc
+3
-3
src/relay/backend/build_module.cc
+5
-5
src/relay/backend/compile_engine.cc
+4
-4
src/relay/backend/vm/compiler.cc
+2
-2
src/relay/pass/fold_constant.cc
+4
-4
src/relay/pass/partial_eval.cc
+2
-2
tests/cpp/build_module_test.cc
+6
-6
tests/cpp/relay_build_module_test.cc
+3
-3
topi/src/topi.cc
+3
-3
No files found.
include/tvm/arithmetic.h
View file @
415a270d
...
...
@@ -290,14 +290,14 @@ class CanonicalSimplifier {
};
/*!
* \brief
A RAII c
onstraint context.
* \brief
C
onstraint context.
*
* \code
*
* Var("x");
* arith::Analyzer analyzer;
* {
*
arith::ConstraintContext cctx
(&analyzer, x % 3 == 0);
*
With<arith::ConstraintContext> scope
(&analyzer, x % 3 == 0);
* CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* }
* // constraint no longer in effect.
...
...
@@ -306,19 +306,24 @@ class CanonicalSimplifier {
* \endcode
*/
class
ConstraintContext
{
public
:
private
:
// declare friend to enable with.
friend
class
With
<
ConstraintContext
>
;
/*!
* \brief Construct a constraint context.
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
*/
ConstraintContext
(
Analyzer
*
analyzer
,
const
Expr
&
constraint
)
DMLC_THROW_EXCEPTION
;
/*! \brief destructor */
~
ConstraintContext
()
DMLC_THROW_EXCEPTION
{
exit_
();
}
private
:
ConstraintContext
(
Analyzer
*
analyzer
,
Expr
constraint
)
:
analyzer_
(
analyzer
),
constraint_
(
constraint
)
{}
// enter the scope.
void
EnterWithScope
();
// exit the scope.
void
ExitWithScope
();
/*! \brief The analyzer */
Analyzer
*
analyzer_
;
/*! \brief The constraint */
Expr
constraint_
;
/*! \brief function to be called in recovery */
std
::
function
<
void
()
>
exit_
;
};
...
...
include/tvm/base.h
View file @
415a270d
...
...
@@ -102,6 +102,50 @@ using ::tvm::AttrVisitor;
};
/*!
* \brief RAII wrapper function to enter and exit a context object
* similar to python's with syntax.
*
* \code
* // context class
* class MyContext {
* private:
* friend class With<MyContext>;
MyContext(arguments);
* void EnterWithScope();
* void ExitWithScope();
* };
*
* {
* With<MyContext> scope(arguments);
* // effect take place.
* }
* \endcode
*
* \tparam ContextType Type of the context object.
*/
template
<
typename
ContextType
>
class
With
{
public
:
/*!
* \brief constructor.
* Enter the scope of the context.
*/
template
<
typename
...
Args
>
explicit
With
(
Args
&&
...
args
)
:
ctx_
(
std
::
forward
<
Args
>
(
args
)...)
{
ctx_
.
EnterWithScope
();
}
/*! \brief destructor, leaves the scope of the context. */
~
With
()
DMLC_THROW_EXCEPTION
{
ctx_
.
ExitWithScope
();
}
private
:
/*! \brief internal context type. */
ContextType
ctx_
;
};
/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
...
...
include/tvm/build_module.h
View file @
415a270d
...
...
@@ -37,7 +37,7 @@ namespace tvm {
/*!
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*
Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
class
TargetNode
:
public
Node
{
public
:
...
...
@@ -89,65 +89,47 @@ class TargetNode : public Node {
mutable
std
::
string
str_repr_
;
};
/*! \brief reference cpass to the target. */
class
Target
:
public
NodeRef
{
public
:
Target
()
{}
explicit
Target
(
NodePtr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief Create a Target given a string
* \param target_str the string to parse
*/
TVM_DLL
static
Target
create
(
const
std
::
string
&
target_str
);
/*!
* \brief Push a new target context onto the thread local stack. The Target on top of
* the stack is used to determine which specialization to use when invoking a GenericFunc.
* \param target The target to set as the current context.
*/
TVM_DLL
static
void
EnterTargetScope
(
const
tvm
::
Target
&
target
);
/*!
* \brief Pop a target off the thread local context stack, restoring the previous target
* as the current context.
*/
TVM_DLL
static
void
ExitTargetScope
();
TVM_DLL
static
Target
Create
(
const
std
::
string
&
target_str
);
/*!
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
*
undefined Target will be returned. Otherwise, an empty context stack will cause a
*
runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
TVM_DLL
static
tvm
::
Target
current_targe
t
(
bool
allow_not_defined
=
true
);
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
*
undefined Target will be returned. Otherwise, an empty context stack will cause a
*
runtime error.
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
*/
TVM_DLL
static
tvm
::
Target
Curren
t
(
bool
allow_not_defined
=
true
);
inline
const
TargetNode
*
operator
->
()
const
{
const
TargetNode
*
operator
->
()
const
{
return
static_cast
<
const
TargetNode
*>
(
node_
.
get
());
}
using
ContainerType
=
TargetNode
;
};
/*!
* \brief RAII container to provide a scoped target context. Pushes a target onto the
* context stack when constructed, and pops it when destructed.
*/
struct
TargetContext
{
class
Internal
;
private
:
// enable with syntax.
friend
class
Internal
;
friend
class
With
<
Target
>
;
/*!
* \brief
Enter a new target context. The given target becomes the new current context
.
*
When the TargetContext is destructed, the previous context is restored.
*
\param target The target to set as the new current context
.
* \brief
Push a new target context onto the thread local stack
.
*
The Target on top of the stack is used to determine which
*
specialization to use when invoking a GenericFunc
.
*/
explicit
TargetContext
(
const
tvm
::
Target
&
target
)
{
Target
::
EnterTargetScope
(
target
);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
~
TargetContext
()
{
Target
::
ExitTargetScope
();
}
TVM_DLL
void
EnterWithScope
();
/*!
* \brief Pop a target off the thread local context stack,
* restoring the previous target as the current context.
*/
TVM_DLL
void
ExitWithScope
();
};
/*! \brief This namespace provides functions to construct Target instances */
...
...
@@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector<std::string>& options =
}
// namespace target
class
BuildConfig
;
/*!
* \brief Container for build configuration options
*/
* \brief Container for build configuration options
*/
class
BuildConfigNode
:
public
Node
{
public
:
/*!
...
...
@@ -271,70 +251,49 @@ class BuildConfigNode : public Node {
};
/*!
* \brief Container for build configuration options
*/
* \brief Build configuration for compilations.
*/
class
BuildConfig
:
public
::
tvm
::
NodeRef
{
public
:
BuildConfig
()
{}
explicit
BuildConfig
(
NodePtr
<::
tvm
::
Node
>
n
)
:
NodeRef
(
n
)
{}
const
BuildConfigNode
*
operator
->
()
const
{
return
static_cast
<
const
BuildConfigNode
*>
(
node_
.
get
());
}
BuildConfigNode
*
operator
->
()
{
return
static_cast
<
BuildConfigNode
*>
(
node_
.
get
());
}
/*!
* \brief
Push a new BuildConfig context onto the thread local stack
.
* \
param build_config The configuration to set as the current context.
* \brief
Construct a BuildConfig containing a empty build config node
.
* \
return The new BuildConfig
*/
TVM_DLL
static
void
EnterBuildConfigScope
(
const
tvm
::
BuildConfig
&
build_config
);
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
TVM_DLL
static
void
ExitBuildConfigScope
();
TVM_DLL
static
BuildConfig
Create
();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
TVM_DLL
static
tvm
::
BuildConfig
Current
();
TVM_DLL
static
BuildConfig
Current
();
using
ContainerType
=
BuildConfigNode
;
}
;
class
Internal
;
/*!
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
* context stack when constructed, and pops it when destructed.
*/
struct
BuildConfigContext
{
private
:
// Enable with syntax.
friend
class
With
<
BuildConfig
>
;
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
* \brief Push a new BuildConfig context onto the thread local stack.
*/
explicit
BuildConfigContext
(
const
tvm
::
BuildConfig
&
build_config
)
{
BuildConfig
::
EnterBuildConfigScope
(
build_config
);
}
TVM_DLL
void
EnterWithScope
();
/*! \brief Destructor. Pops the context off the thread local stack. */
~
BuildConfigContext
()
{
BuildConfig
::
ExitBuildConfigScope
();
}
/*!
* \brief Pop a build config off the thread local context stack,
* restoring the previous configuration as the current context.
*/
TVM_DLL
void
ExitWithScope
();
};
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL
BuildConfig
build_config
();
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
...
...
python/tvm/build_module.py
View file @
415a270d
...
...
@@ -187,7 +187,7 @@ class BuildConfig(NodeBase):
def
__exit__
(
self
,
ptype
,
value
,
trace
):
if
self
.
dump_pass_ir
:
BuildConfig
.
_dump_ir
.
exit
()
_api_internal
.
_ExitBuildConfigScope
()
_api_internal
.
_ExitBuildConfigScope
(
self
)
def
__setattr__
(
self
,
name
,
value
):
if
name
in
BuildConfig
.
_node_defaults
:
...
...
python/tvm/target.py
View file @
415a270d
...
...
@@ -133,7 +133,7 @@ class Target(NodeBase):
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
_api_internal
.
_ExitTargetScope
()
_api_internal
.
_ExitTargetScope
(
self
)
@register_node
...
...
src/api/api_arith.cc
View file @
415a270d
...
...
@@ -123,8 +123,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
auto
ctx
=
std
::
shared_ptr
<
ConstraintContext
>
(
new
ConstraintContext
(
self
.
get
(),
args
[
0
]));
auto
ctx
=
std
::
shared_ptr
<
With
<
ConstraintContext
>
>
(
new
With
<
ConstraintContext
>
(
self
.
get
(),
args
[
0
]));
auto
fexit
=
[
ctx
](
TVMArgs
,
TVMRetValue
*
)
mutable
{
ctx
.
reset
();
};
...
...
src/arithmetic/analyzer.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -54,10 +54,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) {
// skip rewrite simplify
}
ConstraintContext
::
ConstraintContext
(
Analyzer
*
analyzer
,
const
Expr
&
constraint
)
{
void
ConstraintContext
::
EnterWithScope
()
{
CHECK
(
exit_
==
nullptr
);
// entering the scope.
auto
f0
=
analyzer
->
const_int_bound
.
EnterConstraint
(
constraint
);
auto
f1
=
analyzer
->
modular_set
.
EnterConstraint
(
constraint
);
auto
f0
=
analyzer
_
->
const_int_bound
.
EnterConstraint
(
constraint_
);
auto
f1
=
analyzer
_
->
modular_set
.
EnterConstraint
(
constraint_
);
// recovery function.
exit_
=
[
f0
,
f1
]()
{
if
(
f1
!=
nullptr
)
f1
();
...
...
@@ -65,6 +67,11 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
};
}
void
ConstraintContext
::
ExitWithScope
()
{
CHECK
(
exit_
!=
nullptr
);
exit_
();
}
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
>
lower_bound
;
...
...
src/arithmetic/rewrite_simplify.cc
View file @
415a270d
...
...
@@ -1200,11 +1200,11 @@ Mutate_(const Select* op, const Expr& self) {
Expr
cond
=
Mutate
(
op
->
condition
);
Expr
true_value
,
false_value
;
{
ConstraintContext
constraint
(
parent_
,
cond
);
With
<
ConstraintContext
>
constraint
(
parent_
,
cond
);
true_value
=
Mutate
(
op
->
true_value
);
}
{
ConstraintContext
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
With
<
ConstraintContext
>
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
false_value
=
Mutate
(
op
->
false_value
);
}
if
(
is_zero
(
cond
))
{
...
...
@@ -1237,11 +1237,11 @@ Mutate_(const Call* op, const Expr& self) {
Expr
cond
=
Mutate
(
op
->
args
[
0
]);
Expr
true_value
,
false_value
;
{
ConstraintContext
constraint
(
parent_
,
cond
);
With
<
ConstraintContext
>
constraint
(
parent_
,
cond
);
true_value
=
Mutate
(
op
->
args
[
1
]);
}
{
ConstraintContext
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
With
<
ConstraintContext
>
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
false_value
=
Mutate
(
op
->
args
[
2
]);
}
if
(
is_zero
(
cond
))
{
...
...
src/arithmetic/stmt_simplify.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator {
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Stmt
then_case
,
else_case
;
{
ConstraintContext
ctx
(
&
analyzer_
,
condition
);
With
<
ConstraintContext
>
ctx
(
&
analyzer_
,
condition
);
then_case
=
this
->
Mutate
(
op
->
then_case
);
}
if
(
op
->
else_case
.
defined
())
{
ConstraintContext
ctx
(
&
analyzer_
,
Mutate
(
Not
::
make
(
condition
)));
With
<
ConstraintContext
>
ctx
(
&
analyzer_
,
Mutate
(
Not
::
make
(
condition
)));
else_case
=
this
->
Mutate
(
op
->
else_case
);
}
if
(
is_one
(
condition
))
return
then_case
;
...
...
@@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator {
Stmt
Mutate_
(
const
AssertStmt
*
op
,
const
Stmt
&
s
)
final
{
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Expr
message
=
this
->
Mutate
(
op
->
message
);
ConstraintContext
ctx
(
&
analyzer_
,
condition
);
With
<
ConstraintContext
>
ctx
(
&
analyzer_
,
condition
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
...
...
src/codegen/build_module.cc
View file @
415a270d
...
...
@@ -18,7 +18,6 @@
*/
/*!
* Copyright (c) 2017 by Contributors
* Compile executable modules.
* \file build_module.cc
*/
...
...
@@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate")
TVM_REGISTER_API
(
"_TargetFromString"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
std
::
string
target_str
=
args
[
0
];
*
ret
=
Target
::
create
(
target_str
);
*
ret
=
Target
::
Create
(
target_str
);
});
std
::
vector
<
std
::
string
>
TargetNode
::
keys
()
const
{
...
...
@@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) {
return
""
;
}
Target
Target
::
c
reate
(
const
std
::
string
&
target_str
)
{
Target
Target
::
C
reate
(
const
std
::
string
&
target_str
)
{
if
(
target_str
.
length
()
==
0
)
{
LOG
(
ERROR
)
<<
"target_str must not be empty"
;
}
...
...
@@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) {
struct
TVMTargetThreadLocalEntry
{
/*! \brief The current target context */
std
::
stack
<
tvm
::
Target
>
context_stack
;
TVMTargetThreadLocalEntry
()
{
}
};
/*! \brief Thread local store to hold the Target context stack. */
typedef
dmlc
::
ThreadLocalStore
<
TVMTargetThreadLocalEntry
>
TVMTargetThreadLocalStore
;
void
Target
::
Enter
TargetScope
(
const
tvm
::
Target
&
target
)
{
void
Target
::
Enter
WithScope
(
)
{
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
entry
->
context_stack
.
push
(
target
);
entry
->
context_stack
.
push
(
*
this
);
}
void
Target
::
Exit
Target
Scope
()
{
void
Target
::
Exit
With
Scope
()
{
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
CHECK
(
!
entry
->
context_stack
.
empty
());
CHECK
(
entry
->
context_stack
.
top
().
same_as
(
*
this
));
entry
->
context_stack
.
pop
();
}
tvm
::
Target
Target
::
current_targe
t
(
bool
allow_not_defined
)
{
tvm
::
Target
Target
::
Curren
t
(
bool
allow_not_defined
)
{
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
if
(
entry
->
context_stack
.
size
()
>
0
)
{
return
entry
->
context_stack
.
top
();
...
...
@@ -574,7 +571,7 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
const
BuildConfig
&
config
)
{
Map
<
Target
,
Array
<
LoweredFunc
>>
updated_input
;
for
(
const
auto
&
it
:
inputs
)
{
auto
target
=
Target
::
c
reate
(
it
.
first
);
auto
target
=
Target
::
C
reate
(
it
.
first
);
updated_input
.
Set
(
target
,
it
.
second
);
}
return
build
(
updated_input
,
target_host
,
config
);
...
...
@@ -589,33 +586,35 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
return
build
(
inputs
,
target_host
,
config
);
}
BuildConfig
build_config
()
{
BuildConfig
BuildConfig
::
Create
()
{
return
BuildConfig
(
make_node
<
BuildConfigNode
>
());
}
/*! \brief Entry to hold the BuildConfig context stack. */
struct
TVMBuildConfigThreadLocalEntry
{
/*! \brief The default build config if the stack is empty */
tvm
::
BuildConfig
default_config
;
BuildConfig
default_config
;
/*! \brief The current build config context */
std
::
stack
<
tvm
::
BuildConfig
>
context_stack
;
std
::
stack
<
BuildConfig
>
context_stack
;
TVMBuildConfigThreadLocalEntry
()
:
default_config
(
build_config
())
{
default_config
(
BuildConfig
::
Create
())
{
}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef
dmlc
::
ThreadLocalStore
<
TVMBuildConfigThreadLocalEntry
>
TVMBuildConfigThreadLocalStore
;
void
BuildConfig
::
Enter
BuildConfigScope
(
const
tvm
::
BuildConfig
&
build_config
)
{
void
BuildConfig
::
Enter
WithScope
(
)
{
TVMBuildConfigThreadLocalEntry
*
entry
=
TVMBuildConfigThreadLocalStore
::
Get
();
entry
->
context_stack
.
push
(
build_config
);
entry
->
context_stack
.
push
(
*
this
);
}
void
BuildConfig
::
Exit
BuildConfig
Scope
()
{
void
BuildConfig
::
Exit
With
Scope
()
{
TVMBuildConfigThreadLocalEntry
*
entry
=
TVMBuildConfigThreadLocalStore
::
Get
();
CHECK
(
!
entry
->
context_stack
.
empty
());
CHECK
(
entry
->
context_stack
.
top
().
same_as
(
*
this
));
entry
->
context_stack
.
pop
();
}
...
...
@@ -714,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
void
GenericFunc
::
CallPacked
(
TVMArgs
args
,
TVMRetValue
*
ret
)
const
{
auto
node
=
static_cast
<
GenericFuncNode
*>
(
node_
.
get
());
auto
target
=
Target
::
current_targe
t
(
true
);
auto
target
=
Target
::
Curren
t
(
true
);
PackedFunc
func
;
if
(
target
.
defined
())
{
...
...
@@ -740,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig")
*
ret
=
BuildConfig
::
Current
();
});
class
BuildConfig
::
Internal
{
public
:
static
void
EnterScope
(
BuildConfig
target
)
{
target
.
EnterWithScope
();
}
static
void
ExitScope
(
BuildConfig
target
)
{
target
.
ExitWithScope
();
}
};
TVM_REGISTER_API
(
"_EnterBuildConfigScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
BuildConfig
target
=
args
[
0
];
BuildConfig
::
EnterBuildConfigScope
(
target
);
});
.
set_body_typed
(
BuildConfig
::
Internal
::
EnterScope
);
TVM_REGISTER_API
(
"_ExitBuildConfigScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
BuildConfig
::
ExitBuildConfigScope
();
});
.
set_body_typed
(
BuildConfig
::
Internal
::
ExitScope
);
TVM_REGISTER_API
(
"_BuildConfigSetAddLowerPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
...
@@ -836,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc")
TVM_REGISTER_API
(
"_GetCurrentTarget"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
bool
allow_not_defined
=
args
[
0
];
*
ret
=
Target
::
current_targe
t
(
allow_not_defined
);
*
ret
=
Target
::
Curren
t
(
allow_not_defined
);
});
class
Target
::
Internal
{
public
:
static
void
EnterScope
(
Target
target
)
{
target
.
EnterWithScope
();
}
static
void
ExitScope
(
Target
target
)
{
target
.
ExitWithScope
();
}
};
TVM_REGISTER_API
(
"_EnterTargetScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Target
target
=
args
[
0
];
Target
::
EnterTargetScope
(
target
);
});
.
set_body_typed
(
Target
::
Internal
::
EnterScope
);
TVM_REGISTER_API
(
"_ExitTargetScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
Target
::
ExitTargetScope
();
});
.
set_body_typed
(
Target
::
Internal
::
ExitScope
);
}
// namespace tvm
src/codegen/codegen_aocl.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
std
::
string
cmd
=
"aoc aocl.cl"
;
// AOCL supports fp64.
cmd
+=
" -Dcl_khr_fp64"
;
Target
target
=
Target
::
c
reate
(
target_str
);
Target
target
=
Target
::
C
reate
(
target_str
);
if
(
target
->
device_name
!=
""
)
{
cmd
+=
" -board="
+
target
->
device_name
;
}
...
...
src/codegen/codegen_vhls.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
std
::
string
xclbin
;
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_sdaccel_compile"
))
{
Target
target
=
Target
::
c
reate
(
target_str
);
Target
target
=
Target
::
C
reate
(
target_str
);
xclbin
=
(
*
f
)(
kernel_info
,
target
->
device_name
).
operator
std
::
string
();
}
else
{
LOG
(
FATAL
)
<<
"Cannot compile Vivado HLS code."
;
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
}
void
CodeGenLLVM
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
arith
::
ConstraintContext
cctx
(
analyzer_
.
get
(),
op
->
condition
);
With
<
arith
::
ConstraintContext
>
cctx
(
analyzer_
.
get
(),
op
->
condition
);
this
->
VisitStmt
(
op
->
body
);
}
...
...
src/codegen/spirv/codegen_spirv.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
}
void
CodeGenSPIRV
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
arith
::
ConstraintContext
cctx
(
analyzer_
.
get
(),
op
->
condition
);
With
<
arith
::
ConstraintContext
>
cctx
(
analyzer_
.
get
(),
op
->
condition
);
this
->
VisitStmt
(
op
->
body
);
}
...
...
src/relay/backend/build_module.cc
View file @
415a270d
...
...
@@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if
(
targets
.
size
()
==
1
)
{
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
for
(
const
auto
&
kv
:
targets
)
{
TargetContext
tctx
(
kv
.
second
);
With
<
Target
>
tctx
(
kv
.
second
);
func
=
CallPackedFunc
(
"relay._ir_pass.AlterOpLayout"
,
func
);
}
}
else
{
...
...
@@ -466,9 +466,9 @@ class RelayBuildModule : public runtime::ModuleNode {
*/
Target
CreateDefaultTarget
(
int
device_type
)
{
std
::
string
name
=
runtime
::
DeviceName
(
device_type
);
if
(
name
==
"cpu"
)
return
Target
::
c
reate
(
"llvm"
);
if
(
name
==
"gpu"
)
return
Target
::
c
reate
(
"cuda"
);
return
Target
::
c
reate
(
name
);
if
(
name
==
"cpu"
)
return
Target
::
C
reate
(
"llvm"
);
if
(
name
==
"gpu"
)
return
Target
::
C
reate
(
"cuda"
);
return
Target
::
C
reate
(
name
);
}
/*!
* \brief Update the target and fallback device required for heterogeneous
...
...
@@ -548,7 +548,7 @@ class RelayBuildModule : public runtime::ModuleNode {
const
RelayBuildConfig
&
cfg
,
const
std
::
unordered_map
<
std
::
string
,
tvm
::
runtime
::
NDArray
>
&
params
)
{
// convert
tvm_cfg_
=
build_config
();
tvm_cfg_
=
BuildConfig
::
Create
();
TargetsMap
device_target
;
if
(
targets_
.
size
()
>
1
)
{
device_target
=
UpdateHeterogeneousInputs
(
targets_
,
cfg
);
...
...
src/relay/backend/compile_engine.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_
[
key
]
=
value
;
}
// Enforce use the target.
TargetContext
target_ctx
(
key
->
target
);
With
<
Target
>
target_scope
(
key
->
target
);
CHECK
(
!
value
->
cached_func
.
defined
());
auto
spair
=
CreateSchedule
(
key
->
source_func
,
key
->
target
);
...
...
@@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_node
->
funcs
=
(
*
f
)(
spair
.
first
,
all_args
,
cache_node
->
func_name
,
key
->
source_func
);
}
else
{
tvm
::
BuildConfig
bcfg
=
tvm
::
build_config
();
tvm
::
BuildConfig
bcfg
=
BuildConfig
::
Create
();
std
::
unordered_map
<
Tensor
,
Buffer
>
binds
;
cache_node
->
funcs
=
tvm
::
lower
(
spair
.
first
,
all_args
,
cache_node
->
func_name
,
binds
,
bcfg
);
}
...
...
src/relay/backend/vm/compiler.cc
View file @
415a270d
...
...
@@ -364,7 +364,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// Next generate the invoke instruction.
CHECK
(
func
->
IsPrimitive
());
auto
target
=
Target
::
c
reate
(
"llvm"
);
auto
target
=
Target
::
C
reate
(
"llvm"
);
auto
key
=
CCacheKeyNode
::
make
(
func
,
target
);
auto
cfunc
=
engine
->
Lower
(
key
);
// TODO(jroesch): support lowered funcs for multiple targets
...
...
@@ -502,7 +502,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
runtime
::
Module
mod
;
if
(
lowered_funcs
.
size
()
>
0
)
{
// TODO(@jroesch): we need to read target from build config
Target
target
=
Target
::
c
reate
(
"llvm"
);
Target
target
=
Target
::
C
reate
(
"llvm"
);
if
(
const
auto
*
f
=
runtime
::
Registry
::
Get
(
"relay.backend.build"
))
{
mod
=
(
*
f
)(
tvm
::
Array
<
LoweredFunc
>
(
lowered_funcs
.
begin
(),
lowered_funcs
.
end
()),
target
);
}
else
{
...
...
src/relay/pass/fold_constant.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -203,10 +203,10 @@ Expr FoldConstant(const Expr& expr) {
DLContext
ctx
;
ctx
.
device_type
=
kDLCPU
;
ctx
.
device_id
=
0
;
Target
target
=
Target
::
c
reate
(
"llvm"
);
Target
target
=
Target
::
C
reate
(
"llvm"
);
// use a fresh build context
// in case we are already in a build context.
BuildConfigContext
fresh_build_ctx
(
build_config
());
With
<
BuildConfig
>
fresh_build_ctx
(
BuildConfig
::
Create
());
return
ConstantFolder
(
CreateInterpreter
(
Module
(
nullptr
),
ctx
,
target
)).
Mutate
(
expr
);
...
...
src/relay/pass/partial_eval.cc
View file @
415a270d
...
...
@@ -375,10 +375,10 @@ DLContext CPUContext() {
}
FInterpreter
CPUInterpreter
()
{
Target
target
=
Target
::
c
reate
(
"llvm"
);
Target
target
=
Target
::
C
reate
(
"llvm"
);
// use a fresh build context
// in case we are already in a build context.
BuildConfigContext
fresh_build_ctx
(
build_config
());
With
<
BuildConfig
>
fresh_build_ctx
(
BuildConfig
::
Create
());
return
CreateInterpreter
(
Module
(
nullptr
),
CPUContext
(),
target
);
}
...
...
tests/cpp/build_module_test.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -50,14 +50,14 @@ TEST(BuildModule, Basic) {
auto
args
=
Array
<
Tensor
>
({
A
,
B
,
C
});
std
::
unordered_map
<
Tensor
,
Buffer
>
binds
;
auto
config
=
build_config
();
auto
config
=
BuildConfig
::
Create
();
auto
target
=
target
::
llvm
();
auto
lowered
=
lower
(
s
,
args
,
"func"
,
binds
,
config
);
auto
module
=
build
(
lowered
,
target
,
Target
(),
config
);
auto
mali_target
=
Target
::
c
reate
(
"opencl -model=Mali-T860MP4@800Mhz -device=mali"
);
CHECK_EQ
(
mali_target
->
str
(),
"opencl -model=Mali-T860MP4@800Mhz -device=mali"
);
auto
mali_target
=
Target
::
C
reate
(
"opencl -model=Mali-T860MP4@800Mhz -device=mali"
);
CHECK_EQ
(
mali_target
->
str
(),
"opencl -model=Mali-T860MP4@800Mhz -device=mali"
);
}
TEST
(
BuildModule
,
Heterogeneous
)
{
...
...
@@ -105,7 +105,7 @@ TEST(BuildModule, Heterogeneous) {
auto
s1
=
topi
::
cuda
::
schedule_injective
(
target_cuda
,
{
elemwise_add
});
auto
s2
=
create_schedule
({
elemwise_sub
->
op
});
auto
config
=
build_config
();
auto
config
=
BuildConfig
::
Create
();
auto
args1
=
Array
<
Tensor
>
({
A
,
B
,
elemwise_add
});
auto
args2
=
Array
<
Tensor
>
({
copy
,
C
,
elemwise_sub
});
...
...
tests/cpp/relay_build_module_test.cc
View file @
415a270d
...
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
...
@@ -75,7 +75,7 @@ TEST(Relay, BuildModule) {
auto
json_f
=
build_mod
.
GetFunction
(
"get_graph_json"
,
false
);
auto
mod_f
=
build_mod
.
GetFunction
(
"get_module"
,
false
);
Map
<
tvm
::
Integer
,
tvm
::
Target
>
targets
;
Target
llvm_tgt
=
Target
::
c
reate
(
"llvm"
);
Target
llvm_tgt
=
Target
::
C
reate
(
"llvm"
);
targets
.
Set
(
0
,
llvm_tgt
);
build_f
(
func
,
targets
,
llvm_tgt
);
std
::
string
json
=
json_f
();
...
...
topi/src/topi.cc
View file @
415a270d
...
...
@@ -94,7 +94,7 @@ inline bool IsTensorType(TVMArgValue arg) {
TVM_REGISTER_GLOBAL
(
"topi.TEST_create_target"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
tvm
::
Target
::
c
reate
(
args
[
0
]);
*
rv
=
tvm
::
Target
::
C
reate
(
args
[
0
]);
});
/* Ops from broadcast.h */
...
...
@@ -640,7 +640,7 @@ using FTVMScheduleBuilder = std::function<
*/
inline
PackedFunc
WrapSchedule
(
FTVMScheduleBuilder
builder
)
{
return
PackedFunc
([
builder
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
target
=
Target
::
current_targe
t
(
false
);
auto
target
=
Target
::
Curren
t
(
false
);
Array
<
Tensor
>
outs
;
NodeRef
argNodeRef
=
args
[
0
];
if
(
argNodeRef
->
type_index
()
==
outs
->
type_index
())
{
...
...
@@ -712,7 +712,7 @@ using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
*/
inline
PackedFunc
WrapDenseOp
(
FTVMDenseOpBuilder
builder
)
{
return
PackedFunc
([
builder
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
target
=
Target
::
current_targe
t
(
false
);
auto
target
=
Target
::
Curren
t
(
false
);
Tensor
data
=
args
[
0
];
Tensor
weight
=
args
[
1
];
Tensor
bias
=
args
[
2
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment