ferriclink_core/
retrievers.rs

1//! Retriever abstractions for FerricLink Core
2//!
3//! This module provides the core abstractions for retrievers that can
4//! fetch relevant documents based on queries.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use crate::documents::Document;
11use crate::errors::Result;
12use crate::impl_serializable;
13use crate::runnables::{Runnable, RunnableConfig};
14use crate::vectorstores::VectorStore;
15
16/// A retriever result containing documents and metadata
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct RetrieverResult {
19    /// The retrieved documents
20    pub documents: Vec<Document>,
21    /// Additional metadata about the retrieval
22    #[serde(default)]
23    pub metadata: HashMap<String, serde_json::Value>,
24}
25
26impl RetrieverResult {
27    /// Create a new retriever result
28    pub fn new(documents: Vec<Document>) -> Self {
29        Self {
30            documents,
31            metadata: HashMap::new(),
32        }
33    }
34
35    /// Create a new retriever result with metadata
36    pub fn new_with_metadata(
37        documents: Vec<Document>,
38        metadata: HashMap<String, serde_json::Value>,
39    ) -> Self {
40        Self {
41            documents,
42            metadata,
43        }
44    }
45
46    /// Add metadata to the result
47    pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
48        self.metadata.insert(key.into(), value);
49    }
50
51    /// Get the number of documents
52    pub fn len(&self) -> usize {
53        self.documents.len()
54    }
55
56    /// Check if the result is empty
57    pub fn is_empty(&self) -> bool {
58        self.documents.is_empty()
59    }
60}
61
62impl_serializable!(
63    RetrieverResult,
64    ["ferriclink", "retrievers", "retriever_result"]
65);
66
67/// Base trait for all retrievers
68#[async_trait]
69pub trait BaseRetriever: Send + Sync + 'static {
70    /// Retrieve documents based on a query
71    async fn get_relevant_documents(
72        &self,
73        query: &str,
74        config: Option<RunnableConfig>,
75    ) -> Result<RetrieverResult>;
76
77    /// Retrieve documents for multiple queries
78    async fn get_relevant_documents_batch(
79        &self,
80        queries: Vec<String>,
81        config: Option<RunnableConfig>,
82    ) -> Result<Vec<RetrieverResult>> {
83        let mut results = Vec::new();
84        for query in queries {
85            let result = self.get_relevant_documents(&query, config.clone()).await?;
86            results.push(result);
87        }
88        Ok(results)
89    }
90
91    /// Get the input schema for this retriever
92    fn input_schema(&self) -> Option<serde_json::Value> {
93        None
94    }
95
96    /// Get the output schema for this retriever
97    fn output_schema(&self) -> Option<serde_json::Value> {
98        None
99    }
100}
101
102/// A retriever that wraps a vector store
103pub struct VectorStoreRetriever {
104    vector_store: Box<dyn VectorStore>,
105    search_kwargs: HashMap<String, serde_json::Value>,
106}
107
108impl VectorStoreRetriever {
109    /// Create a new vector store retriever
110    pub fn new(vector_store: Box<dyn VectorStore>) -> Self {
111        Self {
112            vector_store,
113            search_kwargs: HashMap::new(),
114        }
115    }
116
117    /// Create a new vector store retriever with search parameters
118    pub fn new_with_kwargs(
119        vector_store: Box<dyn VectorStore>,
120        search_kwargs: HashMap<String, serde_json::Value>,
121    ) -> Self {
122        Self {
123            vector_store,
124            search_kwargs,
125        }
126    }
127
128    /// Add a search parameter
129    pub fn add_search_kwarg(&mut self, key: impl Into<String>, value: serde_json::Value) {
130        self.search_kwargs.insert(key.into(), value);
131    }
132
133    /// Get the number of documents to retrieve
134    fn get_k(&self) -> usize {
135        self.search_kwargs
136            .get("k")
137            .and_then(|v| v.as_u64())
138            .map(|k| k as usize)
139            .unwrap_or(4)
140    }
141
142    /// Get the filter for the search
143    fn get_filter(&self) -> Option<HashMap<String, serde_json::Value>> {
144        self.search_kwargs
145            .get("filter")
146            .and_then(|v| v.as_object())
147            .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
148    }
149}
150
151#[async_trait]
152impl BaseRetriever for VectorStoreRetriever {
153    async fn get_relevant_documents(
154        &self,
155        query: &str,
156        _config: Option<RunnableConfig>,
157    ) -> Result<RetrieverResult> {
158        let k = self.get_k();
159        let filter = self.get_filter();
160
161        let search_results = self
162            .vector_store
163            .similarity_search(query, k, filter)
164            .await?;
165
166        let documents: Vec<Document> = search_results
167            .into_iter()
168            .map(|result| result.document)
169            .collect();
170
171        let mut retriever_result = RetrieverResult::new(documents);
172        retriever_result.add_metadata(
173            "search_type",
174            serde_json::Value::String("similarity".to_string()),
175        );
176        retriever_result.add_metadata("k", serde_json::Value::Number(serde_json::Number::from(k)));
177
178        Ok(retriever_result)
179    }
180}
181
182/// A retriever that can be used as a runnable
183pub struct RunnableRetriever<R> {
184    retriever: R,
185}
186
187impl<R> RunnableRetriever<R>
188where
189    R: BaseRetriever,
190{
191    /// Create a new runnable retriever
192    pub fn new(retriever: R) -> Self {
193        Self { retriever }
194    }
195}
196
197#[async_trait]
198impl<R> Runnable<String, RetrieverResult> for RunnableRetriever<R>
199where
200    R: BaseRetriever,
201{
202    async fn invoke(
203        &self,
204        input: String,
205        config: Option<RunnableConfig>,
206    ) -> Result<RetrieverResult> {
207        self.retriever.get_relevant_documents(&input, config).await
208    }
209}
210
211/// A retriever that combines multiple retrievers
212pub struct MultiRetriever {
213    retrievers: Vec<Box<dyn BaseRetriever>>,
214    combine_method: CombineMethod,
215}
216
217/// Method for combining results from multiple retrievers
218#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
219pub enum CombineMethod {
220    /// Take the union of all results
221    #[default]
222    Union,
223    /// Take the intersection of all results
224    Intersection,
225    /// Take the first retriever's results
226    First,
227    /// Take the last retriever's results
228    Last,
229}
230
231impl MultiRetriever {
232    /// Create a new multi-retriever
233    pub fn new(retrievers: Vec<Box<dyn BaseRetriever>>) -> Self {
234        Self {
235            retrievers,
236            combine_method: CombineMethod::Union,
237        }
238    }
239
240    /// Create a new multi-retriever with a specific combine method
241    pub fn new_with_method(
242        retrievers: Vec<Box<dyn BaseRetriever>>,
243        combine_method: CombineMethod,
244    ) -> Self {
245        Self {
246            retrievers,
247            combine_method,
248        }
249    }
250
251    /// Add a retriever to the multi-retriever
252    pub fn add_retriever(&mut self, retriever: Box<dyn BaseRetriever>) {
253        self.retrievers.push(retriever);
254    }
255
256    /// Set the combine method
257    pub fn set_combine_method(&mut self, method: CombineMethod) {
258        self.combine_method = method;
259    }
260
261    /// Combine results from multiple retrievers
262    fn combine_results(&self, results: Vec<RetrieverResult>) -> RetrieverResult {
263        match self.combine_method {
264            CombineMethod::Union => {
265                let mut all_documents = Vec::new();
266                let mut combined_metadata = HashMap::new();
267
268                for result in results {
269                    all_documents.extend(result.documents);
270                    combined_metadata.extend(result.metadata);
271                }
272
273                RetrieverResult::new_with_metadata(all_documents, combined_metadata)
274            }
275            CombineMethod::Intersection => {
276                if results.is_empty() {
277                    return RetrieverResult::new(Vec::new());
278                }
279
280                let mut intersection = results[0].documents.clone();
281                for result in results.iter().skip(1) {
282                    intersection.retain(|doc| {
283                        result
284                            .documents
285                            .iter()
286                            .any(|other_doc| doc.page_content == other_doc.page_content)
287                    });
288                }
289
290                RetrieverResult::new(intersection)
291            }
292            CombineMethod::First => results
293                .into_iter()
294                .next()
295                .unwrap_or_else(|| RetrieverResult::new(Vec::new())),
296            CombineMethod::Last => results
297                .into_iter()
298                .last()
299                .unwrap_or_else(|| RetrieverResult::new(Vec::new())),
300        }
301    }
302}
303
304#[async_trait]
305impl BaseRetriever for MultiRetriever {
306    async fn get_relevant_documents(
307        &self,
308        query: &str,
309        config: Option<RunnableConfig>,
310    ) -> Result<RetrieverResult> {
311        let mut results = Vec::new();
312
313        for retriever in &self.retrievers {
314            let result = retriever
315                .get_relevant_documents(query, config.clone())
316                .await?;
317            results.push(result);
318        }
319
320        Ok(self.combine_results(results))
321    }
322}
323
324/// Helper function to create a vector store retriever
325pub fn vector_store_retriever(vector_store: Box<dyn VectorStore>) -> VectorStoreRetriever {
326    VectorStoreRetriever::new(vector_store)
327}
328
329/// Helper function to create a runnable retriever
330pub fn runnable_retriever<R>(retriever: R) -> RunnableRetriever<R>
331where
332    R: BaseRetriever,
333{
334    RunnableRetriever::new(retriever)
335}
336
337#[cfg(test)]
338mod tests {
339    use super::*;
340    use crate::serializable::Serializable;
341    use crate::vectorstores::InMemoryVectorStore;
342
343    #[test]
344    fn test_retriever_result() {
345        let docs = vec![Document::new("Document 1"), Document::new("Document 2")];
346        let result = RetrieverResult::new(docs.clone());
347
348        assert_eq!(result.documents, docs);
349        assert_eq!(result.len(), 2);
350        assert!(!result.is_empty());
351    }
352
353    #[tokio::test]
354    async fn test_vector_store_retriever() {
355        let vector_store = Box::new(InMemoryVectorStore::new());
356
357        // Add some documents to the vector store
358        let docs = vec![
359            Document::new("Hello world"),
360            Document::new("Rust is awesome"),
361            Document::new("Python is great"),
362        ];
363        vector_store.add_documents(docs, None).await.unwrap();
364
365        // Create retriever
366        let retriever = VectorStoreRetriever::new(vector_store);
367
368        // Retrieve documents
369        let result = retriever
370            .get_relevant_documents("Hello", None)
371            .await
372            .unwrap();
373        assert!(!result.is_empty());
374        assert_eq!(result.len(), 3); // Only 3 documents in store
375    }
376
377    #[tokio::test]
378    async fn test_vector_store_retriever_default_k() {
379        let vector_store = Box::new(InMemoryVectorStore::new());
380
381        // Add some documents to the vector store
382        let docs = vec![
383            Document::new("Hello world"),
384            Document::new("Rust is awesome"),
385            Document::new("Python is great"),
386            Document::new("Check out FerricLink!"),
387            Document::new(
388                "FerricLink is a Rust library for building AI applications, inspired by LangChain.",
389            ),
390        ];
391        vector_store.add_documents(docs, None).await.unwrap();
392
393        // Create retriever
394        let retriever = VectorStoreRetriever::new(vector_store);
395
396        // Retrieve documents
397        let result = retriever
398            .get_relevant_documents("Hello", None)
399            .await
400            .unwrap();
401        assert!(!result.is_empty());
402        assert_eq!(result.len(), 4); // Only 4 documents retrieved by default with k=4
403    }
404
405    #[tokio::test]
406    async fn test_vector_store_retriever_with_kwargs() {
407        let vector_store = Box::new(InMemoryVectorStore::new());
408
409        let docs = vec![Document::new("Document 1"), Document::new("Document 2")];
410        vector_store.add_documents(docs, None).await.unwrap();
411
412        let mut search_kwargs = HashMap::new();
413        search_kwargs.insert(
414            "k".to_string(),
415            serde_json::Value::Number(serde_json::Number::from(1)),
416        );
417
418        let retriever = VectorStoreRetriever::new_with_kwargs(vector_store, search_kwargs);
419
420        let result = retriever
421            .get_relevant_documents("Document", None)
422            .await
423            .unwrap();
424        assert_eq!(result.len(), 1);
425    }
426
427    #[tokio::test]
428    async fn test_runnable_retriever() {
429        let vector_store = Box::new(InMemoryVectorStore::new());
430        let retriever = VectorStoreRetriever::new(vector_store);
431        let runnable_retriever = RunnableRetriever::new(retriever);
432
433        let result = runnable_retriever
434            .invoke("test query".to_string(), None)
435            .await
436            .unwrap();
437        assert!(result.is_empty()); // Empty vector store
438    }
439
440    #[tokio::test]
441    async fn test_multi_retriever_union() {
442        let vector_store1 = Box::new(InMemoryVectorStore::new());
443        let vector_store2 = Box::new(InMemoryVectorStore::new());
444
445        // Add different documents to each store
446        vector_store1
447            .add_documents(vec![Document::new("Store 1 doc")], None)
448            .await
449            .unwrap();
450        vector_store2
451            .add_documents(vec![Document::new("Store 2 doc")], None)
452            .await
453            .unwrap();
454
455        let retriever1 = VectorStoreRetriever::new(vector_store1);
456        let retriever2 = VectorStoreRetriever::new(vector_store2);
457
458        let multi_retriever = MultiRetriever::new(vec![Box::new(retriever1), Box::new(retriever2)]);
459
460        let result = multi_retriever
461            .get_relevant_documents("doc", None)
462            .await
463            .unwrap();
464        assert_eq!(result.len(), 2); // Union of both results
465    }
466
467    #[tokio::test]
468    async fn test_multi_retriever_first() {
469        let vector_store1 = Box::new(InMemoryVectorStore::new());
470        let vector_store2 = Box::new(InMemoryVectorStore::new());
471
472        vector_store1
473            .add_documents(vec![Document::new("First doc")], None)
474            .await
475            .unwrap();
476        vector_store2
477            .add_documents(vec![Document::new("Second doc")], None)
478            .await
479            .unwrap();
480
481        let retriever1 = VectorStoreRetriever::new(vector_store1);
482        let retriever2 = VectorStoreRetriever::new(vector_store2);
483
484        let multi_retriever = MultiRetriever::new_with_method(
485            vec![Box::new(retriever1), Box::new(retriever2)],
486            CombineMethod::First,
487        );
488
489        let result = multi_retriever
490            .get_relevant_documents("doc", None)
491            .await
492            .unwrap();
493        assert_eq!(result.len(), 1); // Only first retriever's results
494    }
495
496    #[test]
497    fn test_combine_methods() {
498        assert_eq!(CombineMethod::default(), CombineMethod::Union);
499    }
500
501    #[test]
502    fn test_serialization() {
503        let docs = vec![Document::new("Test document")];
504        let result = RetrieverResult::new(docs);
505        let json = result.to_json().unwrap();
506        let deserialized: RetrieverResult = RetrieverResult::from_json(&json).unwrap();
507        assert_eq!(result, deserialized);
508    }
509}