Skip to content

Commit

Permalink
update to v0.14.0 (#79)
Browse files Browse the repository at this point in the history
update to v0.14.0
  • Loading branch information
yymao authored Nov 17, 2019
2 parents 93ba3a0 + 5dce260 commit 758dd41
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 108 deletions.
59 changes: 36 additions & 23 deletions SAGA/database/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def download_as_file(self, file_path, overwrite=False, compress=False):
with file_open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=(16 * 1024 * 1024)):
f.write(chunk)
except:
except: # noqa: E722
if os.path.isfile(file_path):
os.unlink(file_path)
raise
Expand Down Expand Up @@ -195,9 +195,9 @@ class DataObject(object):
def __init__(
self, remote, local=None, cache_in_memory=False, use_local_first=False
):
self._local = None
self.remote = remote
self.local = local
self.local_type = type(remote) if local is None else type(local)
self.use_local_first = use_local_first
self.cache_in_memory = cache_in_memory
self._cached_table = None
Expand All @@ -207,10 +207,22 @@ def __init__(
"Must specify `local` when setting `use_local_first=True`."
)

def _get_local(self):
if self.local is not None and not isinstance(self.local, self.local_type):
self.local = self.local_type(self.local)
return self.local
@property
def local(self):
return self._local

@local.setter
def local(self, value):
if value is None:
self._local = None
elif isinstance(value, FileObject):
self._local = value
elif isinstance(self._local, FileObject):
self._local = type(self._local)(value, **self._local.kwargs)
elif isinstance(self.remote, FileObject):
self._local = type(self.remote)(value, **self.remote.kwargs)
else:
self._local = FileObject(value)

def read(self, reload=False, **kwargs):
"""
Expand All @@ -231,13 +243,13 @@ def read(self, reload=False, **kwargs):
return table

if self.use_local_first:
if not self._get_local().isfile():
if not self.local.isfile():
logging.warning(
"Cannot find local file; attempt to download from remote..."
)
self.download()
try:
table = self._get_local().read()
table = self.local.read()
except (IOError, OSError):
logging.warning(
"Failed to read local file; attempt to read remote file..."
Expand All @@ -247,24 +259,24 @@ def read(self, reload=False, **kwargs):
try:
table = self.remote.read(**kwargs)
except Exception as read_exception: # pylint: disable=W0703
if self._get_local() is None:
if self.local is None:
raise read_exception
logging.warning(
"Failed to read remote; fall back to read local file..."
)
if not self._get_local().isfile():
if not self.local.isfile():
logging.warning(
"Cannot find local file; attempt to download from remote..."
)
self.download()
table = self._get_local().read()
table = self.local.read()

if self.cache_in_memory:
self.store_cache(table)

return table

def write(self, table, dest="remote", overwrite=False):
def write(self, table, dest=None, overwrite=False):
"""
write the data to file
Expand All @@ -277,10 +289,12 @@ def write(self, table, dest="remote", overwrite=False):
overwrite : bool, optional
if set to true, overwrite existing file
"""
if dest is None:
dest = "local" if self.use_local_first else "remote"
if dest.lower() == "remote":
f = self.remote
elif dest.lower() == "local":
f = self._get_local()
f = self.local
else:
raise KeyError('dest must be "remote" or "local"')

Expand All @@ -307,7 +321,7 @@ def download(
"""
if local_file_path is None:
try:
local_file_path = self._get_local().path
local_file_path = self.local.path
except AttributeError:
pass
else:
Expand All @@ -318,15 +332,7 @@ def download(
)

if set_as_local:
kwargs = dict()
try:
kwargs = self._get_local().kwargs
except AttributeError:
try:
kwargs = self.remote.kwargs
except AttributeError:
pass
self.local = self.local_type(local_file_path, **kwargs)
self.local = local_file_path

@staticmethod
def _copy_table(table):
Expand All @@ -342,3 +348,10 @@ def retrive_cache(self):

def store_cache(self, table):
self._cached_table = self._copy_table(table)

@property
def path(self):
return self.local.path if self.use_local_first else self.remote.path

