Skip to content

Commit

Permalink
Delayed step
Browse files Browse the repository at this point in the history
  • Loading branch information
imbolc committed Jun 25, 2024
1 parent 5cfdf16 commit 0e6bc52
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 47 deletions.
40 changes: 17 additions & 23 deletions examples/counter.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use pg_task::{Step, StepResult};
use pg_task::{NextStep, Step, StepResult};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::{env, time::Duration};
Expand All @@ -21,7 +21,7 @@ async fn main() -> anyhow::Result<()> {
init_logging()?;

// Let's schedule a few tasks
pg_task::enqueue(&db, &Tasks::Count(Start { up_to: 2 }.into())).await?;
pg_task::enqueue(&db, &Tasks::Count(Start { up_to: 1000 }.into())).await?;

// And run a worker
pg_task::Worker::<Tasks>::new(db).run().await;
Expand All @@ -35,16 +35,13 @@ pub struct Start {
}
#[async_trait]
impl Step<Count> for Start {
async fn step(self, _db: &PgPool) -> StepResult<Option<Count>> {
async fn step(self, _db: &PgPool) -> StepResult<Count> {
println!("1..{}: start", self.up_to);
Ok(Some(
Proceed {
up_to: self.up_to,
started_at: Utc::now(),
cur: 0,
}
.into(),
))
NextStep::now(Proceed {
up_to: self.up_to,
started_at: Utc::now(),
cur: 0,
})
}
}

Expand All @@ -59,7 +56,7 @@ impl Step<Count> for Proceed {
const RETRY_LIMIT: i32 = 5;
const RETRY_DELAY: Duration = Duration::from_secs(1);

async fn step(self, _db: &PgPool) -> StepResult<Option<Count>> {
async fn step(self, _db: &PgPool) -> StepResult<Count> {
// return Err(anyhow::anyhow!("bailing").into());
let Self {
up_to,
Expand All @@ -69,16 +66,13 @@ impl Step<Count> for Proceed {
cur += 1;
// println!("1..{up_to}: {cur}");
if cur < up_to {
Ok(Some(
Proceed {
up_to,
started_at,
cur,
}
.into(),
))
NextStep::now(Proceed {
up_to,
started_at,
cur,
})
} else {
Ok(Some(Finish { up_to, started_at }.into()))
NextStep::now(Finish { up_to, started_at })
}
}
}
Expand All @@ -90,7 +84,7 @@ pub struct Finish {
}
#[async_trait]
impl Step<Count> for Finish {
async fn step(self, _db: &PgPool) -> StepResult<Option<Count>> {
async fn step(self, _db: &PgPool) -> StepResult<Count> {
let took = Utc::now() - self.started_at;
let secs = num_seconds(took);
let per_sec = self.up_to as f64 / secs;
Expand All @@ -100,7 +94,7 @@ impl Step<Count> for Finish {
secs,
per_sec.round()
);
Ok(None)
NextStep::none()
}
}

Expand Down
53 changes: 53 additions & 0 deletions examples/delay.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use async_trait::async_trait;
use pg_task::{NextStep, Step, StepResult};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::{env, time::Duration};

// It wraps the task step into an enum which proxies necessary methods
pg_task::task!(FooBar { Foo, Bar });

// Also we need a enum representing all the possible tasks
pg_task::scheduler!(Tasks { FooBar });

#[tokio::main]
async fn main() -> anyhow::Result<()> {
let db = connect().await?;

// Let's schedule a few tasks
for delay in [3, 1, 2] {
pg_task::enqueue(&db, &Tasks::FooBar(Foo(delay).into())).await?;
}

// And run a worker
pg_task::Worker::<Tasks>::new(db).run().await;

Ok(())
}

#[derive(Debug, Deserialize, Serialize)]
pub struct Foo(u64);
#[async_trait]
impl Step<FooBar> for Foo {
async fn step(self, _db: &PgPool) -> StepResult<FooBar> {
println!("Sleeping for {} sec", self.0);
NextStep::delay(Bar(self.0), Duration::from_secs(self.0))
}
}

#[derive(Debug, Deserialize, Serialize)]
pub struct Bar(u64);
#[async_trait]
impl Step<FooBar> for Bar {
async fn step(self, _db: &PgPool) -> StepResult<FooBar> {
println!("Woke up after {} sec", self.0);
NextStep::none()
}
}

async fn connect() -> anyhow::Result<sqlx::PgPool> {
dotenv::dotenv().ok();
let db = PgPool::connect(&env::var("DATABASE_URL")?).await?;
sqlx::migrate!().run(&db).await?;
Ok(db)
}
3 changes: 2 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::NextStep;
use std::{error::Error as StdError, result::Result as StdResult};
use tracing::error;

Expand All @@ -19,4 +20,4 @@ pub type Result<T> = StdResult<T, Error>;
pub type StepError = Box<dyn StdError + 'static>;

/// Result returning from task steps
pub type StepResult<T> = StdResult<T, StepError>;
pub type StepResult<T> = StdResult<NextStep<T>, StepError>;
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

mod error;
mod macros;
mod next_step;
mod traits;
mod util;
mod worker;

pub use error::{Error, Result, StepError, StepResult};
pub use next_step::NextStep;
pub use traits::{Scheduler, Step};
pub use worker::Worker;

Expand Down
14 changes: 10 additions & 4 deletions src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,16 @@ macro_rules! task {

#[async_trait::async_trait]
impl $crate::Step<$enum> for $enum {
async fn step(self, db: &sqlx::PgPool) -> $crate::StepResult<Option<$enum>> {
Ok(match self {
$(Self::$variant(inner) => inner.step(db).await?.map(Into::into),)*
})
async fn step(self, db: &sqlx::PgPool) -> $crate::StepResult<$enum> {
match self {
$(Self::$variant(inner) => inner.step(db).await.map(|next|
match next {
NextStep::None => NextStep::None,
NextStep::Now(x) => NextStep::Now(x.into()),
NextStep::Delayed(x, d) => NextStep::Delayed(x.into(), d),
}
),)*
}
}

fn retry_limit(&self) -> i32 {
Expand Down
29 changes: 29 additions & 0 deletions src/next_step.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use crate::StepResult;
use std::time::Duration;

/// Represents next step of the task
pub enum NextStep<T> {
/// The task is done
None,
/// Run the next step immediately
Now(T),
/// Delay the next step
Delayed(T, Duration),
}

impl<T> NextStep<T> {
/// The task is done
pub fn none() -> StepResult<T> {
Ok(Self::None)
}

/// Run the next step immediately
pub fn now(step: impl Into<T>) -> StepResult<T> {
Ok(Self::Now(step.into()))
}

/// Delay the next step
pub fn delay(step: impl Into<T>, delay: Duration) -> StepResult<T> {
Ok(Self::Delayed(step.into(), delay))
}
}
10 changes: 5 additions & 5 deletions src/traits.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Error, Result, StepResult};
use crate::{Error, StepResult};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{de::DeserializeOwned, Serialize};
Expand All @@ -19,7 +19,7 @@ where
const RETRY_DELAY: Duration = Duration::from_secs(1);

/// Processes the current step and returns the next if any
async fn step(self, db: &PgPool) -> StepResult<Option<Task>>;
async fn step(self, db: &PgPool) -> StepResult<Task>;

/// Proxies the `RETRY` const, doesn't mean to be changed in impls
fn retry_limit(&self) -> i32 {
Expand All @@ -36,19 +36,19 @@ where
#[async_trait]
pub trait Scheduler: fmt::Debug + DeserializeOwned + Serialize + Sized + Sync {
/// Enqueues the task to be run immediately
async fn enqueue(&self, db: &PgPool) -> Result<Uuid> {
async fn enqueue(&self, db: &PgPool) -> crate::Result<Uuid> {
self.schedule(db, Utc::now()).await
}

/// Schedules a task to be run after a specified delay
async fn delay(&self, db: &PgPool, delay: Duration) -> Result<Uuid> {
async fn delay(&self, db: &PgPool, delay: Duration) -> crate::Result<Uuid> {
let delay =
chrono::Duration::from_std(delay).unwrap_or_else(|_| chrono::Duration::max_value());
self.schedule(db, Utc::now() + delay).await
}

/// Schedules a task to run at a specified time in the future
async fn schedule(&self, db: &PgPool, at: DateTime<Utc>) -> Result<Uuid> {
async fn schedule(&self, db: &PgPool, at: DateTime<Utc>) -> crate::Result<Uuid> {
let step = serde_json::to_string(self).map_err(Error::SerializeStep)?;
sqlx::query!(
"INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2) RETURNING id",
Expand Down
11 changes: 11 additions & 0 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/// Converts a chrono duration to std, it uses absolute value of the chrono duration
pub fn chrono_duration_to_std(chrono_duration: chrono::Duration) -> std::time::Duration {
let seconds = chrono_duration.num_seconds();
let nanos = chrono_duration.num_nanoseconds().unwrap_or(0) % 1_000_000_000;
std::time::Duration::new(seconds.unsigned_abs(), nanos.unsigned_abs() as u32)
}

/// Converts a std duration to chrono
pub fn std_duration_to_chrono(std_duration: std::time::Duration) -> chrono::Duration {
chrono::Duration::from_std(std_duration).unwrap_or_else(|_| chrono::Duration::max_value())
}
24 changes: 10 additions & 14 deletions src/worker.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::{Step, StepError};
use crate::{util, NextStep, Step, StepError};
use chrono::{DateTime, Utc};
use code_path::code_path;
use sqlx::{
Expand Down Expand Up @@ -65,7 +65,7 @@ impl Task {
if delay <= chrono::Duration::zero() {
Duration::ZERO
} else {
chrono_duration_to_std(delay)
util::chrono_duration_to_std(delay)
}
}
}
Expand Down Expand Up @@ -129,8 +129,9 @@ impl<S: Step<S>> Worker<S> {
self.process_error(id, tried, retry_limit, retry_delay, e)
.await?
}
Ok(None) => self.finish_task(id).await?,
Ok(Some(step)) => self.update_task_step(id, step).await?,
Ok(NextStep::None) => self.finish_task(id).await?,
Ok(NextStep::Now(step)) => self.update_task_step(id, step, Duration::ZERO).await?,
Ok(NextStep::Delayed(step, delay)) => self.update_task_step(id, step, delay).await?,
};
Ok(())
}
Expand Down Expand Up @@ -171,7 +172,9 @@ impl<S: Step<S>> Worker<S> {
tx.commit()
.await
.map_err(sqlx_error!("commit on wait for a period"))?;
waiter.wait_for(chrono_duration_to_std(time_to_run)).await?;
waiter
.wait_for(util::chrono_duration_to_std(time_to_run))
.await?;
} else {
tx.commit()
.await
Expand All @@ -182,7 +185,7 @@ impl<S: Step<S>> Worker<S> {
}

/// Updates the tasks step
async fn update_task_step(&self, task_id: Uuid, step: S) -> Result<()> {
async fn update_task_step(&self, task_id: Uuid, step: S, delay: Duration) -> Result<()> {
let step = match serde_json::to_string(&step)
.map_err(|e| ErrorReport::SerializeStep(e, format!("{:?}", step)))
{
Expand All @@ -209,7 +212,7 @@ impl<S: Step<S>> Worker<S> {
",
task_id,
step,
Utc::now(),
Utc::now() + util::std_duration_to_chrono(delay),
)
.execute(&self.db)
.await
Expand Down Expand Up @@ -403,13 +406,6 @@ impl TaskWaiter {
}
}

/// Converts a chrono duration to std, it uses absolute value of the chrono duration
fn chrono_duration_to_std(chrono_duration: chrono::Duration) -> std::time::Duration {
let seconds = chrono_duration.num_seconds();
let nanos = chrono_duration.num_nanoseconds().unwrap_or(0) % 1_000_000_000;
std::time::Duration::new(seconds.unsigned_abs(), nanos.unsigned_abs() as u32)
}

/// Returns the ordinal string of a given integer
fn ordinal(n: i32) -> String {
match n.abs() {
Expand Down

0 comments on commit 0e6bc52

Please sign in to comment.