Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
97be70a0
Commit
97be70a0
authored
Feb 24, 2019
by
雾雨魔理沙
Committed by
Tianqi Chen
Feb 24, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Relay] add more function to prelude (#2660)
parent
4ba30478
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
83 additions
and
8 deletions
+83
-8
python/tvm/relay/backend/interpreter.py
+5
-2
python/tvm/relay/prelude.py
+58
-1
src/relay/ir/module.cc
+1
-0
src/relay/pass/type_solver.cc
+2
-5
tests/python/relay/test_adt.py
+17
-0
No files found.
python/tvm/relay/backend/interpreter.py
View file @
97be70a0
...
@@ -250,12 +250,15 @@ class Interpreter(Executor):
...
@@ -250,12 +250,15 @@ class Interpreter(Executor):
The optimized expression.
The optimized expression.
"""
"""
# TODO: We need to move this optimization code into the optimizer/pass manager
# TODO: We need to move this optimization code into the optimizer/pass manager
ck_expr
=
ir_pass
.
infer_type
(
expr
,
mod
=
self
.
mod
)
wrapped_expr
=
expr
if
isinstance
(
expr
,
Function
)
else
Function
([],
expr
)
if
self
.
mod
:
self
.
mod
[
self
.
mod
.
entry_func
]
=
wrapped_expr
ck_expr
=
ir_pass
.
infer_type
(
wrapped_expr
,
mod
=
self
.
mod
)
simp_expr
=
ir_pass
.
simplify_inference
(
ck_expr
)
simp_expr
=
ir_pass
.
simplify_inference
(
ck_expr
)
ck_simp
=
ir_pass
.
infer_type
(
simp_expr
,
mod
=
self
.
mod
)
ck_simp
=
ir_pass
.
infer_type
(
simp_expr
,
mod
=
self
.
mod
)
fused_expr
=
ir_pass
.
fuse_ops
(
ck_simp
)
fused_expr
=
ir_pass
.
fuse_ops
(
ck_simp
)
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
mod
=
self
.
mod
)
ck_fused
=
ir_pass
.
infer_type
(
fused_expr
,
mod
=
self
.
mod
)
return
ck_fused
return
ck_fused
if
isinstance
(
expr
,
Function
)
else
Call
(
ck_fused
,
[])
def
_make_executor
(
self
,
expr
):
def
_make_executor
(
self
,
expr
):
def
_interp_wrapper
(
*
args
,
**
kwargs
):
def
_interp_wrapper
(
*
args
,
**
kwargs
):
...
...
python/tvm/relay/prelude.py
View file @
97be70a0
...
@@ -340,7 +340,10 @@ class Prelude:
...
@@ -340,7 +340,10 @@ class Prelude:
Match
(
t
,
[
rose_case
]),
self
.
tree
(
b
),
[
a
,
b
])
Match
(
t
,
[
rose_case
]),
self
.
tree
(
b
),
[
a
,
b
])
def
define_tree_size
(
self
):
def
define_tree_size
(
self
):
"""Defines a function that computes the size of a tree as a nat."""
"""Defines a function that computes the size of a tree as a nat.
Signature: fn<a>(t : tree[a]) -> nat
"""
self
.
size
=
GlobalVar
(
"size"
)
self
.
size
=
GlobalVar
(
"size"
)
a
=
TypeVar
(
"a"
)
a
=
TypeVar
(
"a"
)
t
=
Var
(
"t"
,
self
.
tree
(
a
))
t
=
Var
(
"t"
,
self
.
tree
(
a
))
...
@@ -351,6 +354,56 @@ class Prelude:
...
@@ -351,6 +354,56 @@ class Prelude:
self
.
mod
[
self
.
size
]
=
Function
([
t
],
self
.
mod
[
self
.
size
]
=
Function
([
t
],
Match
(
t
,
[
rose_case
]),
self
.
nat
(),
[
a
])
Match
(
t
,
[
rose_case
]),
self
.
nat
(),
[
a
])
def
define_id
(
self
):
"""Defines a function that return it's argument.
Signature: fn<a>(x : a) -> a
"""
self
.
id
=
GlobalVar
(
"id"
)
a
=
TypeVar
(
"a"
)
x
=
Var
(
"x"
,
a
)
self
.
mod
[
self
.
id
]
=
Function
([
x
],
x
,
a
,
[
a
])
def
define_compose
(
self
):
"""Defines a function that compose two function.
Signature: fn<a, b, c>(f : fn(b) -> c, g : fn(a) -> b) -> fn(a) -> c
"""
self
.
compose
=
GlobalVar
(
"compose"
)
a
=
TypeVar
(
"a"
)
b
=
TypeVar
(
"b"
)
c
=
TypeVar
(
"c"
)
f
=
Var
(
"f"
,
FuncType
([
b
],
c
))
g
=
Var
(
"g"
,
FuncType
([
a
],
b
))
x
=
Var
(
"x"
)
self
.
mod
[
self
.
compose
]
=
Function
([
f
,
g
],
Function
([
x
],
f
(
g
(
x
))),
FuncType
([
a
],
c
),
[
a
,
b
,
c
])
def
define_iterate
(
self
):
"""Define a function that take a number n, a function f,
and return a closure that apply f n time on it's argument.
Signature: fn<a>(n : nat, f : fn(a) -> a) -> fn(a) -> a
"""
self
.
iterate
=
GlobalVar
(
"iterate"
)
a
=
TypeVar
(
"a"
)
f
=
Var
(
"f"
,
FuncType
([
a
],
a
))
x
=
Var
(
"x"
,
self
.
nat
())
y
=
Var
(
"y"
,
self
.
nat
())
z
=
Var
(
"z"
)
z_case
=
Clause
(
PatternConstructor
(
self
.
z
),
Function
([
z
],
z
))
# todo: fix typechecker so Function([z], z) can be replaced by self.id
s_case
=
Clause
(
PatternConstructor
(
self
.
s
,
[
PatternVar
(
y
)]),
self
.
compose
(
f
,
self
.
iterate
(
f
,
y
)))
self
.
mod
[
self
.
iterate
]
=
Function
([
f
,
x
],
Match
(
x
,
[
z_case
,
s_case
]),
FuncType
([
a
],
a
),
[
a
])
def
__init__
(
self
,
mod
):
def
__init__
(
self
,
mod
):
self
.
mod
=
mod
self
.
mod
=
mod
self
.
define_list_adt
()
self
.
define_list_adt
()
...
@@ -377,3 +430,7 @@ class Prelude:
...
@@ -377,3 +430,7 @@ class Prelude:
self
.
define_tree_adt
()
self
.
define_tree_adt
()
self
.
define_tree_map
()
self
.
define_tree_map
()
self
.
define_tree_size
()
self
.
define_tree_size
()
self
.
define_id
()
self
.
define_compose
()
self
.
define_iterate
()
src/relay/ir/module.cc
View file @
97be70a0
...
@@ -83,6 +83,7 @@ void ModuleNode::Add(const GlobalVar& var,
...
@@ -83,6 +83,7 @@ void ModuleNode::Add(const GlobalVar& var,
CHECK
(
AlphaEqual
(
type
,
old_type
))
CHECK
(
AlphaEqual
(
type
,
old_type
))
<<
"Module#update changes type, not possible in this mode."
;
<<
"Module#update changes type, not possible in this mode."
;
}
}
var
->
checked_type_
=
type
;
AddUnchecked
(
var
,
checked_func
);
AddUnchecked
(
var
,
checked_func
);
}
}
...
...
src/relay/pass/type_solver.cc
View file @
97be70a0
...
@@ -400,11 +400,8 @@ Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) {
...
@@ -400,11 +400,8 @@ Type TypeSolver::Unify(const Type& dst, const Type& src, const NodeRef&) {
}
}
void
TypeSolver
::
ReportError
(
const
Error
&
err
,
const
NodeRef
&
location
)
{
void
TypeSolver
::
ReportError
(
const
Error
&
err
,
const
NodeRef
&
location
)
{
this
->
err_reporter_
->
ReportAt
(
err_reporter_
->
ReportAt
(
current_func
,
location
,
err
);
this
->
current_func
,
}
location
,
err
);
}
// Add type constraint to the solver.
// Add type constraint to the solver.
void
TypeSolver
::
AddConstraint
(
const
TypeConstraint
&
constraint
,
const
NodeRef
&
loc
)
{
void
TypeSolver
::
AddConstraint
(
const
TypeConstraint
&
constraint
,
const
NodeRef
&
loc
)
{
...
...
tests/python/relay/test_adt.py
View file @
97be70a0
...
@@ -43,6 +43,9 @@ rose = p.rose
...
@@ -43,6 +43,9 @@ rose = p.rose
tmap
=
p
.
tmap
tmap
=
p
.
tmap
size
=
p
.
size
size
=
p
.
size
compose
=
p
.
compose
iterate
=
p
.
iterate
# this is an example of using the adt value in python side
# this is an example of using the adt value in python side
def
count
(
n
):
def
count
(
n
):
assert
isinstance
(
n
,
ConstructorValue
)
assert
isinstance
(
n
,
ConstructorValue
)
...
@@ -93,6 +96,7 @@ def tree_to_dict(t):
...
@@ -93,6 +96,7 @@ def tree_to_dict(t):
def
test_nat_value
():
def
test_nat_value
():
assert
count
(
make_nat
(
10
))
==
10
assert
count
(
make_nat
(
10
))
==
10
assert
count
(
intrp
.
evaluate
(
s
(
s
(
z
()))))
==
2
def
test_nat_constructor
():
def
test_nat_constructor
():
...
@@ -577,6 +581,17 @@ def test_nested_pattern_match():
...
@@ -577,6 +581,17 @@ def test_nested_pattern_match():
assert
count
(
res
)
==
2
assert
count
(
res
)
==
2
def
test_compose
():
n
=
relay
.
Var
(
'n'
)
inc
=
relay
.
Function
([
n
],
s
(
n
))
x
=
relay
.
Var
(
'x'
)
res
=
intrp
.
evaluate
(
relay
.
Call
(
compose
(
inc
,
double
),
[
s
(
s
(
z
()))]))
assert
count
(
res
)
==
5
def
test_iterate
():
expr
=
relay
.
Call
(
iterate
(
double
,
build_nat
(
2
)),
[
build_nat
(
3
)])
res
=
intrp
.
evaluate
(
relay
.
Function
([],
expr
)())
assert
count
(
res
)
==
12
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test_nat_constructor
()
test_nat_constructor
()
...
@@ -598,3 +613,5 @@ if __name__ == "__main__":
...
@@ -598,3 +613,5 @@ if __name__ == "__main__":
test_sum
()
test_sum
()
test_tmap
()
test_tmap
()
test_size
()
test_size
()
test_compose
()
test_iterate
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment