1#![deny(clippy::future_not_send)]
8#![allow(
9 clippy::unused_async,
11 clippy::too_many_arguments,
13 clippy::let_with_type_underscore,
16)]
17
18use std::{
19 convert::Infallible,
20 sync::{Arc, LazyLock},
21 time::Duration,
22};
23
24use axum::{
25 Extension, Router,
26 extract::{FromRef, FromRequestParts, OriginalUri, RawQuery, State},
27 http::Method,
28 response::{Html, IntoResponse},
29 routing::{get, post},
30};
31use headers::HeaderName;
32use hyper::{
33 StatusCode, Version,
34 header::{
35 ACCEPT, ACCEPT_LANGUAGE, AUTHORIZATION, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE,
36 },
37};
38use mas_axum_utils::{InternalError, cookies::CookieJar};
39use mas_data_model::SiteConfig;
40use mas_http::CorsLayerExt;
41use mas_keystore::{Encrypter, Keystore};
42use mas_matrix::HomeserverConnection;
43use mas_policy::Policy;
44use mas_router::{Route, UrlBuilder};
45use mas_storage::{BoxRepository, BoxRepositoryFactory};
46use mas_templates::{ErrorContext, NotFoundContext, TemplateContext, Templates};
47use opentelemetry::metrics::Meter;
48use sqlx::PgPool;
49use tower::util::AndThenLayer;
50use tower_http::cors::{Any, CorsLayer};
51
52use self::{graphql::ExtraRouterParameters, passwords::PasswordManager};
53
54mod admin;
55mod compat;
56mod graphql;
57mod health;
58mod oauth2;
59pub mod passwords;
60pub mod upstream_oauth2;
61mod views;
62
63mod activity_tracker;
64mod captcha;
65mod preferred_language;
66mod rate_limit;
67mod session;
68#[cfg(test)]
69mod test_utils;
70
71static METER: LazyLock<Meter> = LazyLock::new(|| {
72 let scope = opentelemetry::InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
73 .with_version(env!("CARGO_PKG_VERSION"))
74 .with_schema_url(opentelemetry_semantic_conventions::SCHEMA_URL)
75 .build();
76
77 opentelemetry::global::meter_with_scope(scope)
78});
79
80#[macro_export]
83macro_rules! impl_from_error_for_route {
84 ($route_error:ty : $error:ty) => {
85 impl From<$error> for $route_error {
86 fn from(e: $error) -> Self {
87 Self::Internal(Box::new(e))
88 }
89 }
90 };
91 ($error:ty) => {
92 impl_from_error_for_route!(self::RouteError: $error);
93 };
94}
95
96pub use mas_axum_utils::{ErrorWrapper, cookies::CookieManager};
97use mas_data_model::{BoxClock, BoxRng};
98
99pub use self::{
100 activity_tracker::{ActivityTracker, Bound as BoundActivityTracker},
101 admin::router as admin_api_router,
102 graphql::{
103 Schema as GraphQLSchema, schema as graphql_schema, schema_builder as graphql_schema_builder,
104 },
105 preferred_language::PreferredLanguage,
106 rate_limit::{Limiter, RequesterFingerprint},
107 upstream_oauth2::cache::MetadataCache,
108};
109
110pub fn healthcheck_router<S>() -> Router<S>
111where
112 S: Clone + Send + Sync + 'static,
113 PgPool: FromRef<S>,
114{
115 Router::new().route(mas_router::Healthcheck::route(), get(self::health::get))
116}
117
118pub fn graphql_router<S>(playground: bool, undocumented_oauth2_access: bool) -> Router<S>
119where
120 S: Clone + Send + Sync + 'static,
121 graphql::Schema: FromRef<S>,
122 BoundActivityTracker: FromRequestParts<S>,
123 BoxRepository: FromRequestParts<S>,
124 BoxClock: FromRequestParts<S>,
125 Encrypter: FromRef<S>,
126 CookieJar: FromRequestParts<S>,
127 Limiter: FromRef<S>,
128 RequesterFingerprint: FromRequestParts<S>,
129{
130 let mut router = Router::new()
131 .route(
132 mas_router::GraphQL::route(),
133 get(self::graphql::get).post(self::graphql::post),
134 )
135 .layer(Extension(ExtraRouterParameters {
138 undocumented_oauth2_access,
139 }))
140 .layer(
141 CorsLayer::new()
142 .allow_origin(Any)
143 .allow_methods(Any)
144 .allow_otel_headers([
145 AUTHORIZATION,
146 ACCEPT,
147 ACCEPT_LANGUAGE,
148 CONTENT_LANGUAGE,
149 CONTENT_TYPE,
150 ]),
151 );
152
153 if playground {
154 router = router.route(
155 mas_router::GraphQLPlayground::route(),
156 get(self::graphql::playground),
157 );
158 }
159
160 router
161}
162
163pub fn discovery_router<S>() -> Router<S>
164where
165 S: Clone + Send + Sync + 'static,
166 Keystore: FromRef<S>,
167 SiteConfig: FromRef<S>,
168 UrlBuilder: FromRef<S>,
169 BoxClock: FromRequestParts<S>,
170 BoxRng: FromRequestParts<S>,
171{
172 Router::new()
173 .route(
174 mas_router::OidcConfiguration::route(),
175 get(self::oauth2::discovery::get),
176 )
177 .route(
178 mas_router::Webfinger::route(),
179 get(self::oauth2::webfinger::get),
180 )
181 .layer(
182 CorsLayer::new()
183 .allow_origin(Any)
184 .allow_methods(Any)
185 .allow_otel_headers([
186 AUTHORIZATION,
187 ACCEPT,
188 ACCEPT_LANGUAGE,
189 CONTENT_LANGUAGE,
190 CONTENT_TYPE,
191 ])
192 .max_age(Duration::from_secs(60 * 60)),
193 )
194}
195
196pub fn api_router<S>() -> Router<S>
197where
198 S: Clone + Send + Sync + 'static,
199 Keystore: FromRef<S>,
200 UrlBuilder: FromRef<S>,
201 BoxRepository: FromRequestParts<S>,
202 ActivityTracker: FromRequestParts<S>,
203 BoundActivityTracker: FromRequestParts<S>,
204 Encrypter: FromRef<S>,
205 reqwest::Client: FromRef<S>,
206 SiteConfig: FromRef<S>,
207 Templates: FromRef<S>,
208 Arc<dyn HomeserverConnection>: FromRef<S>,
209 BoxClock: FromRequestParts<S>,
210 BoxRng: FromRequestParts<S>,
211 Policy: FromRequestParts<S>,
212{
213 Router::new()
215 .route(
216 mas_router::OAuth2Keys::route(),
217 get(self::oauth2::keys::get),
218 )
219 .route(
220 mas_router::OidcUserinfo::route(),
221 get(self::oauth2::userinfo::get).post(self::oauth2::userinfo::get),
222 )
223 .route(
224 mas_router::OAuth2Introspection::route(),
225 post(self::oauth2::introspection::post),
226 )
227 .route(
228 mas_router::OAuth2Revocation::route(),
229 post(self::oauth2::revoke::post),
230 )
231 .route(
232 mas_router::OAuth2TokenEndpoint::route(),
233 post(self::oauth2::token::post),
234 )
235 .route(
236 mas_router::OAuth2RegistrationEndpoint::route(),
237 post(self::oauth2::registration::post),
238 )
239 .route(
240 mas_router::OAuth2DeviceAuthorizationEndpoint::route(),
241 post(self::oauth2::device::authorize::post),
242 )
243 .layer(
244 CorsLayer::new()
245 .allow_origin(Any)
246 .allow_methods(Any)
247 .allow_otel_headers([
248 AUTHORIZATION,
249 ACCEPT,
250 ACCEPT_LANGUAGE,
251 CONTENT_LANGUAGE,
252 CONTENT_TYPE,
253 HeaderName::from_static("x-requested-with"),
255 ])
256 .max_age(Duration::from_secs(60 * 60)),
257 )
258}
259
260#[allow(clippy::trait_duplication_in_bounds)]
261pub fn compat_router<S>(templates: Templates) -> Router<S>
262where
263 S: Clone + Send + Sync + 'static,
264 UrlBuilder: FromRef<S>,
265 SiteConfig: FromRef<S>,
266 Arc<dyn HomeserverConnection>: FromRef<S>,
267 PasswordManager: FromRef<S>,
268 Limiter: FromRef<S>,
269 BoxRepositoryFactory: FromRef<S>,
270 BoundActivityTracker: FromRequestParts<S>,
271 RequesterFingerprint: FromRequestParts<S>,
272 BoxRepository: FromRequestParts<S>,
273 BoxClock: FromRequestParts<S>,
274 BoxRng: FromRequestParts<S>,
275{
276 let human_router = Router::new()
278 .route(
279 mas_router::CompatLoginSsoRedirect::route(),
280 get(self::compat::login_sso_redirect::get),
281 )
282 .route(
283 mas_router::CompatLoginSsoRedirectIdp::route(),
284 get(self::compat::login_sso_redirect::get),
285 )
286 .route(
287 mas_router::CompatLoginSsoRedirectSlash::route(),
288 get(self::compat::login_sso_redirect::get),
289 )
290 .layer(AndThenLayer::new(
291 async move |response: axum::response::Response| {
292 Ok::<_, Infallible>(recover_error(&templates, response))
293 },
294 ));
295
296 let api_router = Router::new()
298 .route(
299 mas_router::CompatLogin::route(),
300 get(self::compat::login::get).post(self::compat::login::post),
301 )
302 .route(
303 mas_router::CompatLogout::route(),
304 post(self::compat::logout::post),
305 )
306 .route(
307 mas_router::CompatLogoutAll::route(),
308 post(self::compat::logout_all::post),
309 )
310 .route(
311 mas_router::CompatRefresh::route(),
312 post(self::compat::refresh::post),
313 )
314 .layer(
315 CorsLayer::new()
316 .allow_origin(Any)
317 .allow_methods(Any)
318 .allow_otel_headers([
319 AUTHORIZATION,
320 ACCEPT,
321 ACCEPT_LANGUAGE,
322 CONTENT_LANGUAGE,
323 CONTENT_TYPE,
324 HeaderName::from_static("x-requested-with"),
325 ])
326 .max_age(Duration::from_secs(60 * 60)),
327 );
328
329 Router::new().merge(human_router).merge(api_router)
330}
331
332pub fn human_router<S>(templates: Templates) -> Router<S>
333where
334 S: Clone + Send + Sync + 'static,
335 UrlBuilder: FromRef<S>,
336 PreferredLanguage: FromRequestParts<S>,
337 BoxRepository: FromRequestParts<S>,
338 CookieJar: FromRequestParts<S>,
339 BoundActivityTracker: FromRequestParts<S>,
340 RequesterFingerprint: FromRequestParts<S>,
341 Encrypter: FromRef<S>,
342 Templates: FromRef<S>,
343 Keystore: FromRef<S>,
344 PasswordManager: FromRef<S>,
345 MetadataCache: FromRef<S>,
346 SiteConfig: FromRef<S>,
347 Limiter: FromRef<S>,
348 reqwest::Client: FromRef<S>,
349 Arc<dyn HomeserverConnection>: FromRef<S>,
350 BoxClock: FromRequestParts<S>,
351 BoxRng: FromRequestParts<S>,
352 Policy: FromRequestParts<S>,
353{
354 Router::new()
355 .route(
357 "/account",
358 get(
359 async |State(url_builder): State<UrlBuilder>, RawQuery(query): RawQuery| {
360 let prefix = url_builder.prefix().unwrap_or_default();
361 let route = mas_router::Account::route();
362 let destination = if let Some(query) = query {
363 format!("{prefix}{route}?{query}")
364 } else {
365 format!("{prefix}{route}")
366 };
367
368 axum::response::Redirect::to(&destination)
369 },
370 ),
371 )
372 .route(mas_router::Account::route(), get(self::views::app::get))
373 .route(
374 mas_router::AccountWildcard::route(),
375 get(self::views::app::get),
376 )
377 .route(
378 mas_router::AccountRecoveryFinish::route(),
379 get(self::views::app::get_anonymous),
380 )
381 .route(
382 mas_router::ChangePasswordDiscovery::route(),
383 get(async |State(url_builder): State<UrlBuilder>| {
384 url_builder.redirect(&mas_router::AccountPasswordChange)
385 }),
386 )
387 .route(mas_router::Index::route(), get(self::views::index::get))
388 .route(
389 mas_router::Login::route(),
390 get(self::views::login::get).post(self::views::login::post),
391 )
392 .route(mas_router::Logout::route(), post(self::views::logout::post))
393 .route(
394 mas_router::Register::route(),
395 get(self::views::register::get),
396 )
397 .route(
398 mas_router::PasswordRegister::route(),
399 get(self::views::register::password::get).post(self::views::register::password::post),
400 )
401 .route(
402 mas_router::RegisterVerifyEmail::route(),
403 get(self::views::register::steps::verify_email::get)
404 .post(self::views::register::steps::verify_email::post),
405 )
406 .route(
407 mas_router::RegisterToken::route(),
408 get(self::views::register::steps::registration_token::get)
409 .post(self::views::register::steps::registration_token::post),
410 )
411 .route(
412 mas_router::RegisterDisplayName::route(),
413 get(self::views::register::steps::display_name::get)
414 .post(self::views::register::steps::display_name::post),
415 )
416 .route(
417 mas_router::RegisterFinish::route(),
418 get(self::views::register::steps::finish::get),
419 )
420 .route(
421 mas_router::AccountRecoveryStart::route(),
422 get(self::views::recovery::start::get).post(self::views::recovery::start::post),
423 )
424 .route(
425 mas_router::AccountRecoveryProgress::route(),
426 get(self::views::recovery::progress::get).post(self::views::recovery::progress::post),
427 )
428 .route(
429 mas_router::OAuth2AuthorizationEndpoint::route(),
430 get(self::oauth2::authorization::get),
431 )
432 .route(
433 mas_router::Consent::route(),
434 get(self::oauth2::authorization::consent::get)
435 .post(self::oauth2::authorization::consent::post),
436 )
437 .route(
438 mas_router::CompatLoginSsoComplete::route(),
439 get(self::compat::login_sso_complete::get).post(self::compat::login_sso_complete::post),
440 )
441 .route(
442 mas_router::UpstreamOAuth2Authorize::route(),
443 get(self::upstream_oauth2::authorize::get),
444 )
445 .route(
446 mas_router::UpstreamOAuth2Callback::route(),
447 get(self::upstream_oauth2::callback::handler)
448 .post(self::upstream_oauth2::callback::handler),
449 )
450 .route(
451 mas_router::UpstreamOAuth2Link::route(),
452 get(self::upstream_oauth2::link::get).post(self::upstream_oauth2::link::post),
453 )
454 .route(
455 mas_router::UpstreamOAuth2BackchannelLogout::route(),
456 post(self::upstream_oauth2::backchannel_logout::post),
457 )
458 .route(
459 mas_router::DeviceCodeLink::route(),
460 get(self::oauth2::device::link::get),
461 )
462 .route(
463 mas_router::DeviceCodeConsent::route(),
464 get(self::oauth2::device::consent::get).post(self::oauth2::device::consent::post),
465 )
466 .layer(AndThenLayer::new(
467 async move |response: axum::response::Response| {
468 Ok::<_, Infallible>(recover_error(&templates, response))
469 },
470 ))
471}
472
473fn recover_error(
474 templates: &Templates,
475 response: axum::response::Response,
476) -> axum::response::Response {
477 let ext = response.extensions().get::<ErrorContext>();
479 if let Some(ctx) = ext
480 && let Ok(res) = templates.render_error(ctx)
481 {
482 let (mut parts, _original_body) = response.into_parts();
483 parts.headers.remove(CONTENT_TYPE);
484 parts.headers.remove(CONTENT_LENGTH);
485 return (parts, Html(res)).into_response();
486 }
487
488 response
489}
490
491pub async fn fallback(
497 State(templates): State<Templates>,
498 OriginalUri(uri): OriginalUri,
499 method: Method,
500 version: Version,
501 PreferredLanguage(locale): PreferredLanguage,
502) -> Result<impl IntoResponse, InternalError> {
503 let ctx = NotFoundContext::new(&method, version, &uri).with_language(locale);
504 let res = templates.render_not_found(&ctx)?;
507
508 Ok((StatusCode::NOT_FOUND, Html(res)))
509}