Skip to content

Commit

Permalink
Merge pull request #351 from XpressAI/paul/async
Browse files Browse the repository at this point in the history
Add asyncio support with AsyncComponent
  • Loading branch information
MFA-X-AI authored Dec 18, 2024
2 parents d19e0ff + 5ba551c commit 5ed979a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 8 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"pygithub",
"tqdm",
"toml",
"importlib_resources"
"importlib_resources",
"asgiref"
]
dynamic = ["version", "description", "authors", "urls", "keywords"]

Expand Down
46 changes: 39 additions & 7 deletions xai_components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from typing import TypeVar, Generic, Tuple, NamedTuple, Callable, List
from copy import deepcopy

from asgiref.sync import async_to_sync, sync_to_async

T = TypeVar('T')


class OutArg(Generic[T]):
def __init__(self, value: T = None, getter: Callable[[T], any] = lambda x: x) -> None:
self._value = value
Expand Down Expand Up @@ -67,6 +70,7 @@ def __deepcopy__(self, memo):
memo[id_self] = _copy
return _copy


class InCompArg(Generic[T]):
def __init__(self, value: T = None, getter: Callable[[T], any] = lambda x: x) -> None:
self._value = value
Expand Down Expand Up @@ -98,6 +102,7 @@ def __deepcopy__(self, memo):
memo[id_self] = _copy
return _copy


def xai_component(*args, **kwargs):
# Passthrough element without any changes.
# This is used for parser metadata only.
Expand All @@ -108,14 +113,17 @@ def xai_component(*args, **kwargs):
# @xai_components(...) form
def passthrough(f):
return f

return passthrough


class ExecutionContext:
args: Namespace

def __init__(self, args: Namespace):
self.args = args


class BaseComponent:
def __init__(self):
all_ports = self.__annotations__
Expand Down Expand Up @@ -165,11 +173,12 @@ def __deepcopy__(self, memo):
setattr(_copy, key, deepcopy(getattr(self, key), memo))
return _copy


class Component(BaseComponent):
next: BaseComponent

def do(self, ctx) -> BaseComponent:
print(f"\nExecuting: {self.__class__.__name__}")
print(f"\nExecuting: {self.__class__.__name__}", flush=True)
self.execute(ctx)

return self.next
Expand All @@ -178,18 +187,35 @@ def debug_repr(self) -> str:
return "<h1>Component</h1>"


class AsyncComponent(BaseComponent):
next: BaseComponent

@async_to_sync
async def do(self, ctx) -> BaseComponent:
print(f"\nExecuting: {self.__class__.__name__}", flush=True)
await self.execute(ctx)
return self.next

def debug_repr(self) -> str:
return "<h1>AsyncComponent</h1>"


class SubGraphExecutor:

def __init__(self, component):
self.comp = component

def do(self, ctx):
comp = self.comp

while comp is not None:
comp = comp.do(ctx)
return None

@sync_to_async
def do_async(self, ctx):
return self.do(ctx)


def execute_graph(args: Namespace, start: BaseComponent, ctx) -> None:
BaseComponent.set_execution_context(ExecutionContext(args))
Expand All @@ -207,18 +233,21 @@ def execute_graph(args: Namespace, start: BaseComponent, ctx) -> None:
next_component = start.do(ctx)
while next_component:
next_component = next_component.do(ctx)


class secret:
pass


class message(NamedTuple):
role: str
content: str


class chat(NamedTuple):
messages: List[message]



class dynalist(list):
def __init__(self, *args):
super().__init__(args)
Expand All @@ -229,15 +258,18 @@ def getter(x):
return []
return [item.value if isinstance(item, (InArg, OutArg)) else item for item in x]


class dynatuple(tuple):
def __init__(self, *args):
super().__init__(args)

@staticmethod
def getter(x):
if x is None:
return tuple()

def resolve(item):
if isinstance(item, (InArg, InCompArg,OutArg)):
if isinstance(item, (InArg, InCompArg, OutArg)):
return item.value
else:
return item
Expand Down

0 comments on commit 5ed979a

Please sign in to comment.