Parameterizing Axum Middleware

Pantheon is a project at Develop for Good that I’ve been working on for a while now. It’s an app which centralizes all our data sources and provides ways to take actions on the data that the providers of the data can’t. For example, we use the Google Workspace nonprofits plan to manage our volunteers and the volunteer applications are received on Airtable, and need to be able to conveniently batch export from Airtable into Workspace. This is a type of service that Pantheon provides.


Pantheon is written in Rust and uses Axum to do the heavy lifting on tasks like routing, subrouting, and extraction of parameters. Axum also provides a few ways to inject custom middleware. The middleware::from_fn and middleware::from_fn_with_state functions in the axum::middleware module allow the generation of middleware by passing a function which resolves to a future, or some state and a function which resolves to a future respectively. That usage is relatively simple. For example (adapted from the Axum docs):

src/main.rs
10 collapsed lines
1
use anyhow::Result;
2
use axum::{
3
Router,
4
http,
5
routing::get,
6
response::Response,
7
middleware::{self, Next},
8
extract::Request,
9
};
10
use tokio::net::TcpListener;
11
12
async fn some_middleware(
13
request: Request,
14
next: Next,
15
) -> Response {
16
// do something with `request`...
17
18
let response = next.run(request).await;
19
20
// do something with `response`...
21
22
response
23
}
24
25
async fn main() -> Result<()> {
26
let app = Router::new()
27
.route("/", get(|| async { /* ... */ }))
28
.layer(middleware::from_fn(some_middleware));
29
30
let listener = TcpListener::bind(addr).await?;
31
32
axum::serve(listener, srv).await?;
33
34
Ok(())
35
}

The code block above showcases how you can generate middleware using from_fn. The from_fn_with_state function is the exact same, except you pass an additional state parameter into the function that the middleware gets access to.


The difficulty arises when you need to partially apply the middleware. Let’s discuss this with the example of authentication middleware. At Develop for Good, we have different positions, including volunteers, product leads, management team members. All of them, in addition to our clients, have access to some parts of Pantheon. Additionally, the roles that we have are not stable. As a young organization we will likely add more roles or remove some. Building role based authentication, then, requires flexibility.

The base of our authentication middleware (without role-based access control capabilities) might look like:

src/middleware.rs
11 collapsed lines
1
use axum::{
2
extract::{Request, State},
3
middleware::Next,
4
};
5
6
use axum_extra::{
7
headers::{authorization::Bearer, Authorization},
8
TypedHeader,
9
};
10
11
use crate::app::context::Context;
12
13
pub async fn auth(
14
// Some opaque app state that we extract. In pantheon we store an authenticator implementation in `Context`
15
State(ctx): State<Arc<Context>>,
16
header: TypedHeader<Authorization<Bearer>>,
17
mut request: Request,
18
next: Next,
19
) -> Response {
20
// authenticate the user using the JWT in `header` and get the user data.
21
22
// let data = <result of authentication if successful>
23
// insert data into request (request.extensions_mut.insert(/*...*/))
24
25
let response = next.run(request).await;
26
27
28
response
29
}

This authentication middleware extracts some state, the JWT from the request, the request, and the remainder of the middleware stack (next). It doesn’t do any analysis on the user’s roles (presumably contained in the token’s data). The issue is, the authentication middleware is called before the endpoint is reached. Each endpoint may differ in the type of user allowed to call it. We could duplicate the code (or generate it at compile time using procedural macros) for each function. For example, we could write/generate functions called volunteer_guard, product_lead_guard, management_guard, but this gets messy, and it doesn’t handle the combinations between roles. It would be ideal if we could supply some list of roles to the authentication middleware and it would just use those to do the role checking part of the authentication process. The new authentication middleware might look like:

src/middleware.rs
11 collapsed lines
1
use axum::{
2
extract::{Request, State},
3
middleware::Next,
4
};
5
6
use axum_extra::{
7
headers::{authorization::Bearer, Authorization},
8
TypedHeader,
9
};
10
11
use crate::app::context::Context;
12
13
pub async fn rbac(
14
// Some opaque app state that we extract. In pantheon we store an authenticator implementation in `Context`
15
State(ctx): State<Arc<Context>>,
16
header: TypedHeader<Authorization<Bearer>>,
17
mut request: Request,
18
next: Next,
19
roles: Vec<String>,
20
) -> Response {
21
// authenticate the user using the JWT in `header` and get the user data.
22
// let data = <result of authentication if successful>
23
24
// Analyze the roles in `data` against the roles passed into rbac. If it doesn't contain them, error with a 403
25
26
// insert data into request (request.extensions_mut.insert(/*...*/))
27
28
let response = next.run(request).await;
29
30
31
response
32
}

However, we can’t extract roles and we also don’t have a way to parameterize it in its current form. Axum’s from_fn and from_fn_with_state take the function you want to turn into middleware, not the result of that function. Basically, we need a function which takes some roles and returns a partially applied middleware by calling it with that roles value. The way you can do this is as follows:

