1use std::collections::HashMap;
47use std::sync::{Arc, RwLock};
48use std::time::{Duration, Instant};
49
50use super::{Microdescriptor, NetworkStatusDocument, ServerDescriptor};
51
52const DEFAULT_CONSENSUS_TTL: Duration = Duration::from_secs(3 * 3600);
54
55const DEFAULT_SERVER_DESCRIPTOR_TTL: Duration = Duration::from_secs(24 * 3600);
57
58const DEFAULT_MICRODESCRIPTOR_TTL: Duration = Duration::from_secs(24 * 3600);
60
61const DEFAULT_MAX_ENTRIES: usize = 1000;
63
64#[derive(Debug, Clone)]
66struct CacheEntry<T> {
67 value: T,
68 expires_at: Instant,
69 last_accessed: Instant,
70}
71
72impl<T> CacheEntry<T> {
73 fn new(value: T, ttl: Duration) -> Self {
74 let now = Instant::now();
75 Self {
76 value,
77 expires_at: now + ttl,
78 last_accessed: now,
79 }
80 }
81
82 fn is_expired(&self) -> bool {
83 Instant::now() >= self.expires_at
84 }
85
86 fn touch(&mut self) {
87 self.last_accessed = Instant::now();
88 }
89}
90
91#[derive(Debug, Clone)]
109pub struct DescriptorCache {
110 inner: Arc<RwLock<CacheInner>>,
111}
112
113#[derive(Debug)]
114struct CacheInner {
115 consensus: Option<CacheEntry<NetworkStatusDocument>>,
116 server_descriptors: HashMap<String, CacheEntry<ServerDescriptor>>,
117 microdescriptors: HashMap<String, CacheEntry<Microdescriptor>>,
118 consensus_ttl: Duration,
119 server_descriptor_ttl: Duration,
120 microdescriptor_ttl: Duration,
121 max_entries: usize,
122 stats: CacheStats,
123}
124
125#[derive(Debug, Clone, Default)]
127pub struct CacheStats {
128 pub hits: u64,
130 pub misses: u64,
132 pub expirations: u64,
134 pub evictions: u64,
136}
137
138impl CacheStats {
139 pub fn hit_rate(&self) -> f64 {
141 let total = self.hits + self.misses;
142 if total == 0 {
143 0.0
144 } else {
145 (self.hits as f64 / total as f64) * 100.0
146 }
147 }
148}
149
150impl DescriptorCache {
151 pub fn new() -> Self {
159 Self {
160 inner: Arc::new(RwLock::new(CacheInner {
161 consensus: None,
162 server_descriptors: HashMap::new(),
163 microdescriptors: HashMap::new(),
164 consensus_ttl: DEFAULT_CONSENSUS_TTL,
165 server_descriptor_ttl: DEFAULT_SERVER_DESCRIPTOR_TTL,
166 microdescriptor_ttl: DEFAULT_MICRODESCRIPTOR_TTL,
167 max_entries: DEFAULT_MAX_ENTRIES,
168 stats: CacheStats::default(),
169 })),
170 }
171 }
172
173 pub fn with_consensus_ttl(self, ttl: Duration) -> Self {
175 self.inner.write().unwrap().consensus_ttl = ttl;
176 self
177 }
178
179 pub fn with_server_descriptor_ttl(self, ttl: Duration) -> Self {
181 self.inner.write().unwrap().server_descriptor_ttl = ttl;
182 self
183 }
184
185 pub fn with_microdescriptor_ttl(self, ttl: Duration) -> Self {
187 self.inner.write().unwrap().microdescriptor_ttl = ttl;
188 self
189 }
190
191 pub fn with_max_entries(self, max: usize) -> Self {
195 self.inner.write().unwrap().max_entries = max;
196 self
197 }
198
199 pub fn get_consensus(&self) -> Option<NetworkStatusDocument> {
201 let mut inner = self.inner.write().unwrap();
202
203 if let Some(entry) = &mut inner.consensus {
204 if entry.is_expired() {
205 inner.consensus = None;
206 inner.stats.expirations += 1;
207 inner.stats.misses += 1;
208 return None;
209 }
210 entry.touch();
211 let value = entry.value.clone();
212 inner.stats.hits += 1;
213 return Some(value);
214 }
215
216 inner.stats.misses += 1;
217 None
218 }
219
220 pub fn put_consensus(&self, consensus: NetworkStatusDocument) {
222 let mut inner = self.inner.write().unwrap();
223 let ttl = inner.consensus_ttl;
224 inner.consensus = Some(CacheEntry::new(consensus, ttl));
225 }
226
227 pub fn get_server_descriptor(&self, fingerprint: &str) -> Option<ServerDescriptor> {
229 let mut inner = self.inner.write().unwrap();
230
231 let is_expired = inner
232 .server_descriptors
233 .get(fingerprint)
234 .map(|entry| entry.is_expired())
235 .unwrap_or(false);
236
237 if is_expired {
238 inner.server_descriptors.remove(fingerprint);
239 inner.stats.expirations += 1;
240 inner.stats.misses += 1;
241 return None;
242 }
243
244 if let Some(entry) = inner.server_descriptors.get_mut(fingerprint) {
245 entry.touch();
246 let value = entry.value.clone();
247 inner.stats.hits += 1;
248 return Some(value);
249 }
250
251 inner.stats.misses += 1;
252 None
253 }
254
255 pub fn put_server_descriptor(&self, fingerprint: String, descriptor: ServerDescriptor) {
257 let mut inner = self.inner.write().unwrap();
258 let ttl = inner.server_descriptor_ttl;
259
260 inner.evict_if_needed();
261 inner
262 .server_descriptors
263 .insert(fingerprint, CacheEntry::new(descriptor, ttl));
264 }
265
266 pub fn get_microdescriptor(&self, digest: &str) -> Option<Microdescriptor> {
268 let mut inner = self.inner.write().unwrap();
269
270 let is_expired = inner
271 .microdescriptors
272 .get(digest)
273 .map(|entry| entry.is_expired())
274 .unwrap_or(false);
275
276 if is_expired {
277 inner.microdescriptors.remove(digest);
278 inner.stats.expirations += 1;
279 inner.stats.misses += 1;
280 return None;
281 }
282
283 if let Some(entry) = inner.microdescriptors.get_mut(digest) {
284 entry.touch();
285 let value = entry.value.clone();
286 inner.stats.hits += 1;
287 return Some(value);
288 }
289
290 inner.stats.misses += 1;
291 None
292 }
293
294 pub fn put_microdescriptor(&self, digest: String, descriptor: Microdescriptor) {
296 let mut inner = self.inner.write().unwrap();
297 let ttl = inner.microdescriptor_ttl;
298
299 inner.evict_if_needed();
300 inner
301 .microdescriptors
302 .insert(digest, CacheEntry::new(descriptor, ttl));
303 }
304
305 pub fn clear(&self) {
307 let mut inner = self.inner.write().unwrap();
308 inner.consensus = None;
309 inner.server_descriptors.clear();
310 inner.microdescriptors.clear();
311 }
312
313 pub fn evict_expired(&self) {
315 let mut inner = self.inner.write().unwrap();
316
317 if let Some(entry) = &inner.consensus {
318 if entry.is_expired() {
319 inner.consensus = None;
320 inner.stats.expirations += 1;
321 }
322 }
323
324 let mut expired_server_keys = Vec::new();
325 for (key, entry) in &inner.server_descriptors {
326 if entry.is_expired() {
327 expired_server_keys.push(key.clone());
328 }
329 }
330 for key in expired_server_keys {
331 inner.server_descriptors.remove(&key);
332 inner.stats.expirations += 1;
333 }
334
335 let mut expired_micro_keys = Vec::new();
336 for (key, entry) in &inner.microdescriptors {
337 if entry.is_expired() {
338 expired_micro_keys.push(key.clone());
339 }
340 }
341 for key in expired_micro_keys {
342 inner.microdescriptors.remove(&key);
343 inner.stats.expirations += 1;
344 }
345 }
346
347 pub fn stats(&self) -> CacheStats {
349 self.inner.read().unwrap().stats.clone()
350 }
351
352 pub fn len(&self) -> usize {
354 let inner = self.inner.read().unwrap();
355 let consensus_count = if inner.consensus.is_some() { 1 } else { 0 };
356 consensus_count + inner.server_descriptors.len() + inner.microdescriptors.len()
357 }
358
359 pub fn is_empty(&self) -> bool {
361 self.len() == 0
362 }
363}
364
365impl CacheInner {
366 fn evict_if_needed(&mut self) {
367 let total_entries = self.server_descriptors.len() + self.microdescriptors.len();
368
369 if total_entries >= self.max_entries {
370 self.evict_lru();
371 }
372 }
373
374 fn evict_lru(&mut self) {
375 let mut all_entries: Vec<(String, Instant, bool)> = Vec::new();
376
377 for (key, entry) in &self.server_descriptors {
378 all_entries.push((key.clone(), entry.last_accessed, false));
379 }
380
381 for (key, entry) in &self.microdescriptors {
382 all_entries.push((key.clone(), entry.last_accessed, true));
383 }
384
385 all_entries.sort_by_key(|(_, accessed, _)| *accessed);
386
387 let total_entries = self.server_descriptors.len() + self.microdescriptors.len();
388 let to_evict = if total_entries >= self.max_entries {
389 (total_entries - self.max_entries + 1).max(1)
390 } else {
391 return;
392 };
393
394 for (key, _, is_micro) in all_entries.iter().take(to_evict) {
395 if *is_micro {
396 self.microdescriptors.remove(key);
397 } else {
398 self.server_descriptors.remove(key);
399 }
400 self.stats.evictions += 1;
401 }
402 }
403}
404
405impl Default for DescriptorCache {
406 fn default() -> Self {
407 Self::new()
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use crate::descriptor::Descriptor;
415 use chrono::Utc;
416
417 fn create_test_consensus() -> NetworkStatusDocument {
418 NetworkStatusDocument::parse(
419 r#"network-status-version 3
420vote-status consensus
421consensus-method 1
422valid-after 2023-01-01 00:00:00
423fresh-until 2023-01-01 01:00:00
424valid-until 2023-01-01 03:00:00
425"#,
426 )
427 .unwrap()
428 }
429
430 fn create_test_server_descriptor() -> ServerDescriptor {
431 ServerDescriptor::new(
432 "TestRelay".to_string(),
433 "192.168.1.1".parse().unwrap(),
434 9001,
435 Utc::now(),
436 "test".to_string(),
437 )
438 }
439
440 fn create_test_microdescriptor() -> Microdescriptor {
441 Microdescriptor::parse(
442 "onion-key\n-----BEGIN RSA PUBLIC KEY-----\ntest\n-----END RSA PUBLIC KEY-----\n",
443 )
444 .unwrap()
445 }
446
447 #[test]
448 fn test_cache_consensus() {
449 let cache = DescriptorCache::new();
450 let consensus = create_test_consensus();
451
452 assert!(cache.get_consensus().is_none());
453
454 cache.put_consensus(consensus.clone());
455
456 let cached = cache.get_consensus();
457 assert!(cached.is_some());
458 assert_eq!(cached.unwrap().consensus_method, consensus.consensus_method);
459 }
460
461 #[test]
462 fn test_cache_server_descriptor() {
463 let cache = DescriptorCache::new();
464 let descriptor = create_test_server_descriptor();
465 let fingerprint = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".to_string();
466
467 assert!(cache.get_server_descriptor(&fingerprint).is_none());
468
469 cache.put_server_descriptor(fingerprint.clone(), descriptor.clone());
470
471 let cached = cache.get_server_descriptor(&fingerprint);
472 assert!(cached.is_some());
473 assert_eq!(cached.unwrap().nickname, descriptor.nickname);
474 }
475
476 #[test]
477 fn test_cache_microdescriptor() {
478 let cache = DescriptorCache::new();
479 let descriptor = create_test_microdescriptor();
480 let digest = "test_digest".to_string();
481
482 assert!(cache.get_microdescriptor(&digest).is_none());
483
484 cache.put_microdescriptor(digest.clone(), descriptor.clone());
485
486 let cached = cache.get_microdescriptor(&digest);
487 assert!(cached.is_some());
488 }
489
490 #[test]
491 fn test_cache_expiration() {
492 let cache = DescriptorCache::new().with_consensus_ttl(Duration::from_millis(10));
493
494 let consensus = create_test_consensus();
495 cache.put_consensus(consensus);
496
497 assert!(cache.get_consensus().is_some());
498
499 std::thread::sleep(Duration::from_millis(20));
500
501 assert!(cache.get_consensus().is_none());
502 }
503
504 #[test]
505 fn test_cache_clear() {
506 let cache = DescriptorCache::new();
507
508 cache.put_consensus(create_test_consensus());
509 cache.put_server_descriptor("fp1".to_string(), create_test_server_descriptor());
510 cache.put_microdescriptor("digest1".to_string(), create_test_microdescriptor());
511
512 assert_eq!(cache.len(), 3);
513
514 cache.clear();
515
516 assert_eq!(cache.len(), 0);
517 assert!(cache.is_empty());
518 }
519
520 #[test]
521 fn test_cache_stats() {
522 let cache = DescriptorCache::new();
523 let consensus = create_test_consensus();
524
525 cache.put_consensus(consensus);
526
527 assert!(cache.get_consensus().is_some());
528 assert!(cache.get_consensus().is_some());
529
530 let stats = cache.stats();
531 assert_eq!(stats.hits, 2);
532 assert_eq!(stats.misses, 0);
533 assert!(stats.hit_rate() > 99.0);
534 }
535
536 #[test]
537 fn test_cache_eviction() {
538 let cache = DescriptorCache::new().with_max_entries(5);
539
540 for i in 0..10 {
541 cache.put_server_descriptor(format!("fp{}", i), create_test_server_descriptor());
542 }
543
544 assert!(cache.len() <= 5);
545
546 let stats = cache.stats();
547 assert!(stats.evictions > 0);
548 }
549
550 #[test]
551 fn test_evict_expired() {
552 let cache = DescriptorCache::new().with_server_descriptor_ttl(Duration::from_millis(10));
553
554 cache.put_server_descriptor("fp1".to_string(), create_test_server_descriptor());
555 cache.put_server_descriptor("fp2".to_string(), create_test_server_descriptor());
556
557 assert_eq!(cache.len(), 2);
558
559 std::thread::sleep(Duration::from_millis(20));
560
561 cache.evict_expired();
562
563 assert_eq!(cache.len(), 0);
564 }
565
566 #[test]
567 fn test_lru_eviction() {
568 let cache = DescriptorCache::new().with_max_entries(3);
569
570 cache.put_server_descriptor("fp1".to_string(), create_test_server_descriptor());
571 cache.put_server_descriptor("fp2".to_string(), create_test_server_descriptor());
572 cache.put_server_descriptor("fp3".to_string(), create_test_server_descriptor());
573
574 cache.get_server_descriptor("fp1");
575 cache.get_server_descriptor("fp2");
576
577 std::thread::sleep(Duration::from_millis(10));
578
579 cache.put_server_descriptor("fp4".to_string(), create_test_server_descriptor());
580
581 assert!(cache.get_server_descriptor("fp1").is_some());
582 assert!(cache.get_server_descriptor("fp2").is_some());
583 }
584
585 #[test]
586 fn test_cache_hit_rate() {
587 let cache = DescriptorCache::new();
588
589 cache.put_consensus(create_test_consensus());
590
591 cache.get_consensus();
592 cache.get_consensus();
593 cache.get_server_descriptor("missing");
594
595 let stats = cache.stats();
596 assert_eq!(stats.hits, 2);
597 assert_eq!(stats.misses, 1);
598 assert!((stats.hit_rate() - 66.67).abs() < 0.1);
599 }
600}