ferriclink_core/
tools.rs

1//! Tool abstractions for FerricLink Core
2//!
3//! This module provides the core abstractions for tools that can be used
4//! by language models and other components.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use crate::errors::Result;
11use crate::impl_serializable;
12use crate::runnables::{Runnable, RunnableConfig};
13
14/// A tool call made by a language model
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct ToolCall {
17    /// Unique identifier for this tool call
18    pub id: String,
19    /// Name of the tool being called
20    pub name: String,
21    /// Arguments passed to the tool
22    pub args: HashMap<String, serde_json::Value>,
23}
24
25impl ToolCall {
26    /// Create a new tool call
27    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
28        Self {
29            id: id.into(),
30            name: name.into(),
31            args: HashMap::new(),
32        }
33    }
34
35    /// Create a new tool call with arguments
36    pub fn new_with_args(
37        id: impl Into<String>,
38        name: impl Into<String>,
39        args: HashMap<String, serde_json::Value>,
40    ) -> Self {
41        Self {
42            id: id.into(),
43            name: name.into(),
44            args,
45        }
46    }
47
48    /// Add an argument to the tool call
49    pub fn add_arg(&mut self, key: impl Into<String>, value: serde_json::Value) {
50        self.args.insert(key.into(), value);
51    }
52
53    /// Get an argument value
54    pub fn get_arg(&self, key: &str) -> Option<&serde_json::Value> {
55        self.args.get(key)
56    }
57}
58
59impl_serializable!(ToolCall, ["ferriclink", "tools", "tool_call"]);
60
61/// A tool result returned by a tool
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
63pub struct ToolResult {
64    /// The tool call ID this result corresponds to
65    pub tool_call_id: String,
66    /// The result content
67    pub content: String,
68    /// Additional metadata
69    #[serde(default)]
70    pub metadata: HashMap<String, serde_json::Value>,
71}
72
73impl ToolResult {
74    /// Create a new tool result
75    pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
76        Self {
77            tool_call_id: tool_call_id.into(),
78            content: content.into(),
79            metadata: HashMap::new(),
80        }
81    }
82
83    /// Create a new tool result with metadata
84    pub fn new_with_metadata(
85        tool_call_id: impl Into<String>,
86        content: impl Into<String>,
87        metadata: HashMap<String, serde_json::Value>,
88    ) -> Self {
89        Self {
90            tool_call_id: tool_call_id.into(),
91            content: content.into(),
92            metadata,
93        }
94    }
95
96    /// Add metadata to the result
97    pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
98        self.metadata.insert(key.into(), value);
99    }
100}
101
102impl_serializable!(ToolResult, ["ferriclink", "tools", "tool_result"]);
103
104/// Schema for a tool's input parameters
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106pub struct ToolSchema {
107    /// Name of the tool
108    pub name: String,
109    /// Description of the tool
110    pub description: String,
111    /// JSON schema for the input parameters
112    pub input_schema: serde_json::Value,
113}
114
115impl ToolSchema {
116    /// Create a new tool schema
117    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
118        Self {
119            name: name.into(),
120            description: description.into(),
121            input_schema: serde_json::json!({
122                "type": "object",
123                "properties": {},
124                "required": []
125            }),
126        }
127    }
128
129    /// Create a new tool schema with input schema
130    pub fn new_with_schema(
131        name: impl Into<String>,
132        description: impl Into<String>,
133        input_schema: serde_json::Value,
134    ) -> Self {
135        Self {
136            name: name.into(),
137            description: description.into(),
138            input_schema,
139        }
140    }
141}
142
143impl_serializable!(ToolSchema, ["ferriclink", "tools", "tool_schema"]);
144
145/// Base trait for all tools
146#[async_trait]
147pub trait BaseTool: Send + Sync + 'static {
148    /// Get the name of the tool
149    fn name(&self) -> &str;
150
151    /// Get the description of the tool
152    fn description(&self) -> &str;
153
154    /// Get the schema for this tool
155    fn schema(&self) -> ToolSchema;
156
157    /// Check if the tool is available
158    fn is_available(&self) -> bool {
159        true
160    }
161
162    /// Get the input schema for this tool
163    fn input_schema(&self) -> Option<serde_json::Value> {
164        Some(self.schema().input_schema.clone())
165    }
166
167    /// Get the output schema for this tool
168    fn output_schema(&self) -> Option<serde_json::Value> {
169        None
170    }
171}
172
173/// Trait for tools that can be invoked with a single input
174#[async_trait]
175pub trait Tool: BaseTool {
176    /// Invoke the tool with the given input
177    async fn invoke(
178        &self,
179        input: HashMap<String, serde_json::Value>,
180        config: Option<RunnableConfig>,
181    ) -> Result<ToolResult>;
182}
183
184/// A simple tool that wraps a function
185pub struct FunctionTool<F> {
186    name: String,
187    description: String,
188    schema: ToolSchema,
189    func: F,
190}
191
192impl<F> FunctionTool<F>
193where
194    F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
195{
196    /// Create a new function tool
197    pub fn new(name: impl Into<String>, description: impl Into<String>, func: F) -> Self {
198        let name = name.into();
199        let description = description.into();
200        let schema = ToolSchema::new(&name, &description);
201
202        Self {
203            name,
204            description,
205            schema,
206            func,
207        }
208    }
209
210    /// Create a new function tool with custom schema
211    pub fn new_with_schema(
212        name: impl Into<String>,
213        description: impl Into<String>,
214        schema: ToolSchema,
215        func: F,
216    ) -> Self {
217        Self {
218            name: name.into(),
219            description: description.into(),
220            schema,
221            func,
222        }
223    }
224}
225
226#[async_trait]
227impl<F> BaseTool for FunctionTool<F>
228where
229    F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
230{
231    fn name(&self) -> &str {
232        &self.name
233    }
234
235    fn description(&self) -> &str {
236        &self.description
237    }
238
239    fn schema(&self) -> ToolSchema {
240        self.schema.clone()
241    }
242}
243
244#[async_trait]
245impl<F> Tool for FunctionTool<F>
246where
247    F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
248{
249    async fn invoke(
250        &self,
251        input: HashMap<String, serde_json::Value>,
252        _config: Option<RunnableConfig>,
253    ) -> Result<ToolResult> {
254        let content = (self.func)(input)?;
255        Ok(ToolResult::new("", content))
256    }
257}
258
259/// A tool that can be used as a runnable
260pub struct RunnableTool<T> {
261    tool: T,
262    tool_call_id: String,
263}
264
265impl<T> RunnableTool<T>
266where
267    T: Tool,
268{
269    /// Create a new runnable tool
270    pub fn new(tool: T, tool_call_id: impl Into<String>) -> Self {
271        Self {
272            tool,
273            tool_call_id: tool_call_id.into(),
274        }
275    }
276}
277
278#[async_trait]
279impl<T> Runnable<HashMap<String, serde_json::Value>, ToolResult> for RunnableTool<T>
280where
281    T: Tool,
282{
283    async fn invoke(
284        &self,
285        input: HashMap<String, serde_json::Value>,
286        config: Option<RunnableConfig>,
287    ) -> Result<ToolResult> {
288        let mut result = self.tool.invoke(input, config).await?;
289        result.tool_call_id = self.tool_call_id.clone();
290        Ok(result)
291    }
292}
293
294/// A collection of tools
295pub struct ToolCollection {
296    tools: HashMap<String, Box<dyn Tool>>,
297}
298
299impl ToolCollection {
300    /// Create a new empty tool collection
301    pub fn new() -> Self {
302        Self {
303            tools: HashMap::new(),
304        }
305    }
306
307    /// Add a tool to the collection
308    pub fn add_tool<T>(&mut self, tool: T)
309    where
310        T: Tool + 'static,
311    {
312        let name = tool.name().to_string();
313        self.tools.insert(name, Box::new(tool));
314    }
315
316    /// Get a tool by name
317    pub fn get_tool(&self, name: &str) -> Option<&dyn Tool> {
318        self.tools.get(name).map(|t| t.as_ref())
319    }
320
321    /// Get all tool names
322    pub fn tool_names(&self) -> Vec<&str> {
323        self.tools.keys().map(|s| s.as_str()).collect()
324    }
325
326    /// Get all tools
327    pub fn tools(&self) -> &HashMap<String, Box<dyn Tool>> {
328        &self.tools
329    }
330
331    /// Get the number of tools
332    pub fn len(&self) -> usize {
333        self.tools.len()
334    }
335
336    /// Check if the collection is empty
337    pub fn is_empty(&self) -> bool {
338        self.tools.is_empty()
339    }
340
341    /// Invoke a tool by name
342    pub async fn invoke_tool(
343        &self,
344        name: &str,
345        input: HashMap<String, serde_json::Value>,
346        config: Option<RunnableConfig>,
347    ) -> Result<ToolResult> {
348        let tool = self.get_tool(name).ok_or_else(|| {
349            crate::errors::FerricLinkError::generic(format!("Tool '{name}' not found"))
350        })?;
351
352        tool.invoke(input, config).await
353    }
354}
355
356impl Default for ToolCollection {
357    fn default() -> Self {
358        Self::new()
359    }
360}
361
362/// Helper function to create a simple function tool
363pub fn function_tool<F>(
364    name: impl Into<String>,
365    description: impl Into<String>,
366    func: F,
367) -> FunctionTool<F>
368where
369    F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
370{
371    FunctionTool::new(name, description, func)
372}
373
374/// Helper function to create a tool with custom schema
375pub fn function_tool_with_schema<F>(
376    name: impl Into<String>,
377    description: impl Into<String>,
378    schema: ToolSchema,
379    func: F,
380) -> FunctionTool<F>
381where
382    F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
383{
384    FunctionTool::new_with_schema(name, description, schema, func)
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use crate::serializable::Serializable;
391
392    #[test]
393    fn test_tool_call() {
394        let mut call = ToolCall::new("call_123", "test_tool");
395        call.add_arg("param1", serde_json::Value::String("value1".to_string()));
396
397        assert_eq!(call.id, "call_123");
398        assert_eq!(call.name, "test_tool");
399        assert_eq!(
400            call.get_arg("param1"),
401            Some(&serde_json::Value::String("value1".to_string()))
402        );
403    }
404
405    #[test]
406    fn test_tool_result() {
407        let mut result = ToolResult::new("call_123", "Tool executed successfully");
408        result.add_metadata(
409            "execution_time",
410            serde_json::Value::Number(serde_json::Number::from(100)),
411        );
412
413        assert_eq!(result.tool_call_id, "call_123");
414        assert_eq!(result.content, "Tool executed successfully");
415        assert_eq!(
416            result.metadata.get("execution_time"),
417            Some(&serde_json::Value::Number(serde_json::Number::from(100)))
418        );
419    }
420
421    #[test]
422    fn test_tool_schema() {
423        let schema = ToolSchema::new("test_tool", "A test tool");
424
425        assert_eq!(schema.name, "test_tool");
426        assert_eq!(schema.description, "A test tool");
427        assert!(schema.input_schema.is_object());
428    }
429
430    #[tokio::test]
431    async fn test_function_tool() {
432        let tool = function_tool("add", "Add two numbers", |args| {
433            let a = args.get("a").and_then(|v| v.as_f64()).ok_or_else(|| {
434                crate::errors::FerricLinkError::validation("Missing or invalid 'a' parameter")
435            })?;
436            let b = args.get("b").and_then(|v| v.as_f64()).ok_or_else(|| {
437                crate::errors::FerricLinkError::validation("Missing or invalid 'b' parameter")
438            })?;
439            Ok((a + b).to_string())
440        });
441
442        assert_eq!(tool.name(), "add");
443        assert_eq!(tool.description(), "Add two numbers");
444
445        let mut args = HashMap::new();
446        args.insert(
447            "a".to_string(),
448            serde_json::Value::Number(serde_json::Number::from(5)),
449        );
450        args.insert(
451            "b".to_string(),
452            serde_json::Value::Number(serde_json::Number::from(3)),
453        );
454
455        let result = tool.invoke(args, None).await.unwrap();
456        assert_eq!(result.content, "8");
457    }
458
459    #[tokio::test]
460    async fn test_tool_collection() {
461        let mut collection = ToolCollection::new();
462
463        let add_tool = function_tool("add", "Add two numbers", |args| {
464            let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
465            let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
466            Ok((a + b).to_string())
467        });
468
469        let multiply_tool = function_tool("multiply", "Multiply two numbers", |args| {
470            let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(1.0);
471            let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(1.0);
472            Ok((a * b).to_string())
473        });
474
475        collection.add_tool(add_tool);
476        collection.add_tool(multiply_tool);
477
478        assert_eq!(collection.len(), 2);
479        assert!(!collection.is_empty());
480        assert!(collection.tool_names().contains(&"add"));
481        assert!(collection.tool_names().contains(&"multiply"));
482
483        let mut args = HashMap::new();
484        args.insert(
485            "a".to_string(),
486            serde_json::Value::Number(serde_json::Number::from(4)),
487        );
488        args.insert(
489            "b".to_string(),
490            serde_json::Value::Number(serde_json::Number::from(5)),
491        );
492
493        let result = collection.invoke_tool("add", args, None).await.unwrap();
494        assert_eq!(result.content, "9");
495    }
496
497    #[tokio::test]
498    async fn test_runnable_tool() {
499        let tool = function_tool("test", "Test tool", |_| Ok("test result".to_string()));
500        let runnable_tool = RunnableTool::new(tool, "call_123");
501
502        let args = HashMap::new();
503        let result = runnable_tool.invoke(args, None).await.unwrap();
504
505        assert_eq!(result.tool_call_id, "call_123");
506        assert_eq!(result.content, "test result");
507    }
508
509    #[test]
510    fn test_serialization() {
511        let call = ToolCall::new("call_123", "test_tool");
512        let json = call.to_json().unwrap();
513        let deserialized: ToolCall = ToolCall::from_json(&json).unwrap();
514        assert_eq!(call, deserialized);
515    }
516}