1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use crate::documents::Document;
11use crate::errors::Result;
12use crate::impl_serializable;
13use crate::runnables::{Runnable, RunnableConfig};
14use crate::vectorstores::VectorStore;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct RetrieverResult {
19 pub documents: Vec<Document>,
21 #[serde(default)]
23 pub metadata: HashMap<String, serde_json::Value>,
24}
25
26impl RetrieverResult {
27 pub fn new(documents: Vec<Document>) -> Self {
29 Self {
30 documents,
31 metadata: HashMap::new(),
32 }
33 }
34
35 pub fn new_with_metadata(
37 documents: Vec<Document>,
38 metadata: HashMap<String, serde_json::Value>,
39 ) -> Self {
40 Self {
41 documents,
42 metadata,
43 }
44 }
45
46 pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
48 self.metadata.insert(key.into(), value);
49 }
50
51 pub fn len(&self) -> usize {
53 self.documents.len()
54 }
55
56 pub fn is_empty(&self) -> bool {
58 self.documents.is_empty()
59 }
60}
61
62impl_serializable!(
63 RetrieverResult,
64 ["ferriclink", "retrievers", "retriever_result"]
65);
66
67#[async_trait]
69pub trait BaseRetriever: Send + Sync + 'static {
70 async fn get_relevant_documents(
72 &self,
73 query: &str,
74 config: Option<RunnableConfig>,
75 ) -> Result<RetrieverResult>;
76
77 async fn get_relevant_documents_batch(
79 &self,
80 queries: Vec<String>,
81 config: Option<RunnableConfig>,
82 ) -> Result<Vec<RetrieverResult>> {
83 let mut results = Vec::new();
84 for query in queries {
85 let result = self.get_relevant_documents(&query, config.clone()).await?;
86 results.push(result);
87 }
88 Ok(results)
89 }
90
91 fn input_schema(&self) -> Option<serde_json::Value> {
93 None
94 }
95
96 fn output_schema(&self) -> Option<serde_json::Value> {
98 None
99 }
100}
101
102pub struct VectorStoreRetriever {
104 vector_store: Box<dyn VectorStore>,
105 search_kwargs: HashMap<String, serde_json::Value>,
106}
107
108impl VectorStoreRetriever {
109 pub fn new(vector_store: Box<dyn VectorStore>) -> Self {
111 Self {
112 vector_store,
113 search_kwargs: HashMap::new(),
114 }
115 }
116
117 pub fn new_with_kwargs(
119 vector_store: Box<dyn VectorStore>,
120 search_kwargs: HashMap<String, serde_json::Value>,
121 ) -> Self {
122 Self {
123 vector_store,
124 search_kwargs,
125 }
126 }
127
128 pub fn add_search_kwarg(&mut self, key: impl Into<String>, value: serde_json::Value) {
130 self.search_kwargs.insert(key.into(), value);
131 }
132
133 fn get_k(&self) -> usize {
135 self.search_kwargs
136 .get("k")
137 .and_then(|v| v.as_u64())
138 .map(|k| k as usize)
139 .unwrap_or(4)
140 }
141
142 fn get_filter(&self) -> Option<HashMap<String, serde_json::Value>> {
144 self.search_kwargs
145 .get("filter")
146 .and_then(|v| v.as_object())
147 .map(|obj| obj.iter().map(|(k, v)| (k.clone(), v.clone())).collect())
148 }
149}
150
151#[async_trait]
152impl BaseRetriever for VectorStoreRetriever {
153 async fn get_relevant_documents(
154 &self,
155 query: &str,
156 _config: Option<RunnableConfig>,
157 ) -> Result<RetrieverResult> {
158 let k = self.get_k();
159 let filter = self.get_filter();
160
161 let search_results = self
162 .vector_store
163 .similarity_search(query, k, filter)
164 .await?;
165
166 let documents: Vec<Document> = search_results
167 .into_iter()
168 .map(|result| result.document)
169 .collect();
170
171 let mut retriever_result = RetrieverResult::new(documents);
172 retriever_result.add_metadata(
173 "search_type",
174 serde_json::Value::String("similarity".to_string()),
175 );
176 retriever_result.add_metadata("k", serde_json::Value::Number(serde_json::Number::from(k)));
177
178 Ok(retriever_result)
179 }
180}
181
182pub struct RunnableRetriever<R> {
184 retriever: R,
185}
186
187impl<R> RunnableRetriever<R>
188where
189 R: BaseRetriever,
190{
191 pub fn new(retriever: R) -> Self {
193 Self { retriever }
194 }
195}
196
197#[async_trait]
198impl<R> Runnable<String, RetrieverResult> for RunnableRetriever<R>
199where
200 R: BaseRetriever,
201{
202 async fn invoke(
203 &self,
204 input: String,
205 config: Option<RunnableConfig>,
206 ) -> Result<RetrieverResult> {
207 self.retriever.get_relevant_documents(&input, config).await
208 }
209}
210
211pub struct MultiRetriever {
213 retrievers: Vec<Box<dyn BaseRetriever>>,
214 combine_method: CombineMethod,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
219pub enum CombineMethod {
220 #[default]
222 Union,
223 Intersection,
225 First,
227 Last,
229}
230
231impl MultiRetriever {
232 pub fn new(retrievers: Vec<Box<dyn BaseRetriever>>) -> Self {
234 Self {
235 retrievers,
236 combine_method: CombineMethod::Union,
237 }
238 }
239
240 pub fn new_with_method(
242 retrievers: Vec<Box<dyn BaseRetriever>>,
243 combine_method: CombineMethod,
244 ) -> Self {
245 Self {
246 retrievers,
247 combine_method,
248 }
249 }
250
251 pub fn add_retriever(&mut self, retriever: Box<dyn BaseRetriever>) {
253 self.retrievers.push(retriever);
254 }
255
256 pub fn set_combine_method(&mut self, method: CombineMethod) {
258 self.combine_method = method;
259 }
260
261 fn combine_results(&self, results: Vec<RetrieverResult>) -> RetrieverResult {
263 match self.combine_method {
264 CombineMethod::Union => {
265 let mut all_documents = Vec::new();
266 let mut combined_metadata = HashMap::new();
267
268 for result in results {
269 all_documents.extend(result.documents);
270 combined_metadata.extend(result.metadata);
271 }
272
273 RetrieverResult::new_with_metadata(all_documents, combined_metadata)
274 }
275 CombineMethod::Intersection => {
276 if results.is_empty() {
277 return RetrieverResult::new(Vec::new());
278 }
279
280 let mut intersection = results[0].documents.clone();
281 for result in results.iter().skip(1) {
282 intersection.retain(|doc| {
283 result
284 .documents
285 .iter()
286 .any(|other_doc| doc.page_content == other_doc.page_content)
287 });
288 }
289
290 RetrieverResult::new(intersection)
291 }
292 CombineMethod::First => results
293 .into_iter()
294 .next()
295 .unwrap_or_else(|| RetrieverResult::new(Vec::new())),
296 CombineMethod::Last => results
297 .into_iter()
298 .last()
299 .unwrap_or_else(|| RetrieverResult::new(Vec::new())),
300 }
301 }
302}
303
304#[async_trait]
305impl BaseRetriever for MultiRetriever {
306 async fn get_relevant_documents(
307 &self,
308 query: &str,
309 config: Option<RunnableConfig>,
310 ) -> Result<RetrieverResult> {
311 let mut results = Vec::new();
312
313 for retriever in &self.retrievers {
314 let result = retriever
315 .get_relevant_documents(query, config.clone())
316 .await?;
317 results.push(result);
318 }
319
320 Ok(self.combine_results(results))
321 }
322}
323
324pub fn vector_store_retriever(vector_store: Box<dyn VectorStore>) -> VectorStoreRetriever {
326 VectorStoreRetriever::new(vector_store)
327}
328
329pub fn runnable_retriever<R>(retriever: R) -> RunnableRetriever<R>
331where
332 R: BaseRetriever,
333{
334 RunnableRetriever::new(retriever)
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use crate::serializable::Serializable;
341 use crate::vectorstores::InMemoryVectorStore;
342
343 #[test]
344 fn test_retriever_result() {
345 let docs = vec![Document::new("Document 1"), Document::new("Document 2")];
346 let result = RetrieverResult::new(docs.clone());
347
348 assert_eq!(result.documents, docs);
349 assert_eq!(result.len(), 2);
350 assert!(!result.is_empty());
351 }
352
353 #[tokio::test]
354 async fn test_vector_store_retriever() {
355 let vector_store = Box::new(InMemoryVectorStore::new());
356
357 let docs = vec![
359 Document::new("Hello world"),
360 Document::new("Rust is awesome"),
361 Document::new("Python is great"),
362 ];
363 vector_store.add_documents(docs, None).await.unwrap();
364
365 let retriever = VectorStoreRetriever::new(vector_store);
367
368 let result = retriever
370 .get_relevant_documents("Hello", None)
371 .await
372 .unwrap();
373 assert!(!result.is_empty());
374 assert_eq!(result.len(), 3); }
376
377 #[tokio::test]
378 async fn test_vector_store_retriever_default_k() {
379 let vector_store = Box::new(InMemoryVectorStore::new());
380
381 let docs = vec![
383 Document::new("Hello world"),
384 Document::new("Rust is awesome"),
385 Document::new("Python is great"),
386 Document::new("Check out FerricLink!"),
387 Document::new(
388 "FerricLink is a Rust library for building AI applications, inspired by LangChain.",
389 ),
390 ];
391 vector_store.add_documents(docs, None).await.unwrap();
392
393 let retriever = VectorStoreRetriever::new(vector_store);
395
396 let result = retriever
398 .get_relevant_documents("Hello", None)
399 .await
400 .unwrap();
401 assert!(!result.is_empty());
402 assert_eq!(result.len(), 4); }
404
405 #[tokio::test]
406 async fn test_vector_store_retriever_with_kwargs() {
407 let vector_store = Box::new(InMemoryVectorStore::new());
408
409 let docs = vec![Document::new("Document 1"), Document::new("Document 2")];
410 vector_store.add_documents(docs, None).await.unwrap();
411
412 let mut search_kwargs = HashMap::new();
413 search_kwargs.insert(
414 "k".to_string(),
415 serde_json::Value::Number(serde_json::Number::from(1)),
416 );
417
418 let retriever = VectorStoreRetriever::new_with_kwargs(vector_store, search_kwargs);
419
420 let result = retriever
421 .get_relevant_documents("Document", None)
422 .await
423 .unwrap();
424 assert_eq!(result.len(), 1);
425 }
426
427 #[tokio::test]
428 async fn test_runnable_retriever() {
429 let vector_store = Box::new(InMemoryVectorStore::new());
430 let retriever = VectorStoreRetriever::new(vector_store);
431 let runnable_retriever = RunnableRetriever::new(retriever);
432
433 let result = runnable_retriever
434 .invoke("test query".to_string(), None)
435 .await
436 .unwrap();
437 assert!(result.is_empty()); }
439
440 #[tokio::test]
441 async fn test_multi_retriever_union() {
442 let vector_store1 = Box::new(InMemoryVectorStore::new());
443 let vector_store2 = Box::new(InMemoryVectorStore::new());
444
445 vector_store1
447 .add_documents(vec![Document::new("Store 1 doc")], None)
448 .await
449 .unwrap();
450 vector_store2
451 .add_documents(vec![Document::new("Store 2 doc")], None)
452 .await
453 .unwrap();
454
455 let retriever1 = VectorStoreRetriever::new(vector_store1);
456 let retriever2 = VectorStoreRetriever::new(vector_store2);
457
458 let multi_retriever = MultiRetriever::new(vec![Box::new(retriever1), Box::new(retriever2)]);
459
460 let result = multi_retriever
461 .get_relevant_documents("doc", None)
462 .await
463 .unwrap();
464 assert_eq!(result.len(), 2); }
466
467 #[tokio::test]
468 async fn test_multi_retriever_first() {
469 let vector_store1 = Box::new(InMemoryVectorStore::new());
470 let vector_store2 = Box::new(InMemoryVectorStore::new());
471
472 vector_store1
473 .add_documents(vec![Document::new("First doc")], None)
474 .await
475 .unwrap();
476 vector_store2
477 .add_documents(vec![Document::new("Second doc")], None)
478 .await
479 .unwrap();
480
481 let retriever1 = VectorStoreRetriever::new(vector_store1);
482 let retriever2 = VectorStoreRetriever::new(vector_store2);
483
484 let multi_retriever = MultiRetriever::new_with_method(
485 vec![Box::new(retriever1), Box::new(retriever2)],
486 CombineMethod::First,
487 );
488
489 let result = multi_retriever
490 .get_relevant_documents("doc", None)
491 .await
492 .unwrap();
493 assert_eq!(result.len(), 1); }
495
496 #[test]
497 fn test_combine_methods() {
498 assert_eq!(CombineMethod::default(), CombineMethod::Union);
499 }
500
501 #[test]
502 fn test_serialization() {
503 let docs = vec![Document::new("Test document")];
504 let result = RetrieverResult::new(docs);
505 let json = result.to_json().unwrap();
506 let deserialized: RetrieverResult = RetrieverResult::from_json(&json).unwrap();
507 assert_eq!(result, deserialized);
508 }
509}