From 6b1cd7bffad636cd558ad403ccbc0291139340dc Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Mon, 28 Oct 2024 15:58:13 +0100 Subject: [PATCH] [ENH] add caching action for elephant-data (#633) - add caching action for elephant-data --- .github/workflows/CI.yml | 161 ++++++++++++------ .github/workflows/cache_elephant_data.yml | 65 +++++++ elephant/datasets.py | 149 ++++++++-------- elephant/test/test_causality.py | 28 +-- elephant/test/test_datasets.py | 35 ++++ elephant/test/test_phase_analysis.py | 45 ++--- elephant/test/test_spectral.py | 13 +- elephant/test/test_spike_train_correlation.py | 12 +- .../test/test_spike_train_dissimilarity.py | 28 +-- elephant/test/test_spike_train_synchrony.py | 2 +- 10 files changed, 368 insertions(+), 170 deletions(-) create mode 100644 .github/workflows/cache_elephant_data.yml create mode 100644 elephant/test/test_datasets.py diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 91f42577d..d01bba6be 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -70,34 +70,32 @@ jobs: fail-fast: false steps: - # used to reset cache every month - - name: Get current year-month - id: date - run: echo "date=$(date +'%Y-%m')" >> $GITHUB_OUTPUT - - - name: Get pip cache dir - id: pip-cache - run: | - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT - - - uses: actions/checkout@v3 + - uses: actions/checkout@v4.1.6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5.1.0 with: python-version: ${{ matrix.python-version }} + check-latest: true cache: 'pip' - cache-dependency-path: '**/requirements.txt' + cache-dependency-path: | + **/requirements.txt + **/requirements-extras.txt + **/requirements-tests.txt - - name: Cache test_env - uses: actions/cache@v3 + - name: Get current hash (SHA) of the elephant_data repo + id: elephant-data + run: | + echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/elephant-data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT + + - uses: actions/cache/restore@v4.0.2 + # Loading cache of elephant-data + id: cache-datasets with: - path: ${{ steps.pip-cache.outputs.dir }} - # Look to see if there is a cache hit for the corresponding requirements files - # cache will be reset on changes to any requirements or every month - key: ${{ runner.os }}-venv-${{ hashFiles('**/requirements.txt') }}-${{ hashFiles('**/requirements-tests.txt') }} - -${{ hashFiles('**/requirements-extras.txt') }}-${{ hashFiles('**/CI.yml') }}-${{ hashFiles('setup.py') }} - -${{ steps.date.outputs.date }} + path: ~/elephant-data + key: datasets-${{ steps.elephant-data.outputs.dataset_hash }} + restore-keys: datasets- + enableCrossOsArchive: true - name: Install dependencies run: | @@ -112,6 +110,11 @@ jobs: - name: Test with pytest run: | + if [ -d ~/elephant-data ]; then + export ELEPHANT_DATA_LOCATION=~/elephant-data + echo $ELEPHANT_DATA_LOCATION + fi + coverage run --source=elephant -m pytest coveralls --service=github || echo "Coveralls submission failed" env: @@ -146,6 +149,19 @@ jobs: path: ~/conda_pkgs_dir key: ${{ runner.os }}-conda-${{hashFiles('requirements/environment.yml') }}-${{ hashFiles('**/CI.yml') }}-${{ steps.date.outputs.date }} + - name: Get current hash (SHA) of the elephant_data repo + id: elephant-data + run: | + echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/elephant-data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT + + - uses: actions/cache/restore@v4.0.2 + # Loading cache of elephant-data + id: cache-datasets + with: + path: ~/elephant-data + key: datasets-${{ steps.elephant-data.outputs.dataset_hash }} + restore-keys: datasets- + - uses: conda-incubator/setup-miniconda@a4260408e20b96e80095f42ff7f1a15b27dd94ca # corresponds to v3.0.4 with: auto-update-conda: true @@ -173,6 +189,10 @@ jobs: - name: Test with pytest shell: bash -l {0} run: | + if [ -d ~/elephant-data ]; then + export ELEPHANT_DATA_LOCATION=~/elephant-data + echo $ELEPHANT_DATA_LOCATION + fi pytest --cov=elephant # __ ___ _ @@ -192,24 +212,32 @@ jobs: os: [windows-latest] steps: - - name: Get current year-month - id: date - run: echo "date=$(date +'%Y-%m')" >> $GITHUB_OUTPUT - - - uses: actions/checkout@v3 + - uses: actions/checkout@v4.1.6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5.1.0 with: python-version: ${{ matrix.python-version }} + check-latest: true + cache: 'pip' + cache-dependency-path: | + **/requirements.txt + **/requirements-extras.txt + **/requirements-tests.txt - - name: Cache pip - uses: actions/cache@v3 + - name: Get current hash (SHA) of the elephant_data repo + id: elephant-data + run: | + echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/elephant-data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT + + - uses: actions/cache/restore@v4.0.2 + # Loading cache of elephant-data + id: cache-datasets with: - path: ~\AppData\Local\pip\Cache - # Look to see if there is a cache hit for the corresponding requirements files - key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}-${{ hashFiles('**/requirements-tests.txt') }} - -${{ hashFiles('**/requirements-extras.txt') }}-${{ hashFiles('setup.py') }} -${{ hashFiles('**/CI.yml') }}-${{ steps.date.outputs.date }} + path: ~/elephant-data + key: datasets-${{ steps.elephant-data.outputs.dataset_hash }} + restore-keys: datasets- + enableCrossOsArchive: true - name: Install dependencies run: | @@ -224,6 +252,10 @@ jobs: - name: Test with pytest run: | + if (Test-Path "$env:USERPROFILE\elephant-data") { + $env:ELEPHANT_DATA_LOCATION = "$env:USERPROFILE\elephant-data" + Write-Output $env:ELEPHANT_DATA_LOCATION + } pytest --cov=elephant # __ __ ____ ___ @@ -246,29 +278,32 @@ jobs: fail-fast: false steps: - - name: Get current year-month - id: date - run: echo "date=$(date +'%Y-%m')" >> $GITHUB_OUTPUT - - uses: actions/checkout@v3 + - uses: actions/checkout@v4.1.6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5.1.0 with: python-version: ${{ matrix.python-version }} + check-latest: true + cache: 'pip' + cache-dependency-path: | + **/requirements.txt + **/requirements-extras.txt + **/requirements-tests.txt - - name: Get pip cache dir - id: pip-cache + - name: Get current hash (SHA) of the elephant_data repo + id: elephant-data run: | - echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/elephant-data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT - - name: Cache test_env - uses: actions/cache@v3 + - uses: actions/cache/restore@v4.0.2 + # Loading cache of elephant-data + id: cache-datasets with: - path: ${{ steps.pip-cache.outputs.dir }} - # look to see if there is a cache hit for the corresponding requirements files - # cache will be reset on changes to any requirements or every month - key: ${{ runner.os }}-venv-${{ hashFiles('**/requirements.txt') }}-${{ hashFiles('**/requirements-tests.txt') }} - -${{ hashFiles('**/requirements-extras.txt') }}-${{ hashFiles('setup.py') }} -${{ hashFiles('**/CI.yml') }}-${{ steps.date.outputs.date }} + path: ~/elephant-data + key: datasets-${{ steps.elephant-data.outputs.dataset_hash }} + restore-keys: datasets- + enableCrossOsArchive: true - name: Setup environment run: | @@ -287,6 +322,10 @@ jobs: - name: Test with pytest run: | + if [ -d ~/elephant-data ]; then + export ELEPHANT_DATA_LOCATION=~/elephant-data + echo $ELEPHANT_DATA_LOCATION + fi mpiexec -n 1 python -m mpi4py -m coverage run --source=elephant -m pytest coveralls --service=github || echo "Coveralls submission failed" env: @@ -316,7 +355,7 @@ jobs: id: date run: echo "date=$(date +'%Y-%m')" >> $GITHUB_OUTPUT - - uses: actions/checkout@v3 + - uses: actions/checkout@v4.1.6 - name: Get pip cache dir id: pip-cache @@ -330,6 +369,20 @@ jobs: key: ${{ runner.os }}-pip-${{hashFiles('requirements/environment-tests.yml') }}-${{ hashFiles('**/CI.yml') }}-${{ steps.date.outputs.date }} + - name: Get current hash (SHA) of the elephant_data repo + id: elephant-data + run: | + echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/elephant-data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT + + - uses: actions/cache/restore@v4.0.2 + # Loading cache of elephant-data + id: cache-datasets + with: + path: ~/elephant-data + key: datasets-${{ steps.elephant-data.outputs.dataset_hash }} + restore-keys: datasets- + enableCrossOsArchive: true + - uses: conda-incubator/setup-miniconda@030178870c779d9e5e1b4e563269f3aa69b04081 # corresponds to v3.0.3 with: auto-update-conda: true @@ -358,6 +411,10 @@ jobs: - name: Test with pytest shell: bash -el {0} run: | + if [ -d ~/elephant-data ]; then + export ELEPHANT_DATA_LOCATION=~/elephant-data + echo $ELEPHANT_DATA_LOCATION + fi pytest --cov=elephant # ____ @@ -383,7 +440,7 @@ jobs: id: date run: echo "date=$(date +'%Y-%m')" >> $GITHUB_OUTPUT - - uses: actions/checkout@v3 + - uses: actions/checkout@v4.1.6 - name: Get pip cache dir id: pip-cache @@ -448,10 +505,10 @@ jobs: - name: Get current year-month id: date run: echo "::set-output name=date::$(date +'%Y-%m')" - - uses: actions/checkout@v3 + - uses: actions/checkout@v4.1.6 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5.1.0 with: python-version: ${{ matrix.python-version }} diff --git a/.github/workflows/cache_elephant_data.yml b/.github/workflows/cache_elephant_data.yml new file mode 100644 index 000000000..ab17af4f9 --- /dev/null +++ b/.github/workflows/cache_elephant_data.yml @@ -0,0 +1,65 @@ +name: Create caches for elephant_data + +on: + workflow_dispatch: # Workflow can be triggered manually via GH actions webinterface + push: # When something is pushed into master this checks if caches need to re-created + branches: + - master + schedule: + - cron: "11 23 * * *" # Daily at 23:11 UTC + + +jobs: + create-data-cache-if-missing: + name: Caching data env + runs-on: ubuntu-latest + strategy: + # do not cancel all in-progress jobs if any matrix job fails + fail-fast: false + + steps: + - name: Get current hash (SHA) of the elephant_data repo + id: elephant-data + run: | + echo "dataset_hash=$(git ls-remote https://gin.g-node.org/NeuralEnsemble/elephant-data.git HEAD | cut -f1)" >> $GITHUB_OUTPUT + + - uses: actions/cache@v4.0.2 + # Loading cache of elephant-data + id: cache-datasets + with: + path: ~/elephant-data + key: datasets-${{ steps.elephant-data.outputs.dataset_hash }} + + - name: Cache found? + run: echo "Cache-hit == ${{steps.cache-datasets.outputs.cache-hit == 'true'}}" + + - name: Configuring git + if: steps.cache-datasets.outputs.cache-hit != 'true' + run: | + git config --global user.email "elephant_ci@fake_mail.com" + git config --global user.name "elephant CI" + git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency + + - name: Install Datalad Linux + if: steps.cache-datasets.outputs.cache-hit != 'true' + run: | + python -m pip install -U pip # Official recommended way + pip install datalad-installer + datalad-installer --sudo ok git-annex --method datalad/packages + pip install datalad + + - name: Download dataset + id: download-dataset + if: steps.cache-datasets.outputs.cache-hit != 'true' + # Download repository and also fetch data + run: | + cd ~ + datalad --version + datalad install --recursive --get-data https://gin.g-node.org/NeuralEnsemble/elephant-data + + - name: Show size of the cache to assert data is downloaded + run: | + cd ~ + du -hs ~/elephant-data + ls -lh ~/elephant-data + diff --git a/elephant/datasets.py b/elephant/datasets.py index 58d52d31f..39ed56911 100644 --- a/elephant/datasets.py +++ b/elephant/datasets.py @@ -1,17 +1,19 @@ import hashlib +import os +import ssl import tempfile +from urllib.parse import urlparse import warnings -import ssl - -from elephant import _get_version +from os import environ, getenv from pathlib import Path -from urllib.request import urlretrieve, urlopen from urllib.error import HTTPError, URLError +from urllib.request import urlopen, urlretrieve from zipfile import ZipFile -from os import environ, getenv from tqdm import tqdm +from elephant import _get_version + ELEPHANT_TMP_DIR = Path(tempfile.gettempdir()) / "elephant" @@ -74,65 +76,76 @@ def download(url, filepath=None, checksum=None, verbose=True): def download_datasets(repo_path, filepath=None, checksum=None, verbose=True): r""" - This function can be used to download files from elephant-data using - only the path relative to the root of the elephant-data repository. - The default URL used, points to elephants corresponding release of - elephant-data. - Different versions of the elephant package may require different - versions of elephant-data. - e.g. the follwoing URLs: - - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/0.0.1 - points to release v0.0.1. - - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/master - always points to the latest state of elephant-data. - - https://datasets.python-elephant.org/ - points to the root of elephant data - - To change this URL, use the environment variable `ELEPHANT_DATA_URL`. - When using data, which is not yet contained in the master branch or a - release of elephant data, e.g. during development, this variable can - be used to change the default URL. - For example to use data on branch `multitaper`, change the - `ELEPHANT_DATA_URL` to - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper. - For a complete example, see Examples section. - - Parameters - ---------- - repo_path : str - String denoting the path relative to elephant-data repository root - filepath : str, optional - Path to temporary folder where the downloaded files will be stored - checksum : str, otpional - Checksum to verify dara integrity after download - verbose : bool, optional - Whether to disable the entire progressbar wrapper []. - If set to None, disable on non-TTY. - Default: True - - Returns - ------- - filepath : pathlib.Path - Path to downloaded files. - - - Notes - ----- - The default URL always points to elephant-data. Please - do not change its value. For development purposes use the environment - variable 'ELEPHANT_DATA_URL'. - - Examples - -------- - The following example downloads a file from elephant-data branch - 'multitaper', by setting the environment variable to the branch URL: - - >>> import os - >>> from elephant.datasets import download_datasets - >>> os.environ["ELEPHANT_DATA_URL"] = "https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper" # noqa - >>> download_datasets("unittest/spectral/multitaper_psd/data/time_series.npy") # doctest: +SKIP - PosixPath('/tmp/elephant/time_series.npy') - """ + This function can be used to download files from elephant-data using + only the path relative to the root of the elephant-data repository. + The default URL used, points to elephants corresponding release of + elephant-data. + Different versions of the elephant package may require different + versions of elephant-data. + e.g. the following URLs: + - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/0.0.1 + points to release v0.0.1. + - https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/master + always points to the latest state of elephant-data. + - https://datasets.python-elephant.org/ + points to the root of elephant data + + To change this URL, use the environment variable `ELEPHANT_DATA_LOCATION`. + When using data, which is not yet contained in the master branch or a + release of elephant data, e.g. during development, this variable can + be used to change the default URL. + For example to use data on branch `multitaper`, change the + `ELEPHANT_DATA_LOCATION` to + https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper. + For a complete example, see Examples section. + + To use a local copy of elephant-data, use the environment variable + `ELEPHANT_DATA_LOCATION`, e.g. set to /home/user/elephant-data. + + Parameters + ---------- + repo_path : str + String denoting the path relative to elephant-data repository root + filepath : str, optional + Path to temporary folder where the downloaded files will be stored + checksum : str, optional + Checksum to verify data integrity after download + verbose : bool, optional + Whether to disable the entire progressbar wrapper []. + If set to None, disable on non-TTY. + Default: True + + Returns + ------- + filepath : pathlib.Path + Path to downloaded files. + + + Notes + ----- + The default URL always points to elephant-data. Please + do not change its value. For development purposes use the environment + variable 'ELEPHANT_DATA_LOCATION'. + + Examples + -------- + The following example downloads a file from elephant-data branch + 'multitaper', by setting the environment variable to the branch URL: + + >>> import os + >>> from elephant.datasets import download_datasets + >>> os.environ["ELEPHANT_DATA_LOCATION"] = "https://web.gin.g-node.org/NeuralEnsemble/elephant-data/raw/multitaper" # noqa + >>> download_datasets("unittest/spectral/multitaper_psd/data/time_series.npy") # doctest: +SKIP + PosixPath('/tmp/elephant/time_series.npy') + """ + + env_var = 'ELEPHANT_DATA_LOCATION' + if env_var in os.environ: # user did set path or URL + if os.path.exists(getenv(env_var)): + return Path(f"{getenv(env_var)}/{repo_path}") + elif urlparse(getenv(env_var)).scheme not in ('http', 'https'): + raise ValueError(f"The environment variable {env_var} must be set to either an existing file system path " + f"or a valid URL. Given value: '{getenv(env_var)}' is neither.") # this url redirects to the current location of elephant-data url_to_root = "https://datasets.python-elephant.org/" @@ -141,7 +154,7 @@ def download_datasets(repo_path, filepath=None, checksum=None, # (version elephant is equal to version elephant-data) default_url = url_to_root + f"raw/v{_get_version()}" - if 'ELEPHANT_DATA_URL' not in environ: # user did not set URL + if env_var not in environ: # user did not set URL # is 'version-URL' available? (not for elephant development version) try: urlopen(default_url+'/README.md') @@ -149,7 +162,7 @@ def download_datasets(repo_path, filepath=None, checksum=None, except HTTPError as error: # if corresponding elephant-data version is not found, # use latest commit of elephant-data - default_url = url_to_root + f"raw/master" + default_url = url_to_root + "raw/master" warnings.warn(f"No corresponding version of elephant-data found.\n" f"Elephant version: {_get_version()}. " @@ -164,12 +177,12 @@ def download_datasets(repo_path, filepath=None, checksum=None, ctx.check_hostname = True urlopen(default_url + '/README.md') except HTTPError: # e.g. 404 - default_url = url_to_root + f"raw/master" + default_url = url_to_root + "raw/master" warnings.warn(f"Data URL:{default_url}, error: {error}." f"{error.reason}") - url = f"{getenv('ELEPHANT_DATA_URL', default_url)}/{repo_path}" + url = f"{getenv(env_var, default_url)}/{repo_path}" return download(url, filepath, checksum, verbose) diff --git a/elephant/test/test_causality.py b/elephant/test/test_causality.py index 552fe2f05..648af2a04 100644 --- a/elephant/test/test_causality.py +++ b/elephant/test/test_causality.py @@ -16,7 +16,7 @@ from elephant.spectral import multitaper_cross_spectrum, multitaper_coherence import elephant.causality.granger -from elephant.datasets import download_datasets, ELEPHANT_TMP_DIR +from elephant.datasets import download_datasets class PairwiseGrangerTestCase(unittest.TestCase): @@ -453,12 +453,16 @@ def test_pairwise_spectral_granger_against_ground_truth(self): ("noise_covariance.npy", "6f80ccff2b2aa9485dc9c01d81570bf5") ] + downloaded_files = {} for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) - signals = np.load(ELEPHANT_TMP_DIR / 'time_series.npy') - weights = np.load(ELEPHANT_TMP_DIR / 'weights.npy') - cov = np.load(ELEPHANT_TMP_DIR / 'noise_covariance.npy') + downloaded_files[filename] = { + 'filename': filename, + 'path': download_datasets(repo_path=f"{repo_path}/{filename}", + checksum=checksum)} + + signals = np.load(downloaded_files['time_series.npy']['path']) + weights = np.load(downloaded_files['weights.npy']['path']) + cov = np.load(downloaded_files['noise_covariance.npy']['path']) # Estimate spectral Granger Causality f, spectral_causality = \ @@ -532,11 +536,15 @@ def test_pairwise_spectral_granger_against_r_grangers(self): ("gc_matrix.npy", "c57262145e74a178588ff0a1004879e2") ] + downloaded_files = {} for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) - signal = np.load(ELEPHANT_TMP_DIR / 'time_series_small.npy') - gc_matrix = np.load(ELEPHANT_TMP_DIR / 'gc_matrix.npy') + downloaded_files[filename] = { + 'filename': filename, + 'path': download_datasets(repo_path=f"{repo_path}/{filename}", + checksum=checksum)} + + signal = np.load(downloaded_files['time_series_small.npy']['path']) + gc_matrix = np.load(downloaded_files['gc_matrix.npy']['path']) denom = 20 f, spectral_causality = \ diff --git a/elephant/test/test_datasets.py b/elephant/test/test_datasets.py new file mode 100644 index 000000000..24334e2f6 --- /dev/null +++ b/elephant/test/test_datasets.py @@ -0,0 +1,35 @@ +import unittest +import os +from unittest.mock import patch +from pathlib import Path +import urllib + +from elephant.datasets import download_datasets + + +class TestDownloadDatasets(unittest.TestCase): + @patch.dict(os.environ, {'ELEPHANT_DATA_LOCATION': '/valid/path'}, clear=True) + @patch('os.path.exists', return_value=True) + def test_valid_path(self, mock_exists): + repo_path = 'some/repo/path' + expected = Path('/valid/path/some/repo/path') + result = download_datasets(repo_path) + self.assertEqual(result, expected) + + @patch.dict(os.environ, {'ELEPHANT_DATA_LOCATION': 'http://valid.url'}, clear=True) + @patch('os.path.exists', return_value=False) + def test_valid_url(self, mock_exists): + repo_path = 'some/repo/path' + self.assertRaises(urllib.error.URLError, download_datasets, repo_path) + + @patch.dict(os.environ, {'ELEPHANT_DATA_LOCATION': 'invalid_path_or_url'}, clear=True) + @patch('os.path.exists', return_value=False) + def test_invalid_value(self, mock_exists): + repo_path = 'some/repo/path' + with self.assertRaises(ValueError) as cm: + download_datasets(repo_path) + self.assertIn("invalid_path_or_url", str(cm.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/elephant/test/test_phase_analysis.py b/elephant/test/test_phase_analysis.py index 6b8218c99..c328681ec 100644 --- a/elephant/test/test_phase_analysis.py +++ b/elephant/test/test_phase_analysis.py @@ -384,6 +384,7 @@ def setUpClass(cls): # downloaded once into a local temporary directory # and then loaded/ read for each test function individually. + cls.tmp_path = {} # REAL DATA real_data_path = "unittest/phase_analysis/weighted_phase_lag_index/" \ "data/wpli_real_data" @@ -398,8 +399,10 @@ def setUpClass(cls): ) for filename, checksum in cls.files_to_download_real: # files will be downloaded to ELEPHANT_TMP_DIR - cls.tmp_path = download_datasets( - f"{real_data_path}/{filename}", checksum=checksum) + cls.tmp_path[filename] = { + 'filename': filename, + 'path': download_datasets( + f"{real_data_path}/{filename}", checksum=checksum)} # ARTIFICIAL DATA artificial_data_path = "unittest/phase_analysis/" \ "weighted_phase_lag_index/data/wpli_specific_artificial_dataset" @@ -409,11 +412,13 @@ def setUpClass(cls): ) for filename, checksum in cls.files_to_download_artificial: # files will be downloaded to ELEPHANT_TMP_DIR - cls.tmp_path = download_datasets( - f"{artificial_data_path}/{filename}", checksum=checksum) + cls.tmp_path[filename] = { + 'filename': filename, + 'path': download_datasets( + f"{artificial_data_path}/{filename}", checksum=checksum)} # GROUND TRUTH DATA ground_truth_data_path = "unittest/phase_analysis/" \ - "weighted_phase_lag_index/data/wpli_ground_truth" + "weighted_phase_lag_index/data/wpli_ground_truth" cls.files_to_download_ground_truth = ( ("ground_truth_WPLI_from_ft_connectivity_wpli_" "with_real_LFPs_R2G.csv", "4d9a7b7afab7d107023956077ab11fef"), @@ -422,8 +427,10 @@ def setUpClass(cls): ) for filename, checksum in cls.files_to_download_ground_truth: # files will be downloaded into ELEPHANT_TMP_DIR - cls.tmp_path = download_datasets( - f"{ground_truth_data_path}/{filename}", checksum=checksum) + cls.tmp_path[filename] = { + 'filename': filename, + 'path': download_datasets( + f"{ground_truth_data_path}/{filename}", checksum=checksum)} def setUp(self): self.tolerance = 1e-15 @@ -431,10 +438,10 @@ def setUp(self): # load real/artificial LFP-dataset for ground-truth consistency checks # real LFP-dataset dataset1_real = scipy.io.loadmat( - f"{self.tmp_path.parent}/{self.files_to_download_real[0][0]}", + f"{self.tmp_path[self.files_to_download_real[0][0]]['path']}", squeeze_me=True) dataset2_real = scipy.io.loadmat( - f"{self.tmp_path.parent}/{self.files_to_download_real[1][0]}", + f"{self.tmp_path[self.files_to_download_real[1][0]]['path']}", squeeze_me=True) # get relevant values @@ -449,12 +456,12 @@ def setUp(self): signal=self.lfps2_real, sampling_rate=self.sf2_real) # artificial LFP-dataset - dataset1_artificial = scipy.io.loadmat( - f"{self.tmp_path.parent}/" - f"{self.files_to_download_artificial[0][0]}", squeeze_me=True) - dataset2_artificial = scipy.io.loadmat( - f"{self.tmp_path.parent}/" - f"{self.files_to_download_artificial[1][0]}", squeeze_me=True) + dataset1_path = \ + f"{self.tmp_path[self.files_to_download_artificial[0][0]]['path']}" + dataset1_artificial = scipy.io.loadmat(dataset1_path, squeeze_me=True) + dataset2_path = \ + f"{self.tmp_path[self.files_to_download_artificial[1][0]]['path']}" + dataset2_artificial = scipy.io.loadmat(dataset2_path, squeeze_me=True) # get relevant values self.lfps1_artificial = dataset1_artificial['lfp_matrix'] * pq.uV self.sf1_artificial = dataset1_artificial['sf'] * pq.Hz @@ -469,12 +476,10 @@ def setUp(self): # load ground-truth reference calculated by: # Matlab package 'FieldTrip': ft_connectivity_wpli() self.wpli_ground_truth_ft_connectivity_wpli_real = np.loadtxt( - f"{self.tmp_path.parent}/" - f"{self.files_to_download_ground_truth[0][0]}", + f"{self.tmp_path[self.files_to_download_ground_truth[0][0]]['path']}", # noqa delimiter=',', dtype=np.float64) self.wpli_ground_truth_ft_connectivity_artificial = np.loadtxt( - f"{self.tmp_path.parent}/" - f"{self.files_to_download_ground_truth[1][0]}", + f"{self.tmp_path[self.files_to_download_ground_truth[1][0]]['path']}", # noqa delimiter=',', dtype=np.float64) def test_WPLI_ground_truth_consistency_real_LFP_dataset(self): @@ -700,7 +705,7 @@ def test_WPLI_raises_error_if_AnalogSignals_have_diff_sampling_rate(): def test_WPLI_raises_error_if_sampling_rate_not_given(self): """ Test if WPLI raises a ValueError, when the sampling rate is not given - for np.array() or Quanitity input. + for np.array() or Quantity input. """ signal_x = np.random.random([40, 2100]) * pq.mV signal_y = np.random.random([40, 2100]) * pq.mV diff --git a/elephant/test/test_spectral.py b/elephant/test/test_spectral.py index 41244fb66..00cdd6941 100644 --- a/elephant/test/test_spectral.py +++ b/elephant/test/test_spectral.py @@ -20,7 +20,7 @@ from packaging import version import elephant.spectral -from elephant.datasets import download_datasets, ELEPHANT_TMP_DIR +from elephant.datasets import download_datasets class WelchPSDTestCase(unittest.TestCase): @@ -278,12 +278,15 @@ def test_multitaper_psd_against_nitime(self): ("psd_nitime.npy", "89d1f53957e66c786049ea425b53c0e8") ] + downloaded_files = {} for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) + downloaded_files[filename] = { + 'filename': filename, + 'path': download_datasets(repo_path=f"{repo_path}/{filename}", + checksum=checksum)} - time_series = np.load(ELEPHANT_TMP_DIR / 'time_series.npy') - psd_nitime = np.load(ELEPHANT_TMP_DIR / 'psd_nitime.npy') + time_series = np.load(downloaded_files['time_series.npy']['path']) + psd_nitime = np.load(downloaded_files['psd_nitime.npy']['path']) freqs, psd_multitaper = elephant.spectral.multitaper_psd( signal=time_series, fs=0.1, nw=4, num_tapers=8) diff --git a/elephant/test/test_spike_train_correlation.py b/elephant/test/test_spike_train_correlation.py index cae01e479..90de65ea1 100644 --- a/elephant/test/test_spike_train_correlation.py +++ b/elephant/test/test_spike_train_correlation.py @@ -20,7 +20,7 @@ from elephant.spike_train_generation import StationaryPoissonProcess, \ StationaryGammaProcess import math -from elephant.datasets import download_datasets, ELEPHANT_TMP_DIR +from elephant.datasets import download_datasets from elephant.spike_train_generation import homogeneous_poisson_process, \ homogeneous_gamma_process @@ -852,11 +852,15 @@ def test_sttc_validation_test(self): files_to_download = [("spike_time_tiling_coefficient_results.nix", "e3749d79046622494660a03e89950f51")] + downloaded_files = {} for filename, checksum in files_to_download: - filepath = download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) + downloaded_files[filename] = { + 'filename': filename, + 'path': download_datasets(repo_path=f"{repo_path}/{filename}", + checksum=checksum)} - reader = NixIO(filepath, mode='ro') + reader = NixIO(downloaded_files[ + 'spike_time_tiling_coefficient_results.nix']['path'], mode='ro') test_data_block = reader.read() for segment in test_data_block[0].segments: diff --git a/elephant/test/test_spike_train_dissimilarity.py b/elephant/test/test_spike_train_dissimilarity.py index 6cab2858b..4619d4bba 100644 --- a/elephant/test/test_spike_train_dissimilarity.py +++ b/elephant/test/test_spike_train_dissimilarity.py @@ -15,7 +15,7 @@ from elephant.spike_train_generation import StationaryPoissonProcess import elephant.spike_train_dissimilarity as stds -from elephant.datasets import download_datasets, ELEPHANT_TMP_DIR +from elephant.datasets import download_datasets class TimeScaleDependSpikeTrainDissimMeasuresTestCase(unittest.TestCase): @@ -398,12 +398,16 @@ def test_victor_purpura_matlab_comparison_float(self): ("times_float.npy", "ed1ff4d2c0eeed4a2b50a456803656be"), ("matlab_results_float.npy", "a17f049e7ad0ddf7ca812e86fdb92646")] + downloaded_files = {} for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) + downloaded_files[filename] = { + 'filename': filename, + 'path': download_datasets(repo_path=f"{repo_path}/{filename}", + checksum=checksum)} - times_float = np.load(ELEPHANT_TMP_DIR / 'times_float.npy') - mat_res_float = np.load(ELEPHANT_TMP_DIR / 'matlab_results_float.npy') + times_float = np.load(downloaded_files['times_float.npy']['path']) + mat_res_float = np.load(downloaded_files[ + 'matlab_results_float.npy']['path']) r_float = SpikeTrain(times_float[0], units='ms', t_start=0, t_stop=1000 * ms) @@ -428,12 +432,16 @@ def test_victor_purpura_matlab_comparison_int(self): ("times_int.npy", "aa1411c04da3f58d8b8913ae2f935057"), ("matlab_results_int.npy", "7edd32e50edde12dc1ef4aa5f57f70fb")] + downloaded_files = {} for filename, checksum in files_to_download: - download_datasets(repo_path=f"{repo_path}/{filename}", - checksum=checksum) - - times_int = np.load(ELEPHANT_TMP_DIR / 'times_int.npy') - mat_res_int = np.load(ELEPHANT_TMP_DIR / 'matlab_results_int.npy') + downloaded_files[filename] = { + 'filename': filename, + 'path': download_datasets(repo_path=f"{repo_path}/{filename}", + checksum=checksum)} + + times_int = np.load(downloaded_files['times_int.npy']['path']) + mat_res_int = np.load( + downloaded_files['matlab_results_int.npy']['path']) r_int = SpikeTrain(times_int[0], units='ms', t_start=0, t_stop=1000 * ms) diff --git a/elephant/test/test_spike_train_synchrony.py b/elephant/test/test_spike_train_synchrony.py index 58be525eb..d3455b115 100644 --- a/elephant/test/test_spike_train_synchrony.py +++ b/elephant/test/test_spike_train_synchrony.py @@ -168,7 +168,7 @@ def test_spike_contrast_with_Izhikevich_network_auto(self): checksum = "70e848500c1d9c6403b66de8c741d849" filepath_zip = download_datasets(repo_path=izhikevich_gin, checksum=checksum) - unzip(filepath_zip) + unzip(filepath_zip, outdir=filepath_zip.parent) filepath_json = filepath_zip.with_suffix(".json") with open(filepath_json) as read_file: data = json.load(read_file)