Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save chat history WIP #10191

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/tiny-areas-train.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"gradio": minor
---

feat:Save chat history WIP
1 change: 1 addition & 0 deletions demo/chatinterface_save_history/run.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: chatinterface_save_history"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "\n", "def echo_multimodal(message, history):\n", " response = []\n", " response.append(\"You wrote: '\" + message[\"text\"] + \"' and uploaded:\")\n", " if message.get(\"files\"):\n", " for file in message[\"files\"]:\n", " response.append(gr.File(value=file))\n", " return response\n", "\n", "demo = gr.ChatInterface(\n", " echo_multimodal,\n", " type=\"messages\",\n", " multimodal=True,\n", " textbox=gr.MultimodalTextbox(file_count=\"multiple\"),\n", " save_history=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
20 changes: 20 additions & 0 deletions demo/chatinterface_save_history/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import gradio as gr

def echo_multimodal(message, history):
response = []
response.append("You wrote: '" + message["text"] + "' and uploaded:")
if message.get("files"):
for file in message["files"]:
response.append(gr.File(value=file))
return response

demo = gr.ChatInterface(
echo_multimodal,
type="messages",
multimodal=True,
textbox=gr.MultimodalTextbox(file_count="multiple"),
save_history=True,
)

if __name__ == "__main__":
demo.launch()
223 changes: 163 additions & 60 deletions gradio/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from gradio import utils
from gradio.blocks import Blocks
from gradio.components import (
HTML,
JSON,
BrowserState,
Button,
Chatbot,
Component,
Expand All @@ -40,9 +42,43 @@
from gradio.helpers import create_examples as Examples # noqa: N812
from gradio.helpers import special_args, update
from gradio.layouts import Accordion, Column, Group, Row
from gradio.renderable import render
from gradio.routes import Request
from gradio.themes import ThemeClass as Theme

save_history_css = """
<style>
._gradio-save-history {
width: 100%;
background-color: var(--background-fill-primary);
padding: 3px 6px;
user-select: none;
cursor: pointer;
transition: background-color 0.3s ease;
display: -webkit-box;
-webkit-line-clamp: 2;
-webkit-box-orient: vertical;
overflow: hidden;
font-size: 0.8rem;
text-overflow: ellipsis;
}
._gradio-save-history:hover {
background-color: var(--background-fill-secondary);
}
._gradio-save-history:active {
background-color: var(--color-accent-soft);
}
._gradio-save-history-header {
font-weight: bold;
font-size: 0.8rem;
text-align: center;
padding: 10px 0px;
}
</style>
<div class="_gradio-save-history-header">
Local chat history
</div>
"""

@document()
class ChatInterface(Blocks):
Expand Down Expand Up @@ -101,6 +137,7 @@ def __init__(
fill_height: bool = True,
fill_width: bool = False,
api_name: str | Literal[False] = "chat",
save_history: bool = False,
):
"""
Parameters:
Expand Down Expand Up @@ -138,6 +175,7 @@ def __init__(
fill_height: if True, the chat interface will expand to the height of window.
fill_width: Whether to horizontally expand to fill container fully. If False, centers and constrains app to a maximum width.
api_name: the name of the API endpoint to use for the chat interface. Defaults to "chat". Set to False to disable the API endpoint.
save_history: if True, will save the chat history to the browser's local storage. Defaults to False.
"""
super().__init__(
analytics_enabled=analytics_enabled,
Expand Down Expand Up @@ -176,6 +214,7 @@ def __init__(
self.cache_examples = cache_examples
self.cache_mode = cache_mode
self.editable = editable
self.save_history = save_history
self.additional_inputs = [
get_component_instance(i)
for i in utils.none_or_singleton_to_list(additional_inputs)
Expand Down Expand Up @@ -222,69 +261,120 @@ def __init__(
)
if description:
Markdown(description)
if chatbot:
if self.type:
if self.type != chatbot.type:
warnings.warn(
"The type of the gr.Chatbot does not match the type of the gr.ChatInterface."
f"The type of the gr.ChatInterface, '{self.type}', will be used."
with Row():
if save_history:
with Column(scale=1, min_width=100):
HTML(save_history_css, container=True, padding=False)
if not hasattr(self, 'saved_history'):
self.saved_history = BrowserState([])
with Group():
@render(inputs=self.saved_history)
def create_history(conversations):
for chat_conversation in conversations:
h = HTML(chat_conversation[0]["content"], padding=False, elem_classes=["_gradio-save-history"])
h.click(lambda _:chat_conversation, h, self.chatbot)

with Column(scale=6):
if chatbot:
if self.type:
if self.type != chatbot.type:
warnings.warn(
"The type of the gr.Chatbot does not match the type of the gr.ChatInterface."
f"The type of the gr.ChatInterface, '{self.type}', will be used."
)
chatbot.type = self.type
chatbot._setup_data_model()
else:
warnings.warn(
f"The gr.ChatInterface was not provided with a type, so the type of the gr.Chatbot, '{chatbot.type}', will be used."
)
self.type = chatbot.type
self.chatbot = cast(
Chatbot, get_component_instance(chatbot, render=True)
)
chatbot.type = self.type
chatbot._setup_data_model()
else:
warnings.warn(
f"The gr.ChatInterface was not provided with a type, so the type of the gr.Chatbot, '{chatbot.type}', will be used."
)
self.type = chatbot.type
self.chatbot = cast(
Chatbot, get_component_instance(chatbot, render=True)
)
if self.chatbot.examples and self.examples_messages:
warnings.warn(
"The ChatInterface already has examples set. The examples provided in the chatbot will be ignored."
)
self.chatbot.examples = (
self.examples_messages
if not self._additional_inputs_in_examples
else None
)
self.chatbot._setup_examples()
else:
self.type = self.type or "tuples"
self.chatbot = Chatbot(
label="Chatbot",
scale=1,
height=200 if fill_height else None,
type=self.type,
autoscroll=autoscroll,
examples=self.examples_messages
if not self._additional_inputs_in_examples
else None,
)
with Group():
with Row():
if textbox:
textbox.show_label = False
textbox_ = get_component_instance(textbox, render=True)
if not isinstance(textbox_, (Textbox, MultimodalTextbox)):
raise TypeError(
f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {builtins.type(textbox_)}"
if self.chatbot.examples and self.examples_messages:
warnings.warn(
"The ChatInterface already has examples set. The examples provided in the chatbot will be ignored."
)
self.textbox = textbox_
self.chatbot.examples = (
self.examples_messages
if not self._additional_inputs_in_examples
else None
)
self.chatbot._setup_examples()
else:
textbox_component = (
MultimodalTextbox if self.multimodal else Textbox
self.type = self.type or "tuples"
self.chatbot = Chatbot(
label="Chatbot",
scale=1,
height=400 if fill_height else None,
type=self.type,
autoscroll=autoscroll,
examples=self.examples_messages
if not self._additional_inputs_in_examples
else None,
)
self.textbox = textbox_component(
show_label=False,
label="Message",
placeholder="Type a message...",
scale=7,
autofocus=autofocus,
submit_btn=submit_btn,
stop_btn=stop_btn,
with Group():
with Row():
if textbox:
textbox.show_label = False
textbox_ = get_component_instance(
textbox, render=True
)
if not isinstance(
textbox_, (Textbox, MultimodalTextbox)
):
raise TypeError(
f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {builtins.type(textbox_)}"
)
self.textbox = textbox_
else:
textbox_component = (
MultimodalTextbox
if self.multimodal
else Textbox
)
self.textbox = textbox_component(
show_label=False,
label="Message",
placeholder="Type a message...",
scale=7,
autofocus=autofocus,
submit_btn=submit_btn,
stop_btn=stop_btn,
)

# Hide the stop button at the beginning, and show it with the given value during the generator execution.
self.original_stop_btn = self.textbox.stop_btn
self.textbox.stop_btn = False

self.fake_api_btn = Button("Fake API", visible=False)
self.fake_response_textbox = Textbox(
label="Response", visible=False
) # Used to store the response from the API call

if self.examples:
self.examples_handler = Examples(
examples=self.examples,
inputs=[self.textbox] + self.additional_inputs,
outputs=self.chatbot,
fn=self._examples_stream_fn
if self.is_generator
else self._examples_fn,
cache_examples=self.cache_examples,
cache_mode=self.cache_mode,
visible=self._additional_inputs_in_examples,
preprocess=self._additional_inputs_in_examples,
)

any_unrendered_inputs = any(
not inp.is_rendered for inp in self.additional_inputs
)
if self.additional_inputs and any_unrendered_inputs:
with Accordion(**self.additional_inputs_accordion_params): # type: ignore
for input_component in self.additional_inputs:
if not input_component.is_rendered:
input_component.render()
# Hide the stop button at the beginning, and show it with the given value during the generator execution.
self.original_stop_btn = self.textbox.stop_btn
self.textbox.stop_btn = False
Expand Down Expand Up @@ -322,6 +412,8 @@ def __init__(
self.chatbot_state = (
State(self.chatbot.value) if self.chatbot.value else State([])
)
if self.chatbot.value:
self.saved_history = BrowserState(self.chatbot.value)
self.show_progress = show_progress
self._setup_events()

Expand Down Expand Up @@ -368,10 +460,21 @@ def _setup_events(self) -> None:
if hasattr(self.fn, "zerogpu"):
submit_fn.__func__.zerogpu = self.fn.zerogpu # type: ignore

def test(chatbot, chatbot_state, saved_history):
if self.save_history:
if isinstance(chatbot_state, list) and len(chatbot_state) == 0:
return chatbot, saved_history + [chatbot]
else:
# replace the most recent element in saved history with chatbot_state
saved_history[-1] = chatbot
return chatbot, saved_history
else:
return chatbot

synchronize_chat_state_kwargs = {
"fn": lambda x: x,
"inputs": [self.chatbot],
"outputs": [self.chatbot_state],
"fn": test,
"inputs": [self.chatbot, self.chatbot_state] + ([self.saved_history] if hasattr(self, 'saved_history') else []),
"outputs": [self.chatbot_state] + ([self.saved_history] if hasattr(self, 'saved_history') else []),
"show_api": False,
"queue": False,
}
Expand Down
Loading