Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5fced923
Commit
5fced923
authored
Jan 12, 2017
by
Tianqi Chen
Committed by
Haichen Shen
Jan 12, 2017
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[LANG] Enable json load/save and pickle (#10)
parent
7250005d
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
521 additions
and
37 deletions
+521
-37
include/tvm/base.h
+38
-2
include/tvm/c_api.h
+5
-4
python/tvm/_ctypes/_api.py
+23
-1
python/tvm/function.py
+32
-0
src/base/common.h
+42
-0
src/base/saveload_json.cc
+306
-0
src/c_api/c_api_function.cc
+12
-0
src/c_api/c_api_registry.h
+1
-27
src/lang/expr.cc
+9
-2
src/schedule/schedule_lang.cc
+1
-0
tests/cpp/expr_test.cc
+3
-0
tests/python/test_lang_basic.py
+11
-0
tests/python/test_lang_container.py
+20
-0
tests/python/test_lang_schedule.py
+13
-0
tests/python/test_lang_tensor.py
+5
-1
No files found.
include/tvm/base.h
View file @
5fced923
...
...
@@ -21,6 +21,41 @@ using ::tvm::Node;
using
::
tvm
::
NodeRef
;
using
::
tvm
::
AttrVisitor
;
/*!
* \brief save the node as well as all the node it depends on as json.
* This can be used to serialize any TVM object
*
* \return the string representation of the node.
*/
std
::
string
SaveJSON
(
const
NodeRef
&
node
);
/*!
* \brief Internal implementation of LoadJSON
* Load tvm Node object from json and return a shared_ptr of Node.
* \param json_str The json string to load from.
*
* \return The shared_ptr of the Node.
*/
std
::
shared_ptr
<
Node
>
LoadJSON_
(
std
::
string
json_str
);
/*!
* \brief Load the node from json string.
* This can be used to deserialize any TVM object.
*
* \param json_str The json string to load from.
*
* \tparam NodeType the nodetype
*
* \code
* Expr e = LoadJSON<Expr>(json_str);
* \endcode
*/
template
<
typename
NodeType
,
typename
=
typename
std
::
enable_if
<
std
::
is_base_of
<
NodeRef
,
NodeType
>::
value
>::
type
>
inline
NodeType
LoadJSON
(
const
std
::
string
&
json_str
)
{
return
NodeType
(
LoadJSON_
(
json_str
));
}
/*! \brief typedef the factory function of data iterator */
using
NodeFactory
=
std
::
function
<
std
::
shared_ptr
<
Node
>
()
>
;
/*!
...
...
@@ -32,8 +67,9 @@ struct NodeFactoryReg
};
#define TVM_REGISTER_NODE_TYPE(TypeName) \
DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \
.set_body([]() { return std::make_shared<TypeName>(); })
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \
.set_body([]() { return std::make_shared<TypeName>(); })
}
// namespace tvm
#endif // TVM_BASE_H_
include/tvm/c_api.h
View file @
5fced923
...
...
@@ -15,14 +15,15 @@
/*! \brief TVM_DLL prefix for windows */
#ifdef _WIN32
#ifdef TVM_EXPORTS
#define TVM_DLL
TVM_EXTERN_C
__declspec(dllexport)
#define TVM_DLL __declspec(dllexport)
#else
#define TVM_DLL
TVM_EXTERN_C
__declspec(dllimport)
#define TVM_DLL __declspec(dllimport)
#endif
#else
#define TVM_DLL
TVM_EXTERN_C
#define TVM_DLL
#endif
TVM_EXTERN_C
{
/*! \brief handle to functions */
typedef
void
*
FunctionHandle
;
/*! \brief handle to node */
...
...
@@ -147,5 +148,5 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
TVM_DLL
int
TVMNodeListAttrNames
(
NodeHandle
handle
,
int
*
out_size
,
const
char
***
out_array
);
}
// TVM_EXTERN_C
#endif // TVM_C_API_H_
python/tvm/_ctypes/_api.py
View file @
5fced923
...
...
@@ -89,7 +89,6 @@ class NodeBase(object):
"'
%
s' object has no attribute '
%
s'"
%
(
str
(
type
(
self
)),
name
))
return
value
def
__hash__
(
self
):
return
_function_internal
.
_raw_ptr
(
self
)
...
...
@@ -111,6 +110,29 @@ class NodeBase(object):
names
.
append
(
py_str
(
plist
[
i
]))
return
names
def
__reduce__
(
self
):
return
(
type
(
self
),
(
None
,),
self
.
__getstate__
())
def
__getstate__
(
self
):
handle
=
self
.
handle
if
handle
is
not
None
:
return
{
'handle'
:
_function_internal
.
_save_json
(
self
)}
else
:
return
{
'handle'
:
None
}
def
__setstate__
(
self
,
state
):
# pylint: disable=assigning-non-slot
handle
=
state
[
'handle'
]
if
handle
is
not
None
:
json_str
=
handle
_push_arg
(
json_str
)
other
=
_function_internal
.
_load_json
(
json_str
)
self
.
handle
=
other
.
handle
other
.
handle
=
None
else
:
self
.
handle
=
None
def
const
(
value
,
dtype
=
None
):
"""construct a constant"""
if
dtype
is
None
:
...
...
python/tvm/function.py
View file @
5fced923
...
...
@@ -19,6 +19,38 @@ def const(value, dtype=None):
return
_function_internal
.
_const
(
value
,
dtype
)
def
load_json
(
json_str
):
"""Load tvm object from json_str.
Parameters
----------
json_str : str
The json string
Returns
-------
node : Node
The loaded tvm node.
"""
return
_function_internal
.
_load_json
(
json_str
)
def
save_json
(
node
):
"""Load tvm object as json string.
Parameters
----------
node : Node
A TVM Node object to be saved.
Returns
-------
json_str : str
Saved json string.
"""
return
_function_internal
.
_save_json
(
node
)
def
Var
(
name
=
"tindex"
,
dtype
=
int32
):
"""Create a new variable with specified name and dtype
...
...
src/base/common.h
0 → 100644
View file @
5fced923
/*!
* Copyright (c) 2016 by Contributors
* \file common.h
* \brief Common utilities
*/
#ifndef TVM_BASE_COMMON_H_
#define TVM_BASE_COMMON_H_
#include <tvm/base.h>
#include <string>
namespace
tvm
{
inline
std
::
string
Type2String
(
const
Type
&
t
)
{
std
::
ostringstream
os
;
os
<<
t
;
return
os
.
str
();
}
inline
Type
String2Type
(
std
::
string
s
)
{
std
::
istringstream
is
(
s
);
halide_type_code_t
code
=
Type
::
Int
;
if
(
s
.
substr
(
0
,
3
)
==
"int"
)
{
code
=
Type
::
Int
;
s
=
s
.
substr
(
3
);
}
else
if
(
s
.
substr
(
0
,
4
)
==
"uint"
)
{
code
=
Type
::
UInt
;
s
=
s
.
substr
(
4
);
}
else
if
(
s
.
substr
(
0
,
5
)
==
"float"
)
{
code
=
Type
::
Float
;
s
=
s
.
substr
(
5
);
}
else
if
(
s
.
substr
(
0
,
5
)
==
"float"
)
{
code
=
Type
::
Float
;
s
=
s
.
substr
(
5
);
}
else
{
LOG
(
FATAL
)
<<
"unknown type "
<<
s
;
}
int
bits
=
32
,
lanes
=
1
;
if
(
sscanf
(
s
.
c_str
(),
"%dx%d"
,
&
bits
,
&
lanes
)
==
0
)
{
LOG
(
FATAL
)
<<
"unknown type "
<<
s
;
}
return
Type
(
code
,
bits
,
lanes
);
}
}
// namespace tvm
#endif // TVM_BASE_COMMON_H_
src/base/saveload_json.cc
0 → 100644
View file @
5fced923
/*!
* Copyright (c) 2016 by Contributors
* \file saveload_json.cc
* \brief Utilities to save/load TVM objects.
*/
#include <tvm/base.h>
#include <tvm/container.h>
#include <dmlc/json.h>
#include <string>
#include "./common.h"
namespace
tvm
{
// indexer to index all the ndoes
class
NodeIndexer
:
public
AttrVisitor
{
public
:
std
::
unordered_map
<
Node
*
,
size_t
>
node_index
{{
nullptr
,
0
}};
std
::
vector
<
Node
*>
node_list
{
nullptr
};
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
int64_t
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
uint64_t
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
int
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
bool
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
std
::
string
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
Type
*
value
)
final
{}
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
MakeIndex
(
value
->
node_
.
get
());
}
// make index of all the children of node
void
MakeIndex
(
Node
*
node
)
{
if
(
node
==
nullptr
)
return
;
if
(
node_index
.
count
(
node
))
return
;
CHECK_EQ
(
node_index
.
size
(),
node_list
.
size
());
node_index
[
node
]
=
node_list
.
size
();
node_list
.
push_back
(
node
);
if
(
node
->
is_type
<
ArrayNode
>
())
{
ArrayNode
*
n
=
static_cast
<
ArrayNode
*>
(
node
);
for
(
const
auto
&
sp
:
n
->
data
)
{
MakeIndex
(
sp
.
get
());
}
}
else
if
(
node
->
is_type
<
MapNode
>
())
{
MapNode
*
n
=
static_cast
<
MapNode
*>
(
node
);
for
(
const
auto
&
kv
:
n
->
data
)
{
MakeIndex
(
kv
.
first
.
get
());
MakeIndex
(
kv
.
second
.
get
());
}
}
else
{
node
->
VisitAttrs
(
this
);
}
}
};
// use map so attributes are ordered.
using
AttrMap
=
std
::
map
<
std
::
string
,
std
::
string
>
;
// A Node structure for JSON node.
struct
JSONNode
{
// The type key of the data
std
::
string
type_key
;
// the attributes
AttrMap
attrs
;
// container data
std
::
vector
<
size_t
>
data
;
void
Save
(
dmlc
::
JSONWriter
*
writer
)
const
{
writer
->
BeginObject
();
writer
->
WriteObjectKeyValue
(
"type_key"
,
type_key
);
if
(
attrs
.
size
()
!=
0
)
{
writer
->
WriteObjectKeyValue
(
"attrs"
,
attrs
);
}
if
(
data
.
size
()
!=
0
)
{
writer
->
WriteObjectKeyValue
(
"data"
,
data
);
}
writer
->
EndObject
();
}
void
Load
(
dmlc
::
JSONReader
*
reader
)
{
attrs
.
clear
();
data
.
clear
();
type_key
.
clear
();
dmlc
::
JSONObjectReadHelper
helper
;
helper
.
DeclareOptionalField
(
"type_key"
,
&
type_key
);
helper
.
DeclareOptionalField
(
"attrs"
,
&
attrs
);
helper
.
DeclareOptionalField
(
"data"
,
&
data
);
helper
.
ReadAllFields
(
reader
);
}
};
class
JSONAttrGetter
:
public
AttrVisitor
{
public
:
const
std
::
unordered_map
<
Node
*
,
size_t
>*
node_index_
;
JSONNode
*
node_
;
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
node_
->
attrs
[
key
]
=
std
::
to_string
(
*
value
);
}
void
Visit
(
const
char
*
key
,
int64_t
*
value
)
final
{
node_
->
attrs
[
key
]
=
std
::
to_string
(
*
value
);
}
void
Visit
(
const
char
*
key
,
uint64_t
*
value
)
final
{
node_
->
attrs
[
key
]
=
std
::
to_string
(
*
value
);
}
void
Visit
(
const
char
*
key
,
int
*
value
)
final
{
node_
->
attrs
[
key
]
=
std
::
to_string
(
*
value
);
}
void
Visit
(
const
char
*
key
,
bool
*
value
)
final
{
node_
->
attrs
[
key
]
=
std
::
to_string
(
*
value
);
}
void
Visit
(
const
char
*
key
,
std
::
string
*
value
)
final
{
node_
->
attrs
[
key
]
=
*
value
;
}
void
Visit
(
const
char
*
key
,
Type
*
value
)
final
{
node_
->
attrs
[
key
]
=
Type2String
(
*
value
);
}
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
node_
->
attrs
[
key
]
=
std
::
to_string
(
node_index_
->
at
(
value
->
node_
.
get
()));
}
// Get the node
void
Get
(
Node
*
node
)
{
if
(
node
==
nullptr
)
{
node_
->
type_key
.
clear
();
return
;
}
node_
->
type_key
=
node
->
type_key
();
node_
->
attrs
.
clear
();
node_
->
data
.
clear
();
if
(
node
->
is_type
<
ArrayNode
>
())
{
ArrayNode
*
n
=
static_cast
<
ArrayNode
*>
(
node
);
for
(
size_t
i
=
0
;
i
<
n
->
data
.
size
();
++
i
)
{
node_
->
data
.
push_back
(
node_index_
->
at
(
n
->
data
[
i
].
get
()));
}
}
else
if
(
node
->
is_type
<
MapNode
>
())
{
MapNode
*
n
=
static_cast
<
MapNode
*>
(
node
);
std
::
vector
<
std
::
pair
<
size_t
,
size_t
>
>
elems
;
for
(
const
auto
&
kv
:
n
->
data
)
{
node_
->
data
.
push_back
(
node_index_
->
at
(
kv
.
first
.
get
()));
node_
->
data
.
push_back
(
node_index_
->
at
(
kv
.
second
.
get
()));
}
}
else
{
node
->
VisitAttrs
(
this
);
}
}
};
class
JSONAttrSetter
:
public
AttrVisitor
{
public
:
const
std
::
vector
<
std
::
shared_ptr
<
Node
>
>*
node_list_
;
JSONNode
*
node_
;
std
::
string
GetValue
(
const
char
*
key
)
const
{
auto
it
=
node_
->
attrs
.
find
(
key
);
if
(
it
==
node_
->
attrs
.
end
())
{
LOG
(
FATAL
)
<<
"JSONReader: cannot find field "
<<
key
;
}
return
it
->
second
;
}
template
<
typename
T
>
void
ParseValue
(
const
char
*
key
,
T
*
value
)
const
{
std
::
istringstream
is
(
GetValue
(
key
));
is
>>
*
value
;
if
(
is
.
fail
())
{
LOG
(
FATAL
)
<<
"Wrong value format for field "
<<
key
;
}
}
void
Visit
(
const
char
*
key
,
double
*
value
)
final
{
ParseValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
int64_t
*
value
)
final
{
ParseValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
uint64_t
*
value
)
final
{
ParseValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
int
*
value
)
final
{
ParseValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
bool
*
value
)
final
{
ParseValue
(
key
,
value
);
}
void
Visit
(
const
char
*
key
,
std
::
string
*
value
)
final
{
*
value
=
GetValue
(
key
);
}
void
Visit
(
const
char
*
key
,
Type
*
value
)
final
{
std
::
string
stype
=
GetValue
(
key
);
*
value
=
String2Type
(
stype
);
}
void
Visit
(
const
char
*
key
,
NodeRef
*
value
)
final
{
size_t
index
;
ParseValue
(
key
,
&
index
);
value
->
node_
=
node_list_
->
at
(
index
);
}
// Get the node
void
Set
(
Node
*
node
)
{
if
(
node
==
nullptr
)
return
;
if
(
node
->
is_type
<
ArrayNode
>
())
{
ArrayNode
*
n
=
static_cast
<
ArrayNode
*>
(
node
);
n
->
data
.
clear
();
for
(
size_t
index
:
node_
->
data
)
{
n
->
data
.
push_back
(
node_list_
->
at
(
index
));
}
}
else
if
(
node
->
is_type
<
MapNode
>
())
{
MapNode
*
n
=
static_cast
<
MapNode
*>
(
node
);
CHECK_EQ
(
node_
->
data
.
size
()
%
2
,
0U
);
for
(
size_t
i
=
0
;
i
<
node_
->
data
.
size
();
i
+=
2
)
{
n
->
data
[
node_list_
->
at
(
node_
->
data
[
i
])]
=
node_list_
->
at
(
node_
->
data
[
i
+
1
]);
}
}
else
{
node
->
VisitAttrs
(
this
);
}
}
};
// json graph structure to store node
struct
JSONGraph
{
// the root of the graph
size_t
root
;
// the nodes of the graph
std
::
vector
<
JSONNode
>
nodes
;
// global attributes
AttrMap
attrs
;
void
Save
(
dmlc
::
JSONWriter
*
writer
)
const
{
writer
->
BeginObject
();
writer
->
WriteObjectKeyValue
(
"root"
,
root
);
writer
->
WriteObjectKeyValue
(
"nodes"
,
nodes
);
if
(
attrs
.
size
()
!=
0
)
{
writer
->
WriteObjectKeyValue
(
"attrs"
,
attrs
);
}
writer
->
EndObject
();
}
void
Load
(
dmlc
::
JSONReader
*
reader
)
{
attrs
.
clear
();
dmlc
::
JSONObjectReadHelper
helper
;
helper
.
DeclareField
(
"root"
,
&
root
);
helper
.
DeclareField
(
"nodes"
,
&
nodes
);
helper
.
DeclareOptionalField
(
"attrs"
,
&
attrs
);
helper
.
ReadAllFields
(
reader
);
}
static
JSONGraph
Create
(
const
NodeRef
&
root
)
{
JSONGraph
g
;
NodeIndexer
indexer
;
indexer
.
MakeIndex
(
root
.
node_
.
get
());
JSONAttrGetter
getter
;
getter
.
node_index_
=
&
indexer
.
node_index
;
for
(
Node
*
n
:
indexer
.
node_list
)
{
JSONNode
jnode
;
getter
.
node_
=
&
jnode
;
getter
.
Get
(
n
);
g
.
nodes
.
emplace_back
(
std
::
move
(
jnode
));
}
g
.
attrs
[
"tvm_version"
]
=
"0.1.0"
;
g
.
root
=
indexer
.
node_index
.
at
(
root
.
node_
.
get
());
return
g
;
}
};
std
::
string
SaveJSON
(
const
NodeRef
&
n
)
{
auto
jgraph
=
JSONGraph
::
Create
(
n
);
std
::
ostringstream
os
;
dmlc
::
JSONWriter
writer
(
&
os
);
jgraph
.
Save
(
&
writer
);
return
os
.
str
();
}
std
::
shared_ptr
<
Node
>
LoadJSON_
(
std
::
string
json_str
)
{
std
::
istringstream
is
(
json_str
);
dmlc
::
JSONReader
reader
(
&
is
);
JSONGraph
jgraph
;
// load in json graph.
jgraph
.
Load
(
&
reader
);
std
::
vector
<
std
::
shared_ptr
<
Node
>
>
nodes
;
// node 0 is always null
nodes
.
reserve
(
jgraph
.
nodes
.
size
());
for
(
const
JSONNode
&
jnode
:
jgraph
.
nodes
)
{
if
(
jnode
.
type_key
.
length
()
!=
0
)
{
auto
*
f
=
dmlc
::
Registry
<
NodeFactoryReg
>::
Find
(
jnode
.
type_key
);
CHECK
(
f
!=
nullptr
)
<<
"Node type
\'
"
<<
jnode
.
type_key
<<
"
\'
is not registered in TVM"
;
nodes
.
emplace_back
(
f
->
body
());
}
else
{
nodes
.
emplace_back
(
std
::
shared_ptr
<
Node
>
());
}
}
CHECK_EQ
(
nodes
.
size
(),
jgraph
.
nodes
.
size
());
JSONAttrSetter
setter
;
setter
.
node_list_
=
&
nodes
;
for
(
size_t
i
=
0
;
i
<
nodes
.
size
();
++
i
)
{
setter
.
node_
=
&
jgraph
.
nodes
[
i
];
setter
.
Set
(
nodes
[
i
].
get
());
}
return
nodes
.
at
(
jgraph
.
root
);
}
}
// namespace tvm
src/c_api/c_api_function.cc
View file @
5fced923
...
...
@@ -34,4 +34,16 @@ TVM_REGISTER_API(_raw_ptr)
})
.
add_argument
(
"src"
,
"NodeBase"
,
"the node base"
);
TVM_REGISTER_API
(
_save_json
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
SaveJSON
(
args
.
at
(
0
));
})
.
add_argument
(
"src"
,
"json_str"
,
"the node "
);
TVM_REGISTER_API
(
_load_json
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
NodeRef
(
LoadJSON_
(
args
.
at
(
0
)));
})
.
add_argument
(
"src"
,
"NodeBase"
,
"the node"
);
}
// namespace tvm
src/c_api/c_api_registry.h
View file @
5fced923
...
...
@@ -13,36 +13,10 @@
#include <limits>
#include <string>
#include <vector>
#include "../base/common.h"
namespace
tvm
{
inline
std
::
string
Type2String
(
const
Type
&
t
)
{
std
::
ostringstream
os
;
os
<<
t
;
return
os
.
str
();
}
inline
Type
String2Type
(
std
::
string
s
)
{
std
::
istringstream
is
(
s
);
halide_type_code_t
code
=
Type
::
Int
;
if
(
s
.
substr
(
0
,
3
)
==
"int"
)
{
code
=
Type
::
Int
;
s
=
s
.
substr
(
3
);
}
else
if
(
s
.
substr
(
0
,
4
)
==
"uint"
)
{
code
=
Type
::
UInt
;
s
=
s
.
substr
(
4
);
}
else
if
(
s
.
substr
(
0
,
5
)
==
"float"
)
{
code
=
Type
::
Float
;
s
=
s
.
substr
(
5
);
}
else
if
(
s
.
substr
(
0
,
5
)
==
"float"
)
{
code
=
Type
::
Float
;
s
=
s
.
substr
(
5
);
}
else
{
LOG
(
FATAL
)
<<
"unknown type "
<<
s
;
}
int
bits
=
32
,
lanes
=
1
;
if
(
sscanf
(
s
.
c_str
(),
"%dx%d"
,
&
bits
,
&
lanes
)
==
0
)
{
LOG
(
FATAL
)
<<
"unknown type "
<<
s
;
}
return
Type
(
code
,
bits
,
lanes
);
}
inline
const
char
*
TypeId2Str
(
ArgVariantID
type_id
)
{
switch
(
type_id
)
{
case
kNull
:
return
"Null"
;
...
...
src/lang/expr.cc
View file @
5fced923
...
...
@@ -13,8 +13,11 @@ DMLC_REGISTRY_ENABLE(::tvm::NodeFactoryReg);
}
// namespace dmlc
namespace
tvm
{
using
Halide
::
IR
::
RangeNode
;
Range
::
Range
(
Expr
begin
,
Expr
end
)
:
Range
(
std
::
make_shared
<
Halide
::
IR
::
RangeNode
>
(
:
Range
(
std
::
make_shared
<
RangeNode
>
(
begin
,
is_zero
(
begin
)
?
end
:
(
end
-
begin
)))
{
}
...
...
@@ -67,10 +70,14 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
});
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
Halide
::
IR
::
RangeNode
>
([](
const
Halide
::
IR
::
RangeNode
*
op
,
IRPrinter
*
p
)
{
.
set_dispatch
<
RangeNode
>
([](
const
Halide
::
IR
::
RangeNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"range(min="
<<
op
->
min
<<
", ext="
<<
op
->
extent
<<
')'
;
});
TVM_REGISTER_NODE_TYPE
(
ArrayNode
);
TVM_REGISTER_NODE_TYPE
(
MapNode
);
TVM_REGISTER_NODE_TYPE
(
RangeNode
);
TVM_REGISTER_NODE_TYPE
(
IterVarNode
);
}
// namespace tvm
src/schedule/schedule_lang.cc
View file @
5fced923
...
...
@@ -206,5 +206,6 @@ IterVarRelation FuseNode::make(
TVM_REGISTER_NODE_TYPE
(
StageNode
);
TVM_REGISTER_NODE_TYPE
(
SplitNode
);
TVM_REGISTER_NODE_TYPE
(
FuseNode
);
TVM_REGISTER_NODE_TYPE
(
ScheduleNode
);
}
// namespace tvm
tests/cpp/expr_test.cc
View file @
5fced923
...
...
@@ -6,8 +6,11 @@ TEST(Expr, Basic) {
using
namespace
tvm
;
Var
x
(
"x"
);
auto
z
=
max
(
x
+
1
+
2
,
100
);
NodeRef
tmp
=
z
;
Expr
zz
(
tmp
.
node_
);
std
::
ostringstream
os
;
os
<<
z
;
CHECK
(
zz
.
same_as
(
z
));
CHECK
(
os
.
str
()
==
"max(((x + 1) + 2), 100)"
);
}
...
...
tests/python/test_lang_basic.py
View file @
5fced923
...
...
@@ -5,6 +5,16 @@ def test_const():
assert
x
.
dtype
==
'int32'
assert
isinstance
(
x
,
tvm
.
expr
.
IntImm
)
def
test_const_saveload_json
():
# save load json
x
=
tvm
.
const
(
1
)
y
=
tvm
.
const
(
10
)
z
=
x
+
y
z
=
z
+
z
json_str
=
tvm
.
save_json
(
z
)
zz
=
tvm
.
load_json
(
json_str
)
assert
tvm
.
save_json
(
zz
)
==
tvm
.
save_json
(
z
)
def
test_make
():
x
=
tvm
.
const
(
1
)
y
=
tvm
.
make
.
IntImm
(
'int32'
,
1
)
...
...
@@ -57,6 +67,7 @@ def test_stmt():
if
__name__
==
"__main__"
:
test_attr
()
test_const
()
test_const_saveload_json
()
test_make
()
test_ir
()
test_basic
()
...
...
tests/python/test_lang_container.py
View file @
5fced923
...
...
@@ -4,6 +4,12 @@ def test_array():
a
=
tvm
.
convert
([
1
,
2
,
3
])
assert
len
(
a
)
==
3
def
test_array_save_load_json
():
a
=
tvm
.
convert
([
1
,
2
,
3
])
json_str
=
tvm
.
save_json
(
a
)
a_loaded
=
tvm
.
load_json
(
json_str
)
assert
(
a
[
1
]
.
value
==
2
)
def
test_map
():
a
=
tvm
.
Var
(
'a'
)
b
=
tvm
.
Var
(
'b'
)
...
...
@@ -15,6 +21,20 @@ def test_map():
assert
str
(
dd
)
==
str
(
amap
)
assert
a
+
1
not
in
amap
def
test_map_save_load_json
():
a
=
tvm
.
Var
(
'a'
)
b
=
tvm
.
Var
(
'b'
)
amap
=
tvm
.
convert
({
a
:
2
,
b
:
3
})
json_str
=
tvm
.
save_json
(
amap
)
amap
=
tvm
.
load_json
(
json_str
)
assert
len
(
amap
)
==
2
dd
=
{
kv
[
0
]
.
name
:
kv
[
1
]
.
value
for
kv
in
amap
.
items
()}
assert
(
dd
==
{
"a"
:
2
,
"b"
:
3
})
if
__name__
==
"__main__"
:
test_array
()
test_map
()
test_array_save_load_json
()
test_map_save_load_json
()
tests/python/test_lang_schedule.py
View file @
5fced923
import
tvm
import
pickle
as
pkl
def
test_schedule_create
():
m
=
tvm
.
Var
(
'm'
)
...
...
@@ -17,6 +18,18 @@ def test_schedule_create():
s
[
T
]
.
reorder
(
xi2
,
xi1
)
assert
T
.
op
.
axis
[
1
]
in
s
[
T
]
.
leaf_iter_vars
# save load json
json_str
=
tvm
.
save_json
(
s
)
s_loaded
=
tvm
.
load_json
(
json_str
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
(
str
(
s_loaded
.
roots
[
0
]
.
body
)
==
str
(
s
.
roots
[
0
]
.
body
))
# pickle unpickle
dump
=
pkl
.
dumps
(
s
)
s_loaded
=
pkl
.
loads
(
dump
)
assert
isinstance
(
s_loaded
,
tvm
.
schedule
.
Schedule
)
assert
(
str
(
s_loaded
.
roots
[
0
]
.
body
)
==
str
(
s
.
roots
[
0
]
.
body
))
def
test_reorder
():
m
=
tvm
.
Var
(
'm'
)
A
=
tvm
.
placeholder
((
m
,),
name
=
'A'
)
...
...
tests/python/test_lang_tensor.py
View file @
5fced923
...
...
@@ -27,7 +27,11 @@ def test_tensor_reduce():
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
[
i
,
k
]
*
B
[
j
,
k
])
rv
=
tvm
.
IterVar
((
0
,
A
.
shape
[
1
]),
name
=
"k"
)
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
T
(
i
,
j
,
rv
+
1
),
rdom
=
rv
))
print
(
C
.
op
.
body
)
# json load save
C_json
=
tvm
.
save_json
(
C
)
C_loaded
=
tvm
.
load_json
(
C_json
)
assert
(
isinstance
(
C_loaded
,
tvm
.
tensor
.
Tensor
))
assert
(
str
(
C_loaded
)
==
str
(
C
))
if
__name__
==
"__main__"
:
test_tensor
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment