ferriclink_core/rate_limiters.rs
1//! Interface for rate limiters and an in-memory rate limiter.
2//!
3//! This module provides rate limiting functionality for FerricLink, similar to
4//! LangChain's rate_limiters.py with Rust-specific optimizations.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12
13use crate::errors::Result;
14use crate::impl_serializable;
15
16/// Base trait for all rate limiters.
17///
18/// Usage of the base limiter is through the acquire and aacquire methods depending
19/// on whether running in a sync or async context.
20///
21/// Implementations are free to add a timeout parameter to their initialize method
22/// to allow users to specify a timeout for acquiring the necessary tokens when
23/// using a blocking call.
24///
25/// Current limitations:
26///
27/// - Rate limiting information is not surfaced in tracing or callbacks. This means
28/// that the total time it takes to invoke a chat model will encompass both
29/// the time spent waiting for tokens and the time spent making the request.
30#[async_trait]
31pub trait BaseRateLimiter: Send + Sync {
32 /// Attempt to acquire the necessary tokens for the rate limiter.
33 ///
34 /// This method blocks until the required tokens are available if `blocking`
35 /// is set to true.
36 ///
37 /// If `blocking` is set to false, the method will immediately return the result
38 /// of the attempt to acquire the tokens.
39 ///
40 /// # Arguments
41 ///
42 /// * `blocking` - If true, the method will block until the tokens are available.
43 /// If false, the method will return immediately with the result of
44 /// the attempt. Defaults to true.
45 ///
46 /// # Returns
47 ///
48 /// True if the tokens were successfully acquired, false otherwise.
49 fn acquire(&self, blocking: bool) -> Result<bool>;
50
51 /// Attempt to acquire the necessary tokens for the rate limiter. Async version.
52 ///
53 /// This method blocks until the required tokens are available if `blocking`
54 /// is set to true.
55 ///
56 /// If `blocking` is set to false, the method will return immediately with the result
57 /// of the attempt to acquire the tokens.
58 ///
59 /// # Arguments
60 ///
61 /// * `blocking` - If true, the method will block until the tokens are available.
62 /// If false, the method will return immediately with the result of
63 /// the attempt. Defaults to true.
64 ///
65 /// # Returns
66 ///
67 /// True if the tokens were successfully acquired, false otherwise.
68 async fn aacquire(&self, blocking: bool) -> Result<bool>;
69}
70
71/// An in-memory rate limiter based on a token bucket algorithm.
72///
73/// This is an in-memory rate limiter, so it cannot rate limit across
74/// different processes.
75///
76/// The rate limiter only allows time-based rate limiting and does not
77/// take into account any information about the input or the output, so it
78/// cannot be used to rate limit based on the size of the request.
79///
80/// It is thread safe and can be used in either a sync or async context.
81///
82/// The in-memory rate limiter is based on a token bucket. The bucket is filled
83/// with tokens at a given rate. Each request consumes a token. If there are
84/// not enough tokens in the bucket, the request is blocked until there are
85/// enough tokens.
86///
87/// These *tokens* have NOTHING to do with LLM tokens. They are just
88/// a way to keep track of how many requests can be made at a given time.
89///
90/// Current limitations:
91///
92/// - The rate limiter is not designed to work across different processes. It is
93/// an in-memory rate limiter, but it is thread safe.
94/// - The rate limiter only supports time-based rate limiting. It does not take
95/// into account the size of the request or any other factors.
96///
97/// # Example
98///
99/// ```rust
100/// use ferriclink_core::rate_limiters::InMemoryRateLimiter;
101/// use std::time::Duration;
102///
103/// let rate_limiter = InMemoryRateLimiter::new(
104/// 0.1, // Can only make a request once every 10 seconds
105/// 0.1, // Wake up every 100 ms to check whether allowed to make a request
106/// 10.0, // Controls the maximum burst size
107/// );
108///
109/// // Use with a language model
110/// // let model = ChatAnthropic::new()
111/// // .with_rate_limiter(rate_limiter);
112/// ```
113#[derive(Debug, Clone)]
114pub struct InMemoryRateLimiter {
115 /// Number of requests that we can make per second
116 requests_per_second: f64,
117 /// Number of tokens in the bucket
118 available_tokens: Arc<Mutex<f64>>,
119 /// Maximum number of tokens that can be in the bucket
120 max_bucket_size: f64,
121 /// The last time we tried to consume tokens
122 last: Arc<Mutex<Option<Instant>>>,
123 /// Check whether tokens are available every this many seconds
124 check_every_n_seconds: f64,
125}
126
127/// Serializable version of InMemoryRateLimiter for configuration
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct InMemoryRateLimiterConfig {
130 /// Number of requests that we can make per second
131 pub requests_per_second: f64,
132 /// Maximum number of tokens that can be in the bucket
133 pub max_bucket_size: f64,
134 /// Check whether tokens are available every this many seconds
135 pub check_every_n_seconds: f64,
136}
137
138impl InMemoryRateLimiter {
139 /// Create a new in-memory rate limiter based on a token bucket.
140 ///
141 /// These *tokens* have NOTHING to do with LLM tokens. They are just
142 /// a way to keep track of how many requests can be made at a given time.
143 ///
144 /// This rate limiter is designed to work in a threaded environment.
145 ///
146 /// It works by filling up a bucket with tokens at a given rate. Each
147 /// request consumes a given number of tokens. If there are not enough
148 /// tokens in the bucket, the request is blocked until there are enough
149 /// tokens.
150 ///
151 /// # Arguments
152 ///
153 /// * `requests_per_second` - The number of tokens to add per second to the bucket.
154 /// The tokens represent "credit" that can be used to make requests.
155 /// * `check_every_n_seconds` - Check whether the tokens are available
156 /// every this many seconds. Can be a float to represent
157 /// fractions of a second.
158 /// * `max_bucket_size` - The maximum number of tokens that can be in the bucket.
159 /// Must be at least 1. Used to prevent bursts of requests.
160 ///
161 /// # Panics
162 ///
163 /// Panics if `max_bucket_size` is less than 1.0.
164 pub fn new(requests_per_second: f64, check_every_n_seconds: f64, max_bucket_size: f64) -> Self {
165 assert!(
166 max_bucket_size >= 1.0,
167 "max_bucket_size must be at least 1.0"
168 );
169 assert!(
170 requests_per_second > 0.0,
171 "requests_per_second must be greater than 0.0"
172 );
173 assert!(
174 check_every_n_seconds > 0.0,
175 "check_every_n_seconds must be greater than 0.0"
176 );
177
178 Self {
179 requests_per_second,
180 available_tokens: Arc::new(Mutex::new(1.0)), // Start with 1 token to allow first request
181 max_bucket_size,
182 last: Arc::new(Mutex::new(None)),
183 check_every_n_seconds,
184 }
185 }
186
187 /// Try to consume a token.
188 ///
189 /// # Returns
190 ///
191 /// True means that the tokens were consumed, and the caller can proceed to
192 /// make the request. False means that the tokens were not consumed, and
193 /// the caller should try again later.
194 async fn consume(&self) -> Result<bool> {
195 let mut available_tokens = self.available_tokens.lock().await;
196 let mut last = self.last.lock().await;
197
198 let now = Instant::now();
199
200 // Initialize on first call to avoid a burst
201 if last.is_none() {
202 *last = Some(now);
203 }
204
205 let elapsed = now.duration_since(last.unwrap()).as_secs_f64();
206
207 if elapsed * self.requests_per_second >= 1.0 {
208 *available_tokens += elapsed * self.requests_per_second;
209 *last = Some(now);
210 }
211
212 // Make sure that we don't exceed the bucket size.
213 // This is used to prevent bursts of requests.
214 *available_tokens = (*available_tokens).min(self.max_bucket_size);
215
216 // As long as we have at least one token, we can proceed.
217 if *available_tokens >= 1.0 {
218 *available_tokens -= 1.0;
219 Ok(true)
220 } else {
221 Ok(false)
222 }
223 }
224
225 /// Get the current number of available tokens
226 pub async fn available_tokens(&self) -> f64 {
227 *self.available_tokens.lock().await
228 }
229
230 /// Get the maximum bucket size
231 pub fn max_bucket_size(&self) -> f64 {
232 self.max_bucket_size
233 }
234
235 /// Get the requests per second rate
236 pub fn requests_per_second(&self) -> f64 {
237 self.requests_per_second
238 }
239
240 /// Get the check interval in seconds
241 pub fn check_every_n_seconds(&self) -> f64 {
242 self.check_every_n_seconds
243 }
244
245 /// Convert to a serializable configuration
246 pub fn to_config(&self) -> InMemoryRateLimiterConfig {
247 InMemoryRateLimiterConfig {
248 requests_per_second: self.requests_per_second,
249 max_bucket_size: self.max_bucket_size,
250 check_every_n_seconds: self.check_every_n_seconds,
251 }
252 }
253
254 /// Create from a serializable configuration
255 pub fn from_config(config: InMemoryRateLimiterConfig) -> Self {
256 Self::new(
257 config.requests_per_second,
258 config.check_every_n_seconds,
259 config.max_bucket_size,
260 )
261 }
262}
263
264impl_serializable!(
265 InMemoryRateLimiterConfig,
266 [
267 "ferriclink",
268 "rate_limiters",
269 "in_memory_rate_limiter_config"
270 ]
271);
272
273#[async_trait]
274impl BaseRateLimiter for InMemoryRateLimiter {
275 fn acquire(&self, blocking: bool) -> Result<bool> {
276 // For sync context, we need to use a blocking approach
277 if !blocking {
278 // Use tokio::runtime::Handle::try_current() to run async code in sync context
279 if let Ok(handle) = tokio::runtime::Handle::try_current() {
280 return handle.block_on(self.consume());
281 } else {
282 // If we're not in an async context, create a new runtime
283 let rt = tokio::runtime::Runtime::new().map_err(|e| {
284 crate::errors::FerricLinkError::runtime(format!(
285 "Failed to create runtime: {e}",
286 ))
287 })?;
288 return rt.block_on(self.consume());
289 }
290 }
291
292 // For blocking mode, we need to poll until we can acquire
293 loop {
294 let acquired = if let Ok(handle) = tokio::runtime::Handle::try_current() {
295 handle.block_on(self.consume())?
296 } else {
297 let rt = tokio::runtime::Runtime::new().map_err(|e| {
298 crate::errors::FerricLinkError::runtime(format!(
299 "Failed to create runtime: {e}",
300 ))
301 })?;
302 rt.block_on(self.consume())?
303 };
304
305 if acquired {
306 return Ok(true);
307 }
308
309 // Sleep for the check interval
310 std::thread::sleep(Duration::from_secs_f64(self.check_every_n_seconds));
311 }
312 }
313
314 async fn aacquire(&self, blocking: bool) -> Result<bool> {
315 if !blocking {
316 return self.consume().await;
317 }
318
319 loop {
320 if self.consume().await? {
321 return Ok(true);
322 }
323
324 sleep(Duration::from_secs_f64(self.check_every_n_seconds)).await;
325 }
326 }
327}
328
329/// A more advanced rate limiter that supports different rate limiting strategies.
330#[derive(Debug, Clone)]
331pub struct AdvancedRateLimiter {
332 /// The underlying rate limiter
333 inner: InMemoryRateLimiter,
334 /// Additional configuration
335 config: RateLimiterConfig,
336}
337
338/// Configuration for the advanced rate limiter
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct RateLimiterConfig {
341 /// Whether to use exponential backoff on rate limit errors
342 pub use_exponential_backoff: bool,
343 /// Maximum backoff duration
344 pub max_backoff_duration: Duration,
345 /// Initial backoff duration
346 pub initial_backoff_duration: Duration,
347 /// Maximum number of retries
348 pub max_retries: u32,
349 /// Whether to log rate limiting events
350 pub log_events: bool,
351}
352
353impl Default for RateLimiterConfig {
354 fn default() -> Self {
355 Self {
356 use_exponential_backoff: true,
357 max_backoff_duration: Duration::from_secs(60),
358 initial_backoff_duration: Duration::from_millis(100),
359 max_retries: 5,
360 log_events: false,
361 }
362 }
363}
364
365impl AdvancedRateLimiter {
366 /// Create a new advanced rate limiter
367 pub fn new(
368 requests_per_second: f64,
369 check_every_n_seconds: f64,
370 max_bucket_size: f64,
371 config: RateLimiterConfig,
372 ) -> Self {
373 Self {
374 inner: InMemoryRateLimiter::new(
375 requests_per_second,
376 check_every_n_seconds,
377 max_bucket_size,
378 ),
379 config,
380 }
381 }
382
383 /// Acquire with retry logic and exponential backoff
384 pub async fn acquire_with_retry(&self, blocking: bool) -> Result<bool> {
385 let mut backoff_duration = self.config.initial_backoff_duration;
386 let mut retries = 0;
387
388 loop {
389 match self.inner.aacquire(blocking).await {
390 Ok(true) => {
391 if self.config.log_events {
392 println!("Rate limiter: Token acquired successfully");
393 }
394 return Ok(true);
395 }
396 Ok(false) => {
397 if !blocking {
398 return Ok(false);
399 }
400
401 if retries >= self.config.max_retries {
402 return Err(crate::errors::FerricLinkError::model_rate_limit(
403 "Max retries exceeded for rate limiter",
404 ));
405 }
406
407 if self.config.log_events {
408 println!(
409 "Rate limiter: Token not available, retrying in {:?} (attempt {})",
410 backoff_duration,
411 retries + 1
412 );
413 }
414
415 sleep(backoff_duration).await;
416
417 if self.config.use_exponential_backoff {
418 backoff_duration = backoff_duration
419 .mul_f64(2.0)
420 .min(self.config.max_backoff_duration);
421 }
422
423 retries += 1;
424 }
425 Err(e) => return Err(e),
426 }
427 }
428 }
429
430 /// Get the current configuration
431 pub fn config(&self) -> &RateLimiterConfig {
432 &self.config
433 }
434
435 /// Update the configuration
436 pub fn update_config(&mut self, config: RateLimiterConfig) {
437 self.config = config;
438 }
439}
440
441#[async_trait]
442impl BaseRateLimiter for AdvancedRateLimiter {
443 fn acquire(&self, blocking: bool) -> Result<bool> {
444 self.inner.acquire(blocking)
445 }
446
447 async fn aacquire(&self, blocking: bool) -> Result<bool> {
448 self.acquire_with_retry(blocking).await
449 }
450}
451
452impl_serializable!(
453 RateLimiterConfig,
454 ["ferriclink", "rate_limiters", "rate_limiter_config"]
455);
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[tokio::test]
462 async fn test_in_memory_rate_limiter_basic() {
463 let rate_limiter = InMemoryRateLimiter::new(1.0, 0.1, 2.0);
464
465 // First request should succeed
466 assert!(rate_limiter.aacquire(true).await.unwrap());
467
468 // Second request should fail immediately (no blocking)
469 assert!(!rate_limiter.aacquire(false).await.unwrap());
470
471 // Wait a bit and try again
472 sleep(Duration::from_millis(1100)).await;
473 assert!(rate_limiter.aacquire(true).await.unwrap());
474 }
475
476 #[tokio::test]
477 async fn test_in_memory_rate_limiter_burst() {
478 let rate_limiter = InMemoryRateLimiter::new(10.0, 0.01, 5.0);
479
480 // First call initializes the timer and should succeed
481 assert!(rate_limiter.aacquire(false).await.unwrap());
482
483 // Wait enough time to accumulate more tokens
484 sleep(Duration::from_millis(200)).await;
485
486 // Should be able to make more requests
487 let mut successful = 1; // We already made one successful request
488 for _ in 0..4 {
489 if rate_limiter.aacquire(false).await.unwrap() {
490 successful += 1;
491 }
492 }
493
494 // Should have made at least 2 successful requests total
495 assert!(successful >= 2);
496 }
497
498 #[tokio::test]
499 async fn test_advanced_rate_limiter() {
500 let config = RateLimiterConfig {
501 use_exponential_backoff: true,
502 max_backoff_duration: Duration::from_secs(1),
503 initial_backoff_duration: Duration::from_millis(10),
504 max_retries: 3,
505 log_events: false,
506 };
507
508 let rate_limiter = AdvancedRateLimiter::new(1.0, 0.01, 1.0, config);
509
510 // First request should succeed
511 assert!(rate_limiter.aacquire(true).await.unwrap());
512
513 // Second request should succeed after retry
514 assert!(rate_limiter.aacquire(true).await.unwrap());
515 }
516
517 #[test]
518 fn test_rate_limiter_creation() {
519 let rate_limiter = InMemoryRateLimiter::new(2.0, 0.1, 5.0);
520 assert_eq!(rate_limiter.requests_per_second(), 2.0);
521 assert_eq!(rate_limiter.check_every_n_seconds(), 0.1);
522 assert_eq!(rate_limiter.max_bucket_size(), 5.0);
523 }
524
525 #[test]
526 #[should_panic(expected = "max_bucket_size must be at least 1.0")]
527 fn test_invalid_max_bucket_size() {
528 InMemoryRateLimiter::new(1.0, 0.1, 0.5);
529 }
530
531 #[test]
532 #[should_panic(expected = "requests_per_second must be greater than 0.0")]
533 fn test_invalid_requests_per_second() {
534 InMemoryRateLimiter::new(0.0, 0.1, 1.0);
535 }
536
537 #[test]
538 #[should_panic(expected = "check_every_n_seconds must be greater than 0.0")]
539 fn test_invalid_check_interval() {
540 InMemoryRateLimiter::new(1.0, 0.0, 1.0);
541 }
542
543 #[tokio::test]
544 async fn test_available_tokens() {
545 let rate_limiter = InMemoryRateLimiter::new(1.0, 0.1, 2.0);
546
547 // Initially should have 1 token (to allow first request)
548 assert_eq!(rate_limiter.available_tokens().await, 1.0);
549
550 // First call should succeed and consume the token
551 rate_limiter.aacquire(false).await.unwrap();
552 assert_eq!(rate_limiter.available_tokens().await, 0.0);
553
554 // Wait enough time to accumulate more tokens (1 second for 1 token at 1 req/sec)
555 sleep(Duration::from_millis(1100)).await;
556
557 // Try to acquire a token to trigger token accumulation
558 let acquired = rate_limiter.aacquire(false).await.unwrap();
559 assert!(acquired, "Should have acquired a token after waiting");
560
561 // After acquiring the token, should have fewer tokens (but not necessarily 0 due to accumulation)
562 let tokens_after = rate_limiter.available_tokens().await;
563 assert!(tokens_after >= 0.0);
564 }
565
566 #[tokio::test]
567 async fn test_serialization() {
568 let rate_limiter = InMemoryRateLimiter::new(2.0, 0.1, 5.0);
569 let config = rate_limiter.to_config();
570 let serialized = serde_json::to_string(&config).unwrap();
571 let deserialized_config: InMemoryRateLimiterConfig =
572 serde_json::from_str(&serialized).unwrap();
573 let deserialized_rate_limiter = InMemoryRateLimiter::from_config(deserialized_config);
574
575 assert_eq!(
576 rate_limiter.requests_per_second(),
577 deserialized_rate_limiter.requests_per_second()
578 );
579 assert_eq!(
580 rate_limiter.check_every_n_seconds(),
581 deserialized_rate_limiter.check_every_n_seconds()
582 );
583 assert_eq!(
584 rate_limiter.max_bucket_size(),
585 deserialized_rate_limiter.max_bucket_size()
586 );
587 }
588}