1use std::collections::HashMap;
7use std::sync::Arc;
8
9use chrono::{DateTime, Utc};
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use tracing::{debug, info};
13use uuid::Uuid;
14
15use crate::{OtpError, OtpResult};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct OtpRecord {
23 pub id: Uuid,
25 pub token_hash: Vec<u8>,
27 pub entity_id: String,
29 pub label: String,
31 pub profile: String,
33 pub created_at: DateTime<Utc>,
35 pub expires_at: DateTime<Utc>,
37 pub max_uses: u32,
39 pub current_uses: u32,
41 pub revoked: bool,
43}
44
45pub trait OtpStore: Send + Sync {
49 fn insert(&self, record: OtpRecord) -> impl std::future::Future<Output = OtpResult<()>> + Send;
51
52 fn find_by_hash(
54 &self,
55 hash: &[u8],
56 ) -> impl std::future::Future<Output = OtpResult<Option<OtpRecord>>> + Send;
57
58 fn increment_uses(
60 &self,
61 id: &Uuid,
62 new_count: u32,
63 ) -> impl std::future::Future<Output = OtpResult<()>> + Send;
64
65 fn revoke(&self, id: &Uuid) -> impl std::future::Future<Output = OtpResult<()>> + Send;
67
68 fn cleanup_expired(&self) -> impl std::future::Future<Output = OtpResult<u64>> + Send;
70}
71
72#[derive(Clone)]
80pub struct InMemoryOtpStore {
81 records: Arc<RwLock<HashMap<Uuid, OtpRecord>>>,
82}
83
84impl InMemoryOtpStore {
85 pub fn new() -> Self {
87 Self {
88 records: Arc::new(RwLock::new(HashMap::new())),
89 }
90 }
91}
92
93impl Default for InMemoryOtpStore {
94 fn default() -> Self {
95 Self::new()
96 }
97}
98
99impl OtpStore for InMemoryOtpStore {
100 async fn insert(&self, record: OtpRecord) -> OtpResult<()> {
101 debug!(id = %record.id, entity_id = %record.entity_id, "inserting OTP record");
102 self.records.write().insert(record.id, record);
103 Ok(())
104 }
105
106 async fn find_by_hash(&self, hash: &[u8]) -> OtpResult<Option<OtpRecord>> {
107 let records = self.records.read();
108 Ok(records.values().find(|r| r.token_hash == hash).cloned())
109 }
110
111 async fn increment_uses(&self, id: &Uuid, new_count: u32) -> OtpResult<()> {
112 let mut records = self.records.write();
113 match records.get_mut(id) {
114 Some(r) => {
115 r.current_uses = new_count;
116 Ok(())
117 }
118 None => Err(OtpError::NotFound),
119 }
120 }
121
122 async fn revoke(&self, id: &Uuid) -> OtpResult<()> {
123 let mut records = self.records.write();
124 match records.get_mut(id) {
125 Some(r) => {
126 r.revoked = true;
127 Ok(())
128 }
129 None => Err(OtpError::NotFound),
130 }
131 }
132
133 async fn cleanup_expired(&self) -> OtpResult<u64> {
134 let now = Utc::now();
135 let mut records = self.records.write();
136 let before = records.len();
137 records.retain(|_, r| r.expires_at > now);
138 let removed = (before - records.len()) as u64;
139 if removed > 0 {
140 info!(removed, "cleaned up expired OTP records");
141 }
142 Ok(removed)
143 }
144}
145
146pub struct DbOtpStore {
155 _pool: (),
157}
158
159impl DbOtpStore {
160 pub fn new() -> Self {
167 Self { _pool: () }
168 }
169}
170
171impl Default for DbOtpStore {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl OtpStore for DbOtpStore {
178 async fn insert(&self, _record: OtpRecord) -> OtpResult<()> {
179 Err(OtpError::StorageError(
181 "database OTP store not yet implemented".into(),
182 ))
183 }
184
185 async fn find_by_hash(&self, _hash: &[u8]) -> OtpResult<Option<OtpRecord>> {
186 Err(OtpError::StorageError(
187 "database OTP store not yet implemented".into(),
188 ))
189 }
190
191 async fn increment_uses(&self, _id: &Uuid, _new_count: u32) -> OtpResult<()> {
192 Err(OtpError::StorageError(
193 "database OTP store not yet implemented".into(),
194 ))
195 }
196
197 async fn revoke(&self, _id: &Uuid) -> OtpResult<()> {
198 Err(OtpError::StorageError(
199 "database OTP store not yet implemented".into(),
200 ))
201 }
202
203 async fn cleanup_expired(&self) -> OtpResult<u64> {
204 Err(OtpError::StorageError(
205 "database OTP store not yet implemented".into(),
206 ))
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use crate::generate::{OtpGenerator, OtpGeneratorConfig};
214
215 #[tokio::test]
216 async fn in_memory_store_round_trip() {
217 let store = InMemoryOtpStore::new();
218 let generator = OtpGenerator::new(OtpGeneratorConfig::default()).unwrap();
219 let otp = generator
220 .generate("host.example.com", "test", "default")
221 .unwrap();
222
223 let record = OtpRecord {
224 id: otp.metadata.id,
225 token_hash: otp.token_hash.clone(),
226 entity_id: otp.metadata.entity_id.clone(),
227 label: otp.metadata.label.clone(),
228 profile: otp.metadata.profile.clone(),
229 created_at: otp.metadata.created_at,
230 expires_at: otp.metadata.expires_at,
231 max_uses: otp.metadata.max_uses,
232 current_uses: 0,
233 revoked: false,
234 };
235
236 store.insert(record).await.unwrap();
237
238 let found = store.find_by_hash(&otp.token_hash).await.unwrap();
239 assert!(found.is_some());
240 assert_eq!(found.unwrap().entity_id, "host.example.com");
241 }
242
243 #[tokio::test]
244 async fn cleanup_removes_expired() {
245 let store = InMemoryOtpStore::new();
246 let expired = OtpRecord {
247 id: Uuid::new_v4(),
248 token_hash: vec![0u8; 32],
249 entity_id: "expired.example.com".into(),
250 label: "expired".into(),
251 profile: "default".into(),
252 created_at: Utc::now() - chrono::Duration::hours(2),
253 expires_at: Utc::now() - chrono::Duration::hours(1),
254 max_uses: 1,
255 current_uses: 0,
256 revoked: false,
257 };
258 store.insert(expired).await.unwrap();
259
260 let removed = store.cleanup_expired().await.unwrap();
261 assert_eq!(removed, 1);
262 }
263}