From 5fa4b1bb1343d63219e2e06af40fcebc278d845d Mon Sep 17 00:00:00 2001 From: Abubakar Abid Date: Mon, 23 Dec 2024 16:36:42 -0500 Subject: [PATCH] more type fixes --- demo/agent_chatbot/run.py | 2 +- gradio/external.py | 7 ++++--- gradio/themes/base.py | 4 +++- test/test_utils.py | 2 +- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/demo/agent_chatbot/run.py b/demo/agent_chatbot/run.py index 4e6cf36214536..ad8d7906c4a6b 100644 --- a/demo/agent_chatbot/run.py +++ b/demo/agent_chatbot/run.py @@ -20,7 +20,7 @@ def interact_with_agent(prompt, history): messages = [] yield messages for msg in stream_to_gradio(agent, prompt): - messages.append(asdict(msg)) + messages.append(asdict(msg)) # type: ignore yield messages yield messages diff --git a/gradio/external.py b/gradio/external.py index d410dcdce304b..2e2d0ccf76860 100644 --- a/gradio/external.py +++ b/gradio/external.py @@ -615,7 +615,7 @@ def load_chat( [{"role": "system", "content": system_message}] if system_message else [] ) - def open_api(message: str, history: list | None) -> str: + def open_api(message: str, history: list | None) -> str | None: history = history or start_message if len(history) > 0 and isinstance(history[0], (list, tuple)): history = ChatInterface._tuples_to_messages(history) @@ -641,7 +641,8 @@ def open_api_stream( ) response = "" for chunk in stream: - response += chunk.choices[0].delta.content - yield response + if chunk.choices[0].delta.content is not None: + response += chunk.choices[0].delta.content + yield response return ChatInterface(open_api_stream if streaming else open_api, type="messages") diff --git a/gradio/themes/base.py b/gradio/themes/base.py index d3ca3dc3a3f17..31d571c49d959 100644 --- a/gradio/themes/base.py +++ b/gradio/themes/base.py @@ -124,7 +124,9 @@ def to_dict(self): schema = {"theme": {}} for prop in dir(self): if ( - not prop.startswith("_") or prop.startswith("_font") or prop in ("_stylesheets", "name") + not prop.startswith("_") + or prop.startswith("_font") + or prop in ("_stylesheets", "name") ) and isinstance(getattr(self, prop), (list, str)): schema["theme"][prop] = getattr(self, prop) return schema diff --git a/test/test_utils.py b/test/test_utils.py index 7b28c101a2a8a..7f60554bc28c6 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -304,7 +304,7 @@ class GenericObject: for x in test_objs: hints = get_type_hints(x) assert len(hints) == 1 - assert hints["s"] == str + assert hints["s"] is str assert len(get_type_hints(GenericObject())) == 0