Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
DoKu88 committed Dec 16, 2024
1 parent cc0f0f7 commit 2f05ef2
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 124 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ dependencies = [
"craftaxlm>=0.0.5",
"boto3>=1.35.71",
"botocore>=1.35.71",
"tqdm>=4.66.4",
"aiohttp>=3.8.6"
"tqdm>=4.66.4"
]
classifiers = []

Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,3 @@ zyk>=0.2.10
boto3>=1.35.71
botocore>=1.35.71
tqdm>=4.66.4
aiohttp>=3.8.6
121 changes: 0 additions & 121 deletions synth_sdk/tracing/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import sys
from pympler import asizeof
from tqdm import tqdm
import aiohttp
import boto3
from datetime import datetime

Expand Down Expand Up @@ -136,63 +135,6 @@ async def _async_operations():
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_operations())

async def send_system_traces(dataset: Dataset, traces: List[SystemTrace], base_url: str, api_key: str, upload_id: str):
async with aiohttp.ClientSession() as session:
# Get token
token_url = f"{base_url}/v1/auth/token"
async with session.get(token_url, headers={"customer_specific_api_key": api_key}) as token_response:
token_response.raise_for_status()
token_data = await token_response.json()
access_token = token_data["access_token"]

# Send traces
api_url = f"{base_url}/v1/uploads/{upload_id}"
payload = createPayload(dataset, traces)
validate_json(payload)

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {access_token}"
}

try:
async with session.post(api_url, json=payload, headers=headers) as response:
response.raise_for_status()
response_data = await response.json()
logging.info(f"Upload ID: {response_data.get('upload_id')}")
return response, payload
except aiohttp.ClientResponseError as e:
logging.error(f"HTTP error occurred: Status {e.status} - {e.message}")
raise
except Exception as e:
logging.error(f"An error occurred: {e}")
raise

def chunk_traces(traces: List[SystemTrace], chunk_size_kb: int = 1024):
"""Split traces into chunks that won't exceed approximately chunk_size_kb when serialized"""
chunks = []
current_chunk = []
current_size = 0

for trace in traces:
trace_dict = trace.to_dict()
trace_size = asizeof.asizeof(trace_dict) / 1024 # Memory size in KB
logging.info(f"Trace size (in memory): {trace_size:.2f} KB")

if current_size + trace_size > chunk_size_kb:
# Current chunk would exceed size limit, start new chunk
chunks.append(current_chunk)
current_chunk = [trace]
current_size = trace_size
else:
current_chunk.append(trace)
current_size += trace_size

if current_chunk:
chunks.append(current_chunk)

return chunks

async def get_upload_id(base_url: str, api_key: str):
token_url = f"{base_url}/v1/auth/token"
token_response = requests.get(token_url, headers={"customer_specific_api_key": api_key})
Expand All @@ -218,68 +160,6 @@ async def get_upload_id(base_url: str, api_key: str):
logging.error(f"An error occurred: {e}")
raise

def send_system_traces_chunked(dataset: Dataset, traces: List[SystemTrace],
base_url: str, api_key: str, chunk_size_kb: int = 1024):
"""Upload traces in chunks to avoid memory issues"""

async def _async_upload():
trace_chunks = chunk_traces(traces, chunk_size_kb)
upload_id = await get_upload_id(base_url, api_key)

tasks = []
total_chunks = len(trace_chunks)

# Create progress bar
progress_bar = tqdm(total=total_chunks, desc="Uploading chunks", unit="chunk")

async def upload_with_progress(chunk):
try:
result = await send_system_traces(dataset, chunk, base_url, api_key, upload_id)
progress_bar.update(1)
return result
except Exception as e:
progress_bar.close()
raise e

try:
# Create and gather all tasks
tasks = [upload_with_progress(chunk) for chunk in trace_chunks]
results = await asyncio.gather(*tasks, return_exceptions=True)

# Check for any exceptions in results
for result in results:
if isinstance(result, Exception):
raise result

return results[0] if results else (None, None)
finally:
progress_bar.close()
# Cancel any pending tasks
for task in tasks:
if not task.done():
task.cancel()

# Handle the event loop
try:
if not is_event_loop_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(_async_upload())
else:
loop = asyncio.get_event_loop()
return loop.run_until_complete(_async_upload())
finally:
# Only close the loop if we created it
if 'loop' in locals() and not is_event_loop_running():
try:
# Cancel all pending tasks before closing
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
finally:
loop.close()

class UploadValidator(BaseModel):
traces: List[Dict[str, Any]]
dataset: Dict[str, Any]
Expand Down Expand Up @@ -496,7 +376,6 @@ def upload_helper(dataset: Dataset, traces: List[SystemTrace]=[], verbose: bool
print("Payload sent to server: ")
pprint(payload)

#return response, payload, dataset, traces
questions_json, reward_signals_json, traces_json = format_upload_output(dataset, traces)
return response, questions_json, reward_signals_json, traces_json

Expand Down

0 comments on commit 2f05ef2

Please sign in to comment.