sqruff_lib/utils/reflow/
respace.rs

1use itertools::{Itertools, enumerate};
2use rustc_hash::FxHashMap;
3use sqruff_lib_core::dialects::syntax::{SyntaxKind, SyntaxSet};
4use sqruff_lib_core::edit_type::EditType;
5use sqruff_lib_core::lint_fix::LintFix;
6use sqruff_lib_core::parser::markers::PositionMarker;
7use sqruff_lib_core::parser::segments::base::{ErasedSegment, SegmentBuilder, Tables};
8
9use super::elements::ReflowBlock;
10use crate::core::rules::base::LintResult;
11use crate::utils::reflow::config::Spacing;
12use crate::utils::reflow::helpers::pretty_segment_name;
13
14fn unpack_constraint(constraint: Spacing, strip_newlines: bool) -> (Spacing, bool) {
15    match constraint {
16        Spacing::TouchInline => (Spacing::Touch, true),
17        Spacing::SingleInline => (Spacing::Single, true),
18        _ => (constraint, strip_newlines),
19    }
20}
21
22pub fn determine_constraints(
23    prev_block: Option<&ReflowBlock>,
24    next_block: Option<&ReflowBlock>,
25    strip_newlines: bool,
26) -> (Spacing, Spacing, bool) {
27    // Start with the defaults
28    let (mut pre_constraint, strip_newlines) = unpack_constraint(
29        if let Some(prev_block) = prev_block {
30            prev_block.spacing_after()
31        } else {
32            Spacing::Single
33        },
34        strip_newlines,
35    );
36
37    let (mut post_constraint, mut strip_newlines) = unpack_constraint(
38        if let Some(next_block) = next_block {
39            next_block.spacing_before()
40        } else {
41            Spacing::Single
42        },
43        strip_newlines,
44    );
45
46    let mut within_spacing = None;
47    let mut idx = None;
48
49    if let Some((prev_block, next_block)) = prev_block.zip(next_block) {
50        let common = prev_block.depth_info().common_with(next_block.depth_info());
51        let last_common = common.last().unwrap();
52        idx = prev_block
53            .depth_info()
54            .stack_hashes
55            .iter()
56            .position(|p| p == last_common)
57            .unwrap()
58            .into();
59
60        let within_constraint = prev_block.stack_spacing_configs().get(last_common);
61        if let Some(within_constraint) = within_constraint {
62            let (within_spacing_inner, strip_newlines_inner) =
63                unpack_constraint(*within_constraint, strip_newlines);
64
65            within_spacing = Some(within_spacing_inner);
66            strip_newlines = strip_newlines_inner;
67        }
68    }
69
70    match within_spacing {
71        Some(Spacing::Touch) => {
72            if pre_constraint != Spacing::Any {
73                pre_constraint = Spacing::Touch;
74            }
75            if post_constraint != Spacing::Any {
76                post_constraint = Spacing::Touch;
77            }
78        }
79        Some(Spacing::Any) => {
80            pre_constraint = Spacing::Any;
81            post_constraint = Spacing::Any;
82        }
83        Some(Spacing::Single) => {}
84        Some(spacing) => {
85            panic!(
86                "Unexpected within constraint: {:?} for {:?}",
87                spacing,
88                prev_block.unwrap().depth_info().stack_class_types[idx.unwrap()]
89            );
90        }
91        _ => {}
92    }
93
94    (pre_constraint, post_constraint, strip_newlines)
95}
96
97pub fn process_spacing(
98    segment_buffer: &[ErasedSegment],
99    strip_newlines: bool,
100) -> (Vec<ErasedSegment>, Option<ErasedSegment>, Vec<LintResult>) {
101    let mut removal_buffer = Vec::new();
102    let mut result_buffer = Vec::new();
103    let mut last_whitespace = Vec::new();
104
105    // Loop through the existing segments looking for spacing.
106    for seg in segment_buffer {
107        // If it's whitespace, store it.
108        if seg.is_type(SyntaxKind::Whitespace) {
109            last_whitespace.push(seg.clone());
110        }
111        // If it's a newline, react accordingly.
112        // NOTE: This should only trigger on literal newlines.
113        else if matches!(seg.get_type(), SyntaxKind::Newline | SyntaxKind::EndOfFile) {
114            if seg
115                .get_position_marker()
116                .is_some_and(|pos_marker| !pos_marker.is_literal())
117            {
118                last_whitespace.clear();
119                continue;
120            }
121
122            if strip_newlines && seg.is_type(SyntaxKind::Newline) {
123                removal_buffer.push(seg.clone());
124                result_buffer.push(LintResult::new(
125                    seg.clone().into(),
126                    vec![LintFix::delete(seg.clone())],
127                    Some("Unexpected line break.".into()),
128                    None,
129                ));
130                continue;
131            }
132
133            if !last_whitespace.is_empty() {
134                for ws in last_whitespace.drain(..) {
135                    removal_buffer.push(ws.clone());
136                    result_buffer.push(LintResult::new(
137                        ws.clone().into(),
138                        vec![LintFix::delete(ws)],
139                        Some("Unnecessary trailing whitespace.".into()),
140                        None,
141                    ))
142                }
143            }
144        }
145    }
146
147    if last_whitespace.len() >= 2 {
148        let seg = segment_buffer.last().unwrap();
149
150        for ws in last_whitespace.iter().skip(1).cloned() {
151            removal_buffer.push(ws.clone());
152            result_buffer.push(LintResult::new(
153                seg.clone().into(),
154                vec![LintFix::delete(ws)],
155                "Unnecessary trailing whitespace.".to_owned().into(),
156                None,
157            ));
158        }
159    }
160
161    // Turn the removal buffer updated segment buffer, last whitespace and
162    // associated fixes.
163
164    let filtered_segment_buffer = segment_buffer
165        .iter()
166        .filter(|s| !removal_buffer.contains(s))
167        .cloned()
168        .collect_vec();
169
170    (
171        filtered_segment_buffer,
172        last_whitespace.first().cloned(),
173        result_buffer,
174    )
175}
176
177fn determine_aligned_inline_spacing(
178    root_segment: &ErasedSegment,
179    whitespace_seg: &ErasedSegment,
180    next_seg: &ErasedSegment,
181    mut next_pos: PositionMarker,
182    segment_type: SyntaxKind,
183    align_within: Option<SyntaxKind>,
184    align_scope: Option<SyntaxKind>,
185) -> String {
186    // Find the level of segment that we're aligning.
187    let mut parent_segment = None;
188
189    // Edge case: if next_seg has no position, we should use the position
190    // of the whitespace for searching.
191    if let Some(align_within) = align_within {
192        for ps in root_segment
193            .path_to(if next_seg.get_position_marker().is_some() {
194                next_seg
195            } else {
196                whitespace_seg
197            })
198            .iter()
199            .rev()
200        {
201            if ps.segment.is_type(align_within) {
202                parent_segment = Some(ps.segment.clone());
203            }
204            if let Some(align_scope) = align_scope {
205                if ps.segment.is_type(align_scope) {
206                    break;
207                }
208            }
209        }
210    }
211
212    if parent_segment.is_none() {
213        return " ".to_string();
214    }
215
216    let parent_segment = parent_segment.unwrap();
217
218    // We've got a parent. Find some siblings.
219    let mut siblings = Vec::new();
220    for sibling in parent_segment.recursive_crawl(
221        &SyntaxSet::single(segment_type),
222        true,
223        &SyntaxSet::EMPTY,
224        true,
225    ) {
226        // Purge any siblings with a boundary between them
227        if align_scope.is_none()
228            || !parent_segment
229                .path_to(&sibling)
230                .iter()
231                .any(|ps| ps.segment.is_type(align_scope.unwrap()))
232        {
233            siblings.push(sibling);
234        }
235    }
236
237    // If the segment we're aligning, has position. Use that position.
238    // If it doesn't, then use the provided one. We can't do sibling analysis
239    // without it.
240    if let Some(pos_marker) = next_seg.get_position_marker() {
241        next_pos = pos_marker.clone();
242    }
243
244    // Purge any siblings which are either self, or on the same line but after it.
245    let mut earliest_siblings: FxHashMap<usize, usize> = FxHashMap::default();
246    siblings.retain(|sibling| {
247        let pos_marker = sibling.get_position_marker().unwrap();
248        let best_seen = earliest_siblings.get(&pos_marker.working_line_no).copied();
249        if let Some(best_seen) = best_seen {
250            if pos_marker.working_line_pos > best_seen {
251                return false;
252            }
253        }
254        earliest_siblings.insert(pos_marker.working_line_no, pos_marker.working_line_pos);
255
256        if pos_marker.working_line_no == next_pos.working_line_no
257            && pos_marker.working_line_pos != next_pos.working_line_pos
258        {
259            return false;
260        }
261        true
262    });
263
264    // If there's only one sibling, we have nothing to compare to. Default to a
265    // single space.
266    if siblings.len() <= 1 {
267        return " ".to_string();
268    }
269
270    let mut last_code: Option<ErasedSegment> = None;
271    let mut max_desired_line_pos = 0;
272
273    for seg in parent_segment.get_raw_segments() {
274        for sibling in &siblings {
275            if let (Some(seg_pos), Some(sibling_pos)) =
276                (&seg.get_position_marker(), &sibling.get_position_marker())
277            {
278                if seg_pos.working_loc() == sibling_pos.working_loc() {
279                    if let Some(last_code) = &last_code {
280                        let loc = last_code
281                            .get_position_marker()
282                            .unwrap()
283                            .working_loc_after(last_code.raw());
284
285                        if loc.1 > max_desired_line_pos {
286                            max_desired_line_pos = loc.1;
287                        }
288                    }
289                }
290            }
291        }
292
293        if seg.is_code() {
294            last_code = Some(seg.clone());
295        }
296    }
297
298    " ".repeat(
299        1 + max_desired_line_pos
300            - whitespace_seg
301                .get_position_marker()
302                .as_ref()
303                .unwrap()
304                .working_line_pos,
305    )
306}
307
308#[allow(clippy::too_many_arguments)]
309pub fn handle_respace_inline_with_space(
310    tables: &Tables,
311    pre_constraint: Spacing,
312    post_constraint: Spacing,
313    prev_block: Option<&ReflowBlock>,
314    next_block: Option<&ReflowBlock>,
315    root_segment: &ErasedSegment,
316    mut segment_buffer: Vec<ErasedSegment>,
317    last_whitespace: ErasedSegment,
318) -> (Vec<ErasedSegment>, Vec<LintResult>) {
319    // Get some indices so that we can reference around them
320    let ws_idx = segment_buffer
321        .iter()
322        .position(|it| it == &last_whitespace)
323        .unwrap();
324
325    if pre_constraint == Spacing::Any || post_constraint == Spacing::Any {
326        return (segment_buffer, vec![]);
327    }
328
329    if [pre_constraint, post_constraint].contains(&Spacing::Touch) {
330        segment_buffer.remove(ws_idx);
331
332        let description = if let Some(next_block) = next_block {
333            format!(
334                "Unexpected whitespace before {}.",
335                pretty_segment_name(next_block.segment())
336            )
337        } else {
338            "Unexpected whitespace".to_string()
339        };
340
341        let lint_result = LintResult::new(
342            last_whitespace.clone().into(),
343            vec![LintFix::delete(last_whitespace)],
344            Some(description),
345            None,
346        );
347
348        // Return the segment buffer and the lint result
349        return (segment_buffer, vec![lint_result]);
350    }
351
352    // Handle left alignment & singles
353    if (matches!(post_constraint, Spacing::Align { .. }) && next_block.is_some())
354        || pre_constraint == Spacing::Single && post_constraint == Spacing::Single
355    {
356        let (desc, desired_space) = match (post_constraint, next_block) {
357            (
358                Spacing::Align {
359                    seg_type,
360                    within,
361                    scope,
362                },
363                Some(next_block),
364            ) => {
365                let next_pos = if let Some(pos_marker) = next_block.segment().get_position_marker()
366                {
367                    Some(pos_marker.clone())
368                } else if let Some(pos_marker) = last_whitespace.get_position_marker() {
369                    Some(pos_marker.end_point_marker())
370                } else if let Some(prev_block) = prev_block {
371                    prev_block
372                        .segment()
373                        .get_position_marker()
374                        .map(|pos_marker| pos_marker.end_point_marker())
375                } else {
376                    None
377                };
378
379                if let Some(next_pos) = next_pos {
380                    let desired_space = determine_aligned_inline_spacing(
381                        root_segment,
382                        &last_whitespace,
383                        next_block.segment(),
384                        next_pos,
385                        seg_type,
386                        within,
387                        scope,
388                    );
389                    ("Item misaligned".to_string(), desired_space)
390                } else {
391                    ("Item misaligned".to_string(), " ".to_string())
392                }
393            }
394            _ => {
395                let desc = if let Some(next_block) = next_block {
396                    format!(
397                        "Expected only single space before {:?}. Found {:?}.",
398                        next_block.segment().raw(),
399                        last_whitespace.raw()
400                    )
401                } else {
402                    format!(
403                        "Expected only single space. Found {:?}.",
404                        last_whitespace.raw()
405                    )
406                };
407                let desired_space = " ".to_string();
408                (desc, desired_space)
409            }
410        };
411
412        let mut new_results = Vec::new();
413        if last_whitespace.raw().as_str() != desired_space {
414            let new_seg = last_whitespace.edit(tables.next_id(), desired_space.into(), None);
415
416            new_results.push(LintResult::new(
417                last_whitespace.clone().into(),
418                vec![LintFix::replace(
419                    last_whitespace,
420                    vec![new_seg.clone()],
421                    None,
422                )],
423                Some(desc),
424                None,
425            ));
426            segment_buffer[ws_idx] = new_seg;
427        }
428
429        return (segment_buffer, new_results);
430    }
431
432    unimplemented!("Unexpected Constraints: {pre_constraint:?}, {post_constraint:?}");
433}
434
435#[allow(clippy::too_many_arguments)]
436pub fn handle_respace_inline_without_space(
437    tables: &Tables,
438    pre_constraint: Spacing,
439    post_constraint: Spacing,
440    prev_block: Option<&ReflowBlock>,
441    next_block: Option<&ReflowBlock>,
442    mut segment_buffer: Vec<ErasedSegment>,
443    mut existing_results: Vec<LintResult>,
444    anchor_on: &str,
445) -> (Vec<ErasedSegment>, Vec<LintResult>, bool) {
446    let constraints = [Spacing::Touch, Spacing::Any];
447
448    if constraints.contains(&pre_constraint) || constraints.contains(&post_constraint) {
449        return (segment_buffer, existing_results, false);
450    }
451
452    let added_whitespace = SegmentBuilder::whitespace(tables.next_id(), " ");
453
454    // Add it to the buffer first (the easy bit). The hard bit is to then determine
455    // how to generate the appropriate LintFix objects.
456    segment_buffer.push(added_whitespace.clone());
457
458    // So special handling here. If segments either side already exist then we don't
459    // care which we anchor on but if one is already an insertion (as shown by a
460    // lack) of pos_marker, then we should piggyback on that pre-existing fix.
461    let mut existing_fix = None;
462    let mut insertion = None;
463
464    if let Some(block) = prev_block {
465        if block.segment().get_position_marker().is_none() {
466            existing_fix = Some("after");
467            insertion = Some(block.segment().clone());
468        }
469    } else if let Some(block) = next_block {
470        if block.segment().get_position_marker().is_none() {
471            existing_fix = Some("before");
472            insertion = Some(block.segment().clone());
473        }
474    }
475
476    if let Some(existing_fix) = existing_fix {
477        let mut res_found = None;
478        let mut fix_found = None;
479
480        'outer: for (result_idx, res) in enumerate(&existing_results) {
481            for (fix_idx, fix) in enumerate(&res.fixes) {
482                if fix
483                    .edit
484                    .iter()
485                    .any(|e| e.id() == insertion.as_ref().unwrap().id())
486                {
487                    res_found = Some(result_idx);
488                    fix_found = Some(fix_idx);
489                    break 'outer;
490                }
491            }
492        }
493
494        let res = res_found.unwrap();
495        let fix = fix_found.unwrap();
496
497        let fix = &mut existing_results[res].fixes[fix];
498
499        if existing_fix == "before" {
500            unimplemented!()
501        } else if existing_fix == "after" {
502            fix.edit.push(added_whitespace);
503        }
504
505        return (segment_buffer, existing_results, true);
506    }
507
508    let desc = if let Some((prev_block, next_block)) = prev_block.zip(next_block) {
509        format!(
510            "Expected single whitespace between {:?} and {:?}.",
511            prev_block.segment().raw(),
512            next_block.segment().raw()
513        )
514    } else {
515        "Expected single whitespace.".to_owned()
516    };
517
518    let new_result = if prev_block.is_some() && anchor_on != "after" {
519        let prev_block = prev_block.unwrap();
520        let anchor = if let Some(block) = next_block {
521            // If next_block is Some, get the first segment
522            block.segment().clone()
523        } else {
524            prev_block.segment().clone()
525        };
526
527        LintResult::new(
528            anchor.into(),
529            vec![LintFix {
530                edit_type: EditType::CreateAfter,
531                anchor: prev_block.segment().clone(),
532                edit: vec![added_whitespace],
533                source: vec![],
534            }],
535            desc.into(),
536            None,
537        )
538    } else if let Some(next_block) = next_block {
539        LintResult::new(
540            next_block.segment().clone().into(),
541            vec![LintFix::create_before(
542                next_block.segment().clone(),
543                vec![SegmentBuilder::whitespace(tables.next_id(), " ")],
544            )],
545            Some(desc),
546            None,
547        )
548    } else {
549        unimplemented!("Not set up to handle a missing _after_ and _before_.")
550    };
551
552    existing_results.push(new_result);
553    (segment_buffer, existing_results, false)
554}
555
556#[cfg(test)]
557mod tests {
558    use itertools::Itertools;
559    use pretty_assertions::assert_eq;
560    use smol_str::ToSmolStr;
561    use sqruff_lib::core::test_functions::parse_ansi_string;
562    use sqruff_lib_core::edit_type::EditType;
563    use sqruff_lib_core::helpers::enter_panic;
564
565    use crate::utils::reflow::helpers::fixes_from_results;
566    use crate::utils::reflow::respace::Tables;
567    use crate::utils::reflow::sequence::{Filter, ReflowSequence};
568
569    #[test]
570    fn test_reflow_sequence_respace() {
571        let cases = [
572            // Basic cases
573            ("select 1+2", (false, Filter::All), "select 1 + 2"),
574            (
575                "select    1   +   2    ",
576                (false, Filter::All),
577                "select 1 + 2",
578            ),
579            // Check newline handling
580            (
581                "select\n    1   +   2",
582                (false, Filter::All),
583                "select\n    1 + 2",
584            ),
585            ("select\n    1   +   2", (true, Filter::All), "select 1 + 2"),
586            // Check filtering
587            (
588                "select  \n  1   +   2 \n ",
589                (false, Filter::All),
590                "select\n  1 + 2\n",
591            ),
592            (
593                "select  \n  1   +   2 \n ",
594                (false, Filter::Inline),
595                "select  \n  1 + 2 \n ",
596            ),
597            (
598                "select  \n  1   +   2 \n ",
599                (false, Filter::Newline),
600                "select\n  1   +   2\n",
601            ),
602        ];
603
604        let tables = Tables::default();
605        for (raw_sql_in, (strip_newlines, filter), raw_sql_out) in cases {
606            let root = parse_ansi_string(raw_sql_in);
607            let config = <_>::default();
608            let seq = ReflowSequence::from_root(root, &config);
609
610            let new_seq = seq.respace(&tables, strip_newlines, filter);
611            assert_eq!(new_seq.raw(), raw_sql_out);
612        }
613    }
614
615    #[test]
616    fn test_reflow_point_respace_point() {
617        let cases = [
618            // Basic cases
619            (
620                "select    1",
621                1,
622                false,
623                " ",
624                vec![(EditType::Replace, "    ".into())],
625            ),
626            (
627                "select 1+2",
628                3,
629                false,
630                " ",
631                vec![(EditType::CreateAfter, "1".into())],
632            ),
633            ("select (1+2)", 3, false, "", vec![]),
634            (
635                "select (  1+2)",
636                3,
637                false,
638                "",
639                vec![(EditType::Delete, "  ".into())],
640            ),
641            // Newline handling
642            ("select\n1", 1, false, "\n", vec![]),
643            ("select\n  1", 1, false, "\n  ", vec![]),
644            (
645                "select  \n  1",
646                1,
647                false,
648                "\n  ",
649                vec![(EditType::Delete, "  ".into())],
650            ),
651            (
652                "select  \n 1",
653                1,
654                true,
655                " ",
656                vec![
657                    (EditType::Delete, "\n".into()),
658                    (EditType::Delete, " ".into()),
659                    (EditType::Replace, "  ".into()),
660                ],
661            ),
662            (
663                "select ( \n  1)",
664                3,
665                true,
666                "",
667                vec![
668                    (EditType::Delete, "\n".into()),
669                    (EditType::Delete, "  ".into()),
670                    (EditType::Delete, " ".into()),
671                ],
672            ),
673        ];
674
675        let tables = Tables::default();
676        for (raw_sql_in, point_idx, strip_newlines, raw_point_sql_out, fixes_out) in cases {
677            let _panic = enter_panic(format!("{raw_sql_in:?}"));
678
679            let root = parse_ansi_string(raw_sql_in);
680            let config = <_>::default();
681            let seq = ReflowSequence::from_root(root.clone(), &config);
682            let pnt = seq.elements()[point_idx].as_point().unwrap();
683
684            let (results, new_pnt) = pnt.respace_point(
685                &tables,
686                seq.elements()[point_idx - 1].as_block(),
687                seq.elements()[point_idx + 1].as_block(),
688                &root,
689                Vec::new(),
690                strip_newlines,
691                "before",
692            );
693
694            assert_eq!(new_pnt.raw(), raw_point_sql_out);
695
696            let fixes = fixes_from_results(results.into_iter())
697                .map(|fix| (fix.edit_type, fix.anchor.raw().to_smolstr()))
698                .collect_vec();
699
700            assert_eq!(fixes, fixes_out);
701        }
702    }
703}