Commit 7de8a3a4 by Lianmin Zheng Committed by Tianqi Chen

[TOPI] Memoize winograd matrix (#3687)

* [TOPI] Memoize winograd matrix

* lint

* Fix name
parent 33ab3c60
......@@ -34,9 +34,11 @@ class Cache(object):
----------
key: str
The file key to the function
save_at_exit: bool
Whether save the cache to file when the program exits
"""
cache_by_key = {}
def __init__(self, key):
def __init__(self, key, save_at_exit):
cache_dir = ".pkl_memoize_py{0}".format(sys.version_info[0])
if not os.path.exists(cache_dir):
os.mkdir(cache_dir)
......@@ -49,6 +51,7 @@ class Cache(object):
else:
self.cache = {}
self.dirty = False
self.save_at_exit = save_at_exit
def save(self):
if self.dirty:
......@@ -60,16 +63,19 @@ class Cache(object):
def _atexit():
"""Save handler."""
for value in Cache.cache_by_key.values():
if value.save_at_exit:
value.save()
def memoize(key):
def memoize(key, save_at_exit=False):
"""Memoize the result of function and reuse multiple times.
Parameters
----------
key: str
The unique key to the file
save_at_exit: bool
Whether save the cache to file when the program exits
Returns
-------
......@@ -81,9 +87,9 @@ def memoize(key):
allow_types = (string_types, int, float)
fkey = key + "." + f.__name__ + ".pkl"
if fkey not in Cache.cache_by_key:
Cache.cache_by_key[fkey] = Cache(fkey)
Cache.cache_by_key[fkey] = Cache(fkey, save_at_exit)
cache = Cache.cache_by_key[fkey]
cargs = tuple(x.cell_contents for x in f.__closure__)
cargs = tuple(x.cell_contents for x in f.__closure__) if f.__closure__ else ()
cargs = (len(cargs),) + cargs
def _memoized_f(func, *args, **kwargs):
......
......@@ -25,6 +25,7 @@
from operator import mul
from functools import reduce
import numpy as np
from tvm.contrib.pickle_memoize import memoize
from ..util import const_matrix
......@@ -131,6 +132,8 @@ def _interpolation_points(degree):
return np.array(in_pts[degree-1], dtype=np.float64)
@memoize("topi.nn.winograd_matrices", save_at_exit=False)
def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
"""Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`.
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment