ferriclink_core/
language_models.rs

1//! Language model abstractions for FerricLink Core
2//!
3//! This module provides the core abstractions for language models, including
4//! base traits for LLMs and chat models.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::pin::Pin;
10
11use crate::errors::Result;
12use crate::impl_serializable;
13use crate::messages::AnyMessage;
14use crate::runnables::RunnableConfig;
15
16/// Configuration for language model generation
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct GenerationConfig {
19    /// Temperature for generation (0.0 to 1.0)
20    #[serde(default)]
21    pub temperature: Option<f32>,
22    /// Maximum number of tokens to generate
23    #[serde(default)]
24    pub max_tokens: Option<u32>,
25    /// Stop sequences
26    #[serde(default)]
27    pub stop: Vec<String>,
28    /// Top-p sampling parameter
29    #[serde(default)]
30    pub top_p: Option<f32>,
31    /// Top-k sampling parameter
32    #[serde(default)]
33    pub top_k: Option<u32>,
34    /// Presence penalty
35    #[serde(default)]
36    pub presence_penalty: Option<f32>,
37    /// Frequency penalty
38    #[serde(default)]
39    pub frequency_penalty: Option<f32>,
40    /// Whether to stream the response
41    #[serde(default)]
42    pub stream: bool,
43    /// Additional parameters
44    #[serde(default)]
45    pub extra: HashMap<String, serde_json::Value>,
46}
47
48impl Default for GenerationConfig {
49    fn default() -> Self {
50        Self {
51            temperature: Some(0.7),
52            max_tokens: None,
53            stop: Vec::new(),
54            top_p: None,
55            top_k: None,
56            presence_penalty: None,
57            frequency_penalty: None,
58            stream: false,
59            extra: HashMap::new(),
60        }
61    }
62}
63
64impl GenerationConfig {
65    /// Create a new generation config
66    pub fn new() -> Self {
67        Self::default()
68    }
69
70    /// Set the temperature
71    pub fn with_temperature(mut self, temperature: f32) -> Self {
72        self.temperature = Some(temperature);
73        self
74    }
75
76    /// Set the maximum tokens
77    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
78        self.max_tokens = Some(max_tokens);
79        self
80    }
81
82    /// Add a stop sequence
83    pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
84        self.stop.push(stop.into());
85        self
86    }
87
88    /// Enable streaming
89    pub fn with_streaming(mut self, stream: bool) -> Self {
90        self.stream = stream;
91        self
92    }
93
94    /// Add an extra parameter
95    pub fn with_extra(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
96        self.extra.insert(key.into(), value);
97        self
98    }
99}
100
101impl_serializable!(
102    GenerationConfig,
103    ["ferriclink", "language_models", "generation_config"]
104);
105
106/// A generation result from a language model
107#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
108pub struct Generation {
109    /// The generated text
110    pub text: String,
111    /// Generation metadata
112    #[serde(default)]
113    pub generation_info: HashMap<String, serde_json::Value>,
114}
115
116impl Generation {
117    /// Create a new generation
118    pub fn new(text: impl Into<String>) -> Self {
119        Self {
120            text: text.into(),
121            generation_info: HashMap::new(),
122        }
123    }
124
125    /// Create a new generation with metadata
126    pub fn new_with_info(
127        text: impl Into<String>,
128        generation_info: HashMap<String, serde_json::Value>,
129    ) -> Self {
130        Self {
131            text: text.into(),
132            generation_info,
133        }
134    }
135}
136
137impl_serializable!(Generation, ["ferriclink", "language_models", "generation"]);
138
139/// A result containing multiple generations
140#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
141pub struct LLMResult {
142    /// The generations
143    pub generations: Vec<Vec<Generation>>,
144    /// Result metadata
145    #[serde(default)]
146    pub llm_output: HashMap<String, serde_json::Value>,
147}
148
149impl LLMResult {
150    /// Create a new LLM result
151    pub fn new(generations: Vec<Vec<Generation>>) -> Self {
152        Self {
153            generations,
154            llm_output: HashMap::new(),
155        }
156    }
157
158    /// Create a new LLM result with metadata
159    pub fn new_with_output(
160        generations: Vec<Vec<Generation>>,
161        llm_output: HashMap<String, serde_json::Value>,
162    ) -> Self {
163        Self {
164            generations,
165            llm_output,
166        }
167    }
168
169    /// Get the first generation text
170    pub fn first_text(&self) -> Option<&str> {
171        self.generations.first()?.first().map(|g| g.text.as_str())
172    }
173
174    /// Get all generation texts
175    pub fn all_texts(&self) -> Vec<&str> {
176        self.generations
177            .iter()
178            .flat_map(|gens| gens.iter().map(|g| g.text.as_str()))
179            .collect()
180    }
181}
182
183impl_serializable!(LLMResult, ["ferriclink", "language_models", "llm_result"]);
184
185/// Base trait for all language models
186#[async_trait]
187pub trait BaseLanguageModel: Send + Sync + 'static {
188    /// Get the model name
189    fn model_name(&self) -> &str;
190
191    /// Get the model type
192    fn model_type(&self) -> &str;
193
194    /// Check if the model supports streaming
195    fn supports_streaming(&self) -> bool {
196        false
197    }
198
199    /// Get the input schema for this model
200    fn input_schema(&self) -> Option<serde_json::Value> {
201        None
202    }
203
204    /// Get the output schema for this model
205    fn output_schema(&self) -> Option<serde_json::Value> {
206        None
207    }
208}
209
210/// Trait for language models that generate text from text input
211#[async_trait]
212pub trait BaseLLM: BaseLanguageModel {
213    /// Generate text from a prompt
214    async fn generate(
215        &self,
216        prompt: &str,
217        config: Option<GenerationConfig>,
218        runnable_config: Option<RunnableConfig>,
219    ) -> Result<LLMResult>;
220
221    /// Generate text from multiple prompts
222    async fn generate_batch(
223        &self,
224        prompts: Vec<String>,
225        config: Option<GenerationConfig>,
226        runnable_config: Option<RunnableConfig>,
227    ) -> Result<Vec<LLMResult>> {
228        let mut results = Vec::new();
229        for prompt in prompts {
230            let result = self
231                .generate(&prompt, config.clone(), runnable_config.clone())
232                .await?;
233            results.push(result);
234        }
235        Ok(results)
236    }
237
238    /// Stream text generation
239    async fn stream_generate(
240        &self,
241        prompt: &str,
242        config: Option<GenerationConfig>,
243        runnable_config: Option<RunnableConfig>,
244    ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<Generation>> + Send>>> {
245        // Default implementation just yields the single result
246        let result = self.generate(prompt, config, runnable_config).await?;
247        let generation = result
248            .generations
249            .into_iter()
250            .next()
251            .and_then(|gens| gens.into_iter().next())
252            .unwrap_or_else(|| Generation::new(""));
253        let stream = futures::stream::once(async { Ok(generation) });
254        Ok(Box::pin(stream))
255    }
256}
257
258/// Trait for language models that work with chat messages
259#[async_trait]
260pub trait BaseChatModel: BaseLanguageModel {
261    /// Generate a response from chat messages
262    async fn generate_chat(
263        &self,
264        messages: Vec<AnyMessage>,
265        config: Option<GenerationConfig>,
266        runnable_config: Option<RunnableConfig>,
267    ) -> Result<AnyMessage>;
268
269    /// Generate responses from multiple chat conversations
270    async fn generate_chat_batch(
271        &self,
272        conversations: Vec<Vec<AnyMessage>>,
273        config: Option<GenerationConfig>,
274        runnable_config: Option<RunnableConfig>,
275    ) -> Result<Vec<AnyMessage>> {
276        let mut results = Vec::new();
277        for messages in conversations {
278            let result = self
279                .generate_chat(messages, config.clone(), runnable_config.clone())
280                .await?;
281            results.push(result);
282        }
283        Ok(results)
284    }
285
286    /// Stream chat generation
287    async fn stream_chat(
288        &self,
289        messages: Vec<AnyMessage>,
290        config: Option<GenerationConfig>,
291        runnable_config: Option<RunnableConfig>,
292    ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<AnyMessage>> + Send>>> {
293        // Default implementation just yields the single result
294        let result = self
295            .generate_chat(messages, config, runnable_config)
296            .await?;
297        let stream = futures::stream::once(async { Ok(result) });
298        Ok(Box::pin(stream))
299    }
300}
301
302/// A simple mock LLM for testing
303pub struct MockLLM {
304    model_name: String,
305    responses: Vec<String>,
306    current_index: std::sync::atomic::AtomicUsize,
307}
308
309impl MockLLM {
310    /// Create a new mock LLM
311    pub fn new(model_name: impl Into<String>) -> Self {
312        Self {
313            model_name: model_name.into(),
314            responses: Vec::new(),
315            current_index: std::sync::atomic::AtomicUsize::new(0),
316        }
317    }
318
319    /// Add a response to the mock
320    pub fn add_response(mut self, response: impl Into<String>) -> Self {
321        self.responses.push(response.into());
322        self
323    }
324
325    /// Get the next response
326    fn get_next_response(&self) -> String {
327        if self.responses.is_empty() {
328            "Mock response".to_string()
329        } else {
330            let index = self
331                .current_index
332                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
333            self.responses[index % self.responses.len()].clone()
334        }
335    }
336}
337
338#[async_trait]
339impl BaseLanguageModel for MockLLM {
340    fn model_name(&self) -> &str {
341        &self.model_name
342    }
343
344    fn model_type(&self) -> &str {
345        "mock_llm"
346    }
347
348    fn supports_streaming(&self) -> bool {
349        true
350    }
351}
352
353#[async_trait]
354impl BaseLLM for MockLLM {
355    async fn generate(
356        &self,
357        _prompt: &str,
358        _config: Option<GenerationConfig>,
359        _runnable_config: Option<RunnableConfig>,
360    ) -> Result<LLMResult> {
361        let response = self.get_next_response();
362        let generation = Generation::new(response);
363        Ok(LLMResult::new(vec![vec![generation]]))
364    }
365}
366
367/// A simple mock chat model for testing
368pub struct MockChatModel {
369    model_name: String,
370    responses: Vec<String>,
371    current_index: std::sync::atomic::AtomicUsize,
372}
373
374impl MockChatModel {
375    /// Create a new mock chat model
376    pub fn new(model_name: impl Into<String>) -> Self {
377        Self {
378            model_name: model_name.into(),
379            responses: Vec::new(),
380            current_index: std::sync::atomic::AtomicUsize::new(0),
381        }
382    }
383
384    /// Add a response to the mock
385    pub fn add_response(mut self, response: impl Into<String>) -> Self {
386        self.responses.push(response.into());
387        self
388    }
389
390    /// Get the next response
391    fn get_next_response(&self) -> String {
392        if self.responses.is_empty() {
393            "Mock chat response".to_string()
394        } else {
395            let index = self
396                .current_index
397                .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
398            self.responses[index % self.responses.len()].clone()
399        }
400    }
401}
402
403#[async_trait]
404impl BaseLanguageModel for MockChatModel {
405    fn model_name(&self) -> &str {
406        &self.model_name
407    }
408
409    fn model_type(&self) -> &str {
410        "mock_chat_model"
411    }
412
413    fn supports_streaming(&self) -> bool {
414        true
415    }
416}
417
418#[async_trait]
419impl BaseChatModel for MockChatModel {
420    async fn generate_chat(
421        &self,
422        _messages: Vec<AnyMessage>,
423        _config: Option<GenerationConfig>,
424        _runnable_config: Option<RunnableConfig>,
425    ) -> Result<AnyMessage> {
426        let response = self.get_next_response();
427        Ok(AnyMessage::ai(response))
428    }
429}
430
431/// Helper function to create a mock LLM
432pub fn mock_llm(model_name: impl Into<String>) -> MockLLM {
433    MockLLM::new(model_name)
434}
435
436/// Helper function to create a mock chat model
437pub fn mock_chat_model(model_name: impl Into<String>) -> MockChatModel {
438    MockChatModel::new(model_name)
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use crate::messages::BaseMessage;
445    use crate::serializable::Serializable;
446
447    #[test]
448    fn test_generation_config() {
449        let config = GenerationConfig::new()
450            .with_temperature(0.8)
451            .with_max_tokens(100)
452            .with_stop("STOP")
453            .with_streaming(true);
454
455        assert_eq!(config.temperature, Some(0.8));
456        assert_eq!(config.max_tokens, Some(100));
457        assert!(config.stop.contains(&"STOP".to_string()));
458        assert!(config.stream);
459    }
460
461    #[test]
462    fn test_generation() {
463        let generation = Generation::new("Hello, world!");
464        assert_eq!(generation.text, "Hello, world!");
465        assert!(generation.generation_info.is_empty());
466    }
467
468    #[test]
469    fn test_llm_result() {
470        let generations = vec![
471            vec![Generation::new("Hello")],
472            vec![Generation::new("World")],
473        ];
474        let result = LLMResult::new(generations);
475
476        assert_eq!(result.generations.len(), 2);
477        assert_eq!(result.first_text(), Some("Hello"));
478        assert_eq!(result.all_texts(), vec!["Hello", "World"]);
479    }
480
481    #[tokio::test]
482    async fn test_mock_llm() {
483        let llm = MockLLM::new("test-model")
484            .add_response("Response 1")
485            .add_response("Response 2");
486
487        assert_eq!(llm.model_name(), "test-model");
488        assert_eq!(llm.model_type(), "mock_llm");
489        assert!(llm.supports_streaming());
490
491        let result = llm.generate("test prompt", None, None).await.unwrap();
492        assert_eq!(result.first_text(), Some("Response 1"));
493
494        let result2 = llm.generate("test prompt", None, None).await.unwrap();
495        assert_eq!(result2.first_text(), Some("Response 2"));
496    }
497
498    #[tokio::test]
499    async fn test_mock_chat_model() {
500        let chat_model = MockChatModel::new("test-chat-model")
501            .add_response("Chat response 1")
502            .add_response("Chat response 2");
503
504        assert_eq!(chat_model.model_name(), "test-chat-model");
505        assert_eq!(chat_model.model_type(), "mock_chat_model");
506        assert!(chat_model.supports_streaming());
507
508        let messages = vec![AnyMessage::human("Hello")];
509        let result = chat_model
510            .generate_chat(messages, None, None)
511            .await
512            .unwrap();
513        assert!(result.is_ai());
514        assert_eq!(result.text(), "Chat response 1");
515    }
516
517    #[tokio::test]
518    async fn test_llm_batch_generation() {
519        let llm = MockLLM::new("test-model")
520            .add_response("Response 1")
521            .add_response("Response 2");
522
523        let prompts = vec!["prompt 1".to_string(), "prompt 2".to_string()];
524        let results = llm.generate_batch(prompts, None, None).await.unwrap();
525
526        assert_eq!(results.len(), 2);
527        assert_eq!(results[0].first_text(), Some("Response 1"));
528        assert_eq!(results[1].first_text(), Some("Response 2"));
529    }
530
531    #[tokio::test]
532    async fn test_chat_batch_generation() {
533        let chat_model = MockChatModel::new("test-chat-model")
534            .add_response("Chat 1")
535            .add_response("Chat 2");
536
537        let conversations = vec![
538            vec![AnyMessage::human("Hello 1")],
539            vec![AnyMessage::human("Hello 2")],
540        ];
541        let results = chat_model
542            .generate_chat_batch(conversations, None, None)
543            .await
544            .unwrap();
545
546        assert_eq!(results.len(), 2);
547        assert_eq!(results[0].text(), "Chat 1");
548        assert_eq!(results[1].text(), "Chat 2");
549    }
550
551    #[test]
552    fn test_serialization() {
553        let config = GenerationConfig::new().with_temperature(0.8);
554        let json = config.to_json().unwrap();
555        let deserialized: GenerationConfig = GenerationConfig::from_json(&json).unwrap();
556        assert_eq!(config, deserialized);
557    }
558}