Skip to content

Commit

Permalink
Make tools attribute private; fix article example
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Dec 11, 2024
1 parent a27b064 commit 3388547
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
18 changes: 9 additions & 9 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
"""
self.provider = provider
self._turns: list[Turn] = list(turns or [])
self.tools: dict[str, Tool] = {}
self._tools: dict[str, Tool] = {}
self._echo_options: EchoOptions = {
"rich_markdown": {},
"rich_console": {},
Expand Down Expand Up @@ -696,7 +696,7 @@ def add(a: int, b: int) -> int:
name and docstring of the function.
"""
tool = Tool(func, model=model)
self.tools[tool.name] = tool
self._tools[tool.name] = tool

def export(
self,
Expand Down Expand Up @@ -869,7 +869,7 @@ def _submit_turns(
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str, None, None]:
if any(x._is_async for x in self.tools.values()):
if any(x._is_async for x in self._tools.values()):
raise ValueError("Cannot use async tools in a synchronous chat")

def emit(text: str | Content):
Expand All @@ -884,7 +884,7 @@ def emit(text: str | Content):
response = self.provider.chat_perform(
stream=True,
turns=[*self._turns, user_turn],
tools=self.tools,
tools=self._tools,
data_model=data_model,
kwargs=kwargs,
)
Expand All @@ -910,7 +910,7 @@ def emit(text: str | Content):
response = self.provider.chat_perform(
stream=False,
turns=[*self._turns, user_turn],
tools=self.tools,
tools=self._tools,
data_model=data_model,
kwargs=kwargs,
)
Expand Down Expand Up @@ -948,7 +948,7 @@ def emit(text: str | Content):
response = await self.provider.chat_perform_async(
stream=True,
turns=[*self._turns, user_turn],
tools=self.tools,
tools=self._tools,
data_model=data_model,
kwargs=kwargs,
)
Expand All @@ -974,7 +974,7 @@ def emit(text: str | Content):
response = await self.provider.chat_perform_async(
stream=False,
turns=[*self._turns, user_turn],
tools=self.tools,
tools=self._tools,
data_model=data_model,
kwargs=kwargs,
)
Expand All @@ -999,7 +999,7 @@ def _invoke_tools(self) -> Turn | None:
results: list[ContentToolResult] = []
for x in turn.contents:
if isinstance(x, ContentToolRequest):
tool_def = self.tools.get(x.name, None)
tool_def = self._tools.get(x.name, None)
func = tool_def.func if tool_def is not None else None
results.append(self._invoke_tool(func, x.arguments, x.id))

Expand All @@ -1016,7 +1016,7 @@ async def _invoke_tools_async(self) -> Turn | None:
results: list[ContentToolResult] = []
for x in turn.contents:
if isinstance(x, ContentToolRequest):
tool_def = self.tools.get(x.name, None)
tool_def = self._tools.get(x.name, None)
func = tool_def.func if tool_def is not None else None
results.append(await self._invoke_tool_async(func, x.arguments, x.id))

Expand Down
2 changes: 1 addition & 1 deletion docs/tool-calling.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def get_current_temperature(latitude: float, longitude: float):
"Get the current weather given a latitude and longitude."
raise ValueError("Failed to get current temperature")
chat.tools = [get_current_temperature]
chat.register_tool(get_current_temperature)
_ = chat.chat("What's the weather like today in Duluth, MN?")
```
Expand Down

0 comments on commit 3388547

Please sign in to comment.