Skip to content

Commit

Permalink
Add MedAlign scenario (#3038)
Browse files Browse the repository at this point in the history
Co-authored-by: Yifan Mai <[email protected]>
  • Loading branch information
aunell and yifanmai authored Dec 20, 2024
1 parent 66eba9f commit 5e9bf74
Show file tree
Hide file tree
Showing 7 changed files with 542 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/helm/benchmark/adaptation/adapter_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL: str = "multiple_choice_separate_original"
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED: str = "multiple_choice_separate_calibrated"
ADAPT_RANKING_BINARY: str = "ranking_binary"

ADAPT_EHR_INSTRUCTION: str = "ehr_instruction"
ADAPT_MULTIPLE_CHOICE_SEPARATE_METHODS: List[str] = [
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED,
Expand Down
6 changes: 5 additions & 1 deletion src/helm/benchmark/adaptation/adapters/adapter_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from helm.benchmark.adaptation.adapter_spec import (
ADAPT_EHR_INSTRUCTION,
ADAPT_GENERATION,
ADAPT_CHAT,
ADAPT_GENERATION_MULTIMODAL,
Expand Down Expand Up @@ -27,6 +28,7 @@
)
from helm.benchmark.adaptation.adapters.multiple_choice_separate_adapter import MultipleChoiceSeparateAdapter
from helm.benchmark.window_services.tokenizer_service import TokenizerService
from helm.benchmark.adaptation.adapters.ehr_instruction_adapter import EHRInstructionAdapter


class AdapterFactory:
Expand All @@ -38,7 +40,9 @@ def get_adapter(adapter_spec: AdapterSpec, tokenizer_service: TokenizerService)
method: str = adapter_spec.method
adapter: Adapter

if method == ADAPT_GENERATION:
if method == ADAPT_EHR_INSTRUCTION:
adapter = EHRInstructionAdapter(adapter_spec, tokenizer_service)
elif method == ADAPT_GENERATION:
adapter = GenerationAdapter(adapter_spec, tokenizer_service)
elif method == ADAPT_CHAT:
adapter = ChatAdapter(adapter_spec, tokenizer_service)
Expand Down
108 changes: 108 additions & 0 deletions src/helm/benchmark/adaptation/adapters/ehr_instruction_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import List, Optional

from helm.benchmark.adaptation.adapters.generation_adapter import GenerationAdapter
from helm.benchmark.adaptation.prompt import Prompt
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.scenarios.scenario import TRAIN_SPLIT, Instance
from helm.benchmark.window_services.window_service import EncodeResult
from helm.common.tokenization_request import TokenizationToken


# in the prompt templates for EHR instructions, this is the placeholder for the EHR part
# which we use to compute accurate tokenized sequence lengths
PROMPT_TEMPLATE_EHR_PLACEHOLDER = "{ehr}"


class EHRInstructionAdapter(GenerationAdapter):
"""
Each instance consists of the following:
EHRInstructionInput:
question: the question to answer or instruction to follow
ehr: the XML-tagged EHR to use as context to answer the question
prompt_template: a string template for how to combine the question + ehr
Reference output:
text: the 'golden' clinician response to the question
This Adapter combines the above into RequestStates with logic to truncate the EHR specifically
to fit in the context window with enough room for the instruction/question and the specified
amount of generated tokens.
"""

def adapt(self, instances: List[Instance], parallelism: int) -> List[RequestState]:
"""
Main adaptation method which takes all instances and turns them into `RequestState` objects.
"""
# sanity check, since for now we assume that there are no training instances at all
if any(instance.split == TRAIN_SPLIT for instance in instances):
raise RuntimeError(f"Got train instances for {self.__class__.__name__} - expected only eval instances.")

# use superclass implementation here
return super().adapt(instances, parallelism)

def construct_prompt(
self,
train_instances: List[Instance], # unused
eval_instance: Instance,
include_output: bool, # unused
reference_index: Optional[int], # unused
) -> Prompt:
"""
Uses the instance to construct a prompt for a given eval instance.
Parameters
----------
eval_instance: Instance
the instance we wish to use to construct the prompt
"""
# start by simply getting the inputs
question = eval_instance.input.text
assert eval_instance.extra_data is not None
ehr_text: str = eval_instance.extra_data["ehr"]
prompt_template: str = eval_instance.extra_data["prompt_template"]
full_prompt_text = prompt_template.format(question=question, ehr=ehr_text)

# insert the question and see how many tokens we have so far
prompt_with_instr_no_ehr_placeholder = prompt_template.format(question=question, ehr="")
num_tokens_no_ehr = self.window_service.get_num_tokens(prompt_with_instr_no_ehr_placeholder)

# number of tokens we can allow the EHR part to be
target_ehr_num_tokens = (
self.window_service.max_request_length - self.adapter_spec.max_tokens - num_tokens_no_ehr
)

# round-trip tokenization to get the correct token length we need
# NOTE: we truncate from the left side so that the most recent pieces of the EHR are included in the context
# as opposed to the canonical way of truncating from the right. This is done to match the MedAlign method.
full_ehr_tokens: EncodeResult = self.window_service.encode(ehr_text, max_length=None, truncation=False)
truncated_ehr_tokens: List[TokenizationToken] = full_ehr_tokens.tokens[-target_ehr_num_tokens:]
ehr_truncated: str
ehr_truncated = self.window_service.decode(truncated_ehr_tokens)

# create the truncated prompt
truncated_prompt_text = prompt_template.format(question=question, ehr=ehr_truncated)
num_truncations = 1
while (
num_extra_tokens := self.adapter_spec.max_tokens
+ self.window_service.get_num_tokens(truncated_prompt_text)
- self.window_service.max_request_length
) > 0:
truncated_ehr_tokens = truncated_ehr_tokens[num_extra_tokens:]
ehr_truncated = self.window_service.decode(truncated_ehr_tokens)
truncated_prompt_text = prompt_template.format(question=question, ehr=ehr_truncated)
num_truncations += 1

# naively construct the full non-truncated prompt
prompt = Prompt(
global_prefix=self.adapter_spec.global_prefix,
global_suffix=self.adapter_spec.global_suffix,
instance_prefix=self.adapter_spec.instance_prefix,
substitutions=self.adapter_spec.substitutions,
instructions_block=self.adapter_spec.instructions,
train_instance_blocks=[],
eval_instance_block=full_prompt_text,
truncated_text=truncated_prompt_text,
)

return prompt
125 changes: 125 additions & 0 deletions src/helm/benchmark/metrics/comet_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import logging
from typing import List

import comet
from torch import nn

from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.adaptation.scenario_state import ScenarioState
from helm.benchmark.metrics.metric import Metric, MetricResult
from helm.benchmark.metrics.metric_name import MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat
from helm.common.hierarchical_logger import hlog
from helm.common.request import RequestResult


class CometMetric(Metric):
"""COMET machine translation metric using a regression model.
The model takes a triplet of source sentence, translation, and reference
and computes a score in the range [0, 1] reflecting the quality of the predicted
translation.
Paper:
@inproceedings{rei-etal-2022-comet,
title = "{COMET}-22: Unbabel-{IST} 2022 Submission for the Metrics Shared Task",
author = "Rei, Ricardo and
C. de Souza, Jos{\'e} G. and
Alves, Duarte and
Zerva, Chrysoula and
Farinha, Ana C and
Glushkova, Taisiya and
Lavie, Alon and
Coheur, Luisa and
Martins, Andr{\'e} F. T.",
editor = {Koehn, Philipp and
Barrault, Lo{\"\i}c and
Bojar, Ond{\v{r}}ej and
Bougares, Fethi and
Chatterjee, Rajen and
Costa-juss{\`a}, Marta R. and
Federmann, Christian and
Fishel, Mark and
Fraser, Alexander and
Freitag, Markus and
Graham, Yvette and
Grundkiewicz, Roman and
Guzman, Paco and
Haddow, Barry and
Huck, Matthias and
Jimeno Yepes, Antonio and
Kocmi, Tom and
Martins, Andr{\'e} and
Morishita, Makoto and
Monz, Christof and
Nagata, Masaaki and
Nakazawa, Toshiaki and
Negri, Matteo and
N{\'e}v{\'e}ol, Aur{\'e}lie and
Neves, Mariana and
Popel, Martin and
Turchi, Marco and
Zampieri, Marcos},
booktitle = "Proceedings of the Seventh Conference on Machine Translation (WMT)",
month = dec,
year = "2022",
publisher = "Association for Computational Linguistics",
url = "https://aclanthology.org/2022.wmt-1.52",
}
"""

METRIC_NAME = "comet"

def __init__(self, task: str, model_name: str = "Unbabel/wmt22-comet-da", device: str = "cpu"):
self.model_name = model_name
self.comet_scorer: nn.Module = self._load_model(model_name)
self.num_gpus = 0 if device == "cpu" else 1

# suppress warnings from PyTorch Lightning which spams terminal
logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING)
logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.WARNING)

