Skip to content

Commit

Permalink
Merge pull request #7 from synth-laboratories/simplify
Browse files Browse the repository at this point in the history
SDK Changes: SynthTracker, @trace_system(), upload()
  • Loading branch information
JoshuaPurtell authored Nov 7, 2024
2 parents 3dba64a + 0ebe474 commit 9bb1059
Show file tree
Hide file tree
Showing 7 changed files with 509 additions and 82 deletions.
36 changes: 34 additions & 2 deletions synth_sdk/tracing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import threading
import time
import logging
import inspect
import contextvars
from pydantic import BaseModel

Expand All @@ -22,7 +21,6 @@
from functools import wraps
import time
import logging
import inspect
from pydantic import BaseModel

from synth_sdk.tracing.abstractions import (
Expand All @@ -35,6 +33,8 @@
from synth_sdk.tracing.trackers import synth_tracker_async
from synth_sdk.tracing.events.manage import set_current_event

import inspect

logger = logging.getLogger(__name__)

# # This decorator is used to trace synchronous functions
Expand Down Expand Up @@ -371,3 +371,35 @@ async def async_wrapper(*args, **kwargs):
return async_wrapper

return decorator

def trace_system(
origin: Literal["agent", "environment"],
event_type: str,
log_result: bool = False,
manage_event: Literal["create", "end", "lazy_end", None] = None,
increment_partition: bool = False,
verbose: bool = False,
) -> Callable:
"""
Decorator that chooses the correct tracing method (sync or async) based on
whether the wrapped function is synchronous or asynchronous.
"""
def decorator(func: Callable) -> Callable:
# Check if the function is async or sync
if inspect.iscoroutinefunction(func):
# Use async tracing
logger.debug("Using async tracing")
async_decorator = trace_system_async(
origin, event_type, log_result, manage_event, increment_partition, verbose
)
return async_decorator(func)
else:
# Use sync tracing
logger.debug("Using sync tracing")
sync_decorator = trace_system_sync(
origin, event_type, log_result, manage_event, increment_partition, verbose
)
return sync_decorator(func)

return decorator

101 changes: 100 additions & 1 deletion synth_sdk/tracing/trackers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union, Optional, Tuple, Literal
import asyncio
import threading, contextvars
import contextvars
from pydantic import BaseModel
Expand Down Expand Up @@ -162,7 +163,105 @@ def finalize(cls):
trace_outputs_var.set([])
logger.debug("Finalized async trace data")


# Make traces available globally
synth_tracker_sync = SynthTrackerSync
synth_tracker_async = SynthTrackerAsync

# Generalized SynthTracker class, depending on if an event loop is running (called from async)
# & if the specified tracker is initalized will determine the appropriate tracker to use
class SynthTracker:
def is_called_by_async():
try:
asyncio.get_running_loop() # Attempt to get the running event loop
return True # If successful, we are in an async context
except RuntimeError:
return False # If there's no running event loop, we are in a sync context

# SynthTracker Async & Sync are initalized by the decorators that wrap the
# respective async & sync functions
@classmethod
def initialize(cls):
pass

@classmethod
def track_input(
cls,
var: Union[BaseModel, str, dict, int, float, bool, list, None],
variable_name: str,
origin: Literal["agent", "environment"],
annotation: Optional[str] = None,
async_sync: Literal["async", "sync", ""] = "", # Force the tracker to be async or sync
):

if async_sync == "async" or cls.is_called_by_async() and trace_initialized_var.get():
logger.debug("Using async tracker to track input")
synth_tracker_async.track_input(var, variable_name, origin, annotation)

# don't want to add the same event to both trackers
elif async_sync == "sync" or hasattr(synth_tracker_sync._local, "initialized"):
logger.debug("Using sync tracker to track input")
synth_tracker_sync.track_input(var, variable_name, origin, annotation)
else:
raise RuntimeError("track_input() \n Trace not initialized. Use within a function decorated with @trace_system_async or @trace_system_sync.")

@classmethod
def track_output(
cls,
var: Union[BaseModel, str, dict, int, float, bool, list, None],
variable_name: str,
origin: Literal["agent", "environment"],
annotation: Optional[str] = None,
async_sync: Literal["async", "sync", ""] = "", # Force the tracker to be async or sync
):
if async_sync == "async" or cls.is_called_by_async() and trace_initialized_var.get():
logger.debug("Using async tracker to track output")
synth_tracker_async.track_output(var, variable_name, origin, annotation)

# don't want to add the same event to both trackers
elif async_sync == "sync" or hasattr(synth_tracker_sync._local, "initialized"):
logger.debug("Using sync tracker to track output")
synth_tracker_sync.track_output(var, variable_name, origin, annotation)
else:
raise RuntimeError("track_output() \n Trace not initialized. Use within a function decorated with @trace_system_async or @trace_system_sync.")


# if both trackers have been used, want to return both sets of data
@classmethod
def get_traced_data(
cls,
async_sync: Literal["async", "sync", ""] = "", # Force only async or sync data to be returned
) -> Tuple[list, list]:

traced_inputs, traced_outputs = [], []

if async_sync == "async" or async_sync == "":
# Attempt to get the traced data from the async tracker
logger.debug("Getting traced data from async tracker")
traced_inputs_async, traced_outputs_async = synth_tracker_async.get_traced_data()
traced_inputs.extend(traced_inputs_async)
traced_outputs.extend(traced_outputs_async)

if async_sync == "sync" or async_sync == "":
# Attempt to get the traced data from the sync tracker
logger.debug("Getting traced data from sync tracker")
traced_inputs_sync, traced_outputs_sync = synth_tracker_sync.get_traced_data()
traced_inputs.extend(traced_inputs_sync)
traced_outputs.extend(traced_outputs_sync)

# TODO: Test that the order of the inputs and outputs is correct wrt
# the order of events since we are combining the two trackers
return traced_inputs, traced_outputs

# Finalize both trackers
@classmethod
def finalize(
cls,
async_sync: Literal["async", "sync", ""] = "",
):
if async_sync == "async" or async_sync == "":
logger.debug("Finalizing async tracker")
synth_tracker_async.finalize()

if async_sync == "sync" or async_sync == "":
logger.debug("Finalizing sync tracker")
synth_tracker_sync.finalize()
46 changes: 37 additions & 9 deletions synth_sdk/tracing/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
import os
import time
from synth_sdk.tracing.events.store import event_store
from synth_sdk.tracing.abstractions import Dataset
from synth_sdk.tracing.abstractions import Dataset, SystemTrace
import json
from pprint import pprint
import asyncio


def validate_json(data: dict) -> None:
Expand All @@ -26,7 +27,7 @@ def validate_json(data: dict) -> None:
except (TypeError, OverflowError) as e:
raise ValueError(f"Contains non-JSON-serializable values: {e}. {data}")

def createPayload(dataset: Dataset, traces: str) -> Dict[str, Any]:
def createPayload(dataset: Dataset, traces: List[SystemTrace]) -> Dict[str, Any]:
payload = {
"traces": [
trace.to_dict() for trace in traces
Expand All @@ -36,8 +37,8 @@ def createPayload(dataset: Dataset, traces: str) -> Dict[str, Any]:
return payload

def send_system_traces(
dataset: Dataset, base_url: str, api_key: str
) -> requests.Response:
dataset: Dataset, traces: List[SystemTrace], base_url: str, api_key: str,
):
"""Send all system traces and dataset metadata to the server."""
# Get the token using the API key
token_url = f"{base_url}/token"
Expand All @@ -46,7 +47,7 @@ def send_system_traces(
)
token_response.raise_for_status()
access_token = token_response.json()["access_token"]
traces = event_store.get_system_traces()

# print("Traces: ", traces)
# Send the traces with the token
api_url = f"{base_url}/upload/"
Expand Down Expand Up @@ -169,7 +170,35 @@ def validate_upload(traces: List[Dict[str, Any]], dataset: Dict[str, Any]):
raise ValueError(f"Upload validation failed: {str(e)}")


async def upload(dataset: Dataset, verbose: bool = False, show_payload: bool = False):
def is_event_loop_running():
try:
asyncio.get_running_loop() # Check if there's a running event loop
return True
except RuntimeError:
# This exception is raised if no event loop is running
return False

# Supports calls from both async and sync contexts
def upload(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False):
async def upload_wrapper(dataset, traces, verbose, show_payload):
result = await upload_helper(dataset, traces, verbose, show_payload)
return result

if is_event_loop_running():
logging.info("Event loop is already running")
task = asyncio.create_task(upload_wrapper(dataset, traces, verbose, show_payload))
# Wait for the task if called from an async function
if asyncio.current_task():
return task # Returning the task to be awaited if in async context
else:
# Run task synchronously by waiting for it to finish if in sync context
return asyncio.get_event_loop().run_until_complete(task)

else:
logging.info("Event loop is not running")
return asyncio.run(upload_wrapper(dataset, traces, verbose, show_payload))

async def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False):
"""Upload all system traces and dataset to the server."""
api_key = os.getenv("SYNTH_API_KEY")
if not api_key:
Expand All @@ -192,7 +221,7 @@ async def upload(dataset: Dataset, verbose: bool = False, show_payload: bool = F
_local.active_events.clear()

# Also close any unclosed events in existing traces
traces = event_store.get_system_traces()
traces = event_store.get_system_traces() if len(traces) == 0 else traces
current_time = time.time()
for trace in traces:
for partition in trace.partition:
Expand All @@ -205,7 +234,6 @@ async def upload(dataset: Dataset, verbose: bool = False, show_payload: bool = F

try:
# Get traces and convert to dict format
traces = event_store.get_system_traces()
if len(traces) == 0:
raise ValueError("No system traces found")
traces_dict = [trace.to_dict() for trace in traces]
Expand All @@ -221,6 +249,7 @@ async def upload(dataset: Dataset, verbose: bool = False, show_payload: bool = F
# Send to server
response, payload = send_system_traces(
dataset=dataset,
traces=traces,
base_url="https://agent-learning.onrender.com",
api_key=api_key,
)
Expand All @@ -236,7 +265,6 @@ async def upload(dataset: Dataset, verbose: bool = False, show_payload: bool = F
if show_payload:
print("Payload sent to server: ")
pprint(payload)

return response, payload
except ValueError as e:
if verbose:
Expand Down
16 changes: 8 additions & 8 deletions testing/ai_agent_async.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from zyk import LM
from synth_sdk.tracing.decorators import trace_system_async, _local
from synth_sdk.tracing.trackers import SynthTrackerAsync
from synth_sdk.tracing.decorators import trace_system, _local
from synth_sdk.tracing.trackers import SynthTracker
from synth_sdk.tracing.upload import upload
from synth_sdk.tracing.abstractions import TrainingQuestion, RewardSignal, Dataset
from synth_sdk.tracing.events.store import event_store
Expand Down Expand Up @@ -29,7 +29,7 @@ def __init__(self):
)
logger.debug("LM initialized")

@trace_system_async(
@trace_system(
origin="agent",
event_type="lm_call",
manage_event="create",
Expand All @@ -38,32 +38,32 @@ def __init__(self):
)
async def make_lm_call(self, user_message: str) -> str:
# Only pass the user message, not self
SynthTrackerAsync.track_input([user_message], variable_name="user_message", origin="agent")
SynthTracker.track_input([user_message], variable_name="user_message", origin="agent")

logger.debug("Starting LM call with message: %s", user_message)
response = await self.lm.respond_async(
system_message="You are a helpful assistant.", user_message=user_message
)

SynthTrackerAsync.track_output(response, variable_name="response", origin="agent")
SynthTracker.track_output(response, variable_name="response", origin="agent")

logger.debug("LM response received: %s", response)
time.sleep(0.1)
return response

@trace_system_async(
@trace_system(
origin="environment",
event_type="environment_processing",
manage_event="create",
verbose=True,
)
async def process_environment(self, input_data: str) -> dict:
# Only pass the input data, not self
SynthTrackerAsync.track_input([input_data], variable_name="input_data", origin="environment")
SynthTracker.track_input([input_data], variable_name="input_data", origin="environment")

result = {"processed": input_data, "timestamp": time.time()}

SynthTrackerAsync.track_output(result, variable_name="result", origin="environment")
SynthTracker.track_output(result, variable_name="result", origin="environment")
return result


Expand Down
16 changes: 8 additions & 8 deletions testing/ai_agent_sync.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from zyk import LM
from synth_sdk.tracing.decorators import trace_system_sync, _local
from synth_sdk.tracing.trackers import SynthTrackerSync
from synth_sdk.tracing.decorators import trace_system, _local
from synth_sdk.tracing.trackers import SynthTracker
from synth_sdk.tracing.upload import upload
from synth_sdk.tracing.abstractions import TrainingQuestion, RewardSignal, Dataset
from synth_sdk.tracing.events.store import event_store
Expand Down Expand Up @@ -29,7 +29,7 @@ def __init__(self):
)
logger.debug("LM initialized")

@trace_system_sync(
@trace_system(
origin="agent",
event_type="lm_call",
manage_event="create",
Expand All @@ -38,32 +38,32 @@ def __init__(self):
)
def make_lm_call(self, user_message: str) -> str:
# Only pass the user message, not self
SynthTrackerSync.track_input([user_message], variable_name="user_message", origin="agent")
SynthTracker.track_input([user_message], variable_name="user_message", origin="agent")

logger.debug("Starting LM call with message: %s", user_message)
response = self.lm.respond_sync(
system_message="You are a helpful assistant.", user_message=user_message
)

SynthTrackerSync.track_output(response, variable_name="response", origin="agent")
SynthTracker.track_output(response, variable_name="response", origin="agent")

logger.debug("LM response received: %s", response)
time.sleep(0.1)
return response

@trace_system_sync(
@trace_system(
origin="environment",
event_type="environment_processing",
manage_event="create",
verbose=True,
)
def process_environment(self, input_data: str) -> dict:
# Only pass the input data, not self
SynthTrackerSync.track_input([input_data], variable_name="input_data", origin="environment")
SynthTracker.track_input([input_data], variable_name="input_data", origin="environment")

result = {"processed": input_data, "timestamp": time.time()}

SynthTrackerSync.track_output(result, variable_name="result", origin="environment")
SynthTracker.track_output(result, variable_name="result", origin="environment")
return result


Expand Down
Loading

0 comments on commit 9bb1059

Please sign in to comment.