stem_rs/descriptor/
cache.rs

1//! Descriptor caching for improved performance.
2//!
3//! This module provides in-memory caching of Tor descriptors to avoid
4//! repeated downloads from the Tor process. Caching significantly improves
5//! performance for applications that frequently query descriptor information.
6//!
7//! # Overview
8//!
9//! The descriptor cache stores parsed descriptors with automatic expiration
10//! based on their validity periods. Different descriptor types have different
11//! cache lifetimes:
12//!
13//! - **Consensus documents**: 3 hours (typical validity period)
14//! - **Server descriptors**: 24 hours (published daily)
15//! - **Microdescriptors**: 24 hours (referenced by consensus)
16//!
17//! # Thread Safety
18//!
19//! The cache is thread-safe and can be shared across multiple tasks using
20//! `Arc<DescriptorCache>`. All operations use interior mutability with
21//! `RwLock` for concurrent access.
22//!
23//! # Example
24//!
25//! ```rust
26//! use stem_rs::descriptor::cache::DescriptorCache;
27//! use std::time::Duration;
28//!
29//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
30//! let cache = DescriptorCache::new()
31//!     .with_consensus_ttl(Duration::from_secs(3 * 3600))
32//!     .with_max_entries(1000);
33//!
34//! // Cache is automatically used by Controller methods
35//! // when enabled via Controller::with_descriptor_cache()
36//! # Ok(())
37//! # }
38//! ```
39//!
40//! # Memory Management
41//!
42//! The cache automatically evicts expired entries and enforces a maximum
43//! entry limit to prevent unbounded memory growth. When the limit is reached,
44//! the least recently used entries are evicted.
45
46use std::collections::HashMap;
47use std::sync::{Arc, RwLock};
48use std::time::{Duration, Instant};
49
50use super::{Microdescriptor, NetworkStatusDocument, ServerDescriptor};
51
52/// Default TTL for consensus documents (3 hours).
53const DEFAULT_CONSENSUS_TTL: Duration = Duration::from_secs(3 * 3600);
54
55/// Default TTL for server descriptors (24 hours).
56const DEFAULT_SERVER_DESCRIPTOR_TTL: Duration = Duration::from_secs(24 * 3600);
57
58/// Default TTL for microdescriptors (24 hours).
59const DEFAULT_MICRODESCRIPTOR_TTL: Duration = Duration::from_secs(24 * 3600);
60
61/// Default maximum number of cached entries.
62const DEFAULT_MAX_ENTRIES: usize = 1000;
63
64/// A cached entry with expiration time.
65#[derive(Debug, Clone)]
66struct CacheEntry<T> {
67    value: T,
68    expires_at: Instant,
69    last_accessed: Instant,
70}
71
72impl<T> CacheEntry<T> {
73    fn new(value: T, ttl: Duration) -> Self {
74        let now = Instant::now();
75        Self {
76            value,
77            expires_at: now + ttl,
78            last_accessed: now,
79        }
80    }
81
82    fn is_expired(&self) -> bool {
83        Instant::now() >= self.expires_at
84    }
85
86    fn touch(&mut self) {
87        self.last_accessed = Instant::now();
88    }
89}
90
91/// In-memory cache for Tor descriptors.
92///
93/// Provides automatic expiration and LRU eviction for efficient memory usage.
94/// The cache is thread-safe and can be shared across multiple tasks.
95///
96/// # Example
97///
98/// ```rust
99/// use stem_rs::descriptor::cache::DescriptorCache;
100/// use std::time::Duration;
101///
102/// let cache = DescriptorCache::new()
103///     .with_consensus_ttl(Duration::from_secs(3600))
104///     .with_max_entries(500);
105///
106/// // Cache is used automatically by Controller when enabled
107/// ```
108#[derive(Debug, Clone)]
109pub struct DescriptorCache {
110    inner: Arc<RwLock<CacheInner>>,
111}
112
113#[derive(Debug)]
114struct CacheInner {
115    consensus: Option<CacheEntry<NetworkStatusDocument>>,
116    server_descriptors: HashMap<String, CacheEntry<ServerDescriptor>>,
117    microdescriptors: HashMap<String, CacheEntry<Microdescriptor>>,
118    consensus_ttl: Duration,
119    server_descriptor_ttl: Duration,
120    microdescriptor_ttl: Duration,
121    max_entries: usize,
122    stats: CacheStats,
123}
124
125/// Statistics about cache performance.
126#[derive(Debug, Clone, Default)]
127pub struct CacheStats {
128    /// Number of cache hits.
129    pub hits: u64,
130    /// Number of cache misses.
131    pub misses: u64,
132    /// Number of evictions due to expiration.
133    pub expirations: u64,
134    /// Number of evictions due to size limit.
135    pub evictions: u64,
136}
137
138impl CacheStats {
139    /// Returns the cache hit rate as a percentage (0.0 to 100.0).
140    pub fn hit_rate(&self) -> f64 {
141        let total = self.hits + self.misses;
142        if total == 0 {
143            0.0
144        } else {
145            (self.hits as f64 / total as f64) * 100.0
146        }
147    }
148}
149
150impl DescriptorCache {
151    /// Creates a new descriptor cache with default settings.
152    ///
153    /// Default settings:
154    /// - Consensus TTL: 3 hours
155    /// - Server descriptor TTL: 24 hours
156    /// - Microdescriptor TTL: 24 hours
157    /// - Max entries: 1000
158    pub fn new() -> Self {
159        Self {
160            inner: Arc::new(RwLock::new(CacheInner {
161                consensus: None,
162                server_descriptors: HashMap::new(),
163                microdescriptors: HashMap::new(),
164                consensus_ttl: DEFAULT_CONSENSUS_TTL,
165                server_descriptor_ttl: DEFAULT_SERVER_DESCRIPTOR_TTL,
166                microdescriptor_ttl: DEFAULT_MICRODESCRIPTOR_TTL,
167                max_entries: DEFAULT_MAX_ENTRIES,
168                stats: CacheStats::default(),
169            })),
170        }
171    }
172
173    /// Sets the TTL for consensus documents.
174    pub fn with_consensus_ttl(self, ttl: Duration) -> Self {
175        self.inner.write().unwrap().consensus_ttl = ttl;
176        self
177    }
178
179    /// Sets the TTL for server descriptors.
180    pub fn with_server_descriptor_ttl(self, ttl: Duration) -> Self {
181        self.inner.write().unwrap().server_descriptor_ttl = ttl;
182        self
183    }
184
185    /// Sets the TTL for microdescriptors.
186    pub fn with_microdescriptor_ttl(self, ttl: Duration) -> Self {
187        self.inner.write().unwrap().microdescriptor_ttl = ttl;
188        self
189    }
190
191    /// Sets the maximum number of cached entries.
192    ///
193    /// When this limit is reached, the least recently used entries are evicted.
194    pub fn with_max_entries(self, max: usize) -> Self {
195        self.inner.write().unwrap().max_entries = max;
196        self
197    }
198
199    /// Retrieves the cached consensus document if available and not expired.
200    pub fn get_consensus(&self) -> Option<NetworkStatusDocument> {
201        let mut inner = self.inner.write().unwrap();
202
203        if let Some(entry) = &mut inner.consensus {
204            if entry.is_expired() {
205                inner.consensus = None;
206                inner.stats.expirations += 1;
207                inner.stats.misses += 1;
208                return None;
209            }
210            entry.touch();
211            let value = entry.value.clone();
212            inner.stats.hits += 1;
213            return Some(value);
214        }
215
216        inner.stats.misses += 1;
217        None
218    }
219
220    /// Stores a consensus document in the cache.
221    pub fn put_consensus(&self, consensus: NetworkStatusDocument) {
222        let mut inner = self.inner.write().unwrap();
223        let ttl = inner.consensus_ttl;
224        inner.consensus = Some(CacheEntry::new(consensus, ttl));
225    }
226
227    /// Retrieves a cached server descriptor by fingerprint.
228    pub fn get_server_descriptor(&self, fingerprint: &str) -> Option<ServerDescriptor> {
229        let mut inner = self.inner.write().unwrap();
230
231        let is_expired = inner
232            .server_descriptors
233            .get(fingerprint)
234            .map(|entry| entry.is_expired())
235            .unwrap_or(false);
236
237        if is_expired {
238            inner.server_descriptors.remove(fingerprint);
239            inner.stats.expirations += 1;
240            inner.stats.misses += 1;
241            return None;
242        }
243
244        if let Some(entry) = inner.server_descriptors.get_mut(fingerprint) {
245            entry.touch();
246            let value = entry.value.clone();
247            inner.stats.hits += 1;
248            return Some(value);
249        }
250
251        inner.stats.misses += 1;
252        None
253    }
254
255    /// Stores a server descriptor in the cache.
256    pub fn put_server_descriptor(&self, fingerprint: String, descriptor: ServerDescriptor) {
257        let mut inner = self.inner.write().unwrap();
258        let ttl = inner.server_descriptor_ttl;
259
260        inner.evict_if_needed();
261        inner
262            .server_descriptors
263            .insert(fingerprint, CacheEntry::new(descriptor, ttl));
264    }
265
266    /// Retrieves a cached microdescriptor by digest.
267    pub fn get_microdescriptor(&self, digest: &str) -> Option<Microdescriptor> {
268        let mut inner = self.inner.write().unwrap();
269
270        let is_expired = inner
271            .microdescriptors
272            .get(digest)
273            .map(|entry| entry.is_expired())
274            .unwrap_or(false);
275
276        if is_expired {
277            inner.microdescriptors.remove(digest);
278            inner.stats.expirations += 1;
279            inner.stats.misses += 1;
280            return None;
281        }
282
283        if let Some(entry) = inner.microdescriptors.get_mut(digest) {
284            entry.touch();
285            let value = entry.value.clone();
286            inner.stats.hits += 1;
287            return Some(value);
288        }
289
290        inner.stats.misses += 1;
291        None
292    }
293
294    /// Stores a microdescriptor in the cache.
295    pub fn put_microdescriptor(&self, digest: String, descriptor: Microdescriptor) {
296        let mut inner = self.inner.write().unwrap();
297        let ttl = inner.microdescriptor_ttl;
298
299        inner.evict_if_needed();
300        inner
301            .microdescriptors
302            .insert(digest, CacheEntry::new(descriptor, ttl));
303    }
304
305    /// Clears all cached entries.
306    pub fn clear(&self) {
307        let mut inner = self.inner.write().unwrap();
308        inner.consensus = None;
309        inner.server_descriptors.clear();
310        inner.microdescriptors.clear();
311    }
312
313    /// Removes expired entries from the cache.
314    pub fn evict_expired(&self) {
315        let mut inner = self.inner.write().unwrap();
316
317        if let Some(entry) = &inner.consensus {
318            if entry.is_expired() {
319                inner.consensus = None;
320                inner.stats.expirations += 1;
321            }
322        }
323
324        let mut expired_server_keys = Vec::new();
325        for (key, entry) in &inner.server_descriptors {
326            if entry.is_expired() {
327                expired_server_keys.push(key.clone());
328            }
329        }
330        for key in expired_server_keys {
331            inner.server_descriptors.remove(&key);
332            inner.stats.expirations += 1;
333        }
334
335        let mut expired_micro_keys = Vec::new();
336        for (key, entry) in &inner.microdescriptors {
337            if entry.is_expired() {
338                expired_micro_keys.push(key.clone());
339            }
340        }
341        for key in expired_micro_keys {
342            inner.microdescriptors.remove(&key);
343            inner.stats.expirations += 1;
344        }
345    }
346
347    /// Returns the current cache statistics.
348    pub fn stats(&self) -> CacheStats {
349        self.inner.read().unwrap().stats.clone()
350    }
351
352    /// Returns the number of entries currently in the cache.
353    pub fn len(&self) -> usize {
354        let inner = self.inner.read().unwrap();
355        let consensus_count = if inner.consensus.is_some() { 1 } else { 0 };
356        consensus_count + inner.server_descriptors.len() + inner.microdescriptors.len()
357    }
358
359    /// Returns true if the cache is empty.
360    pub fn is_empty(&self) -> bool {
361        self.len() == 0
362    }
363}
364
365impl CacheInner {
366    fn evict_if_needed(&mut self) {
367        let total_entries = self.server_descriptors.len() + self.microdescriptors.len();
368
369        if total_entries >= self.max_entries {
370            self.evict_lru();
371        }
372    }
373
374    fn evict_lru(&mut self) {
375        let mut all_entries: Vec<(String, Instant, bool)> = Vec::new();
376
377        for (key, entry) in &self.server_descriptors {
378            all_entries.push((key.clone(), entry.last_accessed, false));
379        }
380
381        for (key, entry) in &self.microdescriptors {
382            all_entries.push((key.clone(), entry.last_accessed, true));
383        }
384
385        all_entries.sort_by_key(|(_, accessed, _)| *accessed);
386
387        let total_entries = self.server_descriptors.len() + self.microdescriptors.len();
388        let to_evict = if total_entries >= self.max_entries {
389            (total_entries - self.max_entries + 1).max(1)
390        } else {
391            return;
392        };
393
394        for (key, _, is_micro) in all_entries.iter().take(to_evict) {
395            if *is_micro {
396                self.microdescriptors.remove(key);
397            } else {
398                self.server_descriptors.remove(key);
399            }
400            self.stats.evictions += 1;
401        }
402    }
403}
404
405impl Default for DescriptorCache {
406    fn default() -> Self {
407        Self::new()
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use crate::descriptor::Descriptor;
415    use chrono::Utc;
416
417    fn create_test_consensus() -> NetworkStatusDocument {
418        NetworkStatusDocument::parse(
419            r#"network-status-version 3
420vote-status consensus
421consensus-method 1
422valid-after 2023-01-01 00:00:00
423fresh-until 2023-01-01 01:00:00
424valid-until 2023-01-01 03:00:00
425"#,
426        )
427        .unwrap()
428    }
429
430    fn create_test_server_descriptor() -> ServerDescriptor {
431        ServerDescriptor::new(
432            "TestRelay".to_string(),
433            "192.168.1.1".parse().unwrap(),
434            9001,
435            Utc::now(),
436            "test".to_string(),
437        )
438    }
439
440    fn create_test_microdescriptor() -> Microdescriptor {
441        Microdescriptor::parse(
442            "onion-key\n-----BEGIN RSA PUBLIC KEY-----\ntest\n-----END RSA PUBLIC KEY-----\n",
443        )
444        .unwrap()
445    }
446
447    #[test]
448    fn test_cache_consensus() {
449        let cache = DescriptorCache::new();
450        let consensus = create_test_consensus();
451
452        assert!(cache.get_consensus().is_none());
453
454        cache.put_consensus(consensus.clone());
455
456        let cached = cache.get_consensus();
457        assert!(cached.is_some());
458        assert_eq!(cached.unwrap().consensus_method, consensus.consensus_method);
459    }
460
461    #[test]
462    fn test_cache_server_descriptor() {
463        let cache = DescriptorCache::new();
464        let descriptor = create_test_server_descriptor();
465        let fingerprint = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA".to_string();
466
467        assert!(cache.get_server_descriptor(&fingerprint).is_none());
468
469        cache.put_server_descriptor(fingerprint.clone(), descriptor.clone());
470
471        let cached = cache.get_server_descriptor(&fingerprint);
472        assert!(cached.is_some());
473        assert_eq!(cached.unwrap().nickname, descriptor.nickname);
474    }
475
476    #[test]
477    fn test_cache_microdescriptor() {
478        let cache = DescriptorCache::new();
479        let descriptor = create_test_microdescriptor();
480        let digest = "test_digest".to_string();
481
482        assert!(cache.get_microdescriptor(&digest).is_none());
483
484        cache.put_microdescriptor(digest.clone(), descriptor.clone());
485
486        let cached = cache.get_microdescriptor(&digest);
487        assert!(cached.is_some());
488    }
489
490    #[test]
491    fn test_cache_expiration() {
492        let cache = DescriptorCache::new().with_consensus_ttl(Duration::from_millis(10));
493
494        let consensus = create_test_consensus();
495        cache.put_consensus(consensus);
496
497        assert!(cache.get_consensus().is_some());
498
499        std::thread::sleep(Duration::from_millis(20));
500
501        assert!(cache.get_consensus().is_none());
502    }
503
504    #[test]
505    fn test_cache_clear() {
506        let cache = DescriptorCache::new();
507
508        cache.put_consensus(create_test_consensus());
509        cache.put_server_descriptor("fp1".to_string(), create_test_server_descriptor());
510        cache.put_microdescriptor("digest1".to_string(), create_test_microdescriptor());
511
512        assert_eq!(cache.len(), 3);
513
514        cache.clear();
515
516        assert_eq!(cache.len(), 0);
517        assert!(cache.is_empty());
518    }
519
520    #[test]
521    fn test_cache_stats() {
522        let cache = DescriptorCache::new();
523        let consensus = create_test_consensus();
524
525        cache.put_consensus(consensus);
526
527        assert!(cache.get_consensus().is_some());
528        assert!(cache.get_consensus().is_some());
529
530        let stats = cache.stats();
531        assert_eq!(stats.hits, 2);
532        assert_eq!(stats.misses, 0);
533        assert!(stats.hit_rate() > 99.0);
534    }
535
536    #[test]
537    fn test_cache_eviction() {
538        let cache = DescriptorCache::new().with_max_entries(5);
539
540        for i in 0..10 {
541            cache.put_server_descriptor(format!("fp{}", i), create_test_server_descriptor());
542        }
543
544        assert!(cache.len() <= 5);
545
546        let stats = cache.stats();
547        assert!(stats.evictions > 0);
548    }
549
550    #[test]
551    fn test_evict_expired() {
552        let cache = DescriptorCache::new().with_server_descriptor_ttl(Duration::from_millis(10));
553
554        cache.put_server_descriptor("fp1".to_string(), create_test_server_descriptor());
555        cache.put_server_descriptor("fp2".to_string(), create_test_server_descriptor());
556
557        assert_eq!(cache.len(), 2);
558
559        std::thread::sleep(Duration::from_millis(20));
560
561        cache.evict_expired();
562
563        assert_eq!(cache.len(), 0);
564    }
565
566    #[test]
567    fn test_lru_eviction() {
568        let cache = DescriptorCache::new().with_max_entries(3);
569
570        cache.put_server_descriptor("fp1".to_string(), create_test_server_descriptor());
571        cache.put_server_descriptor("fp2".to_string(), create_test_server_descriptor());
572        cache.put_server_descriptor("fp3".to_string(), create_test_server_descriptor());
573
574        cache.get_server_descriptor("fp1");
575        cache.get_server_descriptor("fp2");
576
577        std::thread::sleep(Duration::from_millis(10));
578
579        cache.put_server_descriptor("fp4".to_string(), create_test_server_descriptor());
580
581        assert!(cache.get_server_descriptor("fp1").is_some());
582        assert!(cache.get_server_descriptor("fp2").is_some());
583    }
584
585    #[test]
586    fn test_cache_hit_rate() {
587        let cache = DescriptorCache::new();
588
589        cache.put_consensus(create_test_consensus());
590
591        cache.get_consensus();
592        cache.get_consensus();
593        cache.get_server_descriptor("missing");
594
595        let stats = cache.stats();
596        assert_eq!(stats.hits, 2);
597        assert_eq!(stats.misses, 1);
598        assert!((stats.hit_rate() - 66.67).abs() < 0.1);
599    }
600}