Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d3ee03eb
Commit
d3ee03eb
authored
Oct 19, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
expose range
parent
56e10eb0
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
94 additions
and
30 deletions
+94
-30
include/tvm/expr.h
+1
-1
include/tvm/expr_util.h
+12
-0
include/tvm/tensor.h
+13
-8
python/tvm/cpp/domain.py
+5
-0
python/tvm/cpp/expr.py
+0
-1
python/tvm/cpp/function.py
+0
-17
src/c_api/c_api_function.cc
+28
-2
src/c_api/c_api_registry.h
+11
-1
src/expr/tensor.cc
+19
-0
tests/cpp/tensor_test.cc
+5
-0
No files found.
include/tvm/expr.h
View file @
d3ee03eb
...
...
@@ -89,7 +89,7 @@ class Var : public Expr {
};
Expr
IntConstant
(
int64_t
value
);
Expr
FloatConstant
(
int64_t
value
);
Expr
FloatConstant
(
double
value
);
/*! \brief base of expression node */
class
ExprNode
:
public
Node
{
...
...
include/tvm/expr_util.h
View file @
d3ee03eb
...
...
@@ -40,6 +40,18 @@ inline void Visit(const Expr& expr, FVisit fvisit) {
Visit
(
n
->
src
,
fvisit
);
break
;
}
case
kReduceNode
:
{
const
auto
*
n
=
expr
.
Get
<
ReduceNode
>
();
Visit
(
n
->
src
,
fvisit
);
break
;
}
case
kTensorReadNode
:
{
const
auto
*
n
=
expr
.
Get
<
TensorReadNode
>
();
for
(
size_t
i
=
0
;
i
<
n
->
indices
.
size
();
++
i
)
{
Visit
(
n
->
indices
[
i
],
fvisit
);
}
break
;
}
default
:
break
;
}
fvisit
(
expr
);
...
...
include/tvm/tensor.h
View file @
d3ee03eb
...
...
@@ -7,6 +7,7 @@
#define TVM_TENSOR_H_
#include <string>
#include <vector>
#include <type_traits>
#include "./expr.h"
#include "./array.h"
...
...
@@ -46,17 +47,17 @@ class TensorNode : public Node {
using
FCompute
=
std
::
function
<
Expr
(
const
Array
<
Var
>&
i
)
>
;
// converters from other functions into fcompute
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
x
)
>
f
)
{
return
[
f
](
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
]);
};
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
x
)
>
f
)
{
return
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
]);
};
}
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
,
Var
)
>
f
)
{
return
[
f
](
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
]);
};
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
,
Var
)
>
f
)
{
return
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
]);
};
}
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
)
{
return
[
f
](
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
]);
};
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
,
Var
,
Var
)
>
f
)
{
return
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
]);
};
}
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
)
{
return
[
f
](
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
]);
};
inline
FCompute
GetFCompute
(
std
::
function
<
Expr
(
Var
,
Var
,
Var
,
Var
)
>
f
)
{
return
[
f
]
(
const
Array
<
Var
>&
i
)
{
return
f
(
i
[
0
],
i
[
1
],
i
[
2
],
i
[
3
]);
};
}
/*!
...
...
@@ -132,6 +133,10 @@ class Tensor : public NodeRef {
* \return the result expression representing tensor read.
*/
Expr
operator
()(
Array
<
Expr
>
indices
)
const
;
/*! \return list of input tensors to this tensor */
std
::
vector
<
Tensor
>
InputTensors
()
const
;
/*! \return whether the tensor stores a result of reduction */
bool
IsRTensor
()
const
;
// printt function
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Tensor
&
t
)
{
// NOLINT(*)
os
<<
"Tensor(shape="
<<
t
.
shape
()
...
...
python/tvm/cpp/domain.py
0 → 100644
View file @
d3ee03eb
from
._ctypes._api
import
NodeBase
,
register_node
@register_node
(
"RangeNode"
)
class
Range
(
NodeBase
):
pass
python/tvm/cpp/expr.py
View file @
d3ee03eb
from
._ctypes._api
import
NodeBase
,
register_node
from
.function
import
binary_op
from
._function_internal
import
_binary_op
class
Expr
(
NodeBase
):
def
__add__
(
self
,
other
):
...
...
python/tvm/cpp/function.py
View file @
d3ee03eb
...
...
@@ -28,23 +28,6 @@ def _symbol(value):
return
value
def
binary_op
(
op
,
lhs
,
rhs
):
"""Binary operator given op lhs and rhs
Parameters
----------
op : str
The operator string
lhs : Expr/number
The left operand
rhs : Expr/number
The right operand
"""
return
_function_internal
.
_binary_op
(
op
,
_symbol
(
lhs
),
_symbol
(
rhs
))
def
max
(
lhs
,
rhs
):
"""Max of two expressions
...
...
src/c_api/c_api_function.cc
View file @
d3ee03eb
...
...
@@ -5,6 +5,7 @@
*/
#include <tvm/expr.h>
#include <tvm/op.h>
#include <tvm/tensor.h>
#include "./c_api_registry.h"
namespace
dmlc
{
...
...
@@ -37,7 +38,7 @@ TVM_REGISTER_API(constant)
})
.
add_argument
(
"src"
,
"Number"
,
"source number"
);
TVM_REGISTER_API
(
_
binary_op
)
TVM_REGISTER_API
(
binary_op
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
CHECK
(
args
.
at
(
0
).
type_id
==
kStr
);
*
ret
=
(
*
BinaryOp
::
Get
(
args
.
at
(
0
).
str
.
c_str
()))(
args
.
at
(
1
),
args
.
at
(
2
));
...
...
@@ -53,11 +54,36 @@ TVM_REGISTER_API(_raw_ptr)
})
.
add_argument
(
"src"
,
"NodeBase"
,
"the node base"
);
TVM_REGISTER_API
(
Range
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Range
(
args
.
at
(
0
),
args
.
at
(
1
));
})
.
add_argument
(
"begin"
,
"Expr"
,
"beginning of the range."
)
.
add_argument
(
"end"
,
"Expr"
,
"end of the range"
);
TVM_REGISTER_API
(
_TensorInput
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
Tensor
(
static_cast
<
Array
<
Expr
>
>
(
args
.
at
(
0
)),
static_cast
<
std
::
string
>
(
args
.
at
(
1
)),
static_cast
<
DataType
>
(
static_cast
<
int
>
(
args
.
at
(
1
))));
});
// transformations
TVM_REGISTER_API
(
format_str
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
CHECK
(
args
.
at
(
0
).
type_id
==
kNodeHandle
);
std
::
ostringstream
os
;
os
<<
Expr
(
args
.
at
(
0
));
auto
&
sptr
=
args
.
at
(
0
).
sptr
;
if
(
sptr
->
is_type
<
TensorNode
>
())
{
os
<<
args
.
at
(
0
).
operator
Tensor
();
}
else
if
(
sptr
->
is_type
<
RDomainNode
>
())
{
os
<<
args
.
at
(
0
).
operator
RDomain
();
}
else
if
(
sptr
->
is_type
<
RangeNode
>
())
{
os
<<
args
.
at
(
0
).
operator
Range
();
}
else
{
os
<<
args
.
at
(
0
).
operator
Expr
();
}
*
ret
=
os
.
str
();
})
.
add_argument
(
"expr"
,
"Expr"
,
"expression to be printed"
);
...
...
src/c_api/c_api_registry.h
View file @
d3ee03eb
...
...
@@ -62,7 +62,17 @@ struct APIVariantValue {
if
(
type_id
==
kNull
)
return
T
();
CHECK_EQ
(
type_id
,
kNodeHandle
);
std
::
shared_ptr
<
Node
>
x
=
sptr
;
return
T
(
std
::
move
(
x
));
T
inst
;
inst
.
node_
=
std
::
move
(
x
);
return
inst
;
}
inline
operator
Expr
()
const
{
if
(
type_id
==
kNull
)
return
Expr
();
if
(
type_id
==
kLong
)
return
IntConstant
(
operator
int64_t
());
if
(
type_id
==
kDouble
)
return
FloatConstant
(
operator
double
());
CHECK_EQ
(
type_id
,
kNodeHandle
);
std
::
shared_ptr
<
Node
>
x
=
sptr
;
return
Expr
(
std
::
move
(
x
));
}
inline
operator
double
()
const
{
CHECK_EQ
(
type_id
,
kDouble
);
...
...
src/expr/tensor.cc
View file @
d3ee03eb
...
...
@@ -4,6 +4,7 @@
*/
#include <tvm/tensor.h>
#include <tvm/expr_node.h>
#include <tvm/expr_util.h>
#include <memory>
namespace
tvm
{
...
...
@@ -43,6 +44,24 @@ Expr Tensor::operator()(Array<Expr> indices) const {
return
Expr
(
std
::
move
(
node
));
}
std
::
vector
<
Tensor
>
Tensor
::
InputTensors
()
const
{
const
TensorNode
*
n
=
static_cast
<
const
TensorNode
*>
(
node_
.
get
());
std
::
vector
<
Tensor
>
inputs
;
if
(
n
->
source
.
is_null
())
return
inputs
;
Visit
(
n
->
source
,
[
&
inputs
](
const
Expr
&
e
)
{
if
(
e
.
node_type
()
==
kTensorReadNode
)
{
inputs
.
push_back
(
e
.
Get
<
TensorReadNode
>
()
->
tensor
);
}
});
return
inputs
;
}
bool
Tensor
::
IsRTensor
()
const
{
const
TensorNode
*
n
=
static_cast
<
const
TensorNode
*>
(
node_
.
get
());
if
(
n
->
source
.
is_null
())
return
false
;
return
n
->
source
.
node_type
()
==
kReduceNode
;
}
TVM_REGISTER_NODE_TYPE
(
TensorNode
);
}
// namespace tvm
tests/cpp/tensor_test.cc
View file @
d3ee03eb
...
...
@@ -13,6 +13,11 @@ TEST(Tensor, Basic) {
auto
C
=
Tensor
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
sum
(
A
(
i
,
rd
.
i0
())
*
B
(
j
,
rd
.
i0
()),
rd
);
},
"C"
);
auto
inputs
=
C
.
InputTensors
();
CHECK
(
inputs
[
0
]
==
A
);
CHECK
(
inputs
[
1
]
==
B
);
CHECK
(
C
.
IsRTensor
());
}
int
main
(
int
argc
,
char
**
argv
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment