-
Notifications
You must be signed in to change notification settings - Fork 263
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Yifan Mai <[email protected]>
- Loading branch information
Showing
7 changed files
with
542 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
108 changes: 108 additions & 0 deletions
108
src/helm/benchmark/adaptation/adapters/ehr_instruction_adapter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
from typing import List, Optional | ||
|
||
from helm.benchmark.adaptation.adapters.generation_adapter import GenerationAdapter | ||
from helm.benchmark.adaptation.prompt import Prompt | ||
from helm.benchmark.adaptation.request_state import RequestState | ||
from helm.benchmark.scenarios.scenario import TRAIN_SPLIT, Instance | ||
from helm.benchmark.window_services.window_service import EncodeResult | ||
from helm.common.tokenization_request import TokenizationToken | ||
|
||
|
||
# in the prompt templates for EHR instructions, this is the placeholder for the EHR part | ||
# which we use to compute accurate tokenized sequence lengths | ||
PROMPT_TEMPLATE_EHR_PLACEHOLDER = "{ehr}" | ||
|
||
|
||
class EHRInstructionAdapter(GenerationAdapter): | ||
""" | ||
Each instance consists of the following: | ||
EHRInstructionInput: | ||
question: the question to answer or instruction to follow | ||
ehr: the XML-tagged EHR to use as context to answer the question | ||
prompt_template: a string template for how to combine the question + ehr | ||
Reference output: | ||
text: the 'golden' clinician response to the question | ||
This Adapter combines the above into RequestStates with logic to truncate the EHR specifically | ||
to fit in the context window with enough room for the instruction/question and the specified | ||
amount of generated tokens. | ||
""" | ||
|
||
def adapt(self, instances: List[Instance], parallelism: int) -> List[RequestState]: | ||
""" | ||
Main adaptation method which takes all instances and turns them into `RequestState` objects. | ||
""" | ||
# sanity check, since for now we assume that there are no training instances at all | ||
if any(instance.split == TRAIN_SPLIT for instance in instances): | ||
raise RuntimeError(f"Got train instances for {self.__class__.__name__} - expected only eval instances.") | ||
|
||
# use superclass implementation here | ||
return super().adapt(instances, parallelism) | ||
|
||
def construct_prompt( | ||
self, | ||
train_instances: List[Instance], # unused | ||
eval_instance: Instance, | ||
include_output: bool, # unused | ||
reference_index: Optional[int], # unused | ||
) -> Prompt: | ||
""" | ||
Uses the instance to construct a prompt for a given eval instance. | ||
Parameters | ||
---------- | ||
eval_instance: Instance | ||
the instance we wish to use to construct the prompt | ||
""" | ||
# start by simply getting the inputs | ||
question = eval_instance.input.text | ||
assert eval_instance.extra_data is not None | ||
ehr_text: str = eval_instance.extra_data["ehr"] | ||
prompt_template: str = eval_instance.extra_data["prompt_template"] | ||
full_prompt_text = prompt_template.format(question=question, ehr=ehr_text) | ||
|
||
# insert the question and see how many tokens we have so far | ||
prompt_with_instr_no_ehr_placeholder = prompt_template.format(question=question, ehr="") | ||
num_tokens_no_ehr = self.window_service.get_num_tokens(prompt_with_instr_no_ehr_placeholder) | ||
|
||
# number of tokens we can allow the EHR part to be | ||
target_ehr_num_tokens = ( | ||
self.window_service.max_request_length - self.adapter_spec.max_tokens - num_tokens_no_ehr | ||
) | ||
|
||
# round-trip tokenization to get the correct token length we need | ||
# NOTE: we truncate from the left side so that the most recent pieces of the EHR are included in the context | ||
# as opposed to the canonical way of truncating from the right. This is done to match the MedAlign method. | ||
full_ehr_tokens: EncodeResult = self.window_service.encode(ehr_text, max_length=None, truncation=False) | ||
truncated_ehr_tokens: List[TokenizationToken] = full_ehr_tokens.tokens[-target_ehr_num_tokens:] | ||
ehr_truncated: str | ||
ehr_truncated = self.window_service.decode(truncated_ehr_tokens) | ||
|
||
# create the truncated prompt | ||
truncated_prompt_text = prompt_template.format(question=question, ehr=ehr_truncated) | ||
num_truncations = 1 | ||
while ( | ||
num_extra_tokens := self.adapter_spec.max_tokens | ||
+ self.window_service.get_num_tokens(truncated_prompt_text) | ||
- self.window_service.max_request_length | ||
) > 0: | ||
truncated_ehr_tokens = truncated_ehr_tokens[num_extra_tokens:] | ||
ehr_truncated = self.window_service.decode(truncated_ehr_tokens) | ||
truncated_prompt_text = prompt_template.format(question=question, ehr=ehr_truncated) | ||
num_truncations += 1 | ||
|
||
# naively construct the full non-truncated prompt | ||
prompt = Prompt( | ||
global_prefix=self.adapter_spec.global_prefix, | ||
global_suffix=self.adapter_spec.global_suffix, | ||
instance_prefix=self.adapter_spec.instance_prefix, | ||
substitutions=self.adapter_spec.substitutions, | ||
instructions_block=self.adapter_spec.instructions, | ||
train_instance_blocks=[], | ||
eval_instance_block=full_prompt_text, | ||
truncated_text=truncated_prompt_text, | ||
) | ||
|
||
return prompt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import logging | ||
from typing import List | ||
|
||
import comet | ||
from torch import nn | ||
|
||
from helm.benchmark.adaptation.adapter_spec import AdapterSpec | ||
from helm.benchmark.adaptation.request_state import RequestState | ||
from helm.benchmark.adaptation.scenario_state import ScenarioState | ||
from helm.benchmark.metrics.metric import Metric, MetricResult | ||
from helm.benchmark.metrics.metric_name import MetricName | ||
from helm.benchmark.metrics.metric_service import MetricService | ||
from helm.benchmark.metrics.statistic import Stat | ||
from helm.common.hierarchical_logger import hlog | ||
from helm.common.request import RequestResult | ||
|
||
|
||
class CometMetric(Metric): | ||
"""COMET machine translation metric using a regression model. | ||
The model takes a triplet of source sentence, translation, and reference | ||
and computes a score in the range [0, 1] reflecting the quality of the predicted | ||
translation. | ||
Paper: | ||
@inproceedings{rei-etal-2022-comet, | ||
title = "{COMET}-22: Unbabel-{IST} 2022 Submission for the Metrics Shared Task", | ||
author = "Rei, Ricardo and | ||
C. de Souza, Jos{\'e} G. and | ||
Alves, Duarte and | ||
Zerva, Chrysoula and | ||
Farinha, Ana C and | ||
Glushkova, Taisiya and | ||
Lavie, Alon and | ||
Coheur, Luisa and | ||
Martins, Andr{\'e} F. T.", | ||
editor = {Koehn, Philipp and | ||
Barrault, Lo{\"\i}c and | ||
Bojar, Ond{\v{r}}ej and | ||
Bougares, Fethi and | ||
Chatterjee, Rajen and | ||
Costa-juss{\`a}, Marta R. and | ||
Federmann, Christian and | ||
Fishel, Mark and | ||
Fraser, Alexander and | ||
Freitag, Markus and | ||
Graham, Yvette and | ||
Grundkiewicz, Roman and | ||
Guzman, Paco and | ||
Haddow, Barry and | ||
Huck, Matthias and | ||
Jimeno Yepes, Antonio and | ||
Kocmi, Tom and | ||
Martins, Andr{\'e} and | ||
Morishita, Makoto and | ||
Monz, Christof and | ||
Nagata, Masaaki and | ||
Nakazawa, Toshiaki and | ||
Negri, Matteo and | ||
N{\'e}v{\'e}ol, Aur{\'e}lie and | ||
Neves, Mariana and | ||
Popel, Martin and | ||
Turchi, Marco and | ||
Zampieri, Marcos}, | ||
booktitle = "Proceedings of the Seventh Conference on Machine Translation (WMT)", | ||
month = dec, | ||
year = "2022", | ||
publisher = "Association for Computational Linguistics", | ||
url = "https://aclanthology.org/2022.wmt-1.52", | ||
} | ||
""" | ||
|
||
METRIC_NAME = "comet" | ||
|
||
def __init__(self, task: str, model_name: str = "Unbabel/wmt22-comet-da", device: str = "cpu"): | ||
self.model_name = model_name | ||
self.comet_scorer: nn.Module = self._load_model(model_name) | ||
self.num_gpus = 0 if device == "cpu" else 1 | ||
|
||
# suppress warnings from PyTorch Lightning which spams terminal | ||
logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING) | ||
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.WARNING) | ||
|
||
@staticmethod | ||
def _load_model(model_name: str) -> nn.Module: | ||
"""Load Comet model from the checkpoint. | ||
Returns: | ||
The loaded model. | ||
""" | ||
return comet.load_from_checkpoint(comet.download_model(model_name)) | ||
|
||
def evaluate( | ||
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int | ||
) -> MetricResult: | ||
hlog( | ||
f"Setting parallelism from {parallelism} to 1, since " | ||
f"evaluating {self.__class__.__name__} with parallelism > 1 seg faults." | ||
) | ||
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=1) | ||
|
||
def evaluate_generation( | ||
self, | ||
adapter_spec: AdapterSpec, | ||
request_state: RequestState, | ||
metric_service: MetricService, | ||
eval_cache_path: str, | ||
) -> List[Stat]: | ||
"""Compute the COMET score for this instance""" | ||
assert len(request_state.instance.references) == 1 | ||
ref = request_state.instance.references[0].output.text | ||
src = request_state.instance.input.text | ||
|
||
result = request_state.result | ||
if not isinstance(result, RequestResult): | ||
raise TypeError(f"Expected a valid result, but got {result}!") | ||
mt = result.completions[0].text.strip() | ||
|
||
# comet requires this exac5 format | ||
data = [dict(ref=ref, src=src, mt=mt)] | ||
output = self.comet_scorer.predict(data, gpus=self.num_gpus, progress_bar=False) # type: ignore | ||
comet_score = output[0][0] # extract the actual score | ||
|
||
metric_result = [Stat(MetricName(self.METRIC_NAME)).add(comet_score)] | ||
|
||
return metric_result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.