Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshuaPurtell committed Nov 15, 2024
1 parent 338882b commit 5405265
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 98 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "synth-sdk"
version = "0.2.73"
version = "0.2.74"
description = ""
authors = [{name = "Synth AI", email = "[email protected]"}]
license = {text = "MIT"}
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="synth-sdk",
version="0.2.73",
version="0.2.74",
packages=find_packages(),
install_requires=[
"opentelemetry-api",
Expand Down
2 changes: 0 additions & 2 deletions synth_sdk/tracing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,6 @@ async def async_wrapper(*args, **kwargs):
else:
logger.warning(f"Unhandled traced output item: {item}")

print("COMPUTE STEPS BY ORIGIN", compute_steps_by_origin)
# Capture compute end time
compute_ended = time.time()

# Create compute steps grouped by origin
Expand Down
101 changes: 50 additions & 51 deletions testing/traces_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from zyk import LM
from synth_sdk.tracing.decorators import trace_system, _local
from synth_sdk.tracing.trackers import SynthTracker
from synth_sdk.tracing.upload import upload
from synth_sdk.tracing.upload import validate_json
from synth_sdk.tracing.upload import createPayload
from synth_sdk.tracing.upload import upload, validate_json, createPayload
from synth_sdk.tracing.abstractions import (
TrainingQuestion, RewardSignal, Dataset, SystemTrace, EventPartitionElement,
MessageInputs, MessageOutputs, ArbitraryInputs, ArbitraryOutputs
Expand All @@ -14,7 +12,7 @@
import time
import logging
import pytest
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, patch
import requests

# Configure logging
Expand All @@ -24,39 +22,38 @@
)
logger = logging.getLogger(__name__)


# Unit Test Configuration:
# ===============================
# Unit Test Configuration:
# ===============================
questions = ["What's the capital of France?"]
mock_llm_response = "The capital of France is Paris."

eventPartition_test = [EventPartitionElement(0, []),
EventPartitionElement(1, [])]

trace_test = [SystemTrace(
system_id="test_agent_upload",
partition=eventPartition_test,
current_partition_index=1)]
eventPartition_test = [
EventPartitionElement(0, []),
EventPartitionElement(1, [])
]

trace_test = [
SystemTrace(
system_id="test_agent_upload",
partition=eventPartition_test,
current_partition_index=1
)
]

# This function generates a payload from the data in the dataset to compare the sent payload against
def generate_payload_from_data(dataset: Dataset, traces: List[SystemTrace]) -> Dict:

payload = {
"traces": [
trace.to_dict() for trace in traces
], # Convert SystemTrace objects to dicts
"traces": [trace.to_dict() for trace in traces],
"dataset": dataset.to_dict(),
}
return payload

def createPayload_wrapper(dataset: Dataset, traces: str, base_url: str, api_key: str) -> Dict:
payload = createPayload(dataset, traces)

response = requests.Response()
response.status_code = 200

return response, payload

# ===============================
# Utility Functions
def createUploadDataset(agent):
Expand All @@ -79,23 +76,20 @@ def createUploadDataset(agent):
for i in range(len(questions))
],
)

logger.debug(
"Dataset created with %d questions and %d reward signals",
len(dataset.questions),
len(dataset.reward_signals),
"Dataset created with %d questions and %d reward signals",
len(dataset.questions),
len(dataset.reward_signals),
)

return dataset

def ask_questions(agent):
# Make multiple LM calls with environment processing
# Make multiple LM calls with environment processing
responses = []
for i, question in enumerate(questions):
logger.info("Processing question %d: %s", i, question)
env_result = agent.process_environment(question)
logger.debug("Environment processing result: %s", env_result)

