1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use crate::documents::Document;
11use crate::embeddings::{Embedding, Embeddings};
12use crate::errors::Result;
13use crate::impl_serializable;
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub struct VectorSearchResult {
18 pub document: Document,
20 pub score: f32,
22 #[serde(default)]
24 pub metadata: HashMap<String, serde_json::Value>,
25}
26
27impl VectorSearchResult {
28 pub fn new(document: Document, score: f32) -> Self {
30 Self {
31 document,
32 score,
33 metadata: HashMap::new(),
34 }
35 }
36
37 pub fn new_with_metadata(
39 document: Document,
40 score: f32,
41 metadata: HashMap<String, serde_json::Value>,
42 ) -> Self {
43 Self {
44 document,
45 score,
46 metadata,
47 }
48 }
49}
50
51impl_serializable!(
52 VectorSearchResult,
53 ["ferriclink", "vectorstores", "search_result"]
54);
55
56#[async_trait]
58pub trait VectorStore: Send + Sync + 'static {
59 async fn add_documents(
61 &self,
62 documents: Vec<Document>,
63 embeddings: Option<Vec<Embedding>>,
64 ) -> Result<Vec<String>>;
65
66 async fn add_texts(
68 &self,
69 texts: Vec<String>,
70 metadatas: Option<Vec<HashMap<String, serde_json::Value>>>,
71 embeddings: Option<Vec<Embedding>>,
72 ) -> Result<Vec<String>> {
73 let documents: Vec<Document> = texts
74 .into_iter()
75 .zip(metadatas.unwrap_or_default().into_iter().cycle())
76 .map(|(text, metadata)| Document::new_with_metadata(text, metadata))
77 .collect();
78
79 self.add_documents(documents, embeddings).await
80 }
81
82 async fn similarity_search(
84 &self,
85 query: &str,
86 k: usize,
87 filter: Option<HashMap<String, serde_json::Value>>,
88 ) -> Result<Vec<VectorSearchResult>>;
89
90 async fn similarity_search_by_embedding(
92 &self,
93 query_embedding: &Embedding,
94 k: usize,
95 filter: Option<HashMap<String, serde_json::Value>>,
96 ) -> Result<Vec<VectorSearchResult>>;
97
98 async fn delete(&self, ids: Vec<String>) -> Result<()>;
100
101 async fn len(&self) -> Result<usize>;
103
104 async fn is_empty(&self) -> Result<bool> {
106 Ok(self.len().await? == 0)
107 }
108
109 async fn clear(&self) -> Result<()>;
111
112 fn embedding_model(&self) -> Option<&dyn Embeddings> {
114 None
115 }
116}
117
118pub struct InMemoryVectorStore {
120 documents: std::sync::Arc<tokio::sync::RwLock<Vec<Document>>>,
121 embeddings: std::sync::Arc<tokio::sync::RwLock<Vec<Embedding>>>,
122 ids: std::sync::Arc<tokio::sync::RwLock<Vec<String>>>,
123 embedding_model: Option<Box<dyn Embeddings>>,
124}
125
126impl InMemoryVectorStore {
127 pub fn new() -> Self {
129 Self {
130 documents: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
131 embeddings: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
132 ids: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
133 embedding_model: None,
134 }
135 }
136
137 pub fn new_with_embeddings(embedding_model: Box<dyn Embeddings>) -> Self {
139 Self {
140 documents: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
141 embeddings: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
142 ids: std::sync::Arc::new(tokio::sync::RwLock::new(Vec::new())),
143 embedding_model: Some(embedding_model),
144 }
145 }
146
147 #[allow(dead_code)]
149 async fn generate_embeddings(&self, texts: &[String]) -> Result<Vec<Embedding>> {
150 if let Some(model) = &self.embedding_model {
151 model.embed_documents(texts).await
152 } else {
153 let mut embeddings = Vec::new();
155 for text in texts {
156 let mut values = Vec::new();
157 for (i, c) in text.chars().enumerate() {
158 values.push((c as u32 as f32 + i as f32) / 100.0);
159 }
160 while values.len() < 128 {
162 values.push(0.0);
163 }
164 embeddings.push(Embedding::new(values));
165 }
166 Ok(embeddings)
167 }
168 }
169
170 fn cosine_similarity(&self, a: &Embedding, b: &Embedding) -> f32 {
172 a.cosine_similarity(b)
173 }
174}
175
176#[async_trait]
177impl VectorStore for InMemoryVectorStore {
178 async fn add_documents(
179 &self,
180 documents: Vec<Document>,
181 embeddings: Option<Vec<Embedding>>,
182 ) -> Result<Vec<String>> {
183 let mut doc_store = self.documents.write().await;
184 let mut emb_store = self.embeddings.write().await;
185 let mut id_store = self.ids.write().await;
186
187 let generated_embeddings = embeddings.unwrap_or_else(|| {
188 documents
189 .iter()
190 .map(|doc| {
191 let mut values = Vec::new();
192 for (i, c) in doc.page_content.chars().enumerate() {
193 values.push((c as u32 as f32 + i as f32) / 100.0);
194 }
195 while values.len() < 128 {
196 values.push(0.0);
197 }
198 Embedding::new(values)
199 })
200 .collect()
201 });
202
203 let mut ids = Vec::new();
204 for (i, document) in documents.into_iter().enumerate() {
205 let id = uuid::Uuid::new_v4().to_string();
206 ids.push(id.clone());
207 id_store.push(id);
208 doc_store.push(document);
209 emb_store.push(generated_embeddings[i].clone());
210 }
211
212 Ok(ids)
213 }
214
215 async fn similarity_search(
216 &self,
217 query: &str,
218 k: usize,
219 _filter: Option<HashMap<String, serde_json::Value>>,
220 ) -> Result<Vec<VectorSearchResult>> {
221 let query_embedding = if let Some(model) = &self.embedding_model {
223 model.embed_query(query).await?
224 } else {
225 let mut values = Vec::new();
227 for (i, c) in query.chars().enumerate() {
228 values.push((c as u32 as f32 + i as f32) / 100.0);
229 }
230 while values.len() < 128 {
231 values.push(0.0);
232 }
233 Embedding::new(values)
234 };
235
236 self.similarity_search_by_embedding(&query_embedding, k, None)
237 .await
238 }
239
240 async fn similarity_search_by_embedding(
241 &self,
242 query_embedding: &Embedding,
243 k: usize,
244 _filter: Option<HashMap<String, serde_json::Value>>,
245 ) -> Result<Vec<VectorSearchResult>> {
246 let documents = self.documents.read().await;
247 let embeddings = self.embeddings.read().await;
248
249 if documents.is_empty() {
250 return Ok(Vec::new());
251 }
252
253 let mut similarities: Vec<(usize, f32)> = embeddings
255 .iter()
256 .enumerate()
257 .map(|(i, emb)| (i, self.cosine_similarity(query_embedding, emb)))
258 .collect();
259
260 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
262
263 let results: Vec<VectorSearchResult> = similarities
265 .into_iter()
266 .take(k)
267 .map(|(idx, score)| VectorSearchResult::new(documents[idx].clone(), score))
268 .collect();
269
270 Ok(results)
271 }
272
273 async fn delete(&self, ids: Vec<String>) -> Result<()> {
274 let mut doc_store = self.documents.write().await;
275 let mut emb_store = self.embeddings.write().await;
276 let mut id_store = self.ids.write().await;
277
278 let mut indices_to_remove: Vec<usize> = Vec::new();
280 for id in &ids {
281 if let Some(pos) = id_store.iter().position(|x| x == id) {
282 indices_to_remove.push(pos);
283 }
284 }
285
286 indices_to_remove.sort_by(|a, b| b.cmp(a));
288
289 for &idx in &indices_to_remove {
291 if idx < doc_store.len() {
292 doc_store.remove(idx);
293 emb_store.remove(idx);
294 id_store.remove(idx);
295 }
296 }
297
298 Ok(())
299 }
300
301 async fn len(&self) -> Result<usize> {
302 Ok(self.documents.read().await.len())
303 }
304
305 async fn clear(&self) -> Result<()> {
306 let mut doc_store = self.documents.write().await;
307 let mut emb_store = self.embeddings.write().await;
308 let mut id_store = self.ids.write().await;
309
310 doc_store.clear();
311 emb_store.clear();
312 id_store.clear();
313
314 Ok(())
315 }
316
317 fn embedding_model(&self) -> Option<&dyn Embeddings> {
318 self.embedding_model.as_ref().map(|m| m.as_ref())
319 }
320}
321
322impl Default for InMemoryVectorStore {
323 fn default() -> Self {
324 Self::new()
325 }
326}
327
328pub fn in_memory_vector_store() -> InMemoryVectorStore {
330 InMemoryVectorStore::new()
331}
332
333pub fn in_memory_vector_store_with_embeddings(
335 embedding_model: Box<dyn Embeddings>,
336) -> InMemoryVectorStore {
337 InMemoryVectorStore::new_with_embeddings(embedding_model)
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use crate::embeddings::MockEmbeddings;
344 use crate::serializable::Serializable;
345
346 #[tokio::test]
347 async fn test_in_memory_vector_store() {
348 let store = InMemoryVectorStore::new();
349
350 assert!(store.is_empty().await.unwrap());
352 assert_eq!(store.len().await.unwrap(), 0);
353
354 let docs = vec![
356 Document::new("Hello world"),
357 Document::new("Rust is awesome"),
358 ];
359
360 let ids = store.add_documents(docs, None).await.unwrap();
361 assert_eq!(ids.len(), 2);
362 assert_eq!(store.len().await.unwrap(), 2);
363 assert!(!store.is_empty().await.unwrap());
364 }
365
366 #[tokio::test]
367 async fn test_similarity_search() {
368 let store = InMemoryVectorStore::new();
369
370 let docs = vec![
372 Document::new("Hello world"),
373 Document::new("Rust programming language"),
374 Document::new("Python is great"),
375 ];
376
377 store.add_documents(docs, None).await.unwrap();
378
379 let results = store.similarity_search("Hello", 2, None).await.unwrap();
381 assert_eq!(results.len(), 2);
382 assert!(results[0].score >= results[1].score); }
384
385 #[tokio::test]
386 async fn test_delete_documents() {
387 let store = InMemoryVectorStore::new();
388
389 let docs = vec![
391 Document::new("Doc 1"),
392 Document::new("Doc 2"),
393 Document::new("Doc 3"),
394 ];
395
396 let ids = store.add_documents(docs, None).await.unwrap();
397 assert_eq!(store.len().await.unwrap(), 3);
398
399 store.delete(vec![ids[0].clone()]).await.unwrap();
401 assert_eq!(store.len().await.unwrap(), 2);
402
403 store.clear().await.unwrap();
405 assert!(store.is_empty().await.unwrap());
406 }
407
408 #[tokio::test]
409 async fn test_with_embedding_model() {
410 let embedding_model = Box::new(MockEmbeddings::new("test-model", 128));
411 let store = InMemoryVectorStore::new_with_embeddings(embedding_model);
412
413 let docs = vec![Document::new("Test document")];
414 let ids = store.add_documents(docs, None).await.unwrap();
415 assert_eq!(ids.len(), 1);
416
417 let results = store.similarity_search("Test", 1, None).await.unwrap();
418 assert_eq!(results.len(), 1);
419 }
420
421 #[test]
422 fn test_vector_search_result() {
423 let doc = Document::new("Test document");
424 let result = VectorSearchResult::new(doc.clone(), 0.95);
425
426 assert_eq!(result.document, doc);
427 assert_eq!(result.score, 0.95);
428 assert!(result.metadata.is_empty());
429 }
430
431 #[test]
432 fn test_serialization() {
433 let doc = Document::new("Test document");
434 let result = VectorSearchResult::new(doc, 0.95);
435 let json = result.to_json().unwrap();
436 let deserialized: VectorSearchResult = VectorSearchResult::from_json(&json).unwrap();
437 assert_eq!(result, deserialized);
438 }
439}