Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Add GraphQL mutation to do self-service user registration #3050

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/cli/src/commands/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ impl Options {
homeserver_connection.clone(),
site_config.clone(),
password_manager.clone(),
http_client_factory.clone(),
url_builder.clone(),
);

let state = {
Expand Down
6 changes: 5 additions & 1 deletion crates/handlers/src/captcha.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

use std::net::IpAddr;

use async_graphql::InputObject;
use axum::BoxError;
use hyper::Request;
use mas_axum_utils::http_client_factory::HttpClientFactory;
Expand Down Expand Up @@ -58,8 +59,11 @@ pub enum Error {
RequestFailed(#[source] BoxError),
}

/// Form (or GraphQL input) containing a CAPTCHA provider's response
/// for one of the providers.
#[allow(clippy::struct_field_names)]
#[derive(Debug, Deserialize, Default)]
#[derive(Debug, Deserialize, Default, InputObject)]
#[graphql(input_name = "CaptchaForm")]
#[serde(rename_all = "kebab-case")]
pub struct Form {
g_recaptcha_response: Option<String>,
Expand Down
61 changes: 50 additions & 11 deletions crates/handlers/src/graphql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ use futures_util::TryStreamExt;
use headers::{authorization::Bearer, Authorization, ContentType, HeaderValue};
use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{
cookies::CookieJar, sentry::SentryEventID, FancyError, SessionInfo, SessionInfoExt,
cookies::CookieJar, http_client_factory::HttpClientFactory, sentry::SentryEventID, FancyError,
SessionInfo, SessionInfoExt,
};
use mas_data_model::{BrowserSession, Session, SiteConfig, User};
use mas_data_model::{BrowserSession, Session, SiteConfig, User, UserAgent};
use mas_matrix::HomeserverConnection;
use mas_policy::{InstantiateError, Policy, PolicyFactory};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, Clock, RepositoryError, SystemClock};
use mas_storage_pg::PgRepository;
use opentelemetry_semantic_conventions::trace::{GRAPHQL_DOCUMENT, GRAPHQL_OPERATION_NAME};
Expand All @@ -59,8 +61,11 @@ use self::{
model::{CreationEvent, Node},
mutations::Mutation,
query::Query,
state::GraphQLCookieJar,
};
use crate::{
impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker, PreferredLanguage,
};
use crate::{impl_from_error_for_route, passwords::PasswordManager, BoundActivityTracker};

#[cfg(test)]
mod tests;
Expand All @@ -71,6 +76,8 @@ struct GraphQLState {
policy_factory: Arc<PolicyFactory>,
site_config: SiteConfig,
password_manager: PasswordManager,
http_client_factory: HttpClientFactory,
url_builder: UrlBuilder,
}

#[async_trait]
Expand Down Expand Up @@ -111,6 +118,14 @@ impl state::State for GraphQLState {
let rng = ChaChaRng::from_rng(rng).expect("Failed to seed rng");
Box::new(rng)
}

fn http_client_factory(&self) -> &HttpClientFactory {
&self.http_client_factory
}

fn url_builder(&self) -> &UrlBuilder {
&self.url_builder
}
}

#[must_use]
Expand All @@ -120,13 +135,17 @@ pub fn schema(
homeserver_connection: impl HomeserverConnection<Error = anyhow::Error> + 'static,
site_config: SiteConfig,
password_manager: PasswordManager,
http_client_factory: HttpClientFactory,
url_builder: UrlBuilder,
) -> Schema {
let state = GraphQLState {
pool: pool.clone(),
policy_factory: Arc::clone(policy_factory),
homeserver_connection: Arc::new(homeserver_connection),
site_config,
password_manager,
http_client_factory,
url_builder,
};
let state: BoxState = Box::new(state);

Expand Down Expand Up @@ -281,31 +300,39 @@ async fn get_requester(

pub async fn post(
AxumState(schema): AxumState<Schema>,
PreferredLanguage(locale): PreferredLanguage,
clock: BoxClock,
repo: BoxRepository,
activity_tracker: BoundActivityTracker,
cookie_jar: CookieJar,
content_type: Option<TypedHeader<ContentType>>,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
user_agent: Option<TypedHeader<headers::UserAgent>>,
body: Body,
) -> Result<impl IntoResponse, RouteError> {
let body = body.into_data_stream();
let token = authorization
.as_ref()
.map(|TypedHeader(Authorization(bearer))| bearer.token());
let (session_info, _cookie_jar) = cookie_jar.session_info();
let (session_info, cookie_jar) = cookie_jar.session_info();
let requester = get_requester(&clock, &activity_tracker, repo, session_info, token).await?;

let content_type = content_type.map(|TypedHeader(h)| h.to_string());

let gql_cookie_jar = Arc::new(GraphQLCookieJar::new(cookie_jar));

let request = async_graphql::http::receive_body(
content_type,
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read(),
MultipartOptions::default(),
)
.await?
.data(requester); // XXX: this should probably return another error response?
.await? // XXX: this should probably return another error response?
.data(requester)
.data(user_agent.map(|ua| UserAgent::parse(ua.as_str().to_owned())))
.data(locale)
.data(activity_tracker)
.data(gql_cookie_jar.clone());

let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
Expand All @@ -318,7 +345,10 @@ pub async fn post(

let headers = response.http_headers.clone();

Ok((headers, cache_control, Json(response)))
// unwrap: the cookie jar only has one reference (ours) after the request
let cookie_jar = Arc::into_inner(gql_cookie_jar).unwrap().into_inner();

Ok((headers, cache_control, cookie_jar, Json(response)))
}

pub async fn get(
Expand All @@ -328,16 +358,22 @@ pub async fn get(
activity_tracker: BoundActivityTracker,
cookie_jar: CookieJar,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
user_agent: Option<TypedHeader<headers::UserAgent>>,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
let token = authorization
.as_ref()
.map(|TypedHeader(Authorization(bearer))| bearer.token());
let (session_info, _cookie_jar) = cookie_jar.session_info();
let (session_info, cookie_jar) = cookie_jar.session_info();
let requester = get_requester(&clock, &activity_tracker, repo, session_info, token).await?;

let request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
let gql_cookie_jar = Arc::new(GraphQLCookieJar::new(cookie_jar));

let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?
.data(requester)
.data(activity_tracker)
.data(user_agent)
.data(gql_cookie_jar.clone());

let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
Expand All @@ -350,7 +386,10 @@ pub async fn get(

let headers = response.http_headers.clone();

Ok((headers, cache_control, Json(response)))
// unwrap: the cookie jar only has one reference (ours) after the request
let cookie_jar = Arc::into_inner(gql_cookie_jar).unwrap().into_inner();

Ok((headers, cache_control, cookie_jar, Json(response)))
}

pub async fn playground() -> impl IntoResponse {
Expand Down
Loading