ferriclink_core/
caches.rs

1//! Cache classes for FerricLink Core.
2//!
3//! **Cache** provides an optional caching layer for LLMs.
4//!
5//! Cache is useful for two reasons:
6//!
7//! - It can save you money by reducing the number of API calls you make to the LLM
8//!   provider if you're often requesting the same completion multiple times.
9//! - It can speed up your application by reducing the number of API calls you make
10//!   to the LLM provider.
11//!
12//! Cache directly competes with Memory. See documentation for Pros and Cons.
13//!
14//! **Class hierarchy:**
15//!
16//! ```text
17//! BaseCache --> <name>Cache  # Examples: InMemoryCache, RedisCache, GPTCache
18//! ```
19
20use async_trait::async_trait;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25use tokio::sync::RwLock;
26
27use crate::errors::Result;
28use crate::impl_serializable;
29use crate::language_models::Generation;
30
31/// Type alias for cached return values
32pub type CachedGenerations = Vec<Generation>;
33
34/// Interface for a caching layer for LLMs and Chat models.
35///
36/// The cache interface consists of the following methods:
37///
38/// - lookup: Look up a value based on a prompt and llm_string.
39/// - update: Update the cache based on a prompt and llm_string.
40/// - clear: Clear the cache.
41///
42/// In addition, the cache interface provides an async version of each method.
43///
44/// The default implementation of the async methods is to run the synchronous
45/// method in an executor. It's recommended to override the async methods
46/// and provide async implementations to avoid unnecessary overhead.
47#[async_trait]
48pub trait BaseCache: Send + Sync {
49    /// Look up based on prompt and llm_string.
50    ///
51    /// A cache implementation is expected to generate a key from the 2-tuple
52    /// of prompt and llm_string (e.g., by concatenating them with a delimiter).
53    ///
54    /// # Arguments
55    ///
56    /// * `prompt` - A string representation of the prompt.
57    ///   In the case of a Chat model, the prompt is a non-trivial
58    ///   serialization of the prompt into the language model.
59    /// * `llm_string` - A string representation of the LLM configuration.
60    ///   This is used to capture the invocation parameters of the LLM
61    ///   (e.g., model name, temperature, stop tokens, max tokens, etc.).
62    ///   These invocation parameters are serialized into a string
63    ///   representation.
64    ///
65    /// # Returns
66    ///
67    /// On a cache miss, return None. On a cache hit, return the cached value.
68    /// The cached value is a list of Generations (or subclasses).
69    fn lookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>>;
70
71    /// Update cache based on prompt and llm_string.
72    ///
73    /// The prompt and llm_string are used to generate a key for the cache.
74    /// The key should match that of the lookup method.
75    ///
76    /// # Arguments
77    ///
78    /// * `prompt` - A string representation of the prompt.
79    ///   In the case of a Chat model, the prompt is a non-trivial
80    ///   serialization of the prompt into the language model.
81    /// * `llm_string` - A string representation of the LLM configuration.
82    ///   This is used to capture the invocation parameters of the LLM
83    ///   (e.g., model name, temperature, stop tokens, max tokens, etc.).
84    ///   These invocation parameters are serialized into a string
85    ///   representation.
86    /// * `return_val` - The value to be cached. The value is a list of Generations
87    ///   (or subclasses).
88    fn update(&self, prompt: &str, llm_string: &str, return_val: CachedGenerations) -> Result<()>;
89
90    /// Clear cache that can take additional keyword arguments.
91    fn clear(&self) -> Result<()>;
92
93    /// Async look up based on prompt and llm_string.
94    ///
95    /// A cache implementation is expected to generate a key from the 2-tuple
96    /// of prompt and llm_string (e.g., by concatenating them with a delimiter).
97    ///
98    /// # Arguments
99    ///
100    /// * `prompt` - A string representation of the prompt.
101    ///   In the case of a Chat model, the prompt is a non-trivial
102    ///   serialization of the prompt into the language model.
103    /// * `llm_string` - A string representation of the LLM configuration.
104    ///   This is used to capture the invocation parameters of the LLM
105    ///   (e.g., model name, temperature, stop tokens, max tokens, etc.).
106    ///   These invocation parameters are serialized into a string
107    ///   representation.
108    ///
109    /// # Returns
110    ///
111    /// On a cache miss, return None. On a cache hit, return the cached value.
112    /// The cached value is a list of Generations (or subclasses).
113    async fn alookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>>;
114
115    /// Async update cache based on prompt and llm_string.
116    ///
117    /// The prompt and llm_string are used to generate a key for the cache.
118    /// The key should match that of the lookup method.
119    ///
120    /// # Arguments
121    ///
122    /// * `prompt` - A string representation of the prompt.
123    ///   In the case of a Chat model, the prompt is a non-trivial
124    ///   serialization of the prompt into the language model.
125    /// * `llm_string` - A string representation of the LLM configuration.
126    ///   This is used to capture the invocation parameters of the LLM
127    ///   (e.g., model name, temperature, stop tokens, max tokens, etc.).
128    ///   These invocation parameters are serialized into a string
129    ///   representation.
130    /// * `return_val` - The value to be cached. The value is a list of Generations
131    ///   (or subclasses).
132    async fn aupdate(
133        &self,
134        prompt: &str,
135        llm_string: &str,
136        return_val: CachedGenerations,
137    ) -> Result<()>;
138
139    /// Async clear cache that can take additional keyword arguments.
140    async fn aclear(&self) -> Result<()>;
141}
142
143/// Cache that stores things in memory.
144///
145/// This is a simple in-memory cache implementation that stores cached values
146/// in a HashMap. It supports optional size limits and LRU eviction.
147#[derive(Debug)]
148pub struct InMemoryCache {
149    /// The actual cache storage
150    cache: Arc<RwLock<HashMap<String, CacheEntry>>>,
151    /// Maximum number of items to store in the cache
152    max_size: Option<usize>,
153    /// Statistics for monitoring
154    stats: Arc<RwLock<CacheStats>>,
155}
156
157/// A cache entry that includes the data and metadata
158#[derive(Debug, Clone)]
159struct CacheEntry {
160    /// The cached generations
161    data: CachedGenerations,
162    /// When this entry was created
163    created_at: Instant,
164    /// When this entry was last accessed
165    last_accessed: Instant,
166    /// Number of times this entry has been accessed
167    access_count: u64,
168}
169
170/// Cache statistics for monitoring
171#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct CacheStats {
173    /// Number of cache hits
174    pub hits: u64,
175    /// Number of cache misses
176    pub misses: u64,
177    /// Number of cache updates
178    pub updates: u64,
179    /// Number of cache clears
180    pub clears: u64,
181    /// Total number of entries currently in cache
182    pub current_size: usize,
183    /// Maximum size the cache has reached
184    pub max_size_reached: usize,
185}
186
187impl CacheStats {
188    /// Get the hit rate as a percentage
189    pub fn hit_rate(&self) -> f64 {
190        let total = self.hits + self.misses;
191        if total == 0 {
192            0.0
193        } else {
194            (self.hits as f64 / total as f64) * 100.0
195        }
196    }
197
198    /// Get the total number of requests
199    pub fn total_requests(&self) -> u64 {
200        self.hits + self.misses
201    }
202}
203
204impl InMemoryCache {
205    /// Create a new in-memory cache with no size limit.
206    pub fn new() -> Self {
207        Self::with_max_size(None)
208    }
209
210    /// Create a new in-memory cache with a maximum size.
211    ///
212    /// # Arguments
213    ///
214    /// * `max_size` - The maximum number of items to store in the cache.
215    ///   If None, the cache has no maximum size.
216    ///   If the cache exceeds the maximum size, the oldest items are removed.
217    ///
218    /// # Panics
219    ///
220    /// Panics if `max_size` is Some(0).
221    pub fn with_max_size(max_size: Option<usize>) -> Self {
222        if let Some(size) = max_size {
223            assert!(size > 0, "max_size must be greater than 0");
224        }
225
226        Self {
227            cache: Arc::new(RwLock::new(HashMap::new())),
228            max_size,
229            stats: Arc::new(RwLock::new(CacheStats::default())),
230        }
231    }
232
233    /// Generate a cache key from prompt and llm_string.
234    fn generate_key(prompt: &str, llm_string: &str) -> String {
235        // Use a simple concatenation with a delimiter
236        // In production, you might want to use a hash function
237        format!("{prompt}|||{llm_string}")
238    }
239
240    /// Get cache statistics.
241    pub async fn stats(&self) -> CacheStats {
242        let stats = self.stats.read().await;
243        let cache = self.cache.read().await;
244        CacheStats {
245            current_size: cache.len(),
246            ..stats.clone()
247        }
248    }
249
250    /// Get the current cache size.
251    pub async fn size(&self) -> usize {
252        let cache = self.cache.read().await;
253        cache.len()
254    }
255
256    /// Check if the cache is empty.
257    pub async fn is_empty(&self) -> bool {
258        let cache = self.cache.read().await;
259        cache.is_empty()
260    }
261
262    /// Get the maximum cache size.
263    pub fn max_size(&self) -> Option<usize> {
264        self.max_size
265    }
266}
267
268impl Default for InMemoryCache {
269    fn default() -> Self {
270        Self::new()
271    }
272}
273
274#[async_trait]
275impl BaseCache for InMemoryCache {
276    fn lookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>> {
277        let key = Self::generate_key(prompt, llm_string);
278        let mut cache = self.cache.try_write().map_err(|e| {
279            crate::errors::FerricLinkError::runtime(format!("Cache lock error: {e}"))
280        })?;
281        let mut stats = self.stats.try_write().map_err(|e| {
282            crate::errors::FerricLinkError::runtime(format!("Stats lock error: {e}"))
283        })?;
284
285        if let Some(entry) = cache.get_mut(&key) {
286            // Update access information
287            entry.last_accessed = Instant::now();
288            entry.access_count += 1;
289            stats.hits += 1;
290            Ok(Some(entry.data.clone()))
291        } else {
292            stats.misses += 1;
293            Ok(None)
294        }
295    }
296
297    fn update(&self, prompt: &str, llm_string: &str, return_val: CachedGenerations) -> Result<()> {
298        let key = Self::generate_key(prompt, llm_string);
299        let mut cache = self.cache.try_write().map_err(|e| {
300            crate::errors::FerricLinkError::runtime(format!("Cache lock error: {e}"))
301        })?;
302        let mut stats = self.stats.try_write().map_err(|e| {
303            crate::errors::FerricLinkError::runtime(format!("Stats lock error: {e}"))
304        })?;
305
306        // Evict if needed before adding new entry
307        if let Some(max_size) = self.max_size {
308            if cache.len() >= max_size {
309                // Find and remove the least recently used entry
310                let mut oldest_key = None;
311                let mut oldest_time = Instant::now();
312
313                for (key, entry) in cache.iter() {
314                    if entry.last_accessed < oldest_time {
315                        oldest_time = entry.last_accessed;
316                        oldest_key = Some(key.clone());
317                    }
318                }
319
320                if let Some(key) = oldest_key {
321                    cache.remove(&key);
322                }
323            }
324        }
325
326        let entry = CacheEntry {
327            data: return_val,
328            created_at: Instant::now(),
329            last_accessed: Instant::now(),
330            access_count: 0,
331        };
332
333        cache.insert(key, entry);
334        stats.updates += 1;
335        stats.max_size_reached = stats.max_size_reached.max(cache.len());
336
337        Ok(())
338    }
339
340    fn clear(&self) -> Result<()> {
341        let mut cache = self.cache.try_write().map_err(|e| {
342            crate::errors::FerricLinkError::runtime(format!("Cache lock error: {e}"))
343        })?;
344        let mut stats = self.stats.try_write().map_err(|e| {
345            crate::errors::FerricLinkError::runtime(format!("Stats lock error: {e}"))
346        })?;
347
348        cache.clear();
349        stats.clears += 1;
350        Ok(())
351    }
352
353    async fn alookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>> {
354        let key = Self::generate_key(prompt, llm_string);
355        let mut cache = self.cache.write().await;
356        let mut stats = self.stats.write().await;
357
358        if let Some(entry) = cache.get_mut(&key) {
359            // Update access information
360            entry.last_accessed = Instant::now();
361            entry.access_count += 1;
362            stats.hits += 1;
363            Ok(Some(entry.data.clone()))
364        } else {
365            stats.misses += 1;
366            Ok(None)
367        }
368    }
369
370    async fn aupdate(
371        &self,
372        prompt: &str,
373        llm_string: &str,
374        return_val: CachedGenerations,
375    ) -> Result<()> {
376        let key = Self::generate_key(prompt, llm_string);
377        let mut cache = self.cache.write().await;
378        let mut stats = self.stats.write().await;
379
380        // Evict if needed before adding new entry
381        if let Some(max_size) = self.max_size {
382            if cache.len() >= max_size {
383                // Find and remove the least recently used entry
384                let mut oldest_key = None;
385                let mut oldest_time = Instant::now();
386
387                for (key, entry) in cache.iter() {
388                    if entry.last_accessed < oldest_time {
389                        oldest_time = entry.last_accessed;
390                        oldest_key = Some(key.clone());
391                    }
392                }
393
394                if let Some(key) = oldest_key {
395                    cache.remove(&key);
396                }
397            }
398        }
399
400        let entry = CacheEntry {
401            data: return_val,
402            created_at: Instant::now(),
403            last_accessed: Instant::now(),
404            access_count: 0,
405        };
406
407        cache.insert(key, entry);
408        stats.updates += 1;
409        stats.max_size_reached = stats.max_size_reached.max(cache.len());
410
411        Ok(())
412    }
413
414    async fn aclear(&self) -> Result<()> {
415        let mut cache = self.cache.write().await;
416        let mut stats = self.stats.write().await;
417
418        cache.clear();
419        stats.clears += 1;
420        Ok(())
421    }
422}
423
424impl_serializable!(CacheStats, ["ferriclink", "caches", "cache_stats"]);
425
426/// A more advanced cache with TTL (Time To Live) support.
427#[derive(Debug)]
428pub struct TtlCache {
429    /// The underlying cache
430    inner: InMemoryCache,
431    /// Default TTL for cache entries
432    default_ttl: Duration,
433}
434
435impl TtlCache {
436    /// Create a new TTL cache with the specified TTL.
437    ///
438    /// # Arguments
439    ///
440    /// * `default_ttl` - The default time-to-live for cache entries.
441    /// * `max_size` - Optional maximum size for the cache.
442    pub fn new(default_ttl: Duration, max_size: Option<usize>) -> Self {
443        Self {
444            inner: InMemoryCache::with_max_size(max_size),
445            default_ttl,
446        }
447    }
448
449    /// Get the default TTL.
450    pub fn default_ttl(&self) -> Duration {
451        self.default_ttl
452    }
453
454    /// Get cache statistics.
455    pub async fn stats(&self) -> CacheStats {
456        self.inner.stats().await
457    }
458
459    /// Check if an entry has expired.
460    fn is_expired(entry: &CacheEntry, ttl: Duration) -> bool {
461        entry.created_at.elapsed() > ttl
462    }
463}
464
465#[async_trait]
466impl BaseCache for TtlCache {
467    fn lookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>> {
468        // For TTL cache, we need to check expiration
469        let key = InMemoryCache::generate_key(prompt, llm_string);
470        let mut cache = self.inner.cache.try_write().map_err(|e| {
471            crate::errors::FerricLinkError::runtime(format!("Cache lock error: {e}"))
472        })?;
473        let mut stats = self.inner.stats.try_write().map_err(|e| {
474            crate::errors::FerricLinkError::runtime(format!("Stats lock error: {e}"))
475        })?;
476
477        if let Some(entry) = cache.get(&key) {
478            if Self::is_expired(entry, self.default_ttl) {
479                // Entry has expired, remove it
480                cache.remove(&key);
481                stats.misses += 1;
482                Ok(None)
483            } else {
484                // Entry is still valid, update access info
485                let mut entry = entry.clone();
486                entry.last_accessed = Instant::now();
487                entry.access_count += 1;
488                let data = entry.data.clone();
489                cache.insert(key, entry);
490                stats.hits += 1;
491                Ok(Some(data))
492            }
493        } else {
494            stats.misses += 1;
495            Ok(None)
496        }
497    }
498
499    fn update(&self, prompt: &str, llm_string: &str, return_val: CachedGenerations) -> Result<()> {
500        self.inner.update(prompt, llm_string, return_val)
501    }
502
503    fn clear(&self) -> Result<()> {
504        self.inner.clear()
505    }
506
507    async fn alookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>> {
508        // For TTL cache, we need to check expiration
509        let key = InMemoryCache::generate_key(prompt, llm_string);
510        let mut cache = self.inner.cache.write().await;
511        let mut stats = self.inner.stats.write().await;
512
513        if let Some(entry) = cache.get(&key) {
514            if Self::is_expired(entry, self.default_ttl) {
515                // Entry has expired, remove it
516                cache.remove(&key);
517                stats.misses += 1;
518                Ok(None)
519            } else {
520                // Entry is still valid, update access info
521                let mut entry = entry.clone();
522                entry.last_accessed = Instant::now();
523                entry.access_count += 1;
524                let data = entry.data.clone();
525                cache.insert(key, entry);
526                stats.hits += 1;
527                Ok(Some(data))
528            }
529        } else {
530            stats.misses += 1;
531            Ok(None)
532        }
533    }
534
535    async fn aupdate(
536        &self,
537        prompt: &str,
538        llm_string: &str,
539        return_val: CachedGenerations,
540    ) -> Result<()> {
541        self.inner.aupdate(prompt, llm_string, return_val).await
542    }
543
544    async fn aclear(&self) -> Result<()> {
545        self.inner.aclear().await
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use crate::language_models::Generation;
553
554    fn create_test_generation(text: &str) -> Generation {
555        Generation {
556            text: text.to_string(),
557            generation_info: std::collections::HashMap::new(),
558        }
559    }
560
561    #[test]
562    fn test_in_memory_cache_basic() {
563        let cache = InMemoryCache::new();
564
565        // Test empty cache
566        assert!(cache.lookup("test", "llm").unwrap().is_none());
567
568        // Test update and lookup
569        let generations = vec![create_test_generation("Hello, world!")];
570        cache.update("test", "llm", generations.clone()).unwrap();
571
572        let result = cache.lookup("test", "llm").unwrap();
573        assert!(result.is_some());
574        assert_eq!(result.unwrap(), generations);
575    }
576
577    #[test]
578    fn test_in_memory_cache_with_max_size() {
579        let cache = InMemoryCache::with_max_size(Some(2));
580
581        // Add entries up to max size
582        cache
583            .update("test1", "llm", vec![create_test_generation("1")])
584            .unwrap();
585        cache
586            .update("test2", "llm", vec![create_test_generation("2")])
587            .unwrap();
588
589        // This should evict the first entry
590        cache
591            .update("test3", "llm", vec![create_test_generation("3")])
592            .unwrap();
593
594        assert!(cache.lookup("test1", "llm").unwrap().is_none());
595        assert!(cache.lookup("test2", "llm").unwrap().is_some());
596        assert!(cache.lookup("test3", "llm").unwrap().is_some());
597    }
598
599    #[test]
600    fn test_in_memory_cache_clear() {
601        let cache = InMemoryCache::new();
602
603        cache
604            .update("test", "llm", vec![create_test_generation("Hello")])
605            .unwrap();
606        assert!(cache.lookup("test", "llm").unwrap().is_some());
607
608        cache.clear().unwrap();
609        assert!(cache.lookup("test", "llm").unwrap().is_none());
610    }
611
612    #[tokio::test]
613    async fn test_in_memory_cache_async() {
614        let cache = InMemoryCache::new();
615
616        // Test async operations
617        let generations = vec![create_test_generation("Async test")];
618        cache
619            .aupdate("test", "llm", generations.clone())
620            .await
621            .unwrap();
622
623        let result = cache.alookup("test", "llm").await.unwrap();
624        assert!(result.is_some());
625        assert_eq!(result.unwrap(), generations);
626    }
627
628    #[tokio::test]
629    async fn test_cache_stats() {
630        let cache = InMemoryCache::new();
631
632        // Initial stats
633        let stats = cache.stats().await;
634        assert_eq!(stats.hits, 0);
635        assert_eq!(stats.misses, 0);
636        assert_eq!(stats.updates, 0);
637
638        // Test miss
639        cache.alookup("test", "llm").await.unwrap();
640        let stats = cache.stats().await;
641        assert_eq!(stats.misses, 1);
642
643        // Test update
644        cache
645            .aupdate("test", "llm", vec![create_test_generation("Hello")])
646            .await
647            .unwrap();
648        let stats = cache.stats().await;
649        assert_eq!(stats.updates, 1);
650
651        // Test hit
652        cache.alookup("test", "llm").await.unwrap();
653        let stats = cache.stats().await;
654        assert_eq!(stats.hits, 1);
655        assert_eq!(stats.misses, 1);
656    }
657
658    #[test]
659    fn test_ttl_cache() {
660        let cache = TtlCache::new(Duration::from_millis(100), None);
661
662        // Add entry
663        cache
664            .update("test", "llm", vec![create_test_generation("TTL test")])
665            .unwrap();
666        assert!(cache.lookup("test", "llm").unwrap().is_some());
667
668        // Wait for expiration
669        std::thread::sleep(Duration::from_millis(150));
670        assert!(cache.lookup("test", "llm").unwrap().is_none());
671    }
672
673    #[test]
674    fn test_cache_key_generation() {
675        let key1 = InMemoryCache::generate_key("prompt1", "llm1");
676        let key2 = InMemoryCache::generate_key("prompt2", "llm1");
677        let key3 = InMemoryCache::generate_key("prompt1", "llm2");
678        let key4 = InMemoryCache::generate_key("prompt1", "llm1");
679
680        assert_ne!(key1, key2);
681        assert_ne!(key1, key3);
682        assert_eq!(key1, key4);
683    }
684}