Commit 7de8a3a4 by Lianmin Zheng Committed by Tianqi Chen

[TOPI] Memoize winograd matrix (#3687)

* [TOPI] Memoize winograd matrix

* lint

* Fix name
parent 33ab3c60
...@@ -34,9 +34,11 @@ class Cache(object): ...@@ -34,9 +34,11 @@ class Cache(object):
---------- ----------
key: str key: str
The file key to the function The file key to the function
save_at_exit: bool
Whether save the cache to file when the program exits
""" """
cache_by_key = {} cache_by_key = {}
def __init__(self, key): def __init__(self, key, save_at_exit):
cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0]) cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0])
if not os.path.exists(cache_dir): if not os.path.exists(cache_dir):
os.mkdir(cache_dir) os.mkdir(cache_dir)
...@@ -49,6 +51,7 @@ class Cache(object): ...@@ -49,6 +51,7 @@ class Cache(object):
else: else:
self.cache = {} self.cache = {}
self.dirty = False self.dirty = False
self.save_at_exit = save_at_exit
def save(self): def save(self):
if self.dirty: if self.dirty:
...@@ -60,16 +63,19 @@ class Cache(object): ...@@ -60,16 +63,19 @@ class Cache(object):
def _atexit(): def _atexit():
"""Save handler.""" """Save handler."""
for value in Cache.cache_by_key.values(): for value in Cache.cache_by_key.values():
if value.save_at_exit:
value.save() value.save()
def memoize(key): def memoize(key, save_at_exit=False):
"""Memoize the result of function and reuse multiple times. """Memoize the result of function and reuse multiple times.
Parameters Parameters
---------- ----------
key: str key: str
The unique key to the file The unique key to the file
save_at_exit: bool
Whether save the cache to file when the program exits
Returns Returns
------- -------
...@@ -81,9 +87,9 @@ def memoize(key): ...@@ -81,9 +87,9 @@ def memoize(key):
allow_types = (string_types, int, float) allow_types = (string_types, int, float)
fkey = key + "." + f.__name__ + ".pkl" fkey = key + "." + f.__name__ + ".pkl"
if fkey not in Cache.cache_by_key: if fkey not in Cache.cache_by_key:
Cache.cache_by_key[fkey] = Cache(fkey) Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit)
cache = Cache.cache_by_key[fkey] cache = Cache.cache_by_key[fkey]
cargs = tuple(x.cell_contents for x in f.__closure__) cargs = tuple(x.cell_contents for x in f.__closure__) if f.__closure__ else ()
cargs = (len(cargs),) + cargs cargs = (len(cargs),) + cargs
def _memoized_f(func, *args, **kwargs): def _memoized_f(func, *args, **kwargs):
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
from operator import mul from operator import mul
from functools import reduce from functools import reduce
import numpy as np import numpy as np
from tvm.contrib.pickle_memoize import memoize
from ..util import const_matrix from ..util import const_matrix
...@@ -131,6 +132,8 @@ def _interpolation_points(degree): ...@@ -131,6 +132,8 @@ def _interpolation_points(degree):
return np.array(in_pts[degree-1], dtype=np.float64) return np.array(in_pts[degree-1], dtype=np.float64)
@memoize("topi.nn.winograd_matrices", save_at_exit=False)
def winograd_transform_matrices(tile_size, kernel_size, out_dtype): def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
"""Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`. """Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment