Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StreamFetcher POC #6459

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions cli/openbb_cli/controllers/base_platform_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Dict, List, Optional

import pandas as pd
from anyio.from_thread import start_blocking_portal
from fastapi.responses import StreamingResponse
from openbb import obb
from openbb_charting.core.openbb_figure import OpenBBFigure
from openbb_cli.argparse_translator.argparse_class_processor import (
Expand Down Expand Up @@ -201,6 +203,23 @@ def method(self, other_args: List[str], translator=translator):
df = pd.DataFrame.from_dict(obbject, orient="columns")
print_rich_table(df=df, show_index=True, title=title)

elif isinstance(obbject, StreamingResponse):
received_data = []

async def stream_data(obbject):
async for data in obbject.body_iterator:
received_data.append(data)
session.console.print(data)

with start_blocking_portal() as portal:
try:
portal.start_task_soon(stream_data, obbject)
finally:
portal.call(portal.stop)

df = pd.DataFrame(received_data)
print_rich_table(df=df, show_index=True, title=title)

elif not isinstance(obbject, OBBject):
session.console.print(obbject)

Expand Down
17 changes: 13 additions & 4 deletions openbb_platform/core/openbb_core/api/router/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import inspect
from functools import partial, wraps
from inspect import Parameter, Signature, signature
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union

from fastapi import APIRouter, Depends, Header
from fastapi.responses import StreamingResponse
from fastapi.routing import APIRoute
from openbb_core.app.command_runner import CommandRunner
from openbb_core.app.model.command_context import CommandContext
Expand Down Expand Up @@ -188,17 +189,25 @@ def build_api_wrapper(
func.__annotations__ = new_annotations_map

@wraps(wrapped=func)
async def wrapper(*args: Tuple[Any], **kwargs: Dict[str, Any]) -> OBBject:
async def wrapper(
*args: Tuple[Any], **kwargs: Dict[str, Any]
) -> Union[OBBject, StreamingResponse]:
user_settings: UserSettings = UserSettings.model_validate(
kwargs.pop(
"__authenticated_user_settings",
UserService.read_default_user_settings(),
)
)
execute = partial(command_runner.run, path, user_settings)
output: OBBject = await execute(*args, **kwargs)
output = await execute(*args, **kwargs)

return validate_output(output)
if route.openapi_extra.get("is_stream", False):
return output.results

if isinstance(output, OBBject):
return validate_output(output)

return output

return wrapper

Expand Down
32 changes: 17 additions & 15 deletions openbb_platform/core/openbb_core/app/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from inspect import Parameter, signature
from sys import exc_info
from time import perf_counter_ns
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from warnings import catch_warnings, showwarning, warn

from fastapi.responses import StreamingResponse
from pydantic import BaseModel, ConfigDict, create_model

from openbb_core.app.logs.logging_service import LoggingService
Expand Down Expand Up @@ -420,7 +421,7 @@ async def run(
/,
*args,
**kwargs,
) -> OBBject:
) -> Union[OBBject, StreamingResponse]:
"""Run a command and return the OBBject as output."""
timestamp = datetime.now()
start_ns = perf_counter_ns()
Expand All @@ -429,7 +430,7 @@ async def run(
route = execution_context.route

if func := command_map.get_command(route=route):
obbject = await cls._execute_func(
result = await cls._execute_func(
route=route,
args=args, # type: ignore
execution_context=execution_context,
Expand All @@ -442,19 +443,20 @@ async def run(
duration = perf_counter_ns() - start_ns

if execution_context.user_settings.preferences.metadata:
try:
obbject.extra["metadata"] = Metadata(
arguments=kwargs,
duration=duration,
route=route,
timestamp=timestamp,
)
except Exception as e:
if Env().DEBUG_MODE:
raise OpenBBError(e) from e
warn(str(e), OpenBBWarning)
if isinstance(result, OBBject):
try:
result.extra["metadata"] = Metadata(
arguments=kwargs,
duration=duration,
route=route,
timestamp=timestamp,
)
except Exception as e:
if Env().DEBUG_MODE:
raise OpenBBError(e) from e
warn(str(e), OpenBBWarning)

return obbject
return result


class CommandRunner:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,11 @@ def _add_posthog_handler(self):
def _add_stdout_handler(self):
"""Add a stdout handler."""
handler = logging.StreamHandler(sys.stdout)
formatter = FormatterWithExceptions(settings=self._settings)
formatter = logging.Formatter(
"%(asctime)s - %(levelname)s - %(name)s - %(message)s"
)
handler.setFormatter(formatter)
handler.setLevel(logging.DEBUG)
logging.getLogger().addHandler(handler)

def _add_stderr_handler(self):
Expand All @@ -68,6 +71,7 @@ def _add_file_handler(self):
handler = PathTrackingFileHandler(settings=self._settings)
formatter = FormatterWithExceptions(settings=self._settings)
handler.setFormatter(formatter)
handler.setLevel(logging.INFO)
logging.getLogger().addHandler(handler)

def update_handlers(self, settings: LoggingSettings):
Expand Down
25 changes: 24 additions & 1 deletion openbb_platform/core/openbb_core/app/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pydantic.v1.validators import find_validators
from typing_extensions import Annotated, ParamSpec, _AnnotatedAlias
Expand Down Expand Up @@ -230,6 +231,26 @@ def __init__(
self._description = description
self._routers: Dict[str, Router] = {}

@overload
def stream(self, func: Callable[P, OBBject]) -> Callable[P, StreamingResponse]:
pass

@overload
def stream(self, **kwargs) -> Callable:
pass

def stream(
self,
func: Optional[Callable[P, OBBject]] = None,
**kwargs,
) -> Optional[Callable]:
"""Stream decorator for routes."""
if func is None:
return lambda f: self.stream(f, **kwargs)

kwargs["is_stream"] = True
return self.command(func, **kwargs)

@overload
def command(self, func: Optional[Callable[P, OBBject]]) -> Callable[P, OBBject]:
pass
Expand Down Expand Up @@ -260,6 +281,8 @@ def command(
examples=kwargs.pop("examples", []),
providers=ProviderInterface().available_providers,
)
kwargs["openapi_extra"]["is_stream"] = kwargs.pop("is_stream", False)

kwargs["operation_id"] = kwargs.get(
"operation_id", SignatureInspector.get_operation_id(func)
)
Expand Down Expand Up @@ -349,7 +372,7 @@ class SignatureInspector:

@classmethod
def complete(
cls, func: Callable[P, OBBject], model: str
cls, func: Callable[P, OBBject], model: str, is_stream: bool = False
) -> Optional[Callable[P, OBBject]]:
"""Complete function signature."""
if isclass(return_type := func.__annotations__["return"]) and not issubclass(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def get_path_hint_type_list(cls, path: str) -> List[Type]:
if route:
if route.deprecated:
hint_type_list.append(type(route.summary.metadata))

function_hint_type_list = cls.get_function_hint_type_list(func=route.endpoint) # type: ignore
hint_type_list.extend(function_hint_type_list)

Expand Down Expand Up @@ -1479,7 +1480,6 @@ def _get_provider_field_params(
expanded_types[field], is_required, "website"
)
field_type = f"Union[{field_type}, {expanded_type}]"

cleaned_description = (
str(field_info.description)
.strip().replace("\n", " ").replace(" ", " ").replace('"', "'")
Expand All @@ -1506,7 +1506,6 @@ def _get_provider_field_params(
# Manually setting to List[<field_type>] for multiple items
# Should be removed if TYPE_EXPANSION is updated to include this
field_type = f"Union[{field_type}, List[{field_type}]]"

default_value = "" if field_info.default is PydanticUndefined else field_info.default # fmt: skip

provider_field_params.append(
Expand Down
33 changes: 32 additions & 1 deletion openbb_platform/core/openbb_core/provider/abstract/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing import (
Any,
AsyncIterator,
Dict,
Generic,
Optional,
Expand Down Expand Up @@ -53,14 +54,19 @@ def transform_query(params: Dict[str, Any]) -> Q:
async def aextract_data(query: Q, credentials: Optional[Dict[str, str]]) -> Any:
"""Asynchronously extract the data from the provider."""

@staticmethod
async def atransform_data(
query: Q, data: Any, **kwargs
) -> Union[R, AnnotatedResult[R]]:
"""Asynchronously transform the provider-specific data."""

@staticmethod
def extract_data(query: Q, credentials: Optional[Dict[str, str]]) -> Any:
"""Extract the data from the provider."""

@staticmethod
def transform_data(query: Q, data: Any, **kwargs) -> Union[R, AnnotatedResult[R]]:
"""Transform the provider-specific data."""
raise NotImplementedError

def __init_subclass__(cls, *args, **kwargs):
"""Initialize the subclass."""
Expand All @@ -75,6 +81,15 @@ def __init_subclass__(cls, *args, **kwargs):
" default."
)

if cls.atransform_data != Fetcher.atransform_data:
cls.transform_data = cls.atransform_data
elif cls.transform_data == Fetcher.transform_data:
raise NotImplementedError(
"Fetcher subclass must implement either transform_data or atransform_data"
" method. If both are implemented, atransform_data will be used as the"
" default."
)

@classmethod
async def fetch_data(
cls,
Expand All @@ -89,6 +104,22 @@ async def fetch_data(
)
return cls.transform_data(query=query, data=data, **kwargs)

@classmethod
async def stream_data(
cls,
params: Dict[str, Any],
credentials: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[AsyncIterator[R], AsyncIterator[AnnotatedResult[R]]]:
"""Fetch data from a provider."""
query = cls.transform_query(params=params)
data = await maybe_coroutine(
cls.aextract_data, query=query, credentials=credentials, **kwargs
)
transformed_data = cls.atransform_data(query=query, data=data, **kwargs)
async for d in transformed_data:
yield d

@classproperty
def query_params_type(self) -> Q:
"""Get the type of query."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=W0613:unused-argument
"""Crypto Price Router."""

from fastapi.responses import StreamingResponse
from openbb_core.app.model.command_context import CommandContext
from openbb_core.app.model.example import APIEx
from openbb_core.app.model.obbject import OBBject
Expand All @@ -11,6 +12,9 @@
)
from openbb_core.app.query import Query
from openbb_core.app.router import Router
from providers.binance.openbb_binance.models.crypto_historical import (
BinanceCryptoHistoricalFetcher,
)

router = Router(prefix="/price")

Expand Down Expand Up @@ -56,3 +60,16 @@ async def historical(
) -> OBBject:
"""Get historical price data for cryptocurrency pair(s) within a provider."""
return await OBBject.from_query(Query(**locals()))


@router.stream(methods=["GET"])
async def live(symbol: str = "ethbtc", lifetime: int = 10, tld: str = "us") -> OBBject:
"""Connect to Binance WebSocket Crypto Price data feed."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This endpoint will not generate any documentation or descriptions because it is not using the ProviderInterface.

generator = BinanceCryptoHistoricalFetcher().stream_data(
params={"symbol": symbol, "lifetime": lifetime, "tld": tld},
credentials=None,
)
return OBBject(
results=StreamingResponse(generator, media_type="application/x-ndjson"),
provider="binance",
)
14 changes: 14 additions & 0 deletions openbb_platform/providers/binance/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# OpenBB Binance Provider

This extension integrates the Binance data provider
into the OpenBB Platform.

## Installation

To install the extension, run the following command in this folder:

```bash
pip install openbb-binance
```

Documentation available [here](https://docs.openbb.co/platform/development/contributing).
1 change: 1 addition & 0 deletions openbb_platform/providers/binance/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Binance provider."""
22 changes: 22 additions & 0 deletions openbb_platform/providers/binance/openbb_binance/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Binance provider module."""

from openbb_core.provider.abstract.provider import Provider

from providers.binance.openbb_binance.models.crypto_historical import (
BinanceCryptoHistoricalFetcher,
)

binance_provider = Provider(
name="binance",
website="https://api.binance.com",
description="""Binance is a cryptocurrency exchange that provides a platform for trading various cryptocurrencies.

The Binance API features both REST and WebSocket endpoints for accessing historical and real-time data.
""",
# credentials=["api_key"],
fetcher_dict={
"CryptoLive": BinanceCryptoHistoricalFetcher,
},
repr_name="Binance",
instructions="",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Binance Provider models."""
Loading
Loading