Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai committed Dec 20, 2024
1 parent 455bd90 commit c434f2f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from helm.benchmark.adaptation.adapters.generation_adapter import GenerationAdapter
from helm.benchmark.adaptation.prompt import Prompt
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.scenarios.scenario import TRAIN_SPLIT, EHRInstructionInput, Instance
from helm.benchmark.scenarios.scenario import TRAIN_SPLIT, Instance
from helm.benchmark.window_services.window_service import EncodeResult
from helm.common.tokenization_request import TokenizationToken

Expand Down Expand Up @@ -58,8 +58,9 @@ def construct_prompt(
"""
# start by simply getting the inputs
question = eval_instance.input.text
assert eval_instance.extra_data is not None
ehr_text: str = eval_instance.extra_data["ehr"]
prompt_template: str = eval_instance.extra_data["prmpt_template"]
prompt_template: str = eval_instance.extra_data["prompt_template"]
full_prompt_text = prompt_template.format(question=question, ehr=ehr_text)

# insert the question and see how many tokens we have so far
Expand Down
14 changes: 7 additions & 7 deletions src/helm/benchmark/scenarios/medalign_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ def get_instances(self, output_path: str) -> List[Instance]:
relevant_ehr = ehrs[pt_id] # type: ignore

# get the clinican response which serves as the reference
clinician_response_rows = list(clinician_responses_df[clinician_responses_df.instruction_id == instruction_id].iterrows())
clinician_response_rows = list(
clinician_responses_df[clinician_responses_df.instruction_id == instruction_id].iterrows()
)
assert len(clinician_response_rows) == 1
clinician_response = clinician_response_rows[0][1].clinician_response

Expand All @@ -230,12 +232,10 @@ def get_instances(self, output_path: str) -> List[Instance]:
text=instruction, # type: ignore
),
references=[Reference(Output(clinician_response), tags=[CORRECT_TAG])],
extra_data=
{
"prompt_template": prompt_template,
"ehr": relevant_ehr,
}
,
extra_data={
"prompt_template": prompt_template,
"ehr": relevant_ehr,
},
split=TEST_SPLIT,
)
)
Expand Down

0 comments on commit c434f2f

Please sign in to comment.