Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaPurtell committed Dec 16, 2024
1 parent fce310b commit 4eba13c
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 42 deletions.
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "synth-sdk"
version = "0.2.84"
version = "0.2.93"
description = ""
authors = [{name = "Synth AI", email = "[email protected]"}]
license = {text = "MIT"}
Expand All @@ -12,7 +12,7 @@ dependencies = [
"pydantic",
"requests",
"asyncio",
"zyk==0.2.21",
"zyk>=0.2.24",
"build>=1.2.2.post1",
"pypi",
"twine>=4.0.0",
Expand All @@ -22,8 +22,7 @@ dependencies = [
"pytest>=8.3.3",
"pydantic-openapi-schema>=1.5.1",
"pytest-asyncio>=0.24.0",
"apropos-ai>=0.4.5",
"craftaxlm>=0.0.5",
"craftaxlm>=0.0.7",
]
classifiers = []

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="synth-sdk",
version="0.2.84",
version="0.2.93",
packages=find_packages(),
install_requires=[
"opentelemetry-api",
Expand Down
6 changes: 6 additions & 0 deletions synth_sdk/tracing/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class AgentComputeStep(ComputeStep):
compute_input: List[Union[MessageInputs, ArbitraryInputs]]
compute_output: List[Union[MessageOutputs, ArbitraryOutputs]]

def to_dict(self):
base_dict = super().to_dict() # Get the parent class serialization
base_dict["model_name"] = self.model_name # Add model_name
return base_dict


@dataclass
class EnvironmentComputeStep(ComputeStep):
Expand Down Expand Up @@ -128,6 +133,7 @@ def to_dict(self):
"system_id": self.system_id,
"partition": [element.to_dict() for element in self.partition],
"current_partition_index": self.current_partition_index,
"metadata": self.metadata if self.metadata else None
}


Expand Down
1 change: 1 addition & 0 deletions synth_sdk/tracing/events/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _get_or_create():
#logger.debug(f"Creating new system trace for {system_id}")
self._traces[system_id] = SystemTrace(
system_id=system_id,
metadata={},
partition=[EventPartitionElement(partition_index=0, events=[])],
current_partition_index=0,
)
Expand Down
101 changes: 64 additions & 37 deletions synth_sdk/tracing/upload.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
from typing import List, Dict, Any, Union, Tuple, Coroutine
from pydantic import BaseModel, validator
import synth_sdk.config.settings
import requests
import asyncio
import json
import logging
import os
import time
from synth_sdk.tracing.events.store import event_store
from synth_sdk.tracing.abstractions import Dataset, SystemTrace
import json
from pprint import pprint
import asyncio
from typing import Any, Dict, List

import requests
from pydantic import BaseModel, validator

from synth_sdk.tracing.abstractions import Dataset, SystemTrace
from synth_sdk.tracing.events.store import event_store


def validate_json(data: dict) -> None:
#Validate that a dictionary contains only JSON-serializable values.
# Validate that a dictionary contains only JSON-serializable values.

#Args:
# Args:
# data: Dictionary to validate for JSON serialization

#Raises:
# Raises:
# ValueError: If the dictionary contains non-serializable values

try:
json.dumps(data)
except (TypeError, OverflowError) as e:
raise ValueError(f"Contains non-JSON-serializable values: {e}. {data}")


def createPayload(dataset: Dataset, traces: List[SystemTrace]) -> Dict[str, Any]:
payload = {
"traces": [
Expand All @@ -35,8 +37,12 @@ def createPayload(dataset: Dataset, traces: List[SystemTrace]) -> Dict[str, Any]
}
return payload


def send_system_traces(
dataset: Dataset, traces: List[SystemTrace], base_url: str, api_key: str,
dataset: Dataset,
traces: List[SystemTrace],
base_url: str,
api_key: str,
):
# Send all system traces and dataset metadata to the server.
# Get the token using the API key
Expand All @@ -50,7 +56,7 @@ def send_system_traces(
# Send the traces with the token
api_url = f"{base_url}/v1/uploads/"

payload = createPayload(dataset, traces) # Create the payload
payload = createPayload(dataset, traces) # Create the payload

validate_json(payload) # Validate the entire payload

Expand Down Expand Up @@ -91,6 +97,11 @@ def validate_traces(cls, traces):
if "partition" not in trace:
raise ValueError("Each trace must have a partition")

# Validate metadata if present
if "metadata" in trace and trace["metadata"] is not None:
if not isinstance(trace["metadata"], dict):
raise ValueError("Metadata must be a dictionary")

# Validate partition structure
partition = trace["partition"]
if not isinstance(partition, list):
Expand Down Expand Up @@ -157,8 +168,8 @@ def validate_dataset(cls, dataset):


def validate_upload(traces: List[Dict[str, Any]], dataset: Dict[str, Any]):
#Validate the upload format before sending to server.
#Raises ValueError if validation fails.
# Validate the upload format before sending to server.
# Raises ValueError if validation fails.
try:
UploadValidator(traces=traces, dataset=dataset)
return True
Expand All @@ -174,55 +185,69 @@ def is_event_loop_running():
# This exception is raised if no event loop is running
return False


def format_upload_output(dataset, traces):
# Format questions array
questions_data = [
{
"intent": q.intent,
"criteria": q.criteria,
"question_id": q.question_id
} for q in dataset.questions
{"intent": q.intent, "criteria": q.criteria, "question_id": q.question_id}
for q in dataset.questions
]

# Format reward signals array with error handling
reward_signals_data = [
{
"system_id": rs.system_id,
"reward": rs.reward,
"question_id": rs.question_id,
"annotation": rs.annotation if hasattr(rs, 'annotation') else None
} for rs in dataset.reward_signals
"annotation": rs.annotation if hasattr(rs, "annotation") else None,
}
for rs in dataset.reward_signals
]

# Format traces array
traces_data = [
{
"system_id": t.system_id,
"metadata": t.metadata if t.metadata else None,
"partition": [
{
"partition_index": p.partition_index,
"events": [e.to_dict() for e in p.events]
} for p in t.partition
]
} for t in traces
"events": [e.to_dict() for e in p.events],
}
for p in t.partition
],
}
for t in traces
]

