Skip to content

Commit

Permalink
Introduce .set_echo_options(); refactor context manager logic
Browse files Browse the repository at this point in the history
  • Loading branch information
cpsievert committed Nov 23, 2024
1 parent 445ace2 commit 4bb1e46
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 51 deletions.
147 changes: 109 additions & 38 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Sequence,
TypeVar,
)
from uuid import uuid4

from pydantic import BaseModel
from rich.live import Live
Expand Down Expand Up @@ -77,6 +78,11 @@ def __init__(
self.provider = provider
self._turns: list[Turn] = list(turns or [])
self.tools: dict[str, Tool] = {}
self._echo_options: EchoOptions = {
"rich_markdown": {},
"rich_console": {},
"css_styles": {},
}

def turns(
self,
Expand Down Expand Up @@ -455,7 +461,7 @@ def extract_data(
The extracted data.
"""

with MaybeLiveDisplay() as display:
with self._display_context() as display:
response = ChatResponse(
self._submit_turns(
user_turn(*args),
Expand Down Expand Up @@ -513,7 +519,7 @@ async def extract_data_async(
The extracted data.
"""

with MaybeLiveDisplay() as display:
with self._display_context() as display:
response = ChatResponseAsync(
self._submit_turns_async(
user_turn(*args),
Expand Down Expand Up @@ -635,7 +641,7 @@ def _chat_impl(
) -> Generator[str, None, None]:
user_turn_result: Turn | None = user_turn

with MaybeLiveDisplay() as display:
with self._display_context() as display:
while user_turn_result is not None:
for chunk in self._submit_turns(
user_turn_result,
Expand All @@ -656,7 +662,7 @@ async def _chat_impl_async(
) -> AsyncGenerator[str, None]:
user_turn_result: Turn | None = user_turn

with MaybeLiveDisplay() as display:
with self._display_context() as display:
while user_turn_result is not None:
async for chunk in self._submit_turns_async(
user_turn_result,
Expand All @@ -672,7 +678,7 @@ def _submit_turns(
self,
user_turn: Turn,
echo: Literal["text", "all", "none"],
display: LiveMarkdownDisplay | None,
display: LiveMarkdownDisplay | IPyMarkdownDisplay,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
Expand Down Expand Up @@ -734,7 +740,7 @@ async def _submit_turns_async(
self,
user_turn: Turn,
echo: Literal["text", "all", "none"],
display: LiveMarkdownDisplay | None,
display: LiveMarkdownDisplay | IPyMarkdownDisplay,
stream: bool,
data_model: type[BaseModel] | None = None,
kwargs: Optional[SubmitInputArgsT] = None,
Expand Down Expand Up @@ -861,6 +867,56 @@ async def _invoke_tool_async(
except Exception as e:
return ContentToolResult(id_, None, str(e))

@contextmanager
def _display_context(
self,
) -> Generator[LiveMarkdownDisplay | IPyMarkdownDisplay, None, None]:
opts = self._echo_options
display = LiveMarkdownDisplay(opts)

# rich seems to be pretty good at detecting a (Jupyter) notebook
# context, so utilize that, but use IPython.display.Markdown instead if
# we're in a notebook (or Quarto) since that's a much more responsive
# way to display markdown
is_web = (
display.live.console.is_jupyter
or os.getenv("QUARTO_PYTHON", None) is not None
)

if is_web:
with IPyMarkdownDisplay(opts) as d:
yield d
else:
with display:
yield display

def set_echo_options(
self,
rich_markdown: Optional[dict[str, Any]] = None,
rich_console: Optional[dict[str, Any]] = None,
css_styles: Optional[dict[str, str]] = None,
):
"""
Set echo styling options for the chat.
Parameters
----------
rich_markdown
A dictionary of options to pass to `rich.markdown.Markdown()`.
This is only relevant when outputting to the console.
rich_console
A dictionary of options to pass to `rich.console.Console()`.
This is only relevant when outputting to the console.
css_styles
A dictionary of CSS styles to apply to `IPython.display.Markdown()`.
This is only relevant when outputing to the browser.
"""
self._echo_options: EchoOptions = {
"rich_markdown": rich_markdown or {},
"rich_console": rich_console or {},
"css_styles": css_styles or {},
}

def __str__(self):
turns = self.turns(include_system_prompt=True)
tokens = sum(sum(turn.tokens) for turn in turns)
Expand Down Expand Up @@ -979,38 +1035,15 @@ def consumed(self) -> bool:
# ----------------------------------------------------------------------------


@contextmanager
def MaybeLiveDisplay() -> Generator[LiveMarkdownDisplay | None, None, None]:
display = LiveMarkdownDisplay()

# rich seems to be pretty good at detecting a (Jupyter) notebook
# context, so utilize that, but use IPython.display.Markdown instead if
# we're in a notebook (or Quarto) since that's a much more responsive
# way to display markdown
is_web = (
display.live.console.is_jupyter or os.getenv("QUARTO_PYTHON", None) is not None
)

if is_web:
yield None
else:
with display:
yield display


def emitter(
echo: Literal["text", "all", "none"],
display: LiveMarkdownDisplay | None,
display: LiveMarkdownDisplay | IPyMarkdownDisplay,
) -> Callable[[Content | str], None]:
if echo == "none":
return lambda _: None

if display is not None:
else:
return lambda x: display.update(str(x))

ipy_display = IPyMarkdownDisplay()
return lambda x: ipy_display.update(str(x))


def emit_user_contents(
x: Turn,
Expand Down Expand Up @@ -1044,36 +1077,52 @@ def emit_other_contents(


class LiveMarkdownDisplay:
def __init__(self):
def __init__(self, echo_options: EchoOptions):
from rich.console import Console

self.content: str = ""
self.live = Live(auto_refresh=False, vertical_overflow="visible")
self.live = Live(
auto_refresh=False,
vertical_overflow="visible",
console=Console(
**echo_options["rich_console"],
),
)
self._markdown_options = echo_options["rich_markdown"]

def update(self, content: str):
from rich.markdown import Markdown

self.content += content
self.live.update(Markdown(self.content), refresh=True)
self.live.update(
Markdown(
self.content,
**self._markdown_options,
),
refresh=True,
)

def __enter__(self):
self.live.__enter__()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.content = ""
return self.live.__exit__(exc_type, exc_value, traceback)


class IPyMarkdownDisplay:
def __init__(self):
def __init__(self, echo_options: EchoOptions):
self.content: str = ""
self.ipy_display_id = self._init_display()
self._css_styles = echo_options["css_styles"]

def update(self, content: str):
from IPython.display import Markdown, update_display

self.content += content
update_display(
Markdown(self.content),
display_id=self.ipy_display_id,
display_id=self._ipy_display_id,
)

def _init_display(self) -> str:
Expand All @@ -1085,8 +1134,30 @@ def _init_display(self) -> str:
"Install it with `pip install ipython`."
)

display(HTML("<div class='chatlas-markdown'></div>"))
if self._css_styles:
id_ = uuid4().hex
css = "".join(f"{k}: {v}; " for k, v in self._css_styles.items())
display(HTML(f"<style>#{id_} + .chatlas-markdown {{ {css} }}</style>"))
display(HTML(f"<div id='{id_}' class='chatlas-markdown'>"))
else:
# Unfortunately, there doesn't seem to be a proper way to wrap
# Markdown() in a div?
display(HTML("<div class='chatlas-markdown'>"))

handle = display(Markdown(""), display_id=True)
if handle is None:
raise ValueError("Failed to create display handle")
return handle.display_id

def __enter__(self):
self._ipy_display_id = self._init_display()
return self

def __exit__(self, exc_type, exc_value, traceback):
self._ipy_display_id = None


class EchoOptions(TypedDict):
rich_markdown: dict[str, Any]
rich_console: dict[str, Any]
css_styles: dict[str, str]
34 changes: 22 additions & 12 deletions docs/prompt-design.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ question = """

### Basic flavour

When I don't provide a system prompt, I get a solution with R code (your results may vary):

When I don't provide a system prompt, I sometimes get answers in a different language (like R):

```{python}
#| eval: false
Expand All @@ -59,15 +58,16 @@ chat = ChatAnthropic(model="claude-3-5-sonnet-20241022")
_ = chat.chat(question, kwargs={"temperature": 0})
```


So I can specify that I want the LLM to be a Python programmer:
I can ensure that I always get Python code by providing a system prompt:

```{python}
chat.system_prompt = "You are a helpful Python (not R) programming assistant."
_ = chat.chat(question)
```

Since I'm mostly interested in the code, I ask it to drop the explanation and also:
Note that I'm using both a system prompt (which defines the general behaviour) and a user prompt (which asks the specific question). You could put all of the content in the user prompt and get similar results, but I think it's helpful to use both to cleanly divide the general framing of the response from the specific questions that you want to ask.

Since I'm mostly interested in the code, I ask it to drop the explanation:

```{python}
chat.system_prompt = """
Expand All @@ -90,7 +90,7 @@ _ = chat.chat(question)

### Be explicit

If there's something about the output that you don't like, you can try being more explicit about it. For example, the code isn't styled quite how I like, so I can be a bit more explicit:
If there's something about the output that you don't like, you can try being more explicit about it. For example, the code isn't styled quite how I like, so I provide more details about what I do want:

```{python}
chat.system_prompt = """
Expand All @@ -103,16 +103,16 @@ chat.system_prompt = """
_ = chat.chat(question)
```

Note that the LLM doesn't follow these instructions exactly, but it does seem to lead to code that looks a bit more the way that I want. If you were investing more time into this, you might provide more specific examples of how you're looking for code to be formatted.
This still doesn't yield exactly the code that I'd write, but it's prety close.

Or maybe you're looking for more explanation of the code:
You could provide a different prompt if you were looking for more explanation of the code:

```{python}
chat.system_prompt = """
You are an an expert Python (not R) programmer and a warm and supportive teacher.
Help me understand the code you produce by explaining each function call with
a brief comment. Add more details for more complicated calls.
Just give me the code without any text explanation.
a brief comment. For more complicated calls, add documentation to each
argument. Just give me the code without any text explanation.
"""
_ = chat.chat(question)
```
Expand Down Expand Up @@ -175,9 +175,9 @@ If you don't have strong feelings about what the data structure should look like
```{python}
instruct_json = """
You're an expert baker who also loves JSON. I am going to give you a list of
ingredients and your job is to return nicely structured JSON.
ingredients and your job is to return nicely structured JSON. Just return the
JSON and no other commentary.
"""
# Just return the JSON and no other commentary.
chat.system_prompt = instruct_json
_ = chat.chat(ingredients)
```
Expand Down Expand Up @@ -329,3 +329,13 @@ Including the input text in the output makes it easier to see if it's doing a go
chat.system_prompt = instruct_json + "\n" + instruct_weight_input
_ = chat.chat(ingredients)
```


When I ran it while writing this vignette, it seems to be working out the weight of the ingredients specified in volume, even though the prompt specifically asks it not to do that. This may suggest I need to broaden my examples.

## Token usage

```{python}
from chatlas import token_usage
token_usage()
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "chatlas"
description = "A simple and consistent interface for chatting with LLMs"
version = "0.1.1.9004"
version = "0.1.1.9005"
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
Expand Down

0 comments on commit 4bb1e46

Please sign in to comment.