From 1ac4d0311cf391d707d5e185002d1ca970e567e4 Mon Sep 17 00:00:00 2001 From: Jonathan Green Date: Wed, 4 Sep 2024 11:12:15 -0300 Subject: [PATCH] Convert MARC Export to use Celery (PP-1472) (#2017) This PR takes the approach of processing batch_size (default: 500) records in one task, then saving the output to redis and re-queuing the task to process the next batch_size of records. Once the data in redis is large enough, a multipart upload is started in S3, and the multipart data is cached in redis. This continues until the file is completely generated. --- .gitignore | 3 + README.md | 2 +- bin/cache_marc_files | 6 - docker-compose.yml | 2 +- docker/services/cron/cron.d/circulation | 3 - .../api/admin/controller/catalog_services.py | 4 +- src/palace/manager/api/circulation_manager.py | 5 +- src/palace/manager/api/controller/marc.py | 16 +- src/palace/manager/celery/tasks/marc.py | 161 ++++ src/palace/manager/marc/__init__.py | 0 .../{core/marc.py => marc/annotator.py} | 406 ++------ src/palace/manager/marc/exporter.py | 373 ++++++++ src/palace/manager/marc/settings.py | 73 ++ src/palace/manager/marc/uploader.py | 144 +++ src/palace/manager/scripts/marc.py | 223 ----- src/palace/manager/service/celery/celery.py | 6 + .../integration_registry/catalog_services.py | 14 +- .../manager/service/redis/models/lock.py | 203 +++- .../manager/service/redis/models/marc.py | 312 ++++++ src/palace/manager/service/redis/redis.py | 15 + src/palace/manager/service/storage/s3.py | 2 + tests/conftest.py | 1 + tests/fixtures/database.py | 59 +- tests/fixtures/marc.py | 99 ++ tests/fixtures/s3.py | 173 +++- .../admin/controller/test_catalog_services.py | 29 +- tests/manager/api/controller/test_marc.py | 21 +- tests/manager/celery/tasks/test_marc.py | 321 +++++++ tests/manager/core/test_marc.py | 900 ------------------ tests/manager/marc/__init__.py | 0 tests/manager/marc/test_annotator.py | 716 ++++++++++++++ tests/manager/marc/test_exporter.py | 425 +++++++++ tests/manager/marc/test_uploader.py | 334 +++++++ tests/manager/scripts/test_marc.py | 466 --------- .../manager/service/redis/models/test_lock.py | 153 ++- .../manager/service/redis/models/test_marc.py | 469 +++++++++ tests/manager/service/storage/test_s3.py | 97 +- tox.ini | 2 +- 38 files changed, 4162 insertions(+), 2076 deletions(-) delete mode 100755 bin/cache_marc_files create mode 100644 src/palace/manager/celery/tasks/marc.py create mode 100644 src/palace/manager/marc/__init__.py rename src/palace/manager/{core/marc.py => marc/annotator.py} (60%) create mode 100644 src/palace/manager/marc/exporter.py create mode 100644 src/palace/manager/marc/settings.py create mode 100644 src/palace/manager/marc/uploader.py delete mode 100644 src/palace/manager/scripts/marc.py create mode 100644 src/palace/manager/service/redis/models/marc.py create mode 100644 tests/fixtures/marc.py create mode 100644 tests/manager/celery/tasks/test_marc.py delete mode 100644 tests/manager/core/test_marc.py create mode 100644 tests/manager/marc/__init__.py create mode 100644 tests/manager/marc/test_annotator.py create mode 100644 tests/manager/marc/test_exporter.py create mode 100644 tests/manager/marc/test_uploader.py delete mode 100644 tests/manager/scripts/test_marc.py create mode 100644 tests/manager/service/redis/models/test_marc.py diff --git a/.gitignore b/.gitignore index 7ee0099696..50d20bcb5c 100644 --- a/.gitignore +++ b/.gitignore @@ -78,3 +78,6 @@ docs/source/* .DS_Store src/palace/manager/core/_version.py + +# Celery beat schedule file +celerybeat-schedule.db diff --git a/README.md b/README.md index 2589105bd5..25ae67f438 100644 --- a/README.md +++ b/README.md @@ -149,7 +149,7 @@ grant all privileges on database circ to palace; Redis is used as the broker for Celery and the caching layer. You can run Redis with docker using the following command: ```sh -docker run -d --name redis -p 6379:6379 redis +docker run -d --name redis -p 6379:6379 redis/redis-stack-server ``` ### Environment variables diff --git a/bin/cache_marc_files b/bin/cache_marc_files deleted file mode 100755 index b42e34ce62..0000000000 --- a/bin/cache_marc_files +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python -"""Refresh and store the MARC files for lanes.""" - -from palace.manager.scripts.marc import CacheMARCFiles - -CacheMARCFiles().run() diff --git a/docker-compose.yml b/docker-compose.yml index f9e801d4a2..ad6faefe4d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -110,7 +110,7 @@ services: retries: 5 redis: - image: "redis:7" + image: "redis/redis-stack-server:7.4.0-v0" healthcheck: test: ["CMD", "redis-cli", "ping"] interval: 30s diff --git a/docker/services/cron/cron.d/circulation b/docker/services/cron/cron.d/circulation index c82a52652b..10122136ed 100644 --- a/docker/services/cron/cron.d/circulation +++ b/docker/services/cron/cron.d/circulation @@ -36,9 +36,6 @@ HOME=/var/www/circulation # Sync a library's collection with NoveList 0 0 * * 0 root bin/run -d 60 novelist_update >> /var/log/cron.log 2>&1 -# Generate MARC files for libraries that have a MARC exporter configured. -0 3,11 * * * root bin/run cache_marc_files >> /var/log/cron.log 2>&1 - # The remaining scripts keep the circulation manager in sync with # specific types of collections. diff --git a/src/palace/manager/api/admin/controller/catalog_services.py b/src/palace/manager/api/admin/controller/catalog_services.py index 5245ec6ee1..b2dd6ddaef 100644 --- a/src/palace/manager/api/admin/controller/catalog_services.py +++ b/src/palace/manager/api/admin/controller/catalog_services.py @@ -8,9 +8,9 @@ ) from palace.manager.api.admin.form_data import ProcessFormData from palace.manager.api.admin.problem_details import MULTIPLE_SERVICES_FOR_LIBRARY -from palace.manager.core.marc import MARCExporter from palace.manager.integration.goals import Goals from palace.manager.integration.settings import BaseSettings +from palace.manager.marc.exporter import MarcExporter from palace.manager.sqlalchemy.listeners import site_configuration_has_changed from palace.manager.sqlalchemy.model.integration import ( IntegrationConfiguration, @@ -21,7 +21,7 @@ class CatalogServicesController( - IntegrationSettingsController[MARCExporter], + IntegrationSettingsController[MarcExporter], AdminPermissionsControllerMixin, ): def process_catalog_services(self) -> Response | ProblemDetail: diff --git a/src/palace/manager/api/circulation_manager.py b/src/palace/manager/api/circulation_manager.py index 36d569c7e3..d6d2d682b7 100644 --- a/src/palace/manager/api/circulation_manager.py +++ b/src/palace/manager/api/circulation_manager.py @@ -343,7 +343,10 @@ def setup_one_time_controllers(self): """ self.index_controller = IndexController(self) self.opds_feeds = OPDSFeedController(self) - self.marc_records = MARCRecordController(self.services.storage.public()) + self.marc_records = MARCRecordController( + self.services.storage.public(), + self.services.integration_registry.catalog_services(), + ) self.loans = LoanController(self) self.annotations = AnnotationController(self) self.urn_lookup = URNLookupController(self) diff --git a/src/palace/manager/api/controller/marc.py b/src/palace/manager/api/controller/marc.py index 802a576081..3114fbca40 100644 --- a/src/palace/manager/api/controller/marc.py +++ b/src/palace/manager/api/controller/marc.py @@ -9,8 +9,11 @@ from sqlalchemy import select from sqlalchemy.orm import Session -from palace.manager.core.marc import MARCExporter from palace.manager.integration.goals import Goals +from palace.manager.marc.exporter import MarcExporter +from palace.manager.service.integration_registry.catalog_services import ( + CatalogServicesRegistry, +) from palace.manager.service.storage.s3 import S3Service from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.integration import ( @@ -49,21 +52,24 @@ class MARCRecordController: """ - def __init__(self, storage_service: S3Service | None) -> None: + def __init__( + self, storage_service: S3Service | None, registry: CatalogServicesRegistry + ) -> None: self.storage_service = storage_service + self.registry = registry @staticmethod def library() -> Library: return flask.request.library # type: ignore[no-any-return,attr-defined] - @staticmethod - def has_integration(session: Session, library: Library) -> bool: + def has_integration(self, session: Session, library: Library) -> bool: + protocols = self.registry.get_protocols(MarcExporter) integration_query = ( select(IntegrationLibraryConfiguration) .join(IntegrationConfiguration) .where( IntegrationConfiguration.goal == Goals.CATALOG_GOAL, - IntegrationConfiguration.protocol == MARCExporter.__name__, + IntegrationConfiguration.protocol.in_(protocols), IntegrationLibraryConfiguration.library == library, ) ) diff --git a/src/palace/manager/celery/tasks/marc.py b/src/palace/manager/celery/tasks/marc.py new file mode 100644 index 0000000000..2d164adcb2 --- /dev/null +++ b/src/palace/manager/celery/tasks/marc.py @@ -0,0 +1,161 @@ +import datetime +from typing import Any + +from celery import shared_task + +from palace.manager.celery.task import Task +from palace.manager.marc.exporter import LibraryInfo, MarcExporter +from palace.manager.marc.uploader import MarcUploadManager +from palace.manager.service.celery.celery import QueueNames +from palace.manager.service.redis.models.marc import ( + MarcFileUploadSession, + MarcFileUploadState, +) +from palace.manager.util.datetime_helpers import utc_now + + +@shared_task(queue=QueueNames.default, bind=True) +def marc_export(task: Task, force: bool = False) -> None: + """ + Export MARC records for all collections with the `export_marc_records` flag set to True, whose libraries + have a MARC exporter integration enabled. + """ + + with task.session() as session: + registry = task.services.integration_registry.catalog_services() + start_time = utc_now() + collections = MarcExporter.enabled_collections(session, registry) + for collection in collections: + # Collection.id should never be able to be None here, but mypy doesn't know that. + # So we assert it for mypy's benefit. + assert collection.id is not None + upload_session = MarcFileUploadSession( + task.services.redis.client(), collection.id + ) + with upload_session.lock() as acquired: + if not acquired: + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because another task holds its lock." + ) + continue + + if ( + upload_state := upload_session.state() + ) != MarcFileUploadState.INITIAL: + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because it is already being " + f"processed (state: {upload_state})." + ) + continue + + libraries_info = MarcExporter.enabled_libraries( + session, registry, collection.id + ) + needs_update = ( + any(info.needs_update for info in libraries_info) or force + ) + + if not needs_update: + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because it has been updated recently." + ) + continue + + works = MarcExporter.query_works( + session, + collection.id, + work_id_offset=0, + batch_size=1, + ) + if not works: + task.log.info( + f"Skipping collection {collection.name} ({collection.id}) because it has no works." + ) + continue + + task.log.info( + f"Generating MARC records for collection {collection.name} ({collection.id})." + ) + upload_session.set_state(MarcFileUploadState.QUEUED) + marc_export_collection.delay( + collection_id=collection.id, + start_time=start_time, + libraries=[l.dict() for l in libraries_info], + ) + + +@shared_task(queue=QueueNames.default, bind=True) +def marc_export_collection( + task: Task, + collection_id: int, + start_time: datetime.datetime, + libraries: list[dict[str, Any]], + batch_size: int = 500, + last_work_id: int | None = None, + update_number: int = 0, +) -> None: + """ + Export MARC records for a single collection. + + This task is designed to be re-queued until all works in the collection have been processed, + this can take some time, however each individual task should complete quickly, so that it + doesn't block other tasks from running. + """ + + base_url = task.services.config.sitewide.base_url() + storage_service = task.services.storage.public() + libraries_info = [LibraryInfo.parse_obj(l) for l in libraries] + upload_manager = MarcUploadManager( + storage_service, + MarcFileUploadSession( + task.services.redis.client(), collection_id, update_number + ), + ) + with upload_manager.begin(): + if not upload_manager.locked: + task.log.info( + f"Skipping collection {collection_id} because another task is already processing it." + ) + return + + with task.session() as session: + works = MarcExporter.query_works( + session, + collection_id, + work_id_offset=last_work_id, + batch_size=batch_size, + ) + for work in works: + MarcExporter.process_work( + work, libraries_info, base_url, upload_manager=upload_manager + ) + + # Sync the upload_manager to ensure that all the data is written to storage. + upload_manager.sync() + + if len(works) == batch_size: + # This task is complete, but there are more works waiting to be exported. So we requeue ourselves + # to process the next batch. + raise task.replace( + marc_export_collection.s( + collection_id=collection_id, + start_time=start_time, + libraries=[l.dict() for l in libraries_info], + batch_size=batch_size, + last_work_id=works[-1].id, + update_number=upload_manager.update_number, + ) + ) + + # If we got here, we have finished generating MARC records. Cleanup and exit. + with task.transaction() as session: + collection = MarcExporter.collection(session, collection_id) + collection_name = collection.name if collection else "unknown" + completed_uploads = upload_manager.complete() + MarcExporter.create_marc_upload_records( + session, start_time, collection_id, libraries_info, completed_uploads + ) + upload_manager.remove_session() + task.log.info( + f"Finished generating MARC records for collection '{collection_name}' ({collection_id})." + ) diff --git a/src/palace/manager/marc/__init__.py b/src/palace/manager/marc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/palace/manager/core/marc.py b/src/palace/manager/marc/annotator.py similarity index 60% rename from src/palace/manager/core/marc.py rename to src/palace/manager/marc/annotator.py index d6720d7f72..47446955f0 100644 --- a/src/palace/manager/core/marc.py +++ b/src/palace/manager/marc/annotator.py @@ -2,40 +2,20 @@ import re import urllib.parse -from collections.abc import Mapping -from datetime import datetime -from io import BytesIO -from uuid import UUID, uuid4 +from collections.abc import Mapping, Sequence -import pytz -from pydantic import NonNegativeInt from pymarc import Field, Indicators, Record, Subfield -from sqlalchemy import select -from sqlalchemy.engine import ScalarResult -from sqlalchemy.orm.session import Session +from sqlalchemy.orm import Session from palace.manager.core.classifier import Classifier -from palace.manager.integration.base import HasLibraryIntegrationConfiguration -from palace.manager.integration.settings import ( - BaseSettings, - ConfigurationFormItem, - ConfigurationFormItemType, - FormField, -) -from palace.manager.service.storage.s3 import S3Service -from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.edition import Edition from palace.manager.sqlalchemy.model.identifier import Identifier -from palace.manager.sqlalchemy.model.library import Library from palace.manager.sqlalchemy.model.licensing import DeliveryMechanism, LicensePool -from palace.manager.sqlalchemy.model.marcfile import MarcFile from palace.manager.sqlalchemy.model.resource import Representation from palace.manager.sqlalchemy.model.work import Work -from palace.manager.sqlalchemy.util import create from palace.manager.util.datetime_helpers import utc_now from palace.manager.util.languages import LanguageCodes from palace.manager.util.log import LoggerMixin -from palace.manager.util.uuid import uuid_encode class Annotator(LoggerMixin): @@ -63,83 +43,90 @@ class Annotator(LoggerMixin): (Representation.PDF_MEDIA_TYPE, DeliveryMechanism.ADOBE_DRM): "Adobe PDF eBook", } - def __init__( - self, - cm_url: str, + @classmethod + def marc_record(cls, work: Work, license_pool: LicensePool) -> Record: + edition = license_pool.presentation_edition + identifier = license_pool.identifier + + record = cls._record() + cls.add_control_fields(record, identifier, license_pool, edition) + cls.add_isbn(record, identifier) + + # TODO: The 240 and 130 fields are for translated works, so they can be grouped even + # though they have different titles. We do not group editions of the same work in + # different languages, so we can't use those yet. + + cls.add_title(record, edition) + cls.add_contributors(record, edition) + cls.add_publisher(record, edition) + cls.add_physical_description(record, edition) + cls.add_audience(record, work) + cls.add_series(record, edition) + cls.add_system_details(record) + cls.add_ebooks_subject(record) + cls.add_distributor(record, license_pool) + cls.add_formats(record, license_pool) + cls.add_summary(record, work) + cls.add_genres(record, work) + + return record + + @classmethod + def library_marc_record( + cls, + record: Record, + identifier: Identifier, + base_url: str, library_short_name: str, - web_client_urls: list[str], + web_client_urls: Sequence[str], organization_code: str | None, include_summary: bool, include_genres: bool, - ) -> None: - self.cm_url = cm_url - self.library_short_name = library_short_name - self.web_client_urls = web_client_urls - self.organization_code = organization_code - self.include_summary = include_summary - self.include_genres = include_genres - - def annotate_work_record( - self, - revised: bool, - work: Work, - active_license_pool: LicensePool, - edition: Edition, - identifier: Identifier, ) -> Record: - """Add metadata from this work to a MARC record. - - :param revised: Whether this record is being revised. - :param work: The Work whose record is being annotated. - :param active_license_pool: Of all the LicensePools associated with this - Work, the client has expressed interest in this one. - :param edition: The Edition to use when associating bibliographic - metadata with this entry. - :param identifier: Of all the Identifiers associated with this - Work, the client has expressed interest in this one. - - :return: A pymarc Record object. - """ - record = Record(leader=self.leader(revised), force_utf8=True) - self.add_control_fields(record, identifier, active_license_pool, edition) - self.add_isbn(record, identifier) + record = cls._copy_record(record) - # TODO: The 240 and 130 fields are for translated works, so they can be grouped even - # though they have different titles. We do not group editions of the same work in - # different languages, so we can't use those yet. + if organization_code: + cls.add_marc_organization_code(record, organization_code) - self.add_title(record, edition) - self.add_contributors(record, edition) - self.add_publisher(record, edition) - self.add_physical_description(record, edition) - self.add_audience(record, work) - self.add_series(record, edition) - self.add_system_details(record) - self.add_ebooks_subject(record) - self.add_distributor(record, active_license_pool) - self.add_formats(record, active_license_pool) + fields_to_remove = [] - if self.organization_code: - self.add_marc_organization_code(record, self.organization_code) + if not include_summary: + fields_to_remove.append("520") - if self.include_summary: - self.add_summary(record, work) + if not include_genres: + fields_to_remove.append("650") - if self.include_genres: - self.add_genres(record, work) + if fields_to_remove: + record.remove_fields(*fields_to_remove) - self.add_web_client_urls( + cls.add_web_client_urls( record, identifier, - self.library_short_name, - self.cm_url, - self.web_client_urls, + library_short_name, + base_url, + web_client_urls, ) return record @classmethod - def leader(cls, revised: bool) -> str: + def _record(cls, leader: str | None = None) -> Record: + leader = leader or cls.leader() + return Record(leader=leader, force_utf8=True) + + @classmethod + def _copy_record(cls, record: Record) -> Record: + copied = cls._record(record.leader) + copied.add_field(*record.get_fields()) + return copied + + @classmethod + def set_revised(cls, record: Record, revised: bool = True) -> Record: + record.leader.record_status = "c" if revised else "n" + return record + + @classmethod + def leader(cls, revised: bool = False) -> str: # The record length is automatically updated once fields are added. initial_record_length = "00000" @@ -558,20 +545,20 @@ def add_web_client_urls( record: Record, identifier: Identifier, library_short_name: str, - cm_url: str, - web_client_urls: list[str], + base_url: str, + web_client_urls: Sequence[str], ) -> None: qualified_identifier = urllib.parse.quote( f"{identifier.type}/{identifier.identifier}", safe="" ) + link = "{}/{}/works/{}".format( + base_url, + library_short_name, + qualified_identifier, + ) + encoded_link = urllib.parse.quote(link, safe="") for web_client_base_url in web_client_urls: - link = "{}/{}/works/{}".format( - cm_url, - library_short_name, - qualified_identifier, - ) - encoded_link = urllib.parse.quote(link, safe="") url = f"{web_client_base_url}/book/{encoded_link}" record.add_field( Field( @@ -580,244 +567,3 @@ def add_web_client_urls( subfields=[Subfield(code="u", value=url)], ) ) - - -class MarcExporterSettings(BaseSettings): - # This setting (in days) controls how often MARC files should be - # automatically updated. Since the crontab in docker isn't easily - # configurable, we can run a script daily but check this to decide - # whether to do anything. - update_frequency: NonNegativeInt = FormField( - 30, - form=ConfigurationFormItem( - label="Update frequency (in days)", - type=ConfigurationFormItemType.NUMBER, - required=True, - ), - alias="marc_update_frequency", - ) - - -class MarcExporterLibrarySettings(BaseSettings): - # MARC organization codes are assigned by the - # Library of Congress and can be found here: - # http://www.loc.gov/marc/organizations/org-search.php - organization_code: str | None = FormField( - None, - form=ConfigurationFormItem( - label="The MARC organization code for this library (003 field).", - description="MARC organization codes are assigned by the Library of Congress.", - type=ConfigurationFormItemType.TEXT, - ), - alias="marc_organization_code", - ) - - web_client_url: str | None = FormField( - None, - form=ConfigurationFormItem( - label="The base URL for the web catalog for this library, for the 856 field.", - description="If using a library registry that provides a web catalog, this can be left blank.", - type=ConfigurationFormItemType.TEXT, - ), - alias="marc_web_client_url", - ) - - include_summary: bool = FormField( - False, - form=ConfigurationFormItem( - label="Include summaries in MARC records (520 field)", - type=ConfigurationFormItemType.SELECT, - options={"false": "Do not include summaries", "true": "Include summaries"}, - ), - ) - - include_genres: bool = FormField( - False, - form=ConfigurationFormItem( - label="Include Palace Collection Manager genres in MARC records (650 fields)", - type=ConfigurationFormItemType.SELECT, - options={ - "false": "Do not include Palace Collection Manager genres", - "true": "Include Palace Collection Manager genres", - }, - ), - alias="include_simplified_genres", - ) - - -class MARCExporter( - HasLibraryIntegrationConfiguration[ - MarcExporterSettings, MarcExporterLibrarySettings - ], - LoggerMixin, -): - """Turn a work into a record for a MARC file.""" - - # The minimum size each piece of a multipart upload should be - MINIMUM_UPLOAD_BATCH_SIZE_BYTES = 5 * 1024 * 1024 # 5MB - - def __init__( - self, - _db: Session, - storage_service: S3Service, - ): - self._db = _db - self.storage_service = storage_service - - @classmethod - def label(cls) -> str: - return "MARC Export" - - @classmethod - def description(cls) -> str: - return ( - "Export metadata into MARC files that can be imported into an ILS manually." - ) - - @classmethod - def settings_class(cls) -> type[MarcExporterSettings]: - return MarcExporterSettings - - @classmethod - def library_settings_class(cls) -> type[MarcExporterLibrarySettings]: - return MarcExporterLibrarySettings - - @classmethod - def create_record( - cls, - revised: bool, - work: Work, - annotator: Annotator, - ) -> Record | None: - """Build a complete MARC record for a given work.""" - pool = work.active_license_pool() - if not pool: - return None - - edition = pool.presentation_edition - identifier = pool.identifier - - return annotator.annotate_work_record(revised, work, pool, edition, identifier) - - @staticmethod - def _date_to_string(date: datetime) -> str: - return date.astimezone(pytz.UTC).strftime("%Y-%m-%d") - - def _file_key( - self, - uuid: UUID, - library: Library, - collection: Collection, - creation_time: datetime, - since_time: datetime | None = None, - ) -> str: - """The path to the hosted MARC file for the given library, collection, - and date range.""" - root = "marc" - short_name = str(library.short_name) - creation = self._date_to_string(creation_time) - - if since_time: - file_type = f"delta.{self._date_to_string(since_time)}.{creation}" - else: - file_type = f"full.{creation}" - - uuid_encoded = uuid_encode(uuid) - collection_name = collection.name.replace(" ", "_") - filename = f"{collection_name}.{file_type}.{uuid_encoded}.mrc" - parts = [root, short_name, filename] - return "/".join(parts) - - def query_works( - self, - collection: Collection, - since_time: datetime | None, - creation_time: datetime, - batch_size: int, - ) -> ScalarResult: - query = ( - select(Work) - .join(LicensePool) - .join(Collection) - .where( - Collection.id == collection.id, - Work.last_update_time <= creation_time, - ) - ) - - if since_time is not None: - query = query.where(Work.last_update_time >= since_time) - - return self._db.execute(query).unique().yield_per(batch_size).scalars() - - def records( - self, - library: Library, - collection: Collection, - annotator: Annotator, - *, - creation_time: datetime, - since_time: datetime | None = None, - batch_size: int = 500, - ) -> None: - """ - Create and export a MARC file for the books in a collection. - """ - uuid = uuid4() - key = self._file_key(uuid, library, collection, creation_time, since_time) - - with self.storage_service.multipart( - key, - content_type=Representation.MARC_MEDIA_TYPE, - ) as upload: - this_batch = BytesIO() - - works = self.query_works(collection, since_time, creation_time, batch_size) - for work in works: - # Create a record for each work and add it to the MARC file in progress. - record = self.create_record( - since_time is not None, - work, - annotator, - ) - if record: - record_bytes = record.as_marc() - this_batch.write(record_bytes) - if ( - this_batch.getbuffer().nbytes - >= self.MINIMUM_UPLOAD_BATCH_SIZE_BYTES - ): - # We've reached or exceeded the upload threshold. - # Upload one part of the multipart document. - upload.upload_part(this_batch.getvalue()) - this_batch.seek(0) - this_batch.truncate() - - # Upload the final part of the multi-document, if - # necessary. - if this_batch.getbuffer().nbytes > 0: - upload.upload_part(this_batch.getvalue()) - - if upload.complete: - create( - self._db, - MarcFile, - id=uuid, - library=library, - collection=collection, - created=creation_time, - since=since_time, - key=key, - ) - else: - if upload.exception: - # Log the exception and move on to the next file. We will try again next script run. - self.log.error( - f"Failed to upload MARC file for {library.short_name}/{collection.name}: {upload.exception}", - exc_info=upload.exception, - ) - else: - # There were no records to upload. This is not an error, but we should log it. - self.log.info( - f"No MARC records to upload for {library.short_name}/{collection.name}." - ) diff --git a/src/palace/manager/marc/exporter.py b/src/palace/manager/marc/exporter.py new file mode 100644 index 0000000000..13a587a7ba --- /dev/null +++ b/src/palace/manager/marc/exporter.py @@ -0,0 +1,373 @@ +from __future__ import annotations + +import datetime +from collections.abc import Iterable, Sequence +from uuid import UUID, uuid4 + +import pytz +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.orm import Session, aliased + +from palace.manager.integration.base import HasLibraryIntegrationConfiguration +from palace.manager.integration.goals import Goals +from palace.manager.marc.annotator import Annotator +from palace.manager.marc.settings import ( + MarcExporterLibrarySettings, + MarcExporterSettings, +) +from palace.manager.marc.uploader import MarcUploadManager +from palace.manager.service.integration_registry.catalog_services import ( + CatalogServicesRegistry, +) +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.sqlalchemy.model.discovery_service_registration import ( + DiscoveryServiceRegistration, +) +from palace.manager.sqlalchemy.model.integration import ( + IntegrationConfiguration, + IntegrationLibraryConfiguration, +) +from palace.manager.sqlalchemy.model.library import Library +from palace.manager.sqlalchemy.model.licensing import LicensePool +from palace.manager.sqlalchemy.model.marcfile import MarcFile +from palace.manager.sqlalchemy.model.work import Work +from palace.manager.sqlalchemy.util import create +from palace.manager.util.datetime_helpers import utc_now +from palace.manager.util.log import LoggerMixin +from palace.manager.util.uuid import uuid_encode + + +class LibraryInfo(BaseModel): + library_id: int + library_short_name: str + last_updated: datetime.datetime | None + needs_update: bool + organization_code: str | None + include_summary: bool + include_genres: bool + web_client_urls: tuple[str, ...] + + s3_key_full_uuid: str + s3_key_full: str + + s3_key_delta_uuid: str + s3_key_delta: str | None = None + + class Config: + frozen = True + + +class MarcExporter( + HasLibraryIntegrationConfiguration[ + MarcExporterSettings, MarcExporterLibrarySettings + ], + LoggerMixin, +): + """ + This class provides the logic for exporting MARC records for a collection to S3. + """ + + @classmethod + def label(cls) -> str: + return "MARC Export" + + @classmethod + def description(cls) -> str: + return ( + "Export metadata into MARC files that can be imported into an ILS manually." + ) + + @classmethod + def settings_class(cls) -> type[MarcExporterSettings]: + return MarcExporterSettings + + @classmethod + def library_settings_class(cls) -> type[MarcExporterLibrarySettings]: + return MarcExporterLibrarySettings + + @staticmethod + def _s3_key( + library: Library, + collection: Collection, + creation_time: datetime.datetime, + uuid: UUID, + since_time: datetime.datetime | None = None, + ) -> str: + """The path to the hosted MARC file for the given library, collection, + and date range.""" + + def date_to_string(date: datetime.datetime) -> str: + return date.astimezone(pytz.UTC).strftime("%Y-%m-%d") + + root = "marc" + short_name = str(library.short_name) + creation = date_to_string(creation_time) + + if since_time: + file_type = f"delta.{date_to_string(since_time)}.{creation}" + else: + file_type = f"full.{creation}" + + uuid_encoded = uuid_encode(uuid) + collection_name = collection.name.replace(" ", "_") + filename = f"{collection_name}.{file_type}.{uuid_encoded}.mrc" + parts = [root, short_name, filename] + return "/".join(parts) + + @staticmethod + def _needs_update( + last_updated_time: datetime.datetime | None, update_frequency: int + ) -> bool: + return not last_updated_time or ( + last_updated_time.date() + <= (utc_now() - datetime.timedelta(days=update_frequency)).date() + ) + + @staticmethod + def _web_client_urls( + session: Session, library: Library, url: str | None = None + ) -> tuple[str, ...]: + """Find web client URLs configured by the registry for this library.""" + urls = [ + s.web_client + for s in session.execute( + select(DiscoveryServiceRegistration.web_client).where( + DiscoveryServiceRegistration.library == library, + DiscoveryServiceRegistration.web_client != None, + ) + ).all() + ] + + if url: + urls.append(url) + + return tuple(urls) + + @classmethod + def _enabled_collections_and_libraries( + cls, + session: Session, + registry: CatalogServicesRegistry, + collection_id: int | None = None, + ) -> set[tuple[Collection, IntegrationLibraryConfiguration]]: + collection_integration_configuration = aliased(IntegrationConfiguration) + collection_integration_library_configuration = aliased( + IntegrationLibraryConfiguration + ) + library_integration_library_configuration = aliased( + IntegrationLibraryConfiguration, + name="library_integration_library_configuration", + ) + library_integration_configuration = aliased(IntegrationConfiguration) + + protocols = registry.get_protocols(cls) + + collection_query = ( + select(Collection, library_integration_library_configuration) + .select_from(Collection) + .join(collection_integration_configuration) + .join(collection_integration_library_configuration) + .join(Library) + .join(library_integration_library_configuration) + .join(library_integration_configuration) + .where( + Collection.export_marc_records == True, + library_integration_configuration.goal == Goals.CATALOG_GOAL, + library_integration_configuration.protocol.in_(protocols), + ) + ) + if collection_id is not None: + collection_query = collection_query.where(Collection.id == collection_id) + return { + (r.Collection, r.library_integration_library_configuration) + for r in session.execute(collection_query) + } + + @staticmethod + def _last_updated( + session: Session, library: Library, collection: Collection + ) -> datetime.datetime | None: + """Find the most recent MarcFile creation time.""" + last_updated_file = session.execute( + select(MarcFile.created) + .where( + MarcFile.library == library, + MarcFile.collection == collection, + ) + .order_by(MarcFile.created.desc()) + ).first() + + return last_updated_file.created if last_updated_file else None + + @classmethod + def enabled_collections( + cls, session: Session, registry: CatalogServicesRegistry + ) -> set[Collection]: + return {c for c, _ in cls._enabled_collections_and_libraries(session, registry)} + + @classmethod + def enabled_libraries( + cls, session: Session, registry: CatalogServicesRegistry, collection_id: int + ) -> Sequence[LibraryInfo]: + library_info = [] + creation_time = utc_now() + for collection, library_integration in cls._enabled_collections_and_libraries( + session, registry, collection_id + ): + library = library_integration.library + library_id = library.id + library_short_name = library.short_name + if library_id is None or library_short_name is None: + cls.logger().warning( + f"Library {library} is missing an ID or short name." + ) + continue + last_updated_time = cls._last_updated(session, library, collection) + update_frequency = cls.settings_load( + library_integration.parent + ).update_frequency + library_settings = cls.library_settings_load(library_integration) + needs_update = cls._needs_update(last_updated_time, update_frequency) + web_client_urls = cls._web_client_urls( + session, library, library_settings.web_client_url + ) + s3_key_full_uuid = uuid4() + s3_key_full = cls._s3_key( + library, + collection, + creation_time, + s3_key_full_uuid, + ) + s3_key_delta_uuid = uuid4() + s3_key_delta = ( + cls._s3_key( + library, + collection, + creation_time, + s3_key_delta_uuid, + since_time=last_updated_time, + ) + if last_updated_time + else None + ) + library_info.append( + LibraryInfo( + library_id=library_id, + library_short_name=library_short_name, + last_updated=last_updated_time, + needs_update=needs_update, + organization_code=library_settings.organization_code, + include_summary=library_settings.include_summary, + include_genres=library_settings.include_genres, + web_client_urls=web_client_urls, + s3_key_full_uuid=str(s3_key_full_uuid), + s3_key_full=s3_key_full, + s3_key_delta_uuid=str(s3_key_delta_uuid), + s3_key_delta=s3_key_delta, + ) + ) + library_info.sort(key=lambda info: info.library_id) + return library_info + + @staticmethod + def query_works( + session: Session, + collection_id: int, + work_id_offset: int | None, + batch_size: int, + ) -> list[Work]: + query = ( + select(Work) + .join(LicensePool) + .where( + LicensePool.collection_id == collection_id, + ) + .limit(batch_size) + .order_by(Work.id.asc()) + ) + + if work_id_offset is not None: + query = query.where(Work.id > work_id_offset) + + return session.execute(query).scalars().unique().all() + + @staticmethod + def collection(session: Session, collection_id: int) -> Collection | None: + return session.execute( + select(Collection).where(Collection.id == collection_id) + ).scalar_one_or_none() + + @classmethod + def process_work( + cls, + work: Work, + libraries_info: Iterable[LibraryInfo], + base_url: str, + *, + upload_manager: MarcUploadManager, + annotator: type[Annotator] = Annotator, + ) -> None: + pool = work.active_license_pool() + if pool is None: + return + base_record = annotator.marc_record(work, pool) + + for library_info in libraries_info: + library_record = annotator.library_marc_record( + base_record, + pool.identifier, + base_url, + library_info.library_short_name, + library_info.web_client_urls, + library_info.organization_code, + library_info.include_summary, + library_info.include_genres, + ) + + upload_manager.add_record( + library_info.s3_key_full, + library_record.as_marc(), + ) + + if ( + library_info.last_updated + and library_info.s3_key_delta + and work.last_update_time + and work.last_update_time > library_info.last_updated + ): + upload_manager.add_record( + library_info.s3_key_delta, + annotator.set_revised(library_record).as_marc(), + ) + + @staticmethod + def create_marc_upload_records( + session: Session, + start_time: datetime.datetime, + collection_id: int, + libraries_info: Iterable[LibraryInfo], + uploaded_keys: set[str], + ) -> None: + for library_info in libraries_info: + if library_info.s3_key_full in uploaded_keys: + create( + session, + MarcFile, + id=library_info.s3_key_full_uuid, + library_id=library_info.library_id, + collection_id=collection_id, + created=start_time, + key=library_info.s3_key_full, + ) + if library_info.s3_key_delta and library_info.s3_key_delta in uploaded_keys: + create( + session, + MarcFile, + id=library_info.s3_key_delta_uuid, + library_id=library_info.library_id, + collection_id=collection_id, + created=start_time, + since=library_info.last_updated, + key=library_info.s3_key_delta, + ) diff --git a/src/palace/manager/marc/settings.py b/src/palace/manager/marc/settings.py new file mode 100644 index 0000000000..4412876fe4 --- /dev/null +++ b/src/palace/manager/marc/settings.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from pydantic import NonNegativeInt + +from palace.manager.integration.settings import ( + BaseSettings, + ConfigurationFormItem, + ConfigurationFormItemType, + FormField, +) + + +class MarcExporterSettings(BaseSettings): + # This setting (in days) controls how often MARC files should be + # automatically updated. We run the celery task to update the MARC + # files on a schedule, but this setting easily allows admins to + # generate files more or less often. + update_frequency: NonNegativeInt = FormField( + 30, + form=ConfigurationFormItem( + label="Update frequency (in days)", + type=ConfigurationFormItemType.NUMBER, + required=True, + ), + alias="marc_update_frequency", + ) + + +class MarcExporterLibrarySettings(BaseSettings): + # MARC organization codes are assigned by the + # Library of Congress and can be found here: + # http://www.loc.gov/marc/organizations/org-search.php + organization_code: str | None = FormField( + None, + form=ConfigurationFormItem( + label="The MARC organization code for this library (003 field).", + description="MARC organization codes are assigned by the Library of Congress.", + type=ConfigurationFormItemType.TEXT, + ), + alias="marc_organization_code", + ) + + web_client_url: str | None = FormField( + None, + form=ConfigurationFormItem( + label="The base URL for the web catalog for this library, for the 856 field.", + description="If using a library registry that provides a web catalog, this can be left blank.", + type=ConfigurationFormItemType.TEXT, + ), + alias="marc_web_client_url", + ) + + include_summary: bool = FormField( + False, + form=ConfigurationFormItem( + label="Include summaries in MARC records (520 field)", + type=ConfigurationFormItemType.SELECT, + options={"false": "Do not include summaries", "true": "Include summaries"}, + ), + ) + + include_genres: bool = FormField( + False, + form=ConfigurationFormItem( + label="Include Palace Collection Manager genres in MARC records (650 fields)", + type=ConfigurationFormItemType.SELECT, + options={ + "false": "Do not include Palace Collection Manager genres", + "true": "Include Palace Collection Manager genres", + }, + ), + alias="include_simplified_genres", + ) diff --git a/src/palace/manager/marc/uploader.py b/src/palace/manager/marc/uploader.py new file mode 100644 index 0000000000..81677977dd --- /dev/null +++ b/src/palace/manager/marc/uploader.py @@ -0,0 +1,144 @@ +from collections import defaultdict +from collections.abc import Generator, Sequence +from contextlib import contextmanager + +from celery.exceptions import Ignore, Retry +from typing_extensions import Self + +from palace.manager.service.redis.models.marc import MarcFileUploadSession +from palace.manager.service.storage.s3 import S3Service +from palace.manager.sqlalchemy.model.resource import Representation +from palace.manager.util.log import LoggerMixin + + +class MarcUploadManager(LoggerMixin): + """ + This class is used to manage the upload of MARC files to S3. The upload is done in multiple + parts, so that the Celery task can be broken up into multiple steps, saving the progress + between steps to redis, and flushing them to S3 when the buffer is large enough. + + This class orchestrates the upload process, delegating the redis operation to the + `MarcFileUploadSession` class, and the S3 upload to the `S3Service` class. + """ + + def __init__( + self, storage_service: S3Service, upload_session: MarcFileUploadSession + ): + self.storage_service = storage_service + self.upload_session = upload_session + self._buffers: defaultdict[str, str] = defaultdict(str) + self._locked = False + + @property + def locked(self) -> bool: + return self._locked + + @property + def update_number(self) -> int: + return self.upload_session.update_number + + def add_record(self, key: str, record: bytes) -> None: + self._buffers[key] += record.decode() + + def _s3_sync(self, needs_upload: Sequence[str]) -> None: + upload_ids = self.upload_session.get_upload_ids(needs_upload) + for key in needs_upload: + if upload_ids.get(key) is None: + upload_id = self.storage_service.multipart_create( + key, content_type=Representation.MARC_MEDIA_TYPE + ) + self.upload_session.set_upload_id(key, upload_id) + upload_ids[key] = upload_id + + part_number, data = self.upload_session.get_part_num_and_buffer(key) + upload_part = self.storage_service.multipart_upload( + key, upload_ids[key], part_number, data.encode() + ) + self.upload_session.add_part_and_clear_buffer(key, upload_part) + + def sync(self) -> None: + # First sync our buffers to redis + buffer_lengths = self.upload_session.append_buffers(self._buffers) + self._buffers.clear() + + # Then, if any of our redis buffers are large enough, upload them to S3 + needs_upload = [ + key + for key, length in buffer_lengths.items() + if length > self.storage_service.MINIMUM_MULTIPART_UPLOAD_SIZE + ] + + if not needs_upload: + return + + self._s3_sync(needs_upload) + + def _abort(self) -> None: + in_progress = self.upload_session.get() + for key, upload in in_progress.items(): + if upload.upload_id is None: + # This upload has not started, so there is nothing to abort. + continue + try: + self.storage_service.multipart_abort(key, upload.upload_id) + except Exception as e: + # We log and keep going, since we want to abort as many uploads as possible + # even if some fail, this is likely already being called in an exception handler. + # So we want to do as much cleanup as possible. + self.log.exception( + f"Failed to abort upload {key} (UploadID: {upload.upload_id}) due to exception ({e})." + ) + + # Delete our in-progress uploads from redis as well + self.remove_session() + + def complete(self) -> set[str]: + # Make sure any local data we have is synced + self.sync() + + in_progress = self.upload_session.get() + for key, upload in in_progress.items(): + if upload.upload_id is None: + # We haven't started the upload. At this point there is no reason to start a + # multipart upload, just upload the file directly and continue. + self.storage_service.store( + key, upload.buffer, Representation.MARC_MEDIA_TYPE + ) + else: + if upload.buffer != "": + # Upload the last chunk if the buffer is not empty, the final part has no + # minimum size requirement. + upload_part = self.storage_service.multipart_upload( + key, upload.upload_id, len(upload.parts), upload.buffer.encode() + ) + upload.parts.append(upload_part) + + # Complete the multipart upload + self.storage_service.multipart_complete( + key, upload.upload_id, upload.parts + ) + + # Delete our in-progress uploads data from redis + if in_progress: + self.upload_session.clear_uploads() + + # Return the keys that were uploaded + return set(in_progress.keys()) + + def remove_session(self) -> None: + self.upload_session.delete() + + @contextmanager + def begin(self) -> Generator[Self, None, None]: + self._locked = self.upload_session.acquire() + try: + yield self + except Exception as e: + # We want to ignore any celery exceptions that are expected, but + # handle cleanup for any other cases. + if not isinstance(e, (Retry, Ignore)): + self._abort() + raise + finally: + self.upload_session.release() + self._locked = False diff --git a/src/palace/manager/scripts/marc.py b/src/palace/manager/scripts/marc.py deleted file mode 100644 index 572c6fe761..0000000000 --- a/src/palace/manager/scripts/marc.py +++ /dev/null @@ -1,223 +0,0 @@ -from __future__ import annotations - -import argparse -import datetime -from collections.abc import Sequence -from datetime import timedelta -from typing import Any - -from sqlalchemy import select -from sqlalchemy.exc import NoResultFound -from sqlalchemy.orm import Session - -from palace.manager.core.config import CannotLoadConfiguration -from palace.manager.core.marc import Annotator as MarcAnnotator -from palace.manager.core.marc import ( - MARCExporter, - MarcExporterLibrarySettings, - MarcExporterSettings, -) -from palace.manager.integration.goals import Goals -from palace.manager.scripts.input import LibraryInputScript -from palace.manager.sqlalchemy.model.collection import Collection -from palace.manager.sqlalchemy.model.discovery_service_registration import ( - DiscoveryServiceRegistration, -) -from palace.manager.sqlalchemy.model.integration import ( - IntegrationConfiguration, - IntegrationLibraryConfiguration, -) -from palace.manager.sqlalchemy.model.library import Library -from palace.manager.sqlalchemy.model.marcfile import MarcFile -from palace.manager.util.datetime_helpers import utc_now - - -class CacheMARCFiles(LibraryInputScript): - """Generate and cache MARC files for each input library.""" - - name = "Cache MARC files" - - @classmethod - def arg_parser(cls, _db: Session) -> argparse.ArgumentParser: # type: ignore[override] - parser = super().arg_parser(_db) - parser.add_argument( - "--force", - help="Generate new MARC files even if MARC files have already been generated recently enough", - dest="force", - action="store_true", - ) - return parser - - def __init__( - self, - _db: Session | None = None, - cmd_args: Sequence[str] | None = None, - exporter: MARCExporter | None = None, - *args: Any, - **kwargs: Any, - ) -> None: - super().__init__(_db, *args, **kwargs) - self.force = False - self.parse_args(cmd_args) - self.storage_service = self.services.storage.public() - self.base_url = self.services.config.sitewide.base_url() - if self.base_url is None: - raise CannotLoadConfiguration( - f"Missing required environment variable: PALACE_BASE_URL." - ) - - self.exporter = exporter or MARCExporter(self._db, self.storage_service) - - def parse_args(self, cmd_args: Sequence[str] | None = None) -> argparse.Namespace: - parser = self.arg_parser(self._db) - parsed = parser.parse_args(cmd_args) - self.force = parsed.force - return parsed - - def settings( - self, library: Library - ) -> tuple[MarcExporterSettings, MarcExporterLibrarySettings]: - integration_query = ( - select(IntegrationLibraryConfiguration) - .join(IntegrationConfiguration) - .where( - IntegrationConfiguration.goal == Goals.CATALOG_GOAL, - IntegrationConfiguration.protocol == MARCExporter.__name__, - IntegrationLibraryConfiguration.library == library, - ) - ) - integration = self._db.execute(integration_query).scalar_one() - - library_settings = MARCExporter.library_settings_load(integration) - settings = MARCExporter.settings_load(integration.parent) - - return settings, library_settings - - def process_libraries(self, libraries: Sequence[Library]) -> None: - if not self.storage_service: - self.log.info("No storage service was found.") - return - - super().process_libraries(libraries) - - def get_collections(self, library: Library) -> Sequence[Collection]: - return self._db.scalars( - select(Collection).where( - Collection.libraries.contains(library), - Collection.export_marc_records == True, - ) - ).all() - - def get_web_client_urls( - self, library: Library, url: str | None = None - ) -> list[str]: - """Find web client URLs configured by the registry for this library.""" - urls = [ - s.web_client - for s in self._db.execute( - select(DiscoveryServiceRegistration.web_client).where( - DiscoveryServiceRegistration.library == library, - DiscoveryServiceRegistration.web_client != None, - ) - ).all() - ] - - if url: - urls.append(url) - - return urls - - def process_library( - self, library: Library, annotator_cls: type[MarcAnnotator] = MarcAnnotator - ) -> None: - try: - settings, library_settings = self.settings(library) - except NoResultFound: - return - - self.log.info("Processing library %s" % library.name) - - update_frequency = int(settings.update_frequency) - - # Find the collections for this library. - collections = self.get_collections(library) - - # Find web client URLs configured by the registry for this library. - web_client_urls = self.get_web_client_urls( - library, library_settings.web_client_url - ) - - annotator = annotator_cls( - self.base_url, - library.short_name or "", - web_client_urls, - library_settings.organization_code, - library_settings.include_summary, - library_settings.include_genres, - ) - - # We set the creation time to be the start of the batch. Any updates that happen during the batch will be - # included in the next batch. - creation_time = utc_now() - - for collection in collections: - self.process_collection( - library, - collection, - annotator, - update_frequency, - creation_time, - ) - - def last_updated( - self, library: Library, collection: Collection - ) -> datetime.datetime | None: - """Find the most recent MarcFile creation time.""" - last_updated_file = self._db.execute( - select(MarcFile.created) - .where( - MarcFile.library == library, - MarcFile.collection == collection, - ) - .order_by(MarcFile.created.desc()) - ).first() - - return last_updated_file.created if last_updated_file else None - - def process_collection( - self, - library: Library, - collection: Collection, - annotator: MarcAnnotator, - update_frequency: int, - creation_time: datetime.datetime, - ) -> None: - last_update = self.last_updated(library, collection) - - if ( - not self.force - and last_update - and (last_update > creation_time - timedelta(days=update_frequency)) - ): - self.log.info( - f"Skipping collection {collection.name} because last update was less than {update_frequency} days ago" - ) - return - - # First update the file with ALL the records. - self.exporter.records( - library, collection, annotator, creation_time=creation_time - ) - - # Then create a new file with changes since the last update. - if last_update: - self.exporter.records( - library, - collection, - annotator, - creation_time=creation_time, - since_time=last_update, - ) - - self._db.commit() - self.log.info("Processed collection %s" % collection.name) diff --git a/src/palace/manager/service/celery/celery.py b/src/palace/manager/service/celery/celery.py index adb4486d61..cb73dba995 100644 --- a/src/palace/manager/service/celery/celery.py +++ b/src/palace/manager/service/celery/celery.py @@ -37,6 +37,12 @@ def beat_schedule() -> dict[str, Any]: "task": "search.search_indexing", "schedule": crontab(minute="*"), # Run every minute }, + "marc_export": { + "task": "marc.marc_export", + "schedule": crontab( + hour="3,11", minute="0" + ), # Run twice a day at 3:00 AM and 11:00 AM + }, } diff --git a/src/palace/manager/service/integration_registry/catalog_services.py b/src/palace/manager/service/integration_registry/catalog_services.py index 913b8c4f1f..1a6593d62f 100644 --- a/src/palace/manager/service/integration_registry/catalog_services.py +++ b/src/palace/manager/service/integration_registry/catalog_services.py @@ -1,9 +1,17 @@ -from palace.manager.core.marc import MARCExporter +from __future__ import annotations + +from typing import TYPE_CHECKING + from palace.manager.integration.goals import Goals from palace.manager.service.integration_registry.base import IntegrationRegistry +if TYPE_CHECKING: + from palace.manager.marc.exporter import MarcExporter # noqa: autoflake -class CatalogServicesRegistry(IntegrationRegistry[MARCExporter]): + +class CatalogServicesRegistry(IntegrationRegistry["MarcExporter"]): def __init__(self) -> None: + from palace.manager.marc.exporter import MarcExporter + super().__init__(Goals.CATALOG_GOAL) - self.register(MARCExporter) + self.register(MarcExporter, aliases=["MARCExporter"]) diff --git a/src/palace/manager/service/redis/models/lock.py b/src/palace/manager/service/redis/models/lock.py index ef4c348aff..1fa232ef27 100644 --- a/src/palace/manager/service/redis/models/lock.py +++ b/src/palace/manager/service/redis/models/lock.py @@ -1,11 +1,12 @@ +import json import random import time from abc import ABC, abstractmethod -from collections.abc import Generator, Sequence +from collections.abc import Generator, Mapping, Sequence from contextlib import contextmanager from datetime import timedelta from functools import cached_property -from typing import cast +from typing import TypeVar, cast from uuid import uuid4 from palace.manager.celery.task import Task @@ -232,3 +233,201 @@ def __init__( else: name = [lock_name] super().__init__(redis_client, name, random_value, lock_timeout, retry_delay) + + +class RedisJsonLock(BaseRedisLock, ABC): + _GET_LOCK_FUNCTION = """ + local function get_lock_value(key, json_key) + local value = redis.call("json.get", key, json_key) + if not value then + return nil + end + return cjson.decode(value)[1] + end + """ + + _ACQUIRE_SCRIPT = f""" + {_GET_LOCK_FUNCTION} + -- If the locks json object doesn't exist, create it with the initial value + redis.call("json.set", KEYS[1], "$", ARGV[4], "nx") + + -- Get the current lock value + local lock_value = get_lock_value(KEYS[1], ARGV[1]) + if not lock_value then + -- The lock isn't currently locked, so we lock it and set the timeout + redis.call("json.set", KEYS[1], ARGV[1], cjson.encode(ARGV[2])) + redis.call("pexpire", KEYS[1], ARGV[3]) + return 1 + elseif lock_value == ARGV[2] then + -- The lock is already held by us, so we extend the timeout + redis.call("pexpire", KEYS[1], ARGV[3]) + return 2 + else + -- The lock is held by someone else, we do nothing + return nil + end + """ + + _RELEASE_SCRIPT = f""" + {_GET_LOCK_FUNCTION} + if get_lock_value(KEYS[1], ARGV[1]) == ARGV[2] then + redis.call("json.del", KEYS[1], ARGV[1]) + return 1 + else + return nil + end + """ + + _EXTEND_SCRIPT = f""" + {_GET_LOCK_FUNCTION} + if get_lock_value(KEYS[1], ARGV[1]) == ARGV[2] then + redis.call("pexpire", KEYS[1], ARGV[3]) + return 1 + else + return nil + end + """ + + _DELETE_SCRIPT = f""" + {_GET_LOCK_FUNCTION} + if get_lock_value(KEYS[1], ARGV[1]) == ARGV[2] then + redis.call("del", KEYS[1]) + return 1 + else + return nil + end + """ + + def __init__( + self, + redis_client: Redis, + random_value: str | None = None, + ): + super().__init__(redis_client, random_value) + + # Register our scripts + self._acquire_script = self._redis_client.register_script(self._ACQUIRE_SCRIPT) + self._release_script = self._redis_client.register_script(self._RELEASE_SCRIPT) + self._extend_script = self._redis_client.register_script(self._EXTEND_SCRIPT) + self._delete_script = self._redis_client.register_script(self._DELETE_SCRIPT) + + @property + @abstractmethod + def _lock_timeout_ms(self) -> int: + """ + The lock timeout in milliseconds. + """ + ... + + @property + def _lock_json_key(self) -> str: + """ + The key to use for the lock value in the JSON object. + + This can be overridden if you need to store the lock value in a different key. It should + be a Redis JSONPath. + See: https://redis.io/docs/latest/develop/data-types/json/path/ + """ + return "$.lock" + + @property + def _initial_value(self) -> str: + """ + The initial value to use for the locks JSON object. + """ + return json.dumps({}) + + T = TypeVar("T") + + @classmethod + def _parse_multi( + cls, value: Mapping[str, Sequence[T]] | None + ) -> dict[str, T | None]: + """ + Helper function that makes it easier to work with the results of a JSON GET command, + where you request multiple keys. + """ + if value is None: + return {} + return {k: cls._parse_value(v) for k, v in value.items()} + + @staticmethod + def _parse_value(value: Sequence[T] | None) -> T | None: + """ + Helper function to parse the value from the results of a JSON GET command, where you + expect the JSONPath to return a single value. + """ + if value is None: + return None + try: + return value[0] + except IndexError: + return None + + @classmethod + def _parse_value_or_raise(cls, value: Sequence[T] | None) -> T: + """ + Wrapper around _parse_value that raises an exception if the value is None. + """ + parsed_value = cls._parse_value(value) + if parsed_value is None: + raise LockError(f"Could not parse value ({json.dumps(value)})") + return parsed_value + + def acquire(self) -> bool: + return ( + self._acquire_script( + keys=(self.key,), + args=( + self._lock_json_key, + self._random_value, + self._lock_timeout_ms, + self._initial_value, + ), + ) + is not None + ) + + def release(self) -> bool: + """ + Release the lock. + + You must have the lock to release it. This will unset the lock value in the JSON object, but importantly + it will not delete the JSON object itself. If you want to delete the JSON object, use the delete method. + """ + return ( + self._release_script( + keys=(self.key,), + args=(self._lock_json_key, self._random_value), + ) + is not None + ) + + def locked(self, by_us: bool = False) -> bool: + lock_value: str | None = self._parse_value( + self._redis_client.json().get(self.key, self._lock_json_key) + ) + if by_us: + return lock_value == self._random_value + return lock_value is not None + + def extend_timeout(self) -> bool: + return ( + self._extend_script( + keys=(self.key,), + args=(self._lock_json_key, self._random_value, self._lock_timeout_ms), + ) + is not None + ) + + def delete(self) -> bool: + """ + Delete the whole json object, including the lock. Must have the lock to delete the object. + """ + return ( + self._delete_script( + keys=(self.key,), + args=(self._lock_json_key, self._random_value), + ) + is not None + ) diff --git a/src/palace/manager/service/redis/models/marc.py b/src/palace/manager/service/redis/models/marc.py new file mode 100644 index 0000000000..6f443b7ceb --- /dev/null +++ b/src/palace/manager/service/redis/models/marc.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import json +import sys +from collections.abc import Callable, Generator, Mapping, Sequence +from contextlib import contextmanager +from enum import auto +from functools import cached_property +from typing import Any + +from pydantic import BaseModel +from redis import ResponseError, WatchError + +from palace.manager.service.redis.models.lock import LockError, RedisJsonLock +from palace.manager.service.redis.redis import Pipeline, Redis +from palace.manager.service.storage.s3 import MultipartS3UploadPart +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.util.log import LoggerMixin + +# TODO: Remove this when we drop support for Python 3.10 +if sys.version_info >= (3, 11): + from enum import StrEnum +else: + from backports.strenum import StrEnum + + +class MarcFileUploadSessionError(LockError): + pass + + +class MarcFileUpload(BaseModel): + buffer: str = "" + upload_id: str | None = None + parts: list[MultipartS3UploadPart] = [] + + +class MarcFileUploadState(StrEnum): + INITIAL = auto() + QUEUED = auto() + UPLOADING = auto() + + +class MarcFileUploadSession(RedisJsonLock, LoggerMixin): + """ + This class is used as a lock for the Celery MARC export task, to ensure that only one + task can upload MARC files for a given collection at a time. It increments an update + number each time an update is made, to guard against corruption if a task gets run + twice. + + It stores the intermediate results of the MARC file generation process, so that the task + can complete in multiple steps, saving the progress between steps to redis, and flushing + them to S3 when the buffer is full. + + This object is focused on the redis part of this operation, the actual s3 upload orchestration + is handled by the `MarcUploadManager` class. + """ + + def __init__( + self, + redis_client: Redis, + collection_id: int, + update_number: int = 0, + ): + super().__init__(redis_client) + self._collection_id = collection_id + self._update_number = update_number + + @cached_property + def key(self) -> str: + return self._redis_client.get_key( + self.__class__.__name__, + Collection.redis_key_from_id(self._collection_id), + ) + + @property + def _lock_timeout_ms(self) -> int: + return 20 * 60 * 1000 # 20 minutes + + @property + def update_number(self) -> int: + return self._update_number + + @property + def _initial_value(self) -> str: + """ + The initial value to use for the locks JSON object. + """ + return json.dumps( + {"uploads": {}, "update_number": 0, "state": MarcFileUploadState.INITIAL} + ) + + @property + def _update_number_json_key(self) -> str: + return "$.update_number" + + @property + def _uploads_json_key(self) -> str: + return "$.uploads" + + @property + def _state_json_key(self) -> str: + return "$.state" + + @staticmethod + def _upload_initial_value(buffer_data: str) -> dict[str, Any]: + return MarcFileUpload(buffer=buffer_data).dict(exclude_none=True) + + def _upload_path(self, upload_key: str) -> str: + return f"{self._uploads_json_key}['{upload_key}']" + + def _buffer_path(self, upload_key: str) -> str: + upload_path = self._upload_path(upload_key) + return f"{upload_path}.buffer" + + def _upload_id_path(self, upload_key: str) -> str: + upload_path = self._upload_path(upload_key) + return f"{upload_path}.upload_id" + + def _parts_path(self, upload_key: str) -> str: + upload_path = self._upload_path(upload_key) + return f"{upload_path}.parts" + + @contextmanager + def _pipeline( + self, begin_transaction: bool = True + ) -> Generator[Pipeline, None, None]: + with self._redis_client.pipeline() as pipe: + pipe.watch(self.key) + fetched_data = self._parse_multi( + pipe.json().get( + self.key, self._lock_json_key, self._update_number_json_key + ) + ) + # Check that we hold the lock + if ( + remote_random := fetched_data.get(self._lock_json_key) + ) != self._random_value: + raise MarcFileUploadSessionError( + f"Must hold lock to update upload session. " + f"Expected: {self._random_value}, got: {remote_random}" + ) + # Check that the update number is correct + if ( + remote_update_number := fetched_data.get(self._update_number_json_key) + ) != self._update_number: + raise MarcFileUploadSessionError( + f"Update number mismatch. " + f"Expected: {self._update_number}, got: {remote_update_number}" + ) + if begin_transaction: + pipe.multi() + yield pipe + + def _execute_pipeline( + self, + pipe: Pipeline, + updates: int, + *, + state: MarcFileUploadState = MarcFileUploadState.UPLOADING, + ) -> list[Any]: + if not pipe.explicit_transaction: + raise MarcFileUploadSessionError( + "Pipeline should be in explicit transaction mode before executing." + ) + pipe.json().set(self.key, path=self._state_json_key, obj=state) + pipe.json().numincrby(self.key, self._update_number_json_key, updates) + pipe.pexpire(self.key, self._lock_timeout_ms) + try: + pipe_results = pipe.execute() + except WatchError as e: + raise MarcFileUploadSessionError( + "Failed to update buffers. Another process is modifying the buffers." + ) from e + self._update_number = self._parse_value_or_raise(pipe_results[-2]) + + return pipe_results[:-3] + + def append_buffers(self, data: Mapping[str, str]) -> dict[str, int]: + if not data: + return {} + + set_results = {} + with self._pipeline(begin_transaction=False) as pipe: + existing_uploads: list[str] = self._parse_value_or_raise( + pipe.json().objkeys(self.key, self._uploads_json_key) + ) + pipe.multi() + for key, value in data.items(): + if value == "": + continue + if key in existing_uploads: + pipe.json().strappend( + self.key, path=self._buffer_path(key), value=value + ) + else: + pipe.json().set( + self.key, + path=self._upload_path(key), + obj=self._upload_initial_value(value), + ) + set_results[key] = len(value) + + pipe_results = self._execute_pipeline(pipe, len(data)) + + if not all(pipe_results): + raise MarcFileUploadSessionError("Failed to append buffers.") + + return { + k: set_results[k] if v is True else self._parse_value_or_raise(v) + for k, v in zip(data.keys(), pipe_results) + } + + def add_part_and_clear_buffer(self, key: str, part: MultipartS3UploadPart) -> None: + with self._pipeline() as pipe: + pipe.json().arrappend( + self.key, + self._parts_path(key), + part.dict(), + ) + pipe.json().set( + self.key, + path=self._buffer_path(key), + obj="", + ) + pipe_results = self._execute_pipeline(pipe, 1) + + if not all(pipe_results): + raise MarcFileUploadSessionError("Failed to add part and clear buffer.") + + def set_upload_id(self, key: str, upload_id: str) -> None: + with self._pipeline() as pipe: + pipe.json().set( + self.key, + path=self._upload_id_path(key), + obj=upload_id, + nx=True, + ) + pipe_results = self._execute_pipeline(pipe, 1) + + if not all(pipe_results): + raise MarcFileUploadSessionError("Failed to set upload ID.") + + def clear_uploads(self) -> None: + with self._pipeline() as pipe: + pipe.json().clear(self.key, self._uploads_json_key) + pipe_results = self._execute_pipeline(pipe, 1) + + if not all(pipe_results): + raise MarcFileUploadSessionError("Failed to clear uploads.") + + def _get_specific( + self, + keys: str | Sequence[str], + get_path: Callable[[str], str], + ) -> dict[str, Any]: + if isinstance(keys, str): + keys = [keys] + paths = {get_path(k): k for k in keys} + results = self._redis_client.json().get(self.key, *paths.keys()) + if len(keys) == 1: + return {keys[0]: self._parse_value(results)} + else: + return {paths[k]: v for k, v in self._parse_multi(results).items()} + + def _get_all(self, key: str) -> dict[str, Any]: + get_results = self._redis_client.json().get(self.key, key) + results: dict[str, Any] | None = self._parse_value(get_results) + + if results is None: + return {} + + return results + + def get(self, keys: str | Sequence[str] | None = None) -> dict[str, MarcFileUpload]: + if keys is None: + uploads = self._get_all(self._uploads_json_key) + else: + uploads = self._get_specific(keys, self._upload_path) + + return { + k: MarcFileUpload.parse_obj(v) for k, v in uploads.items() if v is not None + } + + def get_upload_ids(self, keys: str | Sequence[str]) -> dict[str, str]: + return self._get_specific(keys, self._upload_id_path) + + def get_part_num_and_buffer(self, key: str) -> tuple[int, str]: + try: + with self._redis_client.pipeline() as pipe: + pipe.json().get(self.key, self._buffer_path(key)) + pipe.json().arrlen(self.key, self._parts_path(key)) + results = pipe.execute() + except ResponseError as e: + raise MarcFileUploadSessionError( + "Failed to get part number and buffer data." + ) from e + + buffer_data: str = self._parse_value_or_raise(results[0]) + part_number: int = self._parse_value_or_raise(results[1]) + + return part_number, buffer_data + + def state(self) -> MarcFileUploadState | None: + get_results = self._redis_client.json().get(self.key, self._state_json_key) + state: str | None = self._parse_value(get_results) + if state is None: + return None + return MarcFileUploadState(state) + + def set_state(self, state: MarcFileUploadState) -> None: + with self._pipeline() as pipe: + self._execute_pipeline(pipe, 0, state=state) diff --git a/src/palace/manager/service/redis/redis.py b/src/palace/manager/service/redis/redis.py index c9cd41be9b..25b06f91b5 100644 --- a/src/palace/manager/service/redis/redis.py +++ b/src/palace/manager/service/redis/redis.py @@ -79,19 +79,31 @@ def key_args(self, args: list[Any]) -> Sequence[str]: RedisCommandArgs("KEYS"), RedisCommandArgs("GET"), RedisCommandArgs("EXPIRE"), + RedisCommandArgs("PEXPIRE"), RedisCommandArgs("GETRANGE"), RedisCommandArgs("SET"), RedisCommandArgs("TTL"), RedisCommandArgs("PTTL"), + RedisCommandArgs("PTTL"), RedisCommandArgs("SADD"), RedisCommandArgs("SPOP"), RedisCommandArgs("SCARD"), + RedisCommandArgs("WATCH"), RedisCommandArgs("SRANDMEMBER"), RedisCommandArgs("SREM"), RedisCommandArgs("DEL", args_end=None), RedisCommandArgs("MGET", args_end=None), RedisCommandArgs("EXISTS", args_end=None), RedisCommandArgs("EXPIRETIME"), + RedisCommandArgs("JSON.CLEAR"), + RedisCommandArgs("JSON.SET"), + RedisCommandArgs("JSON.STRLEN"), + RedisCommandArgs("JSON.STRAPPEND"), + RedisCommandArgs("JSON.NUMINCRBY"), + RedisCommandArgs("JSON.GET"), + RedisCommandArgs("JSON.OBJKEYS"), + RedisCommandArgs("JSON.ARRAPPEND"), + RedisCommandArgs("JSON.ARRLEN"), RedisVariableCommandArgs("EVALSHA", key_index=1), ] } @@ -161,3 +173,6 @@ def _prefix(self) -> str: def execute_command(self, *args: Any, **options: Any) -> Any: self._check_prefix(*args) return super().execute_command(*args, **options) + + def __enter__(self) -> Pipeline: + return self diff --git a/src/palace/manager/service/storage/s3.py b/src/palace/manager/service/storage/s3.py index 97704bfffa..fa8dc2e91a 100644 --- a/src/palace/manager/service/storage/s3.py +++ b/src/palace/manager/service/storage/s3.py @@ -110,6 +110,8 @@ def exception(self) -> BaseException | None: class S3Service(LoggerMixin): + MINIMUM_MULTIPART_UPLOAD_SIZE = 5 * 1024 * 1024 # 5MB + def __init__( self, client: S3Client, diff --git a/tests/conftest.py b/tests/conftest.py index 395e05d264..387bbb70b4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,7 @@ "tests.fixtures.files", "tests.fixtures.flask", "tests.fixtures.library", + "tests.fixtures.marc", "tests.fixtures.odl", "tests.fixtures.redis", "tests.fixtures.s3", diff --git a/tests/fixtures/database.py b/tests/fixtures/database.py index 7b1b384c27..f8baad2054 100644 --- a/tests/fixtures/database.py +++ b/tests/fixtures/database.py @@ -6,7 +6,7 @@ import tempfile import time import uuid -from collections.abc import Generator, Iterable +from collections.abc import Generator, Iterable, Mapping from contextlib import contextmanager from functools import cached_property from textwrap import dedent @@ -36,10 +36,15 @@ from palace.manager.core.config import Configuration from palace.manager.core.exceptions import BasePalaceException, PalaceValueError from palace.manager.core.opds_import import OPDSAPI -from palace.manager.integration.base import HasIntegrationConfiguration +from palace.manager.integration.base import ( + HasIntegrationConfiguration, + HasLibraryIntegrationConfiguration, +) from palace.manager.integration.base import SettingsType as TIntegrationSettings from palace.manager.integration.configuration.library import LibrarySettings from palace.manager.integration.goals import Goals +from palace.manager.integration.settings import BaseSettings +from palace.manager.service.integration_registry.base import IntegrationRegistry from palace.manager.sqlalchemy.constants import MediaTypes from palace.manager.sqlalchemy.model.classification import ( Classification, @@ -921,6 +926,16 @@ def license( def isbn_take(self) -> str: return self._isbns.pop() + @cached_property + def _goal_registry_mapping(self) -> Mapping[Goals, IntegrationRegistry[Any]]: + return { + Goals.CATALOG_GOAL: self._services.services.integration_registry.catalog_services(), + Goals.DISCOVERY_GOAL: self._services.services.integration_registry.discovery(), + Goals.LICENSE_GOAL: self._services.services.integration_registry.license_providers(), + Goals.METADATA_GOAL: self._services.services.integration_registry.metadata(), + Goals.PATRON_AUTH_GOAL: self._services.services.integration_registry.patron_auth(), + } + def integration_configuration( self, protocol: type[HasIntegrationConfiguration[TIntegrationSettings]] | str, @@ -930,17 +945,10 @@ def integration_configuration( name: str | None = None, settings: TIntegrationSettings | None = None, ) -> IntegrationConfiguration: - registry_mapping = { - Goals.CATALOG_GOAL: self._services.services.integration_registry.catalog_services(), - Goals.DISCOVERY_GOAL: self._services.services.integration_registry.discovery(), - Goals.LICENSE_GOAL: self._services.services.integration_registry.license_providers(), - Goals.METADATA_GOAL: self._services.services.integration_registry.metadata(), - Goals.PATRON_AUTH_GOAL: self._services.services.integration_registry.patron_auth(), - } protocol_str = ( protocol if isinstance(protocol, str) - else registry_mapping[goal].get_protocol(protocol) + else self._goal_registry_mapping[goal].get_protocol(protocol) ) assert protocol_str is not None integration, ignore = get_one_or_create( @@ -972,6 +980,37 @@ def integration_configuration( return integration + def integration_library_configuration( + self, + parent: IntegrationConfiguration, + library: Library, + settings: BaseSettings | None = None, + ) -> IntegrationLibraryConfiguration: + assert parent.goal is not None + assert parent.protocol is not None + parent_cls = self._goal_registry_mapping[parent.goal][parent.protocol] + if not issubclass(parent_cls, HasLibraryIntegrationConfiguration): + raise TypeError( + f"{parent_cls.__name__} does not support library configuration" + ) + + integration, ignore = get_one_or_create( + self.session, + IntegrationLibraryConfiguration, + parent=parent, + library=library, + ) + + if settings is not None: + if not isinstance(settings, parent_cls.library_settings_class()): + raise TypeError( + f"settings must be an instance of {parent_cls.library_settings_class().__name__} " + f"not {settings.__class__.__name__}" + ) + parent_cls.library_settings_update(integration, settings) + + return integration + def discovery_service_integration( self, url: str | None = None ) -> IntegrationConfiguration: diff --git a/tests/fixtures/marc.py b/tests/fixtures/marc.py new file mode 100644 index 0000000000..06ed89d44e --- /dev/null +++ b/tests/fixtures/marc.py @@ -0,0 +1,99 @@ +import datetime +from collections.abc import Sequence + +import pytest + +from palace.manager.integration.goals import Goals +from palace.manager.marc.exporter import LibraryInfo, MarcExporter +from palace.manager.marc.settings import MarcExporterLibrarySettings +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.sqlalchemy.model.integration import IntegrationConfiguration +from palace.manager.sqlalchemy.model.marcfile import MarcFile +from palace.manager.sqlalchemy.model.work import Work +from palace.manager.sqlalchemy.util import create +from palace.manager.util.datetime_helpers import utc_now +from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.services import ServicesFixture + + +class MarcExporterFixture: + def __init__( + self, db: DatabaseTransactionFixture, services_fixture: ServicesFixture + ): + self._db = db + self._services_fixture = services_fixture + + self.registry = ( + services_fixture.services.integration_registry.catalog_services() + ) + self.session = db.session + + self.library1 = db.default_library() + self.library1.short_name = "library1" + self.library2 = db.library(short_name="library2") + + self.collection1 = db.collection(name="collection1") + self.collection2 = db.collection() + self.collection3 = db.collection() + + self.collection1.libraries = [self.library1, self.library2] + self.collection2.libraries = [self.library1] + self.collection3.libraries = [self.library2] + + self.test_marc_file_key = "test-file-1.mrc" + + def integration(self) -> IntegrationConfiguration: + return self._db.integration_configuration(MarcExporter, Goals.CATALOG_GOAL) + + def work(self, collection: Collection | None = None) -> Work: + collection = collection or self.collection1 + edition = self._db.edition() + self._db.licensepool(edition, collection=collection) + work = self._db.work(presentation_edition=edition) + work.last_update_time = utc_now() + return work + + def works(self, collection: Collection | None = None) -> list[Work]: + return [self.work(collection) for _ in range(5)] + + def configure_export(self) -> None: + marc_integration = self.integration() + self._db.integration_library_configuration( + marc_integration, + self.library1, + MarcExporterLibrarySettings(organization_code="library1-org"), + ) + self._db.integration_library_configuration( + marc_integration, + self.library2, + MarcExporterLibrarySettings(organization_code="library2-org"), + ) + + self.collection1.export_marc_records = True + self.collection2.export_marc_records = True + self.collection3.export_marc_records = True + + create( + self.session, + MarcFile, + library=self.library1, + collection=self.collection1, + key=self.test_marc_file_key, + created=utc_now() - datetime.timedelta(days=7), + ) + + def enabled_libraries( + self, collection: Collection | None = None + ) -> Sequence[LibraryInfo]: + collection = collection or self.collection1 + assert collection.id is not None + return MarcExporter.enabled_libraries( + self.session, self.registry, collection_id=collection.id + ) + + +@pytest.fixture +def marc_exporter_fixture( + db: DatabaseTransactionFixture, services_fixture: ServicesFixture +) -> MarcExporterFixture: + return MarcExporterFixture(db, services_fixture) diff --git a/tests/fixtures/s3.py b/tests/fixtures/s3.py index 19ec790b60..5dc1de11c0 100644 --- a/tests/fixtures/s3.py +++ b/tests/fixtures/s3.py @@ -2,12 +2,27 @@ import functools import sys +import uuid +from collections.abc import Generator +from dataclasses import dataclass, field from typing import TYPE_CHECKING, BinaryIO, NamedTuple, Protocol from unittest.mock import MagicMock +from uuid import uuid4 import pytest +from mypy_boto3_s3 import S3Client +from pydantic import AnyHttpUrl -from palace.manager.service.storage.s3 import MultipartS3ContextManager, S3Service +from palace.manager.service.configuration.service_configuration import ( + ServiceConfiguration, +) +from palace.manager.service.storage.container import Storage +from palace.manager.service.storage.s3 import ( + MultipartS3ContextManager, + MultipartS3UploadPart, + S3Service, +) +from tests.fixtures.config import FixtureTestUrlConfiguration if sys.version_info >= (3, 11): from typing import Self @@ -54,14 +69,28 @@ def upload_part(self, content: bytes) -> None: def _upload_complete(self) -> None: if self.content: self._complete = True - self.parent.uploads.append( - MockS3ServiceUpload(self.key, self.content, self.media_type) + self.parent.uploads[self.key] = MockS3ServiceUpload( + self.key, self.content, self.media_type ) def _upload_abort(self) -> None: ... +@dataclass +class MockMultipartUploadPart: + part_data: MultipartS3UploadPart + content: bytes + + +@dataclass +class MockMultipartUpload: + key: str + upload_id: str + parts: list[MockMultipartUploadPart] = field(default_factory=list) + content_type: str | None = None + + class MockS3Service(S3Service): def __init__( self, @@ -71,16 +100,19 @@ def __init__( url_template: str, ) -> None: super().__init__(client, region, bucket, url_template) - self.uploads: list[MockS3ServiceUpload] = [] + self.uploads: dict[str, MockS3ServiceUpload] = {} self.mocked_multipart_upload: MockMultipartS3ContextManager | None = None + self.upload_in_progress: dict[str, MockMultipartUpload] = {} + self.aborted: list[str] = [] + def store_stream( self, key: str, stream: BinaryIO, content_type: str | None = None, ) -> str | None: - self.uploads.append(MockS3ServiceUpload(key, stream.read(), content_type)) + self.uploads[key] = MockS3ServiceUpload(key, stream.read(), content_type) return self.generate_url(key) def multipart( @@ -91,6 +123,45 @@ def multipart( ) return self.mocked_multipart_upload + def multipart_create(self, key: str, content_type: str | None = None) -> str: + upload_id = str(uuid4()) + self.upload_in_progress[key] = MockMultipartUpload( + key, upload_id, content_type=content_type + ) + return upload_id + + def multipart_upload( + self, key: str, upload_id: str, part_number: int, content: bytes + ) -> MultipartS3UploadPart: + etag = str(uuid4()) + part = MultipartS3UploadPart(etag=etag, part_number=part_number) + assert key in self.upload_in_progress + assert self.upload_in_progress[key].upload_id == upload_id + self.upload_in_progress[key].parts.append( + MockMultipartUploadPart(part, content) + ) + return part + + def multipart_complete( + self, key: str, upload_id: str, parts: list[MultipartS3UploadPart] + ) -> None: + assert key in self.upload_in_progress + assert self.upload_in_progress[key].upload_id == upload_id + complete_upload = self.upload_in_progress.pop(key) + for part_stored, part_passed_in in zip(complete_upload.parts, parts): + assert part_stored.part_data == part_passed_in + self.uploads[key] = MockS3ServiceUpload( + key, + b"".join(part_stored.content for part_stored in complete_upload.parts), + complete_upload.content_type, + ) + + def multipart_abort(self, key: str, upload_id: str) -> None: + assert key in self.upload_in_progress + assert self.upload_in_progress[key].upload_id == upload_id + self.upload_in_progress.pop(key) + self.aborted.append(key) + class S3ServiceProtocol(Protocol): def __call__( @@ -133,3 +204,95 @@ def mock_service(self) -> MockS3Service: @pytest.fixture def s3_service_fixture() -> S3ServiceFixture: return S3ServiceFixture() + + +class S3UploaderIntegrationConfiguration(FixtureTestUrlConfiguration): + url: AnyHttpUrl + user: str + password: str + + class Config(ServiceConfiguration.Config): + env_prefix = "PALACE_TEST_MINIO_" + + +class S3ServiceIntegrationFixture: + def __init__(self): + self.container = Storage() + self.configuration = S3UploaderIntegrationConfiguration.from_env() + self.analytics_bucket = self.random_name("analytics") + self.public_access_bucket = self.random_name("public") + self.container.config.from_dict( + { + "access_key": self.configuration.user, + "secret_key": self.configuration.password, + "endpoint_url": self.configuration.url, + "region": "us-east-1", + "analytics_bucket": self.analytics_bucket, + "public_access_bucket": self.public_access_bucket, + "url_template": self.configuration.url + "/{bucket}/{key}", + } + ) + self.buckets = [] + self.create_buckets() + + @classmethod + def random_name(cls, prefix: str = "test"): + return f"{prefix}-{uuid.uuid4()}" + + @property + def s3_client(self) -> S3Client: + return self.container.s3_client() + + @property + def public(self) -> S3Service: + return self.container.public() + + @property + def analytics(self) -> S3Service: + return self.container.analytics() + + def create_bucket(self, bucket_name: str) -> None: + client = self.s3_client + client.create_bucket(Bucket=bucket_name) + self.buckets.append(bucket_name) + + def get_bucket(self, bucket_name: str) -> str: + if bucket_name == "public": + return self.public_access_bucket + elif bucket_name == "analytics": + return self.analytics_bucket + else: + raise ValueError(f"Unknown bucket name: {bucket_name}") + + def create_buckets(self) -> None: + for bucket in [self.analytics_bucket, self.public_access_bucket]: + self.create_bucket(bucket) + + def list_objects(self, bucket_name: str) -> list[str]: + bucket = self.get_bucket(bucket_name) + response = self.s3_client.list_objects(Bucket=bucket) + return [object["Key"] for object in response.get("Contents", [])] + + def get_object(self, bucket_name: str, key: str) -> bytes: + bucket = self.get_bucket(bucket_name) + response = self.s3_client.get_object(Bucket=bucket, Key=key) + return response["Body"].read() + + def close(self): + for bucket in self.buckets: + response = self.s3_client.list_objects(Bucket=bucket) + + for object in response.get("Contents", []): + object_key = object["Key"] + self.s3_client.delete_object(Bucket=bucket, Key=object_key) + + self.s3_client.delete_bucket(Bucket=bucket) + + +@pytest.fixture +def s3_service_integration_fixture() -> ( + Generator[S3ServiceIntegrationFixture, None, None] +): + fixture = S3ServiceIntegrationFixture() + yield fixture + fixture.close() diff --git a/tests/manager/api/admin/controller/test_catalog_services.py b/tests/manager/api/admin/controller/test_catalog_services.py index 60119e5a6d..81ae168430 100644 --- a/tests/manager/api/admin/controller/test_catalog_services.py +++ b/tests/manager/api/admin/controller/test_catalog_services.py @@ -19,8 +19,9 @@ NO_PROTOCOL_FOR_NEW_SERVICE, UNKNOWN_PROTOCOL, ) -from palace.manager.core.marc import MARCExporter, MarcExporterLibrarySettings from palace.manager.integration.goals import Goals +from palace.manager.marc.exporter import MarcExporter +from palace.manager.marc.settings import MarcExporterLibrarySettings from palace.manager.sqlalchemy.model.integration import IntegrationConfiguration from palace.manager.sqlalchemy.util import get_one from palace.manager.util.problem_detail import ProblemDetail @@ -60,7 +61,7 @@ def test_catalog_services_get_with_no_services( assert 1 == len(protocols) assert protocols[0].get("name") == controller.registry.get_protocol( - MARCExporter + MarcExporter ) assert "settings" in protocols[0] assert "library_settings" in protocols[0] @@ -76,7 +77,7 @@ def test_catalog_services_get_with_marc_exporter( ) integration = db.integration_configuration( - MARCExporter, + MarcExporter, Goals.CATALOG_GOAL, name="name", libraries=[db.default_library()], @@ -84,7 +85,7 @@ def test_catalog_services_get_with_marc_exporter( library_settings_integration = integration.for_library(db.default_library()) assert library_settings_integration is not None - MARCExporter.library_settings_update( + MarcExporter.library_settings_update( library_settings_integration, library_settings ) @@ -120,28 +121,28 @@ def test_catalog_services_get_with_marc_exporter( id="unknown protocol", ), pytest.param( - {"protocol": "MARCExporter", "id": "123"}, + {"protocol": "MarcExporter", "id": "123"}, MISSING_SERVICE, True, None, id="unknown id", ), pytest.param( - {"protocol": "MARCExporter", "id": ""}, + {"protocol": "MarcExporter", "id": ""}, CANNOT_CHANGE_PROTOCOL, True, None, id="cannot change protocol", ), pytest.param( - {"protocol": "MARCExporter"}, + {"protocol": "MarcExporter"}, MISSING_SERVICE_NAME, True, None, id="no name", ), pytest.param( - {"protocol": "MARCExporter", "name": "existing integration"}, + {"protocol": "MarcExporter", "name": "existing integration"}, INTEGRATION_NAME_ALREADY_IN_USE, True, None, @@ -149,7 +150,7 @@ def test_catalog_services_get_with_marc_exporter( ), pytest.param( { - "protocol": "MARCExporter", + "protocol": "MarcExporter", "name": "new name", "libraries": json.dumps([{"short_name": "default"}]), }, @@ -203,7 +204,7 @@ def test_catalog_services_post_create( controller: CatalogServicesController, db: DatabaseTransactionFixture, ): - protocol = controller.registry.get_protocol(MARCExporter) + protocol = controller.registry.get_protocol(MarcExporter) assert protocol is not None with flask_app_fixture.test_request_context_system_admin("/", method="POST"): @@ -241,7 +242,7 @@ def test_catalog_services_post_create( assert service.name == "exporter name" assert service.libraries == [db.default_library()] - settings = MARCExporter.library_settings_load(service.library_configurations[0]) + settings = MarcExporter.library_settings_load(service.library_configurations[0]) assert settings.include_summary is False assert settings.include_genres is True @@ -252,7 +253,7 @@ def test_catalog_services_post_edit( db: DatabaseTransactionFixture, ): service = db.integration_configuration( - MARCExporter, + MarcExporter, Goals.CATALOG_GOAL, name="name", ) @@ -287,7 +288,7 @@ def test_catalog_services_post_edit( assert service.name == "exporter name" assert service.libraries == [db.default_library()] - settings = MARCExporter.library_settings_load(service.library_configurations[0]) + settings = MarcExporter.library_settings_load(service.library_configurations[0]) assert settings.include_summary is True assert settings.include_genres is False @@ -298,7 +299,7 @@ def test_catalog_services_delete( db: DatabaseTransactionFixture, ): service = db.integration_configuration( - MARCExporter, + MarcExporter, Goals.CATALOG_GOAL, ) diff --git a/tests/manager/api/controller/test_marc.py b/tests/manager/api/controller/test_marc.py index cadc2584ee..d4ebec061d 100644 --- a/tests/manager/api/controller/test_marc.py +++ b/tests/manager/api/controller/test_marc.py @@ -7,8 +7,11 @@ from flask import Response from palace.manager.api.controller.marc import MARCRecordController -from palace.manager.core.marc import MARCExporter from palace.manager.integration.goals import Goals +from palace.manager.marc.exporter import MarcExporter +from palace.manager.service.integration_registry.catalog_services import ( + CatalogServicesRegistry, +) from palace.manager.service.storage.s3 import S3Service from palace.manager.sqlalchemy.model.collection import Collection from palace.manager.sqlalchemy.model.library import Library @@ -16,14 +19,18 @@ from palace.manager.sqlalchemy.util import create from palace.manager.util.datetime_helpers import utc_now from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.services import ServicesFixture class MARCRecordControllerFixture: - def __init__(self, db: DatabaseTransactionFixture): + def __init__( + self, db: DatabaseTransactionFixture, registry: CatalogServicesRegistry + ): self.db = db + self.registry = registry self.mock_s3_service = MagicMock(spec=S3Service) self.mock_s3_service.generate_url = lambda x: "http://s3.url/" + x - self.controller = MARCRecordController(self.mock_s3_service) + self.controller = MARCRecordController(self.mock_s3_service, self.registry) self.library = db.default_library() self.collection = db.default_collection() self.collection.export_marc_records = True @@ -35,7 +42,7 @@ def __init__(self, db: DatabaseTransactionFixture): def integration(self, library: Library | None = None): library = library or self.library return self.db.integration_configuration( - MARCExporter, + MarcExporter, Goals.CATALOG_GOAL, libraries=[library], ) @@ -73,9 +80,11 @@ def get_response_html(self, response: Response) -> str: @pytest.fixture def marc_record_controller_fixture( - db: DatabaseTransactionFixture, + db: DatabaseTransactionFixture, services_fixture: ServicesFixture ) -> MARCRecordControllerFixture: - return MARCRecordControllerFixture(db) + return MARCRecordControllerFixture( + db, services_fixture.services.integration_registry.catalog_services() + ) class TestMARCRecordController: diff --git a/tests/manager/celery/tasks/test_marc.py b/tests/manager/celery/tasks/test_marc.py new file mode 100644 index 0000000000..3b796de2ed --- /dev/null +++ b/tests/manager/celery/tasks/test_marc.py @@ -0,0 +1,321 @@ +from typing import Any +from unittest.mock import ANY, call, patch + +import pytest +from pymarc import MARCReader +from sqlalchemy import select + +from palace.manager.celery.tasks import marc +from palace.manager.marc.exporter import MarcExporter +from palace.manager.marc.uploader import MarcUploadManager +from palace.manager.service.logging.configuration import LogLevel +from palace.manager.service.redis.models.marc import ( + MarcFileUploadSession, + MarcFileUploadSessionError, + MarcFileUploadState, +) +from palace.manager.sqlalchemy.model.collection import Collection +from palace.manager.sqlalchemy.model.marcfile import MarcFile +from palace.manager.sqlalchemy.model.work import Work +from palace.manager.sqlalchemy.util import create +from palace.manager.util.datetime_helpers import utc_now +from tests.fixtures.celery import CeleryFixture +from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.marc import MarcExporterFixture +from tests.fixtures.redis import RedisFixture +from tests.fixtures.s3 import S3ServiceFixture, S3ServiceIntegrationFixture +from tests.fixtures.services import ServicesFixture + + +class TestMarcExport: + def test_no_works( + self, + db: DatabaseTransactionFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + celery_fixture: CeleryFixture, + ): + marc_exporter_fixture.configure_export() + with patch.object(marc, "marc_export_collection") as marc_export_collection: + # Because none of the collections have works, we should skip all of them. + marc.marc_export.delay().wait() + marc_export_collection.delay.assert_not_called() + + def test_normal_run( + self, + db: DatabaseTransactionFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + celery_fixture: CeleryFixture, + ): + marc_exporter_fixture.configure_export() + with patch.object(marc, "marc_export_collection") as marc_export_collection: + # Runs against all the expected collections + collections = [ + marc_exporter_fixture.collection1, + marc_exporter_fixture.collection2, + marc_exporter_fixture.collection3, + ] + for collection in collections: + marc_exporter_fixture.work(collection) + marc.marc_export.delay().wait() + marc_export_collection.delay.assert_has_calls( + [ + call(collection_id=collection.id, start_time=ANY, libraries=ANY) + for collection in collections + ], + any_order=True, + ) + + def test_skip_collections( + self, + db: DatabaseTransactionFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + celery_fixture: CeleryFixture, + ): + marc_exporter_fixture.configure_export() + collections = [ + marc_exporter_fixture.collection1, + marc_exporter_fixture.collection2, + marc_exporter_fixture.collection3, + ] + for collection in collections: + marc_exporter_fixture.work(collection) + with patch.object(marc, "marc_export_collection") as marc_export_collection: + # Collection 1 should be skipped because it is locked + assert marc_exporter_fixture.collection1.id is not None + MarcFileUploadSession( + redis_fixture.client, marc_exporter_fixture.collection1.id + ).acquire() + + # Collection 2 should be skipped because it was updated recently + create( + db.session, + MarcFile, + library=marc_exporter_fixture.library1, + collection=marc_exporter_fixture.collection2, + created=utc_now(), + key="test-file-2.mrc", + ) + + # Collection 3 should be skipped because its state is not INITIAL + assert marc_exporter_fixture.collection3.id is not None + upload_session = MarcFileUploadSession( + redis_fixture.client, marc_exporter_fixture.collection3.id + ) + with upload_session.lock() as acquired: + assert acquired + upload_session.set_state(MarcFileUploadState.QUEUED) + + marc.marc_export.delay().wait() + marc_export_collection.delay.assert_not_called() + + +class MarcExportCollectionFixture: + def __init__( + self, + db: DatabaseTransactionFixture, + celery_fixture: CeleryFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + s3_service_integration_fixture: S3ServiceIntegrationFixture, + s3_service_fixture: S3ServiceFixture, + services_fixture: ServicesFixture, + ): + self.db = db + self.celery_fixture = celery_fixture + self.redis_fixture = redis_fixture + self.marc_exporter_fixture = marc_exporter_fixture + self.s3_service_integration_fixture = s3_service_integration_fixture + self.s3_service_fixture = s3_service_fixture + self.services_fixture = services_fixture + + self.mock_s3 = self.s3_service_fixture.mock_service() + self.mock_s3.MINIMUM_MULTIPART_UPLOAD_SIZE = 10 + marc_exporter_fixture.configure_export() + + self.start_time = utc_now() + + def marc_files(self) -> list[MarcFile]: + # We need to ignore the test-file-1.mrc file, which is created by our call to configure_export. + return [ + f + for f in self.db.session.execute(select(MarcFile)).scalars().all() + if f.key != self.marc_exporter_fixture.test_marc_file_key + ] + + def redis_data(self, collection: Collection) -> dict[str, Any] | None: + assert collection.id is not None + uploads = MarcFileUploadSession(self.redis_fixture.client, collection.id) + return self.redis_fixture.client.json().get(uploads.key) + + def setup_minio_storage(self) -> None: + self.services_fixture.services.storage.override( + self.s3_service_integration_fixture.container + ) + + def setup_mock_storage(self) -> None: + self.services_fixture.services.storage.public.override(self.mock_s3) + + def works(self, collection: Collection) -> list[Work]: + return [self.marc_exporter_fixture.work(collection) for _ in range(15)] + + def export_collection(self, collection: Collection) -> None: + service = self.services_fixture.services.integration_registry.catalog_services() + assert collection.id is not None + info = MarcExporter.enabled_libraries(self.db.session, service, collection.id) + libraries = [l.dict() for l in info] + marc.marc_export_collection.delay( + collection.id, batch_size=5, start_time=self.start_time, libraries=libraries + ).wait() + + +@pytest.fixture +def marc_export_collection_fixture( + db: DatabaseTransactionFixture, + celery_fixture: CeleryFixture, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + s3_service_integration_fixture: S3ServiceIntegrationFixture, + s3_service_fixture: S3ServiceFixture, + services_fixture: ServicesFixture, +) -> MarcExportCollectionFixture: + return MarcExportCollectionFixture( + db, + celery_fixture, + redis_fixture, + marc_exporter_fixture, + s3_service_integration_fixture, + s3_service_fixture, + services_fixture, + ) + + +class TestMarcExportCollection: + def test_normal_run( + self, + s3_service_integration_fixture: S3ServiceIntegrationFixture, + marc_exporter_fixture: MarcExporterFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + ): + marc_export_collection_fixture.setup_minio_storage() + collection = marc_exporter_fixture.collection1 + work_uris = [ + work.license_pools[0].identifier.urn + for work in marc_export_collection_fixture.works(collection) + ] + + # Run the full end-to-end process for exporting a collection, this should generate + # 3 batches of 5 works each, putting the results into minio. + marc_export_collection_fixture.export_collection(collection) + + # Verify that we didn't leave anything in the redis cache. + assert marc_export_collection_fixture.redis_data(collection) is None + + # Verify that the expected number of files were uploaded to minio. + uploaded_files = s3_service_integration_fixture.list_objects("public") + assert len(uploaded_files) == 3 + + # Verify that the expected number of marc files were created in the database. + marc_files = marc_export_collection_fixture.marc_files() + assert len(marc_files) == 3 + filenames = [marc_file.key for marc_file in marc_files] + + # Verify that the uploaded files are the expected ones. + assert set(uploaded_files) == set(filenames) + + # Verify that the marc files contain the expected works. + for file in uploaded_files: + data = s3_service_integration_fixture.get_object("public", file) + records = list(MARCReader(data)) + assert len(records) == len(work_uris) + marc_uris = [record["001"].data for record in records] + assert set(marc_uris) == set(work_uris) + + # Make sure the records have the correct organization code. + expected_org = "library1-org" if "library1" in file else "library2-org" + assert all(record["003"].data == expected_org for record in records) + + # Make sure records have the correct status + expected_status = "c" if "delta" in file else "n" + assert all( + record.leader.record_status == expected_status for record in records + ) + + def test_collection_no_works( + self, + marc_exporter_fixture: MarcExporterFixture, + s3_service_integration_fixture: S3ServiceIntegrationFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + ): + marc_export_collection_fixture.setup_minio_storage() + collection = marc_exporter_fixture.collection2 + marc_export_collection_fixture.export_collection(collection) + + assert marc_export_collection_fixture.marc_files() == [] + assert s3_service_integration_fixture.list_objects("public") == [] + assert marc_export_collection_fixture.redis_data(collection) is None + + def test_exception_handled( + self, + marc_exporter_fixture: MarcExporterFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + ): + marc_export_collection_fixture.setup_mock_storage() + collection = marc_exporter_fixture.collection1 + marc_export_collection_fixture.works(collection) + + with patch.object(MarcUploadManager, "complete") as complete: + complete.side_effect = Exception("Test Exception") + with pytest.raises(Exception, match="Test Exception"): + marc_export_collection_fixture.export_collection(collection) + + # After the exception, we should have aborted the multipart uploads and deleted the redis data. + assert marc_export_collection_fixture.marc_files() == [] + assert marc_export_collection_fixture.redis_data(collection) is None + assert len(marc_export_collection_fixture.mock_s3.aborted) == 3 + + def test_locked( + self, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + caplog: pytest.LogCaptureFixture, + ): + caplog.set_level(LogLevel.info) + collection = marc_exporter_fixture.collection1 + assert collection.id is not None + MarcFileUploadSession(redis_fixture.client, collection.id).acquire() + marc_export_collection_fixture.setup_mock_storage() + with patch.object(MarcExporter, "query_works") as query: + marc_export_collection_fixture.export_collection(collection) + query.assert_not_called() + assert "another task is already processing it" in caplog.text + + def test_outdated_task_run( + self, + redis_fixture: RedisFixture, + marc_exporter_fixture: MarcExporterFixture, + marc_export_collection_fixture: MarcExportCollectionFixture, + caplog: pytest.LogCaptureFixture, + ): + # In the case that an old task is run again for some reason, it should + # detect that its update number is incorrect and exit. + caplog.set_level(LogLevel.info) + collection = marc_exporter_fixture.collection1 + marc_export_collection_fixture.setup_mock_storage() + assert collection.id is not None + + # Acquire the lock and start an upload, this simulates another task having done work + # that the current task doesn't know about. + uploads = MarcFileUploadSession(redis_fixture.client, collection.id) + with uploads.lock() as locked: + assert locked + uploads.append_buffers({"test": "data"}) + + with pytest.raises(MarcFileUploadSessionError, match="Update number mismatch"): + marc_export_collection_fixture.export_collection(collection) + + assert marc_export_collection_fixture.marc_files() == [] + assert marc_export_collection_fixture.redis_data(collection) is None diff --git a/tests/manager/core/test_marc.py b/tests/manager/core/test_marc.py deleted file mode 100644 index 12671e3f2d..0000000000 --- a/tests/manager/core/test_marc.py +++ /dev/null @@ -1,900 +0,0 @@ -from __future__ import annotations - -import datetime -import functools -import logging -import urllib -from typing import TYPE_CHECKING -from unittest.mock import MagicMock, create_autospec, patch - -import pytest -from pymarc import Indicators, MARCReader, Record -from pytest import LogCaptureFixture - -from palace.manager.core.marc import Annotator, MARCExporter -from palace.manager.sqlalchemy.model.classification import Genre -from palace.manager.sqlalchemy.model.contributor import Contributor -from palace.manager.sqlalchemy.model.datasource import DataSource -from palace.manager.sqlalchemy.model.edition import Edition -from palace.manager.sqlalchemy.model.identifier import Identifier -from palace.manager.sqlalchemy.model.licensing import ( - DeliveryMechanism, - LicensePoolDeliveryMechanism, - RightsStatus, -) -from palace.manager.sqlalchemy.model.marcfile import MarcFile -from palace.manager.sqlalchemy.model.resource import Representation -from palace.manager.util.datetime_helpers import datetime_utc, utc_now -from palace.manager.util.uuid import uuid_encode - -if TYPE_CHECKING: - from tests.fixtures.database import DatabaseTransactionFixture - from tests.fixtures.s3 import MockS3Service, S3ServiceFixture - - -class AnnotateWorkRecordFixture: - def __init__(self): - self.cm_url = "http://cm.url" - self.short_name = "short_name" - self.web_client_urls = ["http://webclient.url"] - self.organization_name = "org" - self.include_summary = True - self.include_genres = True - - self.annotator = Annotator( - self.cm_url, - self.short_name, - self.web_client_urls, - self.organization_name, - self.include_summary, - self.include_genres, - ) - - self.revised = MagicMock() - self.work = MagicMock() - self.pool = MagicMock() - self.edition = MagicMock() - self.identifier = MagicMock() - - self.mock_leader = create_autospec(self.annotator.leader, return_value=" " * 24) - self.mock_add_control_fields = create_autospec( - self.annotator.add_control_fields - ) - self.mock_add_marc_organization_code = create_autospec( - self.annotator.add_marc_organization_code - ) - self.mock_add_isbn = create_autospec(self.annotator.add_isbn) - self.mock_add_title = create_autospec(self.annotator.add_title) - self.mock_add_contributors = create_autospec(self.annotator.add_contributors) - self.mock_add_publisher = create_autospec(self.annotator.add_publisher) - self.mock_add_distributor = create_autospec(self.annotator.add_distributor) - self.mock_add_physical_description = create_autospec( - self.annotator.add_physical_description - ) - self.mock_add_audience = create_autospec(self.annotator.add_audience) - self.mock_add_series = create_autospec(self.annotator.add_series) - self.mock_add_system_details = create_autospec( - self.annotator.add_system_details - ) - self.mock_add_formats = create_autospec(self.annotator.add_formats) - self.mock_add_summary = create_autospec(self.annotator.add_summary) - self.mock_add_genres = create_autospec(self.annotator.add_genres) - self.mock_add_ebooks_subject = create_autospec( - self.annotator.add_ebooks_subject - ) - self.mock_add_web_client_urls = create_autospec( - self.annotator.add_web_client_urls - ) - - self.annotator.leader = self.mock_leader - self.annotator.add_control_fields = self.mock_add_control_fields - self.annotator.add_marc_organization_code = self.mock_add_marc_organization_code - self.annotator.add_isbn = self.mock_add_isbn - self.annotator.add_title = self.mock_add_title - self.annotator.add_contributors = self.mock_add_contributors - self.annotator.add_publisher = self.mock_add_publisher - self.annotator.add_distributor = self.mock_add_distributor - self.annotator.add_physical_description = self.mock_add_physical_description - self.annotator.add_audience = self.mock_add_audience - self.annotator.add_series = self.mock_add_series - self.annotator.add_system_details = self.mock_add_system_details - self.annotator.add_formats = self.mock_add_formats - self.annotator.add_summary = self.mock_add_summary - self.annotator.add_genres = self.mock_add_genres - self.annotator.add_ebooks_subject = self.mock_add_ebooks_subject - self.annotator.add_web_client_urls = self.mock_add_web_client_urls - - self.annotate_work_record = functools.partial( - self.annotator.annotate_work_record, - self.revised, - self.work, - self.pool, - self.edition, - self.identifier, - ) - - -@pytest.fixture -def annotate_work_record_fixture() -> AnnotateWorkRecordFixture: - return AnnotateWorkRecordFixture() - - -class TestAnnotator: - def test_annotate_work_record( - self, annotate_work_record_fixture: AnnotateWorkRecordFixture - ) -> None: - fixture = annotate_work_record_fixture - with patch("palace.manager.core.marc.Record") as mock_record: - fixture.annotate_work_record() - - mock_record.assert_called_once_with( - force_utf8=True, leader=fixture.mock_leader.return_value - ) - fixture.mock_leader.assert_called_once_with(fixture.revised) - record = mock_record() - fixture.mock_add_control_fields.assert_called_once_with( - record, fixture.identifier, fixture.pool, fixture.edition - ) - fixture.mock_add_marc_organization_code.assert_called_once_with( - record, fixture.organization_name - ) - fixture.mock_add_isbn.assert_called_once_with(record, fixture.identifier) - fixture.mock_add_title.assert_called_once_with(record, fixture.edition) - fixture.mock_add_contributors.assert_called_once_with(record, fixture.edition) - fixture.mock_add_publisher.assert_called_once_with(record, fixture.edition) - fixture.mock_add_distributor.assert_called_once_with(record, fixture.pool) - fixture.mock_add_physical_description.assert_called_once_with( - record, fixture.edition - ) - fixture.mock_add_audience.assert_called_once_with(record, fixture.work) - fixture.mock_add_series.assert_called_once_with(record, fixture.edition) - fixture.mock_add_system_details.assert_called_once_with(record) - fixture.mock_add_formats.assert_called_once_with(record, fixture.pool) - fixture.mock_add_summary.assert_called_once_with(record, fixture.work) - fixture.mock_add_genres.assert_called_once_with(record, fixture.work) - fixture.mock_add_ebooks_subject.assert_called_once_with(record) - fixture.mock_add_web_client_urls.assert_called_once_with( - record, - fixture.identifier, - fixture.short_name, - fixture.cm_url, - fixture.web_client_urls, - ) - - def test_annotate_work_record_no_summary( - self, annotate_work_record_fixture: AnnotateWorkRecordFixture - ) -> None: - fixture = annotate_work_record_fixture - fixture.annotator.include_summary = False - fixture.annotate_work_record() - - assert fixture.mock_add_summary.call_count == 0 - - def test_annotate_work_record_no_genres( - self, annotate_work_record_fixture: AnnotateWorkRecordFixture - ) -> None: - fixture = annotate_work_record_fixture - fixture.annotator.include_genres = False - fixture.annotate_work_record() - - assert fixture.mock_add_genres.call_count == 0 - - def test_annotate_work_record_no_organization_code( - self, annotate_work_record_fixture: AnnotateWorkRecordFixture - ) -> None: - fixture = annotate_work_record_fixture - fixture.annotator.organization_code = None - fixture.annotate_work_record() - - assert fixture.mock_add_marc_organization_code.call_count == 0 - - def test_leader(self): - leader = Annotator.leader(False) - assert leader == "00000nam 2200000 4500" - - # If the record is revised, the leader is different. - leader = Annotator.leader(True) - assert leader == "00000cam 2200000 4500" - - @staticmethod - def _check_control_field(record, tag, expected): - [field] = record.get_fields(tag) - assert field.value() == expected - - @staticmethod - def _check_field( - record, tag, expected_subfields, expected_indicators: Indicators | None = None - ): - if not expected_indicators: - expected_indicators = Indicators(" ", " ") - [field] = record.get_fields(tag) - assert field.indicators == expected_indicators - for subfield, value in expected_subfields.items(): - assert field.get_subfields(subfield)[0] == value - - def test_add_control_fields(self, db: DatabaseTransactionFixture): - # This edition has one format and was published before 1900. - edition, pool = db.edition(with_license_pool=True) - identifier = pool.identifier - edition.issued = datetime_utc(956, 1, 1) - - now = utc_now() - record = Record() - - Annotator.add_control_fields(record, identifier, pool, edition) - self._check_control_field(record, "001", identifier.urn) - assert now.strftime("%Y%m%d") in record.get_fields("005")[0].value() - self._check_control_field(record, "006", "m d ") - self._check_control_field(record, "007", "cr cn ---anuuu") - self._check_control_field( - record, "008", now.strftime("%y%m%d") + "s0956 xxu eng " - ) - - # This French edition has two formats and was published in 2018. - edition2, pool2 = db.edition(with_license_pool=True) - identifier2 = pool2.identifier - edition2.issued = datetime_utc(2018, 2, 3) - edition2.language = "fre" - LicensePoolDeliveryMechanism.set( - pool2.data_source, - identifier2, - Representation.PDF_MEDIA_TYPE, - DeliveryMechanism.ADOBE_DRM, - RightsStatus.IN_COPYRIGHT, - ) - - record = Record() - Annotator.add_control_fields(record, identifier2, pool2, edition2) - self._check_control_field(record, "001", identifier2.urn) - assert now.strftime("%Y%m%d") in record.get_fields("005")[0].value() - self._check_control_field(record, "006", "m d ") - self._check_control_field(record, "007", "cr cn ---mnuuu") - self._check_control_field( - record, "008", now.strftime("%y%m%d") + "s2018 xxu fre " - ) - - def test_add_marc_organization_code(self): - record = Record() - Annotator.add_marc_organization_code(record, "US-MaBoDPL") - self._check_control_field(record, "003", "US-MaBoDPL") - - def test_add_isbn(self, db: DatabaseTransactionFixture): - isbn = db.identifier(identifier_type=Identifier.ISBN) - record = Record() - Annotator.add_isbn(record, isbn) - self._check_field(record, "020", {"a": isbn.identifier}) - - # If the identifier isn't an ISBN, but has an equivalent that is, it still - # works. - equivalent = db.identifier() - data_source = DataSource.lookup(db.session, DataSource.OCLC) - equivalent.equivalent_to(data_source, isbn, 1) - record = Record() - Annotator.add_isbn(record, equivalent) - self._check_field(record, "020", {"a": isbn.identifier}) - - # If there is no ISBN, the field is left out. - non_isbn = db.identifier() - record = Record() - Annotator.add_isbn(record, non_isbn) - assert [] == record.get_fields("020") - - def test_add_title(self, db: DatabaseTransactionFixture): - edition = db.edition() - edition.title = "The Good Soldier" - edition.sort_title = "Good Soldier, The" - edition.subtitle = "A Tale of Passion" - - record = Record() - Annotator.add_title(record, edition) - [field] = record.get_fields("245") - self._check_field( - record, - "245", - { - "a": edition.title, - "b": edition.subtitle, - "c": edition.author, - }, - Indicators("0", "4"), - ) - - # If there's no subtitle or no author, those subfields are left out. - edition.subtitle = None - edition.author = None - - record = Record() - Annotator.add_title(record, edition) - [field] = record.get_fields("245") - self._check_field( - record, - "245", - { - "a": edition.title, - }, - Indicators("0", "4"), - ) - assert [] == field.get_subfields("b") - assert [] == field.get_subfields("c") - - def test_add_contributors(self, db: DatabaseTransactionFixture): - author = "a" - author2 = "b" - translator = "c" - - # Edition with one author gets a 100 field and no 700 fields. - edition = db.edition(authors=[author]) - edition.sort_author = "sorted" - - record = Record() - Annotator.add_contributors(record, edition) - assert [] == record.get_fields("700") - self._check_field( - record, "100", {"a": edition.sort_author}, Indicators("1", " ") - ) - - # Edition with two authors and a translator gets three 700 fields and no 100 fields. - edition = db.edition(authors=[author, author2]) - edition.add_contributor(translator, Contributor.Role.TRANSLATOR) - - record = Record() - Annotator.add_contributors(record, edition) - assert [] == record.get_fields("100") - fields = record.get_fields("700") - for field in fields: - assert Indicators("1", " ") == field.indicators - [author_field, author2_field, translator_field] = sorted( - fields, key=lambda x: x.get_subfields("a")[0] - ) - assert author == author_field.get_subfields("a")[0] - assert Contributor.Role.PRIMARY_AUTHOR == author_field.get_subfields("e")[0] - assert author2 == author2_field.get_subfields("a")[0] - assert Contributor.Role.AUTHOR == author2_field.get_subfields("e")[0] - assert translator == translator_field.get_subfields("a")[0] - assert Contributor.Role.TRANSLATOR == translator_field.get_subfields("e")[0] - - def test_add_publisher(self, db: DatabaseTransactionFixture): - edition = db.edition() - edition.publisher = db.fresh_str() - edition.issued = datetime_utc(1894, 4, 5) - - record = Record() - Annotator.add_publisher(record, edition) - self._check_field( - record, - "264", - { - "a": "[Place of publication not identified]", - "b": edition.publisher, - "c": "1894", - }, - Indicators(" ", "1"), - ) - - # If there's no publisher, the field is left out. - record = Record() - edition.publisher = None - Annotator.add_publisher(record, edition) - assert [] == record.get_fields("264") - - def test_add_distributor(self, db: DatabaseTransactionFixture): - edition, pool = db.edition(with_license_pool=True) - record = Record() - Annotator.add_distributor(record, pool) - self._check_field( - record, "264", {"b": pool.data_source.name}, Indicators(" ", "2") - ) - - def test_add_physical_description(self, db: DatabaseTransactionFixture): - book = db.edition() - book.medium = Edition.BOOK_MEDIUM - audio = db.edition() - audio.medium = Edition.AUDIO_MEDIUM - - record = Record() - Annotator.add_physical_description(record, book) - self._check_field(record, "300", {"a": "1 online resource"}) - self._check_field( - record, - "336", - { - "a": "text", - "b": "txt", - "2": "rdacontent", - }, - ) - self._check_field( - record, - "337", - { - "a": "computer", - "b": "c", - "2": "rdamedia", - }, - ) - self._check_field( - record, - "338", - { - "a": "online resource", - "b": "cr", - "2": "rdacarrier", - }, - ) - self._check_field( - record, - "347", - { - "a": "text file", - "2": "rda", - }, - ) - self._check_field( - record, - "380", - { - "a": "eBook", - "2": "tlcgt", - }, - ) - - record = Record() - Annotator.add_physical_description(record, audio) - self._check_field( - record, - "300", - { - "a": "1 sound file", - "b": "digital", - }, - ) - self._check_field( - record, - "336", - { - "a": "spoken word", - "b": "spw", - "2": "rdacontent", - }, - ) - self._check_field( - record, - "337", - { - "a": "computer", - "b": "c", - "2": "rdamedia", - }, - ) - self._check_field( - record, - "338", - { - "a": "online resource", - "b": "cr", - "2": "rdacarrier", - }, - ) - self._check_field( - record, - "347", - { - "a": "audio file", - "2": "rda", - }, - ) - assert [] == record.get_fields("380") - - def test_add_audience(self, db: DatabaseTransactionFixture): - for audience, term in list(Annotator.AUDIENCE_TERMS.items()): - work = db.work(audience=audience) - record = Record() - Annotator.add_audience(record, work) - self._check_field( - record, - "385", - { - "a": term, - "2": "tlctarget", - }, - ) - - def test_add_series(self, db: DatabaseTransactionFixture): - edition = db.edition() - edition.series = db.fresh_str() - edition.series_position = 5 - record = Record() - Annotator.add_series(record, edition) - self._check_field( - record, - "490", - { - "a": edition.series, - "v": str(edition.series_position), - }, - Indicators("0", " "), - ) - - # If there's no series position, the same field is used without - # the v subfield. - edition.series_position = None - record = Record() - Annotator.add_series(record, edition) - self._check_field( - record, - "490", - { - "a": edition.series, - }, - Indicators("0", " "), - ) - [field] = record.get_fields("490") - assert [] == field.get_subfields("v") - - # If there's no series, the field is left out. - edition.series = None - record = Record() - Annotator.add_series(record, edition) - assert [] == record.get_fields("490") - - def test_add_system_details(self): - record = Record() - Annotator.add_system_details(record) - self._check_field(record, "538", {"a": "Mode of access: World Wide Web."}) - - def test_add_formats(self, db: DatabaseTransactionFixture): - edition, pool = db.edition(with_license_pool=True) - epub_no_drm, ignore = DeliveryMechanism.lookup( - db.session, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM - ) - pool.delivery_mechanisms[0].delivery_mechanism = epub_no_drm - LicensePoolDeliveryMechanism.set( - pool.data_source, - pool.identifier, - Representation.PDF_MEDIA_TYPE, - DeliveryMechanism.ADOBE_DRM, - RightsStatus.IN_COPYRIGHT, - ) - - record = Record() - Annotator.add_formats(record, pool) - fields = record.get_fields("538") - assert 2 == len(fields) - [pdf, epub] = sorted(fields, key=lambda x: x.get_subfields("a")[0]) - assert "Adobe PDF eBook" == pdf.get_subfields("a")[0] - assert Indicators(" ", " ") == pdf.indicators - assert "EPUB eBook" == epub.get_subfields("a")[0] - assert Indicators(" ", " ") == epub.indicators - - def test_add_summary(self, db: DatabaseTransactionFixture): - work = db.work(with_license_pool=True) - work.summary_text = "

Summary

" - - # Build and validate a record with a `520|a` summary. - record = Record() - Annotator.add_summary(record, work) - self._check_field(record, "520", {"a": " Summary "}) - exported_record = record.as_marc() - - # Round trip the exported record to validate it. - marc_reader = MARCReader(exported_record) - round_tripped_record = next(marc_reader) - self._check_field(round_tripped_record, "520", {"a": " Summary "}) - - def test_add_simplified_genres(self, db: DatabaseTransactionFixture): - work = db.work(with_license_pool=True) - fantasy, ignore = Genre.lookup(db.session, "Fantasy", autocreate=True) - romance, ignore = Genre.lookup(db.session, "Romance", autocreate=True) - work.genres = [fantasy, romance] - - record = Record() - Annotator.add_genres(record, work) - fields = record.get_fields("650") - [fantasy_field, romance_field] = sorted( - fields, key=lambda x: x.get_subfields("a")[0] - ) - assert Indicators("0", "7") == fantasy_field.indicators - assert "Fantasy" == fantasy_field.get_subfields("a")[0] - assert "Library Simplified" == fantasy_field.get_subfields("2")[0] - assert Indicators("0", "7") == romance_field.indicators - assert "Romance" == romance_field.get_subfields("a")[0] - assert "Library Simplified" == romance_field.get_subfields("2")[0] - - def test_add_ebooks_subject(self): - record = Record() - Annotator.add_ebooks_subject(record) - self._check_field( - record, "655", {"a": "Electronic books."}, Indicators(" ", "0") - ) - - def test_add_web_client_urls_empty(self): - record = MagicMock(spec=Record) - identifier = MagicMock() - Annotator.add_web_client_urls(record, identifier, "", "", []) - assert record.add_field.call_count == 0 - - def test_add_web_client_urls(self, db: DatabaseTransactionFixture): - record = Record() - identifier = db.identifier() - short_name = "short_name" - cm_url = "http://cm.url" - web_client_urls = ["http://webclient1.url", "http://webclient2.url"] - Annotator.add_web_client_urls( - record, identifier, short_name, cm_url, web_client_urls - ) - fields = record.get_fields("856") - assert len(fields) == 2 - [field1, field2] = fields - assert field1.indicators == Indicators("4", "0") - assert field2.indicators == Indicators("4", "0") - - # The URL for a work is constructed as: - # - //works/ - work_link_template = "{cm_base}/{lib}/works/{qid}" - # It is then encoded and the web client URL is constructed in this form: - # - /book/ - client_url_template = "{client_base}/book/{work_link}" - - qualified_identifier = urllib.parse.quote( - identifier.type + "/" + identifier.identifier, safe="" - ) - - expected_work_link = work_link_template.format( - cm_base=cm_url, lib=short_name, qid=qualified_identifier - ) - encoded_work_link = urllib.parse.quote(expected_work_link, safe="") - - expected_client_url_1 = client_url_template.format( - client_base=web_client_urls[0], work_link=encoded_work_link - ) - expected_client_url_2 = client_url_template.format( - client_base=web_client_urls[1], work_link=encoded_work_link - ) - - # A few checks to ensure that our setup is useful. - assert web_client_urls[0] != web_client_urls[1] - assert expected_client_url_1 != expected_client_url_2 - assert expected_client_url_1.startswith(web_client_urls[0]) - assert expected_client_url_2.startswith(web_client_urls[1]) - - assert field1.get_subfields("u")[0] == expected_client_url_1 - assert field2.get_subfields("u")[0] == expected_client_url_2 - - -class MarcExporterFixture: - def __init__(self, db: DatabaseTransactionFixture, s3: MockS3Service): - self.db = db - - self.now = utc_now() - self.library = db.default_library() - self.s3_service = s3 - self.exporter = MARCExporter(self.db.session, s3) - self.mock_annotator = MagicMock(spec=Annotator) - assert self.library.short_name is not None - self.annotator = Annotator( - "http://cm.url", - self.library.short_name, - ["http://webclient.url"], - "org", - True, - True, - ) - - self.library = db.library() - self.collection = db.collection() - self.collection.libraries.append(self.library) - - self.now = utc_now() - self.yesterday = self.now - datetime.timedelta(days=1) - self.last_week = self.now - datetime.timedelta(days=7) - - self.w1 = db.work( - genre="Mystery", with_open_access_download=True, collection=self.collection - ) - self.w1.last_update_time = self.yesterday - self.w2 = db.work( - genre="Mystery", with_open_access_download=True, collection=self.collection - ) - self.w2.last_update_time = self.last_week - - self.records = functools.partial( - self.exporter.records, - self.library, - self.collection, - annotator=self.annotator, - creation_time=self.now, - ) - - -@pytest.fixture -def marc_exporter_fixture( - db: DatabaseTransactionFixture, - s3_service_fixture: S3ServiceFixture, -) -> MarcExporterFixture: - return MarcExporterFixture(db, s3_service_fixture.mock_service()) - - -class TestMARCExporter: - def test_create_record( - self, db: DatabaseTransactionFixture, marc_exporter_fixture: MarcExporterFixture - ): - work = db.work( - with_license_pool=True, - title="old title", - authors=["old author"], - data_source_name=DataSource.OVERDRIVE, - ) - - mock_revised = MagicMock() - - create_record = functools.partial( - MARCExporter.create_record, - revised=mock_revised, - work=work, - annotator=marc_exporter_fixture.mock_annotator, - ) - - record = create_record() - assert record is not None - - # Make sure we pass the expected arguments to Annotator.annotate_work_record - marc_exporter_fixture.mock_annotator.annotate_work_record.assert_called_once_with( - mock_revised, - work, - work.license_pools[0], - work.license_pools[0].presentation_edition, - work.license_pools[0].identifier, - ) - - def test_records( - self, - db: DatabaseTransactionFixture, - marc_exporter_fixture: MarcExporterFixture, - ): - storage_service = marc_exporter_fixture.s3_service - creation_time = marc_exporter_fixture.now - - marc_exporter_fixture.records() - - # The file was mirrored and a MarcFile was created to track the mirrored file. - assert len(storage_service.uploads) == 1 - [cache] = db.session.query(MarcFile).all() - assert cache.library == marc_exporter_fixture.library - assert cache.collection == marc_exporter_fixture.collection - - short_name = marc_exporter_fixture.library.short_name - collection_name = marc_exporter_fixture.collection.name - date_str = creation_time.strftime("%Y-%m-%d") - uuid_str = uuid_encode(cache.id) - - assert ( - cache.key - == f"marc/{short_name}/{collection_name}.full.{date_str}.{uuid_str}.mrc" - ) - assert cache.created == creation_time - assert cache.since is None - - records = list(MARCReader(storage_service.uploads[0].content)) - assert len(records) == 2 - - title_fields = [record.get_fields("245") for record in records] - titles = {fields[0].get_subfields("a")[0] for fields in title_fields} - assert titles == { - marc_exporter_fixture.w1.title, - marc_exporter_fixture.w2.title, - } - - def test_records_since_time( - self, - db: DatabaseTransactionFixture, - marc_exporter_fixture: MarcExporterFixture, - ): - # If the `since` parameter is set, only works updated since that time - # are included in the export and the filename reflects that we created - # a partial export. - since = marc_exporter_fixture.now - datetime.timedelta(days=3) - storage_service = marc_exporter_fixture.s3_service - creation_time = marc_exporter_fixture.now - - marc_exporter_fixture.records( - since_time=since, - ) - [cache] = db.session.query(MarcFile).all() - assert cache.library == marc_exporter_fixture.library - assert cache.collection == marc_exporter_fixture.collection - - short_name = marc_exporter_fixture.library.short_name - collection_name = marc_exporter_fixture.collection.name - from_date = since.strftime("%Y-%m-%d") - to_date = creation_time.strftime("%Y-%m-%d") - uuid_str = uuid_encode(cache.id) - - assert ( - cache.key - == f"marc/{short_name}/{collection_name}.delta.{from_date}.{to_date}.{uuid_str}.mrc" - ) - assert cache.created == creation_time - assert cache.since == since - - # Only the work updated since the `since` time is included in the export. - [record] = list(MARCReader(storage_service.uploads[0].content)) - [title_field] = record.get_fields("245") - assert title_field.get_subfields("a")[0] == marc_exporter_fixture.w1.title - - def test_records_none( - self, - db: DatabaseTransactionFixture, - marc_exporter_fixture: MarcExporterFixture, - caplog: LogCaptureFixture, - ): - # If there are no works to export, no file is created and a log message is generated. - caplog.set_level(logging.INFO) - - storage_service = marc_exporter_fixture.s3_service - - # Remove the works from the database. - db.session.delete(marc_exporter_fixture.w1) - db.session.delete(marc_exporter_fixture.w2) - - marc_exporter_fixture.records() - - assert [] == storage_service.uploads - assert db.session.query(MarcFile).count() == 0 - assert len(caplog.records) == 1 - assert "No MARC records to upload" in caplog.text - - def test_records_exception( - self, - db: DatabaseTransactionFixture, - marc_exporter_fixture: MarcExporterFixture, - caplog: LogCaptureFixture, - ): - # If an exception occurs while exporting, no file is created and a log message is generated. - caplog.set_level(logging.ERROR) - - exporter = marc_exporter_fixture.exporter - storage_service = marc_exporter_fixture.s3_service - - # Mock our query function to raise an exception. - exporter.query_works = MagicMock(side_effect=Exception("Boom!")) - - marc_exporter_fixture.records() - - assert [] == storage_service.uploads - assert db.session.query(MarcFile).count() == 0 - assert len(caplog.records) == 1 - assert "Failed to upload MARC file" in caplog.text - assert "Boom!" in caplog.text - - def test_records_minimum_size( - self, - marc_exporter_fixture: MarcExporterFixture, - ): - exporter = marc_exporter_fixture.exporter - storage_service = marc_exporter_fixture.s3_service - - exporter.MINIMUM_UPLOAD_BATCH_SIZE_BYTES = 100 - - # Mock the "records" generated, and force the response to be of certain sizes - created_record_mock = MagicMock() - created_record_mock.as_marc = MagicMock( - side_effect=[b"1" * 600, b"2" * 20, b"3" * 500, b"4" * 10] - ) - exporter.create_record = lambda *args: created_record_mock - - # Mock the query_works to return 4 works - exporter.query_works = MagicMock( - return_value=[MagicMock(), MagicMock(), MagicMock(), MagicMock()] - ) - - marc_exporter_fixture.records() - - assert storage_service.mocked_multipart_upload is not None - # Even though there are 4 parts, we upload in 3 batches due to minimum size limitations - # The "4"th part gets uploaded due it being the tail piece - assert len(storage_service.mocked_multipart_upload.content_parts) == 3 - assert storage_service.mocked_multipart_upload.content_parts == [ - b"1" * 600, - b"2" * 20 + b"3" * 500, - b"4" * 10, - ] diff --git a/tests/manager/marc/__init__.py b/tests/manager/marc/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/manager/marc/test_annotator.py b/tests/manager/marc/test_annotator.py new file mode 100644 index 0000000000..41d59c1254 --- /dev/null +++ b/tests/manager/marc/test_annotator.py @@ -0,0 +1,716 @@ +from __future__ import annotations + +import functools +import urllib +from unittest.mock import MagicMock + +import pytest +from freezegun import freeze_time +from pymarc import Indicators, MARCReader, Record + +from palace.manager.marc.annotator import Annotator +from palace.manager.sqlalchemy.model.classification import Genre +from palace.manager.sqlalchemy.model.contributor import Contributor +from palace.manager.sqlalchemy.model.datasource import DataSource +from palace.manager.sqlalchemy.model.edition import Edition +from palace.manager.sqlalchemy.model.identifier import Identifier +from palace.manager.sqlalchemy.model.licensing import ( + DeliveryMechanism, + LicensePool, + LicensePoolDeliveryMechanism, + RightsStatus, +) +from palace.manager.sqlalchemy.model.resource import Representation +from palace.manager.sqlalchemy.model.work import Work +from palace.manager.util.datetime_helpers import datetime_utc, utc_now +from tests.fixtures.database import DatabaseTransactionFixture + + +class AnnotatorFixture: + def __init__(self, db: DatabaseTransactionFixture): + self._db = db + self.cm_url = "http://cm.url" + self.short_name = "short_name" + self.web_client_urls = ["http://webclient.url"] + self.organization_name = "org" + self.include_summary = True + self.include_genres = True + + self.annotator = Annotator() + + @staticmethod + def assert_control_field(record: Record, tag: str, expected: str) -> None: + [field] = record.get_fields(tag) + assert field.value() == expected + + @staticmethod + def assert_field( + record: Record, + tag: str, + expected_subfields: dict[str, str], + expected_indicators: Indicators | None = None, + ) -> None: + if not expected_indicators: + expected_indicators = Indicators(" ", " ") + [field] = record.get_fields(tag) + assert field.indicators == expected_indicators + for subfield, value in expected_subfields.items(): + assert field.get_subfields(subfield)[0] == value + + @staticmethod + def record_tags(record: Record) -> set[int]: + return {int(f.tag) for f in record.fields} + + def assert_record_tags( + self, + record: Record, + includes: set[int] | None = None, + excludes: set[int] | None = None, + ) -> None: + tags = self.record_tags(record) + assert includes or excludes + if includes: + assert includes.issubset(tags) + if excludes: + assert excludes.isdisjoint(tags) + + def record(self) -> Record: + return self.annotator._record() + + def test_work(self) -> tuple[Work, LicensePool]: + edition, pool = self._db.edition( + with_license_pool=True, identifier_type=Identifier.ISBN + ) + work = self._db.work(presentation_edition=edition) + work.summary_text = "Summary" + fantasy, ignore = Genre.lookup(self._db.session, "Fantasy", autocreate=True) + romance, ignore = Genre.lookup(self._db.session, "Romance", autocreate=True) + work.genres = [fantasy, romance] + edition.issued = datetime_utc(956, 1, 1) + edition.series = self._db.fresh_str() + edition.series_position = 5 + return work, pool + + +@pytest.fixture +def annotator_fixture( + db: DatabaseTransactionFixture, +) -> AnnotatorFixture: + return AnnotatorFixture(db) + + +class TestAnnotator: + def test_marc_record( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ) -> None: + work, pool = annotator_fixture.test_work() + annotator = annotator_fixture.annotator + + record = annotator.marc_record(work, pool) + assert annotator_fixture.record_tags(record) == { + 1, + 5, + 6, + 7, + 8, + 20, + 245, + 100, + 264, + 300, + 336, + 385, + 490, + 538, + 655, + 520, + 650, + 337, + 338, + 347, + 380, + } + + def test__copy_record(self, annotator_fixture: AnnotatorFixture): + work, pool = annotator_fixture.test_work() + annotator = annotator_fixture.annotator + record = annotator.marc_record(work, pool) + copied = annotator_fixture.annotator._copy_record(record) + assert copied is not record + assert copied.as_marc() == record.as_marc() + + def test_library_marc_record(self, annotator_fixture: AnnotatorFixture): + work, pool = annotator_fixture.test_work() + annotator = annotator_fixture.annotator + generic_record = annotator.marc_record(work, pool) + + library_marc_record = functools.partial( + annotator.library_marc_record, + record=generic_record, + identifier=pool.identifier, + base_url="http://cm.url", + library_short_name="short_name", + web_client_urls=["http://webclient.url"], + organization_code="xyz", + include_summary=True, + include_genres=True, + ) + + library_record = library_marc_record() + annotator_fixture.assert_record_tags( + library_record, includes={3, 520, 650, 856} + ) + + # Make sure the generic record did not get modified. + assert generic_record != library_record + assert generic_record.as_marc() != library_record.as_marc() + annotator_fixture.assert_record_tags(generic_record, excludes={3, 856}) + + # If the summary is not included, the 520 field is left out. + library_record = library_marc_record(include_summary=False) + annotator_fixture.assert_record_tags( + library_record, includes={3, 650, 856}, excludes={520} + ) + + # If the genres are not included, the 650 field is left out. + library_record = library_marc_record(include_genres=False) + annotator_fixture.assert_record_tags( + library_record, includes={3, 520, 856}, excludes={650} + ) + + # If the genres and summary are not included, the 520 and 650 fields are left out. + library_record = library_marc_record( + include_summary=False, include_genres=False + ) + annotator_fixture.assert_record_tags( + library_record, includes={3, 856}, excludes={520, 650} + ) + + # If the organization code is not provided, the 003 field is left out. + library_record = library_marc_record(organization_code=None) + annotator_fixture.assert_record_tags( + library_record, includes={520, 650, 856}, excludes={3} + ) + + # If the web client URLs are not provided, the 856 fields are left out. + library_record = library_marc_record(web_client_urls=[]) + annotator_fixture.assert_record_tags( + library_record, includes={3, 520, 650}, excludes={856} + ) + + def test_leader(self, annotator_fixture: AnnotatorFixture): + leader = annotator_fixture.annotator.leader(False) + assert leader == "00000nam 2200000 4500" + + # If the record is revised, the leader is different. + leader = Annotator.leader(True) + assert leader == "00000cam 2200000 4500" + + @freeze_time() + def test_add_control_fields( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + # This edition has one format and was published before 1900. + edition, pool = db.edition(with_license_pool=True) + identifier = pool.identifier + edition.issued = datetime_utc(956, 1, 1) + + now = utc_now() + record = annotator_fixture.record() + + annotator_fixture.annotator.add_control_fields( + record, identifier, pool, edition + ) + annotator_fixture.assert_control_field(record, "001", identifier.urn) + assert now.strftime("%Y%m%d") in record.get_fields("005")[0].value() + annotator_fixture.assert_control_field(record, "006", "m d ") + annotator_fixture.assert_control_field(record, "007", "cr cn ---anuuu") + annotator_fixture.assert_control_field( + record, "008", now.strftime("%y%m%d") + "s0956 xxu eng " + ) + + # This French edition has two formats and was published in 2018. + edition2, pool2 = db.edition(with_license_pool=True) + identifier2 = pool2.identifier + edition2.issued = datetime_utc(2018, 2, 3) + edition2.language = "fre" + LicensePoolDeliveryMechanism.set( + pool2.data_source, + identifier2, + Representation.PDF_MEDIA_TYPE, + DeliveryMechanism.ADOBE_DRM, + RightsStatus.IN_COPYRIGHT, + ) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_control_fields( + record, identifier2, pool2, edition2 + ) + annotator_fixture.assert_control_field(record, "001", identifier2.urn) + assert now.strftime("%Y%m%d") in record.get_fields("005")[0].value() + annotator_fixture.assert_control_field(record, "006", "m d ") + annotator_fixture.assert_control_field(record, "007", "cr cn ---mnuuu") + annotator_fixture.assert_control_field( + record, "008", now.strftime("%y%m%d") + "s2018 xxu fre " + ) + + def test_add_marc_organization_code(self, annotator_fixture: AnnotatorFixture): + record = annotator_fixture.record() + annotator_fixture.annotator.add_marc_organization_code(record, "US-MaBoDPL") + annotator_fixture.assert_control_field(record, "003", "US-MaBoDPL") + + def test_add_isbn( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + isbn = db.identifier(identifier_type=Identifier.ISBN) + record = annotator_fixture.record() + annotator_fixture.annotator.add_isbn(record, isbn) + annotator_fixture.assert_field(record, "020", {"a": isbn.identifier}) + + # If the identifier isn't an ISBN, but has an equivalent that is, it still + # works. + equivalent = db.identifier() + data_source = DataSource.lookup(db.session, DataSource.OCLC) + equivalent.equivalent_to(data_source, isbn, 1) + record = annotator_fixture.record() + annotator_fixture.annotator.add_isbn(record, equivalent) + annotator_fixture.assert_field(record, "020", {"a": isbn.identifier}) + + # If there is no ISBN, the field is left out. + non_isbn = db.identifier() + record = annotator_fixture.record() + annotator_fixture.annotator.add_isbn(record, non_isbn) + assert [] == record.get_fields("020") + + def test_add_title( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition = db.edition() + edition.title = "The Good Soldier" + edition.sort_title = "Good Soldier, The" + edition.subtitle = "A Tale of Passion" + + record = annotator_fixture.record() + annotator_fixture.annotator.add_title(record, edition) + assert len(record.get_fields("245")) == 1 + annotator_fixture.assert_field( + record, + "245", + { + "a": edition.title, + "b": edition.subtitle, + "c": edition.author, + }, + Indicators("0", "4"), + ) + + # If there's no subtitle or no author, those subfields are left out. + edition.subtitle = None + edition.author = None + + record = annotator_fixture.record() + annotator_fixture.annotator.add_title(record, edition) + [field] = record.get_fields("245") + annotator_fixture.assert_field( + record, + "245", + { + "a": edition.title, + }, + Indicators("0", "4"), + ) + assert [] == field.get_subfields("b") + assert [] == field.get_subfields("c") + + def test_add_contributors( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + author = "a" + author2 = "b" + translator = "c" + + # Edition with one author gets a 100 field and no 700 fields. + edition = db.edition(authors=[author]) + edition.sort_author = "sorted" + + record = annotator_fixture.record() + annotator_fixture.annotator.add_contributors(record, edition) + assert [] == record.get_fields("700") + annotator_fixture.assert_field( + record, "100", {"a": edition.sort_author}, Indicators("1", " ") + ) + + # Edition with two authors and a translator gets three 700 fields and no 100 fields. + edition = db.edition(authors=[author, author2]) + edition.add_contributor(translator, Contributor.Role.TRANSLATOR) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_contributors(record, edition) + assert [] == record.get_fields("100") + fields = record.get_fields("700") + for field in fields: + assert Indicators("1", " ") == field.indicators + [author_field, author2_field, translator_field] = sorted( + fields, key=lambda x: x.get_subfields("a")[0] + ) + assert author == author_field.get_subfields("a")[0] + assert Contributor.Role.PRIMARY_AUTHOR == author_field.get_subfields("e")[0] + assert author2 == author2_field.get_subfields("a")[0] + assert Contributor.Role.AUTHOR == author2_field.get_subfields("e")[0] + assert translator == translator_field.get_subfields("a")[0] + assert Contributor.Role.TRANSLATOR == translator_field.get_subfields("e")[0] + + def test_add_publisher( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition = db.edition() + edition.publisher = db.fresh_str() + edition.issued = datetime_utc(1894, 4, 5) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_publisher(record, edition) + annotator_fixture.assert_field( + record, + "264", + { + "a": "[Place of publication not identified]", + "b": edition.publisher, + "c": "1894", + }, + Indicators(" ", "1"), + ) + + # If there's no publisher, the field is left out. + record = annotator_fixture.record() + edition.publisher = None + annotator_fixture.annotator.add_publisher(record, edition) + assert [] == record.get_fields("264") + + def test_add_distributor( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition, pool = db.edition(with_license_pool=True) + record = annotator_fixture.record() + annotator_fixture.annotator.add_distributor(record, pool) + annotator_fixture.assert_field( + record, "264", {"b": pool.data_source.name}, Indicators(" ", "2") + ) + + def test_add_physical_description( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + book = db.edition() + book.medium = Edition.BOOK_MEDIUM + audio = db.edition() + audio.medium = Edition.AUDIO_MEDIUM + + record = annotator_fixture.record() + annotator_fixture.annotator.add_physical_description(record, book) + annotator_fixture.assert_field(record, "300", {"a": "1 online resource"}) + annotator_fixture.assert_field( + record, + "336", + { + "a": "text", + "b": "txt", + "2": "rdacontent", + }, + ) + annotator_fixture.assert_field( + record, + "337", + { + "a": "computer", + "b": "c", + "2": "rdamedia", + }, + ) + annotator_fixture.assert_field( + record, + "338", + { + "a": "online resource", + "b": "cr", + "2": "rdacarrier", + }, + ) + annotator_fixture.assert_field( + record, + "347", + { + "a": "text file", + "2": "rda", + }, + ) + annotator_fixture.assert_field( + record, + "380", + { + "a": "eBook", + "2": "tlcgt", + }, + ) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_physical_description(record, audio) + annotator_fixture.assert_field( + record, + "300", + { + "a": "1 sound file", + "b": "digital", + }, + ) + annotator_fixture.assert_field( + record, + "336", + { + "a": "spoken word", + "b": "spw", + "2": "rdacontent", + }, + ) + annotator_fixture.assert_field( + record, + "337", + { + "a": "computer", + "b": "c", + "2": "rdamedia", + }, + ) + annotator_fixture.assert_field( + record, + "338", + { + "a": "online resource", + "b": "cr", + "2": "rdacarrier", + }, + ) + annotator_fixture.assert_field( + record, + "347", + { + "a": "audio file", + "2": "rda", + }, + ) + assert [] == record.get_fields("380") + + def test_add_audience( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + for audience, term in list(annotator_fixture.annotator.AUDIENCE_TERMS.items()): + work = db.work(audience=audience) + record = annotator_fixture.record() + annotator_fixture.annotator.add_audience(record, work) + annotator_fixture.assert_field( + record, + "385", + { + "a": term, + "2": "tlctarget", + }, + ) + + def test_add_series( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition = db.edition() + edition.series = db.fresh_str() + edition.series_position = 5 + record = annotator_fixture.record() + annotator_fixture.annotator.add_series(record, edition) + annotator_fixture.assert_field( + record, + "490", + { + "a": edition.series, + "v": str(edition.series_position), + }, + Indicators("0", " "), + ) + + # If there's no series position, the same field is used without + # the v subfield. + edition.series_position = None + record = annotator_fixture.record() + annotator_fixture.annotator.add_series(record, edition) + annotator_fixture.assert_field( + record, + "490", + { + "a": edition.series, + }, + Indicators("0", " "), + ) + [field] = record.get_fields("490") + assert [] == field.get_subfields("v") + + # If there's no series, the field is left out. + edition.series = None + record = annotator_fixture.record() + annotator_fixture.annotator.add_series(record, edition) + assert [] == record.get_fields("490") + + def test_add_system_details(self, annotator_fixture: AnnotatorFixture): + record = annotator_fixture.record() + annotator_fixture.annotator.add_system_details(record) + annotator_fixture.assert_field( + record, "538", {"a": "Mode of access: World Wide Web."} + ) + + def test_add_formats( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + edition, pool = db.edition(with_license_pool=True) + epub_no_drm, ignore = DeliveryMechanism.lookup( + db.session, Representation.EPUB_MEDIA_TYPE, DeliveryMechanism.NO_DRM + ) + pool.delivery_mechanisms[0].delivery_mechanism = epub_no_drm + LicensePoolDeliveryMechanism.set( + pool.data_source, + pool.identifier, + Representation.PDF_MEDIA_TYPE, + DeliveryMechanism.ADOBE_DRM, + RightsStatus.IN_COPYRIGHT, + ) + + record = annotator_fixture.record() + annotator_fixture.annotator.add_formats(record, pool) + fields = record.get_fields("538") + assert 2 == len(fields) + [pdf, epub] = sorted(fields, key=lambda x: x.get_subfields("a")[0]) + assert "Adobe PDF eBook" == pdf.get_subfields("a")[0] + assert Indicators(" ", " ") == pdf.indicators + assert "EPUB eBook" == epub.get_subfields("a")[0] + assert Indicators(" ", " ") == epub.indicators + + def test_add_summary( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + work = db.work(with_license_pool=True) + work.summary_text = "

Summary

" + + # Build and validate a record with a `520|a` summary. + record = annotator_fixture.record() + annotator_fixture.annotator.add_summary(record, work) + annotator_fixture.assert_field(record, "520", {"a": " Summary "}) + exported_record = record.as_marc() + + # Round trip the exported record to validate it. + marc_reader = MARCReader(exported_record) + round_tripped_record = next(marc_reader) + annotator_fixture.assert_field(round_tripped_record, "520", {"a": " Summary "}) + + def test_add_simplified_genres( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + work = db.work(with_license_pool=True) + fantasy, ignore = Genre.lookup(db.session, "Fantasy", autocreate=True) + romance, ignore = Genre.lookup(db.session, "Romance", autocreate=True) + work.genres = [fantasy, romance] + + record = annotator_fixture.record() + annotator_fixture.annotator.add_genres(record, work) + fields = record.get_fields("650") + [fantasy_field, romance_field] = sorted( + fields, key=lambda x: x.get_subfields("a")[0] + ) + assert Indicators("0", "7") == fantasy_field.indicators + assert "Fantasy" == fantasy_field.get_subfields("a")[0] + assert "Library Simplified" == fantasy_field.get_subfields("2")[0] + assert Indicators("0", "7") == romance_field.indicators + assert "Romance" == romance_field.get_subfields("a")[0] + assert "Library Simplified" == romance_field.get_subfields("2")[0] + + def test_add_ebooks_subject(self, annotator_fixture: AnnotatorFixture): + record = annotator_fixture.record() + annotator_fixture.annotator.add_ebooks_subject(record) + annotator_fixture.assert_field( + record, "655", {"a": "Electronic books."}, Indicators(" ", "0") + ) + + def test_add_web_client_urls_empty(self, annotator_fixture: AnnotatorFixture): + record = MagicMock(spec=Record) + identifier = MagicMock() + annotator_fixture.annotator.add_web_client_urls(record, identifier, "", "", []) + assert record.add_field.call_count == 0 + + def test_add_web_client_urls( + self, + db: DatabaseTransactionFixture, + annotator_fixture: AnnotatorFixture, + ): + record = annotator_fixture.record() + identifier = db.identifier() + short_name = "short_name" + cm_url = "http://cm.url" + web_client_urls = ["http://webclient1.url", "http://webclient2.url"] + annotator_fixture.annotator.add_web_client_urls( + record, identifier, short_name, cm_url, web_client_urls + ) + fields = record.get_fields("856") + assert len(fields) == 2 + [field1, field2] = fields + assert field1.indicators == Indicators("4", "0") + assert field2.indicators == Indicators("4", "0") + + # The URL for a work is constructed as: + # - //works/ + work_link_template = "{cm_base}/{lib}/works/{qid}" + # It is then encoded and the web client URL is constructed in this form: + # - /book/ + client_url_template = "{client_base}/book/{work_link}" + + qualified_identifier = urllib.parse.quote( + identifier.type + "/" + identifier.identifier, safe="" + ) + + expected_work_link = work_link_template.format( + cm_base=cm_url, lib=short_name, qid=qualified_identifier + ) + encoded_work_link = urllib.parse.quote(expected_work_link, safe="") + + expected_client_url_1 = client_url_template.format( + client_base=web_client_urls[0], work_link=encoded_work_link + ) + expected_client_url_2 = client_url_template.format( + client_base=web_client_urls[1], work_link=encoded_work_link + ) + + # A few checks to ensure that our setup is useful. + assert web_client_urls[0] != web_client_urls[1] + assert expected_client_url_1 != expected_client_url_2 + assert expected_client_url_1.startswith(web_client_urls[0]) + assert expected_client_url_2.startswith(web_client_urls[1]) + + assert field1.get_subfields("u")[0] == expected_client_url_1 + assert field2.get_subfields("u")[0] == expected_client_url_2 diff --git a/tests/manager/marc/test_exporter.py b/tests/manager/marc/test_exporter.py new file mode 100644 index 0000000000..c09c6c80a5 --- /dev/null +++ b/tests/manager/marc/test_exporter.py @@ -0,0 +1,425 @@ +import datetime +from functools import partial +from unittest.mock import ANY, call, create_autospec +from uuid import UUID + +import pytest +from freezegun import freeze_time + +from palace.manager.marc.exporter import LibraryInfo, MarcExporter +from palace.manager.marc.settings import MarcExporterLibrarySettings +from palace.manager.marc.uploader import MarcUploadManager +from palace.manager.sqlalchemy.model.discovery_service_registration import ( + DiscoveryServiceRegistration, +) +from palace.manager.sqlalchemy.model.marcfile import MarcFile +from palace.manager.sqlalchemy.util import create, get_one +from palace.manager.util.datetime_helpers import datetime_utc, utc_now +from tests.fixtures.database import DatabaseTransactionFixture +from tests.fixtures.marc import MarcExporterFixture + + +class TestMarcExporter: + def test__s3_key(self, marc_exporter_fixture: MarcExporterFixture) -> None: + library = marc_exporter_fixture.library1 + collection = marc_exporter_fixture.collection1 + + uuid = UUID("c2370bf2-28e1-40ff-9f04-4864306bd11c") + now = datetime_utc(2024, 8, 27) + since = datetime_utc(2024, 8, 20) + + s3_key = partial(MarcExporter._s3_key, library, collection, now, uuid) + + assert ( + s3_key() + == f"marc/{library.short_name}/{collection.name}.full.2024-08-27.wjcL8ijhQP-fBEhkMGvRHA.mrc" + ) + + assert ( + s3_key(since_time=since) + == f"marc/{library.short_name}/{collection.name}.delta.2024-08-20.2024-08-27.wjcL8ijhQP-fBEhkMGvRHA.mrc" + ) + + @freeze_time("2020-02-20T10:00:00Z") + @pytest.mark.parametrize( + "last_updated_time, update_frequency, expected", + [ + (None, 60, True), + (None, 1, True), + (datetime.datetime.fromisoformat("2020-02-20T09:00:00"), 1, False), + (datetime.datetime.fromisoformat("2020-02-19T10:02:00"), 1, True), + (datetime.datetime.fromisoformat("2020-01-31T10:02:00"), 20, True), + (datetime.datetime.fromisoformat("2020-02-01T10:00:00"), 20, False), + ], + ) + def test__needs_update( + self, + last_updated_time: datetime.datetime, + update_frequency: int, + expected: bool, + ): + assert ( + MarcExporter._needs_update(last_updated_time, update_frequency) == expected + ) + + def test__web_client_urls( + self, + db: DatabaseTransactionFixture, + marc_exporter_fixture: MarcExporterFixture, + ): + library = marc_exporter_fixture.library1 + web_client_urls = partial(MarcExporter._web_client_urls, db.session, library) + + # No web client URLs are returned if there are no discovery service registrations. + assert web_client_urls() == () + + # If we pass in a configured web client URL, that URL is returned. + assert web_client_urls(url="http://web-client") == ("http://web-client",) + + # Add a URL from a library registry. + registry = db.discovery_service_integration() + create( + db.session, + DiscoveryServiceRegistration, + library=library, + integration=registry, + web_client="http://web-client/registry", + ) + assert web_client_urls() == ("http://web-client/registry",) + + # URL from library registry and configured URL are both returned. + assert web_client_urls(url="http://web-client") == ( + "http://web-client/registry", + "http://web-client", + ) + + def test__enabled_collections_and_libraries( + self, + db: DatabaseTransactionFixture, + marc_exporter_fixture: MarcExporterFixture, + ) -> None: + enabled_collections_and_libraries = partial( + MarcExporter._enabled_collections_and_libraries, + db.session, + marc_exporter_fixture.registry, + ) + + assert enabled_collections_and_libraries() == set() + + # Marc export is enabled on the collections, but since the libraries don't have a marc exporter, they are + # not included. + marc_exporter_fixture.collection1.export_marc_records = True + marc_exporter_fixture.collection2.export_marc_records = True + assert enabled_collections_and_libraries() == set() + + # Marc export is enabled, but no libraries are added to it + marc_integration = marc_exporter_fixture.integration() + assert enabled_collections_and_libraries() == set() + + # Add a marc exporter to library1 + marc_l1_config = db.integration_library_configuration( + marc_integration, marc_exporter_fixture.library1 + ) + assert enabled_collections_and_libraries() == { + (marc_exporter_fixture.collection1, marc_l1_config), + (marc_exporter_fixture.collection2, marc_l1_config), + } + + # Add a marc exporter to library2 + marc_l2_config = db.integration_library_configuration( + marc_integration, marc_exporter_fixture.library2 + ) + assert enabled_collections_and_libraries() == { + (marc_exporter_fixture.collection1, marc_l1_config), + (marc_exporter_fixture.collection1, marc_l2_config), + (marc_exporter_fixture.collection2, marc_l1_config), + } + + # Enable marc export on collection3 + marc_exporter_fixture.collection3.export_marc_records = True + assert enabled_collections_and_libraries() == { + (marc_exporter_fixture.collection1, marc_l1_config), + (marc_exporter_fixture.collection1, marc_l2_config), + (marc_exporter_fixture.collection2, marc_l1_config), + (marc_exporter_fixture.collection3, marc_l2_config), + } + + # We can also filter by a collection id + assert enabled_collections_and_libraries( + collection_id=marc_exporter_fixture.collection1.id + ) == { + (marc_exporter_fixture.collection1, marc_l1_config), + (marc_exporter_fixture.collection1, marc_l2_config), + } + + def test__last_updated(self, marc_exporter_fixture: MarcExporterFixture) -> None: + library = marc_exporter_fixture.library1 + collection = marc_exporter_fixture.collection1 + + last_updated = partial( + MarcExporter._last_updated, + marc_exporter_fixture.session, + library, + collection, + ) + + # If there is no cached file, we return None. + assert last_updated() is None + + # If there is a cached file, we return the time it was created. + file1 = MarcFile( + library=library, + collection=collection, + created=datetime_utc(1984, 5, 8), + key="file1", + ) + marc_exporter_fixture.session.add(file1) + assert last_updated() == file1.created + + # If there are multiple cached files, we return the time of the most recent one. + file2 = MarcFile( + library=library, + collection=collection, + created=utc_now(), + key="file2", + ) + marc_exporter_fixture.session.add(file2) + assert last_updated() == file2.created + + def test_enabled_collections( + self, + db: DatabaseTransactionFixture, + marc_exporter_fixture: MarcExporterFixture, + ): + enabled_collections = partial( + MarcExporter.enabled_collections, + db.session, + marc_exporter_fixture.registry, + ) + + assert enabled_collections() == set() + + # Marc export is enabled on the collections, but since the libraries don't have a marc exporter, they are + # not included. + marc_exporter_fixture.collection1.export_marc_records = True + marc_exporter_fixture.collection2.export_marc_records = True + assert enabled_collections() == set() + + # Marc export is enabled, but no libraries are added to it + marc_integration = marc_exporter_fixture.integration() + assert enabled_collections() == set() + + # Add a marc exporter to library2 + db.integration_library_configuration( + marc_integration, marc_exporter_fixture.library2 + ) + assert enabled_collections() == {marc_exporter_fixture.collection1} + + # Enable marc export on collection3 + marc_exporter_fixture.collection3.export_marc_records = True + assert enabled_collections() == { + marc_exporter_fixture.collection1, + marc_exporter_fixture.collection3, + } + + def test_enabled_libraries( + self, + db: DatabaseTransactionFixture, + marc_exporter_fixture: MarcExporterFixture, + ): + assert marc_exporter_fixture.collection1.id is not None + enabled_libraries = partial( + MarcExporter.enabled_libraries, + db.session, + marc_exporter_fixture.registry, + collection_id=marc_exporter_fixture.collection1.id, + ) + + assert enabled_libraries() == [] + + # Collections have marc export enabled, and the marc exporter integration is setup, but + # no libraries are configured to use it. + marc_exporter_fixture.collection1.export_marc_records = True + marc_exporter_fixture.collection2.export_marc_records = True + marc_integration = marc_exporter_fixture.integration() + assert enabled_libraries() == [] + + # Add a marc exporter to library2 + db.integration_library_configuration( + marc_integration, + marc_exporter_fixture.library2, + MarcExporterLibrarySettings( + organization_code="org", web_client_url="http://web-client" + ), + ) + [library_2_info] = enabled_libraries() + + def assert_library_2(library_info: LibraryInfo) -> None: + assert library_info.library_id == marc_exporter_fixture.library2.id + assert ( + library_info.library_short_name + == marc_exporter_fixture.library2.short_name + ) + assert library_info.last_updated is None + assert library_info.needs_update + assert library_info.organization_code == "org" + assert library_info.include_summary is False + assert library_info.include_genres is False + assert library_info.web_client_urls == ("http://web-client",) + assert library_info.s3_key_full.startswith("marc/library2/collection1.full") + assert library_info.s3_key_delta is None + + assert_library_2(library_2_info) + + # Add a marc exporter to library1 + db.integration_library_configuration( + marc_integration, + marc_exporter_fixture.library1, + MarcExporterLibrarySettings( + organization_code="org2", include_summary=True, include_genres=True + ), + ) + [library_1_info, library_2_info] = enabled_libraries() + assert_library_2(library_2_info) + + assert library_1_info.library_id == marc_exporter_fixture.library1.id + assert ( + library_1_info.library_short_name + == marc_exporter_fixture.library1.short_name + ) + assert library_1_info.last_updated is None + assert library_1_info.needs_update + assert library_1_info.organization_code == "org2" + assert library_1_info.include_summary is True + assert library_1_info.include_genres is True + assert library_1_info.web_client_urls == () + assert library_1_info.s3_key_full.startswith("marc/library1/collection1.full") + assert library_1_info.s3_key_delta is None + + def test_query_works(self, marc_exporter_fixture: MarcExporterFixture) -> None: + assert marc_exporter_fixture.collection1.id is not None + query_works = partial( + MarcExporter.query_works, + marc_exporter_fixture.session, + collection_id=marc_exporter_fixture.collection1.id, + work_id_offset=None, + batch_size=3, + ) + + assert query_works() == [] + + works = marc_exporter_fixture.works() + + assert query_works() == works[:3] + assert query_works(work_id_offset=works[3].id) == works[4:] + + def test_collection(self, marc_exporter_fixture: MarcExporterFixture) -> None: + collection_id = marc_exporter_fixture.collection1.id + assert collection_id is not None + collection = MarcExporter.collection( + marc_exporter_fixture.session, collection_id + ) + assert collection == marc_exporter_fixture.collection1 + + marc_exporter_fixture.session.delete(collection) + collection = MarcExporter.collection( + marc_exporter_fixture.session, collection_id + ) + assert collection is None + + def test_process_work(self, marc_exporter_fixture: MarcExporterFixture) -> None: + marc_exporter_fixture.configure_export() + + collection = marc_exporter_fixture.collection1 + work = marc_exporter_fixture.work(collection) + enabled_libraries = marc_exporter_fixture.enabled_libraries(collection) + + mock_upload_manager = create_autospec(MarcUploadManager) + + process_work = partial( + MarcExporter.process_work, + work, + enabled_libraries, + "http://base.url", + upload_manager=mock_upload_manager, + ) + + process_work() + mock_upload_manager.add_record.assert_has_calls( + [ + call(enabled_libraries[0].s3_key_full, ANY), + call(enabled_libraries[0].s3_key_delta, ANY), + call(enabled_libraries[1].s3_key_full, ANY), + ] + ) + + # If the work has no license pools, it is skipped. + mock_upload_manager.reset_mock() + work.license_pools = [] + process_work() + mock_upload_manager.add_record.assert_not_called() + + def test_create_marc_upload_records( + self, marc_exporter_fixture: MarcExporterFixture + ) -> None: + marc_exporter_fixture.configure_export() + + collection = marc_exporter_fixture.collection1 + assert collection.id is not None + enabled_libraries = marc_exporter_fixture.enabled_libraries(collection) + + marc_exporter_fixture.session.query(MarcFile).delete() + + start_time = utc_now() + + # If there are no uploads, then no records are created. + MarcExporter.create_marc_upload_records( + marc_exporter_fixture.session, + start_time, + collection.id, + enabled_libraries, + set(), + ) + + assert len(marc_exporter_fixture.session.query(MarcFile).all()) == 0 + + # If there are uploads, then records are created. + assert enabled_libraries[0].s3_key_delta is not None + MarcExporter.create_marc_upload_records( + marc_exporter_fixture.session, + start_time, + collection.id, + enabled_libraries, + { + enabled_libraries[0].s3_key_full, + enabled_libraries[1].s3_key_full, + enabled_libraries[0].s3_key_delta, + }, + ) + + assert len(marc_exporter_fixture.session.query(MarcFile).all()) == 3 + + assert get_one( + marc_exporter_fixture.session, + MarcFile, + collection=collection, + library_id=enabled_libraries[0].library_id, + key=enabled_libraries[0].s3_key_full, + ) + + assert get_one( + marc_exporter_fixture.session, + MarcFile, + collection=collection, + library_id=enabled_libraries[1].library_id, + key=enabled_libraries[1].s3_key_full, + ) + + assert get_one( + marc_exporter_fixture.session, + MarcFile, + collection=collection, + library_id=enabled_libraries[0].library_id, + key=enabled_libraries[0].s3_key_delta, + since=enabled_libraries[0].last_updated, + ) diff --git a/tests/manager/marc/test_uploader.py b/tests/manager/marc/test_uploader.py new file mode 100644 index 0000000000..bb7898e34c --- /dev/null +++ b/tests/manager/marc/test_uploader.py @@ -0,0 +1,334 @@ +from unittest.mock import MagicMock, call + +import pytest +from celery.exceptions import Ignore, Retry + +from palace.manager.marc.uploader import MarcUploadManager +from palace.manager.service.redis.models.marc import ( + MarcFileUpload, + MarcFileUploadSession, +) +from palace.manager.sqlalchemy.model.resource import Representation +from tests.fixtures.redis import RedisFixture +from tests.fixtures.s3 import S3ServiceFixture + + +class MarcUploadManagerFixture: + def __init__( + self, redis_fixture: RedisFixture, s3_service_fixture: S3ServiceFixture + ): + self._redis_fixture = redis_fixture + self._s3_service_fixture = s3_service_fixture + + self.test_key1 = "test.123" + self.test_record1 = b"test_record_123" + self.test_key2 = "test*456" + self.test_record2 = b"test_record_456" + self.test_key3 = "test--?789" + self.test_record3 = b"test_record_789" + + self.mock_s3_service = s3_service_fixture.mock_service() + # Reduce the minimum upload size to make testing easier + self.mock_s3_service.MINIMUM_MULTIPART_UPLOAD_SIZE = len(self.test_record1) * 4 + self.redis_client = redis_fixture.client + + self.mock_collection_id = 52 + + self.uploads = MarcFileUploadSession(self.redis_client, self.mock_collection_id) + self.uploader = MarcUploadManager(self.mock_s3_service, self.uploads) + + +@pytest.fixture +def marc_upload_manager_fixture( + redis_fixture: RedisFixture, s3_service_fixture: S3ServiceFixture +): + return MarcUploadManagerFixture(redis_fixture, s3_service_fixture) + + +class TestMarcUploadManager: + def test_begin( + self, + marc_upload_manager_fixture: MarcUploadManagerFixture, + redis_fixture: RedisFixture, + ): + uploader = marc_upload_manager_fixture.uploader + + assert uploader.locked is False + assert marc_upload_manager_fixture.uploads.locked(by_us=True) is False + + with uploader.begin() as u: + # The context manager returns the uploader object + assert u is uploader + + # It directly tells us the lock status + assert uploader.locked is True + + # The lock is also reflected in the uploads object + assert marc_upload_manager_fixture.uploads.locked(by_us=True) is True # type: ignore[unreachable] + + # The lock is released after the context manager exits + assert uploader.locked is False # type: ignore[unreachable] + assert marc_upload_manager_fixture.uploads.locked(by_us=True) is False + + # If an exception occurs, the lock is deleted and the exception is raised by calling + # the _abort method + mock_abort = MagicMock(wraps=uploader._abort) + uploader._abort = mock_abort + with pytest.raises(Exception): + with uploader.begin(): + assert uploader.locked is True + raise Exception() + assert ( + redis_fixture.client.json().get(marc_upload_manager_fixture.uploads.key) + is None + ) + mock_abort.assert_called_once() + + # If a expected celery exception occurs, the lock is released, but not deleted + # and the abort method isn't called + mock_abort.reset_mock() + for exception in Retry, Ignore: + with pytest.raises(exception): + with uploader.begin(): + assert uploader.locked is True + raise exception() + assert marc_upload_manager_fixture.uploads.locked(by_us=True) is False + assert ( + redis_fixture.client.json().get(marc_upload_manager_fixture.uploads.key) + is not None + ) + mock_abort.assert_not_called() + + def test_add_record(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + uploader = marc_upload_manager_fixture.uploader + + uploader.add_record( + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1, + ) + assert ( + uploader._buffers[marc_upload_manager_fixture.test_key1] + == marc_upload_manager_fixture.test_record1.decode() + ) + + uploader.add_record( + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1, + ) + assert ( + uploader._buffers[marc_upload_manager_fixture.test_key1] + == marc_upload_manager_fixture.test_record1.decode() * 2 + ) + + def test_sync(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + uploader = marc_upload_manager_fixture.uploader + + uploader.add_record( + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1, + ) + uploader.add_record( + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2 * 2, + ) + with uploader.begin(): + uploader.sync() + + # Sync clears the local buffer + assert uploader._buffers == {} + + # And pushes the local records to redis + assert marc_upload_manager_fixture.uploads.get() == { + marc_upload_manager_fixture.test_key1: MarcFileUpload( + buffer=marc_upload_manager_fixture.test_record1 + ), + marc_upload_manager_fixture.test_key2: MarcFileUpload( + buffer=marc_upload_manager_fixture.test_record2 * 2 + ), + } + + # Because the buffer did not contain enough data, it was not uploaded to S3 + assert marc_upload_manager_fixture.mock_s3_service.upload_in_progress == {} + + # Add enough data for test_key1 to be uploaded to S3 + uploader.add_record( + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 2, + ) + uploader.add_record( + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 2, + ) + uploader.add_record( + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2, + ) + + with uploader.begin(): + uploader.sync() + + # The buffer is cleared + assert uploader._buffers == {} + + # Because the data for test_key1 was large enough, it was uploaded to S3, and its redis data structure was + # updated to reflect this. test_key2 was not large enough to upload, so it remains in redis and not in s3. + redis_data = marc_upload_manager_fixture.uploads.get() + assert redis_data[marc_upload_manager_fixture.test_key2] == MarcFileUpload( + buffer=marc_upload_manager_fixture.test_record2 * 3 + ) + redis_data_test1 = redis_data[marc_upload_manager_fixture.test_key1] + assert redis_data_test1.buffer == "" + + assert len(marc_upload_manager_fixture.mock_s3_service.upload_in_progress) == 1 + assert ( + marc_upload_manager_fixture.test_key1 + in marc_upload_manager_fixture.mock_s3_service.upload_in_progress + ) + upload = marc_upload_manager_fixture.mock_s3_service.upload_in_progress[ + marc_upload_manager_fixture.test_key1 + ] + assert upload.upload_id is not None + assert upload.content_type is Representation.MARC_MEDIA_TYPE + [part] = upload.parts + assert part.content == marc_upload_manager_fixture.test_record1 * 5 + + # And the s3 part data and upload_id is synced to redis + assert redis_data_test1.parts == [part.part_data] + assert redis_data_test1.upload_id == upload.upload_id + + def test_complete(self, marc_upload_manager_fixture: MarcUploadManagerFixture): + uploader = marc_upload_manager_fixture.uploader + + # Wrap the clear method so we can check if it was called + mock_clear_uploads = MagicMock( + wraps=marc_upload_manager_fixture.uploads.clear_uploads + ) + marc_upload_manager_fixture.uploads.clear_uploads = mock_clear_uploads + + # Set up the records for the test + uploader.add_record( + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 5, + ) + uploader.add_record( + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2 * 5, + ) + with uploader.begin(): + uploader.sync() + + uploader.add_record( + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 5, + ) + with uploader.begin(): + uploader.sync() + + uploader.add_record( + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2, + ) + + uploader.add_record( + marc_upload_manager_fixture.test_key3, + marc_upload_manager_fixture.test_record3, + ) + + # Complete the uploads + with uploader.begin(): + completed = uploader.complete() + + # The complete method should return the keys that were completed + assert completed == { + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_key3, + } + + # The local buffers should be empty + assert uploader._buffers == {} + + # The redis record should have the completed uploads cleared + mock_clear_uploads.assert_called_once() + + # The s3 service should have the completed uploads + assert len(marc_upload_manager_fixture.mock_s3_service.uploads) == 3 + assert len(marc_upload_manager_fixture.mock_s3_service.upload_in_progress) == 0 + + test_key1_upload = marc_upload_manager_fixture.mock_s3_service.uploads[ + marc_upload_manager_fixture.test_key1 + ] + assert test_key1_upload.key == marc_upload_manager_fixture.test_key1 + assert test_key1_upload.content == marc_upload_manager_fixture.test_record1 * 10 + assert test_key1_upload.media_type == Representation.MARC_MEDIA_TYPE + + test_key2_upload = marc_upload_manager_fixture.mock_s3_service.uploads[ + marc_upload_manager_fixture.test_key2 + ] + assert test_key2_upload.key == marc_upload_manager_fixture.test_key2 + assert test_key2_upload.content == marc_upload_manager_fixture.test_record2 * 6 + assert test_key2_upload.media_type == Representation.MARC_MEDIA_TYPE + + test_key3_upload = marc_upload_manager_fixture.mock_s3_service.uploads[ + marc_upload_manager_fixture.test_key3 + ] + assert test_key3_upload.key == marc_upload_manager_fixture.test_key3 + assert test_key3_upload.content == marc_upload_manager_fixture.test_record3 + assert test_key3_upload.media_type == Representation.MARC_MEDIA_TYPE + + def test__abort( + self, + marc_upload_manager_fixture: MarcUploadManagerFixture, + caplog: pytest.LogCaptureFixture, + ): + uploader = marc_upload_manager_fixture.uploader + + # Set up the records for the test + uploader.add_record( + marc_upload_manager_fixture.test_key1, + marc_upload_manager_fixture.test_record1 * 10, + ) + uploader.add_record( + marc_upload_manager_fixture.test_key2, + marc_upload_manager_fixture.test_record2 * 10, + ) + with uploader.begin(): + uploader.sync() + + # Mock the multipart_abort method so we can check if it was called and have it + # raise an exception on the first call + mock_abort = MagicMock(side_effect=[Exception("Boom"), None]) + marc_upload_manager_fixture.mock_s3_service.multipart_abort = mock_abort + + # Wrap the delete method so we can check if it was called + mock_delete = MagicMock(wraps=marc_upload_manager_fixture.uploads.delete) + marc_upload_manager_fixture.uploads.delete = mock_delete + + upload_id_1 = marc_upload_manager_fixture.mock_s3_service.upload_in_progress[ + marc_upload_manager_fixture.test_key1 + ].upload_id + upload_id_2 = marc_upload_manager_fixture.mock_s3_service.upload_in_progress[ + marc_upload_manager_fixture.test_key2 + ].upload_id + + # Abort the uploads, the original exception should propagate, and the exception + # thrown by the first call to abort should be logged + with pytest.raises(Exception) as exc_info: + with uploader.begin(): + raise Exception("Bang") + assert str(exc_info.value) == "Bang" + + assert ( + f"Failed to abort upload {marc_upload_manager_fixture.test_key1} (UploadID: {upload_id_1}) due to exception (Boom)" + in caplog.text + ) + + mock_abort.assert_has_calls( + [ + call(marc_upload_manager_fixture.test_key1, upload_id_1), + call(marc_upload_manager_fixture.test_key2, upload_id_2), + ] + ) + + # The redis record should have been deleted + mock_delete.assert_called_once() diff --git a/tests/manager/scripts/test_marc.py b/tests/manager/scripts/test_marc.py deleted file mode 100644 index 3b83d359fb..0000000000 --- a/tests/manager/scripts/test_marc.py +++ /dev/null @@ -1,466 +0,0 @@ -from __future__ import annotations - -import datetime -import logging -from unittest.mock import MagicMock, call, create_autospec - -import pytest -from _pytest.logging import LogCaptureFixture -from sqlalchemy.exc import NoResultFound - -from palace.manager.core.config import CannotLoadConfiguration -from palace.manager.core.marc import ( - MARCExporter, - MarcExporterLibrarySettings, - MarcExporterSettings, -) -from palace.manager.integration.goals import Goals -from palace.manager.scripts.marc import CacheMARCFiles -from palace.manager.sqlalchemy.model.discovery_service_registration import ( - DiscoveryServiceRegistration, -) -from palace.manager.sqlalchemy.model.integration import IntegrationConfiguration -from palace.manager.sqlalchemy.model.library import Library -from palace.manager.sqlalchemy.model.marcfile import MarcFile -from palace.manager.sqlalchemy.util import create -from palace.manager.util.datetime_helpers import datetime_utc, utc_now -from tests.fixtures.database import DatabaseTransactionFixture -from tests.fixtures.services import ServicesFixture - - -class CacheMARCFilesFixture: - def __init__( - self, db: DatabaseTransactionFixture, services_fixture: ServicesFixture - ): - self.db = db - self.services_fixture = services_fixture - self.base_url = "http://test-circulation-manager" - services_fixture.set_base_url(self.base_url) - self.exporter = MagicMock(spec=MARCExporter) - self.library = self.db.default_library() - self.collection = self.db.collection() - self.collection.export_marc_records = True - self.collection.libraries += [self.library] - - def integration(self, library: Library | None = None) -> IntegrationConfiguration: - if library is None: - library = self.library - - return self.db.integration_configuration( - protocol=MARCExporter, - goal=Goals.CATALOG_GOAL, - libraries=[library], - ) - - def script(self, cmd_args: list[str] | None = None) -> CacheMARCFiles: - cmd_args = cmd_args or [] - return CacheMARCFiles( - self.db.session, - exporter=self.exporter, - services=self.services_fixture.services, - cmd_args=cmd_args, - ) - - -@pytest.fixture -def cache_marc_files( - db: DatabaseTransactionFixture, services_fixture: ServicesFixture -) -> CacheMARCFilesFixture: - return CacheMARCFilesFixture(db, services_fixture) - - -class TestCacheMARCFiles: - def test_constructor(self, cache_marc_files: CacheMARCFilesFixture): - cache_marc_files.services_fixture.set_base_url(None) - with pytest.raises(CannotLoadConfiguration): - cache_marc_files.script() - - cache_marc_files.services_fixture.set_base_url("http://test.com") - script = cache_marc_files.script() - assert script.base_url == "http://test.com" - - def test_settings(self, cache_marc_files: CacheMARCFilesFixture): - # Test that the script gets the correct settings. - test_library = cache_marc_files.library - other_library = cache_marc_files.db.library() - - expected_settings = MarcExporterSettings(update_frequency=3) - expected_library_settings = MarcExporterLibrarySettings( - organization_code="test", - include_summary=True, - include_genres=True, - ) - - other_library_settings = MarcExporterLibrarySettings( - organization_code="other", - ) - - integration = cache_marc_files.integration(test_library) - integration.libraries += [other_library] - - test_library_integration = integration.for_library(test_library) - assert test_library_integration is not None - other_library_integration = integration.for_library(other_library) - assert other_library_integration is not None - MARCExporter.settings_update(integration, expected_settings) - MARCExporter.library_settings_update( - test_library_integration, expected_library_settings - ) - MARCExporter.library_settings_update( - other_library_integration, other_library_settings - ) - - script = cache_marc_files.script() - actual_settings, actual_library_settings = script.settings(test_library) - - assert actual_settings == expected_settings - assert actual_library_settings == expected_library_settings - - def test_settings_none(self, cache_marc_files: CacheMARCFilesFixture): - # If there are no settings, the setting function raises an exception. - test_library = cache_marc_files.library - script = cache_marc_files.script() - with pytest.raises(NoResultFound): - script.settings(test_library) - - def test_process_libraries_no_storage( - self, cache_marc_files: CacheMARCFilesFixture, caplog: LogCaptureFixture - ): - # If there is no storage integration, the script logs an error and returns. - script = cache_marc_files.script() - script.storage_service = None - caplog.set_level(logging.INFO) - script.process_libraries([MagicMock(), MagicMock()]) - assert "No storage service was found" in caplog.text - - def test_get_collections(self, cache_marc_files: CacheMARCFilesFixture): - # Test that the script gets the correct collections. - test_library = cache_marc_files.library - collection1 = cache_marc_files.collection - - # Second collection is configured to export MARC records. - collection2 = cache_marc_files.db.collection() - collection2.export_marc_records = True - collection2.libraries += [test_library] - - # Third collection is not configured to export MARC records. - collection3 = cache_marc_files.db.collection() - collection3.export_marc_records = False - collection3.libraries += [test_library] - - # Fourth collection is configured to export MARC records, but is - # configured to export only to a different library. - other_library = cache_marc_files.db.library() - other_collection = cache_marc_files.db.collection() - other_collection.export_marc_records = True - other_collection.libraries += [other_library] - - script = cache_marc_files.script() - - # We should get back the two collections that are configured to export - # MARC records to this library. - collections = script.get_collections(test_library) - assert set(collections) == {collection1, collection2} - - # Set collection3 to export MARC records to this library. - collection3.export_marc_records = True - - # We should get back all three collections that are configured to export - # MARC records to this library. - collections = script.get_collections(test_library) - assert set(collections) == {collection1, collection2, collection3} - - def test_get_web_client_urls( - self, - db: DatabaseTransactionFixture, - cache_marc_files: CacheMARCFilesFixture, - ): - # No web client URLs are returned if there are no discovery service registrations. - script = cache_marc_files.script() - assert script.get_web_client_urls(cache_marc_files.library) == [] - - # If we pass in a configured web client URL, that URL is returned. - assert script.get_web_client_urls( - cache_marc_files.library, "http://web-client" - ) == ["http://web-client"] - - # Add a URL from a library registry. - registry = db.discovery_service_integration() - create( - db.session, - DiscoveryServiceRegistration, - library=cache_marc_files.library, - integration=registry, - web_client="http://web-client-url/", - ) - assert script.get_web_client_urls(cache_marc_files.library) == [ - "http://web-client-url/" - ] - - # URL from library registry and configured URL are both returned. - assert script.get_web_client_urls( - cache_marc_files.library, "http://web-client" - ) == [ - "http://web-client-url/", - "http://web-client", - ] - - def test_process_library_not_configured( - self, - cache_marc_files: CacheMARCFilesFixture, - ): - script = cache_marc_files.script() - mock_process_collection = create_autospec(script.process_collection) - script.process_collection = mock_process_collection - mock_settings = create_autospec(script.settings) - script.settings = mock_settings - mock_settings.side_effect = NoResultFound - - # If there is no integration configuration for the library, the script - # does nothing. - script.process_library(cache_marc_files.library) - mock_process_collection.assert_not_called() - - def test_process_library(self, cache_marc_files: CacheMARCFilesFixture): - script = cache_marc_files.script() - mock_annotator_cls = MagicMock() - mock_process_collection = create_autospec(script.process_collection) - script.process_collection = mock_process_collection - mock_settings = create_autospec(script.settings) - script.settings = mock_settings - settings = MarcExporterSettings(update_frequency=3) - library_settings = MarcExporterLibrarySettings( - organization_code="test", - web_client_url="http://web-client-url/", - include_summary=True, - include_genres=False, - ) - mock_settings.return_value = ( - settings, - library_settings, - ) - - before_call_time = utc_now() - - # If there is an integration configuration for the library, the script - # processes all the collections for that library. - script.process_library( - cache_marc_files.library, annotator_cls=mock_annotator_cls - ) - - after_call_time = utc_now() - - mock_annotator_cls.assert_called_once_with( - cache_marc_files.base_url, - cache_marc_files.library.short_name, - [library_settings.web_client_url], - library_settings.organization_code, - library_settings.include_summary, - library_settings.include_genres, - ) - - assert mock_process_collection.call_count == 1 - ( - library, - collection, - annotator, - update_frequency, - creation_time, - ) = mock_process_collection.call_args.args - assert library == cache_marc_files.library - assert collection == cache_marc_files.collection - assert annotator == mock_annotator_cls.return_value - assert update_frequency == settings.update_frequency - assert creation_time > before_call_time - assert creation_time < after_call_time - - def test_last_updated( - self, db: DatabaseTransactionFixture, cache_marc_files: CacheMARCFilesFixture - ): - script = cache_marc_files.script() - - # If there is no cached file, we return None. - assert ( - script.last_updated(cache_marc_files.library, cache_marc_files.collection) - is None - ) - - # If there is a cached file, we return the time it was created. - file1 = MarcFile( - library=cache_marc_files.library, - collection=cache_marc_files.collection, - created=datetime_utc(1984, 5, 8), - key="file1", - ) - db.session.add(file1) - assert ( - script.last_updated(cache_marc_files.library, cache_marc_files.collection) - == file1.created - ) - - # If there are multiple cached files, we return the time of the most recent one. - file2 = MarcFile( - library=cache_marc_files.library, - collection=cache_marc_files.collection, - created=utc_now(), - key="file2", - ) - db.session.add(file2) - assert ( - script.last_updated(cache_marc_files.library, cache_marc_files.collection) - == file2.created - ) - - def test_force(self, cache_marc_files: CacheMARCFilesFixture): - script = cache_marc_files.script() - assert script.force is False - - script = cache_marc_files.script(cmd_args=["--force"]) - assert script.force is True - - @pytest.mark.parametrize( - "last_updated, force, update_frequency, run_exporter", - [ - pytest.param(None, False, 10, True, id="never_run_before"), - pytest.param(None, False, 10, True, id="never_run_before_w_force"), - pytest.param( - utc_now() - datetime.timedelta(days=5), - False, - 10, - False, - id="recently_run", - ), - pytest.param( - utc_now() - datetime.timedelta(days=5), - True, - 10, - True, - id="recently_run_w_force", - ), - pytest.param( - utc_now() - datetime.timedelta(days=5), - False, - 0, - True, - id="recently_run_w_frequency_0", - ), - pytest.param( - utc_now() - datetime.timedelta(days=15), - False, - 10, - True, - id="not_recently_run", - ), - pytest.param( - utc_now() - datetime.timedelta(days=15), - True, - 10, - True, - id="not_recently_run_w_force", - ), - pytest.param( - utc_now() - datetime.timedelta(days=15), - False, - 0, - True, - id="not_recently_run_w_frequency_0", - ), - ], - ) - def test_process_collection_skip( - self, - cache_marc_files: CacheMARCFilesFixture, - caplog: LogCaptureFixture, - last_updated: datetime.datetime | None, - force: bool, - update_frequency: int, - run_exporter: bool, - ): - script = cache_marc_files.script() - script.exporter = MagicMock() - now = utc_now() - caplog.set_level(logging.INFO) - - script.force = force - script.last_updated = MagicMock(return_value=last_updated) - script.process_collection( - cache_marc_files.library, - cache_marc_files.collection, - MagicMock(), - update_frequency, - now, - ) - - if run_exporter: - assert script.exporter.records.call_count > 0 - assert "Processed collection" in caplog.text - else: - assert script.exporter.records.call_count == 0 - assert "Skipping collection" in caplog.text - - def test_process_collection_never_called( - self, cache_marc_files: CacheMARCFilesFixture, caplog: LogCaptureFixture - ): - # If the collection has not been processed before, the script processes - # the collection and created a full export. - caplog.set_level(logging.INFO) - script = cache_marc_files.script() - mock_exporter = MagicMock(spec=MARCExporter) - script.exporter = mock_exporter - script.last_updated = MagicMock(return_value=None) - mock_annotator = MagicMock() - creation_time = utc_now() - script.process_collection( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - 10, - creation_time, - ) - mock_exporter.records.assert_called_once_with( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - creation_time=creation_time, - ) - assert "Processed collection" in caplog.text - - def test_process_collection_with_last_updated( - self, cache_marc_files: CacheMARCFilesFixture, caplog: LogCaptureFixture - ): - # If the collection has been processed before, the script processes - # the collection, created a full export and a delta export. - caplog.set_level(logging.INFO) - script = cache_marc_files.script() - mock_exporter = MagicMock(spec=MARCExporter) - script.exporter = mock_exporter - last_updated = utc_now() - datetime.timedelta(days=20) - script.last_updated = MagicMock(return_value=last_updated) - mock_annotator = MagicMock() - creation_time = utc_now() - script.process_collection( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - 10, - creation_time, - ) - assert "Processed collection" in caplog.text - assert mock_exporter.records.call_count == 2 - - full_call = call( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - creation_time=creation_time, - ) - - delta_call = call( - cache_marc_files.library, - cache_marc_files.collection, - mock_annotator, - creation_time=creation_time, - since_time=last_updated, - ) - - mock_exporter.records.assert_has_calls([full_call, delta_call]) diff --git a/tests/manager/service/redis/models/test_lock.py b/tests/manager/service/redis/models/test_lock.py index ca7aa956d6..c317db7a84 100644 --- a/tests/manager/service/redis/models/test_lock.py +++ b/tests/manager/service/redis/models/test_lock.py @@ -1,10 +1,17 @@ from datetime import timedelta +from typing import Any from unittest.mock import create_autospec import pytest from palace.manager.celery.task import Task -from palace.manager.service.redis.models.lock import LockError, RedisLock, TaskLock +from palace.manager.service.redis.models.lock import ( + LockError, + RedisJsonLock, + RedisLock, + TaskLock, +) +from palace.manager.service.redis.redis import Redis from tests.fixtures.redis import RedisFixture @@ -182,3 +189,147 @@ def test___init__(self, redis_fixture: RedisFixture): # If we provide a lock_name, we should use that instead task_lock = TaskLock(redis_fixture.client, mock_task, lock_name="test_lock") assert task_lock.key.endswith("::TaskLock::test_lock") + + +class MockJsonLock(RedisJsonLock): + def __init__( + self, + redis_client: Redis, + key: str = "test", + timeout: int = 1000, + random_value: str | None = None, + ): + self._key = redis_client.get_key(key) + self._timeout = timeout + super().__init__(redis_client, random_value) + + @property + def key(self) -> str: + return self._key + + @property + def _lock_timeout_ms(self) -> int: + return self._timeout + + +class JsonLockFixture: + def __init__(self, redis_fixture: RedisFixture) -> None: + self.client = redis_fixture.client + self.lock = MockJsonLock(redis_fixture.client) + self.other_lock = MockJsonLock(redis_fixture.client) + + def get_key(self, key: str, json_key: str) -> Any: + ret_val = self.client.json().get(key, json_key) + if ret_val is None or len(ret_val) != 1: + return None + return ret_val[0] + + def assert_locked(self, lock: RedisJsonLock) -> None: + assert self.get_key(lock.key, lock._lock_json_key) == lock._random_value + + +@pytest.fixture +def json_lock_fixture(redis_fixture: RedisFixture) -> JsonLockFixture: + return JsonLockFixture(redis_fixture) + + +class TestJsonLock: + def test_acquire(self, json_lock_fixture: JsonLockFixture): + # We can acquire the lock. And acquiring the lock sets a timeout on the key, so the lock + # will expire eventually if something goes wrong. + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.client.ttl(json_lock_fixture.lock.key) > 0 + json_lock_fixture.assert_locked(json_lock_fixture.lock) + + # Acquiring the lock again with the same random value should return True + # and extend the timeout for the lock + json_lock_fixture.client.pexpire(json_lock_fixture.lock.key, 500) + timeout = json_lock_fixture.client.pttl(json_lock_fixture.lock.key) + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.client.pttl(json_lock_fixture.lock.key) > timeout + + # Acquiring the lock again with a different random value should return False + assert not json_lock_fixture.other_lock.acquire() + json_lock_fixture.assert_locked(json_lock_fixture.lock) + + def test_release(self, json_lock_fixture: JsonLockFixture): + # If the lock doesn't exist, we can't release it + assert json_lock_fixture.lock.release() is False + + # If you acquire a lock another client cannot release it + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.other_lock.release() is False + + # Make sure the key is set in redis + json_lock_fixture.assert_locked(json_lock_fixture.lock) + + # But the client that acquired the lock can release it + assert json_lock_fixture.lock.release() is True + + # And the key should still exist, but the lock key in the json is removed from redis + assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") == {} + + def test_delete(self, json_lock_fixture: JsonLockFixture): + assert json_lock_fixture.lock.delete() is False + + # If you acquire a lock another client cannot delete it + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.other_lock.delete() is False + + # Make sure the key is set in redis + assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") is not None + json_lock_fixture.assert_locked(json_lock_fixture.lock) + + # But the client that acquired the lock can delete it + assert json_lock_fixture.lock.delete() is True + + # And the key should still exist, but the lock key in the json is removed from redis + assert json_lock_fixture.get_key(json_lock_fixture.lock.key, "$") is None + + def test_extend_timeout(self, json_lock_fixture: JsonLockFixture): + assert json_lock_fixture.lock.extend_timeout() is False + + # If the lock has a timeout, the acquiring client can extend it, but another client cannot + assert json_lock_fixture.lock.acquire() + json_lock_fixture.client.pexpire(json_lock_fixture.lock.key, 500) + assert json_lock_fixture.other_lock.extend_timeout() is False + assert json_lock_fixture.client.pttl(json_lock_fixture.lock.key) <= 500 + + # The key should have a new timeout + assert json_lock_fixture.lock.extend_timeout() is True + assert json_lock_fixture.client.pttl(json_lock_fixture.lock.key) > 500 + + def test_locked(self, json_lock_fixture: JsonLockFixture): + # If the lock is not acquired, it should not be locked + assert json_lock_fixture.lock.locked() is False + + # If the lock is acquired, it should be locked + assert json_lock_fixture.lock.acquire() + assert json_lock_fixture.lock.locked() is True + assert json_lock_fixture.other_lock.locked() is True + assert json_lock_fixture.lock.locked(by_us=True) is True + assert json_lock_fixture.other_lock.locked(by_us=True) is False + + # If the lock is released, it should not be locked + assert json_lock_fixture.lock.release() is True + assert json_lock_fixture.lock.locked() is False + assert json_lock_fixture.other_lock.locked() is False + + def test__parse_value(self): + assert RedisJsonLock._parse_value(None) is None + assert RedisJsonLock._parse_value([]) is None + assert RedisJsonLock._parse_value(["value"]) == "value" + + def test__parse_multi(self): + assert RedisJsonLock._parse_multi(None) == {} + assert RedisJsonLock._parse_multi({}) == {} + assert RedisJsonLock._parse_multi( + {"key": ["value"], "key2": ["value2"], "key3": []} + ) == {"key": "value", "key2": "value2", "key3": None} + + def test__parse_value_or_raise(self): + with pytest.raises(LockError): + RedisJsonLock._parse_value_or_raise(None) + with pytest.raises(LockError): + RedisJsonLock._parse_value_or_raise([]) + assert RedisJsonLock._parse_value_or_raise(["value"]) == "value" diff --git a/tests/manager/service/redis/models/test_marc.py b/tests/manager/service/redis/models/test_marc.py new file mode 100644 index 0000000000..3013b2906d --- /dev/null +++ b/tests/manager/service/redis/models/test_marc.py @@ -0,0 +1,469 @@ +import pytest + +from palace.manager.service.redis.models.marc import ( + MarcFileUpload, + MarcFileUploadSession, + MarcFileUploadSessionError, + MarcFileUploadState, +) +from palace.manager.service.redis.redis import Pipeline +from palace.manager.service.storage.s3 import MultipartS3UploadPart +from tests.fixtures.redis import RedisFixture + + +class MarcFileUploadSessionFixture: + def __init__(self, redis_fixture: RedisFixture): + self._redis_fixture = redis_fixture + + self.mock_collection_id = 1 + + self.uploads = MarcFileUploadSession( + self._redis_fixture.client, self.mock_collection_id + ) + + self.mock_upload_key_1 = "test1" + self.mock_upload_key_2 = "test2" + self.mock_upload_key_3 = "test3" + + self.mock_unset_upload_key = "test4" + + self.test_data = { + self.mock_upload_key_1: "test", + self.mock_upload_key_2: "another_test", + self.mock_upload_key_3: "another_another_test", + } + + self.part_1 = MultipartS3UploadPart(etag="abc", part_number=1) + self.part_2 = MultipartS3UploadPart(etag="def", part_number=2) + + def load_test_data(self) -> dict[str, int]: + lock_acquired = False + if not self.uploads.locked(): + self.uploads.acquire() + lock_acquired = True + + return_value = self.uploads.append_buffers(self.test_data) + + if lock_acquired: + self.uploads.release() + + return return_value + + def test_data_records(self, *keys: str): + return {key: MarcFileUpload(buffer=self.test_data[key]) for key in keys} + + +@pytest.fixture +def marc_file_upload_session_fixture(redis_fixture: RedisFixture): + return MarcFileUploadSessionFixture(redis_fixture) + + +class TestMarcFileUploadSession: + def test__pipeline( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # Using the _pipeline() context manager makes sure that we hold the lock + with pytest.raises(MarcFileUploadSessionError) as exc_info: + with uploads._pipeline(): + pass + assert "Must hold lock" in str(exc_info.value) + + uploads.acquire() + + # It also checks that the update_number is correct + uploads._update_number = 1 + with pytest.raises(MarcFileUploadSessionError) as exc_info: + with uploads._pipeline(): + pass + assert "Update number mismatch" in str(exc_info.value) + + uploads._update_number = 0 + with uploads._pipeline() as pipe: + # If the lock and update number are correct, we should get a pipeline object + assert isinstance(pipe, Pipeline) + + # We are watching the key for this object, so that we know all the data within the + # transaction is consistent, and we are still holding the lock when the pipeline + # executes + assert pipe.watching is True + + # By default it starts the pipeline transaction + assert pipe.explicit_transaction is True + + # We can also start the pipeline without a transaction + with uploads._pipeline(begin_transaction=False) as pipe: + assert pipe.explicit_transaction is False + + def test__execute_pipeline( + self, + marc_file_upload_session_fixture: MarcFileUploadSessionFixture, + redis_fixture: RedisFixture, + ): + client = redis_fixture.client + uploads = marc_file_upload_session_fixture.uploads + uploads.acquire() + + # If we try to execute a pipeline without a transaction, we should get an error + with pytest.raises(MarcFileUploadSessionError) as exc_info: + with uploads._pipeline(begin_transaction=False) as pipe: + uploads._execute_pipeline(pipe, 0) + assert "Pipeline should be in explicit transaction mode" in str(exc_info.value) + + # The _execute_pipeline function takes care of extending the timeout and incrementing + # the update number and setting the state of the session + [update_number] = client.json().get( + uploads.key, uploads._update_number_json_key + ) + client.pexpire(uploads.key, 500) + old_state = uploads.state() + with uploads._pipeline() as pipe: + # If we execute the pipeline, we should get a list of results, excluding the + # operations that _execute_pipeline does. + assert uploads._execute_pipeline(pipe, 2) == [] + [new_update_number] = client.json().get( + uploads.key, uploads._update_number_json_key + ) + assert new_update_number == update_number + 2 + assert client.pttl(uploads.key) > 500 + assert uploads.state() != old_state + assert uploads.state() == MarcFileUploadState.UPLOADING + + # If we try to execute a pipeline that has been modified by another process, we should get an error + with uploads._pipeline() as pipe: + client.json().set( + uploads.key, uploads._update_number_json_key, update_number + ) + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads._execute_pipeline(pipe, 1) + assert "Another process is modifying the buffers" in str(exc_info.value) + + def test_append_buffers( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # If we try to update buffers without acquiring the lock, we should get an error + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.append_buffers( + {marc_file_upload_session_fixture.mock_upload_key_1: "test"} + ) + assert "Must hold lock" in str(exc_info.value) + + # Acquire the lock and try to update buffers + with uploads.lock() as locked: + assert locked + assert uploads.append_buffers({}) == {} + + assert uploads.append_buffers( + { + marc_file_upload_session_fixture.mock_upload_key_1: "test", + marc_file_upload_session_fixture.mock_upload_key_2: "another_test", + } + ) == { + marc_file_upload_session_fixture.mock_upload_key_1: 4, + marc_file_upload_session_fixture.mock_upload_key_2: 12, + } + assert uploads._update_number == 2 + + assert uploads.append_buffers( + { + marc_file_upload_session_fixture.mock_upload_key_1: "x", + marc_file_upload_session_fixture.mock_upload_key_2: "y", + marc_file_upload_session_fixture.mock_upload_key_3: "new", + } + ) == { + marc_file_upload_session_fixture.mock_upload_key_1: 5, + marc_file_upload_session_fixture.mock_upload_key_2: 13, + marc_file_upload_session_fixture.mock_upload_key_3: 3, + } + assert uploads._update_number == 5 + + # If we try to update buffers with an old update number, we should get an error + uploads._update_number = 4 + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.append_buffers(marc_file_upload_session_fixture.test_data) + assert "Update number mismatch" in str(exc_info.value) + + # Exiting the context manager should release the lock + assert not uploads.locked() + + def test_get(self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture): + uploads = marc_file_upload_session_fixture.uploads + + assert uploads.get() == {} + assert uploads.get(marc_file_upload_session_fixture.mock_upload_key_1) == {} + + marc_file_upload_session_fixture.load_test_data() + + # You don't need to acquire the lock to get the uploads, but you should if you + # are using the data to do updates. + + # You can get a subset of the uploads + assert uploads.get( + marc_file_upload_session_fixture.mock_upload_key_1, + ) == marc_file_upload_session_fixture.test_data_records( + marc_file_upload_session_fixture.mock_upload_key_1 + ) + + # Or multiple uploads, any that don't exist are not included in the result dict + assert uploads.get( + [ + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.mock_upload_key_2, + marc_file_upload_session_fixture.mock_unset_upload_key, + ] + ) == marc_file_upload_session_fixture.test_data_records( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.mock_upload_key_2, + ) + + # Or you can get all the uploads + assert uploads.get() == marc_file_upload_session_fixture.test_data_records( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.mock_upload_key_2, + marc_file_upload_session_fixture.mock_upload_key_3, + ) + + def test_set_upload_id( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # must hold lock to do update + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_1, "xyz" + ) + assert "Must hold lock" in str(exc_info.value) + + uploads.acquire() + + # We are unable to set an upload id for an item that hasn't been initialized + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_1, "xyz" + ) + assert "Failed to set upload ID" in str(exc_info.value) + + marc_file_upload_session_fixture.load_test_data() + uploads.set_upload_id(marc_file_upload_session_fixture.mock_upload_key_1, "def") + uploads.set_upload_id(marc_file_upload_session_fixture.mock_upload_key_2, "abc") + + all_uploads = uploads.get() + assert ( + all_uploads[marc_file_upload_session_fixture.mock_upload_key_1].upload_id + == "def" + ) + assert ( + all_uploads[marc_file_upload_session_fixture.mock_upload_key_2].upload_id + == "abc" + ) + + # We can't change the upload id for a library that has already been set + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_1, "ghi" + ) + assert "Failed to set upload ID" in str(exc_info.value) + + all_uploads = uploads.get() + assert ( + all_uploads[marc_file_upload_session_fixture.mock_upload_key_1].upload_id + == "def" + ) + assert ( + all_uploads[marc_file_upload_session_fixture.mock_upload_key_2].upload_id + == "abc" + ) + + def test_clear_uploads( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # must hold lock to do update + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.clear_uploads() + assert "Must hold lock" in str(exc_info.value) + + uploads.acquire() + + # We are unable to clear the uploads for an item that hasn't been initialized + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.clear_uploads() + assert "Failed to clear uploads" in str(exc_info.value) + + marc_file_upload_session_fixture.load_test_data() + assert uploads.get() != {} + + uploads.clear_uploads() + assert uploads.get() == {} + + def test_get_upload_ids( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # If the id is not set, we should get None + assert uploads.get_upload_ids( + [marc_file_upload_session_fixture.mock_upload_key_1] + ) == {marc_file_upload_session_fixture.mock_upload_key_1: None} + + marc_file_upload_session_fixture.load_test_data() + + # If the buffer has been set, but the upload id has not, we should still get None + assert uploads.get_upload_ids( + [marc_file_upload_session_fixture.mock_upload_key_1] + ) == {marc_file_upload_session_fixture.mock_upload_key_1: None} + + with uploads.lock() as locked: + assert locked + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_1, "abc" + ) + uploads.set_upload_id( + marc_file_upload_session_fixture.mock_upload_key_2, "def" + ) + assert uploads.get_upload_ids( + marc_file_upload_session_fixture.mock_upload_key_1 + ) == {marc_file_upload_session_fixture.mock_upload_key_1: "abc"} + assert uploads.get_upload_ids( + [ + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.mock_upload_key_2, + ] + ) == { + marc_file_upload_session_fixture.mock_upload_key_1: "abc", + marc_file_upload_session_fixture.mock_upload_key_2: "def", + } + + def test_add_part_and_clear_buffer( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # If we try to add parts without acquiring the lock, we should get an error + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.add_part_and_clear_buffer( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_1, + ) + assert "Must hold lock" in str(exc_info.value) + + # Acquire the lock + uploads.acquire() + + # We are unable to add parts to a library whose buffers haven't been initialized + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.add_part_and_clear_buffer( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_1, + ) + assert "Failed to add part and clear buffer" in str(exc_info.value) + + marc_file_upload_session_fixture.load_test_data() + + # We are able to add parts to a library that exists + uploads.add_part_and_clear_buffer( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_1, + ) + uploads.add_part_and_clear_buffer( + marc_file_upload_session_fixture.mock_upload_key_2, + marc_file_upload_session_fixture.part_1, + ) + uploads.add_part_and_clear_buffer( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_2, + ) + + all_uploads = uploads.get() + # The parts are added in order and the buffers are cleared + assert all_uploads[ + marc_file_upload_session_fixture.mock_upload_key_1 + ].parts == [ + marc_file_upload_session_fixture.part_1, + marc_file_upload_session_fixture.part_2, + ] + assert all_uploads[ + marc_file_upload_session_fixture.mock_upload_key_2 + ].parts == [marc_file_upload_session_fixture.part_1] + assert ( + all_uploads[marc_file_upload_session_fixture.mock_upload_key_1].buffer == "" + ) + assert ( + all_uploads[marc_file_upload_session_fixture.mock_upload_key_2].buffer == "" + ) + + def test_get_part_num_and_buffer( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # If the key has not been initialized, we get an exception + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.get_part_num_and_buffer( + marc_file_upload_session_fixture.mock_upload_key_1 + ) + assert "Failed to get part number and buffer data" in str(exc_info.value) + + marc_file_upload_session_fixture.load_test_data() + + # If the buffer has been set, but no parts have been added + assert uploads.get_part_num_and_buffer( + marc_file_upload_session_fixture.mock_upload_key_1 + ) == ( + 0, + marc_file_upload_session_fixture.test_data[ + marc_file_upload_session_fixture.mock_upload_key_1 + ], + ) + + with uploads.lock() as locked: + assert locked + uploads.add_part_and_clear_buffer( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_1, + ) + uploads.add_part_and_clear_buffer( + marc_file_upload_session_fixture.mock_upload_key_1, + marc_file_upload_session_fixture.part_2, + ) + uploads.append_buffers( + { + marc_file_upload_session_fixture.mock_upload_key_1: "1234567", + } + ) + + assert uploads.get_part_num_and_buffer( + marc_file_upload_session_fixture.mock_upload_key_1 + ) == (2, "1234567") + + def test_state( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # If the session doesn't exist, the state should be None + assert uploads.state() is None + + # Once the state is created, by locking for example, the state should be SessionState.INITIAL + with uploads.lock(): + assert uploads.state() == MarcFileUploadState.INITIAL + + def test_set_state( + self, marc_file_upload_session_fixture: MarcFileUploadSessionFixture + ): + uploads = marc_file_upload_session_fixture.uploads + + # If we don't hold the lock, we can't set the state + with pytest.raises(MarcFileUploadSessionError) as exc_info: + uploads.set_state(MarcFileUploadState.UPLOADING) + assert "Must hold lock" in str(exc_info.value) + + # Once the state is created, by locking for example, we can set the state + with uploads.lock(): + uploads.set_state(MarcFileUploadState.UPLOADING) + assert uploads.state() == MarcFileUploadState.UPLOADING diff --git a/tests/manager/service/storage/test_s3.py b/tests/manager/service/storage/test_s3.py index 28086f7a17..c946aa01e3 100644 --- a/tests/manager/service/storage/test_s3.py +++ b/tests/manager/service/storage/test_s3.py @@ -1,28 +1,15 @@ from __future__ import annotations import functools -import uuid -from collections.abc import Generator from io import BytesIO -from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest from botocore.exceptions import BotoCoreError, ClientError -from pydantic import AnyHttpUrl from palace.manager.core.config import CannotLoadConfiguration -from palace.manager.service.configuration.service_configuration import ( - ServiceConfiguration, -) -from palace.manager.service.storage.container import Storage from palace.manager.service.storage.s3 import S3Service -from tests.fixtures.config import FixtureTestUrlConfiguration - -if TYPE_CHECKING: - from mypy_boto3_s3 import S3Client - - from tests.fixtures.s3 import S3ServiceFixture +from tests.fixtures.s3 import S3ServiceFixture, S3ServiceIntegrationFixture class TestS3Service: @@ -239,88 +226,6 @@ def test_multipart_upload_exception(self, s3_service_fixture: S3ServiceFixture): upload.upload_part(b"foo") -class S3UploaderIntegrationConfiguration(FixtureTestUrlConfiguration): - url: AnyHttpUrl - user: str - password: str - - class Config(ServiceConfiguration.Config): - env_prefix = "PALACE_TEST_MINIO_" - - -class S3ServiceIntegrationFixture: - def __init__(self): - self.container = Storage() - self.configuration = S3UploaderIntegrationConfiguration.from_env() - self.analytics_bucket = self.random_name("analytics") - self.public_access_bucket = self.random_name("public") - self.container.config.from_dict( - { - "access_key": self.configuration.user, - "secret_key": self.configuration.password, - "endpoint_url": self.configuration.url, - "region": "us-east-1", - "analytics_bucket": self.analytics_bucket, - "public_access_bucket": self.public_access_bucket, - "url_template": self.configuration.url + "/{bucket}/{key}", - } - ) - self.buckets = [] - self.create_buckets() - - @classmethod - def random_name(cls, prefix: str = "test"): - return f"{prefix}-{uuid.uuid4()}" - - @property - def s3_client(self) -> S3Client: - return self.container.s3_client() - - @property - def public(self) -> S3Service: - return self.container.public() - - @property - def analytics(self) -> S3Service: - return self.container.analytics() - - def create_bucket(self, bucket_name: str) -> None: - client = self.s3_client - client.create_bucket(Bucket=bucket_name) - self.buckets.append(bucket_name) - - def get_bucket(self, bucket_name: str) -> str: - if bucket_name == "public": - return self.public_access_bucket - elif bucket_name == "analytics": - return self.analytics_bucket - else: - raise ValueError(f"Unknown bucket name: {bucket_name}") - - def create_buckets(self) -> None: - for bucket in [self.analytics_bucket, self.public_access_bucket]: - self.create_bucket(bucket) - - def close(self): - for bucket in self.buckets: - response = self.s3_client.list_objects(Bucket=bucket) - - for object in response.get("Contents", []): - object_key = object["Key"] - self.s3_client.delete_object(Bucket=bucket, Key=object_key) - - self.s3_client.delete_bucket(Bucket=bucket) - - -@pytest.fixture -def s3_service_integration_fixture() -> ( - Generator[S3ServiceIntegrationFixture, None, None] -): - fixture = S3ServiceIntegrationFixture() - yield fixture - fixture.close() - - @pytest.mark.minio class TestS3ServiceIntegration: def test_delete(self, s3_service_integration_fixture: S3ServiceIntegrationFixture): diff --git a/tox.ini b/tox.ini index 51ab13f7ec..8aa20dcf75 100644 --- a/tox.ini +++ b/tox.ini @@ -76,7 +76,7 @@ host_var = PALACE_TEST_MINIO_URL_HOST [docker:redis-circ] -image = redis:7 +image = redis/redis-stack-server:7.4.0-v0 expose = PALACE_TEST_REDIS_URL_PORT=6379/tcp host_var =