diff --git a/pyop2/caching.py b/pyop2/caching.py index 7b0b85735..24a3f5513 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -273,9 +273,10 @@ def decorator(func): def wrapper(*args, **kwargs): if collective: comm, disk_key = key(*args, **kwargs) + disk_key = _as_hexdigest(disk_key) k = hash_comm(comm), disk_key else: - k = key(*args, **kwargs) + k = _as_hexdigest(key(*args, **kwargs)) # first try the in-memory cache try: @@ -286,40 +287,59 @@ def wrapper(*args, **kwargs): # then try to retrieve from disk if collective: if comm.rank == 0: - v = _disk_cache_setdefault(cachedir, disk_key, lambda: func(*args, **kwargs)) + v = _disk_cache_get(cachedir, disk_key) comm.bcast(v, root=0) else: v = comm.bcast(None, root=0) else: - v = _disk_cache_setdefault(cachedir, k, lambda: func(*args, **kwargs)) + v = _disk_cache_get(cachedir, k) + if v is not None: + return cache.setdefault(k, v) + + # if all else fails call func and populate the caches + v = func(*args, **kwargs) + if collective: + if comm.rank == 0: + _disk_cache_set(cachedir, disk_key, v) + else: + _disk_cache_set(cachedir, k, v) return cache.setdefault(k, v) return wrapper return decorator -def _disk_cache_setdefault(cachedir, key, default): - """If ``key`` is in cache, return it. If not, store ``default`` in the cache - and return it. +def _as_hexdigest(key): + return hashlib.md5(str(key).encode()).hexdigest() - :arg cachedir: The cache directory. - :arg key: The cache key. - :arg default: Lazily evaluated callable that returns a new value to insert into the cache. - :returns: The value associated with ``key``. - """ - key = hashlib.md5(str(key).encode()).hexdigest() - key1, key2 = key[:2], key[2:] +def _disk_cache_get(cachedir, key): + """Retrieve a value from the disk cache. - basedir = Path(cachedir, key1) - filepath = basedir.joinpath(key2) + :arg cachedir: The cache directory. + :arg key: The cache key (must be a string). + :returns: The cached object if found, else ``None``. + """ + filepath = Path(cachedir, key[:2], key[2:]) try: with open(filepath, "rb") as f: return pickle.load(f) except FileNotFoundError: - basedir.mkdir(parents=True, exist_ok=True) - tempfile = basedir.joinpath(f"{key2}_p{os.getpid()}.tmp") - obj = default() - with open(tempfile, "wb") as f: - pickle.dump(obj, f) - tempfile.rename(filepath) - return obj + return None + + +def _disk_cache_set(cachedir, key, value): + """Store a new value in the disk cache. + + :arg cachedir: The cache directory. + :arg key: The cache key (must be a string). + :arg value: The new item to store in the cache. + """ + k1, k2 = key[:2], key[2:] + basedir = Path(cachedir, k1) + basedir.mkdir(parents=True, exist_ok=True) + + tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp") + filepath = basedir.joinpath(k2) + with open(tempfile, "wb") as f: + pickle.dump(value, f) + tempfile.rename(filepath)