src/middleware.rs
33 collapsed lines
1
use axum::{
2
extract::{Request, State},
3
middleware::Next,
4
};
5
6
use axum_extra::{
7
headers::{authorization::Bearer, Authorization},
8
TypedHeader,
9
};
10
11
use crate::app::context::Context;
12
13
pub async fn rbac(
14
// Some opaque app state that we extract. In pantheon we store an authenticator implementation in `Context`
15
State(ctx): State<Arc<Context>>,
16
header: TypedHeader<Authorization<Bearer>>,
17
mut request: Request,
18
next: Next,
19
roles: Vec<String>,
20
) -> Response {
21
// authenticate the user using the JWT in `header` and get the user data.
22
// let data = <result of authentication if successful>
23
24
// Analyze the roles in `data` against the roles passed into rbac. If it doesn't contain them, error with a 403
25
26
// insert data into request (request.extensions_mut.insert(/*...*/))
27
28
let response = next.run(request).await;
29
30
31
response
32
}
33
34
35
pub async fn make_rbac(
36
roles: Vec<String>,
37
) -> impl Fn(
38
State<Arc<Context>>,
39
TypedHeader<Authorization<Bearer>>,
40
Request,
41
Next,
42
) -> Pin<Box<dyn Future<Output = Result<impl IntoResponse, AppError>> + Send>>
43
+ Clone {
44
move |state, header, request, next| Box::pin(rbac(state, header, request, next, roles.clone()))
45
}

This took me a while to figure out, and I want to break it down.


The make_rbac function takes some vector of roles and returns a function. That function takes some Axum State parameterized by an atomically reference counting smart pointer wrapping the context or state we want to be available, the authorization header of form “Bearer <token>”, the request object, and the remainder of the middleware stack. The function then returns a pinned, boxed future object which will resolve to a result of either something which implements IntoResponse or an AppError. AppError is just an error struct that we use in Pantheon. Futures are self-referential in rust so that they can be polled, so pinning it is necessary to make it immovable. The Send bound is necessary for the future to be sendable across threads, and the Clone bound is required by Axum. This function body is one line; all it does is take ownership of the variables and return a pinned future. It parameterizes rbac by passing in the roles.

Now it can be used as follows:

src/main.rs
46 collapsed lines
1
use axum::{
2
extract::{Request, State},
3
middleware::Next,
4
};
5
6
use axum_extra::{
7
headers::{authorization::Bearer, Authorization},
8
TypedHeader,
9
};
10
11
use crate::app::context::Context;
12
13
pub async fn rbac(
14
// Some opaque app state that we extract. In pantheon we store an authenticator implementation in `Context`
15
State(ctx): State<Arc<Context>>,
16
header: TypedHeader<Authorization<Bearer>>,
17
mut request: Request,
18
next: Next,
19
roles: Vec<String>,
20
) -> Response {
21
// authenticate the user using the JWT in `header` and get the user data.
22
// let data = <result of authentication if successful>
23
24
// Analyze the roles in `data` against the roles passed into rbac. If it doesn't contain them, error with a 403
25
26
// insert data into request (request.extensions_mut.insert(/*...*/))
27
28
let response = next.run(request).await;
29
30
31
response
32
}
33
34
35
pub async fn make_rbac(
36
roles: Vec<String>,
37
) -> impl Fn(
38
State<Arc<Context>>,
39
TypedHeader<Authorization<Bearer>>,
40
Request,
41
Next,
42
) -> Pin<Box<dyn Future<Output = Result<impl IntoResponse, AppError>> + Send>>
43
+ Clone {
44
move |state, header, request, next| Box::pin(rbac(state, header, request, next, roles.clone()))
45
}
46
47
async fn main() -> Result<()> {
48
49
let admin_rbac = make_rbac(vec!["admin".to_string()]).await;
50
51
let ctx = Context { /*...*/ };
52
53
let app = Router::new()
54
.route("/", get(|| async { /* ... */ }))
55
.route_layer(middleware::from_fn_with_state(ctx.clone(), admin_rbac))
56
.with_state(ctx);
57
58
let listener = TcpListener::bind(addr).await?;
59
60
axum::serve(listener, srv).await?;
61
62
Ok(())
63
}

Lastly, this is the AppError struct we use at Pantheon. It just wraps an anyhow::Error.

src/errors.rs
1
//! Generic error constructs to use for API error reponses
2
3
use anyhow::Error;
4
use axum::http::StatusCode;
5
use axum::response::{IntoResponse, Response};
6
7
// This just takes a status code and some JSON body and turns it into an Axum response
8
use crate::app::api_response;
9
10
/// A generic app error sent by all handlers on failure
11
pub struct AppError(Error);
12
13
impl IntoResponse for AppError {
14
fn into_response(self) -> Response {
15
log::error!("{}", self.0.to_string());
16
api_response::error(StatusCode::INTERNAL_SERVER_ERROR, "Something went wrong!")
17
}
18
}
19
20
impl<E> From<E> for AppError
21
where
22
E: Into<Error>,
23
{
24
fn from(err: E) -> Self {
25
Self(err.into())
26
}
27
}

rust
axum