sqruff_lsp/
lib.rs

1use ahash::AHashMap;
2use lsp_server::{Connection, Message, Request, RequestId, Response};
3use lsp_types::notification::{
4    DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, DidSaveTextDocument,
5    Notification, PublishDiagnostics,
6};
7use lsp_types::request::{Formatting, Request as _};
8use lsp_types::{
9    Diagnostic, DiagnosticSeverity, DidChangeTextDocumentParams, DidCloseTextDocumentParams,
10    DidOpenTextDocumentParams, DidSaveTextDocumentParams, DocumentFormattingParams,
11    InitializeParams, InitializeResult, NumberOrString, OneOf, Position, PublishDiagnosticsParams,
12    Registration, ServerCapabilities, TextDocumentIdentifier, TextDocumentItem,
13    TextDocumentSyncCapability, TextDocumentSyncKind, Uri, VersionedTextDocumentIdentifier,
14};
15use serde_json::Value;
16use sqruff_lib::core::config::FluffConfig;
17use sqruff_lib::core::linter::core::Linter;
18use wasm_bindgen::prelude::*;
19
20#[cfg(not(target_arch = "wasm32"))]
21fn load_config() -> FluffConfig {
22    FluffConfig::from_root(None, false, None).unwrap_or_default()
23}
24
25#[cfg(target_arch = "wasm32")]
26fn load_config() -> FluffConfig {
27    FluffConfig::default()
28}
29
30fn server_initialize_result() -> InitializeResult {
31    InitializeResult {
32        capabilities: ServerCapabilities {
33            text_document_sync: TextDocumentSyncCapability::Kind(TextDocumentSyncKind::FULL).into(),
34            document_formatting_provider: OneOf::Left(true).into(),
35            ..Default::default()
36        },
37        server_info: None,
38    }
39}
40
41pub struct LanguageServer {
42    linter: Linter,
43    send_diagnostics_callback: Box<dyn Fn(PublishDiagnosticsParams)>,
44    documents: AHashMap<Uri, String>,
45}
46
47#[wasm_bindgen]
48pub struct Wasm(LanguageServer);
49
50#[wasm_bindgen]
51impl Wasm {
52    #[wasm_bindgen(constructor)]
53    pub fn new(send_diagnostics_callback: js_sys::Function) -> Self {
54        console_error_panic_hook::set_once();
55
56        let send_diagnostics_callback = Box::leak(Box::new(send_diagnostics_callback));
57
58        Self(LanguageServer::new(|diagnostics| {
59            let diagnostics = serde_wasm_bindgen::to_value(&diagnostics).unwrap();
60            send_diagnostics_callback
61                .call1(&JsValue::null(), &diagnostics)
62                .unwrap();
63        }))
64    }
65
66    #[wasm_bindgen(js_name = saveRegistrationOptions)]
67    pub fn save_registration_options() -> JsValue {
68        serde_wasm_bindgen::to_value(&save_registration_options()).unwrap()
69    }
70
71    #[wasm_bindgen(js_name = updateConfig)]
72    pub fn update_config(&mut self, source: &str) {
73        *self.0.linter.config_mut() = FluffConfig::from_source(source, None);
74        self.0.recheck_files();
75    }
76
77    #[wasm_bindgen(js_name = onInitialize)]
78    pub fn on_initialize(&self) -> JsValue {
79        serde_wasm_bindgen::to_value(&server_initialize_result()).unwrap()
80    }
81
82    #[wasm_bindgen(js_name = onNotification)]
83    pub fn on_notification(&mut self, method: &str, params: JsValue) {
84        self.0
85            .on_notification(method, serde_wasm_bindgen::from_value(params).unwrap())
86    }
87
88    #[wasm_bindgen]
89    pub fn format(&mut self, uri: JsValue) -> JsValue {
90        let uri = serde_wasm_bindgen::from_value(uri).unwrap();
91        let edits = self.0.format(uri);
92        serde_wasm_bindgen::to_value(&edits).unwrap()
93    }
94}
95
96impl LanguageServer {
97    pub fn new(send_diagnostics_callback: impl Fn(PublishDiagnosticsParams) + 'static) -> Self {
98        Self {
99            linter: Linter::new(load_config(), None, None, false),
100            send_diagnostics_callback: Box::new(send_diagnostics_callback),
101            documents: AHashMap::new(),
102        }
103    }
104
105    fn on_request(&mut self, id: RequestId, method: &str, params: Value) -> Option<Response> {
106        match method {
107            Formatting::METHOD => {
108                let DocumentFormattingParams {
109                    text_document: TextDocumentIdentifier { uri },
110                    ..
111                } = serde_json::from_value(params).unwrap();
112
113                let edits = self.format(uri);
114                Some(Response::new_ok(id, edits))
115            }
116            _ => None,
117        }
118    }
119
120    fn format(&mut self, uri: Uri) -> Vec<lsp_types::TextEdit> {
121        let text = &self.documents[&uri];
122        let tree = self.linter.lint_string(text, None, true);
123
124        let new_text = tree.fix_string();
125        let start_position = Position {
126            line: 0,
127            character: 0,
128        };
129        let end_position = Position {
130            line: new_text.lines().count() as u32,
131            character: new_text.chars().count() as u32,
132        };
133
134        let result = vec![lsp_types::TextEdit {
135            range: lsp_types::Range::new(start_position, end_position),
136            new_text,
137        }];
138        result
139    }
140
141    pub fn on_notification(&mut self, method: &str, params: Value) {
142        match method {
143            DidOpenTextDocument::METHOD => {
144                let params: DidOpenTextDocumentParams = serde_json::from_value(params).unwrap();
145                let TextDocumentItem {
146                    uri,
147                    language_id: _,
148                    version: _,
149                    text,
150                } = params.text_document;
151
152                self.check_file(uri.clone(), &text);
153                self.documents.insert(uri, text);
154            }
155            DidChangeTextDocument::METHOD => {
156                let params: DidChangeTextDocumentParams = serde_json::from_value(params).unwrap();
157
158                let content = params.content_changes[0].text.clone();
159                let VersionedTextDocumentIdentifier { uri, version: _ } = params.text_document;
160
161                self.check_file(uri.clone(), &content);
162                self.documents.insert(uri, content);
163            }
164            DidCloseTextDocument::METHOD => {
165                let params: DidCloseTextDocumentParams = serde_json::from_value(params).unwrap();
166                self.documents.remove(&params.text_document.uri);
167            }
168            DidSaveTextDocument::METHOD => {
169                let params: DidSaveTextDocumentParams = serde_json::from_value(params).unwrap();
170                let uri = params.text_document.uri.as_str();
171
172                if uri.ends_with(".sqlfluff") || uri.ends_with(".sqruff") {
173                    *self.linter.config_mut() = load_config();
174
175                    self.recheck_files();
176                }
177            }
178            _ => {}
179        }
180    }
181
182    fn recheck_files(&mut self) {
183        for (uri, text) in self.documents.iter() {
184            self.check_file(uri.clone(), text);
185        }
186    }
187
188    fn check_file(&self, uri: Uri, text: &str) {
189        let result = self.linter.lint_string(text, None, false);
190
191        let diagnostics = result
192            .violations
193            .into_iter()
194            .map(|violation| {
195                let range = {
196                    let pos = Position::new(
197                        (violation.line_no as u32).saturating_sub(1),
198                        (violation.line_pos as u32).saturating_sub(1),
199                    );
200                    lsp_types::Range::new(pos, pos)
201                };
202
203                let code = violation
204                    .rule
205                    .map(|rule| NumberOrString::String(rule.code.to_string()));
206
207                Diagnostic::new(
208                    range,
209                    DiagnosticSeverity::WARNING.into(),
210                    code,
211                    Some("sqruff".to_string()),
212                    violation.description,
213                    None,
214                    None,
215                )
216            })
217            .collect();
218
219        let diagnostics = PublishDiagnosticsParams::new(uri.clone(), diagnostics, None);
220        (self.send_diagnostics_callback)(diagnostics);
221    }
222}
223
224pub fn run() {
225    let (connection, io_threads) = Connection::stdio();
226    let (id, params) = connection.initialize_start().unwrap();
227
228    let init_param: InitializeParams = serde_json::from_value(params).unwrap();
229    let initialize_result = serde_json::to_value(server_initialize_result()).unwrap();
230    connection.initialize_finish(id, initialize_result).unwrap();
231
232    main_loop(connection, init_param);
233
234    io_threads.join().unwrap();
235}
236
237fn main_loop(connection: Connection, _init_param: InitializeParams) {
238    let sender = connection.sender.clone();
239    let mut lsp = LanguageServer::new(move |diagnostics| {
240        let notification = new_notification::<PublishDiagnostics>(diagnostics);
241        sender.send(Message::Notification(notification)).unwrap();
242    });
243
244    let params = save_registration_options();
245    connection
246        .sender
247        .send(Message::Request(Request::new(
248            "textDocument-didSave".to_owned().into(),
249            "client/registerCapability".to_owned(),
250            params,
251        )))
252        .unwrap();
253
254    for message in &connection.receiver {
255        match message {
256            Message::Request(request) => {
257                if connection.handle_shutdown(&request).unwrap() {
258                    return;
259                }
260
261                if let Some(response) = lsp.on_request(request.id, &request.method, request.params)
262                {
263                    connection.sender.send(Message::Response(response)).unwrap();
264                }
265            }
266            Message::Response(_) => {}
267            Message::Notification(notification) => {
268                lsp.on_notification(&notification.method, notification.params);
269            }
270        }
271    }
272}
273
274pub fn save_registration_options() -> lsp_types::RegistrationParams {
275    let save_registration_options = lsp_types::TextDocumentSaveRegistrationOptions {
276        include_text: false.into(),
277        text_document_registration_options: lsp_types::TextDocumentRegistrationOptions {
278            document_selector: Some(vec![
279                lsp_types::DocumentFilter {
280                    language: None,
281                    scheme: None,
282                    pattern: Some("**/.sqlfluff".into()),
283                },
284                lsp_types::DocumentFilter {
285                    language: None,
286                    scheme: None,
287                    pattern: Some("**/.sqruff".into()),
288                },
289            ]),
290        },
291    };
292
293    lsp_types::RegistrationParams {
294        registrations: vec![Registration {
295            id: "textDocument/didSave".into(),
296            method: "textDocument/didSave".into(),
297            register_options: serde_json::to_value(save_registration_options)
298                .unwrap()
299                .into(),
300        }],
301    }
302}
303
304fn new_notification<T>(params: T::Params) -> lsp_server::Notification
305where
306    T: Notification,
307{
308    lsp_server::Notification {
309        method: T::METHOD.to_owned(),
310        params: serde_json::to_value(&params).unwrap(),
311    }
312}