1use std::cmp::PartialEq;
2use std::str::FromStr;
3
4use sqruff_lib_core::dialects::syntax::SyntaxKind;
5use sqruff_lib_core::helpers::capitalize;
6use sqruff_lib_core::lint_fix::LintFix;
7use sqruff_lib_core::parser::segments::base::{ErasedSegment, Tables};
8use strum_macros::{AsRefStr, EnumString};
9
10use super::elements::{ReflowElement, ReflowSequenceType};
11use crate::core::rules::base::LintResult;
12use crate::utils::reflow::depth_map::StackPositionType;
13use crate::utils::reflow::elements::ReflowPoint;
14use crate::utils::reflow::helpers::{deduce_line_indent, fixes_from_results};
15
16#[derive(Debug)]
17pub struct RebreakSpan {
18 pub(crate) target: ErasedSegment,
19 pub(crate) start_idx: usize,
20 pub(crate) end_idx: usize,
21 pub(crate) line_position: LinePosition,
22 pub(crate) strict: bool,
23}
24
25#[derive(Debug)]
26pub struct RebreakIndices {
27 _dir: i32,
28 adj_pt_idx: isize,
29 newline_pt_idx: isize,
30 pre_code_pt_idx: isize,
31}
32
33impl RebreakIndices {
34 fn from_elements(elements: &ReflowSequenceType, start_idx: usize, dir: i32) -> Option<Self> {
35 assert!(dir == 1 || dir == -1);
36 let limit = if dir == -1 { 0 } else { elements.len() };
37 let adj_point_idx = start_idx as isize + dir as isize;
38
39 if adj_point_idx < 0 || adj_point_idx >= elements.len() as isize {
40 return None;
41 }
42
43 let mut newline_point_idx = adj_point_idx;
44 while (dir == 1 && newline_point_idx < limit as isize)
45 || (dir == -1 && newline_point_idx >= 0)
46 {
47 if elements[newline_point_idx as usize]
48 .class_types()
49 .contains(SyntaxKind::Newline)
50 || elements[(newline_point_idx + dir as isize) as usize]
51 .segments()
52 .iter()
53 .any(|seg| seg.is_code())
54 {
55 break;
56 }
57 newline_point_idx += 2 * dir as isize;
58 }
59
60 let mut pre_code_point_idx = newline_point_idx;
61 while (dir == 1 && pre_code_point_idx < limit as isize)
62 || (dir == -1 && pre_code_point_idx >= 0)
63 {
64 if elements[(pre_code_point_idx + dir as isize) as usize]
65 .segments()
66 .iter()
67 .any(|seg| seg.is_code())
68 {
69 break;
70 }
71 pre_code_point_idx += 2 * dir as isize;
72 }
73
74 RebreakIndices {
75 _dir: dir,
76 adj_pt_idx: adj_point_idx,
77 newline_pt_idx: newline_point_idx,
78 pre_code_pt_idx: pre_code_point_idx,
79 }
80 .into()
81 }
82}
83
84#[derive(Debug)]
85pub struct RebreakLocation {
86 target: ErasedSegment,
87 prev: RebreakIndices,
88 next: RebreakIndices,
89 line_position: LinePosition,
90 strict: bool,
91}
92
93#[derive(Debug, PartialEq, Clone, Copy, AsRefStr, EnumString)]
94#[strum(serialize_all = "lowercase")]
95pub enum LinePosition {
96 Leading,
97 Trailing,
98 Alone,
99 Strict,
100}
101
102impl RebreakLocation {
103 pub fn from_span(span: RebreakSpan, elements: &ReflowSequenceType) -> Option<Self> {
105 Self {
106 target: span.target,
107 prev: RebreakIndices::from_elements(elements, span.start_idx, -1)?,
108 next: RebreakIndices::from_elements(elements, span.end_idx, 1)?,
109 line_position: span.line_position,
110 strict: span.strict,
111 }
112 .into()
113 }
114
115 fn has_inappropriate_newlines(&self, elements: &ReflowSequenceType, strict: bool) -> bool {
116 let n_prev_newlines = elements[self.prev.newline_pt_idx as usize].num_newlines();
117 let n_next_newlines = elements[self.next.newline_pt_idx as usize].num_newlines();
118
119 let newlines_on_neither_side = n_prev_newlines + n_next_newlines == 0;
120 let newlines_on_both_sides = n_prev_newlines > 0 && n_next_newlines > 0;
121
122 (newlines_on_neither_side && !strict) || newlines_on_both_sides
123 }
124
125 fn pretty_target_name(&self) -> String {
126 format!("{} {}", self.target.get_type().as_str(), self.target.raw())
127 }
128}
129
130pub fn identify_rebreak_spans(
131 element_buffer: &ReflowSequenceType,
132 root_segment: ErasedSegment,
133) -> Vec<RebreakSpan> {
134 let mut spans = Vec::new();
135
136 for (idx, elem) in element_buffer
137 .iter()
138 .enumerate()
139 .take(element_buffer.len() - 2)
140 .skip(2)
141 {
142 let ReflowElement::Block(block) = elem else {
143 continue;
144 };
145
146 if let Some(original_line_position) = block.line_position() {
147 let line_position = original_line_position.first().unwrap();
148 spans.push(RebreakSpan {
149 target: elem.segments().first().cloned().unwrap(),
150 start_idx: idx,
151 end_idx: idx,
152 line_position: *line_position,
153 strict: original_line_position.last() == Some(&LinePosition::Strict),
154 });
155 }
156
157 for key in block.line_position_configs().keys() {
158 let mut final_idx = None;
159 if block.depth_info().stack_positions[key].idx != 0 {
160 continue;
161 }
162
163 for end_idx in idx..element_buffer.len() {
164 let end_elem = &element_buffer[end_idx];
165 let ReflowElement::Block(end_block) = end_elem else {
166 continue;
167 };
168
169 if !end_block.depth_info().stack_positions.contains_key(key) {
170 final_idx = (end_idx - 2).into();
171 } else if matches!(
172 end_block.depth_info().stack_positions[key].type_,
173 Some(StackPositionType::End) | Some(StackPositionType::Solo)
174 ) {
175 final_idx = end_idx.into();
176 }
177
178 if let Some(final_idx) = final_idx {
179 let target_depth = block
180 .depth_info()
181 .stack_hashes
182 .iter()
183 .position(|it| it == key)
184 .unwrap();
185 let target = root_segment.path_to(&element_buffer[idx].segments()[0])
186 [target_depth]
187 .segment
188 .clone();
189
190 let line_position_configs = block.line_position_configs()[key]
191 .split(':')
192 .next()
193 .unwrap();
194 let line_position = LinePosition::from_str(line_position_configs).unwrap();
195
196 spans.push(RebreakSpan {
197 target,
198 start_idx: idx,
199 end_idx: final_idx,
200 line_position,
201 strict: block.line_position_configs()[key].ends_with("strict"),
202 });
203
204 break;
205 }
206 }
207 }
208 }
209
210 spans
211}
212
213pub fn rebreak_sequence(
214 tables: &Tables,
215 elements: ReflowSequenceType,
216 root_segment: ErasedSegment,
217) -> (ReflowSequenceType, Vec<LintResult>) {
218 let mut lint_results = Vec::new();
219 let mut fixes = Vec::new();
220 let mut elem_buff = elements.clone();
221
222 let spans = identify_rebreak_spans(&elem_buff, root_segment.clone());
231
232 let mut locations = Vec::new();
233 for span in spans {
234 if let Some(loc) = RebreakLocation::from_span(span, &elements) {
235 locations.push(loc);
236 }
237 }
238
239 for loc in locations {
241 if loc.has_inappropriate_newlines(&elements, loc.strict) {
242 continue;
243 }
244
245 let prev_point = elem_buff[loc.prev.adj_pt_idx as usize]
251 .as_point()
252 .unwrap()
253 .clone();
254 let next_point = elem_buff[loc.next.adj_pt_idx as usize]
255 .as_point()
256 .unwrap()
257 .clone();
258
259 let new_results = if loc.line_position == LinePosition::Leading {
261 if elem_buff[loc.prev.newline_pt_idx as usize].num_newlines() != 0 {
262 continue;
264 }
265
266 let pretty_name = loc.pretty_target_name();
268 let _desc = if loc.strict {
269 format!(
270 "{} should always start a new line.",
271 capitalize(&pretty_name)
272 )
273 } else {
274 format!(
275 "Found trailing {}. Expected only leading near line breaks.",
276 pretty_name
277 )
278 };
279
280 if loc.next.adj_pt_idx == loc.next.pre_code_pt_idx
281 && elem_buff[loc.next.newline_pt_idx as usize].num_newlines() == 1
282 {
283 let desired_indent = next_point.get_indent().unwrap_or_default();
288
289 let (new_results, prev_point) = prev_point.indent_to(
290 tables,
291 &desired_indent,
292 None,
293 loc.target.clone().into(),
294 None,
295 None,
296 );
297
298 let (new_results, next_point) = next_point.respace_point(
299 tables,
300 elem_buff[loc.next.adj_pt_idx as usize - 1].as_block(),
301 elem_buff[loc.next.adj_pt_idx as usize + 1].as_block(),
302 &root_segment,
303 new_results,
304 true,
305 "before",
306 );
307
308 elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
310 elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
311
312 new_results
313 } else {
314 fixes.push(LintFix::delete(loc.target.clone()));
315 for seg in elem_buff[loc.prev.adj_pt_idx as usize].segments() {
316 fixes.push(LintFix::delete(seg.clone()));
317 }
318
319 let (new_results, new_point) = ReflowPoint::new(Vec::new()).respace_point(
320 tables,
321 elem_buff[(loc.next.adj_pt_idx - 1) as usize].as_block(),
322 elem_buff[(loc.next.pre_code_pt_idx + 1) as usize].as_block(),
323 &root_segment,
324 Vec::new(),
325 false,
326 "after",
327 );
328
329 let mut create_anchor = None;
330 for i in 0..loc.next.pre_code_pt_idx {
331 let idx = loc.next.pre_code_pt_idx - i;
332 if let Some(elem) = elem_buff.get(idx as usize) {
333 if let Some(segments) = elem.segments().last() {
334 create_anchor = Some(segments.clone());
335 break;
336 }
337 }
338 }
339
340 if create_anchor.is_none() {
341 panic!("Could not find anchor for creation.");
342 }
343
344 fixes.push(LintFix::create_after(
345 create_anchor.unwrap(),
346 vec![loc.target.clone()],
347 None,
348 ));
349
350 rearrange_and_insert(&mut elem_buff, &loc, new_point);
351
352 new_results
353 }
354 } else if loc.line_position == LinePosition::Trailing {
355 if elem_buff[loc.next.newline_pt_idx as usize].num_newlines() != 0 {
356 continue;
357 }
358
359 let pretty_name = loc.pretty_target_name();
360 let _desc = if loc.strict {
361 format!(
362 "{} should always be at the end of a line.",
363 capitalize(&pretty_name)
364 )
365 } else {
366 format!(
367 "Found leading {}. Expected only trailing near line breaks.",
368 pretty_name
369 )
370 };
371
372 if loc.prev.adj_pt_idx == loc.prev.pre_code_pt_idx
373 && elem_buff[loc.prev.newline_pt_idx as usize].num_newlines() == 1
374 {
375 let (new_results, next_point) = next_point.indent_to(
376 tables,
377 prev_point.get_indent().as_deref().unwrap_or_default(),
378 Some(loc.target.clone()),
379 None,
380 None,
381 None,
382 );
383
384 let (new_results, prev_point) = prev_point.respace_point(
385 tables,
386 elem_buff[loc.prev.adj_pt_idx as usize - 1].as_block(),
387 elem_buff[loc.prev.adj_pt_idx as usize + 1].as_block(),
388 &root_segment,
389 new_results,
390 true,
391 "before",
392 );
393
394 elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
396 elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
397
398 new_results
399 } else {
400 fixes.push(LintFix::delete(loc.target.clone()));
401 for seg in elem_buff[loc.next.adj_pt_idx as usize].segments() {
402 fixes.push(LintFix::delete(seg.clone()));
403 }
404
405 let (new_results, new_point) = ReflowPoint::new(Vec::new()).respace_point(
406 tables,
407 elem_buff[(loc.prev.pre_code_pt_idx - 1) as usize].as_block(),
408 elem_buff[(loc.prev.adj_pt_idx + 1) as usize].as_block(),
409 &root_segment,
410 Vec::new(),
411 false,
412 "before",
413 );
414
415 fixes.push(LintFix::create_before(
416 elem_buff[loc.prev.pre_code_pt_idx as usize].segments()[0].clone(),
417 vec![loc.target.clone()],
418 ));
419
420 reorder_and_insert(&mut elem_buff, &loc, new_point);
421
422 new_results
423 }
424 } else if loc.line_position == LinePosition::Alone {
425 let mut new_results = Vec::new();
426
427 if elem_buff[loc.next.newline_pt_idx as usize].num_newlines() == 0 {
428 let (results, next_point) = next_point.indent_to(
429 tables,
430 &deduce_line_indent(
431 loc.target.get_raw_segments().last().unwrap(),
432 &root_segment,
433 ),
434 loc.target.clone().into(),
435 None,
436 None,
437 None,
438 );
439
440 new_results = results;
441 elem_buff[loc.next.adj_pt_idx as usize] = next_point.into();
442 }
443
444 if elem_buff[loc.prev.adj_pt_idx as usize].num_newlines() == 0 {
445 let (results, prev_point) = prev_point.indent_to(
446 tables,
447 &deduce_line_indent(
448 loc.target.get_raw_segments().first().unwrap(),
449 &root_segment,
450 ),
451 None,
452 loc.target.clone().into(),
453 None,
454 None,
455 );
456
457 new_results = results;
458 elem_buff[loc.prev.adj_pt_idx as usize] = prev_point.into();
459 }
460
461 new_results
462 } else {
463 unimplemented!(
464 "Unexpected line_position config: {}",
465 loc.line_position.as_ref()
466 )
467 };
468
469 let fixes = fixes_from_results(new_results.into_iter())
470 .chain(std::mem::take(&mut fixes))
471 .collect();
472 lint_results.push(LintResult::new(
473 loc.target.clone().into(),
474 fixes,
475 None,
476 None,
477 ));
478 }
479
480 (elem_buff, lint_results)
481}
482
483fn rearrange_and_insert(
484 elem_buff: &mut Vec<ReflowElement>,
485 loc: &RebreakLocation,
486 new_point: ReflowPoint,
487) {
488 let mut new_buff = Vec::with_capacity(elem_buff.len() + 1);
489
490 new_buff.extend_from_slice(&elem_buff[..loc.prev.adj_pt_idx as usize]);
492
493 new_buff.extend_from_slice(
495 &elem_buff[loc.next.adj_pt_idx as usize..=loc.next.pre_code_pt_idx as usize],
496 );
497
498 if loc.prev.adj_pt_idx + 1 < loc.next.adj_pt_idx {
501 new_buff.extend_from_slice(
502 &elem_buff[loc.prev.adj_pt_idx as usize + 1..loc.next.adj_pt_idx as usize],
503 );
504 }
505
506 new_buff.push(new_point.into());
508
509 if loc.next.pre_code_pt_idx as usize + 1 < elem_buff.len() {
511 new_buff.extend_from_slice(&elem_buff[loc.next.pre_code_pt_idx as usize + 1..]);
512 }
513
514 *elem_buff = new_buff;
516}
517
518fn reorder_and_insert(
519 elem_buff: &mut Vec<ReflowElement>,
520 loc: &RebreakLocation,
521 new_point: ReflowPoint,
522) {
523 let mut new_buff = Vec::with_capacity(elem_buff.len() + 1);
524
525 new_buff.extend_from_slice(&elem_buff[..loc.prev.pre_code_pt_idx as usize]);
527
528 new_buff.push(new_point.into());
530
531 if loc.prev.adj_pt_idx + 1 < loc.next.adj_pt_idx {
534 new_buff.extend_from_slice(
535 &elem_buff[loc.prev.adj_pt_idx as usize + 1..loc.next.adj_pt_idx as usize],
536 );
537 }
538
539 new_buff.extend_from_slice(
542 &elem_buff[loc.prev.pre_code_pt_idx as usize..=loc.prev.adj_pt_idx as usize],
543 );
544
545 if loc.next.adj_pt_idx as usize + 1 < elem_buff.len() {
547 new_buff.extend_from_slice(&elem_buff[loc.next.adj_pt_idx as usize + 1..]);
548 }
549
550 *elem_buff = new_buff;
552}
553
554#[cfg(test)]
555mod tests {
556 use sqruff_lib::core::test_functions::parse_ansi_string;
557 use sqruff_lib_core::helpers::enter_panic;
558 use sqruff_lib_core::parser::segments::base::Tables;
559
560 use crate::utils::reflow::sequence::{ReflowSequence, TargetSide};
561
562 #[test]
563 fn test_reflow_sequence_rebreak_root() {
564 let cases = [
565 ("select 1", "select 1"),
567 ("select 1\n+2", "select 1\n+2"),
569 ("select 1+\n2", "select 1\n+ 2"), ("select\n 1 +\n 2", "select\n 1\n + 2"),
571 (
572 "select\n 1 +\n -- comment\n 2",
573 "select\n 1\n -- comment\n + 2",
574 ),
575 ("select a,b", "select a,b"),
577 ("select a\n,b", "select a,\nb"),
578 ("select\n a\n , b", "select\n a,\n b"),
579 ("select\n a\n , b", "select\n a,\n b"),
580 ("select\n a\n , b", "select\n a,\n b"),
581 (
582 "select\n a\n -- comment\n , b",
583 "select\n a,\n -- comment\n b",
584 ),
585 ];
586
587 let tables = Tables::default();
588 for (raw_sql_in, raw_sql_out) in cases {
589 let _panic = enter_panic(format!("{raw_sql_in:?}"));
590
591 let root = parse_ansi_string(raw_sql_in);
592 let config = <_>::default();
593 let seq = ReflowSequence::from_root(root, &config);
594 let new_seq = seq.rebreak(&tables);
595
596 assert_eq!(new_seq.raw(), raw_sql_out);
597 }
598 }
599
600 #[test]
601 fn test_reflow_sequence_rebreak_target() {
602 let cases = [
603 ("select 1+\n(2+3)", 4, "1+\n(", "1\n+ ("),
604 ("select a,\n(b+c)", 4, "a,\n(", "a,\n("),
605 ("select a\n , (b+c)", 6, "a\n , (", "a,\n ("),
606 ("select a,\n(b+c)", 6, ",\n(b", ",\n(b"),
609 ("select a<=b", 4, "a<=", "a<="),
611 ];
612
613 let tables = Tables::default();
614 for (raw_sql_in, target_idx, seq_sql_in, seq_sql_out) in cases {
615 let root = parse_ansi_string(raw_sql_in);
616 let target = &root.get_raw_segments()[target_idx];
617 let config = <_>::default();
618 let seq = ReflowSequence::from_around_target(target, root, TargetSide::Both, &config);
619
620 assert_eq!(seq.raw(), seq_sql_in);
621
622 let new_seq = seq.rebreak(&tables);
623 assert_eq!(new_seq.raw(), seq_sql_out);
624 }
625 }
626}