• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

divviup / divviup-api / 25222136258

01 May 2026 04:15PM UTC coverage: 57.824% (+0.6%) from 57.194%
25222136258

push

github

web-flow
Migrate from Trillium [part 6]: auth routes (#2234)

* Migrate from Trillium [part 6]: auth routes

Move `/login`, `/logout`, and `/callback` from the Trillium router to
Axum handlers. I've pulled in `tower-sessions` 0.15 with the
existing `divviup.sid` cookie name, backed by the same `session`
database table via the `TowerSessionStore` added in part 3.

- `oauth2::redirect` / `oauth2::callback` / `misc::logout` rewritten as
  Axum handlers; the Trillium versions are removed.
- `User` gains `FromRequestParts` and `OptionalFromRequestParts` impls;
  `from_parts` now returns `Option<User>` and the `PermissionsActor`
  extractor is updated to match.
- `Error` gains `CallbackCsrfMismatch`, `CallbackMissingPkce`, and
  `CallbackMissingCode` variants, all mapped to 403 to match the
  previous Trillium behavior.
- `OauthClient` is added to `AxumAppState` via `FromRef`.
- `AxumProxy` disables reqwest's default redirect-following so that 302
  responses from Axum handlers (e.g. `/login` redirecting to Auth0) are
  passed back to the caller instead of followed by the proxy.
- When `debug_assertions` is enabled, it compiles in an Axum middleware that reads an
  `X-Integration-Testing-User` header and injects the decoded `User`
  into request extensions, letting integration tests simulate a
  logged-in user through the proxy path. `TestExt::with_user(&user)`
  replaces `with_state(user)` for routes that have migrated.
- The logout test no longer asserts `is_destroyed()`; the cookie
  clearing is tower-sessions's responsibility and will be covered
  end-to-end in part 8 when Trillium is removed.

95 of 136 new or added lines in 9 files covered. (69.85%)

8 existing lines in 2 files now uncovered.

4412 of 7630 relevant lines covered (57.82%)

60.48 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

36.91
/src/handler/oauth2.rs
1
use crate::{handler::Error, Config, User, USER_SESSION_KEY};
2
use axum::{
3
    extract::{Query, State},
4
    http::{header, StatusCode},
5
    response::{IntoResponse, Response},
6
};
7
use oauth2::{
8
    basic::{BasicClient, BasicErrorResponseType},
9
    AsyncHttpClient, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EndpointNotSet,
10
    EndpointSet, HttpRequest, HttpResponse, PkceCodeChallenge, PkceCodeVerifier, RedirectUrl,
11
    RequestTokenError, Scope, StandardErrorResponse, TokenResponse, TokenUrl,
12
};
13
use serde::Deserialize;
14
use std::{future::Future, pin::Pin, sync::Arc};
15
use tower_sessions::Session;
16
use trillium::{KnownHeaderName::Authorization, Status};
17
use trillium_client::{Client, ClientSerdeError};
18
use trillium_http::Headers;
19
use url::Url;
20

21
/// Type alias for an oauth2::Client once we've finished configuring it in `OauthClient::new`.
22
/// Crate oauth's guide to upgrading to 0.5 recommends defining this kind of alias:
23
/// https://github.com/ramosbugs/oauth2-rs/blob/main/UPGRADE.md#add-typestate-generic-types-to-client
24
pub type ConfiguredOauthClient = BasicClient<
25
    EndpointSet,    // HasAuthURL
26
    EndpointNotSet, // HasDeviceAuthURL
27
    EndpointNotSet, // HasIntrospectionURL
28
    EndpointNotSet, // HasRevocationURL
29
    EndpointSet,    // HasTokenURL
30
>;
31

32
#[derive(Clone, Debug)]
33
pub struct Oauth2Config {
34
    pub authorize_url: Url,
35
    pub token_url: Url,
36
    pub client_id: String,
37
    pub client_secret: String,
38
    pub redirect_url: Url,
39
    pub base_url: Url,
40
    pub audience: String,
41
    pub http_client: Client,
42
}
43

44
const PKCE_SESSION_KEY: &str = "pkce_verifier";
45
const CSRF_SESSION_KEY: &str = "csrf_token";
46

47
/// `GET /login` — start the OAuth flow, or short-circuit to the app if the
48
/// user is already logged in.
49
pub async fn redirect(
2✔
50
    State(oauth_client): State<OauthClient>,
2✔
51
    State(config): State<Arc<Config>>,
2✔
52
    session: Session,
2✔
53
    user: Option<User>,
2✔
54
) -> Result<Response, Error> {
2✔
55
    if user.is_some() {
2✔
56
        return Ok(found_redirect(config.app_url.as_ref()));
1✔
57
    }
1✔
58

59
    let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256();
1✔
60

61
    let (auth_url, csrf_token) = oauth_client
1✔
62
        .oauth2_client()
1✔
63
        .authorize_url(CsrfToken::new_random)
1✔
64
        .add_scope(Scope::new(String::from("openid")))
1✔
65
        .add_scope(Scope::new(String::from("profile")))
1✔
66
        .add_scope(Scope::new(String::from("email")))
1✔
67
        .set_pkce_challenge(pkce_challenge)
1✔
68
        .url();
1✔
69

70
    session.insert(PKCE_SESSION_KEY, pkce_verifier).await?;
1✔
71
    session.insert(CSRF_SESSION_KEY, csrf_token).await?;
1✔
72

73
    Ok(found_redirect(auth_url.as_str()))
1✔
74
}
2✔
75

76
#[derive(Debug, Deserialize)]
77
pub struct CallbackParams {
78
    pub code: Option<String>,
79
    pub state: Option<String>,
80
}
81

82
/// `GET /callback` — exchange the authorization code for tokens, then stash
83
/// the user in the session and redirect to the app.
NEW
84
pub async fn callback(
×
NEW
85
    State(oauth_client): State<OauthClient>,
×
NEW
86
    State(config): State<Arc<Config>>,
×
NEW
87
    session: Session,
×
NEW
88
    Query(params): Query<CallbackParams>,
×
NEW
89
) -> Result<Response, Error> {
×
NEW
90
    let auth_code = params
×
NEW
91
        .code
×
NEW
92
        .map(AuthorizationCode::new)
×
NEW
93
        .ok_or(Error::CallbackMissingCode)?;
×
94

NEW
95
    let pkce_verifier: PkceCodeVerifier = session
×
NEW
96
        .get(PKCE_SESSION_KEY)
×
NEW
97
        .await?
×
NEW
98
        .ok_or(Error::CallbackMissingPkce)?;
×
99

NEW
100
    let session_csrf: Option<String> = session.get(CSRF_SESSION_KEY).await?;
×
NEW
101
    match (session_csrf, &params.state) {
×
NEW
102
        (Some(a), Some(b)) if a == *b => {}
×
NEW
103
        _ => return Err(Error::CallbackCsrfMismatch),
×
104
    }
105

NEW
106
    let user = oauth_client
×
NEW
107
        .exchange_code_for_user(auth_code, pkce_verifier)
×
NEW
108
        .await?;
×
109

NEW
110
    session.insert(USER_SESSION_KEY, user).await?;
×
111

NEW
112
    Ok(found_redirect(config.app_url.as_ref()))
×
NEW
113
}
×
114

115
/// `GET /logout` — destroy the session and redirect to Auth0's logout URL so
116
/// the IdP session is also cleared.
117
pub async fn logout(
1✔
118
    State(config): State<Arc<Config>>,
1✔
119
    session: Session,
1✔
120
) -> Result<Response, Error> {
1✔
121
    session.flush().await?;
1✔
122

123
    let mut logout_url = config.auth_url.join("/v2/logout")?;
1✔
124
    logout_url.query_pairs_mut().extend_pairs([
1✔
125
        ("client_id", &*config.auth_client_id),
1✔
126
        ("returnTo", config.app_url.as_ref()),
1✔
127
    ]);
1✔
128

129
    Ok(found_redirect(logout_url.as_ref()))
1✔
130
}
1✔
131

132
fn found_redirect(location: &str) -> Response {
3✔
133
    (
3✔
134
        StatusCode::FOUND,
3✔
135
        [(header::LOCATION, location.to_string())],
3✔
136
    )
3✔
137
        .into_response()
3✔
138
}
3✔
139

140
#[derive(thiserror::Error, Debug)]
141
pub enum OauthError {
142
    #[error(transparent)]
143
    HttpError(#[from] trillium_client::Error),
144
    #[error(transparent)]
145
    InvalidStatusCode(#[from] oauth2::http::status::InvalidStatusCode),
146
    #[error(transparent)]
147
    HeaderConversionError(#[from] trillium_http::http_compat1::HeaderConversionError),
148
    #[error(transparent)]
149
    UrlError(#[from] url::ParseError),
150
    #[error("error response: {0}")]
151
    RequestTokenError(StandardErrorResponse<BasicErrorResponseType>),
152
    #[error(transparent)]
153
    Serde(#[from] serde_json::error::Error),
154
    #[error("Other error: {0}")]
155
    Other(String),
156
    #[error("expected a successful status, but found {0:?}")]
157
    UnexpectedStatus(Option<Status>),
158
    #[error(transparent)]
159
    HttpCrateError(#[from] oauth2::http::Error),
160
}
161

162
impl From<RequestTokenError<OauthError, StandardErrorResponse<BasicErrorResponseType>>>
163
    for OauthError
164
{
165
    fn from(
×
166
        value: RequestTokenError<OauthError, StandardErrorResponse<BasicErrorResponseType>>,
×
167
    ) -> Self {
×
168
        match value {
×
169
            RequestTokenError::ServerResponse(server_response) => {
×
170
                OauthError::RequestTokenError(server_response)
×
171
            }
172
            RequestTokenError::Request(e) => e,
×
173
            RequestTokenError::Parse(error, _path) => OauthError::Serde(error.into_inner()),
×
174
            RequestTokenError::Other(s) => OauthError::Other(s),
×
175
        }
176
    }
×
177
}
178

179
impl From<ClientSerdeError> for OauthError {
180
    fn from(value: ClientSerdeError) -> Self {
×
181
        match value {
×
182
            ClientSerdeError::HttpError(e) => OauthError::HttpError(e),
×
183
            ClientSerdeError::JsonError(e) => OauthError::Serde(e),
×
184
        }
185
    }
×
186
}
187

188
impl From<OauthError> for Error {
NEW
189
    fn from(value: OauthError) -> Self {
×
NEW
190
        Self::Other(Arc::new(value))
×
NEW
191
    }
×
192
}
193

194
#[derive(Clone, Debug)]
195
pub struct OauthClient(Arc<OauthClientInner>);
196

197
#[derive(Debug)]
198
struct OauthClientInner {
199
    oauth_config: Oauth2Config,
200
    oauth2_client: ConfiguredOauthClient,
201
}
202

203
impl OauthClient {
204
    async fn exchange_code_for_user(
×
205
        &self,
×
206
        auth_code: AuthorizationCode,
×
207
        pkce_verifier: PkceCodeVerifier,
×
208
    ) -> Result<User, OauthError> {
×
209
        let http_client = self.http_client().clone();
×
210
        let exchange = self
×
211
            .oauth2_client()
×
212
            .exchange_code(auth_code)
×
213
            .set_pkce_verifier(pkce_verifier)
×
214
            .add_extra_param("audience", &self.0.oauth_config.audience)
×
215
            .request_async(&ClientWrapper(http_client))
×
216
            .await?;
×
217

218
        let mut client_conn = self
×
219
            .http_client()
×
220
            .get(self.0.oauth_config.base_url.join("/userinfo")?)
×
221
            .with_request_header(
×
222
                Authorization,
×
223
                format!("Bearer {}", exchange.access_token().secret()),
×
224
            )
225
            .await?;
×
226
        if !client_conn
×
227
            .status()
×
228
            .as_ref()
×
229
            .map(Status::is_success)
×
230
            .unwrap_or_default()
×
231
        {
232
            return Err(OauthError::UnexpectedStatus(client_conn.status()));
×
233
        }
×
234

235
        Ok(client_conn.response_json().await?)
×
236
    }
×
237

238
    pub fn new(config: &Oauth2Config) -> Self {
278✔
239
        let oauth2_client = BasicClient::new(ClientId::new(config.client_id.clone()))
278✔
240
            .set_client_secret(ClientSecret::new(config.client_secret.clone()))
278✔
241
            .set_auth_uri(AuthUrl::from_url(config.authorize_url.clone()))
278✔
242
            .set_token_uri(TokenUrl::from_url(config.token_url.clone()))
278✔
243
            .set_redirect_uri(RedirectUrl::from_url(config.redirect_url.clone()));
278✔
244

245
        Self(Arc::new(OauthClientInner {
278✔
246
            oauth_config: config.clone(),
278✔
247
            oauth2_client,
278✔
248
        }))
278✔
249
    }
278✔
250

251
    pub fn oauth2_client(&self) -> &ConfiguredOauthClient {
1✔
252
        &self.0.oauth2_client
1✔
253
    }
1✔
254

255
    pub fn http_client(&self) -> &Client {
×
256
        &self.0.oauth_config.http_client
×
257
    }
×
258
}
259

260
// Wraps a [`trillium_client::Client`] so we can implement [`oauth2::AsyncHttpClient`] on it, as
261
// otherwise the orphan rule would forbid this.
262
struct ClientWrapper(Client);
263

264
// Inspired by the impls `oauth2` provides for `reqwest::Client`
265
// https://github.com/ramosbugs/oauth2-rs/blob/23b952b23e6069525bc7e4c4f2c4924b8d28ce3a/src/reqwest.rs
266
impl<'c> AsyncHttpClient<'c> for ClientWrapper {
267
    type Error = OauthError;
268
    type Future = Pin<Box<dyn Future<Output = Result<HttpResponse, Self::Error>> + Send + 'c>>;
269

270
    fn call(&'c self, req: HttpRequest) -> Self::Future {
×
271
        Box::pin(async move {
×
272
            // Translate the oauth2::http::Request into a Trillium request
273
            let mut conn = self
×
274
                .0
×
275
                .build_conn(req.method(), req.uri().to_string().parse::<Url>()?)
×
276
                .with_body(req.body().clone())
×
277
                .with_request_headers(Headers::from(req.headers().clone()))
×
278
                .await?;
×
279
            let status_code: oauth2::http::StatusCode = conn.status().unwrap().try_into()?;
×
280
            let body = conn.response_body().read_bytes().await?;
×
281

282
            // Now transform the Trillium response back into an http::Response
283
            let mut builder = oauth2::http::Response::builder().status(status_code);
×
284
            let http_headers: oauth2::http::HeaderMap =
×
285
                conn.response_headers().clone().try_into()?;
×
286
            builder
×
287
                .headers_mut()
×
288
                .ok_or_else(|| OauthError::Other("no headers in builder?".into()))?
×
289
                .extend(http_headers);
×
290
            Ok::<_, OauthError>(builder.body(body)?)
×
291
        })
×
292
    }
×
293
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc