From b00e3704650a617165f80c2a56df9c7269daa82f Mon Sep 17 00:00:00 2001 From: rdk Date: Wed, 7 Aug 2024 11:17:23 +0200 Subject: [PATCH 1/2] column based residue labeler --- build.gradle | 2 +- .../cz/siret/prank/domain/Dataset.groovy | 18 +++++-- .../labeling/ColumnBasedResidueLabeler.groovy | 50 +++++++++++++++++++ .../domain/labeling/ResidueLabeler.groovy | 10 ++-- .../labeling/SprintLabelingLoader.groovy | 3 +- 5 files changed, 70 insertions(+), 13 deletions(-) create mode 100644 src/main/groovy/cz/siret/prank/domain/labeling/ColumnBasedResidueLabeler.groovy diff --git a/build.gradle b/build.gradle index 50c69cc0..a1b2823a 100644 --- a/build.gradle +++ b/build.gradle @@ -18,7 +18,7 @@ apply plugin: 'groovy' apply plugin: 'java' group = 'cz.siret' -version = '2.4.2' +version = '2.4.3-dev.1' description = 'Ligand binding site prediction based on machine learning.' diff --git a/src/main/groovy/cz/siret/prank/domain/Dataset.groovy b/src/main/groovy/cz/siret/prank/domain/Dataset.groovy index 5734f4af..3d3ea7cf 100644 --- a/src/main/groovy/cz/siret/prank/domain/Dataset.groovy +++ b/src/main/groovy/cz/siret/prank/domain/Dataset.groovy @@ -68,6 +68,7 @@ class Dataset implements Parametrized, Writable, Failable { static final String COLUMN_CONSERVATION_FILES_PATTERN = "conservation_files_pattern" static final String COLUMN_APO_PROTEIN = "apo_protein" static final String COLUMN_APO_CHAINS = "apo_chains" + static final String COLUMN_POSITIVE_RESIDUES = "positive_residues" static final List DEFAULT_HEADER = [ COLUMN_PROTEIN ] @@ -423,7 +424,7 @@ class Dataset implements Parametrized, Writable, Failable { * @return true if explicit residue labeling is defined a as a part of th dataset */ boolean hasExplicitResidueLabeling() { - return attributes.containsKey(PARAM_RESIDUE_LABELING_FORMAT) + return attributes.containsKey(PARAM_RESIDUE_LABELING_FORMAT) || header.contains(COLUMN_POSITIVE_RESIDUES) } /** @@ -438,9 +439,16 @@ class Dataset implements Parametrized, Writable, Failable { */ @Nullable ResidueLabeler getResidueLabeler() { + if (residueLabeler == null && hasExplicitResidueLabeling()) { - String labelingFile = dir + "/" + attributes.get(PARAM_RESIDUE_LABELING_FILE) - residueLabeler = ResidueLabeler.loadFromFile(attributes.get(PARAM_RESIDUE_LABELING_FORMAT), labelingFile) + + if (header.contains(COLUMN_POSITIVE_RESIDUES)) { + + } else { + String labelingFile = dir + "/" + attributes.get(PARAM_RESIDUE_LABELING_FILE) + residueLabeler = ResidueLabeler.loadFromFile(attributes.get(PARAM_RESIDUE_LABELING_FORMAT), labelingFile) + } + } return residueLabeler } @@ -883,7 +891,7 @@ class Dataset implements Parametrized, Writable, Failable { @Nullable BinaryLabeling getExplicitBinaryLabeling() { if (originDataset.hasExplicitResidueLabeling()) { - return originDataset.explicitBinaryResidueLabeler.getBinaryLabeling(protein.residues, protein) + return originDataset.explicitBinaryResidueLabeler.getBinaryLabeling(protein.residues, protein, this) } return null } @@ -893,7 +901,7 @@ class Dataset implements Parametrized, Writable, Failable { */ @Nullable BinaryLabeling getBinaryLabeling() { - return originDataset.binaryResidueLabeler.getBinaryLabeling(protein.residues, protein) + return originDataset.binaryResidueLabeler.getBinaryLabeling(protein.residues, protein, this) } ProcessedItemContext getContext() { diff --git a/src/main/groovy/cz/siret/prank/domain/labeling/ColumnBasedResidueLabeler.groovy b/src/main/groovy/cz/siret/prank/domain/labeling/ColumnBasedResidueLabeler.groovy new file mode 100644 index 00000000..119fb32f --- /dev/null +++ b/src/main/groovy/cz/siret/prank/domain/labeling/ColumnBasedResidueLabeler.groovy @@ -0,0 +1,50 @@ +package cz.siret.prank.domain.labeling + +import cz.siret.prank.domain.Dataset +import cz.siret.prank.domain.Protein +import cz.siret.prank.domain.Residue +import cz.siret.prank.domain.Residues +import cz.siret.prank.utils.Sutils +import cz.siret.prank.utils.Writable +import groovy.transform.CompileStatic +import groovy.util.logging.Slf4j + +/** + * + */ +@Slf4j +@CompileStatic +class ColumnBasedResidueLabeler extends ResidueLabeler implements Writable { + + @Override + ResidueLabeling labelResidues(Residues residues, Protein protein, Dataset.Item item) { + + Set positives = Sutils.split(item.columnValues.get(Dataset.COLUMN_POSITIVE_RESIDUES), ',').toSet() + + return createLabelingFromPositiveResidueCodes(residues, positives) + } + + @Override + boolean isBinary() { + return true + } + + @Override + ResidueLabeling getDoubleLabeling() { + return null + } + +//===========================================================================================================// + + static ResidueLabeling createLabelingFromPositiveResidueCodes(Residues residues, Set positives) { + ResidueLabeling labeling = new ResidueLabeling() + + for (Residue residue : residues) { + boolean label = positives.contains(residue.toString()) + labeling.add(residue, label) + } + + return labeling + } + +} diff --git a/src/main/groovy/cz/siret/prank/domain/labeling/ResidueLabeler.groovy b/src/main/groovy/cz/siret/prank/domain/labeling/ResidueLabeler.groovy index 0c2bac53..43ab7e50 100644 --- a/src/main/groovy/cz/siret/prank/domain/labeling/ResidueLabeler.groovy +++ b/src/main/groovy/cz/siret/prank/domain/labeling/ResidueLabeler.groovy @@ -1,6 +1,6 @@ package cz.siret.prank.domain.labeling - +import cz.siret.prank.domain.Dataset import cz.siret.prank.domain.Protein import cz.siret.prank.domain.Residues import cz.siret.prank.program.PrankException @@ -14,23 +14,23 @@ import javax.annotation.Nullable @CompileStatic abstract class ResidueLabeler { - abstract ResidueLabeling labelResidues(Residues residues, Protein protein) + abstract ResidueLabeling labelResidues(Residues residues, Protein protein, Dataset.Item item) abstract boolean isBinary() @Nullable abstract ResidueLabeling getDoubleLabeling() - BinaryLabeling getBinaryLabeling(Residues residues, Protein protein) { + BinaryLabeling getBinaryLabeling(Residues residues, Protein protein, Dataset.Item item) { if (isBinary()) { - (BinaryLabeling) labelResidues(residues, protein) + (BinaryLabeling) labelResidues(residues, protein, item) } else { throw new PrankException("Residue labeler not binary!") } } BinaryLabeling getBinaryLabeling(Protein protein) { - getBinaryLabeling(protein.residues, protein) + getBinaryLabeling(protein.residues, protein, null) } static ResidueLabeler loadFromFile(String format, String fname) { diff --git a/src/main/groovy/cz/siret/prank/domain/labeling/SprintLabelingLoader.groovy b/src/main/groovy/cz/siret/prank/domain/labeling/SprintLabelingLoader.groovy index e92dee5d..312699cc 100644 --- a/src/main/groovy/cz/siret/prank/domain/labeling/SprintLabelingLoader.groovy +++ b/src/main/groovy/cz/siret/prank/domain/labeling/SprintLabelingLoader.groovy @@ -47,12 +47,11 @@ class SprintLabelingLoader extends ResidueLabeler implements Writable { } @Override - ResidueLabeling labelResidues(Residues residues, Protein protein) { + ResidueLabeling labelResidues(Residues residues, Protein protein, Dataset.Item item) { Map labelMap = new HashMap<>() - boolean foundOneChain = false for (ResidueChain chain : protein.residueChains) { String chainCode = toElementCode(protein, chain) From 1101acc35c7bf52239d870557141e1c0f26b461c Mon Sep 17 00:00:00 2001 From: rdk Date: Wed, 7 Aug 2024 11:40:11 +0200 Subject: [PATCH 2/2] column based residue labeler --- .../prank/domain/labeling/LigandBasedResidueLabeler.groovy | 4 ++-- .../prank/domain/labeling/ModelBasedResidueLabeler.groovy | 3 ++- .../cz/siret/prank/domain/labeling/ResidueLabeler.groovy | 6 +++++- .../siret/prank/domain/labeling/StaticResidueLabeler.groovy | 2 +- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/main/groovy/cz/siret/prank/domain/labeling/LigandBasedResidueLabeler.groovy b/src/main/groovy/cz/siret/prank/domain/labeling/LigandBasedResidueLabeler.groovy index 05eb6005..b79b5f72 100644 --- a/src/main/groovy/cz/siret/prank/domain/labeling/LigandBasedResidueLabeler.groovy +++ b/src/main/groovy/cz/siret/prank/domain/labeling/LigandBasedResidueLabeler.groovy @@ -1,6 +1,6 @@ package cz.siret.prank.domain.labeling - +import cz.siret.prank.domain.Dataset import cz.siret.prank.domain.Protein import cz.siret.prank.domain.Residue import cz.siret.prank.domain.Residues @@ -22,7 +22,7 @@ class LigandBasedResidueLabeler extends ResidueLabeler implements Param ResidueLabeling lastLigandDistanceLabeling @Override - ResidueLabeling labelResidues(Residues residues, Protein protein) { + ResidueLabeling labelResidues(Residues residues, Protein protein, Dataset.Item item) { ResidueLabeling ligandDistanceLabeling = new ResidueLabeling<>(residues.count) for (Residue res : residues) { diff --git a/src/main/groovy/cz/siret/prank/domain/labeling/ModelBasedResidueLabeler.groovy b/src/main/groovy/cz/siret/prank/domain/labeling/ModelBasedResidueLabeler.groovy index a5a2a1e0..8ad03e15 100644 --- a/src/main/groovy/cz/siret/prank/domain/labeling/ModelBasedResidueLabeler.groovy +++ b/src/main/groovy/cz/siret/prank/domain/labeling/ModelBasedResidueLabeler.groovy @@ -1,5 +1,6 @@ package cz.siret.prank.domain.labeling +import cz.siret.prank.domain.Dataset import cz.siret.prank.domain.Protein import cz.siret.prank.domain.Residue import cz.siret.prank.domain.Residues @@ -83,7 +84,7 @@ class ModelBasedResidueLabeler extends ResidueLabeler implements Parame } @Override - ResidueLabeling labelResidues(Residues residues, Protein protein) { + ResidueLabeling labelResidues(Residues residues, Protein protein, Dataset.Item item) { ModelBasedPointLabeler predictor = new ModelBasedPointLabeler(model, context).withObserved(observedPoints) diff --git a/src/main/groovy/cz/siret/prank/domain/labeling/ResidueLabeler.groovy b/src/main/groovy/cz/siret/prank/domain/labeling/ResidueLabeler.groovy index 43ab7e50..191d6d7e 100644 --- a/src/main/groovy/cz/siret/prank/domain/labeling/ResidueLabeler.groovy +++ b/src/main/groovy/cz/siret/prank/domain/labeling/ResidueLabeler.groovy @@ -21,7 +21,7 @@ abstract class ResidueLabeler { @Nullable abstract ResidueLabeling getDoubleLabeling() - BinaryLabeling getBinaryLabeling(Residues residues, Protein protein, Dataset.Item item) { + BinaryLabeling getBinaryLabeling(Residues residues, Protein protein, @Nullable Dataset.Item item) { if (isBinary()) { (BinaryLabeling) labelResidues(residues, protein, item) } else { @@ -29,6 +29,10 @@ abstract class ResidueLabeler { } } + BinaryLabeling getBinaryLabeling(Residues residues, Protein protein) { + return getBinaryLabeling(residues, protein, null) + } + BinaryLabeling getBinaryLabeling(Protein protein) { getBinaryLabeling(protein.residues, protein, null) } diff --git a/src/main/groovy/cz/siret/prank/domain/labeling/StaticResidueLabeler.groovy b/src/main/groovy/cz/siret/prank/domain/labeling/StaticResidueLabeler.groovy index 4feef3fb..f7253fcc 100644 --- a/src/main/groovy/cz/siret/prank/domain/labeling/StaticResidueLabeler.groovy +++ b/src/main/groovy/cz/siret/prank/domain/labeling/StaticResidueLabeler.groovy @@ -54,7 +54,7 @@ abstract class StaticResidueLabeler extends ResidueLabeler implements W } @Override - ResidueLabeling labelResidues(Residues residues, Protein protein) { + ResidueLabeling labelResidues(Residues residues, Protein protein, Dataset.Item item) { Map scores = new HashMap<>() for (String line : Futils.readLines(path).tail()) {