Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
liamjxu committed Dec 15, 2024
1 parent 70e7937 commit 054c4aa
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/helm/benchmark/annotation/bigcodebench_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def predict_with_retry(self, filename: str):
pass_at_one = evals["pass@1"]
return results, pass_at_one


def annotate_all(self, request_states: List[RequestState]) -> List[Dict[str, Any]]:
assert all(request_state.result for request_state in request_states)
assert all(len(request_state.result.completions) == 1 for request_state in request_states)
Expand All @@ -89,9 +88,7 @@ def annotate_all(self, request_states: List[RequestState]) -> List[Dict[str, Any
model_output_text = request_state.result.completions[0].text
solution = code_extract(model_output_text)
idx = int(request_state.instance.id.split("/")[-1])
res[idx] = json.dumps(
{"task_id": request_state.instance.id, "solution": solution}
) + "\n"
res[idx] = json.dumps({"task_id": request_state.instance.id, "solution": solution}) + "\n"
for line in res:
file.write(line)

Expand All @@ -102,7 +99,10 @@ def annotate_all(self, request_states: List[RequestState]) -> List[Dict[str, Any
pass_at_one = 0.0
results = []
if len(results):
ret = [{'bigcodebench': {"pass_at_one": results['eval'][state.instance.id][0]['status'] == 'pass'}} for state in request_states]
ret = [
{"bigcodebench": {"pass_at_one": results["eval"][state.instance.id][0]["status"] == "pass"}}
for state in request_states
]
else:
ret = [{'bigcodebench': {"pass_at_one": False}} for state in request_states]
ret = [{"bigcodebench": {"pass_at_one": False}} for state in request_states]
return ret
5 changes: 4 additions & 1 deletion src/helm/benchmark/annotation_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def execute(self, scenario_state: ScenarioState) -> ScenarioState:
hlog("No annotators to run.")
return scenario_state

if all(getattr(self.factory.get_annotator(spec), "use_global_metric", False) for spec in scenario_state.annotator_specs):
if all(
getattr(self.factory.get_annotator(spec), "use_global_metric", False)
for spec in scenario_state.annotator_specs
):
# Do it!
request_states = self.process_all(
scenario_state.annotator_specs, scenario_state.request_states # processing all request together
Expand Down
2 changes: 1 addition & 1 deletion src/helm/benchmark/scenarios/bigcodebench_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def get_instances(self, output_path: str) -> List[Instance]:
input=input,
references=[],
split=TEST_SPLIT,
id=row['task_id'],
id=row["task_id"],
extra_data={"task_id": row["task_id"]},
)
instances.append(instance)
Expand Down

0 comments on commit 054c4aa

Please sign in to comment.