1use ahash::AHashMap;
2use itertools::chain;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::lint_fix::LintFix;
5use sqruff_lib_core::parser::segments::base::{ErasedSegment, SegmentBuilder, Tables};
6use sqruff_lib_core::utils::functional::segments::Segments;
7use strum_macros::{AsRefStr, EnumString};
8
9use crate::core::config::Value;
10use crate::core::rules::base::{Erased, ErasedRule, LintResult, Rule, RuleGroups};
11use crate::core::rules::context::RuleContext;
12use crate::core::rules::crawlers::{Crawler, SegmentSeekerCrawler};
13use crate::utils::functional::context::FunctionalContext;
14
15#[derive(Debug, Copy, Clone, AsRefStr, EnumString, PartialEq, Default)]
16#[strum(serialize_all = "snake_case")]
17enum TypeCastingStyle {
18 #[default]
19 Consistent,
20 Cast,
21 Convert,
22 Shorthand,
23 None,
24}
25
26#[derive(Copy, Clone)]
27struct PreviousSkipped;
28
29fn get_children(segments: Segments) -> Segments {
30 segments.children(Some(|it: &ErasedSegment| {
31 !it.is_meta()
32 && !matches!(
33 it.get_type(),
34 SyntaxKind::StartBracket
35 | SyntaxKind::EndBracket
36 | SyntaxKind::Whitespace
37 | SyntaxKind::Newline
38 | SyntaxKind::CastingOperator
39 | SyntaxKind::Comma
40 | SyntaxKind::Keyword
41 )
42 }))
43}
44
45fn shorthand_fix_list(
46 tables: &Tables,
47 root_segment: ErasedSegment,
48 shorthand_arg_1: ErasedSegment,
49 shorthand_arg_2: ErasedSegment,
50) -> Vec<LintFix> {
51 let mut edits = if shorthand_arg_1.get_raw_segments().len() > 1 {
52 vec![
53 SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
54 shorthand_arg_1,
55 SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
56 ]
57 } else {
58 vec![shorthand_arg_1]
59 };
60
61 edits.extend([
62 SegmentBuilder::token(tables.next_id(), "::", SyntaxKind::CastingOperator).finish(),
63 shorthand_arg_2,
64 ]);
65
66 vec![LintFix::replace(root_segment, edits, None)]
67}
68
69#[derive(Clone, Debug, Default)]
70pub struct RuleCV11 {
71 preferred_type_casting_style: TypeCastingStyle,
72}
73
74impl Rule for RuleCV11 {
75 fn load_from_config(&self, config: &AHashMap<String, Value>) -> Result<ErasedRule, String> {
76 Ok(RuleCV11 {
77 preferred_type_casting_style: config["preferred_type_casting_style"]
78 .as_string()
79 .unwrap()
80 .parse()
81 .unwrap(),
82 }
83 .erased())
84 }
85
86 fn name(&self) -> &'static str {
87 "convention.casting_style"
88 }
89
90 fn description(&self) -> &'static str {
91 "Enforce consistent type casting style."
92 }
93
94 fn long_description(&self) -> &'static str {
95 r"
96**Anti-pattern**
97
98Using a mixture of `CONVERT`, `::`, and `CAST` when `preferred_type_casting_style` config is set to `consistent` (default).
99
100```sql
101SELECT
102 CONVERT(int, 1) AS bar,
103 100::int::text,
104 CAST(10 AS text) AS coo
105FROM foo;
106```
107
108**Best Practice**
109
110Use a consistent type casting style.
111
112```sql
113SELECT
114 CAST(1 AS int) AS bar,
115 CAST(CAST(100 AS int) AS text),
116 CAST(10 AS text) AS coo
117FROM foo;
118```
119"
120 }
121
122 fn groups(&self) -> &'static [RuleGroups] {
123 &[RuleGroups::All, RuleGroups::Convention]
124 }
125
126 fn eval(&self, context: &RuleContext) -> Vec<LintResult> {
127 let current_type_casting_style = if context.segment.is_type(SyntaxKind::Function) {
128 let Some(function_name) = context
129 .segment
130 .child(const { &SyntaxSet::new(&[SyntaxKind::FunctionName]) })
131 else {
132 return Vec::new();
133 };
134 if function_name.raw().eq_ignore_ascii_case("CAST") {
135 TypeCastingStyle::Cast
136 } else if function_name.raw().eq_ignore_ascii_case("CONVERT") {
137 TypeCastingStyle::Convert
138 } else {
139 TypeCastingStyle::None
140 }
141 } else if context.segment.is_type(SyntaxKind::CastExpression) {
142 TypeCastingStyle::Shorthand
143 } else {
144 TypeCastingStyle::None
145 };
146
147 let functional_context = FunctionalContext::new(context);
148 match self.preferred_type_casting_style {
149 TypeCastingStyle::Consistent => {
150 let Some(prior_type_casting_style) = context.try_get::<TypeCastingStyle>() else {
151 context.set(current_type_casting_style);
152 return Vec::new();
153 };
154 let previous_skipped = context.try_get::<PreviousSkipped>();
155
156 let mut fixes = Vec::new();
157 match prior_type_casting_style {
158 TypeCastingStyle::Cast => match current_type_casting_style {
159 TypeCastingStyle::Convert => {
160 let convert_content =
161 get_children(functional_context.segment().children(Some(
162 |it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed),
163 )));
164 if convert_content.len() > 2 {
165 if previous_skipped.is_none() {
166 context.set(PreviousSkipped);
167 }
168 return Vec::new();
169 }
170
171 fixes = cast_fix_list(
172 context.tables,
173 context.segment.clone(),
174 &[convert_content[1].clone()],
175 convert_content[0].clone(),
176 None,
177 );
178 }
179 TypeCastingStyle::Shorthand => {
180 let expression_datatype_segment =
181 get_children(functional_context.segment());
182
183 fixes = cast_fix_list(
184 context.tables,
185 context.segment.clone(),
186 &[expression_datatype_segment[0].clone()],
187 expression_datatype_segment[1].clone(),
188 Some(Segments::from_vec(
189 expression_datatype_segment.base[2..].to_vec(),
190 None,
191 )),
192 )
193 }
194 _ => {}
195 },
196 TypeCastingStyle::Convert => match current_type_casting_style {
197 TypeCastingStyle::Cast => {
198 let cast_content = get_children(functional_context.segment().children(
199 Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)),
200 ));
201
202 if cast_content.len() > 2 {
203 return Vec::new();
204 }
205
206 fixes = convert_fix_list(
207 context.tables,
208 context.segment.clone(),
209 cast_content[1].clone(),
210 cast_content[0].clone(),
211 None,
212 );
213 }
214 TypeCastingStyle::Shorthand => {
215 let expression_datatype_segment =
216 get_children(functional_context.segment());
217
218 fixes = convert_fix_list(
219 context.tables,
220 context.segment.clone(),
221 expression_datatype_segment[1].clone(),
222 expression_datatype_segment[0].clone(),
223 Some(Segments::from_vec(
224 expression_datatype_segment.base[2..].to_vec(),
225 None,
226 )),
227 );
228 }
229 _ => (),
230 },
231 TypeCastingStyle::Shorthand => {
232 if current_type_casting_style == TypeCastingStyle::Cast {
233 let cast_content = get_children(functional_context.segment().children(
235 Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)),
236 ));
237 if cast_content.len() > 2 {
238 return Vec::new();
239 }
240
241 fixes = shorthand_fix_list(
242 context.tables,
243 context.segment.clone(),
244 cast_content[0].clone(),
245 cast_content[1].clone(),
246 );
247 } else if current_type_casting_style == TypeCastingStyle::Convert {
248 let convert_content =
249 get_children(functional_context.segment().children(Some(
250 |it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed),
251 )));
252 if convert_content.len() > 2 {
253 return Vec::new();
254 }
255
256 fixes = shorthand_fix_list(
257 context.tables,
258 context.segment.clone(),
259 convert_content[1].clone(),
260 convert_content[0].clone(),
261 );
262 }
263 }
264 _ => {}
265 }
266
267 if prior_type_casting_style != current_type_casting_style {
268 return vec![LintResult::new(
269 context.segment.clone().into(),
270 fixes,
271 "Inconsistent type casting styles found.".to_owned().into(),
272 None,
273 )];
274 }
275 }
276 _ if current_type_casting_style != self.preferred_type_casting_style => {
277 let mut convert_content = None;
278 let mut cast_content = None;
279 let mut fixes = Vec::new();
280
281 match self.preferred_type_casting_style {
282 TypeCastingStyle::Cast => match current_type_casting_style {
283 TypeCastingStyle::Convert => {
284 let segments = get_children(functional_context.segment().children(
285 Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)),
286 ));
287 fixes = cast_fix_list(
288 context.tables,
289 context.segment.clone(),
290 &[segments[1].clone()],
291 segments[0].clone(),
292 None,
293 );
294 convert_content = Some(segments);
295 }
296 TypeCastingStyle::Shorthand => {
297 let expression_datatype_segment =
298 get_children(functional_context.segment());
299 let data_type_idx = expression_datatype_segment
300 .iter()
301 .position(|seg| seg.is_type(SyntaxKind::DataType))
302 .unwrap();
303
304 fixes = cast_fix_list(
305 context.tables,
306 context.segment.clone(),
307 &expression_datatype_segment[..data_type_idx],
308 expression_datatype_segment[data_type_idx].clone(),
309 Some(Segments::from_vec(
310 expression_datatype_segment.base[data_type_idx + 1..].to_vec(),
311 None,
312 )),
313 );
314 }
315 _ => {}
316 },
317 TypeCastingStyle::Convert => match current_type_casting_style {
318 TypeCastingStyle::Cast => {
319 let cast_content = get_children(functional_context.segment().children(
320 Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)),
321 ));
322
323 fixes = convert_fix_list(
324 context.tables,
325 context.segment.clone(),
326 cast_content[1].clone(),
327 cast_content[0].clone(),
328 None,
329 );
330 }
331 TypeCastingStyle::Shorthand => {
332 let cast_content = get_children(functional_context.segment());
333
334 fixes = convert_fix_list(
335 context.tables,
336 context.segment.clone(),
337 cast_content[1].clone(),
338 cast_content[0].clone(),
339 Some(Segments::from_vec(cast_content.base[2..].to_vec(), None)),
340 )
341 }
342 _ => {}
343 },
344 TypeCastingStyle::Shorthand => match current_type_casting_style {
345 TypeCastingStyle::Cast => {
346 let segments = get_children(functional_context.segment().children(
347 Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)),
348 ));
349
350 fixes = shorthand_fix_list(
351 context.tables,
352 context.segment.clone(),
353 segments[0].clone(),
354 segments[1].clone(),
355 );
356 cast_content = Some(segments);
357 }
358 TypeCastingStyle::Convert => {
359 let segments = get_children(functional_context.segment().children(
360 Some(|it: &ErasedSegment| it.is_type(SyntaxKind::Bracketed)),
361 ));
362
363 fixes = shorthand_fix_list(
364 context.tables,
365 context.segment.clone(),
366 segments[1].clone(),
367 segments[0].clone(),
368 );
369
370 convert_content = Some(segments);
371 }
372 _ => {}
373 },
374 _ => {}
375 }
376
377 if convert_content
378 .filter(|convert_content| convert_content.len() > 2)
379 .is_some()
380 {
381 fixes.clear();
382 }
383
384 if cast_content
385 .filter(|cast_content| cast_content.len() > 2)
386 .is_some()
387 {
388 fixes.clear();
389 }
390
391 return vec![LintResult::new(
392 context.segment.clone().into(),
393 fixes,
394 "Used type casting style is different from the preferred type casting style."
395 .to_owned()
396 .into(),
397 None,
398 )];
399 }
400
401 _ => {}
402 }
403
404 Vec::new()
405 }
406
407 fn is_fix_compatible(&self) -> bool {
408 true
409 }
410
411 fn crawl_behaviour(&self) -> Crawler {
412 SegmentSeekerCrawler::new(
413 const { SyntaxSet::new(&[SyntaxKind::Function, SyntaxKind::CastExpression]) },
414 )
415 .into()
416 }
417}
418
419fn convert_fix_list(
420 tables: &Tables,
421 root: ErasedSegment,
422 convert_arg_1: ErasedSegment,
423 convert_arg_2: ErasedSegment,
424 later_types: Option<Segments>,
425) -> Vec<LintFix> {
426 use sqruff_lib_core::parser::segments::base::ErasedSegment;
427
428 let mut edits: Vec<ErasedSegment> = vec![
429 SegmentBuilder::token(
430 tables.next_id(),
431 "convert",
432 SyntaxKind::FunctionNameIdentifier,
433 )
434 .finish(),
435 SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
436 convert_arg_1,
437 SegmentBuilder::token(tables.next_id(), ",", SyntaxKind::Comma).finish(),
438 SegmentBuilder::whitespace(tables.next_id(), " "),
439 convert_arg_2,
440 SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
441 ];
442
443 if let Some(later_types) = later_types {
444 let pre_edits: Vec<ErasedSegment> = vec![
445 SegmentBuilder::token(
446 tables.next_id(),
447 "convert",
448 SyntaxKind::FunctionNameIdentifier,
449 )
450 .finish(),
451 SegmentBuilder::symbol(tables.next_id(), "("),
452 ];
453
454 let in_edits: Vec<ErasedSegment> = vec![
455 SegmentBuilder::symbol(tables.next_id(), ","),
456 SegmentBuilder::whitespace(tables.next_id(), " "),
457 ];
458
459 let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
460
461 for _type in later_types.base {
462 edits = chain(
463 chain(pre_edits.clone(), vec![_type]),
464 chain(in_edits.clone(), chain(edits, post_edits.clone())),
465 )
466 .collect();
467 }
468 }
469
470 vec![LintFix::replace(root, edits, None)]
471}
472
473fn cast_fix_list(
474 tables: &Tables,
475 root: ErasedSegment,
476 cast_arg_1: &[ErasedSegment],
477 cast_arg_2: ErasedSegment,
478 later_types: Option<Segments>,
479) -> Vec<LintFix> {
480 let mut edits = vec![
481 SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
482 .finish(),
483 SegmentBuilder::token(tables.next_id(), "(", SyntaxKind::StartBracket).finish(),
484 ];
485 edits.extend_from_slice(cast_arg_1);
486 edits.extend([
487 SegmentBuilder::whitespace(tables.next_id(), " "),
488 SegmentBuilder::keyword(tables.next_id(), "as"),
489 SegmentBuilder::whitespace(tables.next_id(), " "),
490 cast_arg_2,
491 SegmentBuilder::token(tables.next_id(), ")", SyntaxKind::EndBracket).finish(),
492 ]);
493
494 if let Some(later_types) = later_types {
495 let pre_edits: Vec<ErasedSegment> = vec![
496 SegmentBuilder::token(tables.next_id(), "cast", SyntaxKind::FunctionNameIdentifier)
497 .finish(),
498 SegmentBuilder::symbol(tables.next_id(), "("),
499 ];
500
501 let in_edits: Vec<ErasedSegment> = vec![
502 SegmentBuilder::whitespace(tables.next_id(), " "),
503 SegmentBuilder::keyword(tables.next_id(), "as"),
504 SegmentBuilder::whitespace(tables.next_id(), " "),
505 ];
506
507 let post_edits: Vec<ErasedSegment> = vec![SegmentBuilder::symbol(tables.next_id(), ")")];
508
509 for _type in later_types.base {
510 let mut xs = Vec::new();
511 xs.extend(pre_edits.clone());
512 xs.extend(edits);
513 xs.extend(in_edits.clone());
514 xs.push(_type);
515 xs.extend(post_edits.clone());
516 edits = xs;
517 }
518 }
519
520 vec![LintFix::replace(root, edits, None)]
521}