From 6e9ca4cf1d0c8626be18bf1abefcc1e2b8236b55 Mon Sep 17 00:00:00 2001 From: Pierre LE GUEN <26087574+PierreLeGuen@users.noreply.github.com> Date: Thu, 5 Oct 2023 13:41:21 +0200 Subject: [PATCH] Add support for custom metadata --- src/main.rs | 77 ++++++++++++++++++++++++++------------ src/tta/models.rs | 3 ++ src/tta/tta_impl.rs | 91 ++++++++++++++++++++++++++++++++++----------- 3 files changed, 126 insertions(+), 45 deletions(-) diff --git a/src/main.rs b/src/main.rs index 31a8b6f..67cd560 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,5 @@ -use anyhow::Result; -use axum::{response::IntoResponse, Router}; use csv::Writer; -use hyper::{Body, Response}; +use hyper::Body; use tower::ServiceBuilder; use tower_http::{ cors::{Any, CorsLayer}, @@ -12,15 +10,24 @@ use tta::models::ReportRow; use axum::{ extract::{Query, State}, + http::StatusCode, + response::{IntoResponse, Response}, routing::get, + routing::post, + Json, Router, }; + use chrono::DateTime; use dotenvy::dotenv; use near_jsonrpc_client::{JsonRpcClient, NEAR_MAINNET_ARCHIVAL_RPC_URL}; use serde::Deserialize; use sqlx::postgres::PgPoolOptions; -use std::{collections::HashSet, env, fmt, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + env, + sync::{Arc, RwLock}, +}; use tokio::{spawn, sync::Semaphore}; use tracing::*; use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, EnvFilter, FmtSubscriber}; @@ -31,7 +38,7 @@ use crate::tta::{ft_metadata::FtService, sql::sql_queries::SqlClient}; pub mod tta; #[tokio::main] -async fn main() -> Result<()> { +async fn main() -> anyhow::Result<()> { info!("Starting up"); match dotenv() { @@ -92,7 +99,7 @@ fn init_tracing() -> anyhow::Result<()> { Ok(()) } -async fn router() -> Result { +async fn router() -> anyhow::Result { let pool = PgPoolOptions::new() .max_connections(30) .connect(env!("DATABASE_URL")) @@ -110,12 +117,16 @@ async fn router() -> Result { let middleware = ServiceBuilder::new().layer(trace).layer(cors); Ok(Router::new() + .route("/tta", post(get_txns_report)) .route("/tta", get(get_txns_report)) .with_state(tta_service) .layer(middleware)) } // HTTP layer +type AccountID = String; +type TransactionID = String; +type Metadata = HashMap>; #[derive(Debug, Deserialize)] struct TxnsReportParams { @@ -125,10 +136,16 @@ struct TxnsReportParams { pub include_balances: Option, } +#[derive(Debug, Deserialize, Default, Clone)] +struct TxnsReportWithMetadata { + pub metadata: Metadata, +} + async fn get_txns_report( Query(params): Query, State(tta_service): State, -) -> Result> { + metadata_body: Option>, +) -> Result, AppError> { let start_date: DateTime = DateTime::parse_from_rfc3339(¶ms.start_date) .unwrap() .into(); @@ -145,41 +162,63 @@ async fn get_txns_report( let include_balances = params.include_balances.unwrap_or(false); + let metadata = Arc::new(RwLock::new(metadata_body.unwrap_or_default().0)); + let csv_data = tta_service .get_txns_report( start_date.timestamp_nanos() as u128, end_date.timestamp_nanos() as u128, accounts, include_balances, + metadata, ) - .await - .unwrap(); + .await?; // Create a Writer with a Vec as the underlying writer let mut wtr = Writer::from_writer(Vec::new()); // Write the headers - wtr.write_record(&ReportRow::get_vec_headers()).unwrap(); + wtr.write_record(&ReportRow::get_vec_headers())?; // Write each row for row in csv_data { let record: Vec = row.to_vec(); - wtr.write_record(&record).unwrap(); + wtr.write_record(&record)?; } // Get the CSV data - let csv_data = wtr.into_inner().unwrap(); + let csv_data = wtr.into_inner()?; // Create a response with the CSV data let response = Response::builder() .header("Content-Type", "text/csv") .header("Content-Disposition", "attachment; filename=data.csv") - .body(Body::from(csv_data)) - .unwrap(); + .body(Body::from(csv_data))?; Ok(response) } +struct AppError(anyhow::Error); + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Something went wrong: {}", self.0), + ) + .into_response() + } +} + +impl From for AppError +where + E: Into, +{ + fn from(err: E) -> Self { + Self(err.into()) + } +} + #[cfg(test)] mod tests { use super::*; @@ -189,10 +228,6 @@ mod tests { #[tokio::test] async fn test_tta_router() { - let subscriber = FmtSubscriber::builder().finish(); - - tracing::subscriber::set_global_default(subscriber).unwrap(); - let router = router().await.unwrap(); let client = TestClient::new(router); let res = client.get("/tta?start_date=2023-01-01T00:00:00Z&end_date=2023-02-01T00:00:00Z&accounts=nf-payments.near&include_balances=false").send().await; @@ -201,12 +236,8 @@ mod tests { #[tokio::test] async fn loadtest_tta() { - let subscriber = FmtSubscriber::builder().finish(); - - tracing::subscriber::set_global_default(subscriber).unwrap(); let router = router().await.unwrap(); - - let request_url = "/tta?start_date=2023-01-01T00:00:00Z&end_date=2023-02-01T00:00:00Z&accounts=nf-payments.near&include_balances=true"; + let request_url = "/tta?start_date=2023-01-01T00:00:00Z&end_date=2023-02-01T00:00:00Z&accounts=nf-payments.near&include_balances=false"; let futures = (0..20) .map(|_| { diff --git a/src/tta/models.rs b/src/tta/models.rs index ba3354a..f5d6aac 100644 --- a/src/tta/models.rs +++ b/src/tta/models.rs @@ -22,6 +22,7 @@ pub struct ReportRow { pub amount_staked: f64, pub onchain_balance: Option, pub onchain_balance_token: Option, + pub metadata: Option, } // Define the extension trait @@ -57,6 +58,7 @@ impl ReportRow { "amount_staked".to_string(), "onchain_balance".to_string(), "onchain_balance_token".to_string(), + "metadata".to_string(), ] } @@ -83,6 +85,7 @@ impl ReportRow { self.onchain_balance .map_or(String::new(), |v| v.to_5dp_string()), self.onchain_balance_token.clone().unwrap_or_default(), + self.metadata.clone().unwrap_or_default(), ] } } diff --git a/src/tta/tta_impl.rs b/src/tta/tta_impl.rs index 4aa6586..90e31c8 100644 --- a/src/tta/tta_impl.rs +++ b/src/tta/tta_impl.rs @@ -1,11 +1,15 @@ -use std::{collections::HashSet, sync::Arc, vec}; +use std::{ + collections::HashSet, + sync::{Arc, RwLock}, + vec, +}; use anyhow::{bail, Context, Result}; use futures_util::future::join_all; use near_sdk::ONE_NEAR; -use crate::tta::utils::get_associated_lockup; +use crate::{tta::utils::get_associated_lockup, TxnsReportWithMetadata}; use base64::{engine::general_purpose, Engine as _}; use chrono::{NaiveDateTime, Utc}; @@ -88,6 +92,7 @@ impl TTA { end_date: u128, accounts: HashSet, include_balances: bool, + metadata: Arc>, ) -> Result> { info!(?start_date, ?end_date, ?accounts, "Got request"); @@ -116,6 +121,8 @@ impl TTA { let wallets_for_account = wallets_for_account.clone(); let t = t.clone(); let for_account = acc.clone(); + let metadata = metadata.clone(); + async move { let _s = s; t.handle_txns( @@ -125,6 +132,7 @@ impl TTA { start_date, end_date, include_balances, + metadata, ) .await } @@ -143,6 +151,8 @@ impl TTA { let wallets_for_account = wallets_for_account.clone(); let t = t.clone(); let for_account = acc.clone(); + let metadata = metadata.clone(); + async move { let _s = s; t.handle_txns( @@ -152,6 +162,7 @@ impl TTA { start_date, end_date, include_balances, + metadata, ) .await } @@ -170,6 +181,8 @@ impl TTA { let wallets_for_account = wallets_for_account.clone(); let t = t.clone(); let a = acc.clone(); + let metadata = metadata.clone(); + async move { let _s = s; @@ -180,6 +193,7 @@ impl TTA { start_date, end_date, include_balances, + metadata, ) .await } @@ -196,7 +210,7 @@ impl TTA { Ok(res) => match res { Ok(partial_report) => { let mut p = vec![]; - // Aply filtering + // Apply filtering for ele in partial_report { if let Some(ele) = assert_moves_token(ele) { p.push(ele) @@ -240,6 +254,7 @@ impl TTA { start_date: u128, end_date: u128, include_balances: bool, + metadata: Arc>, ) -> Result> { let mut report: Vec = vec![]; let (tx, mut rx) = channel(100); @@ -259,6 +274,7 @@ impl TTA { while let Some(txn) = rx.recv().await { let t2: TTA = self.clone(); let f2 = for_account.clone(); + let metadata = metadata.clone(); let row = tokio::spawn(async move { if txn.ara_action_kind != "FUNCTION_CALL" && txn.ara_action_kind != "TRANSFER" { return Ok(None); @@ -329,6 +345,13 @@ impl TTA { ); } + let data = metadata + .read() + .unwrap() + .metadata + .get(&f2) + .and_then(|m| m.get(&txn.t_transaction_hash).cloned()); + Ok(Some(ReportRow { account_id: f2.clone(), date: get_transaction_date(&txn), @@ -348,6 +371,7 @@ impl TTA { amount_staked: 0.0, onchain_balance, onchain_balance_token, + metadata: data, })) }); rows_handle.push(row); @@ -608,41 +632,32 @@ fn assert_moves_token(row: ReportRow) -> Option { #[cfg(test)] mod tests { + use std::collections::HashMap; + use chrono::DateTime; use near_jsonrpc_client::{JsonRpcClient, NEAR_MAINNET_ARCHIVAL_RPC_URL}; use sqlx::postgres::PgPoolOptions; use super::*; - async fn get_tta_service() { + async fn setup() -> Result<(SqlClient, FtService, TTA)> { let pool = PgPoolOptions::new() .max_connections(30) .connect(env!("DATABASE_URL")) - .await - .unwrap(); + .await?; let sql_client = SqlClient::new(pool); let near_client = JsonRpcClient::connect(NEAR_MAINNET_ARCHIVAL_RPC_URL); let ft_service = FtService::new(near_client); let semaphore = Arc::new(Semaphore::new(30)); + let tta_service = TTA::new(sql_client.clone(), ft_service.clone(), semaphore); - let tta_service = TTA::new(sql_client, ft_service, semaphore); + Ok((sql_client, ft_service, tta_service)) } #[tokio::test] - async fn tta() { - let pool = PgPoolOptions::new() - .max_connections(30) - .connect(env!("DATABASE_URL")) - .await - .unwrap(); - - let sql_client = SqlClient::new(pool); - let near_client = JsonRpcClient::connect(NEAR_MAINNET_ARCHIVAL_RPC_URL); - let ft_service = FtService::new(near_client); - let semaphore = Arc::new(Semaphore::new(30)); - - let tta_service = TTA::new(sql_client, ft_service, semaphore); + async fn tta() -> Result<()> { + let (_, _, tta_service) = setup().await?; let start_date = DateTime::parse_from_rfc3339("2022-01-01T00:00:00Z") .unwrap() @@ -656,8 +671,40 @@ mod tests { .collect(); let include_balances = false; - let res = tta_service.get_txns_report(start_date, end_date, accounts, include_balances); + let mut accounts_metadata = HashMap::new(); + let mut account_txns = HashMap::new(); + + account_txns.insert( + "51VVGwLAFX6K62jB84E6qVHdF4GbhEMB2CoZJ9ZziiEt".to_string(), + "unit test".to_string(), + ); + + accounts_metadata.insert("nf-payments.near".to_string(), account_txns); - assert!(res.await.is_ok()); + let metadata_struct = Arc::new(RwLock::new(TxnsReportWithMetadata { + metadata: accounts_metadata, + })); + + let res = tta_service + .get_txns_report( + start_date, + end_date, + accounts, + include_balances, + metadata_struct, + ) + .await + .unwrap(); + + assert!(!res.is_empty()); + + for row in res { + if row.transaction_hash == "51VVGwLAFX6K62jB84E6qVHdF4GbhEMB2CoZJ9ZziiEt" { + assert_eq!(row.metadata, Some("unit test".to_string())); + } else { + assert_eq!(row.metadata, None); + } + } + Ok(()) } }