Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
415a270d
Unverified
Commit
415a270d
authored
May 24, 2019
by
Tianqi Chen
Committed by
GitHub
May 24, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[C++][API] Consistent RAII scoping API. (#3231)
parent
b2f8b96a
Hide whitespace changes
Inline
Side-by-side
Showing
22 changed files
with
214 additions
and
190 deletions
+214
-190
include/tvm/arithmetic.h
+15
-10
include/tvm/base.h
+44
-0
include/tvm/build_module.h
+45
-86
python/tvm/build_module.py
+1
-1
python/tvm/target.py
+1
-1
src/api/api_arith.cc
+2
-2
src/arithmetic/analyzer.cc
+12
-5
src/arithmetic/rewrite_simplify.cc
+4
-4
src/arithmetic/stmt_simplify.cc
+5
-5
src/codegen/build_module.cc
+44
-35
src/codegen/codegen_aocl.cc
+3
-3
src/codegen/codegen_vhls.cc
+3
-3
src/codegen/llvm/codegen_llvm.cc
+3
-3
src/codegen/spirv/codegen_spirv.cc
+3
-3
src/relay/backend/build_module.cc
+5
-5
src/relay/backend/compile_engine.cc
+4
-4
src/relay/backend/vm/compiler.cc
+2
-2
src/relay/pass/fold_constant.cc
+4
-4
src/relay/pass/partial_eval.cc
+2
-2
tests/cpp/build_module_test.cc
+6
-6
tests/cpp/relay_build_module_test.cc
+3
-3
topi/src/topi.cc
+3
-3
No files found.
include/tvm/arithmetic.h
View file @
415a270d
...
@@ -290,14 +290,14 @@ class CanonicalSimplifier {
...
@@ -290,14 +290,14 @@ class CanonicalSimplifier {
};
};
/*!
/*!
* \brief
A RAII c
onstraint context.
* \brief
C
onstraint context.
*
*
* \code
* \code
*
*
* Var("x");
* Var("x");
* arith::Analyzer analyzer;
* arith::Analyzer analyzer;
* {
* {
*
arith::ConstraintContext cctx
(&analyzer, x % 3 == 0);
*
With<arith::ConstraintContext> scope
(&analyzer, x % 3 == 0);
* CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* CHECK_EQ(analyzer.modular_set(x)->coeff, 3);
* }
* }
* // constraint no longer in effect.
* // constraint no longer in effect.
...
@@ -306,19 +306,24 @@ class CanonicalSimplifier {
...
@@ -306,19 +306,24 @@ class CanonicalSimplifier {
* \endcode
* \endcode
*/
*/
class
ConstraintContext
{
class
ConstraintContext
{
public
:
private
:
// declare friend to enable with.
friend
class
With
<
ConstraintContext
>
;
/*!
/*!
* \brief Construct a constraint context.
* \brief Construct a constraint context.
* \param analyzer The analyzer.
* \param analyzer The analyzer.
* \param constraint The constraint to be applied.
* \param constraint The constraint to be applied.
*/
*/
ConstraintContext
(
Analyzer
*
analyzer
,
const
Expr
&
constraint
)
DMLC_THROW_EXCEPTION
;
ConstraintContext
(
Analyzer
*
analyzer
,
Expr
constraint
)
/*! \brief destructor */
:
analyzer_
(
analyzer
),
constraint_
(
constraint
)
{}
~
ConstraintContext
()
DMLC_THROW_EXCEPTION
{
// enter the scope.
exit_
();
void
EnterWithScope
();
}
// exit the scope.
void
ExitWithScope
();
private
:
/*! \brief The analyzer */
Analyzer
*
analyzer_
;
/*! \brief The constraint */
Expr
constraint_
;
/*! \brief function to be called in recovery */
/*! \brief function to be called in recovery */
std
::
function
<
void
()
>
exit_
;
std
::
function
<
void
()
>
exit_
;
};
};
...
...
include/tvm/base.h
View file @
415a270d
...
@@ -102,6 +102,50 @@ using ::tvm::AttrVisitor;
...
@@ -102,6 +102,50 @@ using ::tvm::AttrVisitor;
};
};
/*!
/*!
* \brief RAII wrapper function to enter and exit a context object
* similar to python's with syntax.
*
* \code
* // context class
* class MyContext {
* private:
* friend class With<MyContext>;
MyContext(arguments);
* void EnterWithScope();
* void ExitWithScope();
* };
*
* {
* With<MyContext> scope(arguments);
* // effect take place.
* }
* \endcode
*
* \tparam ContextType Type of the context object.
*/
template
<
typename
ContextType
>
class
With
{
public
:
/*!
* \brief constructor.
* Enter the scope of the context.
*/
template
<
typename
...
Args
>
explicit
With
(
Args
&&
...
args
)
:
ctx_
(
std
::
forward
<
Args
>
(
args
)...)
{
ctx_
.
EnterWithScope
();
}
/*! \brief destructor, leaves the scope of the context. */
~
With
()
DMLC_THROW_EXCEPTION
{
ctx_
.
ExitWithScope
();
}
private
:
/*! \brief internal context type. */
ContextType
ctx_
;
};
/*!
* \brief save the node as well as all the node it depends on as json.
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
* This can be used to serialize any TVM object
*
*
...
...
include/tvm/build_module.h
View file @
415a270d
...
@@ -37,7 +37,7 @@ namespace tvm {
...
@@ -37,7 +37,7 @@ namespace tvm {
/*!
/*!
* \brief Container for target device information.
* \brief Container for target device information.
* Use target::llvm, target::cuda etc functions instead of constructing directly.
*
Use target::llvm, target::cuda etc functions instead of constructing directly.
*/
*/
class
TargetNode
:
public
Node
{
class
TargetNode
:
public
Node
{
public
:
public
:
...
@@ -89,65 +89,47 @@ class TargetNode : public Node {
...
@@ -89,65 +89,47 @@ class TargetNode : public Node {
mutable
std
::
string
str_repr_
;
mutable
std
::
string
str_repr_
;
};
};
/*! \brief reference cpass to the target. */
class
Target
:
public
NodeRef
{
class
Target
:
public
NodeRef
{
public
:
public
:
Target
()
{}
Target
()
{}
explicit
Target
(
NodePtr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
explicit
Target
(
NodePtr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
/*!
* \brief Create a Target given a string
* \brief Create a Target given a string
* \param target_str the string to parse
* \param target_str the string to parse
*/
*/
TVM_DLL
static
Target
create
(
const
std
::
string
&
target_str
);
TVM_DLL
static
Target
Create
(
const
std
::
string
&
target_str
);
/*!
* \brief Push a new target context onto the thread local stack. The Target on top of
* the stack is used to determine which specialization to use when invoking a GenericFunc.
* \param target The target to set as the current context.
*/
TVM_DLL
static
void
EnterTargetScope
(
const
tvm
::
Target
&
target
);
/*!
* \brief Pop a target off the thread local context stack, restoring the previous target
* as the current context.
*/
TVM_DLL
static
void
ExitTargetScope
();
/*!
/*!
* \brief Get the current target context from thread local storage.
* \brief Get the current target context from thread local storage.
* \param allow_not_defined If the context stack is empty and this is set to true, an
* \param allow_not_defined If the context stack is empty and this is set to true, an
*
undefined Target will be returned. Otherwise, an empty context stack will cause a
*
undefined Target will be returned. Otherwise, an empty context stack will cause a
*
runtime error.
*
runtime error.
* \return The target that is the current context. The target may not be defined if
* \return The target that is the current context. The target may not be defined if
* allow_not_defined is true.
* allow_not_defined is true.
*/
*/
TVM_DLL
static
tvm
::
Target
current_targe
t
(
bool
allow_not_defined
=
true
);
TVM_DLL
static
tvm
::
Target
Curren
t
(
bool
allow_not_defined
=
true
);
inline
const
TargetNode
*
operator
->
()
const
{
const
TargetNode
*
operator
->
()
const
{
return
static_cast
<
const
TargetNode
*>
(
node_
.
get
());
return
static_cast
<
const
TargetNode
*>
(
node_
.
get
());
}
}
using
ContainerType
=
TargetNode
;
using
ContainerType
=
TargetNode
;
};
class
Internal
;
private
:
/*!
// enable with syntax.
* \brief RAII container to provide a scoped target context. Pushes a target onto the
friend
class
Internal
;
* context stack when constructed, and pops it when destructed.
friend
class
With
<
Target
>
;
*/
struct
TargetContext
{
/*!
/*!
* \brief
Enter a new target context. The given target becomes the new current context
.
* \brief
Push a new target context onto the thread local stack
.
*
When the TargetContext is destructed, the previous context is restored.
*
The Target on top of the stack is used to determine which
*
\param target The target to set as the new current context
.
*
specialization to use when invoking a GenericFunc
.
*/
*/
explicit
TargetContext
(
const
tvm
::
Target
&
target
)
{
TVM_DLL
void
EnterWithScope
();
Target
::
EnterTargetScope
(
target
);
/*!
}
* \brief Pop a target off the thread local context stack,
* restoring the previous target as the current context.
/*! \brief Destructor. Pops the context off the thread local stack. */
*/
~
TargetContext
()
{
TVM_DLL
void
ExitWithScope
();
Target
::
ExitTargetScope
();
}
};
};
/*! \brief This namespace provides functions to construct Target instances */
/*! \brief This namespace provides functions to construct Target instances */
...
@@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector<std::string>& options =
...
@@ -190,11 +172,9 @@ TVM_DLL Target stackvm(const std::vector<std::string>& options =
}
// namespace target
}
// namespace target
class
BuildConfig
;
/*!
/*!
* \brief Container for build configuration options
* \brief Container for build configuration options
*/
*/
class
BuildConfigNode
:
public
Node
{
class
BuildConfigNode
:
public
Node
{
public
:
public
:
/*!
/*!
...
@@ -271,70 +251,49 @@ class BuildConfigNode : public Node {
...
@@ -271,70 +251,49 @@ class BuildConfigNode : public Node {
};
};
/*!
/*!
* \brief Container for build configuration options
* \brief Build configuration for compilations.
*/
*/
class
BuildConfig
:
public
::
tvm
::
NodeRef
{
class
BuildConfig
:
public
::
tvm
::
NodeRef
{
public
:
public
:
BuildConfig
()
{}
BuildConfig
()
{}
explicit
BuildConfig
(
NodePtr
<::
tvm
::
Node
>
n
)
:
NodeRef
(
n
)
{}
explicit
BuildConfig
(
NodePtr
<::
tvm
::
Node
>
n
)
:
NodeRef
(
n
)
{}
const
BuildConfigNode
*
operator
->
()
const
{
const
BuildConfigNode
*
operator
->
()
const
{
return
static_cast
<
const
BuildConfigNode
*>
(
node_
.
get
());
return
static_cast
<
const
BuildConfigNode
*>
(
node_
.
get
());
}
}
BuildConfigNode
*
operator
->
()
{
BuildConfigNode
*
operator
->
()
{
return
static_cast
<
BuildConfigNode
*>
(
node_
.
get
());
return
static_cast
<
BuildConfigNode
*>
(
node_
.
get
());
}
}
/*!
/*!
* \brief
Push a new BuildConfig context onto the thread local stack
.
* \brief
Construct a BuildConfig containing a empty build config node
.
* \
param build_config The configuration to set as the current context.
* \
return The new BuildConfig
*/
*/
TVM_DLL
static
void
EnterBuildConfigScope
(
const
tvm
::
BuildConfig
&
build_config
);
TVM_DLL
static
BuildConfig
Create
();
/*!
* \brief Pop a build config off the thread local context stack, restoring the previous
* configuration as the current context.
*/
TVM_DLL
static
void
ExitBuildConfigScope
();
/*!
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
* \return The configuration that is the current context.
*/
*/
TVM_DLL
static
tvm
::
BuildConfig
Current
();
TVM_DLL
static
BuildConfig
Current
();
using
ContainerType
=
BuildConfigNode
;
using
ContainerType
=
BuildConfigNode
;
}
;
class
Internal
;
/*!
private
:
* \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
// Enable with syntax.
* context stack when constructed, and pops it when destructed.
friend
class
With
<
BuildConfig
>
;
*/
struct
BuildConfigContext
{
/*!
/*!
* \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
* \brief Push a new BuildConfig context onto the thread local stack.
* context. When the BuildConfigContext is destructed, the previous context is restored.
* \param build_config The BuildConfig to set as the new current context.
*/
*/
explicit
BuildConfigContext
(
const
tvm
::
BuildConfig
&
build_config
)
{
TVM_DLL
void
EnterWithScope
();
BuildConfig
::
EnterBuildConfigScope
(
build_config
);
}
/*! \brief Destructor. Pops the context off the thread local stack. */
/*!
~
BuildConfigContext
()
{
* \brief Pop a build config off the thread local context stack,
BuildConfig
::
ExitBuildConfigScope
();
* restoring the previous configuration as the current context.
}
*/
TVM_DLL
void
ExitWithScope
();
};
};
/*!
/*!
* \brief Construct a BuildConfig containing a new BuildConfigNode
* \return The new BuildConfig
*/
TVM_DLL
BuildConfig
build_config
();
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower.
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param args The arguments to the function.
...
...
python/tvm/build_module.py
View file @
415a270d
...
@@ -187,7 +187,7 @@ class BuildConfig(NodeBase):
...
@@ -187,7 +187,7 @@ class BuildConfig(NodeBase):
def
__exit__
(
self
,
ptype
,
value
,
trace
):
def
__exit__
(
self
,
ptype
,
value
,
trace
):
if
self
.
dump_pass_ir
:
if
self
.
dump_pass_ir
:
BuildConfig
.
_dump_ir
.
exit
()
BuildConfig
.
_dump_ir
.
exit
()
_api_internal
.
_ExitBuildConfigScope
()
_api_internal
.
_ExitBuildConfigScope
(
self
)
def
__setattr__
(
self
,
name
,
value
):
def
__setattr__
(
self
,
name
,
value
):
if
name
in
BuildConfig
.
_node_defaults
:
if
name
in
BuildConfig
.
_node_defaults
:
...
...
python/tvm/target.py
View file @
415a270d
...
@@ -133,7 +133,7 @@ class Target(NodeBase):
...
@@ -133,7 +133,7 @@ class Target(NodeBase):
return
self
return
self
def
__exit__
(
self
,
ptype
,
value
,
trace
):
def
__exit__
(
self
,
ptype
,
value
,
trace
):
_api_internal
.
_ExitTargetScope
()
_api_internal
.
_ExitTargetScope
(
self
)
@register_node
@register_node
...
...
src/api/api_arith.cc
View file @
415a270d
...
@@ -123,8 +123,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
...
@@ -123,8 +123,8 @@ TVM_REGISTER_API("arith._CreateAnalyzer")
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
return
PackedFunc
([
self
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
// can't use make_shared due to noexcept(false) decl in destructor,
// can't use make_shared due to noexcept(false) decl in destructor,
// see https://stackoverflow.com/a/43907314
// see https://stackoverflow.com/a/43907314
auto
ctx
=
auto
ctx
=
std
::
shared_ptr
<
With
<
ConstraintContext
>
>
(
std
::
shared_ptr
<
ConstraintContext
>
(
new
ConstraintContext
(
self
.
get
(),
args
[
0
]));
new
With
<
ConstraintContext
>
(
self
.
get
(),
args
[
0
]));
auto
fexit
=
[
ctx
](
TVMArgs
,
TVMRetValue
*
)
mutable
{
auto
fexit
=
[
ctx
](
TVMArgs
,
TVMRetValue
*
)
mutable
{
ctx
.
reset
();
ctx
.
reset
();
};
};
...
...
src/arithmetic/analyzer.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -54,10 +54,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) {
...
@@ -54,10 +54,12 @@ void Analyzer::Bind(const VarExpr& v, const Range& range) {
// skip rewrite simplify
// skip rewrite simplify
}
}
ConstraintContext
::
ConstraintContext
(
Analyzer
*
analyzer
,
const
Expr
&
constraint
)
{
void
ConstraintContext
::
EnterWithScope
()
{
CHECK
(
exit_
==
nullptr
);
// entering the scope.
// entering the scope.
auto
f0
=
analyzer
->
const_int_bound
.
EnterConstraint
(
constraint
);
auto
f0
=
analyzer
_
->
const_int_bound
.
EnterConstraint
(
constraint_
);
auto
f1
=
analyzer
->
modular_set
.
EnterConstraint
(
constraint
);
auto
f1
=
analyzer
_
->
modular_set
.
EnterConstraint
(
constraint_
);
// recovery function.
// recovery function.
exit_
=
[
f0
,
f1
]()
{
exit_
=
[
f0
,
f1
]()
{
if
(
f1
!=
nullptr
)
f1
();
if
(
f1
!=
nullptr
)
f1
();
...
@@ -65,6 +67,11 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
...
@@ -65,6 +67,11 @@ ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint)
};
};
}
}
void
ConstraintContext
::
ExitWithScope
()
{
CHECK
(
exit_
!=
nullptr
);
exit_
();
}
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
bool
Analyzer
::
CanProveGreaterEqual
(
const
Expr
&
expr
,
int64_t
lower_bound
)
{
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
IntImm
>
())
{
if
(
const
auto
*
ptr
=
expr
.
as
<
ir
::
IntImm
>
())
{
return
ptr
->
value
>
lower_bound
;
return
ptr
->
value
>
lower_bound
;
...
...
src/arithmetic/rewrite_simplify.cc
View file @
415a270d
...
@@ -1200,11 +1200,11 @@ Mutate_(const Select* op, const Expr& self) {
...
@@ -1200,11 +1200,11 @@ Mutate_(const Select* op, const Expr& self) {
Expr
cond
=
Mutate
(
op
->
condition
);
Expr
cond
=
Mutate
(
op
->
condition
);
Expr
true_value
,
false_value
;
Expr
true_value
,
false_value
;
{
{
ConstraintContext
constraint
(
parent_
,
cond
);
With
<
ConstraintContext
>
constraint
(
parent_
,
cond
);
true_value
=
Mutate
(
op
->
true_value
);
true_value
=
Mutate
(
op
->
true_value
);
}
}
{
{
ConstraintContext
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
With
<
ConstraintContext
>
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
false_value
=
Mutate
(
op
->
false_value
);
false_value
=
Mutate
(
op
->
false_value
);
}
}
if
(
is_zero
(
cond
))
{
if
(
is_zero
(
cond
))
{
...
@@ -1237,11 +1237,11 @@ Mutate_(const Call* op, const Expr& self) {
...
@@ -1237,11 +1237,11 @@ Mutate_(const Call* op, const Expr& self) {
Expr
cond
=
Mutate
(
op
->
args
[
0
]);
Expr
cond
=
Mutate
(
op
->
args
[
0
]);
Expr
true_value
,
false_value
;
Expr
true_value
,
false_value
;
{
{
ConstraintContext
constraint
(
parent_
,
cond
);
With
<
ConstraintContext
>
constraint
(
parent_
,
cond
);
true_value
=
Mutate
(
op
->
args
[
1
]);
true_value
=
Mutate
(
op
->
args
[
1
]);
}
}
{
{
ConstraintContext
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
With
<
ConstraintContext
>
constraint
(
parent_
,
Mutate
(
Not
::
make
(
cond
)));
false_value
=
Mutate
(
op
->
args
[
2
]);
false_value
=
Mutate
(
op
->
args
[
2
]);
}
}
if
(
is_zero
(
cond
))
{
if
(
is_zero
(
cond
))
{
...
...
src/arithmetic/stmt_simplify.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator {
...
@@ -48,11 +48,11 @@ class StmtSimplifier : public IRMutator {
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Stmt
then_case
,
else_case
;
Stmt
then_case
,
else_case
;
{
{
ConstraintContext
ctx
(
&
analyzer_
,
condition
);
With
<
ConstraintContext
>
ctx
(
&
analyzer_
,
condition
);
then_case
=
this
->
Mutate
(
op
->
then_case
);
then_case
=
this
->
Mutate
(
op
->
then_case
);
}
}
if
(
op
->
else_case
.
defined
())
{
if
(
op
->
else_case
.
defined
())
{
ConstraintContext
ctx
(
&
analyzer_
,
Mutate
(
Not
::
make
(
condition
)));
With
<
ConstraintContext
>
ctx
(
&
analyzer_
,
Mutate
(
Not
::
make
(
condition
)));
else_case
=
this
->
Mutate
(
op
->
else_case
);
else_case
=
this
->
Mutate
(
op
->
else_case
);
}
}
if
(
is_one
(
condition
))
return
then_case
;
if
(
is_one
(
condition
))
return
then_case
;
...
@@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator {
...
@@ -94,7 +94,7 @@ class StmtSimplifier : public IRMutator {
Stmt
Mutate_
(
const
AssertStmt
*
op
,
const
Stmt
&
s
)
final
{
Stmt
Mutate_
(
const
AssertStmt
*
op
,
const
Stmt
&
s
)
final
{
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Expr
condition
=
this
->
Mutate
(
op
->
condition
);
Expr
message
=
this
->
Mutate
(
op
->
message
);
Expr
message
=
this
->
Mutate
(
op
->
message
);
ConstraintContext
ctx
(
&
analyzer_
,
condition
);
With
<
ConstraintContext
>
ctx
(
&
analyzer_
,
condition
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
Stmt
body
=
this
->
Mutate
(
op
->
body
);
if
(
condition
.
same_as
(
op
->
condition
)
&&
if
(
condition
.
same_as
(
op
->
condition
)
&&
...
...
src/codegen/build_module.cc
View file @
415a270d
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
*/
*/
/*!
/*!
* Copyright (c) 2017 by Contributors
* Compile executable modules.
* Compile executable modules.
* \file build_module.cc
* \file build_module.cc
*/
*/
...
@@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate")
...
@@ -148,8 +147,7 @@ TVM_REGISTER_API("_TargetCreate")
TVM_REGISTER_API
(
"_TargetFromString"
)
TVM_REGISTER_API
(
"_TargetFromString"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
std
::
string
target_str
=
args
[
0
];
std
::
string
target_str
=
args
[
0
];
*
ret
=
Target
::
Create
(
target_str
);
*
ret
=
Target
::
create
(
target_str
);
});
});
std
::
vector
<
std
::
string
>
TargetNode
::
keys
()
const
{
std
::
vector
<
std
::
string
>
TargetNode
::
keys
()
const
{
...
@@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) {
...
@@ -207,7 +205,7 @@ std::string GetDeviceName(const std::string& target_str) {
return
""
;
return
""
;
}
}
Target
Target
::
c
reate
(
const
std
::
string
&
target_str
)
{
Target
Target
::
C
reate
(
const
std
::
string
&
target_str
)
{
if
(
target_str
.
length
()
==
0
)
{
if
(
target_str
.
length
()
==
0
)
{
LOG
(
ERROR
)
<<
"target_str must not be empty"
;
LOG
(
ERROR
)
<<
"target_str must not be empty"
;
}
}
...
@@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) {
...
@@ -231,25 +229,24 @@ Target Target::create(const std::string& target_str) {
struct
TVMTargetThreadLocalEntry
{
struct
TVMTargetThreadLocalEntry
{
/*! \brief The current target context */
/*! \brief The current target context */
std
::
stack
<
tvm
::
Target
>
context_stack
;
std
::
stack
<
tvm
::
Target
>
context_stack
;
TVMTargetThreadLocalEntry
()
{
}
};
};
/*! \brief Thread local store to hold the Target context stack. */
/*! \brief Thread local store to hold the Target context stack. */
typedef
dmlc
::
ThreadLocalStore
<
TVMTargetThreadLocalEntry
>
TVMTargetThreadLocalStore
;
typedef
dmlc
::
ThreadLocalStore
<
TVMTargetThreadLocalEntry
>
TVMTargetThreadLocalStore
;
void
Target
::
Enter
TargetScope
(
const
tvm
::
Target
&
target
)
{
void
Target
::
Enter
WithScope
(
)
{
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
entry
->
context_stack
.
push
(
target
);
entry
->
context_stack
.
push
(
*
this
);
}
}
void
Target
::
Exit
Target
Scope
()
{
void
Target
::
Exit
With
Scope
()
{
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
CHECK
(
!
entry
->
context_stack
.
empty
());
CHECK
(
entry
->
context_stack
.
top
().
same_as
(
*
this
));
entry
->
context_stack
.
pop
();
entry
->
context_stack
.
pop
();
}
}
tvm
::
Target
Target
::
current_targe
t
(
bool
allow_not_defined
)
{
tvm
::
Target
Target
::
Curren
t
(
bool
allow_not_defined
)
{
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
TVMTargetThreadLocalEntry
*
entry
=
TVMTargetThreadLocalStore
::
Get
();
if
(
entry
->
context_stack
.
size
()
>
0
)
{
if
(
entry
->
context_stack
.
size
()
>
0
)
{
return
entry
->
context_stack
.
top
();
return
entry
->
context_stack
.
top
();
...
@@ -574,7 +571,7 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
...
@@ -574,7 +571,7 @@ runtime::Module build(const Map<std::string, Array<LoweredFunc>>& inputs,
const
BuildConfig
&
config
)
{
const
BuildConfig
&
config
)
{
Map
<
Target
,
Array
<
LoweredFunc
>>
updated_input
;
Map
<
Target
,
Array
<
LoweredFunc
>>
updated_input
;
for
(
const
auto
&
it
:
inputs
)
{
for
(
const
auto
&
it
:
inputs
)
{
auto
target
=
Target
::
c
reate
(
it
.
first
);
auto
target
=
Target
::
C
reate
(
it
.
first
);
updated_input
.
Set
(
target
,
it
.
second
);
updated_input
.
Set
(
target
,
it
.
second
);
}
}
return
build
(
updated_input
,
target_host
,
config
);
return
build
(
updated_input
,
target_host
,
config
);
...
@@ -589,33 +586,35 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
...
@@ -589,33 +586,35 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
return
build
(
inputs
,
target_host
,
config
);
return
build
(
inputs
,
target_host
,
config
);
}
}
BuildConfig
build_config
()
{
BuildConfig
BuildConfig
::
Create
()
{
return
BuildConfig
(
make_node
<
BuildConfigNode
>
());
return
BuildConfig
(
make_node
<
BuildConfigNode
>
());
}
}
/*! \brief Entry to hold the BuildConfig context stack. */
/*! \brief Entry to hold the BuildConfig context stack. */
struct
TVMBuildConfigThreadLocalEntry
{
struct
TVMBuildConfigThreadLocalEntry
{
/*! \brief The default build config if the stack is empty */
/*! \brief The default build config if the stack is empty */
tvm
::
BuildConfig
default_config
;
BuildConfig
default_config
;
/*! \brief The current build config context */
/*! \brief The current build config context */
std
::
stack
<
tvm
::
BuildConfig
>
context_stack
;
std
::
stack
<
BuildConfig
>
context_stack
;
TVMBuildConfigThreadLocalEntry
()
:
TVMBuildConfigThreadLocalEntry
()
:
default_config
(
build_config
())
{
default_config
(
BuildConfig
::
Create
())
{
}
}
};
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef
dmlc
::
ThreadLocalStore
<
TVMBuildConfigThreadLocalEntry
>
TVMBuildConfigThreadLocalStore
;
typedef
dmlc
::
ThreadLocalStore
<
TVMBuildConfigThreadLocalEntry
>
TVMBuildConfigThreadLocalStore
;
void
BuildConfig
::
Enter
BuildConfigScope
(
const
tvm
::
BuildConfig
&
build_config
)
{
void
BuildConfig
::
Enter
WithScope
(
)
{
TVMBuildConfigThreadLocalEntry
*
entry
=
TVMBuildConfigThreadLocalStore
::
Get
();
TVMBuildConfigThreadLocalEntry
*
entry
=
TVMBuildConfigThreadLocalStore
::
Get
();
entry
->
context_stack
.
push
(
build_config
);
entry
->
context_stack
.
push
(
*
this
);
}
}
void
BuildConfig
::
Exit
BuildConfig
Scope
()
{
void
BuildConfig
::
Exit
With
Scope
()
{
TVMBuildConfigThreadLocalEntry
*
entry
=
TVMBuildConfigThreadLocalStore
::
Get
();
TVMBuildConfigThreadLocalEntry
*
entry
=
TVMBuildConfigThreadLocalStore
::
Get
();
CHECK
(
!
entry
->
context_stack
.
empty
());
CHECK
(
entry
->
context_stack
.
top
().
same_as
(
*
this
));
entry
->
context_stack
.
pop
();
entry
->
context_stack
.
pop
();
}
}
...
@@ -714,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
...
@@ -714,7 +713,7 @@ GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
void
GenericFunc
::
CallPacked
(
TVMArgs
args
,
TVMRetValue
*
ret
)
const
{
void
GenericFunc
::
CallPacked
(
TVMArgs
args
,
TVMRetValue
*
ret
)
const
{
auto
node
=
static_cast
<
GenericFuncNode
*>
(
node_
.
get
());
auto
node
=
static_cast
<
GenericFuncNode
*>
(
node_
.
get
());
auto
target
=
Target
::
current_targe
t
(
true
);
auto
target
=
Target
::
Curren
t
(
true
);
PackedFunc
func
;
PackedFunc
func
;
if
(
target
.
defined
())
{
if
(
target
.
defined
())
{
...
@@ -740,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig")
...
@@ -740,16 +739,21 @@ TVM_REGISTER_API("_GetCurrentBuildConfig")
*
ret
=
BuildConfig
::
Current
();
*
ret
=
BuildConfig
::
Current
();
});
});
class
BuildConfig
::
Internal
{
public
:
static
void
EnterScope
(
BuildConfig
target
)
{
target
.
EnterWithScope
();
}
static
void
ExitScope
(
BuildConfig
target
)
{
target
.
ExitWithScope
();
}
};
TVM_REGISTER_API
(
"_EnterBuildConfigScope"
)
TVM_REGISTER_API
(
"_EnterBuildConfigScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body_typed
(
BuildConfig
::
Internal
::
EnterScope
);
BuildConfig
target
=
args
[
0
];
BuildConfig
::
EnterBuildConfigScope
(
target
);
});
TVM_REGISTER_API
(
"_ExitBuildConfigScope"
)
TVM_REGISTER_API
(
"_ExitBuildConfigScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body_typed
(
BuildConfig
::
Internal
::
ExitScope
);
BuildConfig
::
ExitBuildConfigScope
();
});
TVM_REGISTER_API
(
"_BuildConfigSetAddLowerPass"
)
TVM_REGISTER_API
(
"_BuildConfigSetAddLowerPass"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
...
@@ -836,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc")
...
@@ -836,18 +840,23 @@ TVM_REGISTER_API("_GenericFuncCallFunc")
TVM_REGISTER_API
(
"_GetCurrentTarget"
)
TVM_REGISTER_API
(
"_GetCurrentTarget"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
bool
allow_not_defined
=
args
[
0
];
bool
allow_not_defined
=
args
[
0
];
*
ret
=
Target
::
current_targe
t
(
allow_not_defined
);
*
ret
=
Target
::
Curren
t
(
allow_not_defined
);
});
});
class
Target
::
Internal
{
public
:
static
void
EnterScope
(
Target
target
)
{
target
.
EnterWithScope
();
}
static
void
ExitScope
(
Target
target
)
{
target
.
ExitWithScope
();
}
};
TVM_REGISTER_API
(
"_EnterTargetScope"
)
TVM_REGISTER_API
(
"_EnterTargetScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body_typed
(
Target
::
Internal
::
EnterScope
);
Target
target
=
args
[
0
];
Target
::
EnterTargetScope
(
target
);
});
TVM_REGISTER_API
(
"_ExitTargetScope"
)
TVM_REGISTER_API
(
"_ExitTargetScope"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
.
set_body_typed
(
Target
::
Internal
::
ExitScope
);
Target
::
ExitTargetScope
();
});
}
// namespace tvm
}
// namespace tvm
src/codegen/codegen_aocl.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
...
@@ -54,7 +54,7 @@ runtime::Module BuildAOCL(Array<LoweredFunc> funcs, std::string target_str,
std
::
string
cmd
=
"aoc aocl.cl"
;
std
::
string
cmd
=
"aoc aocl.cl"
;
// AOCL supports fp64.
// AOCL supports fp64.
cmd
+=
" -Dcl_khr_fp64"
;
cmd
+=
" -Dcl_khr_fp64"
;
Target
target
=
Target
::
c
reate
(
target_str
);
Target
target
=
Target
::
C
reate
(
target_str
);
if
(
target
->
device_name
!=
""
)
{
if
(
target
->
device_name
!=
""
)
{
cmd
+=
" -board="
+
target
->
device_name
;
cmd
+=
" -board="
+
target
->
device_name
;
}
}
...
...
src/codegen/codegen_vhls.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
...
@@ -155,7 +155,7 @@ runtime::Module BuildSDAccel(Array<LoweredFunc> funcs, std::string target_str) {
std
::
string
xclbin
;
std
::
string
xclbin
;
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_sdaccel_compile"
))
{
if
(
const
auto
*
f
=
Registry
::
Get
(
"tvm_callback_sdaccel_compile"
))
{
Target
target
=
Target
::
c
reate
(
target_str
);
Target
target
=
Target
::
C
reate
(
target_str
);
xclbin
=
(
*
f
)(
kernel_info
,
target
->
device_name
).
operator
std
::
string
();
xclbin
=
(
*
f
)(
kernel_info
,
target
->
device_name
).
operator
std
::
string
();
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Cannot compile Vivado HLS code."
;
LOG
(
FATAL
)
<<
"Cannot compile Vivado HLS code."
;
...
...
src/codegen/llvm/codegen_llvm.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
...
@@ -1142,7 +1142,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) {
}
}
void
CodeGenLLVM
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
void
CodeGenLLVM
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
arith
::
ConstraintContext
cctx
(
analyzer_
.
get
(),
op
->
condition
);
With
<
arith
::
ConstraintContext
>
cctx
(
analyzer_
.
get
(),
op
->
condition
);
this
->
VisitStmt
(
op
->
body
);
this
->
VisitStmt
(
op
->
body
);
}
}
...
...
src/codegen/spirv/codegen_spirv.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
...
@@ -626,7 +626,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) {
}
}
void
CodeGenSPIRV
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
void
CodeGenSPIRV
::
VisitStmt_
(
const
AssertStmt
*
op
)
{
arith
::
ConstraintContext
cctx
(
analyzer_
.
get
(),
op
->
condition
);
With
<
arith
::
ConstraintContext
>
cctx
(
analyzer_
.
get
(),
op
->
condition
);
this
->
VisitStmt
(
op
->
body
);
this
->
VisitStmt
(
op
->
body
);
}
}
...
...
src/relay/backend/build_module.cc
View file @
415a270d
...
@@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode {
...
@@ -445,7 +445,7 @@ class RelayBuildModule : public runtime::ModuleNode {
if
(
targets
.
size
()
==
1
)
{
if
(
targets
.
size
()
==
1
)
{
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
func
=
CallPackedFunc
(
"relay._ir_pass.infer_type"
,
func
,
nullptr
);
for
(
const
auto
&
kv
:
targets
)
{
for
(
const
auto
&
kv
:
targets
)
{
TargetContext
tctx
(
kv
.
second
);
With
<
Target
>
tctx
(
kv
.
second
);
func
=
CallPackedFunc
(
"relay._ir_pass.AlterOpLayout"
,
func
);
func
=
CallPackedFunc
(
"relay._ir_pass.AlterOpLayout"
,
func
);
}
}
}
else
{
}
else
{
...
@@ -466,9 +466,9 @@ class RelayBuildModule : public runtime::ModuleNode {
...
@@ -466,9 +466,9 @@ class RelayBuildModule : public runtime::ModuleNode {
*/
*/
Target
CreateDefaultTarget
(
int
device_type
)
{
Target
CreateDefaultTarget
(
int
device_type
)
{
std
::
string
name
=
runtime
::
DeviceName
(
device_type
);
std
::
string
name
=
runtime
::
DeviceName
(
device_type
);
if
(
name
==
"cpu"
)
return
Target
::
c
reate
(
"llvm"
);
if
(
name
==
"cpu"
)
return
Target
::
C
reate
(
"llvm"
);
if
(
name
==
"gpu"
)
return
Target
::
c
reate
(
"cuda"
);
if
(
name
==
"gpu"
)
return
Target
::
C
reate
(
"cuda"
);
return
Target
::
c
reate
(
name
);
return
Target
::
C
reate
(
name
);
}
}
/*!
/*!
* \brief Update the target and fallback device required for heterogeneous
* \brief Update the target and fallback device required for heterogeneous
...
@@ -548,7 +548,7 @@ class RelayBuildModule : public runtime::ModuleNode {
...
@@ -548,7 +548,7 @@ class RelayBuildModule : public runtime::ModuleNode {
const
RelayBuildConfig
&
cfg
,
const
RelayBuildConfig
&
cfg
,
const
std
::
unordered_map
<
std
::
string
,
tvm
::
runtime
::
NDArray
>
&
params
)
{
const
std
::
unordered_map
<
std
::
string
,
tvm
::
runtime
::
NDArray
>
&
params
)
{
// convert
// convert
tvm_cfg_
=
build_config
();
tvm_cfg_
=
BuildConfig
::
Create
();
TargetsMap
device_target
;
TargetsMap
device_target
;
if
(
targets_
.
size
()
>
1
)
{
if
(
targets_
.
size
()
>
1
)
{
device_target
=
UpdateHeterogeneousInputs
(
targets_
,
cfg
);
device_target
=
UpdateHeterogeneousInputs
(
targets_
,
cfg
);
...
...
src/relay/backend/compile_engine.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode {
...
@@ -344,7 +344,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_
[
key
]
=
value
;
cache_
[
key
]
=
value
;
}
}
// Enforce use the target.
// Enforce use the target.
TargetContext
target_ctx
(
key
->
target
);
With
<
Target
>
target_scope
(
key
->
target
);
CHECK
(
!
value
->
cached_func
.
defined
());
CHECK
(
!
value
->
cached_func
.
defined
());
auto
spair
=
CreateSchedule
(
key
->
source_func
,
key
->
target
);
auto
spair
=
CreateSchedule
(
key
->
source_func
,
key
->
target
);
...
@@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode {
...
@@ -371,7 +371,7 @@ class CompileEngineImpl : public CompileEngineNode {
cache_node
->
funcs
=
(
*
f
)(
cache_node
->
funcs
=
(
*
f
)(
spair
.
first
,
all_args
,
cache_node
->
func_name
,
key
->
source_func
);
spair
.
first
,
all_args
,
cache_node
->
func_name
,
key
->
source_func
);
}
else
{
}
else
{
tvm
::
BuildConfig
bcfg
=
tvm
::
build_config
();
tvm
::
BuildConfig
bcfg
=
BuildConfig
::
Create
();
std
::
unordered_map
<
Tensor
,
Buffer
>
binds
;
std
::
unordered_map
<
Tensor
,
Buffer
>
binds
;
cache_node
->
funcs
=
tvm
::
lower
(
spair
.
first
,
all_args
,
cache_node
->
func_name
,
binds
,
bcfg
);
cache_node
->
funcs
=
tvm
::
lower
(
spair
.
first
,
all_args
,
cache_node
->
func_name
,
binds
,
bcfg
);
}
}
...
...
src/relay/backend/vm/compiler.cc
View file @
415a270d
...
@@ -364,7 +364,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
...
@@ -364,7 +364,7 @@ struct VMCompiler : ExprFunctor<void(const Expr& expr)> {
// Next generate the invoke instruction.
// Next generate the invoke instruction.
CHECK
(
func
->
IsPrimitive
());
CHECK
(
func
->
IsPrimitive
());
auto
target
=
Target
::
c
reate
(
"llvm"
);
auto
target
=
Target
::
C
reate
(
"llvm"
);
auto
key
=
CCacheKeyNode
::
make
(
func
,
target
);
auto
key
=
CCacheKeyNode
::
make
(
func
,
target
);
auto
cfunc
=
engine
->
Lower
(
key
);
auto
cfunc
=
engine
->
Lower
(
key
);
// TODO(jroesch): support lowered funcs for multiple targets
// TODO(jroesch): support lowered funcs for multiple targets
...
@@ -502,7 +502,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
...
@@ -502,7 +502,7 @@ void PopulatePackedFuncMap(const std::vector<LoweredFunc>& lowered_funcs,
runtime
::
Module
mod
;
runtime
::
Module
mod
;
if
(
lowered_funcs
.
size
()
>
0
)
{
if
(
lowered_funcs
.
size
()
>
0
)
{
// TODO(@jroesch): we need to read target from build config
// TODO(@jroesch): we need to read target from build config
Target
target
=
Target
::
c
reate
(
"llvm"
);
Target
target
=
Target
::
C
reate
(
"llvm"
);
if
(
const
auto
*
f
=
runtime
::
Registry
::
Get
(
"relay.backend.build"
))
{
if
(
const
auto
*
f
=
runtime
::
Registry
::
Get
(
"relay.backend.build"
))
{
mod
=
(
*
f
)(
tvm
::
Array
<
LoweredFunc
>
(
lowered_funcs
.
begin
(),
lowered_funcs
.
end
()),
target
);
mod
=
(
*
f
)(
tvm
::
Array
<
LoweredFunc
>
(
lowered_funcs
.
begin
(),
lowered_funcs
.
end
()),
target
);
}
else
{
}
else
{
...
...
src/relay/pass/fold_constant.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -203,10 +203,10 @@ Expr FoldConstant(const Expr& expr) {
...
@@ -203,10 +203,10 @@ Expr FoldConstant(const Expr& expr) {
DLContext
ctx
;
DLContext
ctx
;
ctx
.
device_type
=
kDLCPU
;
ctx
.
device_type
=
kDLCPU
;
ctx
.
device_id
=
0
;
ctx
.
device_id
=
0
;
Target
target
=
Target
::
c
reate
(
"llvm"
);
Target
target
=
Target
::
C
reate
(
"llvm"
);
// use a fresh build context
// use a fresh build context
// in case we are already in a build context.
// in case we are already in a build context.
BuildConfigContext
fresh_build_ctx
(
build_config
());
With
<
BuildConfig
>
fresh_build_ctx
(
BuildConfig
::
Create
());
return
ConstantFolder
(
CreateInterpreter
(
return
ConstantFolder
(
CreateInterpreter
(
Module
(
nullptr
),
ctx
,
target
)).
Mutate
(
expr
);
Module
(
nullptr
),
ctx
,
target
)).
Mutate
(
expr
);
...
...
src/relay/pass/partial_eval.cc
View file @
415a270d
...
@@ -375,10 +375,10 @@ DLContext CPUContext() {
...
@@ -375,10 +375,10 @@ DLContext CPUContext() {
}
}
FInterpreter
CPUInterpreter
()
{
FInterpreter
CPUInterpreter
()
{
Target
target
=
Target
::
c
reate
(
"llvm"
);
Target
target
=
Target
::
C
reate
(
"llvm"
);
// use a fresh build context
// use a fresh build context
// in case we are already in a build context.
// in case we are already in a build context.
BuildConfigContext
fresh_build_ctx
(
build_config
());
With
<
BuildConfig
>
fresh_build_ctx
(
BuildConfig
::
Create
());
return
CreateInterpreter
(
Module
(
nullptr
),
CPUContext
(),
target
);
return
CreateInterpreter
(
Module
(
nullptr
),
CPUContext
(),
target
);
}
}
...
...
tests/cpp/build_module_test.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -50,14 +50,14 @@ TEST(BuildModule, Basic) {
...
@@ -50,14 +50,14 @@ TEST(BuildModule, Basic) {
auto
args
=
Array
<
Tensor
>
({
A
,
B
,
C
});
auto
args
=
Array
<
Tensor
>
({
A
,
B
,
C
});
std
::
unordered_map
<
Tensor
,
Buffer
>
binds
;
std
::
unordered_map
<
Tensor
,
Buffer
>
binds
;
auto
config
=
build_config
();
auto
config
=
BuildConfig
::
Create
();
auto
target
=
target
::
llvm
();
auto
target
=
target
::
llvm
();
auto
lowered
=
lower
(
s
,
args
,
"func"
,
binds
,
config
);
auto
lowered
=
lower
(
s
,
args
,
"func"
,
binds
,
config
);
auto
module
=
build
(
lowered
,
target
,
Target
(),
config
);
auto
module
=
build
(
lowered
,
target
,
Target
(),
config
);
auto
mali_target
=
Target
::
c
reate
(
"opencl -model=Mali-T860MP4@800Mhz -device=mali"
);
auto
mali_target
=
Target
::
C
reate
(
"opencl -model=Mali-T860MP4@800Mhz -device=mali"
);
CHECK_EQ
(
mali_target
->
str
(),
"opencl -model=Mali-T860MP4@800Mhz -device=mali"
);
CHECK_EQ
(
mali_target
->
str
(),
"opencl -model=Mali-T860MP4@800Mhz -device=mali"
);
}
}
TEST
(
BuildModule
,
Heterogeneous
)
{
TEST
(
BuildModule
,
Heterogeneous
)
{
...
@@ -105,7 +105,7 @@ TEST(BuildModule, Heterogeneous) {
...
@@ -105,7 +105,7 @@ TEST(BuildModule, Heterogeneous) {
auto
s1
=
topi
::
cuda
::
schedule_injective
(
target_cuda
,
{
elemwise_add
});
auto
s1
=
topi
::
cuda
::
schedule_injective
(
target_cuda
,
{
elemwise_add
});
auto
s2
=
create_schedule
({
elemwise_sub
->
op
});
auto
s2
=
create_schedule
({
elemwise_sub
->
op
});
auto
config
=
build_config
();
auto
config
=
BuildConfig
::
Create
();
auto
args1
=
Array
<
Tensor
>
({
A
,
B
,
elemwise_add
});
auto
args1
=
Array
<
Tensor
>
({
A
,
B
,
elemwise_add
});
auto
args2
=
Array
<
Tensor
>
({
copy
,
C
,
elemwise_sub
});
auto
args2
=
Array
<
Tensor
>
({
copy
,
C
,
elemwise_sub
});
...
...
tests/cpp/relay_build_module_test.cc
View file @
415a270d
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...
@@ -75,7 +75,7 @@ TEST(Relay, BuildModule) {
...
@@ -75,7 +75,7 @@ TEST(Relay, BuildModule) {
auto
json_f
=
build_mod
.
GetFunction
(
"get_graph_json"
,
false
);
auto
json_f
=
build_mod
.
GetFunction
(
"get_graph_json"
,
false
);
auto
mod_f
=
build_mod
.
GetFunction
(
"get_module"
,
false
);
auto
mod_f
=
build_mod
.
GetFunction
(
"get_module"
,
false
);
Map
<
tvm
::
Integer
,
tvm
::
Target
>
targets
;
Map
<
tvm
::
Integer
,
tvm
::
Target
>
targets
;
Target
llvm_tgt
=
Target
::
c
reate
(
"llvm"
);
Target
llvm_tgt
=
Target
::
C
reate
(
"llvm"
);
targets
.
Set
(
0
,
llvm_tgt
);
targets
.
Set
(
0
,
llvm_tgt
);
build_f
(
func
,
targets
,
llvm_tgt
);
build_f
(
func
,
targets
,
llvm_tgt
);
std
::
string
json
=
json_f
();
std
::
string
json
=
json_f
();
...
...
topi/src/topi.cc
View file @
415a270d
...
@@ -94,7 +94,7 @@ inline bool IsTensorType(TVMArgValue arg) {
...
@@ -94,7 +94,7 @@ inline bool IsTensorType(TVMArgValue arg) {
TVM_REGISTER_GLOBAL
(
"topi.TEST_create_target"
)
TVM_REGISTER_GLOBAL
(
"topi.TEST_create_target"
)
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
.
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
*
rv
=
tvm
::
Target
::
c
reate
(
args
[
0
]);
*
rv
=
tvm
::
Target
::
C
reate
(
args
[
0
]);
});
});
/* Ops from broadcast.h */
/* Ops from broadcast.h */
...
@@ -640,7 +640,7 @@ using FTVMScheduleBuilder = std::function<
...
@@ -640,7 +640,7 @@ using FTVMScheduleBuilder = std::function<
*/
*/
inline
PackedFunc
WrapSchedule
(
FTVMScheduleBuilder
builder
)
{
inline
PackedFunc
WrapSchedule
(
FTVMScheduleBuilder
builder
)
{
return
PackedFunc
([
builder
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
return
PackedFunc
([
builder
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
target
=
Target
::
current_targe
t
(
false
);
auto
target
=
Target
::
Curren
t
(
false
);
Array
<
Tensor
>
outs
;
Array
<
Tensor
>
outs
;
NodeRef
argNodeRef
=
args
[
0
];
NodeRef
argNodeRef
=
args
[
0
];
if
(
argNodeRef
->
type_index
()
==
outs
->
type_index
())
{
if
(
argNodeRef
->
type_index
()
==
outs
->
type_index
())
{
...
@@ -712,7 +712,7 @@ using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
...
@@ -712,7 +712,7 @@ using FTVMDenseOpBuilder = std::function<tvm::Tensor(const Target& target,
*/
*/
inline
PackedFunc
WrapDenseOp
(
FTVMDenseOpBuilder
builder
)
{
inline
PackedFunc
WrapDenseOp
(
FTVMDenseOpBuilder
builder
)
{
return
PackedFunc
([
builder
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
return
PackedFunc
([
builder
](
TVMArgs
args
,
TVMRetValue
*
ret
)
{
auto
target
=
Target
::
current_targe
t
(
false
);
auto
target
=
Target
::
Curren
t
(
false
);
Tensor
data
=
args
[
0
];
Tensor
data
=
args
[
0
];
Tensor
weight
=
args
[
1
];
Tensor
weight
=
args
[
1
];
Tensor
bias
=
args
[
2
];
Tensor
bias
=
args
[
2
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment