1use async_trait::async_trait;
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17use crate::errors::Result;
18use crate::impl_serializable;
19use crate::vectorstores::{VectorSearchResult, VectorStore};
20
21pub type Example = HashMap<String, String>;
23
24#[async_trait]
30pub trait BaseExampleSelector: Send + Sync {
31 fn add_example(&mut self, example: Example) -> Result<()>;
42
43 async fn aadd_example(&mut self, example: Example) -> Result<()> {
54 self.add_example(example)
55 }
56
57 fn select_examples(&self, input_variables: &Example) -> Result<Vec<Example>>;
68
69 async fn aselect_examples(&self, input_variables: &Example) -> Result<Vec<Example>> {
80 self.select_examples(input_variables)
81 }
82}
83
84#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct LengthBasedExampleSelector {
90 pub examples: Vec<Example>,
92 #[serde(skip, default = "default_text_length")]
94 pub get_text_length: fn(&str) -> usize,
95 pub max_length: usize,
97 #[serde(skip)]
99 pub example_text_lengths: Vec<usize>,
100}
101
102fn default_text_length() -> fn(&str) -> usize {
104 LengthBasedExampleSelector::default_text_length
105}
106
107impl LengthBasedExampleSelector {
108 pub fn new(
116 examples: Vec<Example>,
117 max_length: usize,
118 get_text_length: Option<fn(&str) -> usize>,
119 ) -> Self {
120 let get_text_length = get_text_length.unwrap_or(Self::default_text_length);
121 let mut selector = Self {
122 examples,
123 get_text_length,
124 max_length,
125 example_text_lengths: Vec::new(),
126 };
127 selector.update_lengths();
128 selector
129 }
130
131 pub fn with_word_count(examples: Vec<Example>, max_length: usize) -> Self {
133 Self::new(examples, max_length, Some(Self::default_text_length))
134 }
135
136 pub fn with_char_count(examples: Vec<Example>, max_length: usize) -> Self {
138 Self::new(examples, max_length, Some(Self::char_length))
139 }
140
141 pub fn default_text_length(text: &str) -> usize {
143 text.split_whitespace().count()
144 }
145
146 pub fn char_length(text: &str) -> usize {
148 text.len()
149 }
150
151 fn update_lengths(&mut self) {
153 self.example_text_lengths = self
154 .examples
155 .iter()
156 .map(|example| {
157 let text = self.example_to_text(example);
158 (self.get_text_length)(&text)
159 })
160 .collect();
161 }
162
163 fn example_to_text(&self, example: &Example) -> String {
165 let mut values: Vec<_> = example.values().cloned().collect();
166 values.sort();
167 values.join(" ")
168 }
169
170 pub fn total_length(&self) -> usize {
172 self.example_text_lengths.iter().sum()
173 }
174
175 pub fn len(&self) -> usize {
177 self.examples.len()
178 }
179
180 pub fn is_empty(&self) -> bool {
182 self.examples.is_empty()
183 }
184}
185
186#[async_trait]
187impl BaseExampleSelector for LengthBasedExampleSelector {
188 fn add_example(&mut self, example: Example) -> Result<()> {
189 self.examples.push(example.clone());
190 let text = self.example_to_text(&example);
191 self.example_text_lengths
192 .push((self.get_text_length)(&text));
193 Ok(())
194 }
195
196 fn select_examples(&self, input_variables: &Example) -> Result<Vec<Example>> {
197 let input_text = self.example_to_text(input_variables);
198 let input_length = (self.get_text_length)(&input_text);
199 let remaining_length = self.max_length.saturating_sub(input_length);
200
201 let mut selected = Vec::new();
202 let mut current_length = 0;
203
204 for (i, example) in self.examples.iter().enumerate() {
205 let example_length = self.example_text_lengths[i];
206 if current_length + example_length <= remaining_length {
207 selected.push(example.clone());
208 current_length += example_length;
209 } else {
210 break;
211 }
212 }
213
214 Ok(selected)
215 }
216}
217
218impl_serializable!(
219 LengthBasedExampleSelector,
220 ["ferriclink", "example_selectors", "length_based"]
221);
222
223pub struct SemanticSimilarityExampleSelector {
228 pub vectorstore: Box<dyn VectorStore>,
230 pub k: usize,
232 pub example_keys: Option<Vec<String>>,
234 pub input_keys: Option<Vec<String>>,
236 pub vectorstore_kwargs: Option<HashMap<String, serde_json::Value>>,
238}
239
240impl SemanticSimilarityExampleSelector {
241 pub fn new(
251 vectorstore: Box<dyn VectorStore>,
252 k: usize,
253 example_keys: Option<Vec<String>>,
254 input_keys: Option<Vec<String>>,
255 vectorstore_kwargs: Option<HashMap<String, serde_json::Value>>,
256 ) -> Self {
257 Self {
258 vectorstore,
259 k,
260 example_keys,
261 input_keys,
262 vectorstore_kwargs,
263 }
264 }
265
266 fn example_to_text(&self, example: &Example, input_keys: Option<&[String]>) -> String {
268 let filtered_example = if let Some(keys) = input_keys {
269 example
270 .iter()
271 .filter(|(k, _)| keys.contains(k))
272 .map(|(k, v)| (k.clone(), v.clone()))
273 .collect()
274 } else {
275 example.clone()
276 };
277
278 let mut values: Vec<_> = filtered_example.values().cloned().collect();
279 values.sort();
280 values.join(" ")
281 }
282
283 fn search_results_to_examples(&self, results: Vec<VectorSearchResult>) -> Vec<Example> {
285 let mut examples = Vec::new();
286
287 for result in results {
288 let mut example = HashMap::new();
289
290 for (key, value) in result.metadata {
292 if let Some(str_value) = value.as_str() {
293 example.insert(key, str_value.to_string());
294 }
295 }
296
297 if let Some(ref example_keys) = self.example_keys {
299 example.retain(|k, _| example_keys.contains(k));
300 }
301
302 if !example.is_empty() {
303 examples.push(example);
304 }
305 }
306
307 examples
308 }
309}
310
311#[async_trait]
312impl BaseExampleSelector for SemanticSimilarityExampleSelector {
313 fn add_example(&mut self, _example: Example) -> Result<()> {
314 Err(crate::errors::FerricLinkError::generic(
318 "Sync add_example not supported for SemanticSimilarityExampleSelector. Use aadd_example instead.",
319 ))
320 }
321
322 async fn aadd_example(&mut self, example: Example) -> Result<()> {
323 let text = self.example_to_text(&example, self.input_keys.as_deref());
324 let metadata: HashMap<String, serde_json::Value> = example
326 .into_iter()
327 .map(|(k, v)| (k, serde_json::Value::String(v)))
328 .collect();
329
330 self.vectorstore
332 .add_texts(vec![text], Some(vec![metadata]), None)
333 .await?;
334
335 Ok(())
336 }
337
338 fn select_examples(&self, _input_variables: &Example) -> Result<Vec<Example>> {
339 Err(crate::errors::FerricLinkError::generic(
343 "Sync select_examples not supported for SemanticSimilarityExampleSelector. Use aselect_examples instead.",
344 ))
345 }
346
347 async fn aselect_examples(&self, input_variables: &Example) -> Result<Vec<Example>> {
348 let query_text = self.example_to_text(input_variables, self.input_keys.as_deref());
349
350 let results = self
352 .vectorstore
353 .similarity_search(&query_text, self.k, self.vectorstore_kwargs.clone())
354 .await?;
355
356 Ok(self.search_results_to_examples(results))
357 }
358}
359
360pub struct MaxMarginalRelevanceExampleSelector {
368 pub vectorstore: Box<dyn VectorStore>,
370 pub k: usize,
372 pub fetch_k: usize,
374 pub example_keys: Option<Vec<String>>,
376 pub input_keys: Option<Vec<String>>,
378 pub vectorstore_kwargs: Option<HashMap<String, serde_json::Value>>,
380}
381
382impl MaxMarginalRelevanceExampleSelector {
383 pub fn new(
394 vectorstore: Box<dyn VectorStore>,
395 k: usize,
396 fetch_k: usize,
397 example_keys: Option<Vec<String>>,
398 input_keys: Option<Vec<String>>,
399 vectorstore_kwargs: Option<HashMap<String, serde_json::Value>>,
400 ) -> Self {
401 Self {
402 vectorstore,
403 k,
404 fetch_k,
405 example_keys,
406 input_keys,
407 vectorstore_kwargs,
408 }
409 }
410
411 fn example_to_text(&self, example: &Example, input_keys: Option<&[String]>) -> String {
413 let filtered_example = if let Some(keys) = input_keys {
414 example
415 .iter()
416 .filter(|(k, _)| keys.contains(k))
417 .map(|(k, v)| (k.clone(), v.clone()))
418 .collect()
419 } else {
420 example.clone()
421 };
422
423 let mut values: Vec<_> = filtered_example.values().cloned().collect();
424 values.sort();
425 values.join(" ")
426 }
427
428 fn search_results_to_examples(&self, results: Vec<VectorSearchResult>) -> Vec<Example> {
430 let mut examples = Vec::new();
431
432 for result in results {
433 let mut example = HashMap::new();
434
435 for (key, value) in result.metadata {
437 if let Some(str_value) = value.as_str() {
438 example.insert(key, str_value.to_string());
439 }
440 }
441
442 if let Some(ref example_keys) = self.example_keys {
444 example.retain(|k, _| example_keys.contains(k));
445 }
446
447 if !example.is_empty() {
448 examples.push(example);
449 }
450 }
451
452 examples
453 }
454}
455
456#[async_trait]
457impl BaseExampleSelector for MaxMarginalRelevanceExampleSelector {
458 fn add_example(&mut self, _example: Example) -> Result<()> {
459 Err(crate::errors::FerricLinkError::generic(
461 "Sync add_example not supported for MaxMarginalRelevanceExampleSelector. Use aadd_example instead.",
462 ))
463 }
464
465 async fn aadd_example(&mut self, example: Example) -> Result<()> {
466 let text = self.example_to_text(&example, self.input_keys.as_deref());
467 let metadata: HashMap<String, serde_json::Value> = example
469 .into_iter()
470 .map(|(k, v)| (k, serde_json::Value::String(v)))
471 .collect();
472
473 self.vectorstore
475 .add_texts(vec![text], Some(vec![metadata]), None)
476 .await?;
477
478 Ok(())
479 }
480
481 fn select_examples(&self, _input_variables: &Example) -> Result<Vec<Example>> {
482 Err(crate::errors::FerricLinkError::generic(
484 "Sync select_examples not supported for MaxMarginalRelevanceExampleSelector. Use aselect_examples instead.",
485 ))
486 }
487
488 async fn aselect_examples(&self, input_variables: &Example) -> Result<Vec<Example>> {
489 let query_text = self.example_to_text(input_variables, self.input_keys.as_deref());
490
491 let results = self
494 .vectorstore
495 .similarity_search(&query_text, self.k, self.vectorstore_kwargs.clone())
496 .await?;
497
498 Ok(self.search_results_to_examples(results))
499 }
500}
501
502pub fn sorted_values(values: &Example) -> Vec<String> {
513 let mut sorted_pairs: Vec<_> = values.iter().collect();
514 sorted_pairs.sort_by_key(|(key, _)| *key);
515 sorted_pairs
516 .into_iter()
517 .map(|(_, value)| value.clone())
518 .collect()
519}
520
521#[cfg(test)]
522mod tests {
523 use super::*;
524
525 fn create_test_examples() -> Vec<Example> {
526 vec![
527 [("input".to_string(), "What is AI?".to_string())]
528 .iter()
529 .cloned()
530 .collect(),
531 [(
532 "input".to_string(),
533 "How does machine learning work?".to_string(),
534 )]
535 .iter()
536 .cloned()
537 .collect(),
538 [("input".to_string(), "Explain neural networks".to_string())]
539 .iter()
540 .cloned()
541 .collect(),
542 ]
543 }
544
545 #[test]
546 fn test_length_based_selector_basic() {
547 let examples = create_test_examples();
548 let selector = LengthBasedExampleSelector::with_word_count(examples, 10);
549
550 let input = [("input".to_string(), "Tell me about AI".to_string())]
551 .iter()
552 .cloned()
553 .collect();
554
555 let selected = selector.select_examples(&input).unwrap();
556 assert!(!selected.is_empty());
557 assert!(selected.len() <= 3);
558 }
559
560 #[test]
561 fn test_length_based_selector_add_example() {
562 let examples = create_test_examples();
563 let mut selector = LengthBasedExampleSelector::with_word_count(examples, 20);
564
565 let new_example = [("input".to_string(), "What is deep learning?".to_string())]
566 .iter()
567 .cloned()
568 .collect();
569
570 selector.add_example(new_example).unwrap();
571 assert_eq!(selector.len(), 4);
572 }
573
574 #[test]
575 fn test_length_based_selector_max_length() {
576 let examples = create_test_examples();
577 let selector = LengthBasedExampleSelector::with_word_count(examples, 5);
578
579 let input = [("input".to_string(), "AI question".to_string())]
580 .iter()
581 .cloned()
582 .collect();
583
584 let selected = selector.select_examples(&input).unwrap();
585 assert!(selected.len() <= 3);
587 }
588
589 #[test]
590 fn test_sorted_values() {
591 let mut example = HashMap::new();
592 example.insert("z".to_string(), "last".to_string());
593 example.insert("a".to_string(), "first".to_string());
594 example.insert("m".to_string(), "middle".to_string());
595
596 let sorted = sorted_values(&example);
597 assert_eq!(sorted, vec!["first", "middle", "last"]);
598 }
599
600 #[test]
601 fn test_length_based_selector_empty() {
602 let selector = LengthBasedExampleSelector::with_word_count(vec![], 10);
603 assert!(selector.is_empty());
604
605 let input = [("input".to_string(), "test".to_string())]
606 .iter()
607 .cloned()
608 .collect();
609
610 let selected = selector.select_examples(&input).unwrap();
611 assert!(selected.is_empty());
612 }
613
614 #[tokio::test]
615 async fn test_length_based_selector_async() {
616 let examples = create_test_examples();
617 let selector = LengthBasedExampleSelector::with_word_count(examples, 15);
618
619 let input = [("input".to_string(), "AI question".to_string())]
620 .iter()
621 .cloned()
622 .collect();
623
624 let selected = selector.aselect_examples(&input).await.unwrap();
625 assert!(!selected.is_empty());
626 }
627}