return questions_data, reward_signals_data, traces_data


# Supports calls from both async and sync contexts
def upload(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False):
def upload(
dataset: Dataset,
traces: List[SystemTrace] = [],
verbose: bool = False,
show_payload: bool = False,
):
"""Upload all system traces and dataset to the server.
Returns a tuple of (response, questions_json, reward_signals_json, traces_json)
Note that you can directly upload questions, reward_signals, and traces to the server using the Website
response is the response from the server.
questions_json is the formatted questions array
reward_signals_json is the formatted reward signals array
traces_json is the formatted traces array"""

return upload_helper(dataset, traces, verbose, show_payload)

def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False):

def upload_helper(
dataset: Dataset,
traces: List[SystemTrace] = [],
verbose: bool = False,
show_payload: bool = False,
):
api_key = os.getenv("SYNTH_API_KEY")
if not api_key:
raise ValueError("SYNTH_API_KEY environment variable not set")
Expand All @@ -245,8 +270,8 @@ def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool

# Also close any unclosed events in existing traces
logged_traces = event_store.get_system_traces()
traces = logged_traces+ traces
#traces = event_store.get_system_traces() if len(traces) == 0 else traces
traces = logged_traces + traces
# traces = event_store.get_system_traces() if len(traces) == 0 else traces
current_time = time.time()
for trace in traces:
for partition in trace.partition:
Expand Down Expand Up @@ -291,10 +316,12 @@ def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool
print("Payload sent to server: ")
pprint(payload)

#return response, payload, dataset, traces
questions_json, reward_signals_json, traces_json = format_upload_output(dataset, traces)
# return response, payload, dataset, traces
questions_json, reward_signals_json, traces_json = format_upload_output(
dataset, traces
)
return response, questions_json, reward_signals_json, traces_json

except ValueError as e:
if verbose:
print("Validation error:", str(e))
Expand Down

0 comments on commit 4eba13c

Please sign in to comment.