1use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10use crate::errors::Result;
11use crate::impl_serializable;
12use crate::runnables::{Runnable, RunnableConfig};
13
14#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
16pub struct ToolCall {
17 pub id: String,
19 pub name: String,
21 pub args: HashMap<String, serde_json::Value>,
23}
24
25impl ToolCall {
26 pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
28 Self {
29 id: id.into(),
30 name: name.into(),
31 args: HashMap::new(),
32 }
33 }
34
35 pub fn new_with_args(
37 id: impl Into<String>,
38 name: impl Into<String>,
39 args: HashMap<String, serde_json::Value>,
40 ) -> Self {
41 Self {
42 id: id.into(),
43 name: name.into(),
44 args,
45 }
46 }
47
48 pub fn add_arg(&mut self, key: impl Into<String>, value: serde_json::Value) {
50 self.args.insert(key.into(), value);
51 }
52
53 pub fn get_arg(&self, key: &str) -> Option<&serde_json::Value> {
55 self.args.get(key)
56 }
57}
58
59impl_serializable!(ToolCall, ["ferriclink", "tools", "tool_call"]);
60
61#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
63pub struct ToolResult {
64 pub tool_call_id: String,
66 pub content: String,
68 #[serde(default)]
70 pub metadata: HashMap<String, serde_json::Value>,
71}
72
73impl ToolResult {
74 pub fn new(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
76 Self {
77 tool_call_id: tool_call_id.into(),
78 content: content.into(),
79 metadata: HashMap::new(),
80 }
81 }
82
83 pub fn new_with_metadata(
85 tool_call_id: impl Into<String>,
86 content: impl Into<String>,
87 metadata: HashMap<String, serde_json::Value>,
88 ) -> Self {
89 Self {
90 tool_call_id: tool_call_id.into(),
91 content: content.into(),
92 metadata,
93 }
94 }
95
96 pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
98 self.metadata.insert(key.into(), value);
99 }
100}
101
102impl_serializable!(ToolResult, ["ferriclink", "tools", "tool_result"]);
103
104#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106pub struct ToolSchema {
107 pub name: String,
109 pub description: String,
111 pub input_schema: serde_json::Value,
113}
114
115impl ToolSchema {
116 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
118 Self {
119 name: name.into(),
120 description: description.into(),
121 input_schema: serde_json::json!({
122 "type": "object",
123 "properties": {},
124 "required": []
125 }),
126 }
127 }
128
129 pub fn new_with_schema(
131 name: impl Into<String>,
132 description: impl Into<String>,
133 input_schema: serde_json::Value,
134 ) -> Self {
135 Self {
136 name: name.into(),
137 description: description.into(),
138 input_schema,
139 }
140 }
141}
142
143impl_serializable!(ToolSchema, ["ferriclink", "tools", "tool_schema"]);
144
145#[async_trait]
147pub trait BaseTool: Send + Sync + 'static {
148 fn name(&self) -> &str;
150
151 fn description(&self) -> &str;
153
154 fn schema(&self) -> ToolSchema;
156
157 fn is_available(&self) -> bool {
159 true
160 }
161
162 fn input_schema(&self) -> Option<serde_json::Value> {
164 Some(self.schema().input_schema.clone())
165 }
166
167 fn output_schema(&self) -> Option<serde_json::Value> {
169 None
170 }
171}
172
173#[async_trait]
175pub trait Tool: BaseTool {
176 async fn invoke(
178 &self,
179 input: HashMap<String, serde_json::Value>,
180 config: Option<RunnableConfig>,
181 ) -> Result<ToolResult>;
182}
183
184pub struct FunctionTool<F> {
186 name: String,
187 description: String,
188 schema: ToolSchema,
189 func: F,
190}
191
192impl<F> FunctionTool<F>
193where
194 F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
195{
196 pub fn new(name: impl Into<String>, description: impl Into<String>, func: F) -> Self {
198 let name = name.into();
199 let description = description.into();
200 let schema = ToolSchema::new(&name, &description);
201
202 Self {
203 name,
204 description,
205 schema,
206 func,
207 }
208 }
209
210 pub fn new_with_schema(
212 name: impl Into<String>,
213 description: impl Into<String>,
214 schema: ToolSchema,
215 func: F,
216 ) -> Self {
217 Self {
218 name: name.into(),
219 description: description.into(),
220 schema,
221 func,
222 }
223 }
224}
225
226#[async_trait]
227impl<F> BaseTool for FunctionTool<F>
228where
229 F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
230{
231 fn name(&self) -> &str {
232 &self.name
233 }
234
235 fn description(&self) -> &str {
236 &self.description
237 }
238
239 fn schema(&self) -> ToolSchema {
240 self.schema.clone()
241 }
242}
243
244#[async_trait]
245impl<F> Tool for FunctionTool<F>
246where
247 F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
248{
249 async fn invoke(
250 &self,
251 input: HashMap<String, serde_json::Value>,
252 _config: Option<RunnableConfig>,
253 ) -> Result<ToolResult> {
254 let content = (self.func)(input)?;
255 Ok(ToolResult::new("", content))
256 }
257}
258
259pub struct RunnableTool<T> {
261 tool: T,
262 tool_call_id: String,
263}
264
265impl<T> RunnableTool<T>
266where
267 T: Tool,
268{
269 pub fn new(tool: T, tool_call_id: impl Into<String>) -> Self {
271 Self {
272 tool,
273 tool_call_id: tool_call_id.into(),
274 }
275 }
276}
277
278#[async_trait]
279impl<T> Runnable<HashMap<String, serde_json::Value>, ToolResult> for RunnableTool<T>
280where
281 T: Tool,
282{
283 async fn invoke(
284 &self,
285 input: HashMap<String, serde_json::Value>,
286 config: Option<RunnableConfig>,
287 ) -> Result<ToolResult> {
288 let mut result = self.tool.invoke(input, config).await?;
289 result.tool_call_id = self.tool_call_id.clone();
290 Ok(result)
291 }
292}
293
294pub struct ToolCollection {
296 tools: HashMap<String, Box<dyn Tool>>,
297}
298
299impl ToolCollection {
300 pub fn new() -> Self {
302 Self {
303 tools: HashMap::new(),
304 }
305 }
306
307 pub fn add_tool<T>(&mut self, tool: T)
309 where
310 T: Tool + 'static,
311 {
312 let name = tool.name().to_string();
313 self.tools.insert(name, Box::new(tool));
314 }
315
316 pub fn get_tool(&self, name: &str) -> Option<&dyn Tool> {
318 self.tools.get(name).map(|t| t.as_ref())
319 }
320
321 pub fn tool_names(&self) -> Vec<&str> {
323 self.tools.keys().map(|s| s.as_str()).collect()
324 }
325
326 pub fn tools(&self) -> &HashMap<String, Box<dyn Tool>> {
328 &self.tools
329 }
330
331 pub fn len(&self) -> usize {
333 self.tools.len()
334 }
335
336 pub fn is_empty(&self) -> bool {
338 self.tools.is_empty()
339 }
340
341 pub async fn invoke_tool(
343 &self,
344 name: &str,
345 input: HashMap<String, serde_json::Value>,
346 config: Option<RunnableConfig>,
347 ) -> Result<ToolResult> {
348 let tool = self.get_tool(name).ok_or_else(|| {
349 crate::errors::FerricLinkError::generic(format!("Tool '{name}' not found"))
350 })?;
351
352 tool.invoke(input, config).await
353 }
354}
355
356impl Default for ToolCollection {
357 fn default() -> Self {
358 Self::new()
359 }
360}
361
362pub fn function_tool<F>(
364 name: impl Into<String>,
365 description: impl Into<String>,
366 func: F,
367) -> FunctionTool<F>
368where
369 F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
370{
371 FunctionTool::new(name, description, func)
372}
373
374pub fn function_tool_with_schema<F>(
376 name: impl Into<String>,
377 description: impl Into<String>,
378 schema: ToolSchema,
379 func: F,
380) -> FunctionTool<F>
381where
382 F: Fn(HashMap<String, serde_json::Value>) -> Result<String> + Send + Sync + 'static,
383{
384 FunctionTool::new_with_schema(name, description, schema, func)
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use crate::serializable::Serializable;
391
392 #[test]
393 fn test_tool_call() {
394 let mut call = ToolCall::new("call_123", "test_tool");
395 call.add_arg("param1", serde_json::Value::String("value1".to_string()));
396
397 assert_eq!(call.id, "call_123");
398 assert_eq!(call.name, "test_tool");
399 assert_eq!(
400 call.get_arg("param1"),
401 Some(&serde_json::Value::String("value1".to_string()))
402 );
403 }
404
405 #[test]
406 fn test_tool_result() {
407 let mut result = ToolResult::new("call_123", "Tool executed successfully");
408 result.add_metadata(
409 "execution_time",
410 serde_json::Value::Number(serde_json::Number::from(100)),
411 );
412
413 assert_eq!(result.tool_call_id, "call_123");
414 assert_eq!(result.content, "Tool executed successfully");
415 assert_eq!(
416 result.metadata.get("execution_time"),
417 Some(&serde_json::Value::Number(serde_json::Number::from(100)))
418 );
419 }
420
421 #[test]
422 fn test_tool_schema() {
423 let schema = ToolSchema::new("test_tool", "A test tool");
424
425 assert_eq!(schema.name, "test_tool");
426 assert_eq!(schema.description, "A test tool");
427 assert!(schema.input_schema.is_object());
428 }
429
430 #[tokio::test]
431 async fn test_function_tool() {
432 let tool = function_tool("add", "Add two numbers", |args| {
433 let a = args.get("a").and_then(|v| v.as_f64()).ok_or_else(|| {
434 crate::errors::FerricLinkError::validation("Missing or invalid 'a' parameter")
435 })?;
436 let b = args.get("b").and_then(|v| v.as_f64()).ok_or_else(|| {
437 crate::errors::FerricLinkError::validation("Missing or invalid 'b' parameter")
438 })?;
439 Ok((a + b).to_string())
440 });
441
442 assert_eq!(tool.name(), "add");
443 assert_eq!(tool.description(), "Add two numbers");
444
445 let mut args = HashMap::new();
446 args.insert(
447 "a".to_string(),
448 serde_json::Value::Number(serde_json::Number::from(5)),
449 );
450 args.insert(
451 "b".to_string(),
452 serde_json::Value::Number(serde_json::Number::from(3)),
453 );
454
455 let result = tool.invoke(args, None).await.unwrap();
456 assert_eq!(result.content, "8");
457 }
458
459 #[tokio::test]
460 async fn test_tool_collection() {
461 let mut collection = ToolCollection::new();
462
463 let add_tool = function_tool("add", "Add two numbers", |args| {
464 let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0);
465 let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0);
466 Ok((a + b).to_string())
467 });
468
469 let multiply_tool = function_tool("multiply", "Multiply two numbers", |args| {
470 let a = args.get("a").and_then(|v| v.as_f64()).unwrap_or(1.0);
471 let b = args.get("b").and_then(|v| v.as_f64()).unwrap_or(1.0);
472 Ok((a * b).to_string())
473 });
474
475 collection.add_tool(add_tool);
476 collection.add_tool(multiply_tool);
477
478 assert_eq!(collection.len(), 2);
479 assert!(!collection.is_empty());
480 assert!(collection.tool_names().contains(&"add"));
481 assert!(collection.tool_names().contains(&"multiply"));
482
483 let mut args = HashMap::new();
484 args.insert(
485 "a".to_string(),
486 serde_json::Value::Number(serde_json::Number::from(4)),
487 );
488 args.insert(
489 "b".to_string(),
490 serde_json::Value::Number(serde_json::Number::from(5)),
491 );
492
493 let result = collection.invoke_tool("add", args, None).await.unwrap();
494 assert_eq!(result.content, "9");
495 }
496
497 #[tokio::test]
498 async fn test_runnable_tool() {
499 let tool = function_tool("test", "Test tool", |_| Ok("test result".to_string()));
500 let runnable_tool = RunnableTool::new(tool, "call_123");
501
502 let args = HashMap::new();
503 let result = runnable_tool.invoke(args, None).await.unwrap();
504
505 assert_eq!(result.tool_call_id, "call_123");
506 assert_eq!(result.content, "test result");
507 }
508
509 #[test]
510 fn test_serialization() {
511 let call = ToolCall::new("call_123", "test_tool");
512 let json = call.to_json().unwrap();
513 let deserialized: ToolCall = ToolCall::from_json(&json).unwrap();
514 assert_eq!(call, deserialized);
515 }
516}