response = agent.make_lm_call(question)
responses.append(response)
logger.debug("Response received and stored: %s", response)
Expand All @@ -122,6 +116,7 @@ def __init__(self):
def make_lm_call(self, user_message: str) -> str:
# Create MessageInputs
message_input = MessageInputs(messages=[{"role": "user", "content": user_message}])
# Track LM interaction using the new SynthTracker form
SynthTracker.track_lm(
messages=message_input.messages,
model_name=self.lm.model_name,
Expand All @@ -135,6 +130,7 @@ def make_lm_call(self, user_message: str) -> str:

# Create MessageOutputs
message_output = MessageOutputs(messages=[{"role": "assistant", "content": response}])
# Track state using the new SynthTracker form
SynthTracker.track_state(
variable_name="response",
variable_value=message_output.messages,
Expand All @@ -154,6 +150,7 @@ def make_lm_call(self, user_message: str) -> str:
def process_environment(self, input_data: str) -> dict:
# Create ArbitraryInputs
arbitrary_input = ArbitraryInputs(inputs={"input_data": input_data})
# Track state using the new SynthTracker form
SynthTracker.track_state(
variable_name="input_data",
variable_value=arbitrary_input.inputs,
Expand All @@ -165,6 +162,7 @@ def process_environment(self, input_data: str) -> dict:

# Create ArbitraryOutputs
arbitrary_output = ArbitraryOutputs(outputs=result)
# Track state using the new SynthTracker form
SynthTracker.track_state(
variable_name="result",
variable_value=arbitrary_output.outputs,
Expand All @@ -173,13 +171,12 @@ def process_environment(self, input_data: str) -> dict:
)
return result


# Use the new SynthTracker finalize method appropriately
@patch("synth_sdk.tracing.upload.send_system_traces", side_effect=createPayload_wrapper)
def test_generate_traces_sync(mock_send_system_traces):
logger.info("Starting run_test")
agent = TestAgent() # Create test agent

logger.debug("Test questions initialized: %s", questions) # List of test questions
logger.info("Starting test_generate_traces_sync")
agent = TestAgent() # Create test agent
logger.debug("Test questions initialized: %s", questions) # List of test questions

# Ask questions
responses = ask_questions(agent)
Expand All @@ -189,28 +186,30 @@ def test_generate_traces_sync(mock_send_system_traces):
dataset = createUploadDataset(agent)

# Upload traces
logger.info("Attempting to upload traces, async version")
# TODO: uploads traces directly from list of traces (override the event_store)

logger.info("Attempting to upload traces (sync version)")
# Pytest assertion
payload_ground_truth = generate_payload_from_data(dataset, trace_test)
#event_store.get_system_traces())

_ , payload_default_trace = upload(dataset=dataset, verbose=True, show_payload=True)
_, payload_default_trace = upload(dataset=dataset, verbose=True, show_payload=True)
assert payload_ground_truth == payload_default_trace

_ , payload_trace_passed = upload(dataset=dataset, traces=trace_test,
verbose=True, show_payload=True)
_, payload_trace_passed = upload(
dataset=dataset,
traces=trace_test,
verbose=True,
show_payload=True
)
assert payload_ground_truth == payload_trace_passed

# Finalize the tracker
SynthTracker.finalize()
logger.info("Resetting event store 0")
event_store.__init__()

@pytest.mark.asyncio(loop_scope="session")
@pytest.mark.asyncio
@patch("synth_sdk.tracing.upload.send_system_traces", side_effect=createPayload_wrapper)
async def test_generate_traces_async(mock_send_system_traces):
logger.info("Starting run_test")
logger.info("Starting test_generate_traces_async")
agent = TestAgent()

# Ask questions
Expand All @@ -221,27 +220,27 @@ async def test_generate_traces_async(mock_send_system_traces):
dataset = createUploadDataset(agent)

# Upload traces
logger.info("Attempting to upload traces, non-async version")
# TODO: uploads traces directly from list of traces (override the event_store)

logger.info("Attempting to upload traces (async version)")
# Pytest assertion
payload_ground_truth = generate_payload_from_data(dataset, trace_test)
#event_store.get_system_traces())

_ , payload_default_trace = await upload(dataset=dataset, verbose=True, show_payload=True)
_, payload_default_trace = await upload(dataset=dataset, verbose=True, show_payload=True)
assert payload_ground_truth == payload_default_trace

_ , payload_trace_passed = await upload(dataset=dataset, traces=trace_test,
verbose=True, show_payload=True)
_, payload_trace_passed = await upload(
dataset=dataset,
traces=trace_test,
verbose=True,
show_payload=True
)
assert payload_ground_truth == payload_trace_passed

# Finalize the tracker
SynthTracker.finalize()

logger.info("Resetting event store 1")
event_store.__init__()

# Run a sample agent using the sync decorator and tracker
# This file tests that upload function
# Run the tests
if __name__ == "__main__":
logger.info("Starting main execution")
asyncio.run(test_generate_traces_async())
Expand Down
Loading

0 comments on commit 5405265

Please sign in to comment.