ferriclink_core/
embeddings.rs1use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7
8use crate::errors::Result;
9use crate::impl_serializable;
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13pub struct Embedding {
14 pub values: Vec<f32>,
16 #[serde(default)]
18 pub metadata: std::collections::HashMap<String, serde_json::Value>,
19}
20
21impl Embedding {
22 pub fn new(values: Vec<f32>) -> Self {
24 Self {
25 values,
26 metadata: std::collections::HashMap::new(),
27 }
28 }
29
30 pub fn new_with_metadata(
32 values: Vec<f32>,
33 metadata: std::collections::HashMap<String, serde_json::Value>,
34 ) -> Self {
35 Self { values, metadata }
36 }
37
38 pub fn dimension(&self) -> usize {
40 self.values.len()
41 }
42
43 pub fn cosine_similarity(&self, other: &Embedding) -> f32 {
45 if self.values.len() != other.values.len() {
46 return 0.0;
47 }
48
49 let dot_product: f32 = self
50 .values
51 .iter()
52 .zip(other.values.iter())
53 .map(|(a, b)| a * b)
54 .sum();
55
56 let norm_a: f32 = self.values.iter().map(|x| x * x).sum::<f32>().sqrt();
57 let norm_b: f32 = other.values.iter().map(|x| x * x).sum::<f32>().sqrt();
58
59 if norm_a == 0.0 || norm_b == 0.0 {
60 0.0
61 } else {
62 dot_product / (norm_a * norm_b)
63 }
64 }
65
66 pub fn euclidean_distance(&self, other: &Embedding) -> f32 {
68 if self.values.len() != other.values.len() {
69 return f32::INFINITY;
70 }
71
72 let sum_squared_diffs: f32 = self
73 .values
74 .iter()
75 .zip(other.values.iter())
76 .map(|(a, b)| (a - b).powi(2))
77 .sum();
78
79 sum_squared_diffs.sqrt()
80 }
81}
82
83impl_serializable!(Embedding, ["ferriclink", "embeddings", "embedding"]);
84
85#[async_trait]
87pub trait Embeddings: Send + Sync + 'static {
88 fn dimension(&self) -> usize;
90
91 async fn embed_query(&self, text: &str) -> Result<Embedding>;
93
94 async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Embedding>> {
96 let mut embeddings = Vec::new();
97 for text in texts {
98 let embedding = self.embed_query(text).await?;
99 embeddings.push(embedding);
100 }
101 Ok(embeddings)
102 }
103
104 fn model_name(&self) -> &str;
106
107 fn model_type(&self) -> &str {
109 "embeddings"
110 }
111}
112
113pub struct MockEmbeddings {
115 model_name: String,
116 dimension: usize,
117}
118
119impl MockEmbeddings {
120 pub fn new(model_name: impl Into<String>, dimension: usize) -> Self {
122 Self {
123 model_name: model_name.into(),
124 dimension,
125 }
126 }
127
128 fn generate_mock_embedding(&self, text: &str) -> Embedding {
130 use std::collections::hash_map::DefaultHasher;
131 use std::hash::{Hash, Hasher};
132
133 let mut hasher = DefaultHasher::new();
134 text.hash(&mut hasher);
135 let hash = hasher.finish();
136
137 let mut values = Vec::with_capacity(self.dimension);
138 for i in 0..self.dimension {
139 let seed = hash.wrapping_add(i as u64);
140 let value = (seed as f32 / u64::MAX as f32) * 2.0 - 1.0; values.push(value);
142 }
143
144 Embedding::new(values)
145 }
146}
147
148#[async_trait]
149impl Embeddings for MockEmbeddings {
150 fn dimension(&self) -> usize {
151 self.dimension
152 }
153
154 async fn embed_query(&self, text: &str) -> Result<Embedding> {
155 Ok(self.generate_mock_embedding(text))
156 }
157
158 fn model_name(&self) -> &str {
159 &self.model_name
160 }
161
162 fn model_type(&self) -> &str {
163 "mock_embeddings"
164 }
165}
166
167pub fn mock_embeddings(model_name: impl Into<String>, dimension: usize) -> MockEmbeddings {
169 MockEmbeddings::new(model_name, dimension)
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175 use crate::serializable::Serializable;
176
177 #[test]
178 fn test_embedding_creation() {
179 let values = vec![0.1, 0.2, 0.3, 0.4];
180 let embedding = Embedding::new(values.clone());
181
182 assert_eq!(embedding.values, values);
183 assert_eq!(embedding.dimension(), 4);
184 assert!(embedding.metadata.is_empty());
185 }
186
187 #[test]
188 fn test_embedding_cosine_similarity() {
189 let embedding1 = Embedding::new(vec![1.0, 0.0, 0.0]);
190 let embedding2 = Embedding::new(vec![0.0, 1.0, 0.0]);
191 let embedding3 = Embedding::new(vec![1.0, 0.0, 0.0]);
192
193 assert!((embedding1.cosine_similarity(&embedding2) - 0.0).abs() < 1e-6);
195
196 assert!((embedding1.cosine_similarity(&embedding3) - 1.0).abs() < 1e-6);
198 }
199
200 #[test]
201 fn test_embedding_euclidean_distance() {
202 let embedding1 = Embedding::new(vec![0.0, 0.0]);
203 let embedding2 = Embedding::new(vec![3.0, 4.0]);
204
205 assert!((embedding1.euclidean_distance(&embedding2) - 5.0).abs() < 1e-6);
207 }
208
209 #[tokio::test]
210 async fn test_mock_embeddings() {
211 let embeddings = MockEmbeddings::new("test-model", 128);
212
213 assert_eq!(embeddings.dimension(), 128);
214 assert_eq!(embeddings.model_name(), "test-model");
215 assert_eq!(embeddings.model_type(), "mock_embeddings");
216
217 let embedding = embeddings.embed_query("test text").await.unwrap();
218 assert_eq!(embedding.dimension(), 128);
219
220 let embedding2 = embeddings.embed_query("test text").await.unwrap();
222 assert_eq!(embedding.values, embedding2.values);
223
224 let embedding3 = embeddings.embed_query("different text").await.unwrap();
226 assert_ne!(embedding.values, embedding3.values);
227 }
228
229 #[tokio::test]
230 async fn test_embed_documents() {
231 let embeddings = MockEmbeddings::new("test-model", 64);
232 let texts = vec!["text1".to_string(), "text2".to_string()];
233
234 let results = embeddings.embed_documents(&texts).await.unwrap();
235 assert_eq!(results.len(), 2);
236 assert_eq!(results[0].dimension(), 64);
237 assert_eq!(results[1].dimension(), 64);
238 }
239
240 #[test]
241 fn test_serialization() {
242 let embedding = Embedding::new(vec![0.1, 0.2, 0.3]);
243 let json = embedding.to_json().unwrap();
244 let deserialized: Embedding = Embedding::from_json(&json).unwrap();
245 assert_eq!(embedding, deserialized);
246 }
247}