diff --git a/src/helm/benchmark/adaptation/adapter_spec.py b/src/helm/benchmark/adaptation/adapter_spec.py index ccb85ae62e..23b9deda11 100644 --- a/src/helm/benchmark/adaptation/adapter_spec.py +++ b/src/helm/benchmark/adaptation/adapter_spec.py @@ -13,7 +13,7 @@ ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL: str = "multiple_choice_separate_original" ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED: str = "multiple_choice_separate_calibrated" ADAPT_RANKING_BINARY: str = "ranking_binary" - +ADAPT_EHR_INSTRUCTION: str = "ehr_instruction" ADAPT_MULTIPLE_CHOICE_SEPARATE_METHODS: List[str] = [ ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL, ADAPT_MULTIPLE_CHOICE_SEPARATE_CALIBRATED, diff --git a/src/helm/benchmark/adaptation/adapters/adapter_factory.py b/src/helm/benchmark/adaptation/adapters/adapter_factory.py index f5f0df89f0..b8580e2f8c 100644 --- a/src/helm/benchmark/adaptation/adapters/adapter_factory.py +++ b/src/helm/benchmark/adaptation/adapters/adapter_factory.py @@ -1,4 +1,5 @@ from helm.benchmark.adaptation.adapter_spec import ( + ADAPT_EHR_INSTRUCTION, ADAPT_GENERATION, ADAPT_CHAT, ADAPT_GENERATION_MULTIMODAL, @@ -27,6 +28,7 @@ ) from helm.benchmark.adaptation.adapters.multiple_choice_separate_adapter import MultipleChoiceSeparateAdapter from helm.benchmark.window_services.tokenizer_service import TokenizerService +from helm.benchmark.adaptation.adapters.ehr_instruction_adapter import EHRInstructionAdapter class AdapterFactory: @@ -38,7 +40,9 @@ def get_adapter(adapter_spec: AdapterSpec, tokenizer_service: TokenizerService) method: str = adapter_spec.method adapter: Adapter - if method == ADAPT_GENERATION: + if method == ADAPT_EHR_INSTRUCTION: + adapter = EHRInstructionAdapter(adapter_spec, tokenizer_service) + elif method == ADAPT_GENERATION: adapter = GenerationAdapter(adapter_spec, tokenizer_service) elif method == ADAPT_CHAT: adapter = ChatAdapter(adapter_spec, tokenizer_service) diff --git a/src/helm/benchmark/adaptation/adapters/ehr_instruction_adapter.py b/src/helm/benchmark/adaptation/adapters/ehr_instruction_adapter.py new file mode 100644 index 0000000000..cbb66f1a5b --- /dev/null +++ b/src/helm/benchmark/adaptation/adapters/ehr_instruction_adapter.py @@ -0,0 +1,108 @@ +from typing import List, Optional + +from helm.benchmark.adaptation.adapters.generation_adapter import GenerationAdapter +from helm.benchmark.adaptation.prompt import Prompt +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.scenarios.scenario import TRAIN_SPLIT, Instance +from helm.benchmark.window_services.window_service import EncodeResult +from helm.common.tokenization_request import TokenizationToken + + +# in the prompt templates for EHR instructions, this is the placeholder for the EHR part +# which we use to compute accurate tokenized sequence lengths +PROMPT_TEMPLATE_EHR_PLACEHOLDER = "{ehr}" + + +class EHRInstructionAdapter(GenerationAdapter): + """ + Each instance consists of the following: + + EHRInstructionInput: + question: the question to answer or instruction to follow + ehr: the XML-tagged EHR to use as context to answer the question + prompt_template: a string template for how to combine the question + ehr + + Reference output: + text: the 'golden' clinician response to the question + + This Adapter combines the above into RequestStates with logic to truncate the EHR specifically + to fit in the context window with enough room for the instruction/question and the specified + amount of generated tokens. + """ + + def adapt(self, instances: List[Instance], parallelism: int) -> List[RequestState]: + """ + Main adaptation method which takes all instances and turns them into `RequestState` objects. + """ + # sanity check, since for now we assume that there are no training instances at all + if any(instance.split == TRAIN_SPLIT for instance in instances): + raise RuntimeError(f"Got train instances for {self.__class__.__name__} - expected only eval instances.") + + # use superclass implementation here + return super().adapt(instances, parallelism) + + def construct_prompt( + self, + train_instances: List[Instance], # unused + eval_instance: Instance, + include_output: bool, # unused + reference_index: Optional[int], # unused + ) -> Prompt: + """ + Uses the instance to construct a prompt for a given eval instance. + + Parameters + ---------- + eval_instance: Instance + the instance we wish to use to construct the prompt + """ + # start by simply getting the inputs + question = eval_instance.input.text + assert eval_instance.extra_data is not None + ehr_text: str = eval_instance.extra_data["ehr"] + prompt_template: str = eval_instance.extra_data["prompt_template"] + full_prompt_text = prompt_template.format(question=question, ehr=ehr_text) + + # insert the question and see how many tokens we have so far + prompt_with_instr_no_ehr_placeholder = prompt_template.format(question=question, ehr="") + num_tokens_no_ehr = self.window_service.get_num_tokens(prompt_with_instr_no_ehr_placeholder) + + # number of tokens we can allow the EHR part to be + target_ehr_num_tokens = ( + self.window_service.max_request_length - self.adapter_spec.max_tokens - num_tokens_no_ehr + ) + + # round-trip tokenization to get the correct token length we need + # NOTE: we truncate from the left side so that the most recent pieces of the EHR are included in the context + # as opposed to the canonical way of truncating from the right. This is done to match the MedAlign method. + full_ehr_tokens: EncodeResult = self.window_service.encode(ehr_text, max_length=None, truncation=False) + truncated_ehr_tokens: List[TokenizationToken] = full_ehr_tokens.tokens[-target_ehr_num_tokens:] + ehr_truncated: str + ehr_truncated = self.window_service.decode(truncated_ehr_tokens) + + # create the truncated prompt + truncated_prompt_text = prompt_template.format(question=question, ehr=ehr_truncated) + num_truncations = 1 + while ( + num_extra_tokens := self.adapter_spec.max_tokens + + self.window_service.get_num_tokens(truncated_prompt_text) + - self.window_service.max_request_length + ) > 0: + truncated_ehr_tokens = truncated_ehr_tokens[num_extra_tokens:] + ehr_truncated = self.window_service.decode(truncated_ehr_tokens) + truncated_prompt_text = prompt_template.format(question=question, ehr=ehr_truncated) + num_truncations += 1 + + # naively construct the full non-truncated prompt + prompt = Prompt( + global_prefix=self.adapter_spec.global_prefix, + global_suffix=self.adapter_spec.global_suffix, + instance_prefix=self.adapter_spec.instance_prefix, + substitutions=self.adapter_spec.substitutions, + instructions_block=self.adapter_spec.instructions, + train_instance_blocks=[], + eval_instance_block=full_prompt_text, + truncated_text=truncated_prompt_text, + ) + + return prompt diff --git a/src/helm/benchmark/metrics/comet_metric.py b/src/helm/benchmark/metrics/comet_metric.py new file mode 100644 index 0000000000..048eebd4c4 --- /dev/null +++ b/src/helm/benchmark/metrics/comet_metric.py @@ -0,0 +1,125 @@ +import logging +from typing import List + +import comet +from torch import nn + +from helm.benchmark.adaptation.adapter_spec import AdapterSpec +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.adaptation.scenario_state import ScenarioState +from helm.benchmark.metrics.metric import Metric, MetricResult +from helm.benchmark.metrics.metric_name import MetricName +from helm.benchmark.metrics.metric_service import MetricService +from helm.benchmark.metrics.statistic import Stat +from helm.common.hierarchical_logger import hlog +from helm.common.request import RequestResult + + +class CometMetric(Metric): + """COMET machine translation metric using a regression model. + The model takes a triplet of source sentence, translation, and reference + and computes a score in the range [0, 1] reflecting the quality of the predicted + translation. + + Paper: + @inproceedings{rei-etal-2022-comet, + title = "{COMET}-22: Unbabel-{IST} 2022 Submission for the Metrics Shared Task", + author = "Rei, Ricardo and + C. de Souza, Jos{\'e} G. and + Alves, Duarte and + Zerva, Chrysoula and + Farinha, Ana C and + Glushkova, Taisiya and + Lavie, Alon and + Coheur, Luisa and + Martins, Andr{\'e} F. T.", + editor = {Koehn, Philipp and + Barrault, Lo{\"\i}c and + Bojar, Ond{\v{r}}ej and + Bougares, Fethi and + Chatterjee, Rajen and + Costa-juss{\`a}, Marta R. and + Federmann, Christian and + Fishel, Mark and + Fraser, Alexander and + Freitag, Markus and + Graham, Yvette and + Grundkiewicz, Roman and + Guzman, Paco and + Haddow, Barry and + Huck, Matthias and + Jimeno Yepes, Antonio and + Kocmi, Tom and + Martins, Andr{\'e} and + Morishita, Makoto and + Monz, Christof and + Nagata, Masaaki and + Nakazawa, Toshiaki and + Negri, Matteo and + N{\'e}v{\'e}ol, Aur{\'e}lie and + Neves, Mariana and + Popel, Martin and + Turchi, Marco and + Zampieri, Marcos}, + booktitle = "Proceedings of the Seventh Conference on Machine Translation (WMT)", + month = dec, + year = "2022", + publisher = "Association for Computational Linguistics", + url = "https://aclanthology.org/2022.wmt-1.52", + } + """ + + METRIC_NAME = "comet" + + def __init__(self, task: str, model_name: str = "Unbabel/wmt22-comet-da", device: str = "cpu"): + self.model_name = model_name + self.comet_scorer: nn.Module = self._load_model(model_name) + self.num_gpus = 0 if device == "cpu" else 1 + + # suppress warnings from PyTorch Lightning which spams terminal + logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.WARNING) + logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.WARNING) + + @staticmethod + def _load_model(model_name: str) -> nn.Module: + """Load Comet model from the checkpoint. + + Returns: + The loaded model. + """ + return comet.load_from_checkpoint(comet.download_model(model_name)) + + def evaluate( + self, scenario_state: ScenarioState, metric_service: MetricService, eval_cache_path: str, parallelism: int + ) -> MetricResult: + hlog( + f"Setting parallelism from {parallelism} to 1, since " + f"evaluating {self.__class__.__name__} with parallelism > 1 seg faults." + ) + return super().evaluate(scenario_state, metric_service, eval_cache_path, parallelism=1) + + def evaluate_generation( + self, + adapter_spec: AdapterSpec, + request_state: RequestState, + metric_service: MetricService, + eval_cache_path: str, + ) -> List[Stat]: + """Compute the COMET score for this instance""" + assert len(request_state.instance.references) == 1 + ref = request_state.instance.references[0].output.text + src = request_state.instance.input.text + + result = request_state.result + if not isinstance(result, RequestResult): + raise TypeError(f"Expected a valid result, but got {result}!") + mt = result.completions[0].text.strip() + + # comet requires this exac5 format + data = [dict(ref=ref, src=src, mt=mt)] + output = self.comet_scorer.predict(data, gpus=self.num_gpus, progress_bar=False) # type: ignore + comet_score = output[0][0] # extract the actual score + + metric_result = [Stat(MetricName(self.METRIC_NAME)).add(comet_score)] + + return metric_result diff --git a/src/helm/benchmark/run_specs/classic_run_specs.py b/src/helm/benchmark/run_specs/classic_run_specs.py index b9e80e90b9..630529a81f 100644 --- a/src/helm/benchmark/run_specs/classic_run_specs.py +++ b/src/helm/benchmark/run_specs/classic_run_specs.py @@ -14,6 +14,7 @@ ADAPT_MULTIPLE_CHOICE_SEPARATE_ORIGINAL, ADAPT_RANKING_BINARY, AdapterSpec, + ADAPT_EHR_INSTRUCTION, ) from helm.benchmark.adaptation.adapters.binary_ranking_adapter import BinaryRankingAdapter from helm.benchmark.adaptation.common_adapter_specs import ( @@ -1218,6 +1219,41 @@ def get_medication_qa_spec() -> RunSpec: ) +def get_medalign_adapter_spec() -> AdapterSpec: + return AdapterSpec( + method=ADAPT_EHR_INSTRUCTION, + max_train_instances=0, + max_eval_instances=None, + num_outputs=1, + max_tokens=256, # MedAlign default number of generation tokens + ) + + +def get_comet_metric_specs(args: Dict[str, Any]) -> List[MetricSpec]: + return [MetricSpec(class_name="helm.benchmark.metrics.comet_metric.CometMetric", args=args)] + + +@run_spec_function("medalign") +def get_medalign_spec(prompt_template: str = "generic.txt") -> RunSpec: + from helm.common.gpu_utils import get_torch_device_name + + scenario_spec = ScenarioSpec( + class_name="helm.benchmark.scenarios.medalign_scenario.MedAlignScenario", + args={"prompt_template": prompt_template}, + ) + + adapter_spec = get_medalign_adapter_spec() + + metric_args = {"task": "medalign", "device": get_torch_device_name()} + return RunSpec( + name="medalign", + scenario_spec=scenario_spec, + adapter_spec=adapter_spec, + metric_specs=get_summarization_metric_specs(metric_args) + get_comet_metric_specs(metric_args), + groups=["medalign", "med_helm"], + ) + + @run_spec_function("lextreme") def get_lextreme_spec(subset: str) -> RunSpec: from helm.benchmark.scenarios.lextreme_scenario import ( diff --git a/src/helm/benchmark/scenarios/medalign_scenario.py b/src/helm/benchmark/scenarios/medalign_scenario.py new file mode 100644 index 0000000000..c84846ceff --- /dev/null +++ b/src/helm/benchmark/scenarios/medalign_scenario.py @@ -0,0 +1,242 @@ +import os +from pathlib import Path +import re +from typing import Dict, List, Optional, Tuple, Union + +import pandas as pd + +from helm.benchmark.scenarios.scenario import ( + CORRECT_TAG, + TEST_SPLIT, + Input, + Instance, + Output, + Reference, + Scenario, +) + + +# /share/pi/nigam/data/MedAlign on Carina +EHR_BASE_PATH = "/local-scratch/shahlab/aunell/helm/medalign_data/ehr_unzipped/full_ehrs" +INSTRUCTIONS_PATH = "/local-scratch/shahlab/aunell/helm/medalign_data/ehr-relevance-labels.csv" +CLINICIAN_RESPONSES_PATH = "/local-scratch/shahlab/aunell/helm/medalign_data/clinician-instruction-responses.csv" +PROMPT_TEMPLATES_BASE_PATH = "/local-scratch/shahlab/aunell/helm/medalign_data/prompt_templates/" + + +def extract_patient_id_from_fname(fname: str) -> Optional[int]: + """ + Extracts and returns the patient ID from a given filename. + + The function expects filenames in the format 'EHR_.xml', + where is a sequence of digits. + + Parameters: + fname (str): The filename from which to extract the patient ID. + + Returns: + Optional[int]: The extracted patient ID as an integer, or None if + the filename doesn't match the expected format. + """ + regex_result = re.search(r"EHR_(\d+)\.xml", fname) + if regex_result is None: + return None + return int(regex_result.group(1)) + + +def get_ehrs(path_to_ehrs: str) -> Dict[int, str]: + """ + Builds a map from Instruction ID to EHR (Electronic Health Record) timeline. + + EHR timelines are in string format and EHR files are read in from the + user-specified directory. Each file in the directory should be named + 'EHR_.xml', where is a sequence of digits. + + See https://stanfordmedicine.box.com/s/r28wfwwude9rpjtu0szhzegmku8qv2pe + + Parameters: + path_to_ehrs (str): The path to the directory containing the EHR files. + + Returns: + Dict[int, str]: A dictionary mapping patient IDs to EHR timelines. + + Raises: + FileNotFoundError: If the specified directory does not exist. + """ + if not os.path.isdir(path_to_ehrs): + raise FileNotFoundError(f"The specified directory {path_to_ehrs} does not exist.") + + ehr_map = {} + for fname in os.listdir(path_to_ehrs): + pt_id = extract_patient_id_from_fname(fname) + if pt_id is None: + print(f"Warning: File '{fname}' does not match the expected format " "and will be skipped.") + continue + + file_path = os.path.join(path_to_ehrs, fname) + with open(file_path, encoding="utf-8", mode="r") as f: + ehr = f.read() + + ehr_map[pt_id] = ehr + return ehr_map + + +def get_instructions(path_to_instructions: str) -> Dict[int, Dict[str, Union[int, str]]]: + """ + Builds map from Instruction ID to instruction details + + The needed information for creating the map is accomplished by reading + a CSV file from the user-specified specified path. + + The CSV file is expected to contain at least the following columns: + - instruction_id: The ID of the instruction. + - question: The text of the instruction. + - person_id: The ID of the associated patient. + - is_selected_ehr: A flag indicating whether the instruction is selected. + + See https://stanfordmedicine.box.com/s/0om9qav2sklb9vaitn0ibye65vgbfx0e + + Parameters: + path_to_instructions (str): Path to CSV file containing instructions. + + Returns: + Dict[int, Dict[str, Any]]: A dictionary mapping instruction IDs to a + dictionary containing instruction text and associated patient ID. + + Raises: + FileNotFoundError: If the specified file does not exist. + ValueError: If the CSV file does not contain the expected columns. + """ + if not os.path.exists(path_to_instructions): + raise FileNotFoundError(f"The specified file {path_to_instructions} does not exist.") + + instructions_df = pd.read_csv(path_to_instructions) + required_columns = { + "instruction_id", + "question", + "person_id", + "is_selected_ehr", + } + if not required_columns.issubset(instructions_df.columns): + raise ValueError(f"The CSV file is missing one or more of the required columns: {required_columns}") + + selected_instructions_df = instructions_df.query("is_selected_ehr == 'yes'") + instructions_map = { + row["instruction_id"]: { + "instruction": row["question"], + "patient_id": row["person_id"], + } + for _, row in selected_instructions_df.iterrows() + } + return instructions_map + + +class MedAlignScenario(Scenario): + """Scenario defining the MedAlign task as defined in the following work by Fleming et al: + @article{fleming2023medalign, + title={MedAlign: A Clinician-Generated Dataset for Instruction Following with Electronic Medical Records}, + author={Scott L. Fleming + and Alejandro Lozano + and William J. Haberkorn + and Jenelle A. Jindal + and Eduardo P. Reis + and Rahul Thapa + and Louis Blankemeier + and Julian Z. Genkins + and Ethan Steinberg + and Ashwin Nayak + and Birju S. Patel + and Chia-Chun Chiang + and Alison Callahan + and Zepeng Huo + and Sergios Gatidis + and Scott J. Adams + and Oluseyi Fayanju + and Shreya J. Shah + and Thomas Savage + and Ethan Goh + and Akshay S. Chaudhari + and Nima Aghaeepour + and Christopher Sharp + and Michael A. Pfeffer + and Percy Liang + and Jonathan H. Chen + and Keith E. Morse + and Emma P. Brunskill + and Jason A. Fries + and Nigam H. Shah}, + journal={arXiv preprint arXiv:2308.14089}, + year={2023} + } + + Each instance includes: + - input: the instruction and patient record + - reference: the clinical 'gold standard' completion for the instruction for the given patient record + + This is a clinical instruction-following task, wherein a generative language model must follow + the instructions using the provided patient record. As explained in the MedAlign work, each example + is guaranteed to be completable for the given patient record. + + This task is evaluated using COMET and BERTScore metrics. + """ + + name = "medalign" + description = "MedAlign clinical instruction following task and dataset" + tags = ["instruction_following", "generation"] + + def __init__(self, prompt_template: str = "generic.txt"): + super().__init__() + self.prompt_template = prompt_template + + def _load_dataset(self) -> Tuple[Dict[int, Dict[str, Union[int, str]]], Dict[int, str], pd.DataFrame]: + assert os.path.exists(INSTRUCTIONS_PATH) + assert os.path.exists(CLINICIAN_RESPONSES_PATH) + assert os.path.exists(EHR_BASE_PATH) + + instructions = get_instructions(INSTRUCTIONS_PATH) + ehrs = get_ehrs(EHR_BASE_PATH) + gold_df = pd.read_csv(CLINICIAN_RESPONSES_PATH) + + # required filtering to match MedAlign code. + # TODO: clean this up either in the data files or with better logic + gold_df = gold_df[gold_df.annotator_num == "Annotator_1"] + return instructions, ehrs, gold_df + + def get_instances(self, output_path: str) -> List[Instance]: + instructions, ehrs, clinician_responses_df = self._load_dataset() + prompt_template_path = Path(PROMPT_TEMPLATES_BASE_PATH) / self.prompt_template + if not (prompt_template_path.exists() and prompt_template_path.is_file()): + raise RuntimeError(f"Prompt template path {str(prompt_template_path)} not found!") + + with open(prompt_template_path, "r", encoding="utf-8") as fh: + prompt_template = fh.read() + + instances: List[Instance] = [] + for instruction_id, instruction_dict in instructions.items(): + # get the actual instruction + instruction: Union[str, int] = instruction_dict["instruction"] # question or task + + # get the patient EHR selected for this instruction + pt_id: Union[str, int] = instruction_dict["patient_id"] + relevant_ehr = ehrs[pt_id] # type: ignore + + # get the clinican response which serves as the reference + clinician_response_rows = list( + clinician_responses_df[clinician_responses_df.instruction_id == instruction_id].iterrows() + ) + assert len(clinician_response_rows) == 1 + clinician_response = clinician_response_rows[0][1].clinician_response + + instances.append( + Instance( + input=Input( + text=instruction, # type: ignore + ), + references=[Reference(Output(clinician_response), tags=[CORRECT_TAG])], + extra_data={ + "prompt_template": prompt_template, + "ehr": relevant_ehr, + }, + split=TEST_SPLIT, + ) + ) + return instances diff --git a/src/helm/benchmark/static/schema_medical.yaml b/src/helm/benchmark/static/schema_medical.yaml index cd37759ea3..aff838dc23 100644 --- a/src/helm/benchmark/static/schema_medical.yaml +++ b/src/helm/benchmark/static/schema_medical.yaml @@ -82,6 +82,15 @@ metrics: display_name: Judge Score description: LLM-as-judge score lower_is_better: false + - name: comet + display_name: COMET Score + short_display_name: comet + description: A model-based score of similarity of a machine translation based on a source, predicted translation, and reference translation. + lower_is_better: false + - name: BERTScore-F + display_name: BERTScore F1 + description: BERTScore F1 score. + lower_is_better: false # Toxicity metrics - name: expected_max_toxicity @@ -150,6 +159,7 @@ run_groups: - mmlu - live_qa - medication_qa + - medalign - name: med_qa display_name: MedQA @@ -253,3 +263,18 @@ run_groups: who: n/a when: n/a language: English + - name: medalign + display_name: MedAlign + short_display_name: MedAlign + description: A question answering dataset for clinical questions, each paired with a relevant patient EHR and a clinician-generated gold response. + metric_groups: + - accuracy + environment: + main_name: BERTScore-F + main_split: test + taxonomy: + task: question answering + what: "?" + who: "?" + when: "?" + language: English \ No newline at end of file