Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
5445a936
Commit
5445a936
authored
Dec 01, 2016
by
tqchen
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Refactor to use iterVar
parent
7591714a
Show whitespace changes
Inline
Side-by-side
Showing
24 changed files
with
318 additions
and
281 deletions
+318
-281
HalideIR
+1
-1
include/tvm/domain.h
+0
-145
include/tvm/expr.h
+134
-9
include/tvm/ir.h
+5
-5
include/tvm/operation.h
+0
-1
include/tvm/split.h
+0
-1
include/tvm/tensor.h
+4
-9
python/tvm/_ctypes/_api.py
+1
-0
python/tvm/collections.py
+3
-4
python/tvm/expr.py
+5
-1
python/tvm/function.py
+25
-24
python/tvm/tensor.py
+14
-2
src/c_api/c_api_function.cc
+1
-15
src/c_api/c_api_lang.cc
+4
-4
src/c_api/c_api_registry.h
+6
-0
src/lang/domain.cc
+0
-37
src/lang/expr.cc
+75
-0
src/lang/ir.cc
+5
-7
src/lang/tensor.cc
+6
-0
src/pass/ir_mutator.cc
+9
-7
src/pass/ir_visitor.cc
+4
-4
src/pass/schedule_ops.cc
+0
-1
tests/cpp/tensor_test.cc
+13
-1
tests/python/test_tensor.py
+3
-3
No files found.
HalideIR
@
e96ee0f2
Subproject commit e
b2f7d604a611318fc685172847bcf5ba2fcf835
Subproject commit e
96ee0f2fb5239021c0facd5398a9a96644bc411
include/tvm/domain.h
deleted
100644 → 0
View file @
7591714a
/*!
* Copyright (c) 2016 by Contributors
* \file domain.h
* \brief Defines the domain in AST
*/
#ifndef TVM_DOMAIN_H_
#define TVM_DOMAIN_H_
#include <ir/Range.h>
#include <memory>
#include "./base.h"
#include "./expr.h"
namespace
tvm
{
/*! \brief container class of reduction domain */
class
RDomainNode
;
class
IterDomainNode
;
/*!
* \brief same as Halide::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class
Range
:
public
Halide
::
IR
::
Range
{
public
:
/*! \brief constructor */
Range
()
{}
explicit
Range
(
std
::
shared_ptr
<
Node
>
n
)
:
Halide
::
IR
::
Range
(
n
)
{}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
Range
(
Expr
begin
,
Expr
end
);
static
Range
make_with_min_extent
(
Expr
min
,
Expr
extent
);
};
/*! \brief Domain is a multi-dimensional range */
using
Domain
=
Array
<
Range
>
;
/*! \brief reduction domain */
class
RDomain
:
public
NodeRef
{
public
:
/*! \brief constructor*/
RDomain
()
{}
explicit
RDomain
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* constructor by domain
* \param domain The domain of reduction.
*/
explicit
RDomain
(
Domain
domain
);
/*!
* \brief constructor by list of ranges
* \param domain The reduction domain
*/
explicit
RDomain
(
std
::
initializer_list
<
Range
>
domain
)
:
RDomain
(
Domain
(
domain
))
{}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
RDomainNode
*
operator
->
()
const
;
/*! \return The dimension of the RDomain */
inline
size_t
ndim
()
const
;
/*!
* \param i the index.
* \return i-th index variable in the RDomain
*/
inline
Var
index
(
size_t
i
)
const
;
/*! \return the 0-th index of the domain */
inline
Var
i0
()
const
{
return
index
(
0
);
}
// low level constructor
static
RDomain
make
(
Array
<
Var
>
index
,
Domain
domain
);
};
/*! \brief use RDom as alias of RDomain */
using
RDom
=
RDomain
;
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional domain.
*/
class
IterVarNode
:
public
Node
{
/*! \brief The */
Var
var
;
/*! \brief the domain of iteration */
Range
dom
;
/*! \brief additional tag on the iteration variable */
std
::
string
tag
;
};
/*! \brief reduction domain node */
class
RDomainNode
:
public
Node
{
public
:
/*! \brief internal index */
Array
<
Var
>
index
;
/*! \brief The inernal domain */
Domain
domain
;
/*! \brief constructor */
RDomainNode
()
{}
RDomainNode
(
Array
<
Var
>
index
,
Domain
domain
)
:
index
(
index
),
domain
(
domain
)
{
}
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"index"
,
&
index
);
v
->
Visit
(
"domain"
,
&
domain
);
}
static
constexpr
const
char
*
_type_key
=
"RDomain"
;
TVM_DECLARE_NODE_TYPE_INFO
(
RDomainNode
);
};
inline
const
RDomainNode
*
RDomain
::
operator
->
()
const
{
return
static_cast
<
const
RDomainNode
*>
(
node_
.
get
());
}
inline
size_t
RDomain
::
ndim
()
const
{
return
(
*
this
)
->
index
.
size
();
}
inline
Var
RDomain
::
index
(
size_t
i
)
const
{
return
(
*
this
)
->
index
[
i
];
}
// overload print function
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
RDomain
&
r
){
// NOLINT(*)
os
<<
"rdomain("
<<
r
->
domain
<<
")"
;
return
os
;
}
}
// namespace tvm
#endif // TVM_DOMAIN_H_
include/tvm/expr.h
View file @
5445a936
/*!
* Copyright (c) 2016 by Contributors
* \file expr.h
* \brief
Defines the expressions in AST
.
* \brief
The Expr and related elements in DataFlow construction
.
*/
#ifndef TVM_EXPR_H_
#define TVM_EXPR_H_
#include <ir/Expr.h>
#include <ir/IRPrinter.h>
#include <ir/IROperator.h>
#include <string>
#include <algorithm>
#include "./base.h"
namespace
tvm
{
...
...
@@ -19,20 +21,14 @@ using Halide::Int;
using
Halide
::
UInt
;
using
Halide
::
Handle
;
// functions
using
Halide
::
cast
;
using
Halide
::
min
;
using
Halide
::
max
;
using
Halide
::
abs
;
using
Halide
::
select
;
using
Halide
::
Expr
;
using
Halide
::
VarExpr
;
using
Halide
::
IR
::
FunctionRef
;
using
Halide
::
IR
::
FunctionBaseNode
;
using
Halide
::
Internal
::
Stmt
;
using
Halide
::
Internal
::
IRPrinter
;
/*! \brief a named variable in TVM */
class
Var
:
public
Halide
::
VarExpr
{
public
:
explicit
Var
(
const
std
::
string
&
name_hint
=
"v"
,
...
...
@@ -41,5 +37,134 @@ class Var : public Halide::VarExpr {
explicit
Var
(
std
::
shared_ptr
<
Node
>
n
)
:
VarExpr
(
n
)
{}
};
/*! \brief container class of iteration variable. */
class
IterVarNode
;
/*!
* \brief same as Halide::IR::Range
* except it provide an constructor with (begin, end)
*
* \note Traditional Halide's Range have a constructor with
* (begin, extent), which does not match the convention in e.g. python.
* We decided to correct it by removing the constructor in HalideIR,
* and add it back in TVM's range.
*/
class
Range
:
public
Halide
::
IR
::
Range
{
public
:
/*! \brief constructor */
Range
()
{}
explicit
Range
(
std
::
shared_ptr
<
Node
>
n
)
:
Halide
::
IR
::
Range
(
n
)
{}
/*!
* \brief constructor by begin and end
* \param begin The begin of the range.
* \param end The end of the range.
*/
Range
(
Expr
begin
,
Expr
end
);
static
Range
make_with_min_extent
(
Expr
min
,
Expr
extent
);
};
/*!
* \brief Iteration Variable,
* represents an iteration over an integer interval.
*/
class
IterVar
:
public
NodeRef
{
public
:
// construct a new iter var without a domain
IterVar
()
{}
// construct from shared ptr.
explicit
IterVar
(
std
::
shared_ptr
<
Node
>
n
)
:
NodeRef
(
n
)
{}
/*!
* \brief construction of iteration variable.
* \param dom The iteration domain.
* \param var_name The name of iteration variable.
* \param thread_tag The additional tag to indicate whether the var is binded to fixed-thread.
*/
explicit
IterVar
(
Range
dom
,
std
::
string
var_name
=
"i"
,
std
::
string
thread_tag
=
""
);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline
const
IterVarNode
*
operator
->
()
const
;
/*!
* \return the corresponding var in the IterVar.
*/
inline
operator
Expr
()
const
;
/*! \brief specify container node */
using
ContainerType
=
IterVarNode
;
};
using
Domain
=
Array
<
Range
>
;
// functions
using
Halide
::
cast
;
using
Halide
::
min
;
using
Halide
::
max
;
using
Halide
::
abs
;
using
Halide
::
select
;
/*!
* \brief sum of of source expression over rdom
* \param source The source expression.
*/
Expr
sum
(
Expr
source
,
Array
<
IterVar
>
rdom
);
/*!
* \brief max of of source expression over rdom
* \param source The source expression.
*/
Expr
max
(
Expr
source
,
Array
<
IterVar
>
rdom
);
/*!
* \brief max of of source expression over rdom
* \param source The source expression.
*/
Expr
min
(
Expr
source
,
Array
<
IterVar
>
rdom
);
// print functions for expr
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
NodeRef
&
n
);
// NOLINT(*)
// definition of Node.
/*!
* \brief An iteration variable representing an iteration
* over a one dimensional interval.
*/
class
IterVarNode
:
public
Node
{
public
:
/*! \brief The looping variable */
Var
var
;
/*!
* \brief the domain of iteration, if known, can be None
* For the intermediate schedule node, before schedule.
*/
Range
dom
;
/*!
* \brief additional tag on the iteration variable,
* set this if this is binded already to a known thread tag.
*/
std
::
string
thread_tag
;
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"var"
,
&
var
);
v
->
Visit
(
"dom"
,
&
dom
);
v
->
Visit
(
"thread_tag"
,
&
thread_tag
);
}
static
IterVar
make
(
Var
var
,
Range
dom
,
std
::
string
thread_tag
);
static
constexpr
const
char
*
_type_key
=
"IterVar"
;
TVM_DECLARE_NODE_TYPE_INFO
(
IterVarNode
);
};
// inline implementations
inline
const
IterVarNode
*
IterVar
::
operator
->
()
const
{
return
static_cast
<
const
IterVarNode
*>
(
node_
.
get
());
}
inline
IterVar
::
operator
Expr
()
const
{
return
(
*
this
)
->
var
;
}
}
// namespace tvm
#endif // TVM_EXPR_H_
include/tvm/ir.h
View file @
5445a936
...
...
@@ -11,7 +11,7 @@
#include <type_traits>
#include <string>
#include "./base.h"
#include "./
domain
.h"
#include "./
expr
.h"
namespace
tvm
{
namespace
ir
{
...
...
@@ -30,11 +30,11 @@ struct Reduce : public ExprNode<Reduce> {
std
::
string
op
;
/*! \brief The source operand */
Expr
source
;
/*! \brief The reduction domain */
RDomain
rdom
;
/*! \brief The reduction domain
s
*/
Array
<
IterVar
>
rdom
;
/*! \brief construct expr from
name
and rdom */
static
Expr
make
(
std
::
string
name
,
Expr
src
,
RDomain
rdom
);
/*! \brief construct expr from
op
and rdom */
static
Expr
make
(
std
::
string
op
,
Expr
src
,
Array
<
IterVar
>
rdom
);
void
VisitAttrs
(
AttrVisitor
*
v
)
final
{
v
->
Visit
(
"dtype"
,
&
type
);
...
...
include/tvm/operation.h
View file @
5445a936
...
...
@@ -8,7 +8,6 @@
#include <string>
#include "./expr.h"
#include "./domain.h"
#include "./tensor.h"
namespace
tvm
{
...
...
include/tvm/split.h
View file @
5445a936
...
...
@@ -8,7 +8,6 @@
#include "./base.h"
#include "./expr.h"
#include "./domain.h"
namespace
tvm
{
...
...
include/tvm/tensor.h
View file @
5445a936
...
...
@@ -14,7 +14,6 @@
#include "./base.h"
#include "./expr.h"
#include "./domain.h"
namespace
tvm
{
...
...
@@ -66,8 +65,8 @@ class Tensor : public FunctionRef {
* \return the result expression representing tensor read.
*/
Expr
operator
()(
Array
<
Expr
>
indices
)
const
;
/
/ overload print function
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Tensor
&
t
)
;
/
*! \brief specify container node */
using
ContainerType
=
TensorNode
;
};
/*! \brief Operation that produces tensors */
...
...
@@ -87,6 +86,8 @@ class Operation : public NodeRef {
* \return The i-th output.
*/
Tensor
output
(
size_t
i
)
const
;
/*! \brief specify container node */
using
ContainerType
=
OperationNode
;
};
/*! \brief Node to represent a tensor */
...
...
@@ -162,11 +163,5 @@ inline size_t Tensor::ndim() const {
return
(
*
this
)
->
shape
.
size
();
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
Tensor
&
t
)
{
// NOLINT(*)
os
<<
"Tensor(shape="
<<
t
->
shape
<<
", name="
<<
t
->
name
<<
')'
;
return
os
;
}
}
// namespace tvm
#endif // TVM_TENSOR_H_
python/tvm/_ctypes/_api.py
View file @
5445a936
...
...
@@ -118,6 +118,7 @@ def convert(value):
raise
ValueError
(
"don't know how to handle type
%
s"
%
type
(
value
))
return
value
def
_push_arg
(
arg
):
a
=
ArgVariant
()
if
arg
is
None
:
...
...
python/tvm/collections.py
View file @
5445a936
...
...
@@ -2,6 +2,7 @@
from
__future__
import
absolute_import
as
_abs
from
._ctypes._api
import
NodeBase
,
register_node
from
.
import
_function_internal
from
.
import
expr
as
_expr
@register_node
class
Array
(
NodeBase
):
...
...
@@ -19,11 +20,9 @@ class Array(NodeBase):
@register_node
class
Range
(
NodeBase
):
def
__repr__
(
self
):
return
(
'Range(min='
+
str
(
self
.
min
)
+
', extent='
+
str
(
self
.
extent
)
+
')'
)
pass
@register_node
class
RDomain
(
NodeBas
e
):
class
IterVar
(
_expr
.
ExprCompatibl
e
):
pass
python/tvm/expr.py
View file @
5445a936
...
...
@@ -2,7 +2,7 @@ from __future__ import absolute_import as _abs
from
._ctypes._api
import
NodeBase
,
register_node
from
.
import
make
as
_make
class
Expr
(
NodeBase
):
class
Expr
Compatible
(
NodeBase
):
def
__add__
(
self
,
other
):
return
_make
.
Add
(
self
,
other
)
...
...
@@ -36,6 +36,10 @@ class Expr(NodeBase):
def
__neg__
(
self
):
return
self
.
__mul__
(
-
1
)
class
Expr
(
ExprCompatible
):
pass
class
ConstExpr
(
Expr
):
pass
...
...
python/tvm/function.py
View file @
5445a936
...
...
@@ -103,33 +103,34 @@ def compute(shape, fcompute, name="TensorCompute"):
shape
,
name
,
body
.
dtype
,
op_node
,
0
)
def
RDomain
(
dom
):
"""Create a
reduction domain given domain
def
IterVar
(
dom
,
name
=
'iter'
,
thread_tag
=
''
):
"""Create a
iteration variable
Parameters
----------
dom : list of Range or list of pairs
The reduction domain.
dom : Range
The domain of iteration.
name : str
The name of iteration variable.
thread_tag : str
The thread tag of the iteration variable.
Returns
-------
rdom : RDomain
The result
rdomain
iter_var : IterVar
The result
itervar
"""
if
not
isinstance
(
dom
,
(
list
,
tuple
)):
dom
=
[
dom
]
elif
not
isinstance
(
dom
[
0
],
(
list
,
tuple
)):
dom
=
[
dom
]
dnorm
=
[]
for
x
in
dom
:
if
isinstance
(
x
,
(
list
,
tuple
)):
if
len
(
x
)
!=
2
:
if
isinstance
(
dom
,
(
list
,
tuple
)):
if
len
(
dom
)
!=
2
:
raise
ValueError
(
"need to list of ranges"
)
dnorm
.
append
(
Range
(
x
[
0
],
x
[
1
]))
else
:
dnorm
.
append
(
x
)
dnorm
=
convert
(
dnorm
)
return
_function_internal
.
_RDomain
(
dnorm
)
dom
=
Range
(
dom
[
0
],
dom
[
1
])
if
not
isinstance
(
dom
,
_collections
.
Range
):
raise
ValueError
(
"dom need to be Range"
)
return
_function_internal
.
_IterVar
(
dom
,
name
,
thread_tag
)
def
sum
(
expr
,
rdom
):
...
...
@@ -143,10 +144,11 @@ def sum(expr, rdom):
rdom : RDomain
The reduction domainx
"""
assert
isinstance
(
rdom
,
_collections
.
RDomain
)
rdom
=
rdom
if
isinstance
(
rdom
,
list
)
else
[
rdom
]
x
=
_make
.
Reduce
(
"Add"
,
expr
,
rdom
)
return
x
def
min
(
expr
,
rdom
):
"""Create a min expression over rdom
...
...
@@ -158,11 +160,11 @@ def min(expr, rdom):
rdom : RDomain
The reduction domainx
"""
assert
isinstance
(
expr
,
_expr
.
Expr
)
assert
isinstance
(
rdom
,
_collections
.
RDomain
)
rdom
=
rdom
if
isinstance
(
rdom
,
list
)
else
[
rdom
]
x
=
_make
.
Reduce
(
"Min"
,
expr
,
rdom
)
return
x
def
max
(
expr
,
rdom
):
"""Create a min expression over rdom
...
...
@@ -174,8 +176,7 @@ def max(expr, rdom):
rdom : RDomain
The reduction domainx
"""
assert
isinstance
(
expr
,
_expr
.
Expr
)
assert
isinstance
(
rdom
,
_collections
.
RDomain
)
rdom
=
rdom
if
isinstance
(
rdom
,
list
)
else
[
rdom
]
x
=
_make
.
Reduce
(
"Max"
,
expr
,
rdom
)
return
x
...
...
python/tvm/tensor.py
View file @
5445a936
from
__future__
import
absolute_import
as
_abs
from
._ctypes._api
import
NodeBase
,
register_node
from
._ctypes._api
import
NodeBase
,
register_node
,
convert
from
.
import
collections
as
_collections
from
.
import
make
as
_make
from
.
import
expr
as
_expr
...
...
@@ -10,7 +11,18 @@ class Tensor(NodeBase):
ndim
=
self
.
ndim
if
len
(
indices
)
!=
ndim
:
raise
ValueError
(
"Need to provide
%
d index in tensor slice"
%
ndim
)
return
_make
.
Call
(
self
.
dtype
,
self
.
name
,
indices
,
_expr
.
Call
.
Halide
,
self
,
0
)
indices
=
convert
(
indices
)
args
=
[]
for
x
in
indices
:
if
isinstance
(
x
,
_collections
.
IterVar
):
args
.
append
(
x
.
var
)
elif
isinstance
(
x
,
_expr
.
Expr
):
args
.
append
(
x
)
else
:
raise
ValueError
(
"The indices must be expression"
)
return
_make
.
Call
(
self
.
dtype
,
self
.
name
,
args
,
_expr
.
Call
.
Halide
,
self
,
0
)
@property
def
ndim
(
self
):
...
...
src/c_api/c_api_function.cc
View file @
5445a936
...
...
@@ -4,9 +4,7 @@
* \file c_api_impl.cc
*/
#include <tvm/expr.h>
#include <tvm/domain.h>
#include <tvm/tensor.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace
dmlc
{
...
...
@@ -22,21 +20,9 @@ TVM_REGISTER_API(_format_str)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
using
Halide
::
Internal
::
BaseExprNode
;
using
Halide
::
Internal
::
BaseStmtNode
;
CHECK
(
args
.
at
(
0
).
type_id
==
kNodeHandle
);
std
::
ostringstream
os
;
auto
&
sptr
=
args
.
at
(
0
).
sptr
;
if
(
dynamic_cast
<
const
TensorNode
*>
(
sptr
.
get
()))
{
os
<<
args
.
at
(
0
).
operator
Tensor
();
}
else
if
(
dynamic_cast
<
const
RDomainNode
*>
(
sptr
.
get
()))
{
os
<<
args
.
at
(
0
).
operator
RDomain
();
}
else
if
(
dynamic_cast
<
const
BaseExprNode
*>
(
sptr
.
get
()))
{
os
<<
args
.
at
(
0
).
operator
Expr
();
}
else
if
(
dynamic_cast
<
const
BaseStmtNode
*>
(
sptr
.
get
()))
{
os
<<
args
.
at
(
0
).
operator
Stmt
();
}
else
{
LOG
(
FATAL
)
<<
"don't know how to print input NodeBaseType"
;
}
os
<<
args
.
at
(
0
).
operator
NodeRef
();
*
ret
=
os
.
str
();
})
.
add_argument
(
"expr"
,
"Node"
,
"expression to be printed"
);
...
...
src/c_api/c_api_lang.cc
View file @
5445a936
...
...
@@ -5,10 +5,8 @@
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/domain.h>
#include <tvm/split.h>
#include <tvm/schedule.h>
#include <ir/IROperator.h>
#include "./c_api_registry.h"
namespace
tvm
{
...
...
@@ -95,11 +93,13 @@ TVM_REGISTER_API(_ComputeOp)
args
.
at
(
3
));
});
TVM_REGISTER_API
(
_RDomain
)
TVM_REGISTER_API
(
_IterVar
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
RDomain
(
args
.
at
(
0
).
operator
Domain
(
));
*
ret
=
IterVar
(
args
.
at
(
0
),
args
.
at
(
1
),
args
.
at
(
2
));
});
TVM_REGISTER_API
(
_DimSplit
)
.
set_body
([](
const
ArgStack
&
args
,
RetValue
*
ret
)
{
*
ret
=
DimSplitNode
::
make
(
args
.
at
(
0
),
args
.
at
(
1
));
...
...
src/c_api/c_api_registry.h
View file @
5445a936
...
...
@@ -125,8 +125,14 @@ class APIVariantValue {
return
Expr
(
static_cast
<
float
>
(
operator
double
()));
}
CHECK_EQ
(
type_id
,
kNodeHandle
);
if
(
sptr
->
is_type
<
IterVarNode
>
())
{
return
IterVar
(
sptr
)
->
var
;
}
else
{
CHECK
(
dynamic_cast
<
typename
Expr
::
ContainerType
*>
(
sptr
.
get
()))
<<
"did not pass in Expr in a place need Expr"
;
return
Expr
(
sptr
);
}
}
inline
operator
double
()
const
{
CHECK_EQ
(
type_id
,
kDouble
);
return
v_union
.
v_double
;
...
...
src/lang/domain.cc
deleted
100644 → 0
View file @
7591714a
/*!
* Copyright (c) 2016 by Contributors
* \file domain.cc
*/
#include <tvm/base.h>
#include <tvm/domain.h>
namespace
tvm
{
Range
::
Range
(
Expr
begin
,
Expr
end
)
:
Range
(
std
::
make_shared
<
Halide
::
IR
::
RangeNode
>
(
begin
,
end
-
begin
))
{
// TODO(tqchen) add simplify to end - begin
}
Range
Range
::
make_with_min_extent
(
Expr
min
,
Expr
extent
)
{
return
Range
(
std
::
make_shared
<
Halide
::
IR
::
RangeNode
>
(
min
,
extent
));
}
RDomain
::
RDomain
(
Domain
domain
)
{
std
::
vector
<
Var
>
index
;
for
(
size_t
i
=
0
;
i
<
domain
.
size
();
++
i
)
{
std
::
ostringstream
os
;
os
<<
"reduction_index"
<<
i
;
index
.
push_back
(
Var
(
os
.
str
()));
}
Array
<
Var
>
idx
(
index
);
node_
=
std
::
make_shared
<
RDomainNode
>
(
std
::
move
(
idx
),
std
::
move
(
domain
));
}
RDomain
RDomain
::
make
(
Array
<
Var
>
index
,
Domain
domain
)
{
return
RDomain
(
std
::
make_shared
<
RDomainNode
>
(
index
,
domain
));
}
TVM_REGISTER_NODE_TYPE
(
RDomainNode
);
}
// namespace tvm
src/lang/expr.cc
0 → 100644
View file @
5445a936
/*!
* Copyright (c) 2016 by Contributors
* \file expr.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <ir/IRPrinter.h>
#include <memory>
namespace
dmlc
{
DMLC_REGISTRY_ENABLE
(
::
tvm
::
NodeFactoryReg
);
}
// namespace dmlc
namespace
tvm
{
Range
::
Range
(
Expr
begin
,
Expr
end
)
:
Range
(
std
::
make_shared
<
Halide
::
IR
::
RangeNode
>
(
begin
,
end
-
begin
))
{
// TODO(tqchen) add simplify to end - begin
}
Range
Range
::
make_with_min_extent
(
Expr
min
,
Expr
extent
)
{
return
Range
(
std
::
make_shared
<
Halide
::
IR
::
RangeNode
>
(
min
,
extent
));
}
IterVar
::
IterVar
(
Range
dom
,
std
::
string
var_name
,
std
::
string
thread_tag
)
:
IterVar
(
IterVarNode
::
make
(
Var
(
var_name
,
Int
(
32
)),
dom
,
thread_tag
))
{}
IterVar
IterVarNode
::
make
(
Var
var
,
Range
dom
,
std
::
string
thread_tag
)
{
std
::
shared_ptr
<
IterVarNode
>
n
=
std
::
make_shared
<
IterVarNode
>
();
n
->
var
=
var
;
n
->
dom
=
dom
;
n
->
thread_tag
=
thread_tag
;
return
IterVar
(
n
);
}
Expr
sum
(
Expr
source
,
Array
<
IterVar
>
rdom
)
{
return
ir
::
Reduce
::
make
(
"Add"
,
source
,
rdom
);
}
Expr
max
(
Expr
source
,
Array
<
IterVar
>
rdom
)
{
return
ir
::
Reduce
::
make
(
"Max"
,
source
,
rdom
);
}
Expr
min
(
Expr
source
,
Array
<
IterVar
>
rdom
)
{
return
ir
::
Reduce
::
make
(
"Min"
,
source
,
rdom
);
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
NodeRef
&
n
)
{
// NOLINT(*)
IRPrinter
(
os
).
print
(
n
);
return
os
;
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
IterVarNode
>
([](
const
IterVarNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"iter_var("
;
if
(
op
->
var
->
name_hint
.
length
()
!=
0
)
{
p
->
stream
<<
op
->
var
->
name_hint
<<
", "
;
}
p
->
stream
<<
op
->
dom
;
if
(
op
->
thread_tag
.
length
()
!=
0
)
{
p
->
stream
<<
", "
<<
op
->
thread_tag
;
}
p
->
stream
<<
")"
;
});
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
Halide
::
IR
::
RangeNode
>
([](
const
Halide
::
IR
::
RangeNode
*
op
,
IRPrinter
*
p
)
{
p
->
stream
<<
"range(min="
<<
op
->
min
<<
", ext="
<<
op
->
extent
<<
')'
;
});
TVM_REGISTER_NODE_TYPE
(
IterVarNode
);
}
// namespace tvm
src/lang/ir.cc
View file @
5445a936
/*!
* Copyright (c) 2016 by Contributors
* \file ir
_node
.cc
* \file ir.cc
*/
#include <tvm/base.h>
#include <tvm/expr.h>
...
...
@@ -9,11 +9,6 @@
#include <ir/IRPrinter.h>
#include <memory>
namespace
dmlc
{
DMLC_REGISTRY_ENABLE
(
::
tvm
::
NodeFactoryReg
);
}
// namespace dmlc
namespace
Halide
{
namespace
Internal
{
...
...
@@ -53,9 +48,12 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable)
namespace
tvm
{
namespace
ir
{
Expr
Reduce
::
make
(
std
::
string
op
,
Expr
source
,
RDomain
rdom
)
{
Expr
Reduce
::
make
(
std
::
string
op
,
Expr
source
,
Array
<
IterVar
>
rdom
)
{
auto
n
=
std
::
make_shared
<
Reduce
>
();
CHECK
(
source
.
defined
());
for
(
size_t
i
=
0
;
i
<
rdom
.
size
();
++
i
)
{
CHECK
(
rdom
[
i
].
defined
());
}
n
->
type
=
source
.
type
();
n
->
source
=
source
;
n
->
op
=
op
;
...
...
src/lang/tensor.cc
View file @
5445a936
...
...
@@ -41,6 +41,12 @@ Tensor TensorNode::make(Array<Expr> shape,
return
Tensor
(
n
);
}
TVM_STATIC_IR_FUNCTOR
(
IRPrinter
,
vtable
)
.
set_dispatch
<
TensorNode
>
([](
const
TensorNode
*
t
,
IRPrinter
*
p
)
{
p
->
stream
<<
"Tensor(shape="
<<
t
->
shape
<<
", name="
<<
t
->
name
<<
')'
;
});
TVM_REGISTER_NODE_TYPE
(
TensorNode
);
}
// namespace tvm
src/pass/ir_mutator.cc
View file @
5445a936
...
...
@@ -42,27 +42,29 @@ inline Array<Expr> MutateArray(Array<Expr> arr, IRMutator *m) {
}
}
inline
RDomain
MutateRDom
(
RDomain
rdom
,
IRMutator
*
m
)
{
std
::
vector
<
Range
>
new_dom
(
rdom
->
domain
.
size
());
inline
Array
<
IterVar
>
MutateRDom
(
Array
<
IterVar
>
rdom
,
IRMutator
*
m
)
{
std
::
vector
<
IterVar
>
new_dom
(
rdom
.
size
());
bool
changed
=
false
;
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
Range
r
=
rdom
->
domain
[
i
];
for
(
size_t
i
=
0
;
i
<
rdom
.
size
();
i
++
)
{
IterVar
v
=
rdom
[
i
];
Range
r
=
v
->
dom
;
Expr
new_min
=
m
->
Mutate
(
r
->
min
);
Expr
new_extent
=
m
->
Mutate
(
r
->
extent
);
if
(
!
r
->
min
.
same_as
(
new_min
))
changed
=
true
;
if
(
!
r
->
extent
.
same_as
(
new_extent
))
changed
=
true
;
new_dom
[
i
]
=
Range
::
make_with_min_extent
(
new_min
,
new_extent
);
new_dom
[
i
]
=
IterVarNode
::
make
(
v
->
var
,
Range
::
make_with_min_extent
(
new_min
,
new_extent
),
v
->
thread_tag
);
}
if
(
!
changed
)
{
return
rdom
;
}
else
{
return
RDomain
::
make
(
rdom
->
index
,
Domain
(
new_dom
)
);
return
Array
<
IterVar
>
(
new_dom
);
}
}
TVM_STATIC_IR_FUNCTOR
(
IRMutator
,
vtable_expr
)
.
set_dispatch
<
Reduce
>
([](
const
Reduce
*
op
,
const
Expr
&
e
,
IRMutator
*
m
)
{
RDomain
new_rdom
=
MutateRDom
(
op
->
rdom
,
m
);
Array
<
IterVar
>
new_rdom
=
MutateRDom
(
op
->
rdom
,
m
);
Expr
new_source
=
m
->
Mutate
(
op
->
source
);
if
(
op
->
rdom
.
same_as
(
new_rdom
)
&&
op
->
source
.
same_as
(
new_source
))
{
...
...
src/pass/ir_visitor.cc
View file @
5445a936
...
...
@@ -45,15 +45,15 @@ using namespace Halide::Internal;
void
NoOp
(
const
NodeRef
&
n
,
IRVisitor
*
v
)
{
}
inline
void
VisitArray
(
Array
<
Expr
>
arr
,
IRVisitor
*
v
)
{
inline
void
VisitArray
(
const
Array
<
Expr
>&
arr
,
IRVisitor
*
v
)
{
for
(
size_t
i
=
0
;
i
<
arr
.
size
();
i
++
)
{
v
->
Visit
(
arr
[
i
]);
}
}
inline
void
VisitRDom
(
RDomain
rdom
,
IRVisitor
*
v
)
{
for
(
size_t
i
=
0
;
i
<
rdom
->
domain
.
size
();
i
++
)
{
Range
r
=
rdom
->
domain
[
i
]
;
inline
void
VisitRDom
(
const
Array
<
IterVar
>&
rdom
,
IRVisitor
*
v
)
{
for
(
size_t
i
=
0
;
i
<
rdom
.
size
();
i
++
)
{
Range
r
=
rdom
[
i
]
->
dom
;
v
->
Visit
(
r
->
min
);
v
->
Visit
(
r
->
extent
);
}
...
...
src/pass/schedule_ops.cc
View file @
5445a936
...
...
@@ -67,7 +67,6 @@ void MakeLoop(const DimSplitNode* op,
Stmt
MakePipeline
(
const
Schedule
&
sch
,
Stmt
body
)
{
return
body
;
}
...
...
tests/cpp/tensor_test.cc
View file @
5445a936
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <tvm/tvm.h>
...
...
@@ -14,6 +13,19 @@ TEST(Tensor, Basic) {
},
"C"
);
}
TEST
(
Tensor
,
Reduce
)
{
using
namespace
tvm
;
Var
m
(
"m"
),
n
(
"n"
),
l
(
"l"
);
Tensor
A
({
m
,
l
},
"A"
);
Tensor
B
({
n
,
l
},
"B"
);
IterVar
rv
(
Range
{
0
,
l
},
"k"
);
auto
C
=
Compute
({
m
,
n
},
[
&
](
Var
i
,
Var
j
)
{
return
sum
(
max
(
A
(
i
,
rv
)
*
B
(
j
,
rv
),
1
),
{
rv
});
},
"C"
);
LOG
(
INFO
)
<<
C
->
op
.
as
<
ComputeOpNode
>
()
->
body
;
}
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
FLAGS_gtest_death_test_style
=
"threadsafe"
;
...
...
tests/python/test_tensor.py
View file @
5445a936
...
...
@@ -7,7 +7,7 @@ def test_tensor():
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
))
print
(
T
)
print
(
T
.
op
.
body
)
assert
(
tuple
(
T
.
shape
)
==
(
m
,
n
,
l
))
assert
(
A
.
source
is
None
)
...
...
@@ -19,8 +19,8 @@ def test_tensor_reduce():
A
=
tvm
.
placeholder
((
m
,
l
),
name
=
'A'
)
B
=
tvm
.
placeholder
((
n
,
l
),
name
=
'B'
)
T
=
tvm
.
compute
((
m
,
n
,
l
),
lambda
i
,
j
,
k
:
A
(
i
,
k
)
*
B
(
j
,
k
))
r
d
=
tvm
.
RDomain
(
tvm
.
Range
(
A
.
shape
[
1
])
)
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
T
(
i
,
j
,
r
d
.
index
[
0
]),
rdom
=
rd
))
r
v
=
tvm
.
IterVar
((
0
,
A
.
shape
[
1
]),
name
=
"k"
)
C
=
tvm
.
compute
((
m
,
n
),
lambda
i
,
j
:
tvm
.
sum
(
T
(
i
,
j
,
r
v
+
1
),
rdom
=
rv
))
print
(
C
.
op
.
body
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment