Skip to content

Commit

Permalink
Merge pull request #177 from princeton-nlp/fix-run_live
Browse files Browse the repository at this point in the history
Fix run live scripts
  • Loading branch information
carlosejimenez authored Jul 11, 2024
2 parents 9085553 + 5221670 commit ba367d7
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 43 deletions.
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@
'torch',
'flash_attn',
'triton',
'jedi',
'tenacity',
],
},
include_package_data=True,
Expand Down
2 changes: 1 addition & 1 deletion swebench/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ You can also specify further options:

## Run live inference on open GitHub issues

Follow instructions [here](https://github.com/castorini/pyserini/blob/master/docs/installation.md) to install [Pyserini](https://github.com/castorini/pyserini), to perform BM25 retrieval.
Follow instructions [here](https://github.com/castorini/pyserini/blob/master/docs/installation.md) to install [Pyserini](https://github.com/castorini/pyserini), to perform BM25 retrieval, and [here](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md) to install [Faiss](https://github.com/facebookresearch/faiss).

Then run `run_live.py` to try solving a new issue. For example, you can try solving [this issue](https://github.com/huggingface/transformers/issues/26706 ) by running the following command:

Expand Down
24 changes: 12 additions & 12 deletions swebench/inference/make_datasets/bm25_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,13 +398,13 @@ def get_index_paths_worker(
repo_dir = clone_repo(repo, root_dir_name, token)
query = instance["problem_statement"]
index_path = make_index(
repo_dir,
root_dir_name,
query,
commit,
document_encoding_func,
python,
instance_id,
repo_dir=repo_dir,
root_dir=root_dir_name,
query=query,
commit=commit,
document_encoding_func=document_encoding_func,
python=python,
instance_id=instance_id,
)
except:
logger.error(f"Failed to process {repo}/{commit} (instance {instance_id})")
Expand Down Expand Up @@ -438,11 +438,11 @@ def get_index_paths(
all_index_paths = dict()
for instance in tqdm(remaining_instances, desc="Indexing"):
instance_id, index_path = get_index_paths_worker(
instance,
root_dir_name,
document_encoding_func,
python,
token,
instance=instance,
root_dir_name=root_dir_name,
document_encoding_func=document_encoding_func,
python=python,
token=token,
)
if index_path is None:
continue
Expand Down
10 changes: 5 additions & 5 deletions swebench/inference/run_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import openai
import tiktoken
import openai
from anthropic import HUMAN_PROMPT, AI_PROMPT, Anthropic
from tenacity import (
retry,
Expand Down Expand Up @@ -127,7 +127,7 @@ def call_chat(model_name_or_path, inputs, use_azure, temperature, top_p, **model
user_message = inputs.split("\n", 1)[1]
try:
if use_azure:
response = openai.ChatCompletion.create(
response = openai.chat.completions.create(
engine=ENGINES[model_name_or_path] if use_azure else None,
messages=[
{"role": "system", "content": system_messages},
Expand All @@ -138,7 +138,7 @@ def call_chat(model_name_or_path, inputs, use_azure, temperature, top_p, **model
**model_args,
)
else:
response = openai.ChatCompletion.create(
response = openai.chat.completions.create(
model=model_name_or_path,
messages=[
{"role": "system", "content": system_messages},
Expand All @@ -152,7 +152,7 @@ def call_chat(model_name_or_path, inputs, use_azure, temperature, top_p, **model
output_tokens = response.usage.completion_tokens
cost = calc_cost(response.model, input_tokens, output_tokens)
return response, cost
except openai.error.InvalidRequestError as e:
except openai.BadRequestError as e:
if e.code == "context_length_exceeded":
print("Context length exceeded")
return None
Expand Down Expand Up @@ -231,7 +231,7 @@ def openai_inference(
temperature,
top_p,
)
completion = response.choices[0]["message"]["content"]
completion = response.choices[0].message.content
total_cost += cost
print(f"Total Cost: {total_cost:.2f}")
output_dict["full_output"] = completion
Expand Down
48 changes: 23 additions & 25 deletions swebench/inference/run_live.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ def make_instance(
commit = subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=repo_dir
).decode("utf-8").strip()
logger.info(f"Buidling BM25 retrieval index for {owner}/{repo}@{commit}")
logger.info(f"Building BM25 retrieval index for {owner}/{repo}@{commit}")
index_dir = make_index(
repo_dir,
root_dir,
commit,
document_encoding_func,
python,
thread_id,
instance_id,
repo_dir=repo_dir,
root_dir=root_dir,
query=query,
commit=commit,
document_encoding_func=document_encoding_func,
python=python,
instance_id=instance_id,
)
results = search(instance, index_dir)
hits = results["hits"]
Expand Down Expand Up @@ -193,31 +193,29 @@ def main(
instance_id = f"{owner}__{repo}-{issue_num}"
logger.info(f"Creating instance {instance_id}")
instance = make_instance(
owner,
repo,
problem_statement,
commit,
root_dir,
gh_token,
document_encoding_func,
python,
instance_id,
tokenizer,
tokenizer_func,
prompt_style,
max_context_length,
include_readmes,
owner=owner,
repo=repo,
query=problem_statement,
commit=commit,
root_dir=root_dir,
token=gh_token,
document_encoding_func=document_encoding_func,
python=python,
instance_id=instance_id,
tokenizer=tokenizer,
tokenizer_func=tokenizer_func,
prompt_style=prompt_style,
max_context_len=max_context_length,
include_readmes=include_readmes,
)
logger.info(f"Calling model {model_name}")
start = time.time()
if model_name.startswith("gpt"):
import openai
openai.api_key = os.environ.get("OPENAI_API_KEY", None)
inputs = instance["text_inputs"]
response, _ = call_chat(
model_name, inputs, use_azure=False, temperature=0, top_p=1
)
completion = response.choices[0]["message"]["content"]
completion = response.choices[0].message.content
logger.info(f'Generated {response.usage.completion_tokens} tokens in {(time.time() - start):.2f} seconds')
else:
from anthropic import Anthropic
Expand Down

0 comments on commit ba367d7

Please sign in to comment.