1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::pin::Pin;
10
11use crate::errors::Result;
12use crate::impl_serializable;
13use crate::messages::AnyMessage;
14use crate::runnables::RunnableConfig;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct GenerationConfig {
19 #[serde(default)]
21 pub temperature: Option<f32>,
22 #[serde(default)]
24 pub max_tokens: Option<u32>,
25 #[serde(default)]
27 pub stop: Vec<String>,
28 #[serde(default)]
30 pub top_p: Option<f32>,
31 #[serde(default)]
33 pub top_k: Option<u32>,
34 #[serde(default)]
36 pub presence_penalty: Option<f32>,
37 #[serde(default)]
39 pub frequency_penalty: Option<f32>,
40 #[serde(default)]
42 pub stream: bool,
43 #[serde(default)]
45 pub extra: HashMap<String, serde_json::Value>,
46}
47
48impl Default for GenerationConfig {
49 fn default() -> Self {
50 Self {
51 temperature: Some(0.7),
52 max_tokens: None,
53 stop: Vec::new(),
54 top_p: None,
55 top_k: None,
56 presence_penalty: None,
57 frequency_penalty: None,
58 stream: false,
59 extra: HashMap::new(),
60 }
61 }
62}
63
64impl GenerationConfig {
65 pub fn new() -> Self {
67 Self::default()
68 }
69
70 pub fn with_temperature(mut self, temperature: f32) -> Self {
72 self.temperature = Some(temperature);
73 self
74 }
75
76 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
78 self.max_tokens = Some(max_tokens);
79 self
80 }
81
82 pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
84 self.stop.push(stop.into());
85 self
86 }
87
88 pub fn with_streaming(mut self, stream: bool) -> Self {
90 self.stream = stream;
91 self
92 }
93
94 pub fn with_extra(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
96 self.extra.insert(key.into(), value);
97 self
98 }
99}
100
101impl_serializable!(
102 GenerationConfig,
103 ["ferriclink", "language_models", "generation_config"]
104);
105
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
108pub struct Generation {
109 pub text: String,
111 #[serde(default)]
113 pub generation_info: HashMap<String, serde_json::Value>,
114}
115
116impl Generation {
117 pub fn new(text: impl Into<String>) -> Self {
119 Self {
120 text: text.into(),
121 generation_info: HashMap::new(),
122 }
123 }
124
125 pub fn new_with_info(
127 text: impl Into<String>,
128 generation_info: HashMap<String, serde_json::Value>,
129 ) -> Self {
130 Self {
131 text: text.into(),
132 generation_info,
133 }
134 }
135}
136
137impl_serializable!(Generation, ["ferriclink", "language_models", "generation"]);
138
139#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
141pub struct LLMResult {
142 pub generations: Vec<Vec<Generation>>,
144 #[serde(default)]
146 pub llm_output: HashMap<String, serde_json::Value>,
147}
148
149impl LLMResult {
150 pub fn new(generations: Vec<Vec<Generation>>) -> Self {
152 Self {
153 generations,
154 llm_output: HashMap::new(),
155 }
156 }
157
158 pub fn new_with_output(
160 generations: Vec<Vec<Generation>>,
161 llm_output: HashMap<String, serde_json::Value>,
162 ) -> Self {
163 Self {
164 generations,
165 llm_output,
166 }
167 }
168
169 pub fn first_text(&self) -> Option<&str> {
171 self.generations.first()?.first().map(|g| g.text.as_str())
172 }
173
174 pub fn all_texts(&self) -> Vec<&str> {
176 self.generations
177 .iter()
178 .flat_map(|gens| gens.iter().map(|g| g.text.as_str()))
179 .collect()
180 }
181}
182
183impl_serializable!(LLMResult, ["ferriclink", "language_models", "llm_result"]);
184
185#[async_trait]
187pub trait BaseLanguageModel: Send + Sync + 'static {
188 fn model_name(&self) -> &str;
190
191 fn model_type(&self) -> &str;
193
194 fn supports_streaming(&self) -> bool {
196 false
197 }
198
199 fn input_schema(&self) -> Option<serde_json::Value> {
201 None
202 }
203
204 fn output_schema(&self) -> Option<serde_json::Value> {
206 None
207 }
208}
209
210#[async_trait]
212pub trait BaseLLM: BaseLanguageModel {
213 async fn generate(
215 &self,
216 prompt: &str,
217 config: Option<GenerationConfig>,
218 runnable_config: Option<RunnableConfig>,
219 ) -> Result<LLMResult>;
220
221 async fn generate_batch(
223 &self,
224 prompts: Vec<String>,
225 config: Option<GenerationConfig>,
226 runnable_config: Option<RunnableConfig>,
227 ) -> Result<Vec<LLMResult>> {
228 let mut results = Vec::new();
229 for prompt in prompts {
230 let result = self
231 .generate(&prompt, config.clone(), runnable_config.clone())
232 .await?;
233 results.push(result);
234 }
235 Ok(results)
236 }
237
238 async fn stream_generate(
240 &self,
241 prompt: &str,
242 config: Option<GenerationConfig>,
243 runnable_config: Option<RunnableConfig>,
244 ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<Generation>> + Send>>> {
245 let result = self.generate(prompt, config, runnable_config).await?;
247 let generation = result
248 .generations
249 .into_iter()
250 .next()
251 .and_then(|gens| gens.into_iter().next())
252 .unwrap_or_else(|| Generation::new(""));
253 let stream = futures::stream::once(async { Ok(generation) });
254 Ok(Box::pin(stream))
255 }
256}
257
258#[async_trait]
260pub trait BaseChatModel: BaseLanguageModel {
261 async fn generate_chat(
263 &self,
264 messages: Vec<AnyMessage>,
265 config: Option<GenerationConfig>,
266 runnable_config: Option<RunnableConfig>,
267 ) -> Result<AnyMessage>;
268
269 async fn generate_chat_batch(
271 &self,
272 conversations: Vec<Vec<AnyMessage>>,
273 config: Option<GenerationConfig>,
274 runnable_config: Option<RunnableConfig>,
275 ) -> Result<Vec<AnyMessage>> {
276 let mut results = Vec::new();
277 for messages in conversations {
278 let result = self
279 .generate_chat(messages, config.clone(), runnable_config.clone())
280 .await?;
281 results.push(result);
282 }
283 Ok(results)
284 }
285
286 async fn stream_chat(
288 &self,
289 messages: Vec<AnyMessage>,
290 config: Option<GenerationConfig>,
291 runnable_config: Option<RunnableConfig>,
292 ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<AnyMessage>> + Send>>> {
293 let result = self
295 .generate_chat(messages, config, runnable_config)
296 .await?;
297 let stream = futures::stream::once(async { Ok(result) });
298 Ok(Box::pin(stream))
299 }
300}
301
302pub struct MockLLM {
304 model_name: String,
305 responses: Vec<String>,
306 current_index: std::sync::atomic::AtomicUsize,
307}
308
309impl MockLLM {
310 pub fn new(model_name: impl Into<String>) -> Self {
312 Self {
313 model_name: model_name.into(),
314 responses: Vec::new(),
315 current_index: std::sync::atomic::AtomicUsize::new(0),
316 }
317 }
318
319 pub fn add_response(mut self, response: impl Into<String>) -> Self {
321 self.responses.push(response.into());
322 self
323 }
324
325 fn get_next_response(&self) -> String {
327 if self.responses.is_empty() {
328 "Mock response".to_string()
329 } else {
330 let index = self
331 .current_index
332 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
333 self.responses[index % self.responses.len()].clone()
334 }
335 }
336}
337
338#[async_trait]
339impl BaseLanguageModel for MockLLM {
340 fn model_name(&self) -> &str {
341 &self.model_name
342 }
343
344 fn model_type(&self) -> &str {
345 "mock_llm"
346 }
347
348 fn supports_streaming(&self) -> bool {
349 true
350 }
351}
352
353#[async_trait]
354impl BaseLLM for MockLLM {
355 async fn generate(
356 &self,
357 _prompt: &str,
358 _config: Option<GenerationConfig>,
359 _runnable_config: Option<RunnableConfig>,
360 ) -> Result<LLMResult> {
361 let response = self.get_next_response();
362 let generation = Generation::new(response);
363 Ok(LLMResult::new(vec![vec![generation]]))
364 }
365}
366
367pub struct MockChatModel {
369 model_name: String,
370 responses: Vec<String>,
371 current_index: std::sync::atomic::AtomicUsize,
372}
373
374impl MockChatModel {
375 pub fn new(model_name: impl Into<String>) -> Self {
377 Self {
378 model_name: model_name.into(),
379 responses: Vec::new(),
380 current_index: std::sync::atomic::AtomicUsize::new(0),
381 }
382 }
383
384 pub fn add_response(mut self, response: impl Into<String>) -> Self {
386 self.responses.push(response.into());
387 self
388 }
389
390 fn get_next_response(&self) -> String {
392 if self.responses.is_empty() {
393 "Mock chat response".to_string()
394 } else {
395 let index = self
396 .current_index
397 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
398 self.responses[index % self.responses.len()].clone()
399 }
400 }
401}
402
403#[async_trait]
404impl BaseLanguageModel for MockChatModel {
405 fn model_name(&self) -> &str {
406 &self.model_name
407 }
408
409 fn model_type(&self) -> &str {
410 "mock_chat_model"
411 }
412
413 fn supports_streaming(&self) -> bool {
414 true
415 }
416}
417
418#[async_trait]
419impl BaseChatModel for MockChatModel {
420 async fn generate_chat(
421 &self,
422 _messages: Vec<AnyMessage>,
423 _config: Option<GenerationConfig>,
424 _runnable_config: Option<RunnableConfig>,
425 ) -> Result<AnyMessage> {
426 let response = self.get_next_response();
427 Ok(AnyMessage::ai(response))
428 }
429}
430
431pub fn mock_llm(model_name: impl Into<String>) -> MockLLM {
433 MockLLM::new(model_name)
434}
435
436pub fn mock_chat_model(model_name: impl Into<String>) -> MockChatModel {
438 MockChatModel::new(model_name)
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::messages::BaseMessage;
445 use crate::serializable::Serializable;
446
447 #[test]
448 fn test_generation_config() {
449 let config = GenerationConfig::new()
450 .with_temperature(0.8)
451 .with_max_tokens(100)
452 .with_stop("STOP")
453 .with_streaming(true);
454
455 assert_eq!(config.temperature, Some(0.8));
456 assert_eq!(config.max_tokens, Some(100));
457 assert!(config.stop.contains(&"STOP".to_string()));
458 assert!(config.stream);
459 }
460
461 #[test]
462 fn test_generation() {
463 let generation = Generation::new("Hello, world!");
464 assert_eq!(generation.text, "Hello, world!");
465 assert!(generation.generation_info.is_empty());
466 }
467
468 #[test]
469 fn test_llm_result() {
470 let generations = vec![
471 vec![Generation::new("Hello")],
472 vec![Generation::new("World")],
473 ];
474 let result = LLMResult::new(generations);
475
476 assert_eq!(result.generations.len(), 2);
477 assert_eq!(result.first_text(), Some("Hello"));
478 assert_eq!(result.all_texts(), vec!["Hello", "World"]);
479 }
480
481 #[tokio::test]
482 async fn test_mock_llm() {
483 let llm = MockLLM::new("test-model")
484 .add_response("Response 1")
485 .add_response("Response 2");
486
487 assert_eq!(llm.model_name(), "test-model");
488 assert_eq!(llm.model_type(), "mock_llm");
489 assert!(llm.supports_streaming());
490
491 let result = llm.generate("test prompt", None, None).await.unwrap();
492 assert_eq!(result.first_text(), Some("Response 1"));
493
494 let result2 = llm.generate("test prompt", None, None).await.unwrap();
495 assert_eq!(result2.first_text(), Some("Response 2"));
496 }
497
498 #[tokio::test]
499 async fn test_mock_chat_model() {
500 let chat_model = MockChatModel::new("test-chat-model")
501 .add_response("Chat response 1")
502 .add_response("Chat response 2");
503
504 assert_eq!(chat_model.model_name(), "test-chat-model");
505 assert_eq!(chat_model.model_type(), "mock_chat_model");
506 assert!(chat_model.supports_streaming());
507
508 let messages = vec![AnyMessage::human("Hello")];
509 let result = chat_model
510 .generate_chat(messages, None, None)
511 .await
512 .unwrap();
513 assert!(result.is_ai());
514 assert_eq!(result.text(), "Chat response 1");
515 }
516
517 #[tokio::test]
518 async fn test_llm_batch_generation() {
519 let llm = MockLLM::new("test-model")
520 .add_response("Response 1")
521 .add_response("Response 2");
522
523 let prompts = vec!["prompt 1".to_string(), "prompt 2".to_string()];
524 let results = llm.generate_batch(prompts, None, None).await.unwrap();
525
526 assert_eq!(results.len(), 2);
527 assert_eq!(results[0].first_text(), Some("Response 1"));
528 assert_eq!(results[1].first_text(), Some("Response 2"));
529 }
530
531 #[tokio::test]
532 async fn test_chat_batch_generation() {
533 let chat_model = MockChatModel::new("test-chat-model")
534 .add_response("Chat 1")
535 .add_response("Chat 2");
536
537 let conversations = vec![
538 vec![AnyMessage::human("Hello 1")],
539 vec![AnyMessage::human("Hello 2")],
540 ];
541 let results = chat_model
542 .generate_chat_batch(conversations, None, None)
543 .await
544 .unwrap();
545
546 assert_eq!(results.len(), 2);
547 assert_eq!(results[0].text(), "Chat 1");
548 assert_eq!(results[1].text(), "Chat 2");
549 }
550
551 #[test]
552 fn test_serialization() {
553 let config = GenerationConfig::new().with_temperature(0.8);
554 let json = config.to_json().unwrap();
555 let deserialized: GenerationConfig = GenerationConfig::from_json(&json).unwrap();
556 assert_eq!(config, deserialized);
557 }
558}