1use std::{
2 fmt,
3 io::{self, BufRead, Write},
4};
5
6use serde::de::DeserializeOwned;
7use serde_derive::{Deserialize, Serialize};
8
9use crate::error::ExtractError;
10
11#[derive(Serialize, Deserialize, Debug, Clone)]
12#[serde(untagged)]
13pub enum Message {
14 Request(Request),
15 Response(Response),
16 Notification(Notification),
17}
18
19impl From<Request> for Message {
20 fn from(request: Request) -> Message {
21 Message::Request(request)
22 }
23}
24
25impl From<Response> for Message {
26 fn from(response: Response) -> Message {
27 Message::Response(response)
28 }
29}
30
31impl From<Notification> for Message {
32 fn from(notification: Notification) -> Message {
33 Message::Notification(notification)
34 }
35}
36
37#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
38#[serde(transparent)]
39pub struct RequestId(IdRepr);
40
41#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
42#[serde(untagged)]
43enum IdRepr {
44 I32(i32),
45 String(String),
46}
47
48impl From<i32> for RequestId {
49 fn from(id: i32) -> RequestId {
50 RequestId(IdRepr::I32(id))
51 }
52}
53
54impl From<String> for RequestId {
55 fn from(id: String) -> RequestId {
56 RequestId(IdRepr::String(id))
57 }
58}
59
60impl fmt::Display for RequestId {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 match &self.0 {
63 IdRepr::I32(it) => fmt::Display::fmt(it, f),
64 IdRepr::String(it) => fmt::Debug::fmt(it, f),
68 }
69 }
70}
71
72#[derive(Debug, Serialize, Deserialize, Clone)]
73pub struct Request {
74 pub id: RequestId,
75 pub method: String,
76 #[serde(default = "serde_json::Value::default")]
77 #[serde(skip_serializing_if = "serde_json::Value::is_null")]
78 pub params: serde_json::Value,
79}
80
81#[derive(Debug, Serialize, Deserialize, Clone)]
82pub struct Response {
83 pub id: RequestId,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub result: Option<serde_json::Value>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub error: Option<ResponseError>,
91}
92
93#[derive(Debug, Serialize, Deserialize, Clone)]
94pub struct ResponseError {
95 pub code: i32,
96 pub message: String,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 pub data: Option<serde_json::Value>,
99}
100
101#[derive(Clone, Copy, Debug)]
102#[non_exhaustive]
103pub enum ErrorCode {
104 ParseError = -32700,
106 InvalidRequest = -32600,
107 MethodNotFound = -32601,
108 InvalidParams = -32602,
109 InternalError = -32603,
110 ServerErrorStart = -32099,
111 ServerErrorEnd = -32000,
112
113 ServerNotInitialized = -32002,
116 UnknownErrorCode = -32001,
117
118 RequestCanceled = -32800,
122
123 ContentModified = -32801,
132
133 ServerCancelled = -32802,
139
140 RequestFailed = -32803,
147}
148
149#[derive(Debug, Serialize, Deserialize, Clone)]
150pub struct Notification {
151 pub method: String,
152 #[serde(default = "serde_json::Value::default")]
153 #[serde(skip_serializing_if = "serde_json::Value::is_null")]
154 pub params: serde_json::Value,
155}
156
157fn invalid_data(error: impl Into<Box<dyn std::error::Error + Send + Sync>>) -> io::Error {
158 io::Error::new(io::ErrorKind::InvalidData, error)
159}
160
161macro_rules! invalid_data {
162 ($($tt:tt)*) => (invalid_data(format!($($tt)*)))
163}
164
165impl Message {
166 pub fn read(r: &mut impl BufRead) -> io::Result<Option<Message>> {
167 Message::_read(r)
168 }
169 fn _read(r: &mut dyn BufRead) -> io::Result<Option<Message>> {
170 let text = match read_msg_text(r)? {
171 None => return Ok(None),
172 Some(text) => text,
173 };
174
175 let msg = match serde_json::from_str(&text) {
176 Ok(msg) => msg,
177 Err(e) => {
178 return Err(invalid_data!("malformed LSP payload: {:?}", e));
179 }
180 };
181
182 Ok(Some(msg))
183 }
184 pub fn write(self, w: &mut impl Write) -> io::Result<()> {
185 self._write(w)
186 }
187 fn _write(self, w: &mut dyn Write) -> io::Result<()> {
188 #[derive(Serialize)]
189 struct JsonRpc {
190 jsonrpc: &'static str,
191 #[serde(flatten)]
192 msg: Message,
193 }
194 let text = serde_json::to_string(&JsonRpc { jsonrpc: "2.0", msg: self })?;
195 write_msg_text(w, &text)
196 }
197}
198
199impl Response {
200 pub fn new_ok<R: serde::Serialize>(id: RequestId, result: R) -> Response {
201 Response { id, result: Some(serde_json::to_value(result).unwrap()), error: None }
202 }
203 pub fn new_err(id: RequestId, code: i32, message: String) -> Response {
204 let error = ResponseError { code, message, data: None };
205 Response { id, result: None, error: Some(error) }
206 }
207}
208
209impl Request {
210 pub fn new<P: serde::Serialize>(id: RequestId, method: String, params: P) -> Request {
211 Request { id, method, params: serde_json::to_value(params).unwrap() }
212 }
213 pub fn extract<P: DeserializeOwned>(
214 self,
215 method: &str,
216 ) -> Result<(RequestId, P), ExtractError<Request>> {
217 if self.method != method {
218 return Err(ExtractError::MethodMismatch(self));
219 }
220 match serde_json::from_value(self.params) {
221 Ok(params) => Ok((self.id, params)),
222 Err(error) => Err(ExtractError::JsonError { method: self.method, error }),
223 }
224 }
225
226 pub(crate) fn is_shutdown(&self) -> bool {
227 self.method == "shutdown"
228 }
229 pub(crate) fn is_initialize(&self) -> bool {
230 self.method == "initialize"
231 }
232}
233
234impl Notification {
235 pub fn new(method: String, params: impl serde::Serialize) -> Notification {
236 Notification { method, params: serde_json::to_value(params).unwrap() }
237 }
238 pub fn extract<P: DeserializeOwned>(
239 self,
240 method: &str,
241 ) -> Result<P, ExtractError<Notification>> {
242 if self.method != method {
243 return Err(ExtractError::MethodMismatch(self));
244 }
245 match serde_json::from_value(self.params) {
246 Ok(params) => Ok(params),
247 Err(error) => Err(ExtractError::JsonError { method: self.method, error }),
248 }
249 }
250 pub(crate) fn is_exit(&self) -> bool {
251 self.method == "exit"
252 }
253 pub(crate) fn is_initialized(&self) -> bool {
254 self.method == "initialized"
255 }
256}
257
258fn read_msg_text(inp: &mut dyn BufRead) -> io::Result<Option<String>> {
259 let mut size = None;
260 let mut buf = String::new();
261 loop {
262 buf.clear();
263 if inp.read_line(&mut buf)? == 0 {
264 return Ok(None);
265 }
266 if !buf.ends_with("\r\n") {
267 return Err(invalid_data!("malformed header: {:?}", buf));
268 }
269 let buf = &buf[..buf.len() - 2];
270 if buf.is_empty() {
271 break;
272 }
273 let mut parts = buf.splitn(2, ": ");
274 let header_name = parts.next().unwrap();
275 let header_value =
276 parts.next().ok_or_else(|| invalid_data!("malformed header: {:?}", buf))?;
277 if header_name.eq_ignore_ascii_case("Content-Length") {
278 size = Some(header_value.parse::<usize>().map_err(invalid_data)?);
279 }
280 }
281 let size: usize = size.ok_or_else(|| invalid_data!("no Content-Length"))?;
282 let mut buf = buf.into_bytes();
283 buf.resize(size, 0);
284 inp.read_exact(&mut buf)?;
285 let buf = String::from_utf8(buf).map_err(invalid_data)?;
286 log::debug!("< {}", buf);
287 Ok(Some(buf))
288}
289
290fn write_msg_text(out: &mut dyn Write, msg: &str) -> io::Result<()> {
291 log::debug!("> {}", msg);
292 write!(out, "Content-Length: {}\r\n\r\n", msg.len())?;
293 out.write_all(msg.as_bytes())?;
294 out.flush()?;
295 Ok(())
296}
297
298#[cfg(test)]
299mod tests {
300 use super::{Message, Notification, Request, RequestId};
301
302 #[test]
303 fn shutdown_with_explicit_null() {
304 let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\", \"params\": null }";
305 let msg: Message = serde_json::from_str(text).unwrap();
306
307 assert!(
308 matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
309 );
310 }
311
312 #[test]
313 fn shutdown_with_no_params() {
314 let text = "{\"jsonrpc\": \"2.0\",\"id\": 3,\"method\": \"shutdown\"}";
315 let msg: Message = serde_json::from_str(text).unwrap();
316
317 assert!(
318 matches!(msg, Message::Request(req) if req.id == 3.into() && req.method == "shutdown")
319 );
320 }
321
322 #[test]
323 fn notification_with_explicit_null() {
324 let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\", \"params\": null }";
325 let msg: Message = serde_json::from_str(text).unwrap();
326
327 assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
328 }
329
330 #[test]
331 fn notification_with_no_params() {
332 let text = "{\"jsonrpc\": \"2.0\",\"method\": \"exit\"}";
333 let msg: Message = serde_json::from_str(text).unwrap();
334
335 assert!(matches!(msg, Message::Notification(not) if not.method == "exit"));
336 }
337
338 #[test]
339 fn serialize_request_with_null_params() {
340 let msg = Message::Request(Request {
341 id: RequestId::from(3),
342 method: "shutdown".into(),
343 params: serde_json::Value::Null,
344 });
345 let serialized = serde_json::to_string(&msg).unwrap();
346
347 assert_eq!("{\"id\":3,\"method\":\"shutdown\"}", serialized);
348 }
349
350 #[test]
351 fn serialize_notification_with_null_params() {
352 let msg = Message::Notification(Notification {
353 method: "exit".into(),
354 params: serde_json::Value::Null,
355 });
356 let serialized = serde_json::to_string(&msg).unwrap();
357
358 assert_eq!("{\"method\":\"exit\"}", serialized);
359 }
360}