Skip to content

Commit

Permalink
Merge pull request #4 from bennjii/json_errors
Browse files Browse the repository at this point in the history
feat: Appropriate JSON Schema Validation
  • Loading branch information
bennjii authored Dec 23, 2023
2 parents d8d0e66 + e1460fd commit 27d5b78
Show file tree
Hide file tree
Showing 29 changed files with 418 additions and 153 deletions.
60 changes: 60 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ rocket_dyn_templates = { version = "0.1.0-rc.3", features = ["handlebars", "tera
rocket_okapi = { version = "0.8.0-rc.3", features = ["uuid", "swagger", "rocket_db_pools", "rocket_dyn_templates"] }
schemars = { version = "0.8.15" , features = ["chrono"]}
okapi = { version = "0.7.0-rc.1", features = ["impl_json_schema"] }
rocket-validation = "0.1.3"
validator="0.16.1"

[dev-dependencies]
reqwest = "0.11.13"
Expand Down
188 changes: 182 additions & 6 deletions src/catchers.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,202 @@
use crate::{guards::UserErrorMessage};
use rocket::{serde::json::{json, Value}, Request, catch};
use crate::guards::JsonValidationError;
use rocket::{serde::json::{json, Value}, Request, catch, Data, form};
use rocket::data::{FromData, Outcome as DataOutcome};
use rocket::form::{DataField, FromForm, ValueField};
use rocket::http::Status;
use rocket::outcome::Outcome;
use rocket::request::FromRequest;
use rocket::serde::json::Json;
use rocket_okapi::gen::OpenApiGenerator;
use rocket_okapi::request::OpenApiFromData;
use schemars::JsonSchema;
use validator::{Validate, ValidationErrors};
use okapi::{
openapi3::{MediaType, RequestBody},
Map,
};

/*
The below code is a mix between json_validator
and serde handling, in order to handle serde validations
Credit to a large portion of it is to: owlnext-fr
https://github.com/owlnext-fr/rust-microservice-skeleton/blob/main/src/core/validation.rs
*/


#[derive(Clone, Debug, JsonSchema)]
pub struct Validated<T>(pub T);

#[derive(Clone)]
pub struct CachedValidationErrors(pub Option<ValidationErrors>);

#[derive(Clone)]
pub struct CachedParseErrors(pub Option<String>);

macro_rules! fn_request_body {
($gen:ident, $ty:path, $mime_type:expr) => {{
let schema = $gen.json_schema::<$ty>();
Ok(RequestBody {
content: {
let mut map = Map::new();
map.insert(
$mime_type.to_owned(),
MediaType {
schema: Some(schema),
..MediaType::default()
},
);
map
},
required: true,
..okapi::openapi3::RequestBody::default()
})
}};
}

impl<'r, D: validator::Validate + rocket::serde::Deserialize<'r> + JsonSchema> OpenApiFromData<'r> for Validated<Json<D>> {
fn request_body(gen: &mut OpenApiGenerator) -> rocket_okapi::Result<RequestBody> {
fn_request_body!(gen, D, "application/json")
}
}

#[rocket::async_trait]
impl<'r, D: validator::Validate + rocket::serde::Deserialize<'r> + JsonSchema> FromData<'r> for Validated<Json<D>> {
type Error = Result<ValidationErrors, rocket::serde::json::Error<'r>>;

async fn from_data(req: &'r Request<'_>, data: Data<'r>) -> DataOutcome<'r, Self> {
let data_outcome = <Json<D> as FromData<'r>>::from_data(req, data).await;

match data_outcome {
Outcome::Failure((status, err)) => {
req.local_cache(|| CachedParseErrors(Some(err.to_string())));
Outcome::Failure((status, Err(err)))
}
Outcome::Forward(err) => Outcome::Forward(err),
Outcome::Success(data) => match data.validate() {
Ok(_) => Outcome::Success(Validated(data)),
Err(err) => {
req.local_cache(|| CachedValidationErrors(Some(err.to_owned())));
Outcome::Failure((Status::BadRequest, Ok(err)))
}
},
}
}
}

#[rocket::async_trait]
impl<'r, D: Validate + FromRequest<'r>> FromRequest<'r> for Validated<D> {
type Error = Result<ValidationErrors, D::Error>;
async fn from_request(req: &'r Request<'_>) -> rocket::request::Outcome<Self, Self::Error> {
let data_outcome = D::from_request(req).await;

match data_outcome {
Outcome::Failure((status, err)) => {
let error_message = format!("{err:?}");
req.local_cache(|| CachedParseErrors(Some(error_message)));
Outcome::Failure((status, Err(err)))
}
Outcome::Forward(err) => Outcome::Forward(err),
Outcome::Success(data) => match data.validate() {
Ok(_) => Outcome::Success(Validated(data)),
Err(err) => {
req.local_cache(|| CachedValidationErrors(Some(err.to_owned())));
Outcome::Failure((Status::BadRequest, Ok(err)))
}
},
}
}
}


