1use serde::{Deserialize, Serialize};
10
11pub trait Visitor {
13 type Output;
15
16 fn allowed_comparators(&self) -> Option<&[Comparator]> {
18 None
19 }
20
21 fn allowed_operators(&self) -> Option<&[Operator]> {
23 None
24 }
25
26 fn validate_func(&self, func: &dyn Function) -> Result<(), String> {
28 if let Some(allowed_operators) = self.allowed_operators() {
29 if let Some(op) = func.as_operator() {
30 if !allowed_operators.contains(&op) {
31 return Err(format!(
32 "Received disallowed operator {op:?}. Allowed operators are {allowed_operators:?}"
33 ));
34 }
35 }
36 }
37
38 if let Some(allowed_comparators) = self.allowed_comparators() {
39 if let Some(comp) = func.as_comparator() {
40 if !allowed_comparators.contains(&comp) {
41 return Err(format!(
42 "Received disallowed comparator {comp:?}. Allowed comparators are {allowed_comparators:?}"
43 ));
44 }
45 }
46 }
47
48 Ok(())
49 }
50
51 fn visit_operation(&self, operation: &Operation) -> Result<Self::Output, String>;
53
54 fn visit_comparison(&self, comparison: &Comparison) -> Result<Self::Output, String>;
56
57 fn visit_structured_query(
59 &self,
60 structured_query: &StructuredQuery,
61 ) -> Result<Self::Output, String>;
62}
63
64pub trait Function {
66 fn as_operator(&self) -> Option<Operator>;
68 fn as_comparator(&self) -> Option<Comparator>;
70}
71
72impl Function for Operator {
73 fn as_operator(&self) -> Option<Operator> {
74 Some(*self)
75 }
76
77 fn as_comparator(&self) -> Option<Comparator> {
78 None
79 }
80}
81
82impl Function for Comparator {
83 fn as_operator(&self) -> Option<Operator> {
84 None
85 }
86
87 fn as_comparator(&self) -> Option<Comparator> {
88 Some(*self)
89 }
90}
91
92pub trait Expr {
94 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output, String>;
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
100pub enum Operator {
101 And,
103 Or,
105 Not,
107}
108
109impl std::fmt::Display for Operator {
110 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
111 match self {
112 Operator::And => write!(f, "and"),
113 Operator::Or => write!(f, "or"),
114 Operator::Not => write!(f, "not"),
115 }
116 }
117}
118
119#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
121pub enum Comparator {
122 Eq,
124 Ne,
126 Gt,
128 Gte,
130 Lt,
132 Lte,
134 Contain,
136 Like,
138 In,
140 Nin,
142}
143
144impl std::fmt::Display for Comparator {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 Comparator::Eq => write!(f, "eq"),
148 Comparator::Ne => write!(f, "ne"),
149 Comparator::Gt => write!(f, "gt"),
150 Comparator::Gte => write!(f, "gte"),
151 Comparator::Lt => write!(f, "lt"),
152 Comparator::Lte => write!(f, "lte"),
153 Comparator::Contain => write!(f, "contain"),
154 Comparator::Like => write!(f, "like"),
155 Comparator::In => write!(f, "in"),
156 Comparator::Nin => write!(f, "nin"),
157 }
158 }
159}
160
161#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
163pub enum FilterDirective {
164 Comparison(Comparison),
166 Operation(Operation),
168}
169
170impl FilterDirective {
171 pub fn comparison(comparator: Comparator, attribute: String, value: serde_json::Value) -> Self {
173 Self::Comparison(Comparison::new(comparator, attribute, value))
174 }
175
176 pub fn operation(operator: Operator, arguments: Vec<FilterDirective>) -> Self {
178 Self::Operation(Operation::new(operator, arguments))
179 }
180}
181
182impl Expr for FilterDirective {
183 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output, String> {
184 match self {
185 FilterDirective::Comparison(comp) => comp.accept(visitor),
186 FilterDirective::Operation(op) => op.accept(visitor),
187 }
188 }
189}
190
191#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
193pub struct Comparison {
194 pub comparator: Comparator,
196 pub attribute: String,
198 pub value: serde_json::Value,
200}
201
202impl Comparison {
203 pub fn new(comparator: Comparator, attribute: String, value: serde_json::Value) -> Self {
211 Self {
212 comparator,
213 attribute,
214 value,
215 }
216 }
217}
218
219impl Expr for Comparison {
220 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output, String> {
221 visitor.validate_func(&self.comparator)?;
222 visitor.visit_comparison(self)
223 }
224}
225
226#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
228pub struct Operation {
229 pub operator: Operator,
231 pub arguments: Vec<FilterDirective>,
233}
234
235impl Operation {
236 pub fn new(operator: Operator, arguments: Vec<FilterDirective>) -> Self {
243 Self {
244 operator,
245 arguments,
246 }
247 }
248}
249
250impl Expr for Operation {
251 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output, String> {
252 visitor.validate_func(&self.operator)?;
253 visitor.visit_operation(self)
254 }
255}
256
257#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
259pub struct StructuredQuery {
260 pub query: String,
262 pub filter: Option<FilterDirective>,
264 pub limit: Option<u32>,
266}
267
268impl StructuredQuery {
269 pub fn new(query: String, filter: Option<FilterDirective>, limit: Option<u32>) -> Self {
277 Self {
278 query,
279 filter,
280 limit,
281 }
282 }
283
284 pub fn simple(query: String) -> Self {
290 Self {
291 query,
292 filter: None,
293 limit: None,
294 }
295 }
296}
297
298impl Expr for StructuredQuery {
299 fn accept<V: Visitor>(&self, visitor: &V) -> Result<V::Output, String> {
300 visitor.visit_structured_query(self)
301 }
302}
303
304pub mod builders {
306 use super::*;
307
308 pub fn eq(attribute: &str, value: serde_json::Value) -> FilterDirective {
315 FilterDirective::comparison(Comparator::Eq, attribute.to_string(), value)
316 }
317
318 pub fn ne(attribute: &str, value: serde_json::Value) -> FilterDirective {
325 FilterDirective::comparison(Comparator::Ne, attribute.to_string(), value)
326 }
327
328 pub fn gt(attribute: &str, value: serde_json::Value) -> FilterDirective {
335 FilterDirective::comparison(Comparator::Gt, attribute.to_string(), value)
336 }
337
338 pub fn gte(attribute: &str, value: serde_json::Value) -> FilterDirective {
345 FilterDirective::comparison(Comparator::Gte, attribute.to_string(), value)
346 }
347
348 pub fn lt(attribute: &str, value: serde_json::Value) -> FilterDirective {
355 FilterDirective::comparison(Comparator::Lt, attribute.to_string(), value)
356 }
357
358 pub fn lte(attribute: &str, value: serde_json::Value) -> FilterDirective {
365 FilterDirective::comparison(Comparator::Lte, attribute.to_string(), value)
366 }
367
368 pub fn contain(attribute: &str, value: serde_json::Value) -> FilterDirective {
375 FilterDirective::comparison(Comparator::Contain, attribute.to_string(), value)
376 }
377
378 pub fn like(attribute: &str, value: serde_json::Value) -> FilterDirective {
385 FilterDirective::comparison(Comparator::Like, attribute.to_string(), value)
386 }
387
388 pub fn r#in(attribute: &str, value: serde_json::Value) -> FilterDirective {
395 FilterDirective::comparison(Comparator::In, attribute.to_string(), value)
396 }
397
398 pub fn nin(attribute: &str, value: serde_json::Value) -> FilterDirective {
405 FilterDirective::comparison(Comparator::Nin, attribute.to_string(), value)
406 }
407
408 pub fn and(arguments: Vec<FilterDirective>) -> FilterDirective {
414 FilterDirective::operation(Operator::And, arguments)
415 }
416
417 pub fn or(arguments: Vec<FilterDirective>) -> FilterDirective {
423 FilterDirective::operation(Operator::Or, arguments)
424 }
425
426 pub fn not(argument: FilterDirective) -> FilterDirective {
432 FilterDirective::operation(Operator::Not, vec![argument])
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use serde_json::json;
440
441 struct MockVisitor {
443 allowed_comparators: Option<Vec<Comparator>>,
444 allowed_operators: Option<Vec<Operator>>,
445 }
446
447 impl MockVisitor {
448 fn new() -> Self {
449 Self {
450 allowed_comparators: None,
451 allowed_operators: None,
452 }
453 }
454
455 fn with_allowed_comparators(comparators: Vec<Comparator>) -> Self {
456 Self {
457 allowed_comparators: Some(comparators),
458 allowed_operators: None,
459 }
460 }
461 }
462
463 impl Visitor for MockVisitor {
464 type Output = String;
465
466 fn allowed_comparators(&self) -> Option<&[Comparator]> {
467 self.allowed_comparators.as_deref()
468 }
469
470 fn allowed_operators(&self) -> Option<&[Operator]> {
471 self.allowed_operators.as_deref()
472 }
473
474 fn visit_operation(&self, operation: &Operation) -> Result<Self::Output, String> {
475 Ok(format!(
476 "Operation({:?}, {} args)",
477 operation.operator,
478 operation.arguments.len()
479 ))
480 }
481
482 fn visit_comparison(&self, comparison: &Comparison) -> Result<Self::Output, String> {
483 Ok(format!(
484 "Comparison({:?}, {}, {})",
485 comparison.comparator, comparison.attribute, comparison.value
486 ))
487 }
488
489 fn visit_structured_query(
490 &self,
491 structured_query: &StructuredQuery,
492 ) -> Result<Self::Output, String> {
493 let filter_str = if structured_query.filter.is_some() {
494 "with filter"
495 } else {
496 "no filter"
497 };
498 let limit_str = if let Some(limit) = structured_query.limit {
499 format!(", limit: {limit}")
500 } else {
501 String::new()
502 };
503 Ok(format!(
504 "Query(\"{}\", {}, {})",
505 structured_query.query, filter_str, limit_str
506 ))
507 }
508 }
509
510 #[test]
511 fn test_comparison_creation() {
512 let comp = Comparison::new(Comparator::Eq, "name".to_string(), json!("John"));
513 assert_eq!(comp.comparator, Comparator::Eq);
514 assert_eq!(comp.attribute, "name");
515 assert_eq!(comp.value, json!("John"));
516 }
517
518 #[test]
519 fn test_filter_directive_creation() {
520 let comp = FilterDirective::comparison(Comparator::Eq, "name".to_string(), json!("John"));
521 match comp {
522 FilterDirective::Comparison(c) => {
523 assert_eq!(c.comparator, Comparator::Eq);
524 assert_eq!(c.attribute, "name");
525 }
526 _ => panic!("Expected Comparison variant"),
527 }
528 }
529
530 #[test]
531 fn test_operation_creation() {
532 let comp1 = FilterDirective::comparison(Comparator::Eq, "name".to_string(), json!("John"));
533 let comp2 = FilterDirective::comparison(Comparator::Gt, "age".to_string(), json!(18));
534 let op = FilterDirective::operation(Operator::And, vec![comp1, comp2]);
535 match op {
536 FilterDirective::Operation(o) => {
537 assert_eq!(o.operator, Operator::And);
538 assert_eq!(o.arguments.len(), 2);
539 }
540 _ => panic!("Expected Operation variant"),
541 }
542 }
543
544 #[test]
545 fn test_structured_query_creation() {
546 let query = StructuredQuery::simple("test query".to_string());
547 assert_eq!(query.query, "test query");
548 assert!(query.filter.is_none());
549 assert!(query.limit.is_none());
550 }
551
552 #[test]
553 fn test_visitor_acceptance() {
554 let visitor = MockVisitor::new();
555 let comp = FilterDirective::comparison(Comparator::Eq, "name".to_string(), json!("John"));
556 let result = comp.accept(&visitor).unwrap();
557 assert!(result.contains("Comparison"));
558 assert!(result.contains("Eq"));
559 assert!(result.contains("name"));
560 }
561
562 #[test]
563 fn test_visitor_validation_success() {
564 let visitor = MockVisitor::with_allowed_comparators(vec![Comparator::Eq, Comparator::Gt]);
565 let comp = FilterDirective::comparison(Comparator::Eq, "name".to_string(), json!("John"));
566 let result = comp.accept(&visitor);
567 assert!(result.is_ok());
568 }
569
570 #[test]
571 fn test_visitor_validation_failure() {
572 let visitor = MockVisitor::with_allowed_comparators(vec![Comparator::Eq, Comparator::Gt]);
573 let comp = FilterDirective::comparison(Comparator::Lt, "age".to_string(), json!(18));
574 let result = comp.accept(&visitor);
575 assert!(result.is_err());
576 assert!(result.unwrap_err().contains("disallowed comparator"));
577 }
578
579 #[test]
580 fn test_builders() {
581 let eq_comp = builders::eq("name", json!("John"));
582 match &eq_comp {
583 FilterDirective::Comparison(c) => {
584 assert_eq!(c.comparator, Comparator::Eq);
585 assert_eq!(c.attribute, "name");
586 }
587 _ => panic!("Expected Comparison variant"),
588 }
589
590 let gt_comp = builders::gt("age", json!(18));
591 match >_comp {
592 FilterDirective::Comparison(c) => {
593 assert_eq!(c.comparator, Comparator::Gt);
594 assert_eq!(c.attribute, "age");
595 }
596 _ => panic!("Expected Comparison variant"),
597 }
598
599 let and_op = builders::and(vec![eq_comp, gt_comp]);
600 match and_op {
601 FilterDirective::Operation(o) => {
602 assert_eq!(o.operator, Operator::And);
603 assert_eq!(o.arguments.len(), 2);
604 }
605 _ => panic!("Expected Operation variant"),
606 }
607 }
608
609 #[test]
610 fn test_serialization() {
611 let comp = FilterDirective::comparison(Comparator::Eq, "name".to_string(), json!("John"));
612 let serialized = serde_json::to_string(&comp).unwrap();
613 let deserialized: FilterDirective = serde_json::from_str(&serialized).unwrap();
614 assert_eq!(comp, deserialized);
615 }
616
617 #[test]
618 fn test_display_traits() {
619 assert_eq!(Operator::And.to_string(), "and");
620 assert_eq!(Operator::Or.to_string(), "or");
621 assert_eq!(Operator::Not.to_string(), "not");
622
623 assert_eq!(Comparator::Eq.to_string(), "eq");
624 assert_eq!(Comparator::Gt.to_string(), "gt");
625 assert_eq!(Comparator::Contain.to_string(), "contain");
626 }
627}