Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3ba5c15b
Commit
3ba5c15b
authored
Dec 31, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
IntSet Evaluation, skeleton finish
parent
cea88d00
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
432 additions
and
37 deletions
+432
-37
src/bound/bound.cc
+7
-9
src/bound/int_set.cc
+343
-0
src/bound/int_set.h
+82
-25
src/lang/schedule.cc
+0
-1
src/pass/schedule_ops.cc
+0
-2
No files found.
src/bound/bound.cc
View file @
3ba5c15b
...
@@ -66,23 +66,21 @@ void PassUp(const Schedule& s,
...
@@ -66,23 +66,21 @@ void PassUp(const Schedule& s,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
std
::
unordered_map
<
IterVar
,
IntSet
>*
p_state
)
{
std
::
unordered_map
<
IterVar
,
IntSet
>*
p_state
)
{
auto
&
state
=
*
p_state
;
auto
&
state
=
*
p_state
;
for
(
size_t
i
=
s
->
relations
.
size
();
i
!=
0
;
--
i
)
{
for
(
size_t
i
=
s
->
relations
.
size
();
i
!=
0
;
--
i
)
{
IterVarRelation
rel
=
s
->
relations
[
i
-
1
];
IterVarRelation
rel
=
s
->
relations
[
i
-
1
];
if
(
rel
.
as
<
SplitNode
>
())
{
if
(
rel
.
as
<
SplitNode
>
())
{
IntSet
parent
;
IntSet
parent
;
const
SplitNode
*
r
=
rel
.
as
<
SplitNode
>
();
const
SplitNode
*
r
=
rel
.
as
<
SplitNode
>
();
IntSet
::
PassUp
(
PassUp
(
r
,
dom_map
,
r
,
dom_map
,
state
.
at
(
r
->
outer
),
state
.
at
(
r
->
inner
),
state
.
at
(
r
->
outer
),
state
.
at
(
r
->
inner
),
&
parent
);
&
parent
);
state
[
r
->
parent
]
=
parent
;
state
[
r
->
parent
]
=
parent
;
}
else
if
(
rel
.
as
<
FuseNode
>
())
{
}
else
if
(
rel
.
as
<
FuseNode
>
())
{
IntSet
outer
,
inner
;
IntSet
outer
,
inner
;
const
FuseNode
*
r
=
rel
.
as
<
FuseNode
>
();
const
FuseNode
*
r
=
rel
.
as
<
FuseNode
>
();
IntSet
::
PassUp
(
PassUp
(
r
,
dom_map
,
r
,
dom_map
,
state
.
at
(
r
->
fused
),
state
.
at
(
r
->
fused
),
&
outer
,
&
inner
);
&
outer
,
&
inner
);
state
[
r
->
outer
]
=
outer
;
state
[
r
->
outer
]
=
outer
;
state
[
r
->
inner
]
=
inner
;
state
[
r
->
inner
]
=
inner
;
}
else
{
}
else
{
...
...
src/bound/int_set.cc
0 → 100644
View file @
3ba5c15b
/*!
* Copyright (c) 2016 by Contributors
* \file int_set.cc
* \brief The integer set functions
*/
#include <tvm/ir.h>
#include "./int_set.h"
namespace
tvm
{
namespace
bound
{
using
namespace
ir
;
/*!
* \brief Internal node container of int set.
*/
class
IntSetNode
:
public
Node
{
public
:
/*! \brief The base range scope */
Range
base
;
/*! \brief additional strided domain */
Array
<
Range
>
domain
;
/*! \brief The stride of each strided domain */
Array
<
Expr
>
stride
;
/*!
* \brief The concrete set,
* used when concrete execution is enabled.
*/
std
::
vector
<
int32_t
>
concrete
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"base"
,
&
base
);
v
->
Visit
(
"domain"
,
&
domain
);
v
->
Visit
(
"stride"
,
&
stride
);
}
static
constexpr
const
char
*
_type_key
=
"IntSet"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IntSetNode
);
};
TVM_REGISTER_NODE_TYPE
(
IntSetNode
);
namespace
{
inline
bool
Match
(
const
Expr
&
e
,
int64_t
value
)
{
const
ir
::
IntImm
*
v
=
e
.
as
<
ir
::
IntImm
>
();
return
v
!=
nullptr
&&
v
->
value
;
}
// whether a exactly matches b.
inline
bool
Match
(
const
IntSet
&
a
,
const
Range
&
b
)
{
if
(
a
->
base
==
b
&&
a
->
domain
.
size
()
==
0
&&
a
->
concrete
.
size
()
==
0
)
{
return
true
;
}
else
{
return
false
;
}
}
// whether a exactly matches b.
inline
bool
Match
(
const
IntSet
&
a
,
const
Expr
&
b
)
{
if
(
a
->
domain
.
size
()
==
0
&&
a
->
concrete
.
size
()
==
0
)
{
return
Match
(
a
->
base
->
extent
,
1
)
&&
a
->
base
->
min
.
same_as
(
b
);
}
else
{
return
false
;
}
}
inline
bool
IsNumber
(
const
IntSet
&
s
)
{
if
(
s
->
domain
.
size
()
!=
0
)
return
false
;
if
(
s
->
concrete
.
size
()
!=
0
)
{
return
s
->
concrete
.
size
()
==
1
;
}
return
Match
(
s
->
base
->
extent
,
1
);
}
inline
Expr
AsNumber
(
const
IntSet
&
s
)
{
return
s
->
base
->
min
;
}
// set combination rule by operators
template
<
typename
T
>
inline
IntSet
BinaryCombine
(
IntSet
a
,
IntSet
b
)
{
LOG
(
WARNING
)
<<
"cannot evaluate binary op "
<<
T
::
_type_key
;
return
IntSet
::
make_all_set
();
}
template
<>
inline
IntSet
BinaryCombine
<
Add
>
(
IntSet
a
,
IntSet
b
)
{
auto
n
=
std
::
make_shared
<
IntSetNode
>
(
*
(
a
.
operator
->
()));
for
(
size_t
i
=
0
;
i
<
b
->
domain
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
b
->
domain
[
i
]);
n
->
stride
.
push_back
(
b
->
stride
[
i
]);
}
if
(
IsNumber
(
a
))
{
n
->
base
=
Range
::
make_with_min_extent
(
a
->
base
->
min
+
b
->
base
->
min
,
b
->
base
->
extent
);
}
else
if
(
IsNumber
(
b
))
{
n
->
base
=
Range
::
make_with_min_extent
(
a
->
base
->
min
+
b
->
base
->
min
,
a
->
base
->
extent
);
}
else
{
n
->
base
=
Range
::
make_with_min_extent
(
a
->
base
->
min
+
b
->
base
->
min
,
a
->
base
->
extent
+
b
->
base
->
extent
-
1
);
}
return
IntSet
(
n
);
}
inline
Range
Negation
(
Range
a
)
{
if
(
Match
(
a
->
extent
,
1
))
{
return
Range
::
make_with_min_extent
(
-
a
->
min
,
a
->
extent
);
}
else
{
return
Range
::
make_with_min_extent
(
-
(
a
->
min
+
a
->
extent
-
1
),
a
->
extent
);
}
}
inline
IntSet
Negation
(
IntSet
a
)
{
CHECK_EQ
(
a
->
concrete
.
size
(),
0
);
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
n
->
base
=
Negation
(
a
->
base
);
for
(
size_t
i
=
0
;
i
<
a
->
domain
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
Negation
(
a
->
domain
[
i
]));
n
->
stride
.
push_back
(
a
->
stride
[
i
]);
}
return
IntSet
(
a
);
}
template
<>
inline
IntSet
BinaryCombine
<
Sub
>
(
IntSet
a
,
IntSet
b
)
{
return
BinaryCombine
<
Add
>
(
a
,
Negation
(
b
));
}
inline
IntSet
BinaryMul
(
IntSet
a
,
Expr
b
)
{
// copy construct
if
(
Match
(
b
,
1
))
return
a
;
if
(
Match
(
b
,
-
1
))
return
Negation
(
a
);
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
n
->
base
=
Range
::
make_with_min_extent
(
0
,
1
);
n
->
domain
.
push_back
(
a
->
base
);
n
->
stride
.
push_back
(
b
);
for
(
size_t
i
=
0
;
i
<
a
->
domain
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
a
->
domain
[
i
]);
n
->
stride
.
push_back
(
a
->
stride
[
i
]
*
b
);
}
return
IntSet
(
a
);
}
template
<>
inline
IntSet
BinaryCombine
<
Mul
>
(
IntSet
a
,
IntSet
b
)
{
if
(
IsNumber
(
a
))
{
return
BinaryMul
(
a
,
AsNumber
(
b
));
}
else
if
(
IsNumber
(
b
))
{
return
BinaryMul
(
b
,
AsNumber
(
a
));
}
else
{
return
IntSet
::
make_all_set
();
}
}
}
// namespace
inline
const
IntSetNode
*
IntSet
::
operator
->
()
const
{
return
static_cast
<
const
IntSetNode
*>
(
node_
.
get
());
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
IntSetNode
>
([](
const
IntSetNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"int-set(base="
;
p
->
print
(
op
->
base
);
p
->
stream
<<
')'
;
});
IntSet
IntSet
::
make
(
Range
dom
)
{
auto
n
=
std
::
make_shared
<
IntSetNode
>
();
n
->
base
=
dom
;
return
IntSet
(
n
);
}
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
const
IntSet
&
inner
,
IntSet
*
parent
)
{
if
(
dom_map
.
count
(
s
->
outer
)
&&
dom_map
.
count
(
s
->
inner
)
&&
dom_map
.
count
(
s
->
parent
)
&&
Match
(
outer
,
dom_map
.
at
(
s
->
outer
))
&&
Match
(
inner
,
dom_map
.
at
(
s
->
inner
)))
{
*
parent
=
IntSet
::
make
(
dom_map
.
at
(
s
->
parent
));
return
;
}
// copy construct
auto
n
=
std
::
make_shared
<
IntSetNode
>
(
*
(
inner
.
operator
->
()));
if
(
IsNumber
(
outer
))
{
// shift the base offset
n
->
base
=
Range
::
make_with_min_extent
(
AsNumber
(
outer
)
*
s
->
factor
+
inner
->
base
->
min
,
inner
->
base
->
extent
);
*
parent
=
IntSet
(
n
);
}
else
{
// default use all domains in the data.
n
->
domain
.
push_back
(
outer
->
base
);
n
->
stride
.
push_back
(
s
->
factor
);
for
(
size_t
i
=
0
;
i
<
outer
->
domain
.
size
();
++
i
)
{
n
->
domain
.
push_back
(
outer
->
domain
[
i
]);
n
->
stride
.
push_back
(
outer
->
stride
[
i
]
*
s
->
factor
);
}
}
}
void
PassUp
(
const
FuseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
inner
)
{
CHECK
(
dom_map
.
count
(
s
->
outer
));
CHECK
(
dom_map
.
count
(
s
->
inner
));
CHECK
(
dom_map
.
count
(
s
->
fused
));
if
(
Match
(
fused
,
dom_map
.
at
(
s
->
fused
)))
{
*
outer
=
IntSet
::
make
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make
(
dom_map
.
at
(
s
->
inner
));
return
;
}
if
(
IsNumber
(
fused
))
{
Expr
value
=
AsNumber
(
fused
);
Expr
factor
=
dom_map
.
at
(
s
->
outer
)
->
extent
;
*
outer
=
IntSet
::
make
(
Range
::
make_with_min_extent
(
value
/
factor
,
1
));
*
inner
=
IntSet
::
make
(
Range
::
make_with_min_extent
(
value
%
factor
,
1
));
}
else
{
LOG
(
WARNING
)
<<
"use fallback inference rule in fuse"
;
// simply use the entire set, this rule can be enhanced.
*
outer
=
IntSet
::
make
(
dom_map
.
at
(
s
->
outer
));
*
inner
=
IntSet
::
make
(
dom_map
.
at
(
s
->
inner
));
return
;
}
}
namespace
{
// evaluator to evaluate the int set
class
IRSetEvaluator
{
public
:
inline
IntSet
Eval
(
Expr
expr
)
{
static
const
FType
&
f
=
vtable
();
if
(
f
.
can_dispatch
(
expr
))
{
return
f
(
expr
,
expr
,
this
);
}
else
{
LOG
(
WARNING
)
<<
"cannot evaluate set type "
<<
expr
->
type_key
();
return
IntSet
::
make_all_set
();
}
}
using
FType
=
tvm
::
IRFunctor
<
IntSet
(
const
NodeRef
&
,
const
Expr
&
,
IRSetEvaluator
*
)
>
;
static
FType
&
vtable
()
{
// NOLINT(*)
static
FType
inst
;
return
inst
;
}
std
::
unordered_map
<
const
Variable
*
,
IntSet
>
dom_map
;
};
inline
IntSet
ConstOp
(
const
NodeRef
&
,
const
Expr
&
e
,
IRSetEvaluator
*
)
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
e
,
1
));
}
TVM_STATIC_IR_FUNCTOR
(
IRSetEvaluator
,
vtable
)
.
set_dispatch
<
IntImm
>
(
ConstOp
)
.
set_dispatch
<
UIntImm
>
(
ConstOp
)
.
set_dispatch
<
FloatImm
>
(
ConstOp
);
TVM_STATIC_IR_FUNCTOR
(
IRSetEvaluator
,
vtable
)
.
set_dispatch
<
Variable
>
([](
const
Variable
*
op
,
const
Expr
&
e
,
IRSetEvaluator
*
m
)
{
auto
it
=
m
->
dom_map
.
find
(
op
);
if
(
it
!=
m
->
dom_map
.
end
())
{
return
it
->
second
;
}
else
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
e
,
1
));
}
});
// binary operator
template
<
typename
T
>
inline
IntSet
Binary
(
const
T
*
op
,
const
Expr
&
e
,
IRSetEvaluator
*
m
)
{
IntSet
a
=
m
->
Eval
(
op
->
a
);
IntSet
b
=
m
->
Eval
(
op
->
b
);
if
(
IsNumber
(
a
)
&&
IsNumber
(
b
))
{
if
(
Match
(
a
,
op
->
a
)
&&
Match
(
b
,
op
->
b
))
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
e
,
1
));
}
else
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
T
::
make
(
AsNumber
(
a
),
AsNumber
(
b
)),
1
));
}
}
else
{
return
BinaryCombine
<
T
>
(
a
,
b
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRSetEvaluator
,
vtable
)
.
set_dispatch
<
Add
>
(
Binary
<
Add
>
)
.
set_dispatch
<
Sub
>
(
Binary
<
Sub
>
)
.
set_dispatch
<
Mul
>
(
Binary
<
Mul
>
)
.
set_dispatch
<
Div
>
(
Binary
<
Div
>
)
.
set_dispatch
<
Mod
>
(
Binary
<
Mod
>
)
.
set_dispatch
<
Min
>
(
Binary
<
Min
>
)
.
set_dispatch
<
Max
>
(
Binary
<
Max
>
);
// use simply bound for logical expressions for now.
inline
IntSet
Logical
(
const
NodeRef
&
,
const
Expr
&
e
,
IRSetEvaluator
*
)
{
return
IntSet
::
make
(
Range
::
make_with_min_extent
(
0
,
2
));
}
TVM_STATIC_IR_FUNCTOR
(
IRSetEvaluator
,
vtable
)
.
set_dispatch
<
EQ
>
(
Logical
)
.
set_dispatch
<
NE
>
(
Logical
)
.
set_dispatch
<
LT
>
(
Logical
)
.
set_dispatch
<
LE
>
(
Logical
)
.
set_dispatch
<
GT
>
(
Logical
)
.
set_dispatch
<
GE
>
(
Logical
)
.
set_dispatch
<
And
>
(
Logical
)
.
set_dispatch
<
Or
>
(
Logical
);
}
// namespace
IntSet
Eval
(
Expr
e
,
const
std
::
unordered_map
<
IterVar
,
IntSet
>&
dom_map
)
{
IRSetEvaluator
m
;
for
(
auto
kv
:
dom_map
)
{
m
.
dom_map
[
kv
.
first
->
var
.
as
<
Variable
>
()]
=
kv
.
second
;
}
return
m
.
Eval
(
e
);
}
}
// namespace bound
}
// namespace tvm
src/bound/int_set.h
View file @
3ba5c15b
/*!
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2016 by Contributors
* \file int_set.h
* \file int_set.h
* \brief Abstract
class for iteration integer set
s.
* \brief Abstract
ion for all integer set operation
s.
*/
*/
#ifndef TVM_BOUND_INT_SET_H_
#ifndef TVM_BOUND_INT_SET_H_
#define TVM_BOUND_INT_SET_H_
#define TVM_BOUND_INT_SET_H_
...
@@ -11,35 +11,92 @@
...
@@ -11,35 +11,92 @@
namespace
tvm
{
namespace
tvm
{
namespace
bound
{
namespace
bound
{
// internal node container of int set.
class
IntSetNode
;
/*!
/*!
* \brief
abstract class of integer set for iteration sets
.
* \brief
Integer set class, represent a set of integers in one dimension
.
*/
*/
class
IntSet
{
class
IntSet
:
public
NodeRef
{
public
:
public
:
/
/ constructor
/
*! \brief constructor */
IntSet
()
;
IntSet
()
{}
//
whether the set is same as range
//
constructor from not deontainer.
bool
SameAs
(
const
Range
&
r
)
const
;
explicit
IntSet
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/
/ make integer set by range
/
*! \return whether the set is empty */
static
IntSet
make
(
Range
r
);
inline
bool
is_empty
()
const
{
// make integer set as a constant value
return
!
defined
();
static
IntSet
make
(
Expr
value
);
}
/
/ upward inference function
/
*!
// get the int set of parent given int set of outer and in
ner
* \brief access the internal node contai
ner
static
void
PassUp
(
const
SplitNode
*
s
,
* \return the pointer to the internal node container
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
*/
const
IntSet
&
outer
,
inline
const
IntSetNode
*
operator
->
()
const
;
const
IntSet
&
inner
,
/*!
IntSet
*
parent
);
* \param dom The domain to be created.
// upward inference functio
n
* \return create integer set from existing domai
n
// get the int set of outer and inner given int set of fused.
*/
static
void
PassUp
(
const
FuseNode
*
s
,
static
IntSet
make
(
Range
dom
);
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
/*!
const
IntSet
&
fused
,
* \return create integer set that represents everything
IntSet
*
outer
,
*/
IntSet
*
inner
);
static
IntSet
make_all_set
(
);
};
};
/*!
* \brief Find an symbolic integer set that contains all possible values of
* e given the domain of each iteration variables.
*
* \param e The expression to be evaluated.
* \param dom_map The domain of each variable.
* \return An integer set that can cover all the possible values of e.
*/
IntSet
Eval
(
Expr
e
,
const
std
::
unordered_map
<
IterVar
,
IntSet
>&
dom_map
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Split relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param outer domain of outer iteration.
* \param inner domain of inner iteration.
* \param parent The result domain of parent.
*/
void
PassUp
(
const
SplitNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
outer
,
const
IntSet
&
inner
,
IntSet
*
parent
);
/*!
* \brief Conditional upward message passing.
*
* Get domain of parent, condition on domain of children.
* Domain is represented as IntSet.
*
* \param s The Fuse relation node.
* \param dom_map The old domain result from downward message passing.
* Contains the domain set if all the children are full set.
* \param fused domain of fused iteration.
* \param outer The result domain of outer iteration.
* \param inner The result domain of inner iteration.
*/
void
PassUp
(
const
FuseNode
*
s
,
const
std
::
unordered_map
<
IterVar
,
Range
>&
dom_map
,
const
IntSet
&
fused
,
IntSet
*
outer
,
IntSet
*
inner
);
/*!
* \brief Create an union set of all sets
* \param sets The sets to be unioned
* \return the set after union
*/
IntSet
Union
(
const
Array
<
IntSet
>&
sets
);
}
// namespace bound
}
// namespace bound
}
// namespace tvm
}
// namespace tvm
...
...
src/lang/schedule.cc
View file @
3ba5c15b
...
@@ -152,7 +152,6 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
...
@@ -152,7 +152,6 @@ Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent,
IterVar
*
p_x_outer
,
IterVar
*
p_y_outer
,
IterVar
*
p_x_outer
,
IterVar
*
p_y_outer
,
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
IterVar
*
p_x_inner
,
IterVar
*
p_y_inner
,
Expr
x_factor
,
Expr
y_factor
)
{
// NOLINT(*)
Expr
x_factor
,
Expr
y_factor
)
{
// NOLINT(*)
split
(
x_parent
,
p_x_outer
,
p_x_inner
,
x_factor
);
split
(
x_parent
,
p_x_outer
,
p_x_inner
,
x_factor
);
split
(
y_parent
,
p_y_outer
,
p_y_inner
,
y_factor
);
split
(
y_parent
,
p_y_outer
,
p_y_inner
,
y_factor
);
reorder
(
Array
<
IterVar
>
({
*
p_x_inner
,
*
p_y_inner
,
*
p_x_outer
,
*
p_y_outer
}));
reorder
(
Array
<
IterVar
>
({
*
p_x_inner
,
*
p_y_inner
,
*
p_x_outer
,
*
p_y_outer
}));
...
...
src/pass/schedule_ops.cc
View file @
3ba5c15b
...
@@ -10,8 +10,6 @@
...
@@ -10,8 +10,6 @@
namespace
tvm
{
namespace
tvm
{
namespace
ir
{
namespace
ir
{
namespace
{
namespace
{
}
// namespace
}
// namespace
}
// namespace ir
}
// namespace ir
}
// namespace tvm
}
// namespace tvm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment