lsp_server/
msg.rs

1use std::{
2    fmt,
3    io::{self, BufRead, Write},
4};
5
6use serde::de::DeserializeOwned;
7use serde_derive::{Deserialize, Serialize};
8
9use crate::error::ExtractError;
10
11#[derive(Serialize, Deserialize, Debug, Clone)]
12#[serde(untagged)]
13pub enum Message {
14    Request(Request),
15    Response(Response),
16    Notification(Notification),
17}
18
19impl From<Request> for Message {
20    fn from(request: Request) -> Message {
21        Message::Request(request)
22    }
23}
24
25impl From<Response> for Message {
26    fn from(response: Response) -> Message {
27        Message::Response(response)
28    }
29}
30
31impl From<Notification> for Message {
32    fn from(notification: Notification) -> Message {
33        Message::Notification(notification)
34    }
35}
36
37#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
38#[serde(transparent)]
39pub struct RequestId(IdRepr);
40
41#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
42#[serde(untagged)]
43enum IdRepr {
44    I32(i32),
45    String(String),
46}
47
48impl From<i32> for RequestId {
49    fn from(id: i32) -> RequestId {
50        RequestId(IdRepr::I32(id))
51    }
52}
53
54impl From<String> for RequestId {
55    fn from(id: String) -> RequestId {
56        RequestId(IdRepr::String(id))
57    }
58}
59
60impl fmt::Display for RequestId {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        match &self.0 {
63            IdRepr::I32(it) => fmt::Display::fmt(it, f),
64            // Use debug here, to make it clear that `92` and `"92"` are
65            // different, and to reduce WTF factor if the sever uses `" "` as an
66            // ID.
67            IdRepr::String(it) => fmt::Debug::fmt(it, f),
68        }
69    }
70}
71
72#[derive(Debug, Serialize, Deserialize, Clone)]
73pub struct Request {
74    pub id: RequestId,
75    pub method: String,
76    #[serde(default = "serde_json::Value::default")]
77    #[serde(skip_serializing_if = "serde_json::Value::is_null")]
78    pub params: serde_json::Value,
79}
80
81#[derive(Debug, Serialize, Deserialize, Clone)]
82pub struct Response {
83    // JSON RPC allows this to be null if it was impossible
84    // to decode the request's id. Ignore this special case
85    // and just die horribly.
86    pub id: RequestId,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub result: Option<serde_json::Value>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub error: Option<ResponseError>,
91}
92
93#[derive(Debug, Serialize, Deserialize, Clone)]
94pub struct ResponseError {
95    pub code: i32,
96    pub message: String,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    pub data: Option<serde_json::Value>,
99}
100
101#[derive(Clone, Copy, Debug)]
102#[non_exhaustive]
103pub enum ErrorCode {
104    // Defined by JSON RPC:
105    ParseError = -32700,
106    InvalidRequest = -32600,
107    MethodNotFound = -32601,
108    InvalidParams = -32602,
109    InternalError = -32603,
110    ServerErrorStart = -32099,
111    ServerErrorEnd = -32000,
112
113    /// Error code indicating that a server received a notification or
114    /// request before the server has received the `initialize` request.
115    ServerNotInitialized = -32002,
116    UnknownErrorCode = -32001,
117
118    // Defined by the protocol:
119    /// The client has canceled a request and a server has detected
120    /// the cancel.
121    RequestCanceled = -32800,
122
123    /// The server detected that the content of a document got
124    /// modified outside normal conditions. A server should
125    /// NOT send this error code if it detects a content change
126    /// in it unprocessed messages. The result even computed
127    /// on an older state might still be useful for the client.
128    ///
129    /// If a client decides that a result is not of any use anymore
130    /// the client should cancel the request.
131    ContentModified = -32801,
132
133    /// The server cancelled the request. This error code should
134    /// only be used for requests that explicitly support being
135    /// server cancellable.
136    ///
137    /// @since 3.17.0
138    ServerCancelled = -32802,
139
140    /// A request failed but it was syntactically correct, e.g the
141    /// method name was known and the parameters were valid. The error
142    /// message should contain human readable information about why
143    /// the request failed.
144    ///
145    /// @since 3.17.0
146    RequestFailed = -32803,
147}
148
149#[derive(Debug, Serialize, Deserialize, Clone)]
150pub struct Notification {
151    pub method: String,
152    #[serde(default = "serde_json::Value::default")]
153    #[serde(skip_serializing_if = "serde_json::Value::is_null")]
154    pub params: serde_json::Value,
155}
156
157fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
158    io::Error::new(io::ErrorKind::InvalidData, error)
159}
160
161macro_rules! invalid_data {
162    ($($tt:tt)*) => (invalid_data(format!($($tt)*)))
163}
164
165impl Message {
166    pub fn read(r: &mut impl BufRead) -> io::Result<Option<Message>> {
167        Message::_read(r)
168    }
169    fn _read(r: &mut dyn BufRead) -> io::Result<Option<Message>> {
170        let text = match read_msg_text(r)? {
171            None => return Ok(None),
172            Some(text) => text,
173        };
174
175        let msg = match serde_json::from_str(&text) {
176            Ok(msg) => msg,
177            Err(e) => {
178                return Err(invalid_data!("malformed LSP payload: {:?}", e));
179            }
180        };
181
182        Ok(Some(msg))
183    }
184    pub fn write(self, w: &mut impl Write) -> io::Result<()> {
185        self._write(w)
186    }
187    fn _write(self, w: &mut dyn Write) -> io::Result<()> {
188        #[derive(Serialize)]
189        struct JsonRpc {
190            jsonrpc: &'static str,
191            #[serde(flatten)]
192            msg: Message,
193        }
194        let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?;
195        write_msg_text(w, &text)
196    }
197}
198
199impl Response {
200    pub fn new_ok<R: serde::Serialize>(id: RequestId, result: R) -> Response {
201        Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None }
202    }
203    pub fn new_err(id: RequestId, code: i32, message: String) -> Response {
204        let error = ResponseError { code, message, data: None };
205        Response { id, result: None, error: Some(error) }
206    }
207}
208
209impl Request {
210    pub fn new<P: serde::Serialize>(id: RequestId, method: String, params: P) -> Request {
211        Request { id, method, params: serde_json::to_value(params).unwrap() }
212    }
213    pub fn extract<P: DeserializeOwned>(
214        self,
215        method: &str,
216    ) -> Result<(RequestId, P), ExtractError<Request>> {
217        if self.method != method {
218            return Err(ExtractError::MethodMismatch(self));
219        }
220        match serde_json::from_value(self.params) {
221            Ok(params) => Ok((self.id, params)),
222            Err(error) => Err(ExtractError::JsonError { method: self.method, error }),
223        }
224    }
225
226    pub(crate) fn is_shutdown(&self) -> bool {
227        self.method == "shutdown"
228    }
229    pub(crate) fn is_initialize(&self) -> bool {
230        self.method == "initialize"
231    }
232}
233
234impl Notification {
235    pub fn new(method: String, params: impl serde::Serialize) -> Notification {
236        Notification { method, params: serde_json::to_value(params).unwrap() }
237    }
238    pub fn extract<P: DeserializeOwned>(
239        self,
240        method: &str,
241    ) -> Result<P, ExtractError<Notification>> {
242        if self.method != method {
243            return Err(ExtractError::MethodMismatch(self));
244        }
245        match serde_json::from_value(self.params) {
246            Ok(params) => Ok(params),
247            Err(error) => Err(ExtractError::JsonError { method: self.method, error }),
248        }
249    }
250    pub(crate) fn is_exit(&self) -> bool {
251        self.method == "exit"
252    }
253    pub(crate) fn is_initialized(&self) -> bool {
254        self.method == "initialized"
255    }
256}
257
258fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> {
259    let mut size = None;
260    let mut buf = String::new();
261    loop {
262        buf.clear();
263        if inp.read_line(&mut buf)? == 0 {
264            return Ok(None);
265        }
266        if !buf.ends_with("\r\n") {
267            return Err(invalid_data!("malformed header: {:?}", buf));
268        }
269        let buf = &buf[..buf.len() - 2];
270        if buf.is_empty() {
271            break;
272        }
273        let mut parts = buf.splitn(2, ": ");
274        let header_name = parts.next().unwrap();
275        let header_value =
276            parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?;
277        if header_name.eq_ignore_ascii_case("Content-Length") {
278            size = Some(header_value.parse::<usize>().map_err(invalid_data)?);
279        }
280    }
281    let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?;
282    let mut buf = buf.into_bytes();
283    buf.resize(size, 0);
284    inp.read_exact(&mut buf)?;
285    let buf = String::from_utf8(buf).map_err(invalid_data)?;
286    log::debug!("< {}", buf);
287    Ok(Some(buf))
288}
289
290fn write_msg_text(out: &mut dyn Write, msg: &str) -> io::Result<()> {
291    log::debug!("> {}", msg);
292    write!(out, "Content-Length: {}\r\n\r\n", msg.len())?;
293    out.write_all(msg.as_bytes())?;
294    out.flush()?;
295    Ok(())
296}
297
298#[cfg(test)]
299mod tests {
300    use super::{Message, Notification, Request, RequestId};
301
302    #[test]
303    fn shutdown_with_explicit_null() {
304        let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\", \"params\": null }";
305        let msg: Message = serde_json::from_str(text).unwrap();
306
307        assert!(
308            matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
309        );
310    }
311
312    #[test]
313    fn shutdown_with_no_params() {
314        let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\"}";
315        let msg: Message = serde_json::from_str(text).unwrap();
316
317        assert!(
318            matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
319        );
320    }
321
322    #[test]
323    fn notification_with_explicit_null() {
324        let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\", \"params\": null }";
325        let msg: Message = serde_json::from_str(text).unwrap();
326
327        assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
328    }
329
330    #[test]
331    fn notification_with_no_params() {
332        let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\"}";
333        let msg: Message = serde_json::from_str(text).unwrap();
334
335        assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
336    }
337
338    #[test]
339    fn serialize_request_with_null_params() {
340        let msg = Message::Request(Request {
341            id: RequestId::from(3),
342            method: "shutdown".into(),
343            params: serde_json::Value::Null,
344        });
345        let serialized = serde_json::to_string(&msg).unwrap();
346
347        assert_eq!("{\"id\":3,\"method\":\"shutdown\"}", serialized);
348    }
349
350    #[test]
351    fn serialize_notification_with_null_params() {
352        let msg = Message::Notification(Notification {
353            method: "exit".into(),
354            params: serde_json::Value::Null,
355        });
356        let serialized = serde_json::to_string(&msg).unwrap();
357
358        assert_eq!("{\"method\":\"exit\"}", serialized);
359    }
360}