1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::pin::Pin;
10use std::sync::Arc;
11
12use crate::errors::Result;
13use crate::impl_serializable;
14use crate::utils::{colors, print_colored_text};
15
16#[derive(Clone, Serialize, Deserialize, Default)]
18pub struct RunnableConfig {
19 #[serde(default)]
21 pub tags: Vec<String>,
22 #[serde(default)]
24 pub metadata: HashMap<String, serde_json::Value>,
25 #[serde(default)]
27 pub debug: bool,
28 #[serde(default)]
30 pub verbose: bool,
31 #[serde(skip)]
33 pub callbacks: Vec<Arc<dyn CallbackHandler>>,
34}
35
36impl RunnableConfig {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
44 self.tags.push(tag.into());
45 self
46 }
47
48 pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
50 self.metadata.insert(key.into(), value);
51 self
52 }
53
54 pub fn with_debug(mut self, debug: bool) -> Self {
56 self.debug = debug;
57 self
58 }
59
60 pub fn with_verbose(mut self, verbose: bool) -> Self {
62 self.verbose = verbose;
63 self
64 }
65
66 pub fn with_callback(mut self, callback: Arc<dyn CallbackHandler>) -> Self {
68 self.callbacks.push(callback);
69 self
70 }
71}
72
73impl PartialEq for RunnableConfig {
74 fn eq(&self, other: &Self) -> bool {
75 self.tags == other.tags
76 && self.metadata == other.metadata
77 && self.debug == other.debug
78 && self.verbose == other.verbose
79 }
81}
82
83impl_serializable!(RunnableConfig, ["ferriclink", "runnables", "config"]);
84
85#[async_trait]
87pub trait CallbackHandler: Send + Sync {
88 async fn on_start(&self, run_id: &str, input: &serde_json::Value) -> Result<()> {
90 let _ = (run_id, input);
91 Ok(())
92 }
93
94 async fn on_success(&self, run_id: &str, output: &serde_json::Value) -> Result<()> {
96 let _ = (run_id, output);
97 Ok(())
98 }
99
100 async fn on_error(&self, run_id: &str, error: &crate::errors::FerricLinkError) -> Result<()> {
102 let _ = (run_id, error);
103 Ok(())
104 }
105
106 async fn on_stream(&self, run_id: &str, chunk: &serde_json::Value) -> Result<()> {
108 let _ = (run_id, chunk);
109 Ok(())
110 }
111}
112
113pub struct ConsoleCallbackHandler {
115 pub color: Option<String>,
117}
118
119impl ConsoleCallbackHandler {
120 pub fn new() -> Self {
122 Self { color: None }
123 }
124
125 pub fn new_with_color(color: impl Into<String>) -> Self {
127 Self {
128 color: Some(color.into()),
129 }
130 }
131}
132
133impl Default for ConsoleCallbackHandler {
134 fn default() -> Self {
135 Self::new()
136 }
137}
138
139#[async_trait]
140impl CallbackHandler for ConsoleCallbackHandler {
141 async fn on_start(&self, run_id: &str, input: &serde_json::Value) -> Result<()> {
142 let message = format!("Starting run {run_id} with input: {input}");
143 print_colored_text(&message, self.color.as_deref());
144 Ok(())
145 }
146
147 async fn on_success(&self, run_id: &str, output: &serde_json::Value) -> Result<()> {
148 let message = format!("Run {run_id} completed with output: {output}");
149 print_colored_text(&message, self.color.as_deref());
150 Ok(())
151 }
152
153 async fn on_error(&self, run_id: &str, error: &crate::errors::FerricLinkError) -> Result<()> {
154 let message = format!("Run {run_id} failed with error: {error}");
155 print_colored_text(&message, Some(colors::RED));
156 Ok(())
157 }
158
159 async fn on_stream(&self, run_id: &str, chunk: &serde_json::Value) -> Result<()> {
160 let message = format!("Run {run_id} streamed: {chunk}");
161 print_colored_text(&message, self.color.as_deref());
162 Ok(())
163 }
164}
165
166#[async_trait]
168pub trait Runnable<Input, Output>: Send + Sync + 'static
169where
170 Input: Send + Sync + 'static,
171 Output: Send + Sync + 'static,
172{
173 async fn invoke(&self, input: Input, config: Option<RunnableConfig>) -> Result<Output>;
175
176 async fn invoke_simple(&self, input: Input) -> Result<Output> {
178 self.invoke(input, None).await
179 }
180
181 async fn batch(
183 &self,
184 inputs: Vec<Input>,
185 config: Option<RunnableConfig>,
186 ) -> Result<Vec<Output>> {
187 let mut results = Vec::new();
188 for input in inputs {
189 let result = self.invoke(input, config.clone()).await?;
190 results.push(result);
191 }
192 Ok(results)
193 }
194
195 async fn stream(
197 &self,
198 input: Input,
199 config: Option<RunnableConfig>,
200 ) -> Result<Pin<Box<dyn futures::Stream<Item = Result<Output>> + Send>>> {
201 let result = self.invoke(input, config).await;
203 let stream = futures::stream::once(async { result });
204 Ok(Box::pin(stream))
205 }
206
207 fn input_schema(&self) -> Option<serde_json::Value> {
209 None
210 }
211
212 fn output_schema(&self) -> Option<serde_json::Value> {
214 None
215 }
216
217 fn config_schema(&self) -> Option<serde_json::Value> {
219 None
220 }
221}
222
223pub struct RunnableLambda<F, Input, Output>
225where
226 F: Fn(Input) -> Result<Output> + Send + Sync + 'static,
227 Input: Send + Sync + 'static,
228 Output: Send + Sync + 'static,
229{
230 func: F,
231 _phantom: std::marker::PhantomData<(Input, Output)>,
232}
233
234impl<F, Input, Output> RunnableLambda<F, Input, Output>
235where
236 F: Fn(Input) -> Result<Output> + Send + Sync + 'static,
237 Input: Send + Sync + 'static,
238 Output: Send + Sync + 'static,
239{
240 pub fn new(func: F) -> Self {
242 Self {
243 func,
244 _phantom: std::marker::PhantomData,
245 }
246 }
247}
248
249#[async_trait]
250impl<F, Input, Output> Runnable<Input, Output> for RunnableLambda<F, Input, Output>
251where
252 F: Fn(Input) -> Result<Output> + Send + Sync + 'static,
253 Input: Send + Sync + 'static,
254 Output: Send + Sync + 'static,
255{
256 async fn invoke(&self, input: Input, _config: Option<RunnableConfig>) -> Result<Output> {
257 (self.func)(input)
258 }
259}
260
261pub struct RunnableAsync<F, Input, Output, Fut>
263where
264 F: Fn(Input) -> Fut + Send + Sync + 'static,
265 Fut: std::future::Future<Output = Result<Output>> + Send + 'static,
266 Input: Send + Sync + 'static,
267 Output: Send + Sync + 'static,
268{
269 func: F,
270 _phantom: std::marker::PhantomData<(Input, Output)>,
271}
272
273impl<F, Input, Output, Fut> RunnableAsync<F, Input, Output, Fut>
274where
275 F: Fn(Input) -> Fut + Send + Sync + 'static,
276 Fut: std::future::Future<Output = Result<Output>> + Send + 'static,
277 Input: Send + Sync + 'static,
278 Output: Send + Sync + 'static,
279{
280 pub fn new(func: F) -> Self {
282 Self {
283 func,
284 _phantom: std::marker::PhantomData,
285 }
286 }
287}
288
289#[async_trait]
290impl<F, Input, Output, Fut> Runnable<Input, Output> for RunnableAsync<F, Input, Output, Fut>
291where
292 F: Fn(Input) -> Fut + Send + Sync + 'static,
293 Fut: std::future::Future<Output = Result<Output>> + Send + 'static,
294 Input: Send + Sync + 'static,
295 Output: Send + Sync + 'static,
296{
297 async fn invoke(&self, input: Input, _config: Option<RunnableConfig>) -> Result<Output> {
298 (self.func)(input).await
299 }
300}
301
302pub struct RunnableSequence<Input, Intermediate, Output> {
304 first: Arc<dyn Runnable<Input, Intermediate>>,
305 second: Arc<dyn Runnable<Intermediate, Output>>,
306}
307
308impl<Input, Intermediate, Output> RunnableSequence<Input, Intermediate, Output>
309where
310 Input: Send + Sync + 'static,
311 Intermediate: Send + Sync + 'static,
312 Output: Send + Sync + 'static,
313{
314 pub fn new(
316 first: Arc<dyn Runnable<Input, Intermediate>>,
317 second: Arc<dyn Runnable<Intermediate, Output>>,
318 ) -> Self {
319 Self { first, second }
320 }
321}
322
323#[async_trait]
324impl<Input, Intermediate, Output> Runnable<Input, Output>
325 for RunnableSequence<Input, Intermediate, Output>
326where
327 Input: Send + Sync + 'static,
328 Intermediate: Send + Sync + 'static,
329 Output: Send + Sync + 'static,
330{
331 async fn invoke(&self, input: Input, config: Option<RunnableConfig>) -> Result<Output> {
332 let intermediate = self.first.invoke(input, config.clone()).await?;
333 self.second.invoke(intermediate, config).await
334 }
335}
336
337pub struct RunnableParallel<Input, Output> {
339 runnables: Vec<Arc<dyn Runnable<Input, Output>>>,
340}
341
342impl<Input, Output> RunnableParallel<Input, Output>
343where
344 Input: Send + Sync + 'static + Clone,
345 Output: Send + Sync + 'static,
346{
347 pub fn new(runnables: Vec<Arc<dyn Runnable<Input, Output>>>) -> Self {
349 Self { runnables }
350 }
351
352 pub fn add_runnable(&mut self, runnable: Arc<dyn Runnable<Input, Output>>) {
354 self.runnables.push(runnable);
355 }
356}
357
358#[async_trait]
359impl<Input, Output> Runnable<Input, Vec<Output>> for RunnableParallel<Input, Output>
360where
361 Input: Send + Sync + 'static + Clone,
362 Output: Send + Sync + 'static,
363{
364 async fn invoke(&self, input: Input, config: Option<RunnableConfig>) -> Result<Vec<Output>> {
365 let mut handles = Vec::new();
366
367 for runnable in &self.runnables {
368 let runnable = runnable.clone();
369 let input = input.clone();
370 let config = config.clone();
371
372 let handle = tokio::spawn(async move { runnable.invoke(input, config).await });
373
374 handles.push(handle);
375 }
376
377 let mut results = Vec::new();
378 for handle in handles {
379 let result = handle.await.map_err(|e| {
380 crate::errors::FerricLinkError::runtime(format!("Task failed: {e}"))
381 })?;
382 results.push(result?);
383 }
384
385 Ok(results)
386 }
387}
388
389pub fn runnable<F, Input, Output>(func: F) -> Arc<dyn Runnable<Input, Output>>
391where
392 F: Fn(Input) -> Result<Output> + Send + Sync + 'static,
393 Input: Send + Sync + 'static,
394 Output: Send + Sync + 'static,
395{
396 Arc::new(RunnableLambda::new(func))
397}
398
399pub fn runnable_async<F, Input, Output, Fut>(func: F) -> Arc<dyn Runnable<Input, Output>>
401where
402 F: Fn(Input) -> Fut + Send + Sync + 'static,
403 Fut: std::future::Future<Output = Result<Output>> + Send + 'static,
404 Input: Send + Sync + 'static,
405 Output: Send + Sync + 'static,
406{
407 Arc::new(RunnableAsync::new(func))
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[tokio::test]
415 async fn test_runnable_lambda() {
416 let runnable = RunnableLambda::new(|x: i32| Ok(x * 2));
417 let result = runnable.invoke_simple(5).await.unwrap();
418 assert_eq!(result, 10);
419 }
420
421 #[tokio::test]
422 async fn test_runnable_async() {
423 let runnable = RunnableAsync::new(|x: i32| async move {
424 tokio::time::sleep(tokio::time::Duration::from_millis(1)).await;
425 Ok(x * 3)
426 });
427 let result = runnable.invoke_simple(4).await.unwrap();
428 assert_eq!(result, 12);
429 }
430
431 #[tokio::test]
432 async fn test_runnable_sequence() {
433 let first = Arc::new(RunnableLambda::new(|x: i32| Ok(x + 1)));
434 let second = Arc::new(RunnableLambda::new(|x: i32| Ok(x * 2)));
435 let sequence = RunnableSequence::new(first, second);
436
437 let result = sequence.invoke_simple(5).await.unwrap();
438 assert_eq!(result, 12); }
440
441 #[tokio::test]
442 async fn test_runnable_parallel() {
443 let runnable1 = Arc::new(RunnableLambda::new(|x: i32| Ok(x * 2)));
444 let runnable2 = Arc::new(RunnableLambda::new(|x: i32| Ok(x * 3)));
445 let parallel = RunnableParallel::new(vec![runnable1, runnable2]);
446
447 let results = parallel.invoke_simple(5).await.unwrap();
448 assert_eq!(results.len(), 2);
449 assert!(results.contains(&10)); assert!(results.contains(&15)); }
452
453 #[tokio::test]
454 async fn test_runnable_batch() {
455 let runnable = RunnableLambda::new(|x: i32| Ok(x * 2));
456 let results = runnable.batch(vec![1, 2, 3], None).await.unwrap();
457 assert_eq!(results, vec![2, 4, 6]);
458 }
459
460 #[tokio::test]
461 async fn test_runnable_config() {
462 let config = RunnableConfig::new()
463 .with_tag("test")
464 .with_metadata("key", serde_json::Value::String("value".to_string()))
465 .with_debug(true);
466
467 assert!(config.tags.contains(&"test".to_string()));
468 assert_eq!(
469 config.metadata.get("key"),
470 Some(&serde_json::Value::String("value".to_string()))
471 );
472 assert!(config.debug);
473 }
474
475 #[tokio::test]
476 async fn test_console_callback_handler() {
477 let handler = ConsoleCallbackHandler::new();
478 let run_id = "test-run";
479 let input = serde_json::Value::String("test input".to_string());
480 let output = serde_json::Value::String("test output".to_string());
481 let error = crate::errors::FerricLinkError::generic("test error");
482
483 handler.on_start(run_id, &input).await.unwrap();
485 handler.on_success(run_id, &output).await.unwrap();
486 handler.on_error(run_id, &error).await.unwrap();
487 handler.on_stream(run_id, &output).await.unwrap();
488 }
489
490 #[tokio::test]
491 async fn test_helper_functions() {
492 let sync_runnable = runnable(|x: i32| Ok(x + 1));
493 let result1 = sync_runnable.invoke_simple(5).await.unwrap();
494 assert_eq!(result1, 6);
495
496 let async_runnable = runnable_async(|x: i32| async move { Ok(x * 2) });
497 let result2 = async_runnable.invoke_simple(3).await.unwrap();
498 assert_eq!(result2, 6);
499 }
500}