#[rocket::async_trait]
impl<'r, T: Validate + FromForm<'r>> FromForm<'r> for Validated<T> {
type Context = T::Context;

#[inline]
fn init(opts: form::Options) -> Self::Context {
T::init(opts)
}

#[inline]
fn push_value(ctxt: &mut Self::Context, field: ValueField<'r>) {
T::push_value(ctxt, field)
}

#[inline]
async fn push_data(ctxt: &mut Self::Context, field: DataField<'r, '_>) {
T::push_data(ctxt, field).await
}

fn finalize(this: Self::Context) -> form::Result<'r, Self> {
match T::finalize(this) {
Err(err) => Err(err),
Ok(data) => match data.validate() {
Ok(_) => Ok(Validated(data)),
Err(err) => Err(err
.into_errors()
.into_iter()
.map(|e| form::Error {
name: Some(e.0.into()),
kind: form::error::ErrorKind::Validation(std::borrow::Cow::Borrowed(e.0)),
value: None,
entity: form::error::Entity::Value,
})
.collect::<Vec<_>>()
.into()),
},
}
}
}

#[catch(400)]
pub fn general_catcher(req: &Request) -> Value {
json!([{
"code": "error.general",
"message": "Bad Request. The request could not be understood by the server due to malformed syntax.",
"errors": req.local_cache(|| CachedValidationErrors(None)).0.as_ref(),
}])
}

#[catch(403)]
pub fn not_authorized() -> Value {
json!([{"label": "unauthorized", "message": "Not authorized to make request"}])
json!([{"code": "error.unauthorized", "message": "Not authorized to make request"}])
}

#[catch(404)]
pub fn not_found() -> Value {
json!([])
json!([{"code": "error.not_found", "message": "The requested route was not found."}])
}

#[catch(422)]
pub fn unprocessable_entry(req: &Request) -> Value {
json! [{"label": "failed.request", "message": "failed to service request"}]
let possible_parse_violation = req.local_cache(|| CachedParseErrors(None)).0.as_ref();
let validation_errors = req.local_cache(|| CachedValidationErrors(None)).0.as_ref();

let mut message = "Failed to service request, structure parsing failed.".to_string();

if validation_errors.is_some() {
message.clear();

let erros = validation_errors.unwrap().field_errors();

for (_,val) in erros.iter() {
for error in val.iter() {
message.push_str(error.message.as_ref().unwrap());
}
}
} else if possible_parse_violation.is_some() {
message.clear();
message.push_str(possible_parse_violation.unwrap());
}

json! [{ "code": "error.input", "message": &message }]
}

#[catch(500)]
pub fn internal_server_error(req: &Request) -> Value {
let error_message = req
.local_cache(|| Some(UserErrorMessage("Internal server error".to_owned())));

json! [{"label": "internal.error", "message": error_message}]
json! [{"code": "error.internal", "message": error_message}]
}
2 changes: 1 addition & 1 deletion src/guards.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use rocket::{
data::{self, Data, FromData, Limits},
http::Status,
request::{self, local_cache, FromRequest, Request},
request::{local_cache, FromRequest, Request},
};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::sync::mpsc;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
use rocket::serde::json::Json;

#[cfg(feature = "process")]
pub mod entities;
Expand Down
8 changes: 3 additions & 5 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub use entities::*;
pub use methods::*;
#[cfg(feature = "sql")]
pub use migrator::*;
use open_stock::{catchers, guards};
use open_stock::{catchers};

#[cfg(feature = "sql")]
extern crate argon2;
Expand All @@ -49,9 +49,6 @@ impl Fairing for CORS {
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
let access_origin = dotenv::var("ACCESS_ORIGIN").unwrap();

// Permit `localhost:3000` when DEMO mode is enabled.
// `request.host().unwrap().domain().eq( &access_origin)`

response.set_header(Header::new("Access-Control-Allow-Origin", access_origin));
response.set_header(Header::new(
"Access-Control-Allow-Methods",
Expand Down Expand Up @@ -79,7 +76,8 @@ fn rocket() -> _ {
catchers::not_authorized,
catchers::internal_server_error,
catchers::not_found,
catchers::unprocessable_entry
catchers::unprocessable_entry,
catchers::general_catcher,
])
.attach(Db::init())
.attach(CORS)
Expand Down
Loading

0 comments on commit 27d5b78

Please sign in to comment.