ferriclink_core/
embeddings.rs

1//! Embedding abstractions for FerricLink Core
2//!
3//! This module provides the core abstractions for text embeddings.
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7
8use crate::errors::Result;
9use crate::impl_serializable;
10
11/// A text embedding
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
13pub struct Embedding {
14    /// The embedding vector
15    pub values: Vec<f32>,
16    /// Metadata about the embedding
17    #[serde(default)]
18    pub metadata: std::collections::HashMap<String, serde_json::Value>,
19}
20
21impl Embedding {
22    /// Create a new embedding
23    pub fn new(values: Vec<f32>) -> Self {
24        Self {
25            values,
26            metadata: std::collections::HashMap::new(),
27        }
28    }
29
30    /// Create a new embedding with metadata
31    pub fn new_with_metadata(
32        values: Vec<f32>,
33        metadata: std::collections::HashMap<String, serde_json::Value>,
34    ) -> Self {
35        Self { values, metadata }
36    }
37
38    /// Get the dimension of the embedding
39    pub fn dimension(&self) -> usize {
40        self.values.len()
41    }
42
43    /// Calculate cosine similarity with another embedding
44    pub fn cosine_similarity(&self, other: &Embedding) -> f32 {
45        if self.values.len() != other.values.len() {
46            return 0.0;
47        }
48
49        let dot_product: f32 = self
50            .values
51            .iter()
52            .zip(other.values.iter())
53            .map(|(a, b)| a * b)
54            .sum();
55
56        let norm_a: f32 = self.values.iter().map(|x| x * x).sum::<f32>().sqrt();
57        let norm_b: f32 = other.values.iter().map(|x| x * x).sum::<f32>().sqrt();
58
59        if norm_a == 0.0 || norm_b == 0.0 {
60            0.0
61        } else {
62            dot_product / (norm_a * norm_b)
63        }
64    }
65
66    /// Calculate Euclidean distance to another embedding
67    pub fn euclidean_distance(&self, other: &Embedding) -> f32 {
68        if self.values.len() != other.values.len() {
69            return f32::INFINITY;
70        }
71
72        let sum_squared_diffs: f32 = self
73            .values
74            .iter()
75            .zip(other.values.iter())
76            .map(|(a, b)| (a - b).powi(2))
77            .sum();
78
79        sum_squared_diffs.sqrt()
80    }
81}
82
83impl_serializable!(Embedding, ["ferriclink", "embeddings", "embedding"]);
84
85/// Base trait for all embedding models
86#[async_trait]
87pub trait Embeddings: Send + Sync + 'static {
88    /// Get the dimension of the embeddings produced by this model
89    fn dimension(&self) -> usize;
90
91    /// Embed a single text
92    async fn embed_query(&self, text: &str) -> Result<Embedding>;
93
94    /// Embed multiple texts
95    async fn embed_documents(&self, texts: &[String]) -> Result<Vec<Embedding>> {
96        let mut embeddings = Vec::new();
97        for text in texts {
98            let embedding = self.embed_query(text).await?;
99            embeddings.push(embedding);
100        }
101        Ok(embeddings)
102    }
103
104    /// Get the model name
105    fn model_name(&self) -> &str;
106
107    /// Get the model type
108    fn model_type(&self) -> &str {
109        "embeddings"
110    }
111}
112
113/// A simple mock embedding model for testing
114pub struct MockEmbeddings {
115    model_name: String,
116    dimension: usize,
117}
118
119impl MockEmbeddings {
120    /// Create a new mock embedding model
121    pub fn new(model_name: impl Into<String>, dimension: usize) -> Self {
122        Self {
123            model_name: model_name.into(),
124            dimension,
125        }
126    }
127
128    /// Generate a mock embedding based on the input text
129    fn generate_mock_embedding(&self, text: &str) -> Embedding {
130        use std::collections::hash_map::DefaultHasher;
131        use std::hash::{Hash, Hasher};
132
133        let mut hasher = DefaultHasher::new();
134        text.hash(&mut hasher);
135        let hash = hasher.finish();
136
137        let mut values = Vec::with_capacity(self.dimension);
138        for i in 0..self.dimension {
139            let seed = hash.wrapping_add(i as u64);
140            let value = (seed as f32 / u64::MAX as f32) * 2.0 - 1.0; // Normalize to [-1, 1]
141            values.push(value);
142        }
143
144        Embedding::new(values)
145    }
146}
147
148#[async_trait]
149impl Embeddings for MockEmbeddings {
150    fn dimension(&self) -> usize {
151        self.dimension
152    }
153
154    async fn embed_query(&self, text: &str) -> Result<Embedding> {
155        Ok(self.generate_mock_embedding(text))
156    }
157
158    fn model_name(&self) -> &str {
159        &self.model_name
160    }
161
162    fn model_type(&self) -> &str {
163        "mock_embeddings"
164    }
165}
166
167/// Helper function to create a mock embedding model
168pub fn mock_embeddings(model_name: impl Into<String>, dimension: usize) -> MockEmbeddings {
169    MockEmbeddings::new(model_name, dimension)
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175    use crate::serializable::Serializable;
176
177    #[test]
178    fn test_embedding_creation() {
179        let values = vec![0.1, 0.2, 0.3, 0.4];
180        let embedding = Embedding::new(values.clone());
181
182        assert_eq!(embedding.values, values);
183        assert_eq!(embedding.dimension(), 4);
184        assert!(embedding.metadata.is_empty());
185    }
186
187    #[test]
188    fn test_embedding_cosine_similarity() {
189        let embedding1 = Embedding::new(vec![1.0, 0.0, 0.0]);
190        let embedding2 = Embedding::new(vec![0.0, 1.0, 0.0]);
191        let embedding3 = Embedding::new(vec![1.0, 0.0, 0.0]);
192
193        // Orthogonal vectors should have similarity 0
194        assert!((embedding1.cosine_similarity(&embedding2) - 0.0).abs() < 1e-6);
195
196        // Identical vectors should have similarity 1
197        assert!((embedding1.cosine_similarity(&embedding3) - 1.0).abs() < 1e-6);
198    }
199
200    #[test]
201    fn test_embedding_euclidean_distance() {
202        let embedding1 = Embedding::new(vec![0.0, 0.0]);
203        let embedding2 = Embedding::new(vec![3.0, 4.0]);
204
205        // Distance should be 5 (3-4-5 triangle)
206        assert!((embedding1.euclidean_distance(&embedding2) - 5.0).abs() < 1e-6);
207    }
208
209    #[tokio::test]
210    async fn test_mock_embeddings() {
211        let embeddings = MockEmbeddings::new("test-model", 128);
212
213        assert_eq!(embeddings.dimension(), 128);
214        assert_eq!(embeddings.model_name(), "test-model");
215        assert_eq!(embeddings.model_type(), "mock_embeddings");
216
217        let embedding = embeddings.embed_query("test text").await.unwrap();
218        assert_eq!(embedding.dimension(), 128);
219
220        // Same text should produce same embedding
221        let embedding2 = embeddings.embed_query("test text").await.unwrap();
222        assert_eq!(embedding.values, embedding2.values);
223
224        // Different text should produce different embedding
225        let embedding3 = embeddings.embed_query("different text").await.unwrap();
226        assert_ne!(embedding.values, embedding3.values);
227    }
228
229    #[tokio::test]
230    async fn test_embed_documents() {
231        let embeddings = MockEmbeddings::new("test-model", 64);
232        let texts = vec!["text1".to_string(), "text2".to_string()];
233
234        let results = embeddings.embed_documents(&texts).await.unwrap();
235        assert_eq!(results.len(), 2);
236        assert_eq!(results[0].dimension(), 64);
237        assert_eq!(results[1].dimension(), 64);
238    }
239
240    #[test]
241    fn test_serialization() {
242        let embedding = Embedding::new(vec![0.1, 0.2, 0.3]);
243        let json = embedding.to_json().unwrap();
244        let deserialized: Embedding = Embedding::from_json(&json).unwrap();
245        assert_eq!(embedding, deserialized);
246    }
247}