From 8b95dfe8ac83cfc65707ad3e25bc067c1b8f7bdf Mon Sep 17 00:00:00 2001 From: doku88 Date: Wed, 20 Nov 2024 00:17:41 -0800 Subject: [PATCH] upload return jsons --- synth_sdk/tracing/upload.py | 57 +++++++++++++++++++++++++++++++------ 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/synth_sdk/tracing/upload.py b/synth_sdk/tracing/upload.py index 10ef292..a95f64c 100644 --- a/synth_sdk/tracing/upload.py +++ b/synth_sdk/tracing/upload.py @@ -174,19 +174,54 @@ def is_event_loop_running(): # This exception is raised if no event loop is running return False +def format_upload_output(dataset, traces): + # Format questions array + questions_data = [ + { + "intent": q.intent, + "criteria": q.criteria, + "question_id": q.question_id + } for q in dataset.questions + ] + + # Format reward signals array with error handling + reward_signals_data = [ + { + "system_id": rs.system_id, + "reward": rs.reward, + "question_id": rs.question_id, + "annotation": rs.annotation if hasattr(rs, 'annotation') else None + } for rs in dataset.reward_signals + ] + + # Format traces array + traces_data = [ + { + "system_id": t.system_id, + "partition": [ + { + "partition_index": p.partition_index, + "events": [e.to_dict() for e in p.events] + } for p in t.partition + ] + } for t in traces + ] + + return questions_data, reward_signals_data, traces_data + # Supports calls from both async and sync contexts def upload(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool = False, show_payload: bool = False): """Upload all system traces and dataset to the server. - Returns a tuple of (response, payload, dataset, traces) + Returns a tuple of (response, questions_json, reward_signals_json, traces_json) + Note that you can directly upload questions, reward_signals, and traces to the server using the Website - Response is the response from the server. - Payload is the payload that was sent to the server. - Dataset is the dataset that was uploaded. - Traces is the list of traces that were tracked by the trace_system() decorators and SynthTracker.""" + response is the response from the server. + questions_json is the formatted questions array + reward_signals_json is the formatted reward signals array + traces_json is the formatted traces array""" async def upload_wrapper(dataset, traces, verbose, show_payload): - result = await upload_helper(dataset, traces, verbose, show_payload) - return result - + response, payload, dataset, traces = await upload_helper(dataset, traces, verbose, show_payload) + # If we're in an async context (event loop is running) if is_event_loop_running(): logging.info("Event loop is already running") @@ -266,7 +301,11 @@ async def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: if show_payload: print("Payload sent to server: ") pprint(payload) - return response, payload, dataset, traces + + #return response, payload, dataset, traces + questions_json, reward_signals_json, traces_json = format_upload_output(dataset, traces) + return response, questions_json, reward_signals_json, traces_json + except ValueError as e: if verbose: print("Validation error:", str(e))