Skip to content

Commit

Permalink
Merge branch 'local-changes' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
DoKu88 authored Dec 16, 2024
2 parents c309626 + 4eba13c commit 009dfa4
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 34 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"pydantic",
"requests",
"asyncio",
"zyk==0.2.21",
"zyk>=0.2.24",
"build>=1.2.2.post1",
"pypi",
"twine>=4.0.0",
Expand All @@ -23,10 +23,10 @@ dependencies = [
"pydantic-openapi-schema>=1.5.1",
"pytest-asyncio>=0.24.0",
"apropos-ai>=0.4.5",
"craftaxlm>=0.0.5",
"boto3>=1.35.71",
"botocore>=1.35.71",
"tqdm>=4.66.4"
"craftaxlm>=0.0.7",
]
classifiers = []

Expand Down
6 changes: 6 additions & 0 deletions synth_sdk/tracing/abstractions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class AgentComputeStep(ComputeStep):
compute_input: List[Union[MessageInputs, ArbitraryInputs]]
compute_output: List[Union[MessageOutputs, ArbitraryOutputs]]

def to_dict(self):
base_dict = super().to_dict() # Get the parent class serialization
base_dict["model_name"] = self.model_name # Add model_name
return base_dict


@dataclass
class EnvironmentComputeStep(ComputeStep):
Expand Down Expand Up @@ -128,6 +133,7 @@ def to_dict(self):
"system_id": self.system_id,
"partition": [element.to_dict() for element in self.partition],
"current_partition_index": self.current_partition_index,
"metadata": self.metadata if self.metadata else None
}


Expand Down
1 change: 1 addition & 0 deletions synth_sdk/tracing/events/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _get_or_create():
#logger.debug(f"Creating new system trace for {system_id}")
self._traces[system_id] = SystemTrace(
system_id=system_id,
metadata={},
partition=[EventPartitionElement(partition_index=0, events=[])],
current_partition_index=0,
)
Expand Down
88 changes: 56 additions & 32 deletions synth_sdk/tracing/upload.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
from typing import List, Dict, Any, Union, Tuple, Coroutine
from pydantic import BaseModel, validator
import synth_sdk.config.settings
import requests
import asyncio
import json
import logging
import os
import time
from synth_sdk.tracing.events.store import event_store
from synth_sdk.tracing.abstractions import Dataset, SystemTrace
import json
from pprint import pprint
import asyncio
import sys
Expand All @@ -16,21 +11,31 @@
import boto3
from datetime import datetime

from typing import Any, Dict, List

import requests
from pydantic import BaseModel, validator

from synth_sdk.tracing.abstractions import Dataset, SystemTrace
from synth_sdk.tracing.events.store import event_store


# NOTE: This may cause memory issues in the future
def validate_json(data: dict) -> None:
#Validate that a dictionary contains only JSON-serializable values.
# Validate that a dictionary contains only JSON-serializable values.

#Args:
# Args:
# data: Dictionary to validate for JSON serialization

#Raises:
# Raises:
# ValueError: If the dictionary contains non-serializable values

try:
json.dumps(data)
except (TypeError, OverflowError) as e:
raise ValueError(f"Contains non-JSON-serializable values: {e}. {data}")


def createPayload(dataset: Dataset, traces: List[SystemTrace]) -> Dict[str, Any]:
payload = {
"traces": [
Expand Down Expand Up @@ -176,6 +181,11 @@ def validate_traces(cls, traces):
if "partition" not in trace:
raise ValueError("Each trace must have a partition")

# Validate metadata if present
if "metadata" in trace and trace["metadata"] is not None:
if not isinstance(trace["metadata"], dict):
raise ValueError("Metadata must be a dictionary")

# Validate partition structure
partition = trace["partition"]
if not isinstance(partition, list):
Expand Down Expand Up @@ -242,8 +252,8 @@ def validate_dataset(cls, dataset):


def validate_upload(traces: List[Dict[str, Any]], dataset: Dict[str, Any]):
#Validate the upload format before sending to server.
#Raises ValueError if validation fails.
# Validate the upload format before sending to server.
# Raises ValueError if validation fails.
try:
UploadValidator(traces=traces, dataset=dataset)
return True
Expand All @@ -259,55 +269,69 @@ def is_event_loop_running():
# This exception is raised if no event loop is running
return False


def format_upload_output(dataset, traces):
# Format questions array
questions_data = [
{
"intent": q.intent,
"criteria": q.criteria,
"question_id": q.question_id
} for q in dataset.questions
{"intent": q.intent, "criteria": q.criteria, "question_id": q.question_id}
for q in dataset.questions
]

# Format reward signals array with error handling
reward_signals_data = [
{
"system_id": rs.system_id,
"reward": rs.reward,
"question_id": rs.question_id,
"annotation": rs.annotation if hasattr(rs, 'annotation') else None
} for rs in dataset.reward_signals
"annotation": rs.annotation if hasattr(rs, "annotation") else None,
}
for rs in dataset.reward_signals
]

# Format traces array
traces_data = [
{
"system_id": t.system_id,
"metadata": t.metadata if t.metadata else None,
"partition": [
{
"partition_index": p.partition_index,
"events": [e.to_dict() for e in p.events]
} for p in t.partition
]
} for t in traces
"events": [e.to_dict() for e in p.events],
}
for p in t.partition
],
}
for t in traces
]

return questions_data, reward_signals_data, traces_data


# Supports calls from both async and sync contexts
def upload(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False):
def upload(
dataset: Dataset,
traces: List[SystemTrace] = [],
verbose: bool = False,
show_payload: bool = False,
):
"""Upload all system traces and dataset to the server.
Returns a tuple of (response, questions_json, reward_signals_json, traces_json)
Note that you can directly upload questions, reward_signals, and traces to the server using the Website
response is the response from the server.
questions_json is the formatted questions array
reward_signals_json is the formatted reward signals array
traces_json is the formatted traces array"""

return upload_helper(dataset, traces, verbose, show_payload)

def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False):

def upload_helper(
dataset: Dataset,
traces: List[SystemTrace] = [],
verbose: bool = False,
show_payload: bool = False,
):
api_key = os.getenv("SYNTH_API_KEY")
if not api_key:
raise ValueError("SYNTH_API_KEY environment variable not set")
Expand All @@ -330,8 +354,8 @@ def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool

# Also close any unclosed events in existing traces
logged_traces = event_store.get_system_traces()
traces = logged_traces+ traces
#traces = event_store.get_system_traces() if len(traces) == 0 else traces
traces = logged_traces + traces
# traces = event_store.get_system_traces() if len(traces) == 0 else traces
current_time = time.time()
for trace in traces:
for partition in trace.partition:
Expand Down Expand Up @@ -378,7 +402,7 @@ def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool

questions_json, reward_signals_json, traces_json = format_upload_output(dataset, traces)
return response, questions_json, reward_signals_json, traces_json

except ValueError as e:
if verbose:
print("Validation error:", str(e))
Expand Down

0 comments on commit 009dfa4

Please sign in to comment.