Skip to content

Commit

Permalink
Return to using contextmanager approach for rich
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Nov 22, 2024
1 parent 5bf4990 commit 445ace2
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 87 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
__pycache__

.ipynb_checkpoints/
Untitled*.ipynb

.venv
uv.lock

Expand Down
187 changes: 101 additions & 86 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import os
from contextlib import contextmanager
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterator,
Expand All @@ -18,6 +18,7 @@
)

from pydantic import BaseModel
from rich.live import Live

from ._content import (
Content,
Expand All @@ -31,9 +32,6 @@
from ._turn import Turn, user_turn
from ._typing_extensions import TypedDict

if TYPE_CHECKING:
import rich.live


class AnyTypeDict(TypedDict, total=False):
pass
Expand Down Expand Up @@ -457,14 +455,16 @@ def extract_data(
The extracted data.
"""

response = ChatResponse(
self._submit_turns(
user_turn(*args),
data_model=data_model,
echo=echo,
stream=stream,
with MaybeLiveDisplay() as display:
response = ChatResponse(
self._submit_turns(
user_turn(*args),
data_model=data_model,
echo=echo,
display=display,
stream=stream,
)
)
)

for _ in response:
pass
Expand Down Expand Up @@ -512,14 +512,17 @@ async def extract_data_async(
dict[str, Any]
The extracted data.
"""
response = ChatResponseAsync(
self._submit_turns_async(
user_turn(*args),
data_model=data_model,
echo=echo,
stream=stream,

with MaybeLiveDisplay() as display:
response = ChatResponseAsync(
self._submit_turns_async(
user_turn(*args),
data_model=data_model,
echo=echo,
display=display,
stream=stream,
)
)
)

async for _ in response:
pass
Expand Down Expand Up @@ -631,15 +634,18 @@ def _chat_impl(
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str, None, None]:
user_turn_result: Turn | None = user_turn
while user_turn_result is not None:
for chunk in self._submit_turns(
user_turn_result,
echo=echo,
stream=stream,
kwargs=kwargs,
):
yield chunk
user_turn_result = self._invoke_tools()

with MaybeLiveDisplay() as display:
while user_turn_result is not None:
for chunk in self._submit_turns(
user_turn_result,
echo=echo,
display=display,
stream=stream,
kwargs=kwargs,
):
yield chunk
user_turn_result = self._invoke_tools()

async def _chat_impl_async(
self,
Expand All @@ -649,28 +655,32 @@ async def _chat_impl_async(
kwargs: Optional[SubmitInputArgsT] = None,
) -> AsyncGenerator[str, None]:
user_turn_result: Turn | None = user_turn
while user_turn_result is not None:
async for chunk in self._submit_turns_async(
user_turn_result,
echo=echo,
stream=stream,
kwargs=kwargs,
):
yield chunk
user_turn_result = await self._invoke_tools_async()

with MaybeLiveDisplay() as display:
while user_turn_result is not None:
async for chunk in self._submit_turns_async(
user_turn_result,
echo=echo,
display=display,
stream=stream,
kwargs=kwargs,
):
yield chunk
user_turn_result = await self._invoke_tools_async()

def _submit_turns(
self,
user_turn: Turn,
echo: Literal["text", "all", "none"],
display: LiveMarkdownDisplay | None,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> Generator[str, None, None]:
if any(x._is_async for x in self.tools.values()):
raise ValueError("Cannot use async tools in a synchronous chat")

emit = emitter(echo)
emit = emitter(echo, display)

if echo == "all":
emit_user_contents(user_turn, emit)
Expand Down Expand Up @@ -724,11 +734,12 @@ async def _submit_turns_async(
self,
user_turn: Turn,
echo: Literal["text", "all", "none"],
display: LiveMarkdownDisplay | None,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
) -> AsyncGenerator[str, None]:
emit = emitter(echo)
emit = emitter(echo, display)

if echo == "all":
emit_user_contents(user_turn, emit)
Expand Down Expand Up @@ -968,13 +979,37 @@ def consumed(self) -> bool:
# ----------------------------------------------------------------------------


def emitter(echo: Literal["text", "all", "none"]) -> Callable[[Content | str], None]:
@contextmanager
def MaybeLiveDisplay() -> Generator[LiveMarkdownDisplay | None, None, None]:
display = LiveMarkdownDisplay()

# rich seems to be pretty good at detecting a (Jupyter) notebook
# context, so utilize that, but use IPython.display.Markdown instead if
# we're in a notebook (or Quarto) since that's a much more responsive
# way to display markdown
is_web = (
display.live.console.is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None
)

if is_web:
yield None
else:
with display:
yield display


def emitter(
echo: Literal["text", "all", "none"],
display: LiveMarkdownDisplay | None,
) -> Callable[[Content | str], None]:
if echo == "none":
return lambda _: None

stream = StreamingMarkdown()
if display is not None:
return lambda x: display.update(str(x))

return lambda x: stream.update(str(x))
ipy_display = IPyMarkdownDisplay()
return lambda x: ipy_display.update(str(x))


def emit_user_contents(
Expand Down Expand Up @@ -1008,50 +1043,38 @@ def emit_other_contents(
emit(f" {str(content)}\n\n")


class StreamingMarkdown:
"""
Stream markdown content.
This uses rich for non-notebook contexts, and IPython.display.Markdown for
notebook (+Quarto) contexts.
"""

content: str = ""
ipy_display_id: Optional[str] = None
live: "Optional[rich.live.Live]" = None

class LiveMarkdownDisplay:
def __init__(self):
self.content = ""

from rich.console import Console
from rich.live import Live

# rich seems to be pretty good at detecting a (Jupyter) notebook
# context, so utilize that, but use IPython.display.Markdown instead if
# we're in a notebook (or Quarto) since that's a much more responsive
# way to display markdown
console = Console()
if console.is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None:
self.ipy_display_id = self._init_display()
else:
live = Live(auto_refresh=False, vertical_overflow="visible")
live.start()
self.live = live
self.content: str = ""
self.live = Live(auto_refresh=False, vertical_overflow="visible")

def update(self, content: str):
from rich.markdown import Markdown

self.content += content
self.live.update(Markdown(self.content), refresh=True)

if self.ipy_display_id is not None:
from IPython.display import Markdown, update_display
def __enter__(self):
self.live.__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback):
return self.live.__exit__(exc_type, exc_value, traceback)

update_display(
Markdown(self.content),
display_id=self.ipy_display_id,
)
elif self.live is not None:
from rich.markdown import Markdown

self.live.update(Markdown(self.content), refresh=True)
class IPyMarkdownDisplay:
def __init__(self):
self.content: str = ""
self.ipy_display_id = self._init_display()

def update(self, content: str):
from IPython.display import Markdown, update_display

self.content += content
update_display(
Markdown(self.content),
display_id=self.ipy_display_id,
)

def _init_display(self) -> str:
try:
Expand All @@ -1067,11 +1090,3 @@ def _init_display(self) -> str:
if handle is None:
raise ValueError("Failed to create display handle")
return handle.display_id

def __del__(self):
if self.live is not None:
self.live.stop()
self.live = None
if self.ipy_display_id is not None:
# I don't think there's any more cleanup to do here?
self.ipy_display_id = None
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "chatlas"
description = "A simple and consistent interface for chatting with LLMs"
version = "0.1.1.9003"
version = "0.1.1.9004"
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
Expand Down

0 comments on commit 445ace2

Please sign in to comment.