diff --git a/.gitignore b/.gitignore index 13711bc..64d7b6c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ __pycache__ +.ipynb_checkpoints/ +Untitled*.ipynb + .venv uv.lock diff --git a/chatlas/_chat.py b/chatlas/_chat.py index d953e96..de8a0aa 100644 --- a/chatlas/_chat.py +++ b/chatlas/_chat.py @@ -1,8 +1,8 @@ from __future__ import annotations import os +from contextlib import contextmanager from typing import ( - TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, @@ -18,6 +18,7 @@ ) from pydantic import BaseModel +from rich.live import Live from ._content import ( Content, @@ -31,9 +32,6 @@ from ._turn import Turn, user_turn from ._typing_extensions import TypedDict -if TYPE_CHECKING: - import rich.live - class AnyTypeDict(TypedDict, total=False): pass @@ -457,14 +455,16 @@ def extract_data( The extracted data. """ - response = ChatResponse( - self._submit_turns( - user_turn(*args), - data_model=data_model, - echo=echo, - stream=stream, + with MaybeLiveDisplay() as display: + response = ChatResponse( + self._submit_turns( + user_turn(*args), + data_model=data_model, + echo=echo, + display=display, + stream=stream, + ) ) - ) for _ in response: pass @@ -512,14 +512,17 @@ async def extract_data_async( dict[str, Any] The extracted data. """ - response = ChatResponseAsync( - self._submit_turns_async( - user_turn(*args), - data_model=data_model, - echo=echo, - stream=stream, + + with MaybeLiveDisplay() as display: + response = ChatResponseAsync( + self._submit_turns_async( + user_turn(*args), + data_model=data_model, + echo=echo, + display=display, + stream=stream, + ) ) - ) async for _ in response: pass @@ -631,15 +634,18 @@ def _chat_impl( kwargs: Optional[SubmitInputArgsT] = None, ) -> Generator[str, None, None]: user_turn_result: Turn | None = user_turn - while user_turn_result is not None: - for chunk in self._submit_turns( - user_turn_result, - echo=echo, - stream=stream, - kwargs=kwargs, - ): - yield chunk - user_turn_result = self._invoke_tools() + + with MaybeLiveDisplay() as display: + while user_turn_result is not None: + for chunk in self._submit_turns( + user_turn_result, + echo=echo, + display=display, + stream=stream, + kwargs=kwargs, + ): + yield chunk + user_turn_result = self._invoke_tools() async def _chat_impl_async( self, @@ -649,20 +655,24 @@ async def _chat_impl_async( kwargs: Optional[SubmitInputArgsT] = None, ) -> AsyncGenerator[str, None]: user_turn_result: Turn | None = user_turn - while user_turn_result is not None: - async for chunk in self._submit_turns_async( - user_turn_result, - echo=echo, - stream=stream, - kwargs=kwargs, - ): - yield chunk - user_turn_result = await self._invoke_tools_async() + + with MaybeLiveDisplay() as display: + while user_turn_result is not None: + async for chunk in self._submit_turns_async( + user_turn_result, + echo=echo, + display=display, + stream=stream, + kwargs=kwargs, + ): + yield chunk + user_turn_result = await self._invoke_tools_async() def _submit_turns( self, user_turn: Turn, echo: Literal["text", "all", "none"], + display: LiveMarkdownDisplay | None, stream: bool, data_model: type[BaseModel] | None = None, kwargs: Optional[SubmitInputArgsT] = None, @@ -670,7 +680,7 @@ def _submit_turns( if any(x._is_async for x in self.tools.values()): raise ValueError("Cannot use async tools in a synchronous chat") - emit = emitter(echo) + emit = emitter(echo, display) if echo == "all": emit_user_contents(user_turn, emit) @@ -724,11 +734,12 @@ async def _submit_turns_async( self, user_turn: Turn, echo: Literal["text", "all", "none"], + display: LiveMarkdownDisplay | None, stream: bool, data_model: type[BaseModel] | None = None, kwargs: Optional[SubmitInputArgsT] = None, ) -> AsyncGenerator[str, None]: - emit = emitter(echo) + emit = emitter(echo, display) if echo == "all": emit_user_contents(user_turn, emit) @@ -968,13 +979,37 @@ def consumed(self) -> bool: # ---------------------------------------------------------------------------- -def emitter(echo: Literal["text", "all", "none"]) -> Callable[[Content | str], None]: +@contextmanager +def MaybeLiveDisplay() -> Generator[LiveMarkdownDisplay | None, None, None]: + display = LiveMarkdownDisplay() + + # rich seems to be pretty good at detecting a (Jupyter) notebook + # context, so utilize that, but use IPython.display.Markdown instead if + # we're in a notebook (or Quarto) since that's a much more responsive + # way to display markdown + is_web = ( + display.live.console.is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None + ) + + if is_web: + yield None + else: + with display: + yield display + + +def emitter( + echo: Literal["text", "all", "none"], + display: LiveMarkdownDisplay | None, +) -> Callable[[Content | str], None]: if echo == "none": return lambda _: None - stream = StreamingMarkdown() + if display is not None: + return lambda x: display.update(str(x)) - return lambda x: stream.update(str(x)) + ipy_display = IPyMarkdownDisplay() + return lambda x: ipy_display.update(str(x)) def emit_user_contents( @@ -1008,50 +1043,38 @@ def emit_other_contents( emit(f" {str(content)}\n\n") -class StreamingMarkdown: - """ - Stream markdown content. - - This uses rich for non-notebook contexts, and IPython.display.Markdown for - notebook (+Quarto) contexts. - """ - - content: str = "" - ipy_display_id: Optional[str] = None - live: "Optional[rich.live.Live]" = None - +class LiveMarkdownDisplay: def __init__(self): - self.content = "" - - from rich.console import Console - from rich.live import Live - - # rich seems to be pretty good at detecting a (Jupyter) notebook - # context, so utilize that, but use IPython.display.Markdown instead if - # we're in a notebook (or Quarto) since that's a much more responsive - # way to display markdown - console = Console() - if console.is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None: - self.ipy_display_id = self._init_display() - else: - live = Live(auto_refresh=False, vertical_overflow="visible") - live.start() - self.live = live + self.content: str = "" + self.live = Live(auto_refresh=False, vertical_overflow="visible") def update(self, content: str): + from rich.markdown import Markdown + self.content += content + self.live.update(Markdown(self.content), refresh=True) - if self.ipy_display_id is not None: - from IPython.display import Markdown, update_display + def __enter__(self): + self.live.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + return self.live.__exit__(exc_type, exc_value, traceback) - update_display( - Markdown(self.content), - display_id=self.ipy_display_id, - ) - elif self.live is not None: - from rich.markdown import Markdown - self.live.update(Markdown(self.content), refresh=True) +class IPyMarkdownDisplay: + def __init__(self): + self.content: str = "" + self.ipy_display_id = self._init_display() + + def update(self, content: str): + from IPython.display import Markdown, update_display + + self.content += content + update_display( + Markdown(self.content), + display_id=self.ipy_display_id, + ) def _init_display(self) -> str: try: @@ -1067,11 +1090,3 @@ def _init_display(self) -> str: if handle is None: raise ValueError("Failed to create display handle") return handle.display_id - - def __del__(self): - if self.live is not None: - self.live.stop() - self.live = None - if self.ipy_display_id is not None: - # I don't think there's any more cleanup to do here? - self.ipy_display_id = None diff --git a/pyproject.toml b/pyproject.toml index 88107f1..3eb8672 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "chatlas" description = "A simple and consistent interface for chatting with LLMs" -version = "0.1.1.9003" +version = "0.1.1.9004" readme = "README.md" requires-python = ">=3.9" dependencies = [