1use ahash::AHashMap;
2use lsp_server::{Connection, Message, Request, RequestId, Response};
3use lsp_types::notification::{
4 DidChangeTextDocument, DidCloseTextDocument, DidOpenTextDocument, DidSaveTextDocument,
5 Notification, PublishDiagnostics,
6};
7use lsp_types::request::{Formatting, Request as _};
8use lsp_types::{
9 Diagnostic, DiagnosticSeverity, DidChangeTextDocumentParams, DidCloseTextDocumentParams,
10 DidOpenTextDocumentParams, DidSaveTextDocumentParams, DocumentFormattingParams,
11 InitializeParams, InitializeResult, NumberOrString, OneOf, Position, PublishDiagnosticsParams,
12 Registration, ServerCapabilities, TextDocumentIdentifier, TextDocumentItem,
13 TextDocumentSyncCapability, TextDocumentSyncKind, Uri, VersionedTextDocumentIdentifier,
14};
15use serde_json::Value;
16use sqruff_lib::core::config::FluffConfig;
17use sqruff_lib::core::linter::core::Linter;
18use wasm_bindgen::prelude::*;
19
20#[cfg(not(target_arch = "wasm32"))]
21fn load_config() -> FluffConfig {
22 FluffConfig::from_root(None, false, None).unwrap_or_default()
23}
24
25#[cfg(target_arch = "wasm32")]
26fn load_config() -> FluffConfig {
27 FluffConfig::default()
28}
29
30fn server_initialize_result() -> InitializeResult {
31 InitializeResult {
32 capabilities: ServerCapabilities {
33 text_document_sync: TextDocumentSyncCapability::Kind(TextDocumentSyncKind::FULL).into(),
34 document_formatting_provider: OneOf::Left(true).into(),
35 ..Default::default()
36 },
37 server_info: None,
38 }
39}
40
41pub struct LanguageServer {
42 linter: Linter,
43 send_diagnostics_callback: Box<dyn Fn(PublishDiagnosticsParams)>,
44 documents: AHashMap<Uri, String>,
45}
46
47#[wasm_bindgen]
48pub struct Wasm(LanguageServer);
49
50#[wasm_bindgen]
51impl Wasm {
52 #[wasm_bindgen(constructor)]
53 pub fn new(send_diagnostics_callback: js_sys::Function) -> Self {
54 console_error_panic_hook::set_once();
55
56 let send_diagnostics_callback = Box::leak(Box::new(send_diagnostics_callback));
57
58 Self(LanguageServer::new(|diagnostics| {
59 let diagnostics = serde_wasm_bindgen::to_value(&diagnostics).unwrap();
60 send_diagnostics_callback
61 .call1(&JsValue::null(), &diagnostics)
62 .unwrap();
63 }))
64 }
65
66 #[wasm_bindgen(js_name = saveRegistrationOptions)]
67 pub fn save_registration_options() -> JsValue {
68 serde_wasm_bindgen::to_value(&save_registration_options()).unwrap()
69 }
70
71 #[wasm_bindgen(js_name = updateConfig)]
72 pub fn update_config(&mut self, source: &str) {
73 *self.0.linter.config_mut() = FluffConfig::from_source(source, None);
74 self.0.recheck_files();
75 }
76
77 #[wasm_bindgen(js_name = onInitialize)]
78 pub fn on_initialize(&self) -> JsValue {
79 serde_wasm_bindgen::to_value(&server_initialize_result()).unwrap()
80 }
81
82 #[wasm_bindgen(js_name = onNotification)]
83 pub fn on_notification(&mut self, method: &str, params: JsValue) {
84 self.0
85 .on_notification(method, serde_wasm_bindgen::from_value(params).unwrap())
86 }
87
88 #[wasm_bindgen]
89 pub fn format(&mut self, uri: JsValue) -> JsValue {
90 let uri = serde_wasm_bindgen::from_value(uri).unwrap();
91 let edits = self.0.format(uri);
92 serde_wasm_bindgen::to_value(&edits).unwrap()
93 }
94}
95
96impl LanguageServer {
97 pub fn new(send_diagnostics_callback: impl Fn(PublishDiagnosticsParams) + 'static) -> Self {
98 Self {
99 linter: Linter::new(load_config(), None, None, false),
100 send_diagnostics_callback: Box::new(send_diagnostics_callback),
101 documents: AHashMap::new(),
102 }
103 }
104
105 fn on_request(&mut self, id: RequestId, method: &str, params: Value) -> Option<Response> {
106 match method {
107 Formatting::METHOD => {
108 let DocumentFormattingParams {
109 text_document: TextDocumentIdentifier { uri },
110 ..
111 } = serde_json::from_value(params).unwrap();
112
113 let edits = self.format(uri);
114 Some(Response::new_ok(id, edits))
115 }
116 _ => None,
117 }
118 }
119
120 fn format(&mut self, uri: Uri) -> Vec<lsp_types::TextEdit> {
121 let text = &self.documents[&uri];
122 let tree = self.linter.lint_string(text, None, true);
123
124 let new_text = tree.fix_string();
125 let start_position = Position {
126 line: 0,
127 character: 0,
128 };
129 let end_position = Position {
130 line: new_text.lines().count() as u32,
131 character: new_text.chars().count() as u32,
132 };
133
134 let result = vec![lsp_types::TextEdit {
135 range: lsp_types::Range::new(start_position, end_position),
136 new_text,
137 }];
138 result
139 }
140
141 pub fn on_notification(&mut self, method: &str, params: Value) {
142 match method {
143 DidOpenTextDocument::METHOD => {
144 let params: DidOpenTextDocumentParams = serde_json::from_value(params).unwrap();
145 let TextDocumentItem {
146 uri,
147 language_id: _,
148 version: _,
149 text,
150 } = params.text_document;
151
152 self.check_file(uri.clone(), &text);
153 self.documents.insert(uri, text);
154 }
155 DidChangeTextDocument::METHOD => {
156 let params: DidChangeTextDocumentParams = serde_json::from_value(params).unwrap();
157
158 let content = params.content_changes[0].text.clone();
159 let VersionedTextDocumentIdentifier { uri, version: _ } = params.text_document;
160
161 self.check_file(uri.clone(), &content);
162 self.documents.insert(uri, content);
163 }
164 DidCloseTextDocument::METHOD => {
165 let params: DidCloseTextDocumentParams = serde_json::from_value(params).unwrap();
166 self.documents.remove(¶ms.text_document.uri);
167 }
168 DidSaveTextDocument::METHOD => {
169 let params: DidSaveTextDocumentParams = serde_json::from_value(params).unwrap();
170 let uri = params.text_document.uri.as_str();
171
172 if uri.ends_with(".sqlfluff") || uri.ends_with(".sqruff") {
173 *self.linter.config_mut() = load_config();
174
175 self.recheck_files();
176 }
177 }
178 _ => {}
179 }
180 }
181
182 fn recheck_files(&mut self) {
183 for (uri, text) in self.documents.iter() {
184 self.check_file(uri.clone(), text);
185 }
186 }
187
188 fn check_file(&self, uri: Uri, text: &str) {
189 let result = self.linter.lint_string(text, None, false);
190
191 let diagnostics = result
192 .violations
193 .into_iter()
194 .map(|violation| {
195 let range = {
196 let pos = Position::new(
197 (violation.line_no as u32).saturating_sub(1),
198 (violation.line_pos as u32).saturating_sub(1),
199 );
200 lsp_types::Range::new(pos, pos)
201 };
202
203 let code = violation
204 .rule
205 .map(|rule| NumberOrString::String(rule.code.to_string()));
206
207 Diagnostic::new(
208 range,
209 DiagnosticSeverity::WARNING.into(),
210 code,
211 Some("sqruff".to_string()),
212 violation.description,
213 None,
214 None,
215 )
216 })
217 .collect();
218
219 let diagnostics = PublishDiagnosticsParams::new(uri.clone(), diagnostics, None);
220 (self.send_diagnostics_callback)(diagnostics);
221 }
222}
223
224pub fn run() {
225 let (connection, io_threads) = Connection::stdio();
226 let (id, params) = connection.initialize_start().unwrap();
227
228 let init_param: InitializeParams = serde_json::from_value(params).unwrap();
229 let initialize_result = serde_json::to_value(server_initialize_result()).unwrap();
230 connection.initialize_finish(id, initialize_result).unwrap();
231
232 main_loop(connection, init_param);
233
234 io_threads.join().unwrap();
235}
236
237fn main_loop(connection: Connection, _init_param: InitializeParams) {
238 let sender = connection.sender.clone();
239 let mut lsp = LanguageServer::new(move |diagnostics| {
240 let notification = new_notification::<PublishDiagnostics>(diagnostics);
241 sender.send(Message::Notification(notification)).unwrap();
242 });
243
244 let params = save_registration_options();
245 connection
246 .sender
247 .send(Message::Request(Request::new(
248 "textDocument-didSave".to_owned().into(),
249 "client/registerCapability".to_owned(),
250 params,
251 )))
252 .unwrap();
253
254 for message in &connection.receiver {
255 match message {
256 Message::Request(request) => {
257 if connection.handle_shutdown(&request).unwrap() {
258 return;
259 }
260
261 if let Some(response) = lsp.on_request(request.id, &request.method, request.params)
262 {
263 connection.sender.send(Message::Response(response)).unwrap();
264 }
265 }
266 Message::Response(_) => {}
267 Message::Notification(notification) => {
268 lsp.on_notification(¬ification.method, notification.params);
269 }
270 }
271 }
272}
273
274pub fn save_registration_options() -> lsp_types::RegistrationParams {
275 let save_registration_options = lsp_types::TextDocumentSaveRegistrationOptions {
276 include_text: false.into(),
277 text_document_registration_options: lsp_types::TextDocumentRegistrationOptions {
278 document_selector: Some(vec![
279 lsp_types::DocumentFilter {
280 language: None,
281 scheme: None,
282 pattern: Some("**/.sqlfluff".into()),
283 },
284 lsp_types::DocumentFilter {
285 language: None,
286 scheme: None,
287 pattern: Some("**/.sqruff".into()),
288 },
289 ]),
290 },
291 };
292
293 lsp_types::RegistrationParams {
294 registrations: vec![Registration {
295 id: "textDocument/didSave".into(),
296 method: "textDocument/didSave".into(),
297 register_options: serde_json::to_value(save_registration_options)
298 .unwrap()
299 .into(),
300 }],
301 }
302}
303
304fn new_notification<T>(params: T::Params) -> lsp_server::Notification
305where
306 T: Notification,
307{
308 lsp_server::Notification {
309 method: T::METHOD.to_owned(),
310 params: serde_json::to_value(¶ms).unwrap(),
311 }
312}