Skip to content

Commit

Permalink
Pass runtime to LLM provider's ainvoke
Browse files Browse the repository at this point in the history
  • Loading branch information
klntsky committed Nov 22, 2024
1 parent 4393c8d commit 013ba66
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/model-selection-demo.metaprompt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[:MODEL=[:use ./choose-model :prompt=[:prompt]]]
Selected model: [:MODEL]
[:STATUS=evaluating the prompt]
[:STATUS=evaluating the prompt using [:MODEL]]
[$[:prompt]]
3 changes: 2 additions & 1 deletion python/src/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ async def stream_invoke(
chat: List[{"role": str, "content": str}],
history: List[{"role": str, "content": str}] = [],
) -> AsyncGenerator[str, None]:
nonlocal runtime
provider = get_current_model_provider()
async for chunk in provider.ainvoke(chat, history):
async for chunk in provider.ainvoke(chat, history, runtime=runtime):
yield chunk

async def invoke(chat, history) -> str:
Expand Down
8 changes: 4 additions & 4 deletions python/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,13 @@ async def _main():
for file_path in args.INPUT_FILES:
if os.path.isfile(file_path):
with open(file_path, "r") as file:
content = file.read()
metaprompt = parse_metaprompt(content)
env = Env(env=config.parameters)
runtime = CliRuntime()
runtime.cwd = os.path.dirname(file_path)
# TODO: use file loading from runtime
runtime.set_status("running " + file_path)
runtime.cwd = os.path.dirname(file_path)
content = file.read()
metaprompt = parse_metaprompt(content)
env = Env(env=config.parameters)
async for chunk in eval_ast(metaprompt, config, runtime):
runtime.print_chunk(chunk)
runtime.finalize()
Expand Down
5 changes: 3 additions & 2 deletions python/src/providers/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def __init__(self, api_key: str = None):
async def ainvoke(
self,
chat: List[{ "role": str, "content": str }],
history = [] # TODO: make interactive provider respect history?
history = [], # TODO: make interactive provider respect history?
runtime = None,
) -> AsyncGenerator[str, None]:
"""Asynchronously invoke the OpenAI API and yield results in chunks.
Expand All @@ -38,5 +39,5 @@ async def ainvoke(
str: Chunks of the response as they're received.
"""
prompt = serialize_chat_history(chat)
output = input("Input:\n" + prompt + "\nYour answer: ")
output = (input if runtime is None else runtime.input)(prompt)
yield output
3 changes: 2 additions & 1 deletion python/src/providers/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def match_rule(rule, history, chat):
async def ainvoke(
self,
chat: List[{ "role": str, "content": str }],
history = [] # TODO: make interactive provider respect history?
history: List[{ "role": str, "content": str }] = [],
runtime = None,
) -> AsyncGenerator[str, None]:
"""Asynchronously invoke the OpenAI API and yield results in chunks.
Expand Down
3 changes: 2 additions & 1 deletion python/src/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(self, model: str, api_key: str = None):
async def ainvoke(
self,
chat: List[{ "role": str, "content": str }],
history: List[{ "role": str, "content": str }] = []
history: List[{ "role": str, "content": str }] = [],
runtime = None
) -> AsyncGenerator[str, None]:
"""Asynchronously invoke the OpenAI API and yield results in chunks.
Expand Down
7 changes: 7 additions & 0 deletions python/src/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,10 @@ def set_status(self, status: List[{"text": str, "color": str}]):
@abstractmethod
def print_chunk(self, chunk: str):
pass

@abstractmethod
def input(self, prompt):
"""Used to request input from the user, which can be done by some BaseLLMProvider
subclasses
"""
pass
10 changes: 8 additions & 2 deletions python/src/runtimes/cli_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,14 @@ def print_chunk(self, chunk: str):
self.print_status()

def input(self, prompt):
print("\r", end="", flush=True)
res = input(prompt)
self.print_chunk(prompt + "\n")
print(
"\r" + self.padding_for(self.status, "") + "\r",
end="",
flush=True
)
print("$ ", end="", flush=True)
res = input()
self.print_status()
return res

Expand Down

0 comments on commit 013ba66

Please sign in to comment.