@staticmethod
def _load_model(model_name: str) -> nn.Module:
"""Load Comet model from the checkpoint.
Returns:
The loaded model.
"""
return comet.load_from_checkpoint(comet.download_model(model_name))

def evaluate(
self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int
) -> MetricResult:
hlog(
f"Setting parallelism from {parallelism} to 1, since "
f"evaluating {self.__class__.__name__} with parallelism > 1 seg faults."
)
return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=1)

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
"""Compute the COMET score for this instance"""
assert len(request_state.instance.references) == 1
ref = request_state.instance.references[0].output.text
src = request_state.instance.input.text

result = request_state.result
if not isinstance(result, RequestResult):
raise TypeError(f"Expected a valid result, but got {result}!")
mt = result.completions[0].text.strip()

# comet requires this exac5 format
data = [dict(ref=ref, src=src, mt=mt)]
output = self.comet_scorer.predict(data, gpus=self.num_gpus, progress_bar=False) # type: ignore
comet_score = output[0][0] # extract the actual score

metric_result = [Stat(MetricName(self.METRIC_NAME)).add(comet_score)]

return metric_result
36 changes: 36 additions & 0 deletions src/helm/benchmark/run_specs/classic_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL,
ADAPT_RANKING_BINARY,
AdapterSpec,
ADAPT_EHR_INSTRUCTION,
)
from helm.benchmark.adaptation.adapters.binary_ranking_adapter import BinaryRankingAdapter
from helm.benchmark.adaptation.common_adapter_specs import (
Expand Down Expand Up @@ -1218,6 +1219,41 @@ def get_medication_qa_spec() -> RunSpec:
)


