ferriclink_core/
rate_limiters.rs

1//! Interface for rate limiters and an in-memory rate limiter.
2//!
3//! This module provides rate limiting functionality for FerricLink, similar to
4//! LangChain's rate_limiters.py with Rust-specific optimizations.
5
6use async_trait::async_trait;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::Mutex;
11use tokio::time::sleep;
12
13use crate::errors::Result;
14use crate::impl_serializable;
15
16/// Base trait for all rate limiters.
17///
18/// Usage of the base limiter is through the acquire and aacquire methods depending
19/// on whether running in a sync or async context.
20///
21/// Implementations are free to add a timeout parameter to their initialize method
22/// to allow users to specify a timeout for acquiring the necessary tokens when
23/// using a blocking call.
24///
25/// Current limitations:
26///
27/// - Rate limiting information is not surfaced in tracing or callbacks. This means
28///   that the total time it takes to invoke a chat model will encompass both
29///   the time spent waiting for tokens and the time spent making the request.
30#[async_trait]
31pub trait BaseRateLimiter: Send + Sync {
32    /// Attempt to acquire the necessary tokens for the rate limiter.
33    ///
34    /// This method blocks until the required tokens are available if `blocking`
35    /// is set to true.
36    ///
37    /// If `blocking` is set to false, the method will immediately return the result
38    /// of the attempt to acquire the tokens.
39    ///
40    /// # Arguments
41    ///
42    /// * `blocking` - If true, the method will block until the tokens are available.
43    ///   If false, the method will return immediately with the result of
44    ///   the attempt. Defaults to true.
45    ///
46    /// # Returns
47    ///
48    /// True if the tokens were successfully acquired, false otherwise.
49    fn acquire(&self, blocking: bool) -> Result<bool>;
50
51    /// Attempt to acquire the necessary tokens for the rate limiter. Async version.
52    ///
53    /// This method blocks until the required tokens are available if `blocking`
54    /// is set to true.
55    ///
56    /// If `blocking` is set to false, the method will return immediately with the result
57    /// of the attempt to acquire the tokens.
58    ///
59    /// # Arguments
60    ///
61    /// * `blocking` - If true, the method will block until the tokens are available.
62    ///   If false, the method will return immediately with the result of
63    ///   the attempt. Defaults to true.
64    ///
65    /// # Returns
66    ///
67    /// True if the tokens were successfully acquired, false otherwise.
68    async fn aacquire(&self, blocking: bool) -> Result<bool>;
69}
70
71/// An in-memory rate limiter based on a token bucket algorithm.
72///
73/// This is an in-memory rate limiter, so it cannot rate limit across
74/// different processes.
75///
76/// The rate limiter only allows time-based rate limiting and does not
77/// take into account any information about the input or the output, so it
78/// cannot be used to rate limit based on the size of the request.
79///
80/// It is thread safe and can be used in either a sync or async context.
81///
82/// The in-memory rate limiter is based on a token bucket. The bucket is filled
83/// with tokens at a given rate. Each request consumes a token. If there are
84/// not enough tokens in the bucket, the request is blocked until there are
85/// enough tokens.
86///
87/// These *tokens* have NOTHING to do with LLM tokens. They are just
88/// a way to keep track of how many requests can be made at a given time.
89///
90/// Current limitations:
91///
92/// - The rate limiter is not designed to work across different processes. It is
93///   an in-memory rate limiter, but it is thread safe.
94/// - The rate limiter only supports time-based rate limiting. It does not take
95///   into account the size of the request or any other factors.
96///
97/// # Example
98///
99/// ```rust
100/// use ferriclink_core::rate_limiters::InMemoryRateLimiter;
101/// use std::time::Duration;
102///
103/// let rate_limiter = InMemoryRateLimiter::new(
104///     0.1,  // Can only make a request once every 10 seconds
105///     0.1,  // Wake up every 100 ms to check whether allowed to make a request
106///     10.0, // Controls the maximum burst size
107/// );
108///
109/// // Use with a language model
110/// // let model = ChatAnthropic::new()
111/// //     .with_rate_limiter(rate_limiter);
112/// ```
113#[derive(Debug, Clone)]
114pub struct InMemoryRateLimiter {
115    /// Number of requests that we can make per second
116    requests_per_second: f64,
117    /// Number of tokens in the bucket
118    available_tokens: Arc<Mutex<f64>>,
119    /// Maximum number of tokens that can be in the bucket
120    max_bucket_size: f64,
121    /// The last time we tried to consume tokens
122    last: Arc<Mutex<Option<Instant>>>,
123    /// Check whether tokens are available every this many seconds
124    check_every_n_seconds: f64,
125}
126
127/// Serializable version of InMemoryRateLimiter for configuration
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct InMemoryRateLimiterConfig {
130    /// Number of requests that we can make per second
131    pub requests_per_second: f64,
132    /// Maximum number of tokens that can be in the bucket
133    pub max_bucket_size: f64,
134    /// Check whether tokens are available every this many seconds
135    pub check_every_n_seconds: f64,
136}
137
138impl InMemoryRateLimiter {
139    /// Create a new in-memory rate limiter based on a token bucket.
140    ///
141    /// These *tokens* have NOTHING to do with LLM tokens. They are just
142    /// a way to keep track of how many requests can be made at a given time.
143    ///
144    /// This rate limiter is designed to work in a threaded environment.
145    ///
146    /// It works by filling up a bucket with tokens at a given rate. Each
147    /// request consumes a given number of tokens. If there are not enough
148    /// tokens in the bucket, the request is blocked until there are enough
149    /// tokens.
150    ///
151    /// # Arguments
152    ///
153    /// * `requests_per_second` - The number of tokens to add per second to the bucket.
154    ///   The tokens represent "credit" that can be used to make requests.
155    /// * `check_every_n_seconds` - Check whether the tokens are available
156    ///   every this many seconds. Can be a float to represent
157    ///   fractions of a second.
158    /// * `max_bucket_size` - The maximum number of tokens that can be in the bucket.
159    ///   Must be at least 1. Used to prevent bursts of requests.
160    ///
161    /// # Panics
162    ///
163    /// Panics if `max_bucket_size` is less than 1.0.
164    pub fn new(requests_per_second: f64, check_every_n_seconds: f64, max_bucket_size: f64) -> Self {
165        assert!(
166            max_bucket_size >= 1.0,
167            "max_bucket_size must be at least 1.0"
168        );
169        assert!(
170            requests_per_second > 0.0,
171            "requests_per_second must be greater than 0.0"
172        );
173        assert!(
174            check_every_n_seconds > 0.0,
175            "check_every_n_seconds must be greater than 0.0"
176        );
177
178        Self {
179            requests_per_second,
180            available_tokens: Arc::new(Mutex::new(1.0)), // Start with 1 token to allow first request
181            max_bucket_size,
182            last: Arc::new(Mutex::new(None)),
183            check_every_n_seconds,
184        }
185    }
186
187    /// Try to consume a token.
188    ///
189    /// # Returns
190    ///
191    /// True means that the tokens were consumed, and the caller can proceed to
192    /// make the request. False means that the tokens were not consumed, and
193    /// the caller should try again later.
194    async fn consume(&self) -> Result<bool> {
195        let mut available_tokens = self.available_tokens.lock().await;
196        let mut last = self.last.lock().await;
197
198        let now = Instant::now();
199
200        // Initialize on first call to avoid a burst
201        if last.is_none() {
202            *last = Some(now);
203        }
204
205        let elapsed = now.duration_since(last.unwrap()).as_secs_f64();
206
207        if elapsed * self.requests_per_second >= 1.0 {
208            *available_tokens += elapsed * self.requests_per_second;
209            *last = Some(now);
210        }
211
212        // Make sure that we don't exceed the bucket size.
213        // This is used to prevent bursts of requests.
214        *available_tokens = (*available_tokens).min(self.max_bucket_size);
215
216        // As long as we have at least one token, we can proceed.
217        if *available_tokens >= 1.0 {
218            *available_tokens -= 1.0;
219            Ok(true)
220        } else {
221            Ok(false)
222        }
223    }
224
225    /// Get the current number of available tokens
226    pub async fn available_tokens(&self) -> f64 {
227        *self.available_tokens.lock().await
228    }
229
230    /// Get the maximum bucket size
231    pub fn max_bucket_size(&self) -> f64 {
232        self.max_bucket_size
233    }
234
235    /// Get the requests per second rate
236    pub fn requests_per_second(&self) -> f64 {
237        self.requests_per_second
238    }
239
240    /// Get the check interval in seconds
241    pub fn check_every_n_seconds(&self) -> f64 {
242        self.check_every_n_seconds
243    }
244
245    /// Convert to a serializable configuration
246    pub fn to_config(&self) -> InMemoryRateLimiterConfig {
247        InMemoryRateLimiterConfig {
248            requests_per_second: self.requests_per_second,
249            max_bucket_size: self.max_bucket_size,
250            check_every_n_seconds: self.check_every_n_seconds,
251        }
252    }
253
254    /// Create from a serializable configuration
255    pub fn from_config(config: InMemoryRateLimiterConfig) -> Self {
256        Self::new(
257            config.requests_per_second,
258            config.check_every_n_seconds,
259            config.max_bucket_size,
260        )
261    }
262}
263
264impl_serializable!(
265    InMemoryRateLimiterConfig,
266    [
267        "ferriclink",
268        "rate_limiters",
269        "in_memory_rate_limiter_config"
270    ]
271);
272
273#[async_trait]
274impl BaseRateLimiter for InMemoryRateLimiter {
275    fn acquire(&self, blocking: bool) -> Result<bool> {
276        // For sync context, we need to use a blocking approach
277        if !blocking {
278            // Use tokio::runtime::Handle::try_current() to run async code in sync context
279            if let Ok(handle) = tokio::runtime::Handle::try_current() {
280                return handle.block_on(self.consume());
281            } else {
282                // If we're not in an async context, create a new runtime
283                let rt = tokio::runtime::Runtime::new().map_err(|e| {
284                    crate::errors::FerricLinkError::runtime(format!(
285                        "Failed to create runtime: {e}",
286                    ))
287                })?;
288                return rt.block_on(self.consume());
289            }
290        }
291
292        // For blocking mode, we need to poll until we can acquire
293        loop {
294            let acquired = if let Ok(handle) = tokio::runtime::Handle::try_current() {
295                handle.block_on(self.consume())?
296            } else {
297                let rt = tokio::runtime::Runtime::new().map_err(|e| {
298                    crate::errors::FerricLinkError::runtime(format!(
299                        "Failed to create runtime: {e}",
300                    ))
301                })?;
302                rt.block_on(self.consume())?
303            };
304
305            if acquired {
306                return Ok(true);
307            }
308
309            // Sleep for the check interval
310            std::thread::sleep(Duration::from_secs_f64(self.check_every_n_seconds));
311        }
312    }
313
314    async fn aacquire(&self, blocking: bool) -> Result<bool> {
315        if !blocking {
316            return self.consume().await;
317        }
318
319        loop {
320            if self.consume().await? {
321                return Ok(true);
322            }
323
324            sleep(Duration::from_secs_f64(self.check_every_n_seconds)).await;
325        }
326    }
327}
328
329/// A more advanced rate limiter that supports different rate limiting strategies.
330#[derive(Debug, Clone)]
331pub struct AdvancedRateLimiter {
332    /// The underlying rate limiter
333    inner: InMemoryRateLimiter,
334    /// Additional configuration
335    config: RateLimiterConfig,
336}
337
338/// Configuration for the advanced rate limiter
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct RateLimiterConfig {
341    /// Whether to use exponential backoff on rate limit errors
342    pub use_exponential_backoff: bool,
343    /// Maximum backoff duration
344    pub max_backoff_duration: Duration,
345    /// Initial backoff duration
346    pub initial_backoff_duration: Duration,
347    /// Maximum number of retries
348    pub max_retries: u32,
349    /// Whether to log rate limiting events
350    pub log_events: bool,
351}
352
353impl Default for RateLimiterConfig {
354    fn default() -> Self {
355        Self {
356            use_exponential_backoff: true,
357            max_backoff_duration: Duration::from_secs(60),
358            initial_backoff_duration: Duration::from_millis(100),
359            max_retries: 5,
360            log_events: false,
361        }
362    }
363}
364
365impl AdvancedRateLimiter {
366    /// Create a new advanced rate limiter
367    pub fn new(
368        requests_per_second: f64,
369        check_every_n_seconds: f64,
370        max_bucket_size: f64,
371        config: RateLimiterConfig,
372    ) -> Self {
373        Self {
374            inner: InMemoryRateLimiter::new(
375                requests_per_second,
376                check_every_n_seconds,
377                max_bucket_size,
378            ),
379            config,
380        }
381    }
382
383    /// Acquire with retry logic and exponential backoff
384    pub async fn acquire_with_retry(&self, blocking: bool) -> Result<bool> {
385        let mut backoff_duration = self.config.initial_backoff_duration;
386        let mut retries = 0;
387
388        loop {
389            match self.inner.aacquire(blocking).await {
390                Ok(true) => {
391                    if self.config.log_events {
392                        println!("Rate limiter: Token acquired successfully");
393                    }
394                    return Ok(true);
395                }
396                Ok(false) => {
397                    if !blocking {
398                        return Ok(false);
399                    }
400
401                    if retries >= self.config.max_retries {
402                        return Err(crate::errors::FerricLinkError::model_rate_limit(
403                            "Max retries exceeded for rate limiter",
404                        ));
405                    }
406
407                    if self.config.log_events {
408                        println!(
409                            "Rate limiter: Token not available, retrying in {:?} (attempt {})",
410                            backoff_duration,
411                            retries + 1
412                        );
413                    }
414
415                    sleep(backoff_duration).await;
416
417                    if self.config.use_exponential_backoff {
418                        backoff_duration = backoff_duration
419                            .mul_f64(2.0)
420                            .min(self.config.max_backoff_duration);
421                    }
422
423                    retries += 1;
424                }
425                Err(e) => return Err(e),
426            }
427        }
428    }
429
430    /// Get the current configuration
431    pub fn config(&self) -> &RateLimiterConfig {
432        &self.config
433    }
434
435    /// Update the configuration
436    pub fn update_config(&mut self, config: RateLimiterConfig) {
437        self.config = config;
438    }
439}
440
441#[async_trait]
442impl BaseRateLimiter for AdvancedRateLimiter {
443    fn acquire(&self, blocking: bool) -> Result<bool> {
444        self.inner.acquire(blocking)
445    }
446
447    async fn aacquire(&self, blocking: bool) -> Result<bool> {
448        self.acquire_with_retry(blocking).await
449    }
450}
451
452impl_serializable!(
453    RateLimiterConfig,
454    ["ferriclink", "rate_limiters", "rate_limiter_config"]
455);
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[tokio::test]
462    async fn test_in_memory_rate_limiter_basic() {
463        let rate_limiter = InMemoryRateLimiter::new(1.0, 0.1, 2.0);
464
465        // First request should succeed
466        assert!(rate_limiter.aacquire(true).await.unwrap());
467
468        // Second request should fail immediately (no blocking)
469        assert!(!rate_limiter.aacquire(false).await.unwrap());
470
471        // Wait a bit and try again
472        sleep(Duration::from_millis(1100)).await;
473        assert!(rate_limiter.aacquire(true).await.unwrap());
474    }
475
476    #[tokio::test]
477    async fn test_in_memory_rate_limiter_burst() {
478        let rate_limiter = InMemoryRateLimiter::new(10.0, 0.01, 5.0);
479
480        // First call initializes the timer and should succeed
481        assert!(rate_limiter.aacquire(false).await.unwrap());
482
483        // Wait enough time to accumulate more tokens
484        sleep(Duration::from_millis(200)).await;
485
486        // Should be able to make more requests
487        let mut successful = 1; // We already made one successful request
488        for _ in 0..4 {
489            if rate_limiter.aacquire(false).await.unwrap() {
490                successful += 1;
491            }
492        }
493
494        // Should have made at least 2 successful requests total
495        assert!(successful >= 2);
496    }
497
498    #[tokio::test]
499    async fn test_advanced_rate_limiter() {
500        let config = RateLimiterConfig {
501            use_exponential_backoff: true,
502            max_backoff_duration: Duration::from_secs(1),
503            initial_backoff_duration: Duration::from_millis(10),
504            max_retries: 3,
505            log_events: false,
506        };
507
508        let rate_limiter = AdvancedRateLimiter::new(1.0, 0.01, 1.0, config);
509
510        // First request should succeed
511        assert!(rate_limiter.aacquire(true).await.unwrap());
512
513        // Second request should succeed after retry
514        assert!(rate_limiter.aacquire(true).await.unwrap());
515    }
516
517    #[test]
518    fn test_rate_limiter_creation() {
519        let rate_limiter = InMemoryRateLimiter::new(2.0, 0.1, 5.0);
520        assert_eq!(rate_limiter.requests_per_second(), 2.0);
521        assert_eq!(rate_limiter.check_every_n_seconds(), 0.1);
522        assert_eq!(rate_limiter.max_bucket_size(), 5.0);
523    }
524
525    #[test]
526    #[should_panic(expected = "max_bucket_size must be at least 1.0")]
527    fn test_invalid_max_bucket_size() {
528        InMemoryRateLimiter::new(1.0, 0.1, 0.5);
529    }
530
531    #[test]
532    #[should_panic(expected = "requests_per_second must be greater than 0.0")]
533    fn test_invalid_requests_per_second() {
534        InMemoryRateLimiter::new(0.0, 0.1, 1.0);
535    }
536
537    #[test]
538    #[should_panic(expected = "check_every_n_seconds must be greater than 0.0")]
539    fn test_invalid_check_interval() {
540        InMemoryRateLimiter::new(1.0, 0.0, 1.0);
541    }
542
543    #[tokio::test]
544    async fn test_available_tokens() {
545        let rate_limiter = InMemoryRateLimiter::new(1.0, 0.1, 2.0);
546
547        // Initially should have 1 token (to allow first request)
548        assert_eq!(rate_limiter.available_tokens().await, 1.0);
549
550        // First call should succeed and consume the token
551        rate_limiter.aacquire(false).await.unwrap();
552        assert_eq!(rate_limiter.available_tokens().await, 0.0);
553
554        // Wait enough time to accumulate more tokens (1 second for 1 token at 1 req/sec)
555        sleep(Duration::from_millis(1100)).await;
556
557        // Try to acquire a token to trigger token accumulation
558        let acquired = rate_limiter.aacquire(false).await.unwrap();
559        assert!(acquired, "Should have acquired a token after waiting");
560
561        // After acquiring the token, should have fewer tokens (but not necessarily 0 due to accumulation)
562        let tokens_after = rate_limiter.available_tokens().await;
563        assert!(tokens_after >= 0.0);
564    }
565
566    #[tokio::test]
567    async fn test_serialization() {
568        let rate_limiter = InMemoryRateLimiter::new(2.0, 0.1, 5.0);
569        let config = rate_limiter.to_config();
570        let serialized = serde_json::to_string(&config).unwrap();
571        let deserialized_config: InMemoryRateLimiterConfig =
572            serde_json::from_str(&serialized).unwrap();
573        let deserialized_rate_limiter = InMemoryRateLimiter::from_config(deserialized_config);
574
575        assert_eq!(
576            rate_limiter.requests_per_second(),
577            deserialized_rate_limiter.requests_per_second()
578        );
579        assert_eq!(
580            rate_limiter.check_every_n_seconds(),
581            deserialized_rate_limiter.check_every_n_seconds()
582        );
583        assert_eq!(
584            rate_limiter.max_bucket_size(),
585            deserialized_rate_limiter.max_bucket_size()
586        );
587    }
588}