ferriclink_core/
callbacks.rs

1//! Callback system for FerricLink Core
2//!
3//! This module provides a comprehensive callback system for monitoring
4//! and tracing the execution of FerricLink components.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11
12use crate::errors::Result;
13use crate::impl_serializable;
14use crate::utils::{colors, print_bold_text, print_colored_text};
15
16/// A run ID for tracking execution
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
18pub struct RunId {
19    /// The unique identifier
20    pub id: String,
21    /// The timestamp when the run was created
22    pub created_at: chrono::DateTime<chrono::Utc>,
23}
24
25impl RunId {
26    /// Create a new run ID
27    pub fn new() -> Self {
28        Self {
29            id: uuid::Uuid::new_v4().to_string(),
30            created_at: chrono::Utc::now(),
31        }
32    }
33
34    /// Create a new run ID with a custom ID
35    pub fn new_with_id(id: impl Into<String>) -> Self {
36        Self {
37            id: id.into(),
38            created_at: chrono::Utc::now(),
39        }
40    }
41}
42
43impl Default for RunId {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl_serializable!(RunId, ["ferriclink", "callbacks", "run_id"]);
50
51/// Information about a run
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
53pub struct RunInfo {
54    /// The run ID
55    pub run_id: RunId,
56    /// The name of the component being run
57    pub name: String,
58    /// The type of the component
59    pub component_type: String,
60    /// Input to the component
61    pub input: serde_json::Value,
62    /// Output from the component (if completed)
63    pub output: Option<serde_json::Value>,
64    /// Error that occurred (if any)
65    pub error: Option<String>,
66    /// Start time of the run
67    pub start_time: chrono::DateTime<chrono::Utc>,
68    /// End time of the run (if completed)
69    pub end_time: Option<chrono::DateTime<chrono::Utc>>,
70    /// Duration of the run (if completed)
71    pub duration: Option<Duration>,
72    /// Tags associated with the run
73    #[serde(default)]
74    pub tags: Vec<String>,
75    /// Metadata associated with the run
76    #[serde(default)]
77    pub metadata: HashMap<String, serde_json::Value>,
78    /// Parent run ID (if this is a sub-run)
79    pub parent_run_id: Option<RunId>,
80    /// Child run IDs
81    #[serde(default)]
82    pub child_run_ids: Vec<RunId>,
83}
84
85impl RunInfo {
86    /// Create a new run info
87    pub fn new(
88        run_id: RunId,
89        name: impl Into<String>,
90        component_type: impl Into<String>,
91        input: serde_json::Value,
92    ) -> Self {
93        Self {
94            run_id,
95            name: name.into(),
96            component_type: component_type.into(),
97            input,
98            output: None,
99            error: None,
100            start_time: chrono::Utc::now(),
101            end_time: None,
102            duration: None,
103            tags: Vec::new(),
104            metadata: HashMap::new(),
105            parent_run_id: None,
106            child_run_ids: Vec::new(),
107        }
108    }
109
110    /// Mark the run as completed with output
111    pub fn complete_with_output(mut self, output: serde_json::Value) -> Self {
112        self.output = Some(output);
113        self.end_time = Some(chrono::Utc::now());
114        self.duration = Some(
115            (self.end_time.unwrap() - self.start_time)
116                .to_std()
117                .unwrap_or_default(),
118        );
119        self
120    }
121
122    /// Mark the run as failed with error
123    pub fn complete_with_error(mut self, error: impl Into<String>) -> Self {
124        self.error = Some(error.into());
125        self.end_time = Some(chrono::Utc::now());
126        self.duration = Some(
127            (self.end_time.unwrap() - self.start_time)
128                .to_std()
129                .unwrap_or_default(),
130        );
131        self
132    }
133
134    /// Add a tag to the run
135    pub fn add_tag(mut self, tag: impl Into<String>) -> Self {
136        self.tags.push(tag.into());
137        self
138    }
139
140    /// Add metadata to the run
141    pub fn add_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
142        self.metadata.insert(key.into(), value);
143        self
144    }
145
146    /// Set the parent run ID
147    pub fn with_parent(mut self, parent_run_id: RunId) -> Self {
148        self.parent_run_id = Some(parent_run_id);
149        self
150    }
151
152    /// Add a child run ID
153    pub fn add_child(mut self, child_run_id: RunId) -> Self {
154        self.child_run_ids.push(child_run_id);
155        self
156    }
157
158    /// Check if the run is completed
159    pub fn is_completed(&self) -> bool {
160        self.end_time.is_some()
161    }
162
163    /// Check if the run failed
164    pub fn is_failed(&self) -> bool {
165        self.error.is_some()
166    }
167
168    /// Check if the run succeeded
169    pub fn is_successful(&self) -> bool {
170        self.is_completed() && !self.is_failed()
171    }
172}
173
174impl_serializable!(RunInfo, ["ferriclink", "callbacks", "run_info"]);
175
176/// Base trait for all callback handlers
177#[async_trait]
178pub trait CallbackHandler: Send + Sync + 'static {
179    /// Called when a run starts
180    async fn on_run_start(&self, run_info: &RunInfo) -> Result<()> {
181        let _ = run_info;
182        Ok(())
183    }
184
185    /// Called when a run completes successfully
186    async fn on_run_success(&self, run_info: &RunInfo) -> Result<()> {
187        let _ = run_info;
188        Ok(())
189    }
190
191    /// Called when a run fails
192    async fn on_run_error(&self, run_info: &RunInfo) -> Result<()> {
193        let _ = run_info;
194        Ok(())
195    }
196
197    /// Called when a run produces streaming output
198    async fn on_run_stream(&self, run_info: &RunInfo, chunk: &serde_json::Value) -> Result<()> {
199        let _ = (run_info, chunk);
200        Ok(())
201    }
202
203    /// Called when a run is cancelled
204    async fn on_run_cancel(&self, run_info: &RunInfo) -> Result<()> {
205        let _ = run_info;
206        Ok(())
207    }
208}
209
210/// A console callback handler that prints run information to stdout
211pub struct ConsoleCallbackHandler {
212    /// Whether to print detailed information
213    pub verbose: bool,
214    /// The color to use for text output (matching LangChain's color scheme)
215    pub color: Option<String>,
216}
217
218impl ConsoleCallbackHandler {
219    /// Create a new console callback handler
220    pub fn new() -> Self {
221        Self {
222            verbose: false,
223            color: None,
224        }
225    }
226
227    /// Create a new console callback handler with verbosity setting
228    pub fn new_with_verbose(verbose: bool) -> Self {
229        Self {
230            verbose,
231            color: None,
232        }
233    }
234
235    /// Create a new console callback handler with color
236    pub fn new_with_color(color: impl Into<String>) -> Self {
237        Self {
238            verbose: false,
239            color: Some(color.into()),
240        }
241    }
242
243    /// Create a new console callback handler with verbosity and color
244    pub fn new_with_verbose_and_color(verbose: bool, color: impl Into<String>) -> Self {
245        Self {
246            verbose,
247            color: Some(color.into()),
248        }
249    }
250}
251
252impl Default for ConsoleCallbackHandler {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258#[async_trait]
259impl CallbackHandler for ConsoleCallbackHandler {
260    async fn on_run_start(&self, run_info: &RunInfo) -> Result<()> {
261        // Match LangChain's format: "> Entering new {name} chain..."
262        let message = format!("\n\n> Entering new {} chain...", run_info.name);
263        print_bold_text(&message);
264
265        if self.verbose {
266            println!("   Input: {}", run_info.input);
267            if !run_info.tags.is_empty() {
268                println!("   Tags: {:?}", run_info.tags);
269            }
270        }
271
272        Ok(())
273    }
274
275    async fn on_run_success(&self, run_info: &RunInfo) -> Result<()> {
276        // Match LangChain's format: "> Finished chain."
277        print_bold_text("\n> Finished chain.");
278
279        if self.verbose {
280            if let Some(output) = &run_info.output {
281                let color = self.color.as_deref();
282                print_colored_text(&format!("   Output: {output}"), color);
283            }
284        }
285
286        Ok(())
287    }
288
289    async fn on_run_error(&self, run_info: &RunInfo) -> Result<()> {
290        let error_msg = run_info.error.as_deref().unwrap_or("Unknown error");
291        print_colored_text(&format!("\n> Error: {error_msg}"), Some(colors::RED));
292
293        Ok(())
294    }
295
296    async fn on_run_stream(&self, _run_info: &RunInfo, chunk: &serde_json::Value) -> Result<()> {
297        let color = self.color.as_deref();
298        print_colored_text(&format!("{chunk}"), color);
299        Ok(())
300    }
301
302    async fn on_run_cancel(&self, run_info: &RunInfo) -> Result<()> {
303        print_colored_text(
304            &format!("\n> Run {} was cancelled", run_info.run_id.id),
305            Some(colors::YELLOW),
306        );
307        Ok(())
308    }
309}
310
311/// A callback handler that collects run information in memory
312pub struct MemoryCallbackHandler {
313    runs: Arc<tokio::sync::RwLock<Vec<RunInfo>>>,
314}
315
316impl MemoryCallbackHandler {
317    /// Create a new memory callback handler
318    pub fn new() -> Self {
319        Self {
320            runs: Arc::new(tokio::sync::RwLock::new(Vec::new())),
321        }
322    }
323
324    /// Get all runs
325    pub async fn get_runs(&self) -> Vec<RunInfo> {
326        self.runs.read().await.clone()
327    }
328
329    /// Get runs by name
330    pub async fn get_runs_by_name(&self, name: &str) -> Vec<RunInfo> {
331        self.runs
332            .read()
333            .await
334            .iter()
335            .filter(|run| run.name == name)
336            .cloned()
337            .collect()
338    }
339
340    /// Get runs by component type
341    pub async fn get_runs_by_type(&self, component_type: &str) -> Vec<RunInfo> {
342        self.runs
343            .read()
344            .await
345            .iter()
346            .filter(|run| run.component_type == component_type)
347            .cloned()
348            .collect()
349    }
350
351    /// Get successful runs
352    pub async fn get_successful_runs(&self) -> Vec<RunInfo> {
353        self.runs
354            .read()
355            .await
356            .iter()
357            .filter(|run| run.is_successful())
358            .cloned()
359            .collect()
360    }
361
362    /// Get failed runs
363    pub async fn get_failed_runs(&self) -> Vec<RunInfo> {
364        self.runs
365            .read()
366            .await
367            .iter()
368            .filter(|run| run.is_failed())
369            .cloned()
370            .collect()
371    }
372
373    /// Clear all runs
374    pub async fn clear(&self) {
375        self.runs.write().await.clear();
376    }
377
378    /// Get the number of runs
379    pub async fn len(&self) -> usize {
380        self.runs.read().await.len()
381    }
382
383    /// Check if there are any runs
384    pub async fn is_empty(&self) -> bool {
385        self.runs.read().await.is_empty()
386    }
387}
388
389impl Default for MemoryCallbackHandler {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395#[async_trait]
396impl CallbackHandler for MemoryCallbackHandler {
397    async fn on_run_start(&self, run_info: &RunInfo) -> Result<()> {
398        self.runs.write().await.push(run_info.clone());
399        Ok(())
400    }
401
402    async fn on_run_success(&self, run_info: &RunInfo) -> Result<()> {
403        // Update the existing run info
404        if let Some(existing_run) = self
405            .runs
406            .write()
407            .await
408            .iter_mut()
409            .find(|run| run.run_id.id == run_info.run_id.id)
410        {
411            *existing_run = run_info.clone();
412        }
413        Ok(())
414    }
415
416    async fn on_run_error(&self, run_info: &RunInfo) -> Result<()> {
417        // Update the existing run info
418        if let Some(existing_run) = self
419            .runs
420            .write()
421            .await
422            .iter_mut()
423            .find(|run| run.run_id.id == run_info.run_id.id)
424        {
425            *existing_run = run_info.clone();
426        }
427        Ok(())
428    }
429}
430
431/// A callback manager that manages multiple callback handlers
432pub struct CallbackManager {
433    handlers: Vec<Arc<dyn CallbackHandler>>,
434}
435
436impl CallbackManager {
437    /// Create a new callback manager
438    pub fn new() -> Self {
439        Self {
440            handlers: Vec::new(),
441        }
442    }
443
444    /// Add a callback handler
445    pub fn add_handler(&mut self, handler: Arc<dyn CallbackHandler>) {
446        self.handlers.push(handler);
447    }
448
449    /// Remove all handlers
450    pub fn clear(&mut self) {
451        self.handlers.clear();
452    }
453
454    /// Get the number of handlers
455    pub fn len(&self) -> usize {
456        self.handlers.len()
457    }
458
459    /// Check if there are any handlers
460    pub fn is_empty(&self) -> bool {
461        self.handlers.is_empty()
462    }
463
464    /// Call all handlers for run start
465    pub async fn on_run_start(&self, run_info: &RunInfo) -> Result<()> {
466        for handler in &self.handlers {
467            handler.on_run_start(run_info).await?;
468        }
469        Ok(())
470    }
471
472    /// Call all handlers for run success
473    pub async fn on_run_success(&self, run_info: &RunInfo) -> Result<()> {
474        for handler in &self.handlers {
475            handler.on_run_success(run_info).await?;
476        }
477        Ok(())
478    }
479
480    /// Call all handlers for run error
481    pub async fn on_run_error(&self, run_info: &RunInfo) -> Result<()> {
482        for handler in &self.handlers {
483            handler.on_run_error(run_info).await?;
484        }
485        Ok(())
486    }
487
488    /// Call all handlers for run stream
489    pub async fn on_run_stream(&self, run_info: &RunInfo, chunk: &serde_json::Value) -> Result<()> {
490        for handler in &self.handlers {
491            handler.on_run_stream(run_info, chunk).await?;
492        }
493        Ok(())
494    }
495
496    /// Call all handlers for run cancel
497    pub async fn on_run_cancel(&self, run_info: &RunInfo) -> Result<()> {
498        for handler in &self.handlers {
499            handler.on_run_cancel(run_info).await?;
500        }
501        Ok(())
502    }
503}
504
505impl Default for CallbackManager {
506    fn default() -> Self {
507        Self::new()
508    }
509}
510
511/// Helper function to create a console callback handler
512pub fn console_callback_handler() -> Arc<ConsoleCallbackHandler> {
513    Arc::new(ConsoleCallbackHandler::new())
514}
515
516/// Helper function to create a verbose console callback handler
517pub fn verbose_console_callback_handler() -> Arc<ConsoleCallbackHandler> {
518    Arc::new(ConsoleCallbackHandler::new_with_verbose(true))
519}
520
521/// Helper function to create a colored console callback handler
522pub fn colored_console_callback_handler(color: impl Into<String>) -> Arc<ConsoleCallbackHandler> {
523    Arc::new(ConsoleCallbackHandler::new_with_color(color))
524}
525
526/// Helper function to create a verbose colored console callback handler
527pub fn verbose_colored_console_callback_handler(
528    verbose: bool,
529    color: impl Into<String>,
530) -> Arc<ConsoleCallbackHandler> {
531    Arc::new(ConsoleCallbackHandler::new_with_verbose_and_color(
532        verbose, color,
533    ))
534}
535
536/// Helper function to create a memory callback handler
537pub fn memory_callback_handler() -> Arc<MemoryCallbackHandler> {
538    Arc::new(MemoryCallbackHandler::new())
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use crate::serializable::Serializable;
545
546    #[test]
547    fn test_run_id() {
548        let run_id = RunId::new();
549        assert!(!run_id.id.is_empty());
550        assert!(run_id.created_at <= chrono::Utc::now());
551    }
552
553    #[test]
554    fn test_run_info() {
555        let run_id = RunId::new();
556        let run_info = RunInfo::new(
557            run_id.clone(),
558            "test_component",
559            "test_type",
560            serde_json::json!({"input": "test"}),
561        );
562
563        assert_eq!(run_info.run_id, run_id);
564        assert_eq!(run_info.name, "test_component");
565        assert_eq!(run_info.component_type, "test_type");
566        assert!(!run_info.is_completed());
567        assert!(!run_info.is_failed());
568        assert!(!run_info.is_successful());
569    }
570
571    #[test]
572    fn test_run_info_completion() {
573        let run_id = RunId::new();
574        let mut run_info = RunInfo::new(
575            run_id,
576            "test_component",
577            "test_type",
578            serde_json::json!({"input": "test"}),
579        );
580
581        run_info = run_info.complete_with_output(serde_json::json!({"output": "result"}));
582        assert!(run_info.is_completed());
583        assert!(!run_info.is_failed());
584        assert!(run_info.is_successful());
585        assert!(run_info.output.is_some());
586        assert!(run_info.end_time.is_some());
587        assert!(run_info.duration.is_some());
588    }
589
590    #[test]
591    fn test_run_info_error() {
592        let run_id = RunId::new();
593        let mut run_info = RunInfo::new(
594            run_id,
595            "test_component",
596            "test_type",
597            serde_json::json!({"input": "test"}),
598        );
599
600        run_info = run_info.complete_with_error("Test error");
601        assert!(run_info.is_completed());
602        assert!(run_info.is_failed());
603        assert!(!run_info.is_successful());
604        assert!(run_info.error.is_some());
605    }
606
607    #[tokio::test]
608    async fn test_console_callback_handler() {
609        let handler = ConsoleCallbackHandler::new();
610        let run_info = RunInfo::new(
611            RunId::new(),
612            "test",
613            "test_type",
614            serde_json::json!({"input": "test"}),
615        );
616
617        // These should not panic
618        handler.on_run_start(&run_info).await.unwrap();
619        handler.on_run_success(&run_info).await.unwrap();
620        handler.on_run_error(&run_info).await.unwrap();
621        handler
622            .on_run_stream(&run_info, &serde_json::json!("chunk"))
623            .await
624            .unwrap();
625        handler.on_run_cancel(&run_info).await.unwrap();
626    }
627
628    #[tokio::test]
629    async fn test_memory_callback_handler() {
630        let handler = MemoryCallbackHandler::new();
631        let run_info = RunInfo::new(
632            RunId::new(),
633            "test",
634            "test_type",
635            serde_json::json!({"input": "test"}),
636        );
637
638        handler.on_run_start(&run_info).await.unwrap();
639        assert_eq!(handler.len().await, 1);
640
641        let runs = handler.get_runs().await;
642        assert_eq!(runs.len(), 1);
643        assert_eq!(runs[0].name, "test");
644    }
645
646    #[tokio::test]
647    async fn test_callback_manager() {
648        let mut manager = CallbackManager::new();
649        let console_handler = Arc::new(ConsoleCallbackHandler::new());
650        let memory_handler = Arc::new(MemoryCallbackHandler::new());
651
652        manager.add_handler(console_handler);
653        manager.add_handler(memory_handler);
654
655        assert_eq!(manager.len(), 2);
656        assert!(!manager.is_empty());
657
658        let run_info = RunInfo::new(
659            RunId::new(),
660            "test",
661            "test_type",
662            serde_json::json!({"input": "test"}),
663        );
664
665        manager.on_run_start(&run_info).await.unwrap();
666        manager.on_run_success(&run_info).await.unwrap();
667    }
668
669    #[test]
670    fn test_serialization() {
671        let run_id = RunId::new();
672        let run_info = RunInfo::new(
673            run_id,
674            "test",
675            "test_type",
676            serde_json::json!({"input": "test"}),
677        );
678
679        let json = run_info.to_json().unwrap();
680        let deserialized: RunInfo = RunInfo::from_json(&json).unwrap();
681        assert_eq!(run_info.name, deserialized.name);
682        assert_eq!(run_info.component_type, deserialized.component_type);
683    }
684}