Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Merge pull request #650 from OP2/connorjward/fix-deadlocking-decorator
Browse files Browse the repository at this point in the history
Fix deadlocking disk_cached decorator
  • Loading branch information
connorjward authored Mar 22, 2022
2 parents 4ba2952 + 462da86 commit 5fb8559
Showing 1 changed file with 42 additions and 22 deletions.
64 changes: 42 additions & 22 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,10 @@ def decorator(func):
def wrapper(*args, **kwargs):
if collective:
comm, disk_key = key(*args, **kwargs)
disk_key = _as_hexdigest(disk_key)
k = hash_comm(comm), disk_key
else:
k = key(*args, **kwargs)
k = _as_hexdigest(key(*args, **kwargs))

# first try the in-memory cache
try:
Expand All @@ -286,40 +287,59 @@ def wrapper(*args, **kwargs):
# then try to retrieve from disk
if collective:
if comm.rank == 0:
v = _disk_cache_setdefault(cachedir, disk_key, lambda: func(*args, **kwargs))
v = _disk_cache_get(cachedir, disk_key)
comm.bcast(v, root=0)
else:
v = comm.bcast(None, root=0)
else:
v = _disk_cache_setdefault(cachedir, k, lambda: func(*args, **kwargs))
v = _disk_cache_get(cachedir, k)
if v is not None:
return cache.setdefault(k, v)

# if all else fails call func and populate the caches
v = func(*args, **kwargs)
if collective:
if comm.rank == 0:
_disk_cache_set(cachedir, disk_key, v)
else:
_disk_cache_set(cachedir, k, v)
return cache.setdefault(k, v)
return wrapper
return decorator


def _disk_cache_setdefault(cachedir, key, default):
"""If ``key`` is in cache, return it. If not, store ``default`` in the cache
and return it.
def _as_hexdigest(key):
return hashlib.md5(str(key).encode()).hexdigest()

:arg cachedir: The cache directory.
:arg key: The cache key.
:arg default: Lazily evaluated callable that returns a new value to insert into the cache.

:returns: The value associated with ``key``.
"""
key = hashlib.md5(str(key).encode()).hexdigest()
key1, key2 = key[:2], key[2:]
def _disk_cache_get(cachedir, key):
"""Retrieve a value from the disk cache.
basedir = Path(cachedir, key1)
filepath = basedir.joinpath(key2)
:arg cachedir: The cache directory.
:arg key: The cache key (must be a string).
:returns: The cached object if found, else ``None``.
"""
filepath = Path(cachedir, key[:2], key[2:])
try:
with open(filepath, "rb") as f:
return pickle.load(f)
except FileNotFoundError:
basedir.mkdir(parents=True, exist_ok=True)
tempfile = basedir.joinpath(f"{key2}_p{os.getpid()}.tmp")
obj = default()
with open(tempfile, "wb") as f:
pickle.dump(obj, f)
tempfile.rename(filepath)
return obj
return None


def _disk_cache_set(cachedir, key, value):
"""Store a new value in the disk cache.
:arg cachedir: The cache directory.
:arg key: The cache key (must be a string).
:arg value: The new item to store in the cache.
"""
k1, k2 = key[:2], key[2:]
basedir = Path(cachedir, k1)
basedir.mkdir(parents=True, exist_ok=True)

tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp")
filepath = basedir.joinpath(k2)
with open(tempfile, "wb") as f:
pickle.dump(value, f)
tempfile.rename(filepath)

0 comments on commit 5fb8559

Please sign in to comment.