ferriclink_core/
runnables.rs

1//! Runnable trait and related abstractions for FerricLink Core
2//!
3//! This module provides the core Runnable trait that powers the FerricLink ecosystem,
4//! similar to LangChain's Runnable interface.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use crate::errors::Result;
13use crate::impl_serializable;
14use crate::utils::{colors, print_colored_text};
15
16/// Configuration for running a Runnable
17#[derive(Clone, Serialize, Deserialize, Default)]
18pub struct RunnableConfig {
19    /// Tags for this run
20    #[serde(default)]
21    pub tags: Vec<String>,
22    /// Metadata for this run
23    #[serde(default)]
24    pub metadata: HashMap<String, serde_json::Value>,
25    /// Whether to run in debug mode
26    #[serde(default)]
27    pub debug: bool,
28    /// Whether to run in verbose mode
29    #[serde(default)]
30    pub verbose: bool,
31    /// Callback handlers for this run
32    #[serde(skip)]
33    pub callbacks: Vec<Arc<dyn CallbackHandler>>,
34}
35
36impl RunnableConfig {
37    /// Create a new empty configuration
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Add a tag to the configuration
43    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
44        self.tags.push(tag.into());
45        self
46    }
47
48    /// Add metadata to the configuration
49    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
50        self.metadata.insert(key.into(), value);
51        self
52    }
53
54    /// Enable debug mode
55    pub fn with_debug(mut self, debug: bool) -> Self {
56        self.debug = debug;
57        self
58    }
59
60    /// Enable verbose mode
61    pub fn with_verbose(mut self, verbose: bool) -> Self {
62        self.verbose = verbose;
63        self
64    }
65
66    /// Add a callback handler
67    pub fn with_callback(mut self, callback: Arc<dyn CallbackHandler>) -> Self {
68        self.callbacks.push(callback);
69        self
70    }
71}
72
73impl PartialEq for RunnableConfig {
74    fn eq(&self, other: &Self) -> bool {
75        self.tags == other.tags
76            && self.metadata == other.metadata
77            && self.debug == other.debug
78            && self.verbose == other.verbose
79        // Skip callbacks comparison
80    }
81}
82
83impl_serializable!(RunnableConfig, ["ferriclink", "runnables", "config"]);
84
85/// Trait for callback handlers that can be used during runnable execution
86#[async_trait]
87pub trait CallbackHandler: Send + Sync {
88    /// Called when a runnable starts
89    async fn on_start(&self, run_id: &str, input: &serde_json::Value) -> Result<()> {
90        let _ = (run_id, input);
91        Ok(())
92    }
93
94    /// Called when a runnable completes successfully
95    async fn on_success(&self, run_id: &str, output: &serde_json::Value) -> Result<()> {
96        let _ = (run_id, output);
97        Ok(())
98    }
99
100    /// Called when a runnable fails
101    async fn on_error(&self, run_id: &str, error: &crate::errors::FerricLinkError) -> Result<()> {
102        let _ = (run_id, error);
103        Ok(())
104    }
105
106    /// Called when a runnable produces streaming output
107    async fn on_stream(&self, run_id: &str, chunk: &serde_json::Value) -> Result<()> {
108        let _ = (run_id, chunk);
109        Ok(())
110    }
111}
112
113/// A simple console callback handler for debugging
114pub struct ConsoleCallbackHandler {
115    /// The color to use for text output (matching LangChain's color scheme)
116    pub color: Option<String>,
117}
118
119impl ConsoleCallbackHandler {
120    /// Create a new console callback handler
121    pub fn new() -> Self {
122        Self { color: None }
123    }
124
125    /// Create a new console callback handler with color
126    pub fn new_with_color(color: impl Into<String>) -> Self {
127        Self {
128            color: Some(color.into()),
129        }
130    }
131}
132
133impl Default for ConsoleCallbackHandler {
134    fn default() -> Self {
135        Self::new()
136    }
137}
138
139#[async_trait]
140impl CallbackHandler for ConsoleCallbackHandler {
141    async fn on_start(&self, run_id: &str, input: &serde_json::Value) -> Result<()> {
142        let message = format!("Starting run {run_id} with input: {input}");
143        print_colored_text(&message, self.color.as_deref());
144        Ok(())
145    }
146
147    async fn on_success(&self, run_id: &str, output: &serde_json::Value) -> Result<()> {
148        let message = format!("Run {run_id} completed with output: {output}");
149        print_colored_text(&message, self.color.as_deref());
150        Ok(())
151    }
152
153    async fn on_error(&self, run_id: &str, error: &crate::errors::FerricLinkError) -> Result<()> {
154        let message = format!("Run {run_id} failed with error: {error}");
155        print_colored_text(&message, Some(colors::RED));
156        Ok(())
157    }
158
159    async fn on_stream(&self, run_id: &str, chunk: &serde_json::Value) -> Result<()> {
160        let message = format!("Run {run_id} streamed: {chunk}");
161        print_colored_text(&message, self.color.as_deref());
162        Ok(())
163    }
164}
165
166/// The core Runnable trait that all FerricLink components implement
167#[async_trait]
168pub trait Runnable<Input, Output>: Send + Sync + 'static
169where
170    Input: Send + Sync + 'static,
171    Output: Send + Sync + 'static,
172{
173    /// Invoke the runnable with a single input
174    async fn invoke(&self, input: Input, config: Option<RunnableConfig>) -> Result<Output>;
175
176    /// Invoke the runnable with a single input (convenience method with default config)
177    async fn invoke_simple(&self, input: Input) -> Result<Output> {
178        self.invoke(input, None).await
179    }
180
181    /// Batch invoke the runnable with multiple inputs
182    async fn batch(
183        &self,
184        inputs: Vec<Input>,
185        config: Option<RunnableConfig>,
186    ) -> Result<Vec<Output>> {
187        let mut results = Vec::new();
188        for input in inputs {
189            let result = self.invoke(input, config.clone()).await?;
190            results.push(result);
191        }
192        Ok(results)
193    }
194
195    /// Stream the output of the runnable
196    async fn stream(
197        &self,
198        input: Input,
199        config: Option<RunnableConfig>,
200    ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<Output>> + Send>>> {
201        // Default implementation just yields the single result
202        let result = self.invoke(input, config).await;
203        let stream = futures::stream::once(async { result });
204        Ok(Box::pin(stream))
205    }
206
207    /// Get the input schema for this runnable
208    fn input_schema(&self) -> Option<serde_json::Value> {
209        None
210    }
211
212    /// Get the output schema for this runnable
213    fn output_schema(&self) -> Option<serde_json::Value> {
214        None
215    }
216
217    /// Get the configuration schema for this runnable
218    fn config_schema(&self) -> Option<serde_json::Value> {
219        None
220    }
221}
222
223/// A runnable that wraps a simple function
224pub struct RunnableLambda<F, Input, Output>
225where
226    F: Fn(Input) -> Result<Output> + Send + Sync + 'static,
227    Input: Send + Sync + 'static,
228    Output: Send + Sync + 'static,
229{
230    func: F,
231    _phantom: std::marker::PhantomData<(Input, Output)>,
232}
233
234impl<F, Input, Output> RunnableLambda<F, Input, Output>
235where
236    F: Fn(Input) -> Result<Output> + Send + Sync + 'static,
237    Input: Send + Sync + 'static,
238    Output: Send + Sync + 'static,
239{
240    /// Create a new runnable lambda
241    pub fn new(func: F) -> Self {
242        Self {
243            func,
244            _phantom: std::marker::PhantomData,
245        }
246    }
247}
248
249#[async_trait]
250impl<F, Input, Output> Runnable<Input, Output> for RunnableLambda<F, Input, Output>
251where
252    F: Fn(Input) -> Result<Output> + Send + Sync + 'static,
253    Input: Send + Sync + 'static,
254    Output: Send + Sync + 'static,
255{
256    async fn invoke(&self, input: Input, _config: Option<RunnableConfig>) -> Result<Output> {
257        (self.func)(input)
258    }
259}
260
261/// A runnable that wraps an async function
262pub struct RunnableAsync<F, Input, Output, Fut>
263where
264    F: Fn(Input) -> Fut + Send + Sync + 'static,
265    Fut: std::future::Future<Output = Result<Output>> + Send + 'static,
266    Input: Send + Sync + 'static,
267    Output: Send + Sync + 'static,
268{
269    func: F,
270    _phantom: std::marker::PhantomData<(Input, Output)>,
271}
272
273impl<F, Input, Output, Fut> RunnableAsync<F, Input, Output, Fut>
274where
275    F: Fn(Input) -> Fut + Send + Sync + 'static,
276    Fut: std::future::Future<Output = Result<Output>> + Send + 'static,
277    Input: Send + Sync + 'static,
278    Output: Send + Sync + 'static,
279{
280    /// Create a new async runnable
281    pub fn new(func: F) -> Self {
282        Self {
283            func,
284            _phantom: std::marker::PhantomData,
285        }
286    }
287}
288
289#[async_trait]
290impl<F, Input, Output, Fut> Runnable<Input, Output> for RunnableAsync<F, Input, Output, Fut>
291where
292    F: Fn(Input) -> Fut + Send + Sync + 'static,
293    Fut: std::future::Future<Output = Result<Output>> + Send + 'static,
294    Input: Send + Sync + 'static,
295    Output: Send + Sync + 'static,
296{
297    async fn invoke(&self, input: Input, _config: Option<RunnableConfig>) -> Result<Output> {
298        (self.func)(input).await
299    }
300}
301
302/// A runnable sequence that chains multiple runnables together
303pub struct RunnableSequence<Input, Intermediate, Output> {
304    first: Arc<dyn Runnable<Input, Intermediate>>,
305    second: Arc<dyn Runnable<Intermediate, Output>>,
306}
307
308impl<Input, Intermediate, Output> RunnableSequence<Input, Intermediate, Output>
309where
310    Input: Send + Sync + 'static,
311    Intermediate: Send + Sync + 'static,
312    Output: Send + Sync + 'static,
313{
314    /// Create a new runnable sequence
315    pub fn new(
316        first: Arc<dyn Runnable<Input, Intermediate>>,
317        second: Arc<dyn Runnable<Intermediate, Output>>,
318    ) -> Self {
319        Self { first, second }
320    }
321}
322
323#[async_trait]
324impl<Input, Intermediate, Output> Runnable<Input, Output>
325    for RunnableSequence<Input, Intermediate, Output>
326where
327    Input: Send + Sync + 'static,
328    Intermediate: Send + Sync + 'static,
329    Output: Send + Sync + 'static,
330{
331    async fn invoke(&self, input: Input, config: Option<RunnableConfig>) -> Result<Output> {
332        let intermediate = self.first.invoke(input, config.clone()).await?;
333        self.second.invoke(intermediate, config).await
334    }
335}
336
337/// A runnable that runs multiple runnables in parallel
338pub struct RunnableParallel<Input, Output> {
339    runnables: Vec<Arc<dyn Runnable<Input, Output>>>,
340}
341
342impl<Input, Output> RunnableParallel<Input, Output>
343where
344    Input: Send + Sync + 'static + Clone,
345    Output: Send + Sync + 'static,
346{
347    /// Create a new runnable parallel
348    pub fn new(runnables: Vec<Arc<dyn Runnable<Input, Output>>>) -> Self {
349        Self { runnables }
350    }
351
352    /// Add a runnable to the parallel execution
353    pub fn add_runnable(&mut self, runnable: Arc<dyn Runnable<Input, Output>>) {
354        self.runnables.push(runnable);
355    }
356}
357
358#[async_trait]
359impl<Input, Output> Runnable<Input, Vec<Output>> for RunnableParallel<Input, Output>
360where
361    Input: Send + Sync + 'static + Clone,
362    Output: Send + Sync + 'static,
363{
364    async fn invoke(&self, input: Input, config: Option<RunnableConfig>) -> Result<Vec<Output>> {
365        let mut handles = Vec::new();
366
367        for runnable in &self.runnables {
368            let runnable = runnable.clone();
369            let input = input.clone();
370            let config = config.clone();
371
372            let handle = tokio::spawn(async move { runnable.invoke(input, config).await });
373
374            handles.push(handle);
375        }
376
377        let mut results = Vec::new();
378        for handle in handles {
379            let result = handle.await.map_err(|e| {
380                crate::errors::FerricLinkError::runtime(format!("Task failed: {e}"))
381            })?;
382            results.push(result?);
383        }
384
385        Ok(results)
386    }
387}
388
389/// Helper function to create a runnable from a simple function
390pub fn runnable<F, Input, Output>(func: F) -> Arc<dyn Runnable<Input, Output>>
391where
392    F: Fn(Input) -> Result<Output> + Send + Sync + 'static,
393    Input: Send + Sync + 'static,
394    Output: Send + Sync + 'static,
395{
396    Arc::new(RunnableLambda::new(func))
397}
398
399/// Helper function to create an async runnable from a function
400pub fn runnable_async<F, Input, Output, Fut>(func: F) -> Arc<dyn Runnable<Input, Output>>
401where
402    F: Fn(Input) -> Fut + Send + Sync + 'static,
403    Fut: std::future::Future<Output = Result<Output>> + Send + 'static,
404    Input: Send + Sync + 'static,
405    Output: Send + Sync + 'static,
406{
407    Arc::new(RunnableAsync::new(func))
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    #[tokio::test]
415    async fn test_runnable_lambda() {
416        let runnable = RunnableLambda::new(|x: i32| Ok(x * 2));
417        let result = runnable.invoke_simple(5).await.unwrap();
418        assert_eq!(result, 10);
419    }
420
421    #[tokio::test]
422    async fn test_runnable_async() {
423        let runnable = RunnableAsync::new(|x: i32| async move {
424            tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
425            Ok(x * 3)
426        });
427        let result = runnable.invoke_simple(4).await.unwrap();
428        assert_eq!(result, 12);
429    }
430
431    #[tokio::test]
432    async fn test_runnable_sequence() {
433        let first = Arc::new(RunnableLambda::new(|x: i32| Ok(x + 1)));
434        let second = Arc::new(RunnableLambda::new(|x: i32| Ok(x * 2)));
435        let sequence = RunnableSequence::new(first, second);
436
437        let result = sequence.invoke_simple(5).await.unwrap();
438        assert_eq!(result, 12); // (5 + 1) * 2
439    }
440
441    #[tokio::test]
442    async fn test_runnable_parallel() {
443        let runnable1 = Arc::new(RunnableLambda::new(|x: i32| Ok(x * 2)));
444        let runnable2 = Arc::new(RunnableLambda::new(|x: i32| Ok(x * 3)));
445        let parallel = RunnableParallel::new(vec![runnable1, runnable2]);
446
447        let results = parallel.invoke_simple(5).await.unwrap();
448        assert_eq!(results.len(), 2);
449        assert!(results.contains(&10)); // 5 * 2
450        assert!(results.contains(&15)); // 5 * 3
451    }
452
453    #[tokio::test]
454    async fn test_runnable_batch() {
455        let runnable = RunnableLambda::new(|x: i32| Ok(x * 2));
456        let results = runnable.batch(vec![1, 2, 3], None).await.unwrap();
457        assert_eq!(results, vec![2, 4, 6]);
458    }
459
460    #[tokio::test]
461    async fn test_runnable_config() {
462        let config = RunnableConfig::new()
463            .with_tag("test")
464            .with_metadata("key", serde_json::Value::String("value".to_string()))
465            .with_debug(true);
466
467        assert!(config.tags.contains(&"test".to_string()));
468        assert_eq!(
469            config.metadata.get("key"),
470            Some(&serde_json::Value::String("value".to_string()))
471        );
472        assert!(config.debug);
473    }
474
475    #[tokio::test]
476    async fn test_console_callback_handler() {
477        let handler = ConsoleCallbackHandler::new();
478        let run_id = "test-run";
479        let input = serde_json::Value::String("test input".to_string());
480        let output = serde_json::Value::String("test output".to_string());
481        let error = crate::errors::FerricLinkError::generic("test error");
482
483        // These should not panic
484        handler.on_start(run_id, &input).await.unwrap();
485        handler.on_success(run_id, &output).await.unwrap();
486        handler.on_error(run_id, &error).await.unwrap();
487        handler.on_stream(run_id, &output).await.unwrap();
488    }
489
490    #[tokio::test]
491    async fn test_helper_functions() {
492        let sync_runnable = runnable(|x: i32| Ok(x + 1));
493        let result1 = sync_runnable.invoke_simple(5).await.unwrap();
494        assert_eq!(result1, 6);
495
496        let async_runnable = runnable_async(|x: i32| async move { Ok(x * 2) });
497        let result2 = async_runnable.invoke_simple(3).await.unwrap();
498        assert_eq!(result2, 6);
499    }
500}