ferriclink_core/
example_selectors.rs

1//! Example selectors for FerricLink Core.
2//!
3//! **Example selector** implements logic for selecting examples to include them in prompts.
4//! This allows us to select examples that are most relevant to the input.
5//!
6//! Example selectors are crucial for:
7//! - Few-shot learning and prompting
8//! - Dynamic prompt construction
9//! - Semantic similarity-based example selection
10//! - Length-based prompt management
11//! - Max Marginal Relevance algorithms
12
13use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17use crate::errors::Result;
18use crate::impl_serializable;
19use crate::vectorstores::{VectorSearchResult, VectorStore};
20
21/// A type alias for example data
22pub type Example = HashMap<String, String>;
23
24/// Interface for selecting examples to include in prompts.
25///
26/// Example selectors are used to dynamically choose which examples to include
27/// in prompts based on the input variables. This is essential for few-shot
28/// learning and context-aware prompt construction.
29#[async_trait]
30pub trait BaseExampleSelector: Send + Sync {
31    /// Add a new example to the store.
32    ///
33    /// # Arguments
34    ///
35    /// * `example` - A dictionary with keys as input variables
36    ///   and values as their values.
37    ///
38    /// # Returns
39    ///
40    /// Any return value (e.g., example ID for tracking).
41    fn add_example(&mut self, example: Example) -> Result<()>;
42
43    /// Async add a new example to the store.
44    ///
45    /// # Arguments
46    ///
47    /// * `example` - A dictionary with keys as input variables
48    ///   and values as their values.
49    ///
50    /// # Returns
51    ///
52    /// Any return value (e.g., example ID for tracking).
53    async fn aadd_example(&mut self, example: Example) -> Result<()> {
54        self.add_example(example)
55    }
56
57    /// Select which examples to use based on the inputs.
58    ///
59    /// # Arguments
60    ///
61    /// * `input_variables` - A dictionary with keys as input variables
62    ///   and values as their values.
63    ///
64    /// # Returns
65    ///
66    /// A list of examples to include in the prompt.
67    fn select_examples(&self, input_variables: &Example) -> Result<Vec<Example>>;
68
69    /// Async select which examples to use based on the inputs.
70    ///
71    /// # Arguments
72    ///
73    /// * `input_variables` - A dictionary with keys as input variables
74    ///   and values as their values.
75    ///
76    /// # Returns
77    ///
78    /// A list of examples to include in the prompt.
79    async fn aselect_examples(&self, input_variables: &Example) -> Result<Vec<Example>> {
80        self.select_examples(input_variables)
81    }
82}
83
84/// Select examples based on length.
85///
86/// This selector chooses examples that fit within a maximum length constraint,
87/// making it useful for managing prompt size and token limits.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct LengthBasedExampleSelector {
90    /// A list of the examples that the prompt template expects
91    pub examples: Vec<Example>,
92    /// Function to measure text length. Defaults to word count.
93    #[serde(skip, default = "default_text_length")]
94    pub get_text_length: fn(&str) -> usize,
95    /// Max length for the prompt, beyond which examples are cut
96    pub max_length: usize,
97    /// Length of each example (cached for performance)
98    #[serde(skip)]
99    pub example_text_lengths: Vec<usize>,
100}
101
102/// Default text length function for serialization
103fn default_text_length() -> fn(&str) -> usize {
104    LengthBasedExampleSelector::default_text_length
105}
106
107impl LengthBasedExampleSelector {
108    /// Create a new length-based example selector.
109    ///
110    /// # Arguments
111    ///
112    /// * `examples` - Initial list of examples
113    /// * `max_length` - Maximum total length for selected examples
114    /// * `get_text_length` - Function to measure text length (defaults to word count)
115    pub fn new(
116        examples: Vec<Example>,
117        max_length: usize,
118        get_text_length: Option<fn(&str) -> usize>,
119    ) -> Self {
120        let get_text_length = get_text_length.unwrap_or(Self::default_text_length);
121        let mut selector = Self {
122            examples,
123            get_text_length,
124            max_length,
125            example_text_lengths: Vec::new(),
126        };
127        selector.update_lengths();
128        selector
129    }
130
131    /// Create a new selector with default word count function.
132    pub fn with_word_count(examples: Vec<Example>, max_length: usize) -> Self {
133        Self::new(examples, max_length, Some(Self::default_text_length))
134    }
135
136    /// Create a new selector with character count function.
137    pub fn with_char_count(examples: Vec<Example>, max_length: usize) -> Self {
138        Self::new(examples, max_length, Some(Self::char_length))
139    }
140
141    /// Default text length function (word count).
142    pub fn default_text_length(text: &str) -> usize {
143        text.split_whitespace().count()
144    }
145
146    /// Character count function.
147    pub fn char_length(text: &str) -> usize {
148        text.len()
149    }
150
151    /// Update the cached lengths of examples.
152    fn update_lengths(&mut self) {
153        self.example_text_lengths = self
154            .examples
155            .iter()
156            .map(|example| {
157                let text = self.example_to_text(example);
158                (self.get_text_length)(&text)
159            })
160            .collect();
161    }
162
163    /// Convert an example to text for length calculation.
164    fn example_to_text(&self, example: &Example) -> String {
165        let mut values: Vec<_> = example.values().cloned().collect();
166        values.sort();
167        values.join(" ")
168    }
169
170    /// Get the current total length of examples.
171    pub fn total_length(&self) -> usize {
172        self.example_text_lengths.iter().sum()
173    }
174
175    /// Get the number of examples.
176    pub fn len(&self) -> usize {
177        self.examples.len()
178    }
179
180    /// Check if the selector is empty.
181    pub fn is_empty(&self) -> bool {
182        self.examples.is_empty()
183    }
184}
185
186#[async_trait]
187impl BaseExampleSelector for LengthBasedExampleSelector {
188    fn add_example(&mut self, example: Example) -> Result<()> {
189        self.examples.push(example.clone());
190        let text = self.example_to_text(&example);
191        self.example_text_lengths
192            .push((self.get_text_length)(&text));
193        Ok(())
194    }
195
196    fn select_examples(&self, input_variables: &Example) -> Result<Vec<Example>> {
197        let input_text = self.example_to_text(input_variables);
198        let input_length = (self.get_text_length)(&input_text);
199        let remaining_length = self.max_length.saturating_sub(input_length);
200
201        let mut selected = Vec::new();
202        let mut current_length = 0;
203
204        for (i, example) in self.examples.iter().enumerate() {
205            let example_length = self.example_text_lengths[i];
206            if current_length + example_length <= remaining_length {
207                selected.push(example.clone());
208                current_length += example_length;
209            } else {
210                break;
211            }
212        }
213
214        Ok(selected)
215    }
216}
217
218impl_serializable!(
219    LengthBasedExampleSelector,
220    ["ferriclink", "example_selectors", "length_based"]
221);
222
223/// Select examples based on semantic similarity using vector stores.
224///
225/// This selector uses embeddings and vector similarity search to find
226/// the most relevant examples for a given input.
227pub struct SemanticSimilarityExampleSelector {
228    /// Vector store containing the examples
229    pub vectorstore: Box<dyn VectorStore>,
230    /// Number of examples to select
231    pub k: usize,
232    /// Optional keys to filter examples to
233    pub example_keys: Option<Vec<String>>,
234    /// Optional keys to filter input to
235    pub input_keys: Option<Vec<String>>,
236    /// Extra arguments passed to similarity search
237    pub vectorstore_kwargs: Option<HashMap<String, serde_json::Value>>,
238}
239
240impl SemanticSimilarityExampleSelector {
241    /// Create a new semantic similarity example selector.
242    ///
243    /// # Arguments
244    ///
245    /// * `vectorstore` - Vector store containing the examples
246    /// * `k` - Number of examples to select
247    /// * `example_keys` - Optional keys to filter examples to
248    /// * `input_keys` - Optional keys to filter input to
249    /// * `vectorstore_kwargs` - Extra arguments for similarity search
250    pub fn new(
251        vectorstore: Box<dyn VectorStore>,
252        k: usize,
253        example_keys: Option<Vec<String>>,
254        input_keys: Option<Vec<String>>,
255        vectorstore_kwargs: Option<HashMap<String, serde_json::Value>>,
256    ) -> Self {
257        Self {
258            vectorstore,
259            k,
260            example_keys,
261            input_keys,
262            vectorstore_kwargs,
263        }
264    }
265
266    /// Convert an example to text for embedding.
267    fn example_to_text(&self, example: &Example, input_keys: Option<&[String]>) -> String {
268        let filtered_example = if let Some(keys) = input_keys {
269            example
270                .iter()
271                .filter(|(k, _)| keys.contains(k))
272                .map(|(k, v)| (k.clone(), v.clone()))
273                .collect()
274        } else {
275            example.clone()
276        };
277
278        let mut values: Vec<_> = filtered_example.values().cloned().collect();
279        values.sort();
280        values.join(" ")
281    }
282
283    /// Convert search results to examples.
284    fn search_results_to_examples(&self, results: Vec<VectorSearchResult>) -> Vec<Example> {
285        let mut examples = Vec::new();
286
287        for result in results {
288            let mut example = HashMap::new();
289
290            // Convert metadata to example format
291            for (key, value) in result.metadata {
292                if let Some(str_value) = value.as_str() {
293                    example.insert(key, str_value.to_string());
294                }
295            }
296
297            // Filter by example keys if specified
298            if let Some(ref example_keys) = self.example_keys {
299                example.retain(|k, _| example_keys.contains(k));
300            }
301
302            if !example.is_empty() {
303                examples.push(example);
304            }
305        }
306
307        examples
308    }
309}
310
311#[async_trait]
312impl BaseExampleSelector for SemanticSimilarityExampleSelector {
313    fn add_example(&mut self, _example: Example) -> Result<()> {
314        // For sync version, we can't easily handle async operations
315        // This is a limitation of the current design
316        // In practice, you'd want to use the async version
317        Err(crate::errors::FerricLinkError::generic(
318            "Sync add_example not supported for SemanticSimilarityExampleSelector. Use aadd_example instead.",
319        ))
320    }
321
322    async fn aadd_example(&mut self, example: Example) -> Result<()> {
323        let text = self.example_to_text(&example, self.input_keys.as_deref());
324        // Convert example to the right metadata format
325        let metadata: HashMap<String, serde_json::Value> = example
326            .into_iter()
327            .map(|(k, v)| (k, serde_json::Value::String(v)))
328            .collect();
329
330        // Add to vector store
331        self.vectorstore
332            .add_texts(vec![text], Some(vec![metadata]), None)
333            .await?;
334
335        Ok(())
336    }
337
338    fn select_examples(&self, _input_variables: &Example) -> Result<Vec<Example>> {
339        // For sync version, we can't easily handle async operations
340        // This is a limitation of the current design
341        // In practice, you'd want to use the async version
342        Err(crate::errors::FerricLinkError::generic(
343            "Sync select_examples not supported for SemanticSimilarityExampleSelector. Use aselect_examples instead.",
344        ))
345    }
346
347    async fn aselect_examples(&self, input_variables: &Example) -> Result<Vec<Example>> {
348        let query_text = self.example_to_text(input_variables, self.input_keys.as_deref());
349
350        // Perform async similarity search
351        let results = self
352            .vectorstore
353            .similarity_search(&query_text, self.k, self.vectorstore_kwargs.clone())
354            .await?;
355
356        Ok(self.search_results_to_examples(results))
357    }
358}
359
360/// Select examples based on Max Marginal Relevance (MMR).
361///
362/// MMR balances relevance and diversity in example selection, often
363/// leading to better performance than simple similarity search.
364///
365/// Note: This is a placeholder implementation. MMR requires additional
366/// methods in the VectorStore trait that are not yet implemented.
367pub struct MaxMarginalRelevanceExampleSelector {
368    /// Vector store containing the examples
369    pub vectorstore: Box<dyn VectorStore>,
370    /// Number of examples to select
371    pub k: usize,
372    /// Number of examples to fetch for reranking
373    pub fetch_k: usize,
374    /// Optional keys to filter examples to
375    pub example_keys: Option<Vec<String>>,
376    /// Optional keys to filter input to
377    pub input_keys: Option<Vec<String>>,
378    /// Extra arguments passed to similarity search
379    pub vectorstore_kwargs: Option<HashMap<String, serde_json::Value>>,
380}
381
382impl MaxMarginalRelevanceExampleSelector {
383    /// Create a new MMR example selector.
384    ///
385    /// # Arguments
386    ///
387    /// * `vectorstore` - Vector store containing the examples
388    /// * `k` - Number of examples to select
389    /// * `fetch_k` - Number of examples to fetch for reranking
390    /// * `example_keys` - Optional keys to filter examples to
391    /// * `input_keys` - Optional keys to filter input to
392    /// * `vectorstore_kwargs` - Extra arguments for similarity search
393    pub fn new(
394        vectorstore: Box<dyn VectorStore>,
395        k: usize,
396        fetch_k: usize,
397        example_keys: Option<Vec<String>>,
398        input_keys: Option<Vec<String>>,
399        vectorstore_kwargs: Option<HashMap<String, serde_json::Value>>,
400    ) -> Self {
401        Self {
402            vectorstore,
403            k,
404            fetch_k,
405            example_keys,
406            input_keys,
407            vectorstore_kwargs,
408        }
409    }
410
411    /// Convert an example to text for embedding.
412    fn example_to_text(&self, example: &Example, input_keys: Option<&[String]>) -> String {
413        let filtered_example = if let Some(keys) = input_keys {
414            example
415                .iter()
416                .filter(|(k, _)| keys.contains(k))
417                .map(|(k, v)| (k.clone(), v.clone()))
418                .collect()
419        } else {
420            example.clone()
421        };
422
423        let mut values: Vec<_> = filtered_example.values().cloned().collect();
424        values.sort();
425        values.join(" ")
426    }
427
428    /// Convert search results to examples.
429    fn search_results_to_examples(&self, results: Vec<VectorSearchResult>) -> Vec<Example> {
430        let mut examples = Vec::new();
431
432        for result in results {
433            let mut example = HashMap::new();
434
435            // Convert metadata to example format
436            for (key, value) in result.metadata {
437                if let Some(str_value) = value.as_str() {
438                    example.insert(key, str_value.to_string());
439                }
440            }
441
442            // Filter by example keys if specified
443            if let Some(ref example_keys) = self.example_keys {
444                example.retain(|k, _| example_keys.contains(k));
445            }
446
447            if !example.is_empty() {
448                examples.push(example);
449            }
450        }
451
452        examples
453    }
454}
455
456#[async_trait]
457impl BaseExampleSelector for MaxMarginalRelevanceExampleSelector {
458    fn add_example(&mut self, _example: Example) -> Result<()> {
459        // For sync version, we can't easily handle async operations
460        Err(crate::errors::FerricLinkError::generic(
461            "Sync add_example not supported for MaxMarginalRelevanceExampleSelector. Use aadd_example instead.",
462        ))
463    }
464
465    async fn aadd_example(&mut self, example: Example) -> Result<()> {
466        let text = self.example_to_text(&example, self.input_keys.as_deref());
467        // Convert example to the right metadata format
468        let metadata: HashMap<String, serde_json::Value> = example
469            .into_iter()
470            .map(|(k, v)| (k, serde_json::Value::String(v)))
471            .collect();
472
473        // Add to vector store
474        self.vectorstore
475            .add_texts(vec![text], Some(vec![metadata]), None)
476            .await?;
477
478        Ok(())
479    }
480
481    fn select_examples(&self, _input_variables: &Example) -> Result<Vec<Example>> {
482        // For sync version, we can't easily handle async operations
483        Err(crate::errors::FerricLinkError::generic(
484            "Sync select_examples not supported for MaxMarginalRelevanceExampleSelector. Use aselect_examples instead.",
485        ))
486    }
487
488    async fn aselect_examples(&self, input_variables: &Example) -> Result<Vec<Example>> {
489        let query_text = self.example_to_text(input_variables, self.input_keys.as_deref());
490
491        // For now, fall back to regular similarity search since MMR is not implemented
492        // TODO: Implement proper MMR algorithm when vector store supports it
493        let results = self
494            .vectorstore
495            .similarity_search(&query_text, self.k, self.vectorstore_kwargs.clone())
496            .await?;
497
498        Ok(self.search_results_to_examples(results))
499    }
500}
501
502/// Utility function to return values in a dictionary sorted by key.
503///
504/// # Arguments
505///
506/// * `values` - A dictionary with keys as input variables
507///   and values as their values.
508///
509/// # Returns
510///
511/// A list of values in dict sorted by key.
512pub fn sorted_values(values: &Example) -> Vec<String> {
513    let mut sorted_pairs: Vec<_> = values.iter().collect();
514    sorted_pairs.sort_by_key(|(key, _)| *key);
515    sorted_pairs
516        .into_iter()
517        .map(|(_, value)| value.clone())
518        .collect()
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    fn create_test_examples() -> Vec<Example> {
526        vec![
527            [("input".to_string(), "What is AI?".to_string())]
528                .iter()
529                .cloned()
530                .collect(),
531            [(
532                "input".to_string(),
533                "How does machine learning work?".to_string(),
534            )]
535            .iter()
536            .cloned()
537            .collect(),
538            [("input".to_string(), "Explain neural networks".to_string())]
539                .iter()
540                .cloned()
541                .collect(),
542        ]
543    }
544
545    #[test]
546    fn test_length_based_selector_basic() {
547        let examples = create_test_examples();
548        let selector = LengthBasedExampleSelector::with_word_count(examples, 10);
549
550        let input = [("input".to_string(), "Tell me about AI".to_string())]
551            .iter()
552            .cloned()
553            .collect();
554
555        let selected = selector.select_examples(&input).unwrap();
556        assert!(!selected.is_empty());
557        assert!(selected.len() <= 3);
558    }
559
560    #[test]
561    fn test_length_based_selector_add_example() {
562        let examples = create_test_examples();
563        let mut selector = LengthBasedExampleSelector::with_word_count(examples, 20);
564
565        let new_example = [("input".to_string(), "What is deep learning?".to_string())]
566            .iter()
567            .cloned()
568            .collect();
569
570        selector.add_example(new_example).unwrap();
571        assert_eq!(selector.len(), 4);
572    }
573
574    #[test]
575    fn test_length_based_selector_max_length() {
576        let examples = create_test_examples();
577        let selector = LengthBasedExampleSelector::with_word_count(examples, 5);
578
579        let input = [("input".to_string(), "AI question".to_string())]
580            .iter()
581            .cloned()
582            .collect();
583
584        let selected = selector.select_examples(&input).unwrap();
585        // Should select fewer examples due to length constraint
586        assert!(selected.len() <= 3);
587    }
588
589    #[test]
590    fn test_sorted_values() {
591        let mut example = HashMap::new();
592        example.insert("z".to_string(), "last".to_string());
593        example.insert("a".to_string(), "first".to_string());
594        example.insert("m".to_string(), "middle".to_string());
595
596        let sorted = sorted_values(&example);
597        assert_eq!(sorted, vec!["first", "middle", "last"]);
598    }
599
600    #[test]
601    fn test_length_based_selector_empty() {
602        let selector = LengthBasedExampleSelector::with_word_count(vec![], 10);
603        assert!(selector.is_empty());
604
605        let input = [("input".to_string(), "test".to_string())]
606            .iter()
607            .cloned()
608            .collect();
609
610        let selected = selector.select_examples(&input).unwrap();
611        assert!(selected.is_empty());
612    }
613
614    #[tokio::test]
615    async fn test_length_based_selector_async() {
616        let examples = create_test_examples();
617        let selector = LengthBasedExampleSelector::with_word_count(examples, 15);
618
619        let input = [("input".to_string(), "AI question".to_string())]
620            .iter()
621            .cloned()
622            .collect();
623
624        let selected = selector.aselect_examples(&input).await.unwrap();
625        assert!(!selected.is_empty());
626    }
627}