1use async_trait::async_trait;
21use serde::{Deserialize, Serialize};
22use std::collections::HashMap;
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25use tokio::sync::RwLock;
26
27use crate::errors::Result;
28use crate::impl_serializable;
29use crate::language_models::Generation;
30
31pub type CachedGenerations = Vec<Generation>;
33
34#[async_trait]
48pub trait BaseCache: Send + Sync {
49 fn lookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>>;
70
71 fn update(&self, prompt: &str, llm_string: &str, return_val: CachedGenerations) -> Result<()>;
89
90 fn clear(&self) -> Result<()>;
92
93 async fn alookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>>;
114
115 async fn aupdate(
133 &self,
134 prompt: &str,
135 llm_string: &str,
136 return_val: CachedGenerations,
137 ) -> Result<()>;
138
139 async fn aclear(&self) -> Result<()>;
141}
142
143#[derive(Debug)]
148pub struct InMemoryCache {
149 cache: Arc<RwLock<HashMap<String, CacheEntry>>>,
151 max_size: Option<usize>,
153 stats: Arc<RwLock<CacheStats>>,
155}
156
157#[derive(Debug, Clone)]
159struct CacheEntry {
160 data: CachedGenerations,
162 created_at: Instant,
164 last_accessed: Instant,
166 access_count: u64,
168}
169
170#[derive(Debug, Clone, Default, Serialize, Deserialize)]
172pub struct CacheStats {
173 pub hits: u64,
175 pub misses: u64,
177 pub updates: u64,
179 pub clears: u64,
181 pub current_size: usize,
183 pub max_size_reached: usize,
185}
186
187impl CacheStats {
188 pub fn hit_rate(&self) -> f64 {
190 let total = self.hits + self.misses;
191 if total == 0 {
192 0.0
193 } else {
194 (self.hits as f64 / total as f64) * 100.0
195 }
196 }
197
198 pub fn total_requests(&self) -> u64 {
200 self.hits + self.misses
201 }
202}
203
204impl InMemoryCache {
205 pub fn new() -> Self {
207 Self::with_max_size(None)
208 }
209
210 pub fn with_max_size(max_size: Option<usize>) -> Self {
222 if let Some(size) = max_size {
223 assert!(size > 0, "max_size must be greater than 0");
224 }
225
226 Self {
227 cache: Arc::new(RwLock::new(HashMap::new())),
228 max_size,
229 stats: Arc::new(RwLock::new(CacheStats::default())),
230 }
231 }
232
233 fn generate_key(prompt: &str, llm_string: &str) -> String {
235 format!("{prompt}|||{llm_string}")
238 }
239
240 pub async fn stats(&self) -> CacheStats {
242 let stats = self.stats.read().await;
243 let cache = self.cache.read().await;
244 CacheStats {
245 current_size: cache.len(),
246 ..stats.clone()
247 }
248 }
249
250 pub async fn size(&self) -> usize {
252 let cache = self.cache.read().await;
253 cache.len()
254 }
255
256 pub async fn is_empty(&self) -> bool {
258 let cache = self.cache.read().await;
259 cache.is_empty()
260 }
261
262 pub fn max_size(&self) -> Option<usize> {
264 self.max_size
265 }
266}
267
268impl Default for InMemoryCache {
269 fn default() -> Self {
270 Self::new()
271 }
272}
273
274#[async_trait]
275impl BaseCache for InMemoryCache {
276 fn lookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>> {
277 let key = Self::generate_key(prompt, llm_string);
278 let mut cache = self.cache.try_write().map_err(|e| {
279 crate::errors::FerricLinkError::runtime(format!("Cache lock error: {e}"))
280 })?;
281 let mut stats = self.stats.try_write().map_err(|e| {
282 crate::errors::FerricLinkError::runtime(format!("Stats lock error: {e}"))
283 })?;
284
285 if let Some(entry) = cache.get_mut(&key) {
286 entry.last_accessed = Instant::now();
288 entry.access_count += 1;
289 stats.hits += 1;
290 Ok(Some(entry.data.clone()))
291 } else {
292 stats.misses += 1;
293 Ok(None)
294 }
295 }
296
297 fn update(&self, prompt: &str, llm_string: &str, return_val: CachedGenerations) -> Result<()> {
298 let key = Self::generate_key(prompt, llm_string);
299 let mut cache = self.cache.try_write().map_err(|e| {
300 crate::errors::FerricLinkError::runtime(format!("Cache lock error: {e}"))
301 })?;
302 let mut stats = self.stats.try_write().map_err(|e| {
303 crate::errors::FerricLinkError::runtime(format!("Stats lock error: {e}"))
304 })?;
305
306 if let Some(max_size) = self.max_size {
308 if cache.len() >= max_size {
309 let mut oldest_key = None;
311 let mut oldest_time = Instant::now();
312
313 for (key, entry) in cache.iter() {
314 if entry.last_accessed < oldest_time {
315 oldest_time = entry.last_accessed;
316 oldest_key = Some(key.clone());
317 }
318 }
319
320 if let Some(key) = oldest_key {
321 cache.remove(&key);
322 }
323 }
324 }
325
326 let entry = CacheEntry {
327 data: return_val,
328 created_at: Instant::now(),
329 last_accessed: Instant::now(),
330 access_count: 0,
331 };
332
333 cache.insert(key, entry);
334 stats.updates += 1;
335 stats.max_size_reached = stats.max_size_reached.max(cache.len());
336
337 Ok(())
338 }
339
340 fn clear(&self) -> Result<()> {
341 let mut cache = self.cache.try_write().map_err(|e| {
342 crate::errors::FerricLinkError::runtime(format!("Cache lock error: {e}"))
343 })?;
344 let mut stats = self.stats.try_write().map_err(|e| {
345 crate::errors::FerricLinkError::runtime(format!("Stats lock error: {e}"))
346 })?;
347
348 cache.clear();
349 stats.clears += 1;
350 Ok(())
351 }
352
353 async fn alookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>> {
354 let key = Self::generate_key(prompt, llm_string);
355 let mut cache = self.cache.write().await;
356 let mut stats = self.stats.write().await;
357
358 if let Some(entry) = cache.get_mut(&key) {
359 entry.last_accessed = Instant::now();
361 entry.access_count += 1;
362 stats.hits += 1;
363 Ok(Some(entry.data.clone()))
364 } else {
365 stats.misses += 1;
366 Ok(None)
367 }
368 }
369
370 async fn aupdate(
371 &self,
372 prompt: &str,
373 llm_string: &str,
374 return_val: CachedGenerations,
375 ) -> Result<()> {
376 let key = Self::generate_key(prompt, llm_string);
377 let mut cache = self.cache.write().await;
378 let mut stats = self.stats.write().await;
379
380 if let Some(max_size) = self.max_size {
382 if cache.len() >= max_size {
383 let mut oldest_key = None;
385 let mut oldest_time = Instant::now();
386
387 for (key, entry) in cache.iter() {
388 if entry.last_accessed < oldest_time {
389 oldest_time = entry.last_accessed;
390 oldest_key = Some(key.clone());
391 }
392 }
393
394 if let Some(key) = oldest_key {
395 cache.remove(&key);
396 }
397 }
398 }
399
400 let entry = CacheEntry {
401 data: return_val,
402 created_at: Instant::now(),
403 last_accessed: Instant::now(),
404 access_count: 0,
405 };
406
407 cache.insert(key, entry);
408 stats.updates += 1;
409 stats.max_size_reached = stats.max_size_reached.max(cache.len());
410
411 Ok(())
412 }
413
414 async fn aclear(&self) -> Result<()> {
415 let mut cache = self.cache.write().await;
416 let mut stats = self.stats.write().await;
417
418 cache.clear();
419 stats.clears += 1;
420 Ok(())
421 }
422}
423
424impl_serializable!(CacheStats, ["ferriclink", "caches", "cache_stats"]);
425
426#[derive(Debug)]
428pub struct TtlCache {
429 inner: InMemoryCache,
431 default_ttl: Duration,
433}
434
435impl TtlCache {
436 pub fn new(default_ttl: Duration, max_size: Option<usize>) -> Self {
443 Self {
444 inner: InMemoryCache::with_max_size(max_size),
445 default_ttl,
446 }
447 }
448
449 pub fn default_ttl(&self) -> Duration {
451 self.default_ttl
452 }
453
454 pub async fn stats(&self) -> CacheStats {
456 self.inner.stats().await
457 }
458
459 fn is_expired(entry: &CacheEntry, ttl: Duration) -> bool {
461 entry.created_at.elapsed() > ttl
462 }
463}
464
465#[async_trait]
466impl BaseCache for TtlCache {
467 fn lookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>> {
468 let key = InMemoryCache::generate_key(prompt, llm_string);
470 let mut cache = self.inner.cache.try_write().map_err(|e| {
471 crate::errors::FerricLinkError::runtime(format!("Cache lock error: {e}"))
472 })?;
473 let mut stats = self.inner.stats.try_write().map_err(|e| {
474 crate::errors::FerricLinkError::runtime(format!("Stats lock error: {e}"))
475 })?;
476
477 if let Some(entry) = cache.get(&key) {
478 if Self::is_expired(entry, self.default_ttl) {
479 cache.remove(&key);
481 stats.misses += 1;
482 Ok(None)
483 } else {
484 let mut entry = entry.clone();
486 entry.last_accessed = Instant::now();
487 entry.access_count += 1;
488 let data = entry.data.clone();
489 cache.insert(key, entry);
490 stats.hits += 1;
491 Ok(Some(data))
492 }
493 } else {
494 stats.misses += 1;
495 Ok(None)
496 }
497 }
498
499 fn update(&self, prompt: &str, llm_string: &str, return_val: CachedGenerations) -> Result<()> {
500 self.inner.update(prompt, llm_string, return_val)
501 }
502
503 fn clear(&self) -> Result<()> {
504 self.inner.clear()
505 }
506
507 async fn alookup(&self, prompt: &str, llm_string: &str) -> Result<Option<CachedGenerations>> {
508 let key = InMemoryCache::generate_key(prompt, llm_string);
510 let mut cache = self.inner.cache.write().await;
511 let mut stats = self.inner.stats.write().await;
512
513 if let Some(entry) = cache.get(&key) {
514 if Self::is_expired(entry, self.default_ttl) {
515 cache.remove(&key);
517 stats.misses += 1;
518 Ok(None)
519 } else {
520 let mut entry = entry.clone();
522 entry.last_accessed = Instant::now();
523 entry.access_count += 1;
524 let data = entry.data.clone();
525 cache.insert(key, entry);
526 stats.hits += 1;
527 Ok(Some(data))
528 }
529 } else {
530 stats.misses += 1;
531 Ok(None)
532 }
533 }
534
535 async fn aupdate(
536 &self,
537 prompt: &str,
538 llm_string: &str,
539 return_val: CachedGenerations,
540 ) -> Result<()> {
541 self.inner.aupdate(prompt, llm_string, return_val).await
542 }
543
544 async fn aclear(&self) -> Result<()> {
545 self.inner.aclear().await
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use crate::language_models::Generation;
553
554 fn create_test_generation(text: &str) -> Generation {
555 Generation {
556 text: text.to_string(),
557 generation_info: std::collections::HashMap::new(),
558 }
559 }
560
561 #[test]
562 fn test_in_memory_cache_basic() {
563 let cache = InMemoryCache::new();
564
565 assert!(cache.lookup("test", "llm").unwrap().is_none());
567
568 let generations = vec![create_test_generation("Hello, world!")];
570 cache.update("test", "llm", generations.clone()).unwrap();
571
572 let result = cache.lookup("test", "llm").unwrap();
573 assert!(result.is_some());
574 assert_eq!(result.unwrap(), generations);
575 }
576
577 #[test]
578 fn test_in_memory_cache_with_max_size() {
579 let cache = InMemoryCache::with_max_size(Some(2));
580
581 cache
583 .update("test1", "llm", vec![create_test_generation("1")])
584 .unwrap();
585 cache
586 .update("test2", "llm", vec![create_test_generation("2")])
587 .unwrap();
588
589 cache
591 .update("test3", "llm", vec![create_test_generation("3")])
592 .unwrap();
593
594 assert!(cache.lookup("test1", "llm").unwrap().is_none());
595 assert!(cache.lookup("test2", "llm").unwrap().is_some());
596 assert!(cache.lookup("test3", "llm").unwrap().is_some());
597 }
598
599 #[test]
600 fn test_in_memory_cache_clear() {
601 let cache = InMemoryCache::new();
602
603 cache
604 .update("test", "llm", vec![create_test_generation("Hello")])
605 .unwrap();
606 assert!(cache.lookup("test", "llm").unwrap().is_some());
607
608 cache.clear().unwrap();
609 assert!(cache.lookup("test", "llm").unwrap().is_none());
610 }
611
612 #[tokio::test]
613 async fn test_in_memory_cache_async() {
614 let cache = InMemoryCache::new();
615
616 let generations = vec![create_test_generation("Async test")];
618 cache
619 .aupdate("test", "llm", generations.clone())
620 .await
621 .unwrap();
622
623 let result = cache.alookup("test", "llm").await.unwrap();
624 assert!(result.is_some());
625 assert_eq!(result.unwrap(), generations);
626 }
627
628 #[tokio::test]
629 async fn test_cache_stats() {
630 let cache = InMemoryCache::new();
631
632 let stats = cache.stats().await;
634 assert_eq!(stats.hits, 0);
635 assert_eq!(stats.misses, 0);
636 assert_eq!(stats.updates, 0);
637
638 cache.alookup("test", "llm").await.unwrap();
640 let stats = cache.stats().await;
641 assert_eq!(stats.misses, 1);
642
643 cache
645 .aupdate("test", "llm", vec![create_test_generation("Hello")])
646 .await
647 .unwrap();
648 let stats = cache.stats().await;
649 assert_eq!(stats.updates, 1);
650
651 cache.alookup("test", "llm").await.unwrap();
653 let stats = cache.stats().await;
654 assert_eq!(stats.hits, 1);
655 assert_eq!(stats.misses, 1);
656 }
657
658 #[test]
659 fn test_ttl_cache() {
660 let cache = TtlCache::new(Duration::from_millis(100), None);
661
662 cache
664 .update("test", "llm", vec![create_test_generation("TTL test")])
665 .unwrap();
666 assert!(cache.lookup("test", "llm").unwrap().is_some());
667
668 std::thread::sleep(Duration::from_millis(150));
670 assert!(cache.lookup("test", "llm").unwrap().is_none());
671 }
672
673 #[test]
674 fn test_cache_key_generation() {
675 let key1 = InMemoryCache::generate_key("prompt1", "llm1");
676 let key2 = InMemoryCache::generate_key("prompt2", "llm1");
677 let key3 = InMemoryCache::generate_key("prompt1", "llm2");
678 let key4 = InMemoryCache::generate_key("prompt1", "llm1");
679
680 assert_ne!(key1, key2);
681 assert_ne!(key1, key3);
682 assert_eq!(key1, key4);
683 }
684}