def get_medalign_adapter_spec() -> AdapterSpec:
return AdapterSpec(
method=ADAPT_EHR_INSTRUCTION,
max_train_instances=0,
max_eval_instances=None,
num_outputs=1,
max_tokens=256, # MedAlign default number of generation tokens
)


def get_comet_metric_specs(args: Dict[str, Any]) -> List[MetricSpec]:
return [MetricSpec(class_name="helm.benchmark.metrics.comet_metric.CometMetric", args=args)]


@run_spec_function("medalign")
def get_medalign_spec(prompt_template: str = "generic.txt") -> RunSpec:
from helm.common.gpu_utils import get_torch_device_name

scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.medalign_scenario.MedAlignScenario",
args={"prompt_template": prompt_template},
)

adapter_spec = get_medalign_adapter_spec()

metric_args = {"task": "medalign", "device": get_torch_device_name()}
return RunSpec(
name="medalign",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_summarization_metric_specs(metric_args) + get_comet_metric_specs(metric_args),
groups=["medalign", "med_helm"],
)


@run_spec_function("lextreme")
def get_lextreme_spec(subset: str) -> RunSpec:
from helm.benchmark.scenarios.lextreme_scenario import (
Expand Down
Loading

0 comments on commit 5e9bf74

Please sign in to comment.