diff --git a/alembic/versions/20240216_9d2dccb0d6ff_ability_to_suppress_works_per_library.py b/alembic/versions/20240216_9d2dccb0d6ff_ability_to_suppress_works_per_library.py new file mode 100644 index 0000000000..212c45ac95 --- /dev/null +++ b/alembic/versions/20240216_9d2dccb0d6ff_ability_to_suppress_works_per_library.py @@ -0,0 +1,31 @@ +"""Ability to suppress works per library + +Revision ID: 9d2dccb0d6ff +Revises: 1c9f519415b5 +Create Date: 2024-02-16 17:08:52.146860+00:00 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "9d2dccb0d6ff" +down_revision = "1c9f519415b5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "work_library_suppressions", + sa.Column("work_id", sa.Integer(), nullable=False), + sa.Column("library_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["library_id"], ["libraries.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["work_id"], ["works.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("work_id", "library_id"), + ) + + +def downgrade() -> None: + op.drop_table("work_library_suppressions") diff --git a/bin/configuration/suppress_work_for_library b/bin/configuration/suppress_work_for_library new file mode 100755 index 0000000000..16109a64a9 --- /dev/null +++ b/bin/configuration/suppress_work_for_library @@ -0,0 +1,10 @@ +#!/usr/bin/env python +import os +import sys + +bin_dir = os.path.split(__file__)[0] +package_dir = os.path.join(bin_dir, "..", "..") +sys.path.append(os.path.abspath(package_dir)) +from core.scripts import SuppressWorkForLibraryScript + +SuppressWorkForLibraryScript().run() diff --git a/core/external_search.py b/core/external_search.py index f5278a13e9..1af769eda9 100644 --- a/core/external_search.py +++ b/core/external_search.py @@ -1575,6 +1575,7 @@ def from_worklist(cls, _db, worklist, facets): allow_holds=allow_holds, license_datasource=license_datasource_id, lane_building=True, + library=library, ) def __init__( @@ -1723,6 +1724,9 @@ def __init__( self.lane_building = kwargs.pop("lane_building", False) + library = kwargs.pop("library", None) + self.library_id = library.id if library else None + # At this point there should be no keyword arguments -- you can't pass # whatever you want into this method. if kwargs: @@ -1847,6 +1851,11 @@ def build(self, _chain_filters=None): if self.author is not None: nested_filters["contributors"].append(self.author_filter) + if self.library_id: + f = chain( + f, Bool(must_not=[Terms(**{"suppressed_for": [self.library_id]})]) + ) + if self.media: f = chain(f, Terms(medium=scrub_list(self.media))) diff --git a/core/lane.py b/core/lane.py index 48ff6e4528..ef00b4d510 100644 --- a/core/lane.py +++ b/core/lane.py @@ -2387,10 +2387,17 @@ def only_show_ready_deliverable_works(self, _db, query, show_suppressed=False): Note that this assumes the query has an active join against LicensePool. """ - return Collection.restrict_to_ready_deliverable_works( + query = Collection.restrict_to_ready_deliverable_works( query, show_suppressed=show_suppressed, collection_ids=self.collection_ids ) + if not show_suppressed and self.library_id is not None: + query = query.filter( + not_(Work.suppressed_for.contains(self.get_library(_db))) + ) + + return query + def bibliographic_filter_clauses(self, _db, qu): """Create a SQLAlchemy filter that excludes books whose bibliographic metadata doesn't match what we're looking for. diff --git a/core/model/listeners.py b/core/model/listeners.py index 25d357db5a..41030fc717 100644 --- a/core/model/listeners.py +++ b/core/model/listeners.py @@ -105,6 +105,13 @@ def licensepool_removed_from_work(target, value, initiator): target.external_index_needs_updating() +@event.listens_for(Work.suppressed_for, "append") +@event.listens_for(Work.suppressed_for, "remove") +def work_suppressed_for_library(target, value, initiator): + if target: + target.external_index_needs_updating() + + @Listener.before_flush(LicensePool, ListenerState.deleted) def licensepool_deleted(session: Session, instance: LicensePool) -> None: """A LicensePool is deleted only when its collection is deleted. diff --git a/core/model/work.py b/core/model/work.py index c343cc9ed8..48dd74aa09 100644 --- a/core/model/work.py +++ b/core/model/work.py @@ -20,6 +20,7 @@ ForeignKey, Integer, Numeric, + Table, Unicode, ) from sqlalchemy.dialects.postgresql import INT4RANGE @@ -209,6 +210,11 @@ class Work(Base): # will be made to make the Work presentation ready. presentation_ready_exception = Column(Unicode, default=None, index=True) + # Supress this work from appearing in any feeds for a specific library. + suppressed_for: Mapped[Library] = relationship( + "Library", secondary="work_library_suppressions", passive_deletes=True + ) + # These fields are potentially large and can be deferred if you # don't need all the data in a Work. LARGE_FIELDS = [ @@ -1601,6 +1607,7 @@ def _set_value(parent, key, target): result["_id"] = getattr(doc, "id") result["work_id"] = getattr(doc, "id") result["summary"] = getattr(doc, "summary_text") + result["suppressed_for"] = [int(l.id) for l in getattr(doc, "suppressed_for")] result["fiction"] = ( "Fiction" if getattr(doc, "fiction") is True else "Nonfiction" ) @@ -1806,6 +1813,16 @@ def delete( _db.delete(self) +work_library_suppressions = Table( + "work_library_suppressions", + Base.metadata, + Column("work_id", ForeignKey("works.id", ondelete="CASCADE"), primary_key=True), + Column( + "library_id", ForeignKey("libraries.id", ondelete="CASCADE"), primary_key=True + ), +) + + def add_work_to_customlists_for_collection(pool_or_work: LicensePool | Work) -> None: if isinstance(pool_or_work, Work): work = pool_or_work diff --git a/core/scripts.py b/core/scripts.py index 04502328b0..dc548d607c 100644 --- a/core/scripts.py +++ b/core/scripts.py @@ -2668,6 +2668,96 @@ def process_loan(self, loan: Loan): self.notifications.send_loan_expiry_message(loan, delta.days, tokens) +class SuppressWorkForLibraryScript(Script): + """Suppress works from a library by identifier""" + + BY_DATABASE_ID = "Database ID" + + @classmethod + def arg_parser(cls, _db: Session | None) -> argparse.ArgumentParser: # type: ignore[override] + parser = argparse.ArgumentParser() + if _db is None: + raise ValueError("No database session provided.") + library_name_list = sorted(str(l.short_name) for l in _db.query(Library)) + library_names = '"' + '", "'.join(library_name_list) + '"' + parser.add_argument( + "-l", + "--library", + help="Short name of the library. Libraries on this system: %s." + % library_names, + required=True, + metavar="SHORT_NAME", + ) + parser.add_argument( + "-t", + "--identifier-type", + help="Identifier type (default: ISBN). " + f'To name identifiers by their database ID, use --identifier-type="{cls.BY_DATABASE_ID}".', + default="ISBN", + ) + parser.add_argument( + "-i", + "--identifier", + help="The identifier to suppress.", + required=True, + ) + return parser + + @classmethod + def parse_command_line( + cls, _db: Session | None = None, cmd_args: list[str] | None = None + ): + parser = cls.arg_parser(_db) + return parser.parse_known_args(cmd_args)[0] + + def load_library(self, library_short_name: str) -> Library: + library_short_name = library_short_name.strip() + library = self._db.scalars( + select(Library).where(Library.short_name == library_short_name) + ).one_or_none() + if not library: + raise ValueError(f"Unknown library: {library_short_name}") + return library + + def load_identifier(self, identifier_type: str, identifier: str) -> Identifier: + query = select(Identifier) + identifier_type = identifier_type.strip() + identifier = identifier.strip() + if identifier_type == self.BY_DATABASE_ID: + query = query.where(Identifier.id == int(identifier)) + else: + query = query.where(Identifier.type == identifier_type).where( + Identifier.identifier == identifier + ) + + identifier_obj = self._db.scalars(query).unique().one_or_none() + if not identifier_obj: + raise ValueError(f"Unknown identifier: {identifier_type}/{identifier}") + + return identifier_obj + + def do_run(self, cmd_args: list[str] | None = None) -> None: + parsed = self.parse_command_line(self._db, cmd_args=cmd_args) + + library = self.load_library(parsed.library) + identifier = self.load_identifier(parsed.identifier_type, parsed.identifier) + + self.suppress_work(library, identifier) + + def suppress_work(self, library: Library, identifier: Identifier) -> None: + work = identifier.work + if not work: + self.log.warning(f"No work found for {identifier}") + return + + work.suppressed_for.append(library) + self.log.info( + f"Suppressed {identifier.type}/{identifier.identifier} (work id: {work.id}) for {library.short_name}." + ) + + self._db.commit() + + class MockStdin: """Mock a list of identifiers passed in on standard input.""" diff --git a/tests/core/models/test_listeners.py b/tests/core/models/test_listeners.py index 142ef8f1af..d10aaca3ee 100644 --- a/tests/core/models/test_listeners.py +++ b/tests/core/models/test_listeners.py @@ -216,3 +216,22 @@ def test_licensepool_storage_status_change( == work.coverage_records[0].operation ) assert WorkCoverageRecord.REGISTERED == work.coverage_records[0].status + + def test_work_suppressed_for_library(self, db: DatabaseTransactionFixture): + work = db.work(with_license_pool=True) + library = db.library() + + # Clear out any WorkCoverageRecords created as the work was initialized. + work.coverage_records = [] + + # Act + work.suppressed_for.append(library) + + # Assert + assert 1 == len(work.coverage_records) + assert work.id == work.coverage_records[0].work_id + assert ( + WorkCoverageRecord.UPDATE_SEARCH_INDEX_OPERATION + == work.coverage_records[0].operation + ) + assert WorkCoverageRecord.REGISTERED == work.coverage_records[0].status diff --git a/tests/core/models/test_work.py b/tests/core/models/test_work.py index a90ffda587..ff1e041c22 100644 --- a/tests/core/models/test_work.py +++ b/tests/core/models/test_work.py @@ -4,6 +4,7 @@ import pytest import pytz from psycopg2.extras import NumericRange +from sqlalchemy import select from core.classifier import Classifier, Fantasy, Romance, Science_Fiction from core.equivalents_coverage import EquivalentIdentifiersCoverageProvider @@ -16,7 +17,7 @@ from core.model.identifier import Identifier from core.model.licensing import LicensePool from core.model.resource import Hyperlink, Representation, Resource -from core.model.work import Work, WorkGenre +from core.model.work import Work, WorkGenre, work_library_suppressions from core.util.datetime_helpers import datetime_utc, from_timestamp, utc_now from tests.fixtures.database import DatabaseTransactionFixture from tests.fixtures.sample_covers import SampleCoversFixture @@ -806,6 +807,53 @@ def test_work_updates_info_on_pool_suppressed(self, db: DatabaseTransactionFixtu assert "Alice Adder, Bob Bitshifter" == work.author assert "Adder, Alice ; Bitshifter, Bob" == work.sort_author + def test_suppressed_for_delete_work(self, db: DatabaseTransactionFixture): + work = db.work() + library1 = db.library() + library2 = db.library() + + work.suppressed_for.append(library1) + work.suppressed_for.append(library2) + db.session.flush() + + assert len(db.session.execute(select(work_library_suppressions)).all()) == 2 + + db.session.delete(work) + db.session.flush() + + # The libraries are not deleted. + assert library1 in db.session + assert library2 in db.session + + # The work is deleted. + assert work not in db.session + + # The references in the work_library_suppressions table to the work are deleted. + assert len(db.session.execute(select(work_library_suppressions)).all()) == 0 + + def test_suppressed_for_delete_library(self, db: DatabaseTransactionFixture): + work = db.work() + library1 = db.library() + library2 = db.library() + + work.suppressed_for.append(library1) + work.suppressed_for.append(library2) + db.session.flush() + + assert len(db.session.execute(select(work_library_suppressions)).all()) == 2 + + db.session.delete(library1) + db.session.flush() + db.session.expire_all() + + assert library1 not in db.session + + assert library2 in db.session + assert work in db.session + + assert len(db.session.execute(select(work_library_suppressions)).all()) == 1 + assert work.suppressed_for == [library2] + def test_different_language_means_different_work( self, db: DatabaseTransactionFixture ): diff --git a/tests/core/test_external_search.py b/tests/core/test_external_search.py index bd60ba9472..a44938e3aa 100644 --- a/tests/core/test_external_search.py +++ b/tests/core/test_external_search.py @@ -3611,6 +3611,10 @@ def assert_filter_builds_to(self, expect, filter, _chain_filters=None): """ final_query = {"bool": {"must_not": [RESEARCH.to_dict()]}} + if filter.library_id: + suppressed_for = Terms(**{"suppressed_for": [filter.library_id]}) + final_query["bool"]["must_not"].insert(0, suppressed_for.to_dict()) + if expect: final_query["bool"]["must"] = expect main, nested = filter.build(_chain_filters) @@ -3718,6 +3722,7 @@ def test_build(self, filter_fixture: FilterFixture): chain = self._mock_chain + filter.library_id = transaction.default_library().id filter.collection_ids = [transaction.default_collection()] filter.fiction = True filter._audiences = "CHILDREN" @@ -3834,7 +3839,15 @@ def test_build(self, filter_fixture: FilterFixture): # Every other restriction imposed on the Filter object becomes an # Opensearch filter object in this list. - (medium, language, fiction, audience, target_age, updated_after) = built + ( + library_suppression, + medium, + language, + fiction, + audience, + target_age, + updated_after, + ) = built # Test them one at a time. # @@ -3852,6 +3865,14 @@ def test_build(self, filter_fixture: FilterFixture): assert medium_built == medium.to_dict() assert language_built == language.to_dict() + assert { + "bool": { + "must_not": [ + {"terms": {"suppressed_for": [transaction.default_library().id]}} + ] + } + } == library_suppression.to_dict() + assert {"term": {"fiction": "fiction"}} == fiction.to_dict() assert {"terms": {"audience": ["children"]}} == audience.to_dict() @@ -4727,6 +4748,8 @@ def test_operation( def test_to_search_document(self, db: DatabaseTransactionFixture): """Test the output of the to_search_document method.""" customlist, editions = db.customlist() + library = db.library() + works = [ db.work( authors=[db.contributor()], @@ -4738,6 +4761,7 @@ def test_to_search_document(self, db: DatabaseTransactionFixture): work1: Work = works[0] work2: Work = works[1] + work2.suppressed_for.append(library) work1.target_age = NumericRange(lower=18, upper=22, bounds="()") work2.target_age = NumericRange(lower=18, upper=99, bounds="[]") @@ -4786,6 +4810,11 @@ def compare(doc: dict[str, Any], work: Work) -> None: assert doc["rating"] == work.rating assert doc["popularity"] == work.popularity + if work.suppressed_for: + assert doc["suppressed_for"] == [l.id for l in work.suppressed_for] + else: + assert doc["suppressed_for"] is None + if work.license_pools: assert len(doc["licensepools"]) == len(work.license_pools) for idx, pool in enumerate(work.license_pools): diff --git a/tests/core/test_lane.py b/tests/core/test_lane.py index 674a2589ed..142caa1a23 100644 --- a/tests/core/test_lane.py +++ b/tests/core/test_lane.py @@ -4627,6 +4627,87 @@ def test_works_and_works_from_database( from_db = fixture.work_ids_from_db(lane) assert from_search == from_db + def test_works_and_works_from_database_with_suppressed( + self, + work_list_groups_end_to_end_fixture: WorkListGroupsEndToEndFixture, + ): + db = work_list_groups_end_to_end_fixture.db + fixture = work_list_groups_end_to_end_fixture + index = fixture.external_search_fixture.external_search_index + + # Create a bunch of lanes and works. + data = fixture.populate_works() + lane_data = fixture.create_lanes(data) + + decoy_library = db.library() + another_library = db.library() + + db.default_collection().libraries += [decoy_library, another_library] + + # Add a couple suppressed works, to make sure they don't show up in the results. + globally_suppressed_work = db.work( + title="Suppressed LP", + fiction=True, + genre="Literary Fiction", + with_license_pool=True, + ) + globally_suppressed_work.quality = 0.95 + for license_pool in globally_suppressed_work.license_pools: + license_pool.suppressed = True + + # This work is only suppressed for a specific library. + library_suppressed_work = db.work( + title="Suppressed 2", + fiction=True, + genre="Literary Fiction", + with_license_pool=True, + ) + library_suppressed_work.quality = 0.95 + library_suppressed_work.suppressed_for = [fixture.library, decoy_library] + + fixture.populate_search_index() + + for lane_name in fields(lane_data): + lane = getattr(lane_data, lane_name.name) + from_search = fixture.work_ids_from_search(lane) + from_db = fixture.work_ids_from_db(lane) + + # The suppressed work is not included in the results. + assert globally_suppressed_work.id not in from_search + assert globally_suppressed_work.id not in from_db + assert library_suppressed_work.id not in from_search + assert library_suppressed_work.id not in from_db + + # Test the decoy libraries lane as well + decoy_library_lane = db.lane("Fiction", fiction=True, library=decoy_library) + from_search = fixture.work_ids_from_search(decoy_library_lane) + from_db = fixture.work_ids_from_db(decoy_library_lane) + assert globally_suppressed_work.id not in from_search + assert globally_suppressed_work.id not in from_db + assert library_suppressed_work.id not in from_search + assert library_suppressed_work.id not in from_db + + # Test a lane for a different library, this time the globally suppressed work should + # still be absent, but the work suppressed for the other library should be present. + another_library_lane = db.lane("Fiction", fiction=True, library=another_library) + from_search = fixture.work_ids_from_search(another_library_lane) + from_db = fixture.work_ids_from_db(another_library_lane) + assert globally_suppressed_work.id not in from_search + assert globally_suppressed_work.id not in from_db + assert library_suppressed_work.id in from_search + assert library_suppressed_work.id in from_db + + # Make sure that the suppressed works are handled correctly when searching in a lane as well + assert library_suppressed_work in another_library_lane.search( + db.session, "suppressed", index + ) + assert library_suppressed_work not in lane_data.fiction.search( + db.session, "suppressed", index + ) + assert library_suppressed_work not in decoy_library_lane.search( + db.session, "suppressed", index + ) + class RandomSeedFixture: def __init__(self): diff --git a/tests/core/test_scripts.py b/tests/core/test_scripts.py index 1c34d5057d..bf9eb29cac 100644 --- a/tests/core/test_scripts.py +++ b/tests/core/test_scripts.py @@ -74,6 +74,7 @@ ShowIntegrationsScript, ShowLanesScript, ShowLibrariesScript, + SuppressWorkForLibraryScript, TimestampScript, UpdateCustomListSizeScript, UpdateLaneSizeScript, @@ -2459,6 +2460,100 @@ def test_constructor( ) +class TestSuppressWorkForLibraryScript: + @pytest.mark.parametrize( + "cmd_args", + [ + "", + "--library test", + "--library test --identifier-type test", + "--identifier-type test", + "--identifier test", + ], + ) + def test_parse_command_line_error( + self, db: DatabaseTransactionFixture, capsys, cmd_args: str + ): + with pytest.raises(SystemExit): + SuppressWorkForLibraryScript.parse_command_line( + db.session, cmd_args.split(" ") + ) + + assert "error: the following arguments are required" in capsys.readouterr().err + + @pytest.mark.parametrize( + "cmd_args", + [ + "--library test1 --identifier-type test2 --identifier test3", + "-l test1 -t test2 -i test3", + ], + ) + def test_parse_command_line(self, db: DatabaseTransactionFixture, cmd_args: str): + parsed = SuppressWorkForLibraryScript.parse_command_line( + db.session, cmd_args.split(" ") + ) + assert parsed.library == "test1" + assert parsed.identifier_type == "test2" + assert parsed.identifier == "test3" + + def test_load_library(self, db: DatabaseTransactionFixture): + test_library = db.library(short_name="test") + + script = SuppressWorkForLibraryScript(db.session) + loaded_library = script.load_library("test") + assert loaded_library == test_library + + with pytest.raises(ValueError): + script.load_library("test2") + + def test_load_identifier(self, db: DatabaseTransactionFixture): + test_identifier = db.identifier() + + script = SuppressWorkForLibraryScript(db.session) + loaded_identifier = script.load_identifier( + str(test_identifier.type), str(test_identifier.identifier) + ) + assert loaded_identifier == test_identifier + + loaded_identifier = script.load_identifier( + script.BY_DATABASE_ID, str(test_identifier.id) + ) + assert loaded_identifier == test_identifier + + with pytest.raises(ValueError): + script.load_identifier("test", "test") + + def test_do_run(self, db: DatabaseTransactionFixture): + test_library = db.library(short_name="test") + test_identifier = db.identifier() + + script = SuppressWorkForLibraryScript(db.session) + suppress_work_mock = create_autospec(script.suppress_work) + script.suppress_work = suppress_work_mock + args = [ + "--library", + test_library.short_name, + "--identifier-type", + test_identifier.type, + "--identifier", + test_identifier.identifier, + ] + script.do_run(args) + + suppress_work_mock.assert_called_once_with(test_library, test_identifier) + + def test_suppress_work(self, db: DatabaseTransactionFixture): + test_library = db.library(short_name="test") + work = db.work(with_license_pool=True) + + assert work.suppressed_for == [] + + script = SuppressWorkForLibraryScript(db.session) + script.suppress_work(test_library, work.presentation_edition.primary_identifier) + + assert work.suppressed_for == [test_library] + + class TestWorkConsolidationScript: """TODO"""