Skip to content

Commit

Permalink
updating logic
Browse files Browse the repository at this point in the history
  • Loading branch information
liamjxu committed Dec 14, 2024
1 parent 75e3ded commit a6788b7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/helm/benchmark/annotation/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def annotate(self, request_state: RequestState) -> Any:
that are implementation specific."""
pass

def annotate_all(self, request_states: List[RequestState]) -> Any:
def annotate_all(self, request_states: List[RequestState]) -> List[Dict[str, Any]]:
"""Fills the annotations field of all request states with additional information
that are implementation specific."""
pass
Expand Down
10 changes: 6 additions & 4 deletions src/helm/benchmark/annotation/bigcodebench_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from helm.common.request import Request
from helm.common.hierarchical_logger import hlog

from typing import Any, List
from typing import Any, List, Dict
from gradio_client import Client, handle_file
from tempfile import TemporaryDirectory
from tenacity import retry, stop_after_attempt, wait_fixed
Expand Down Expand Up @@ -74,7 +74,7 @@ def predict_with_retry(self, filename):
return results, pass_at_one


def annotate_all(self, request_states: List[RequestState]) -> Any:
def annotate_all(self, request_states: List[RequestState]) -> List[Dict[str, Any]]:
assert all(request_state.result for request_state in request_states)
assert all(len(request_state.result.completions) == 1 for request_state in request_states)
assert all(request_state.instance.extra_data for request_state in request_states)
Expand Down Expand Up @@ -103,6 +103,8 @@ def annotate_all(self, request_states: List[RequestState]) -> Any:
hlog("Failed to complete the operation after 3 attempts.")
pass_at_one = 0.0
results = []

ret = [{"pass_at_one": results['eval'][state.instance.id][0]['status'] == 'pass'} for state in request_states]
if len(results):
ret = [{"pass_at_one": results['eval'][state.instance.id][0]['status'] == 'pass'} for state in request_states]
else:
ret = [{"pass_at_one": False} for state in request_states]
return ret

0 comments on commit a6788b7

Please sign in to comment.