mas_handlers/admin/v1/user_registration_tokens/
add.rs

1// Copyright 2025 New Vector Ltd.
2// Copyright 2025 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
5// Please see LICENSE files in the repository root for full details.
6
7use aide::{NoApi, OperationIo, transform::TransformOperation};
8use axum::{Json, response::IntoResponse};
9use chrono::{DateTime, Utc};
10use hyper::StatusCode;
11use mas_axum_utils::record_error;
12use mas_storage::BoxRng;
13use rand::distributions::{Alphanumeric, DistString};
14use schemars::JsonSchema;
15use serde::Deserialize;
16
17use crate::{
18    admin::{
19        call_context::CallContext,
20        model::UserRegistrationToken,
21        response::{ErrorResponse, SingleResponse},
22    },
23    impl_from_error_for_route,
24};
25
26#[derive(Debug, thiserror::Error, OperationIo)]
27#[aide(output_with = "Json<ErrorResponse>")]
28pub enum RouteError {
29    #[error("A registration token with the same token already exists")]
30    Conflict(mas_data_model::UserRegistrationToken),
31
32    #[error(transparent)]
33    Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
34}
35
36impl_from_error_for_route!(mas_storage::RepositoryError);
37
38impl IntoResponse for RouteError {
39    fn into_response(self) -> axum::response::Response {
40        let error = ErrorResponse::from_error(&self);
41        let sentry_event_id = record_error!(self, Self::Internal(_));
42        let status = match self {
43            Self::Conflict(_) => StatusCode::CONFLICT,
44            Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
45        };
46        (status, sentry_event_id, Json(error)).into_response()
47    }
48}
49
50/// # JSON payload for the `POST /api/admin/v1/user-registration-tokens`
51#[derive(Deserialize, JsonSchema)]
52#[serde(rename = "AddUserRegistrationTokenRequest")]
53pub struct Request {
54    /// The token string. If not provided, a random token will be generated.
55    token: Option<String>,
56
57    /// Maximum number of times this token can be used. If not provided, the
58    /// token can be used an unlimited number of times.
59    usage_limit: Option<u32>,
60
61    /// When the token expires. If not provided, the token never expires.
62    expires_at: Option<DateTime<Utc>>,
63}
64
65pub fn doc(operation: TransformOperation) -> TransformOperation {
66    operation
67        .id("addUserRegistrationToken")
68        .summary("Create a new user registration token")
69        .tag("user-registration-token")
70        .response_with::<201, Json<SingleResponse<UserRegistrationToken>>, _>(|t| {
71            let [sample, ..] = UserRegistrationToken::samples();
72            let response = SingleResponse::new_canonical(sample);
73            t.description("A new user registration token was created")
74                .example(response)
75        })
76}
77
78#[tracing::instrument(name = "handler.admin.v1.user_registration_tokens.post", skip_all)]
79pub async fn handler(
80    CallContext {
81        mut repo, clock, ..
82    }: CallContext,
83    NoApi(mut rng): NoApi<BoxRng>,
84    Json(params): Json<Request>,
85) -> Result<(StatusCode, Json<SingleResponse<UserRegistrationToken>>), RouteError> {
86    // Generate a random token if none was provided
87    let token = params
88        .token
89        .unwrap_or_else(|| Alphanumeric.sample_string(&mut rng, 12));
90
91    // See if we have an existing token with the same token
92    let existing_token = repo.user_registration_token().find_by_token(&token).await?;
93    if let Some(existing_token) = existing_token {
94        return Err(RouteError::Conflict(existing_token));
95    }
96
97    let registration_token = repo
98        .user_registration_token()
99        .add(
100            &mut rng,
101            &clock,
102            token,
103            params.usage_limit,
104            params.expires_at,
105        )
106        .await?;
107
108    repo.save().await?;
109
110    Ok((
111        StatusCode::CREATED,
112        Json(SingleResponse::new_canonical(UserRegistrationToken::new(
113            registration_token,
114            clock.now(),
115        ))),
116    ))
117}
118
119#[cfg(test)]
120mod tests {
121    use hyper::{Request, StatusCode};
122    use insta::assert_json_snapshot;
123    use sqlx::PgPool;
124
125    use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
126
127    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
128    async fn test_create(pool: PgPool) {
129        setup();
130        let mut state = TestState::from_pool(pool).await.unwrap();
131        let token = state.token_with_scope("urn:mas:admin").await;
132
133        let request = Request::post("/api/admin/v1/user-registration-tokens")
134            .bearer(&token)
135            .json(serde_json::json!({
136                "token": "test_token_123",
137                "usage_limit": 5,
138            }));
139        let response = state.request(request).await;
140        response.assert_status(StatusCode::CREATED);
141        let body: serde_json::Value = response.json();
142
143        assert_json_snapshot!(body, @r#"
144        {
145          "data": {
146            "type": "user-registration_token",
147            "id": "01FSHN9AG0MZAA6S4AF7CTV32E",
148            "attributes": {
149              "token": "test_token_123",
150              "valid": true,
151              "usage_limit": 5,
152              "times_used": 0,
153              "created_at": "2022-01-16T14:40:00Z",
154              "last_used_at": null,
155              "expires_at": null,
156              "revoked_at": null
157            },
158            "links": {
159              "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
160            }
161          },
162          "links": {
163            "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
164          }
165        }
166        "#);
167    }
168
169    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
170    async fn test_create_auto_token(pool: PgPool) {
171        setup();
172        let mut state = TestState::from_pool(pool).await.unwrap();
173        let token = state.token_with_scope("urn:mas:admin").await;
174
175        let request = Request::post("/api/admin/v1/user-registration-tokens")
176            .bearer(&token)
177            .json(serde_json::json!({
178                "usage_limit": 1
179            }));
180        let response = state.request(request).await;
181        response.assert_status(StatusCode::CREATED);
182
183        let body: serde_json::Value = response.json();
184
185        assert_json_snapshot!(body, @r#"
186        {
187          "data": {
188            "type": "user-registration_token",
189            "id": "01FSHN9AG0QMGC989M0XSFVF2X",
190            "attributes": {
191              "token": "42oTpLoieH5I",
192              "valid": true,
193              "usage_limit": 1,
194              "times_used": 0,
195              "created_at": "2022-01-16T14:40:00Z",
196              "last_used_at": null,
197              "expires_at": null,
198              "revoked_at": null
199            },
200            "links": {
201              "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0QMGC989M0XSFVF2X"
202            }
203          },
204          "links": {
205            "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0QMGC989M0XSFVF2X"
206          }
207        }
208        "#);
209    }
210
211    #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
212    async fn test_create_conflict(pool: PgPool) {
213        setup();
214        let mut state = TestState::from_pool(pool).await.unwrap();
215        let token = state.token_with_scope("urn:mas:admin").await;
216
217        let request = Request::post("/api/admin/v1/user-registration-tokens")
218            .bearer(&token)
219            .json(serde_json::json!({
220                "token": "test_token_123",
221                "usage_limit": 5
222            }));
223        let response = state.request(request).await;
224        response.assert_status(StatusCode::CREATED);
225
226        let body: serde_json::Value = response.json();
227
228        assert_json_snapshot!(body, @r#"
229        {
230          "data": {
231            "type": "user-registration_token",
232            "id": "01FSHN9AG0MZAA6S4AF7CTV32E",
233            "attributes": {
234              "token": "test_token_123",
235              "valid": true,
236              "usage_limit": 5,
237              "times_used": 0,
238              "created_at": "2022-01-16T14:40:00Z",
239              "last_used_at": null,
240              "expires_at": null,
241              "revoked_at": null
242            },
243            "links": {
244              "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
245            }
246          },
247          "links": {
248            "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
249          }
250        }
251        "#);
252
253        let request = Request::post("/api/admin/v1/user-registration-tokens")
254            .bearer(&token)
255            .json(serde_json::json!({
256                "token": "test_token_123",
257                "usage_limit": 5
258            }));
259        let response = state.request(request).await;
260        response.assert_status(StatusCode::CONFLICT);
261    }
262}