mas_axum_utils/
cookies.rs1use std::convert::Infallible;
10
11use axum::{
12    extract::{FromRef, FromRequestParts},
13    response::{IntoResponseParts, ResponseParts},
14};
15use axum_extra::extract::cookie::{Cookie, Key, PrivateCookieJar, SameSite};
16use http::request::Parts;
17use serde::{Serialize, de::DeserializeOwned};
18use thiserror::Error;
19use url::Url;
20
21#[derive(Debug, Error)]
22#[error("could not decode cookie")]
23pub enum CookieDecodeError {
24    Deserialize(#[from] serde_json::Error),
25}
26
27#[derive(Clone)]
32pub struct CookieManager {
33    options: CookieOption,
34    key: Key,
35}
36
37impl CookieManager {
38    #[must_use]
39    pub const fn new(base_url: Url, key: Key) -> Self {
40        let options = CookieOption::new(base_url);
41        Self { options, key }
42    }
43
44    #[must_use]
45    pub fn derive_from(base_url: Url, key: &[u8]) -> Self {
46        let key = Key::derive_from(key);
47        Self::new(base_url, key)
48    }
49
50    #[must_use]
51    pub fn cookie_jar(&self) -> CookieJar {
52        let inner = PrivateCookieJar::new(self.key.clone());
53        let options = self.options.clone();
54
55        CookieJar { inner, options }
56    }
57
58    #[must_use]
59    pub fn cookie_jar_from_headers(&self, headers: &http::HeaderMap) -> CookieJar {
60        let inner = PrivateCookieJar::from_headers(headers, self.key.clone());
61        let options = self.options.clone();
62
63        CookieJar { inner, options }
64    }
65}
66
67impl<S> FromRequestParts<S> for CookieJar
68where
69    CookieManager: FromRef<S>,
70    S: Send + Sync,
71{
72    type Rejection = Infallible;
73
74    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
75        let cookie_manager = CookieManager::from_ref(state);
76        Ok(cookie_manager.cookie_jar_from_headers(&parts.headers))
77    }
78}
79
80#[derive(Debug, Clone)]
81struct CookieOption {
82    base_url: Url,
83}
84
85impl CookieOption {
86    const fn new(base_url: Url) -> Self {
87        Self { base_url }
88    }
89
90    fn secure(&self) -> bool {
91        self.base_url.scheme() == "https"
92    }
93
94    fn path(&self) -> &str {
95        self.base_url.path()
96    }
97
98    fn apply<'a>(&self, mut cookie: Cookie<'a>) -> Cookie<'a> {
99        cookie.set_http_only(true);
100        cookie.set_secure(self.secure());
101        cookie.set_path(self.path().to_owned());
102        cookie.set_same_site(SameSite::Lax);
103        cookie
104    }
105}
106
107pub struct CookieJar {
109    inner: PrivateCookieJar<Key>,
110    options: CookieOption,
111}
112
113impl CookieJar {
114    #[must_use]
122    pub fn save<T: Serialize>(mut self, key: &str, payload: &T, permanent: bool) -> Self {
123        let serialized =
124            serde_json::to_string(payload).expect("failed to serialize cookie payload");
125
126        let cookie = Cookie::new(key.to_owned(), serialized);
127        let mut cookie = self.options.apply(cookie);
128
129        if permanent {
130            cookie.make_permanent();
132        }
133
134        self.inner = self.inner.add(cookie);
135
136        self
137    }
138
139    #[must_use]
141    pub fn remove(mut self, key: &str) -> Self {
142        self.inner = self.inner.remove(key.to_owned());
143        self
144    }
145
146    pub fn load<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>, CookieDecodeError> {
154        let Some(cookie) = self.inner.get(key) else {
155            return Ok(None);
156        };
157
158        let decoded = serde_json::from_str(cookie.value())?;
159        Ok(Some(decoded))
160    }
161}
162
163impl IntoResponseParts for CookieJar {
164    type Error = Infallible;
165
166    fn into_response_parts(self, res: ResponseParts) -> Result<ResponseParts, Self::Error> {
167        self.inner.into_response_parts(res)
168    }
169}