def isfile(self):
return self.local.isfile() if self.use_local_first else self.remote.isfile()
111 changes: 75 additions & 36 deletions SAGA/database/saga_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
"hosts_v2": GoogleSheets(
"1b3k2eyFjHFDtmHce1xi6JKuj3ATOWYduTBFftx5oPp8", 1765625842
),
"host_stats": GoogleSheets(
"1b3k2eyFjHFDtmHce1xi6JKuj3ATOWYduTBFftx5oPp8", 1217798377
),
"host_remove": GoogleSheets(
"1Y3nO7VyU4jDiBPawCs8wJQt2s_PIAKRj-HSrmcWeQZo",
1133875164,
Expand Down Expand Up @@ -143,9 +146,6 @@ def __init__(self, shared_dir=None, local_dir=None):
)
)
),
"saga_clean_specs": DataObject(
FitsTable(os.path.join(self._local_dir, "saga_clean_specs.fits.gz"))
),
"hyperleda_kt12": DataObject(
HyperledaQuery(
"v IS NOT NULL and modbest IS NOT NULL and kt<12 and objtype='G'",
Expand Down Expand Up @@ -377,43 +377,27 @@ def __init__(self, shared_dir=None, local_dir=None):
self._tables["footprint_" + name] = DataObject(obj)

for k, v in known_google_sheets.items():
if k == "hosts_v2":
self._tables[k] = DataObject(
v,
CsvTable(
os.path.join(
self._shared_dir, "HostCatalogs", "host_list_v2.csv"
)
),
cache_in_memory=True,
)
elif k == "hosts_v1":
self._tables[k] = DataObject(
v,
CsvTable(
os.path.join(
self._shared_dir, "HostCatalogs", "host_list_v1.csv"
)
),
cache_in_memory=True,
)
elif k == "lowz_fields":
self._tables[k] = DataObject(
v,
CsvTable(
os.path.join(
self._shared_dir, "HostCatalogs", "lowz_fields.csv"
)
),
cache_in_memory=True,
)
else:
self._tables[k] = DataObject(v, CsvTable(), cache_in_memory=True)
self._tables[k] = DataObject(v, CsvTable(), cache_in_memory=True)

self._tables["hosts_v2"].local = CsvTable(
os.path.join(self._shared_dir, "HostCatalogs", "host_list_v2.csv")
)

self._tables["hosts_v1"].local = CsvTable(
os.path.join(self._shared_dir, "HostCatalogs", "host_list_v1.csv")
)

self._tables["lowz_fields"].local = CsvTable(
os.path.join(self._shared_dir, "HostCatalogs", "lowz_fields.csv")
)

self._tables["hosts"] = self._tables["hosts_v2"]
self._tables["master_list"] = self._tables["master_list_v2"]

self._file_path_pattern = {
"base_v2p1": os.path.join(
self._local_dir, "base_catalogs_v2.1", "base_v2_{}.fits.gz"
),
"base_v2": os.path.join(
self._local_dir, "base_catalogs", "base_v2_{}.fits.gz"
),
Expand All @@ -439,9 +423,38 @@ def __init__(self, shared_dir=None, local_dir=None):
self._local_dir, "external_catalogs", "decals", "{}_decals.fits.gz"
),
}
self._file_path_pattern["base"] = self._file_path_pattern["base_v2"]

self._possible_base_versions = tuple(
k.partition("_v")[2]
for k in self._file_path_pattern
if k.startswith("base_")
)
self._file_path_pattern["sdss"] = self._file_path_pattern["sdss_dr14"]
self._file_path_pattern["des"] = self._file_path_pattern["des_dr1"]
self.set_default_base_version()

def _add_derived_data(self):
t = FitsTable(self.base_file_path_pattern.format("saga_clean_specs"))
if "saga_clean_specs" in self._tables:
self._tables["saga_clean_specs"].remote = t
else:
self._tables["saga_clean_specs"] = DataObject(t)

t = FastCsvTable(
self.base_file_path_pattern.format("host_stats").replace(".fits.gz", ".csv")
)

if "host_stats" in self._tables:
self._tables["host_stats"].local = t
self._tables["host_stats"].use_local_first = True
self._tables["host_stats"].clear_cache()
else:
self._tables["host_stats"] = DataObject(
known_google_sheets["host_stats"],
t,
use_local_first=True,
cache_in_memory=True,
)

def _set_file_path_pattern(self, key, value):
self._file_path_pattern[key] = value
Expand All @@ -452,6 +465,8 @@ def _set_file_path_pattern(self, key, value):
]
for k in keys_to_del:
del self._tables[k]
if key == "base":
self._add_derived_data()

@property
def base_file_path_pattern(self):
Expand Down Expand Up @@ -525,3 +540,27 @@ def keys(self):

def _ipython_key_completions_(self):
return list(self.keys())

def resolve_base_version(self, version=None):
"""
resolve `version` into `build_version` and `version_postfix`
"""

if version is None:
return 2, ""
version = str(version).lower().strip().lstrip("v")
if version in ("paper1", "p1"):
return 0, "_v0p1"
while version.endswith(".0"):
version = version[:-2]
version.replace(".", "p")
if version not in self._possible_base_versions:
raise ValueError("version value unknown!")
return int(version[0]), "_v" + version

def set_default_base_version(self, version=None):
_, version_postfix = self.resolve_base_version(version)
if not version_postfix:
version_postfix = '_v2p1'
self._file_path_pattern["base"] = self._file_path_pattern["base" + version_postfix]
self._add_derived_data()
1 change: 1 addition & 0 deletions SAGA/hosts/cuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
good_hosts = Query("HOST_SCORE >= 3")
preferred_hosts = Query("HOST_SCORE >= 4")
good = good_hosts & has_image
build_default = potential_hosts & has_image

hostlist_v1 = QueryMaker.in1d("PGC", _list_by_pgc["hostlist_v1"])
paper1_complete = QueryMaker.in1d("PGC", _list_by_pgc["paper1_complete"])
Expand Down
Loading

0 comments on commit 758dd41

Please sign in to comment.