1use crate::WeakRpc;
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{self, Poll};
7use std::{fmt, mem, str};
8
9use hyper::header::{self, HeaderMap, HeaderValue};
10use hyper::{self, service::Service, Body, Method};
11
12use crate::jsonrpc::serde_json;
13use crate::jsonrpc::{self as core, middleware, Metadata, Middleware};
14use crate::response::Response;
15use crate::server_utils::cors;
16
17use crate::{utils, AllowedHosts, CorsDomains, RequestMiddleware, RequestMiddlewareAction, RestApi};
18
19pub struct ServerHandler<M: Metadata = (), S: Middleware<M> = middleware::Noop> {
21 jsonrpc_handler: WeakRpc<M, S>,
22 allowed_hosts: AllowedHosts,
23 cors_domains: CorsDomains,
24 cors_max_age: Option<u32>,
25 cors_allowed_headers: cors::AccessControlAllowHeaders,
26 middleware: Arc<dyn RequestMiddleware>,
27 rest_api: RestApi,
28 health_api: Option<(String, String)>,
29 max_request_body_size: usize,
30 keep_alive: bool,
31}
32
33impl<M: Metadata, S: Middleware<M>> ServerHandler<M, S> {
34 pub fn new(
36 jsonrpc_handler: WeakRpc<M, S>,
37 cors_domains: CorsDomains,
38 cors_max_age: Option<u32>,
39 cors_allowed_headers: cors::AccessControlAllowHeaders,
40 allowed_hosts: AllowedHosts,
41 middleware: Arc<dyn RequestMiddleware>,
42 rest_api: RestApi,
43 health_api: Option<(String, String)>,
44 max_request_body_size: usize,
45 keep_alive: bool,
46 ) -> Self {
47 ServerHandler {
48 jsonrpc_handler,
49 allowed_hosts,
50 cors_domains,
51 cors_max_age,
52 cors_allowed_headers,
53 middleware,
54 rest_api,
55 health_api,
56 max_request_body_size,
57 keep_alive,
58 }
59 }
60}
61
62impl<M: Metadata, S: Middleware<M>> Service<hyper::Request<Body>> for ServerHandler<M, S>
63where
64 S::Future: Unpin,
65 S::CallFuture: Unpin,
66 M: Unpin,
67{
68 type Response = hyper::Response<Body>;
69 type Error = hyper::Error;
70 type Future = Handler<M, S>;
71
72 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> task::Poll<hyper::Result<()>> {
73 task::Poll::Ready(Ok(()))
74 }
75
76 fn call(&mut self, request: hyper::Request<Body>) -> Self::Future {
77 let is_host_allowed = utils::is_host_allowed(&request, &self.allowed_hosts);
78 let action = self.middleware.on_request(request);
79
80 let (should_validate_hosts, should_continue_on_invalid_cors, response) = match action {
81 RequestMiddlewareAction::Proceed {
82 should_continue_on_invalid_cors,
83 request,
84 } => (true, should_continue_on_invalid_cors, Err(request)),
85 RequestMiddlewareAction::Respond {
86 should_validate_hosts,
87 response,
88 } => (should_validate_hosts, false, Ok(response)),
89 };
90
91 if should_validate_hosts && !is_host_allowed {
93 return Handler::Err(Some(Response::host_not_allowed()));
94 }
95
96 match response {
98 Ok(response) => Handler::Middleware(response),
99 Err(request) => {
100 Handler::Rpc(RpcHandler {
101 jsonrpc_handler: self.jsonrpc_handler.clone(),
102 state: RpcHandlerState::ReadingHeaders {
103 request,
104 cors_domains: self.cors_domains.clone(),
105 cors_headers: self.cors_allowed_headers.clone(),
106 continue_on_invalid_cors: should_continue_on_invalid_cors,
107 keep_alive: self.keep_alive,
108 },
109 is_options: false,
110 cors_max_age: self.cors_max_age,
111 cors_allow_origin: cors::AllowCors::NotRequired,
112 cors_allow_headers: cors::AllowCors::NotRequired,
113 rest_api: self.rest_api,
114 health_api: self.health_api.clone(),
115 max_request_body_size: self.max_request_body_size,
116 keep_alive: true,
118 })
119 }
120 }
121 }
122}
123
124pub enum Handler<M: Metadata, S: Middleware<M>> {
125 Rpc(RpcHandler<M, S>),
126 Err(Option<Response>),
127 Middleware(Pin<Box<dyn Future<Output = hyper::Result<hyper::Response<Body>>> + Send>>),
128}
129
130impl<M: Metadata, S: Middleware<M>> Future for Handler<M, S>
131where
132 S::Future: Unpin,
133 S::CallFuture: Unpin,
134 M: Unpin,
135{
136 type Output = hyper::Result<hyper::Response<Body>>;
137
138 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
139 match Pin::into_inner(self) {
140 Handler::Rpc(ref mut handler) => Pin::new(handler).poll(cx),
141 Handler::Middleware(ref mut middleware) => Pin::new(middleware).poll(cx),
142 Handler::Err(ref mut response) => Poll::Ready(Ok(response
143 .take()
144 .expect("Response always Some initialy. Returning `Ready` so will never be polled again; qed")
145 .into())),
146 }
147 }
148}
149
150enum RpcPollState<M> {
151 Ready(RpcHandlerState<M>),
152 NotReady(RpcHandlerState<M>),
153}
154
155impl<M> RpcPollState<M> {
156 fn decompose(self) -> (RpcHandlerState<M>, bool) {
157 use self::RpcPollState::*;
158 match self {
159 Ready(handler) => (handler, true),
160 NotReady(handler) => (handler, false),
161 }
162 }
163}
164
165enum RpcHandlerState<M> {
166 ReadingHeaders {
167 request: hyper::Request<Body>,
168 cors_domains: CorsDomains,
169 cors_headers: cors::AccessControlAllowHeaders,
170 continue_on_invalid_cors: bool,
171 keep_alive: bool,
172 },
173 ReadingBody {
174 body: hyper::Body,
175 uri: Option<hyper::Uri>,
176 request: Vec<u8>,
177 metadata: M,
178 },
179 ProcessRest {
180 uri: hyper::Uri,
181 metadata: M,
182 },
183 ProcessHealth {
184 method: String,
185 metadata: M,
186 },
187 Writing(Response),
188 Waiting(Pin<Box<dyn Future<Output = Option<String>> + Send>>),
189 WaitingForResponse(Pin<Box<dyn Future<Output = Response> + Send>>),
190 Done,
191}
192
193impl<M> fmt::Debug for RpcHandlerState<M> {
194 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
195 use self::RpcHandlerState::*;
196
197 match *self {
198 ReadingHeaders { .. } => write!(fmt, "ReadingHeaders"),
199 ReadingBody { .. } => write!(fmt, "ReadingBody"),
200 ProcessRest { .. } => write!(fmt, "ProcessRest"),
201 ProcessHealth { .. } => write!(fmt, "ProcessHealth"),
202 Writing(ref res) => write!(fmt, "Writing({:?})", res),
203 WaitingForResponse(_) => write!(fmt, "WaitingForResponse"),
204 Waiting(_) => write!(fmt, "Waiting"),
205 Done => write!(fmt, "Done"),
206 }
207 }
208}
209
210pub struct RpcHandler<M: Metadata, S: Middleware<M>> {
211 jsonrpc_handler: WeakRpc<M, S>,
212 state: RpcHandlerState<M>,
213 is_options: bool,
214 cors_allow_origin: cors::AllowCors<header::HeaderValue>,
215 cors_allow_headers: cors::AllowCors<Vec<header::HeaderValue>>,
216 cors_max_age: Option<u32>,
217 rest_api: RestApi,
218 health_api: Option<(String, String)>,
219 max_request_body_size: usize,
220 keep_alive: bool,
221}
222
223impl<M: Metadata, S: Middleware<M>> Future for RpcHandler<M, S>
224where
225 S::Future: Unpin,
226 S::CallFuture: Unpin,
227 M: Unpin,
228{
229 type Output = hyper::Result<hyper::Response<Body>>;
230
231 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
232 let this = Pin::into_inner(self);
233
234 let new_state = match mem::replace(&mut this.state, RpcHandlerState::Done) {
235 RpcHandlerState::ReadingHeaders {
236 request,
237 cors_domains,
238 cors_headers,
239 continue_on_invalid_cors,
240 keep_alive,
241 } => {
242 this.cors_allow_origin = utils::cors_allow_origin(&request, &cors_domains);
244 this.cors_allow_headers = utils::cors_allow_headers(&request, &cors_headers);
245 this.keep_alive = utils::keep_alive(&request, keep_alive);
246 this.is_options = *request.method() == Method::OPTIONS;
247 RpcPollState::Ready(this.read_headers(request, continue_on_invalid_cors))
249 }
250 RpcHandlerState::ReadingBody {
251 body,
252 request,
253 metadata,
254 uri,
255 } => match this.process_body(body, request, uri, metadata, cx) {
256 Err(BodyError::Utf8(ref e)) => {
257 let mesg = format!("utf-8 encoding error at byte {} in request body", e.valid_up_to());
258 let resp = Response::bad_request(mesg);
259 RpcPollState::Ready(RpcHandlerState::Writing(resp))
260 }
261 Err(BodyError::TooLarge) => {
262 let resp = Response::too_large("request body size exceeds allowed maximum");
263 RpcPollState::Ready(RpcHandlerState::Writing(resp))
264 }
265 Err(BodyError::Hyper(e)) => return Poll::Ready(Err(e)),
266 Ok(state) => state,
267 },
268 RpcHandlerState::ProcessRest { uri, metadata } => this.process_rest(uri, metadata)?,
269 RpcHandlerState::ProcessHealth { method, metadata } => this.process_health(method, metadata)?,
270 RpcHandlerState::WaitingForResponse(mut waiting) => match Pin::new(&mut waiting).poll(cx) {
271 Poll::Ready(response) => RpcPollState::Ready(RpcHandlerState::Writing(response)),
272 Poll::Pending => RpcPollState::NotReady(RpcHandlerState::WaitingForResponse(waiting)),
273 },
274 RpcHandlerState::Waiting(mut waiting) => {
275 match Pin::new(&mut waiting).poll(cx) {
276 Poll::Ready(response) => {
277 RpcPollState::Ready(RpcHandlerState::Writing(match response {
278 None => Response::ok(String::new()),
280 Some(result) => Response::ok(format!("{}\n", result)),
282 }))
283 }
284 Poll::Pending => RpcPollState::NotReady(RpcHandlerState::Waiting(waiting)),
285 }
286 }
287 state => RpcPollState::NotReady(state),
288 };
289
290 let (new_state, is_ready) = new_state.decompose();
291 match new_state {
292 RpcHandlerState::Writing(res) => {
293 let mut response: hyper::Response<Body> = res.into();
294 let cors_allow_origin = mem::replace(&mut this.cors_allow_origin, cors::AllowCors::Invalid);
295 let cors_allow_headers = mem::replace(&mut this.cors_allow_headers, cors::AllowCors::Invalid);
296
297 Self::set_response_headers(
298 response.headers_mut(),
299 this.is_options,
300 this.cors_max_age,
301 cors_allow_origin.into(),
302 cors_allow_headers.into(),
303 this.keep_alive,
304 );
305 Poll::Ready(Ok(response))
306 }
307 state => {
308 this.state = state;
309 if is_ready {
310 Pin::new(this).poll(cx)
311 } else {
312 Poll::Pending
313 }
314 }
315 }
316 }
317}
318
319enum BodyError {
322 Hyper(hyper::Error),
323 Utf8(str::Utf8Error),
324 TooLarge,
325}
326
327impl From<hyper::Error> for BodyError {
328 fn from(e: hyper::Error) -> BodyError {
329 BodyError::Hyper(e)
330 }
331}
332
333impl<M: Metadata, S: Middleware<M>> RpcHandler<M, S>
334where
335 S::Future: Unpin,
336 S::CallFuture: Unpin,
337{
338 fn read_headers(&self, request: hyper::Request<Body>, continue_on_invalid_cors: bool) -> RpcHandlerState<M> {
339 if self.cors_allow_origin == cors::AllowCors::Invalid && !continue_on_invalid_cors {
340 return RpcHandlerState::Writing(Response::invalid_allow_origin());
341 }
342
343 if self.cors_allow_headers == cors::AllowCors::Invalid && !continue_on_invalid_cors {
344 return RpcHandlerState::Writing(Response::invalid_allow_headers());
345 }
346
347 let handler = match self.jsonrpc_handler.upgrade() {
349 Some(handler) => handler,
350 None => return RpcHandlerState::Writing(Response::closing()),
351 };
352 let metadata = handler.extractor.read_metadata(&request);
353
354 match *request.method() {
356 Method::POST if Self::is_json(request.headers().get("content-type")) => {
359 let uri = if self.rest_api != RestApi::Disabled {
360 Some(request.uri().clone())
361 } else {
362 None
363 };
364 RpcHandlerState::ReadingBody {
365 metadata,
366 request: Default::default(),
367 uri,
368 body: request.into_body(),
369 }
370 }
371 Method::POST if self.rest_api == RestApi::Unsecure && request.uri().path().split('/').count() > 2 => {
372 RpcHandlerState::ProcessRest {
373 metadata,
374 uri: request.uri().clone(),
375 }
376 }
377 Method::POST => RpcHandlerState::Writing(Response::unsupported_content_type()),
379 Method::OPTIONS => RpcHandlerState::Writing(Response::empty()),
381 Method::GET if self.health_api.as_ref().map(|x| &*x.0) == Some(request.uri().path()) => {
383 RpcHandlerState::ProcessHealth {
384 metadata,
385 method: self
386 .health_api
387 .as_ref()
388 .map(|x| x.1.clone())
389 .expect("Health api is defined since the URI matched."),
390 }
391 }
392 _ => RpcHandlerState::Writing(Response::method_not_allowed()),
394 }
395 }
396
397 fn process_health(&self, method: String, metadata: M) -> Result<RpcPollState<M>, hyper::Error> {
398 use self::core::types::{Call, Failure, Id, MethodCall, Output, Params, Request, Success, Version};
399
400 let call = Request::Single(Call::MethodCall(MethodCall {
402 jsonrpc: Some(Version::V2),
403 method,
404 params: Params::None,
405 id: Id::Num(1),
406 }));
407
408 let response = match self.jsonrpc_handler.upgrade() {
409 Some(h) => h.handler.handle_rpc_request(call, metadata),
410 None => return Ok(RpcPollState::Ready(RpcHandlerState::Writing(Response::closing()))),
411 };
412
413 Ok(RpcPollState::Ready(RpcHandlerState::WaitingForResponse(Box::pin(
414 async {
415 match response.await {
416 Some(core::Response::Single(Output::Success(Success { result, .. }))) => {
417 let result = serde_json::to_string(&result).expect("Serialization of result is infallible;qed");
418
419 Response::ok(result)
420 }
421 Some(core::Response::Single(Output::Failure(Failure { error, .. }))) => {
422 let result = serde_json::to_string(&error).expect("Serialization of error is infallible;qed");
423
424 Response::service_unavailable(result)
425 }
426 e => Response::internal_error(format!("Invalid response for health request: {:?}", e)),
427 }
428 },
429 ))))
430 }
431
432 fn process_rest(&self, uri: hyper::Uri, metadata: M) -> Result<RpcPollState<M>, hyper::Error> {
433 use self::core::types::{Call, Id, MethodCall, Params, Request, Value, Version};
434
435 let mut it = uri.path().split('/').skip(1);
437
438 let method = it.next().unwrap_or("");
440 let mut params = Vec::new();
441 for param in it {
442 let v = serde_json::from_str(param)
443 .or_else(|_| serde_json::from_str(&format!("\"{}\"", param)))
444 .unwrap_or(Value::Null);
445 params.push(v)
446 }
447
448 let call = Request::Single(Call::MethodCall(MethodCall {
450 jsonrpc: Some(Version::V2),
451 method: method.into(),
452 params: Params::Array(params),
453 id: Id::Num(1),
454 }));
455
456 let response = match self.jsonrpc_handler.upgrade() {
457 Some(h) => h.handler.handle_rpc_request(call, metadata),
458 None => return Ok(RpcPollState::Ready(RpcHandlerState::Writing(Response::closing()))),
459 };
460
461 Ok(RpcPollState::Ready(RpcHandlerState::Waiting(Box::pin(async {
462 response
463 .await
464 .map(|x| serde_json::to_string(&x).expect("Serialization of response is infallible;qed"))
465 }))))
466 }
467
468 fn process_body(
469 &self,
470 mut body: hyper::Body,
471 mut request: Vec<u8>,
472 uri: Option<hyper::Uri>,
473 metadata: M,
474 cx: &mut task::Context<'_>,
475 ) -> Result<RpcPollState<M>, BodyError> {
476 use futures::Stream;
477
478 loop {
479 let pinned_body = Pin::new(&mut body);
480 match pinned_body.poll_next(cx)? {
481 Poll::Ready(Some(chunk)) => {
482 if request
483 .len()
484 .checked_add(chunk.len())
485 .map(|n| n > self.max_request_body_size)
486 .unwrap_or(true)
487 {
488 return Err(BodyError::TooLarge);
489 }
490 request.extend_from_slice(&*chunk)
491 }
492 Poll::Ready(None) => {
493 if let (Some(uri), true) = (uri, request.is_empty()) {
494 return Ok(RpcPollState::Ready(RpcHandlerState::ProcessRest { uri, metadata }));
495 }
496
497 let content = match str::from_utf8(&request) {
498 Ok(content) => content,
499 Err(err) => {
500 return Err(BodyError::Utf8(err));
502 }
503 };
504
505 let response = match self.jsonrpc_handler.upgrade() {
506 Some(h) => h.handler.handle_request(content, metadata),
507 None => return Ok(RpcPollState::Ready(RpcHandlerState::Writing(Response::closing()))),
508 };
509
510 return Ok(RpcPollState::Ready(RpcHandlerState::Waiting(Box::pin(response))));
512 }
513 Poll::Pending => {
514 return Ok(RpcPollState::NotReady(RpcHandlerState::ReadingBody {
515 body,
516 request,
517 metadata,
518 uri,
519 }));
520 }
521 }
522 }
523 }
524
525 fn set_response_headers(
526 headers: &mut HeaderMap,
527 is_options: bool,
528 cors_max_age: Option<u32>,
529 cors_allow_origin: Option<HeaderValue>,
530 cors_allow_headers: Option<Vec<HeaderValue>>,
531 keep_alive: bool,
532 ) {
533 let as_header = |m: Method| m.as_str().parse().expect("`Method` will always parse; qed");
534 let concat = |headers: &[HeaderValue]| {
535 let separator = b", ";
536 let val = headers
537 .iter()
538 .flat_map(|h| h.as_bytes().iter().chain(separator.iter()))
539 .cloned()
540 .collect::<Vec<_>>();
541 let max_len = if val.is_empty() { 0 } else { val.len() - 2 };
542 HeaderValue::from_bytes(&val[..max_len])
543 .expect("Concatenation of valid headers with `, ` is still valid; qed")
544 };
545
546 let allowed = concat(&[as_header(Method::OPTIONS), as_header(Method::POST)]);
547
548 if is_options {
549 headers.append(header::ALLOW, allowed.clone());
550 headers.append(header::ACCEPT, HeaderValue::from_static("application/json"));
551 }
552
553 if let Some(cors_allow_origin) = cors_allow_origin {
554 headers.append(header::VARY, HeaderValue::from_static("origin"));
555 headers.append(header::ACCESS_CONTROL_ALLOW_METHODS, allowed);
556 headers.append(header::ACCESS_CONTROL_ALLOW_ORIGIN, cors_allow_origin);
557
558 if let Some(cma) = cors_max_age {
559 headers.append(
560 header::ACCESS_CONTROL_MAX_AGE,
561 HeaderValue::from_str(&cma.to_string()).expect("`u32` will always parse; qed"),
562 );
563 }
564
565 if let Some(cors_allow_headers) = cors_allow_headers {
566 if !cors_allow_headers.is_empty() {
567 headers.append(header::ACCESS_CONTROL_ALLOW_HEADERS, concat(&cors_allow_headers));
568 }
569 }
570 }
571
572 if !keep_alive {
573 headers.append(header::CONNECTION, HeaderValue::from_static("close"));
574 }
575 }
576
577 fn is_json(content_type: Option<&header::HeaderValue>) -> bool {
580 match content_type.and_then(|val| val.to_str().ok()) {
581 Some(ref content)
582 if content.eq_ignore_ascii_case("application/json")
583 || content.eq_ignore_ascii_case("application/json; charset=utf-8")
584 || content.eq_ignore_ascii_case("application/json;charset=utf-8") =>
585 {
586 true
587 }
588 _ => false,
589 }
590 }
591}
592
593#[cfg(test)]
594mod test {
595 use super::{hyper, RpcHandler};
596 use jsonrpc_core::middleware::Noop;
597
598 #[test]
599 fn test_case_insensitive_content_type() {
600 let request = hyper::Request::builder()
601 .header("content-type", "Application/Json; charset=UTF-8")
602 .body(())
603 .unwrap();
604
605 let request2 = hyper::Request::builder()
606 .header("content-type", "Application/Json;charset=UTF-8")
607 .body(())
608 .unwrap();
609
610 assert_eq!(
611 request.headers().get("content-type").unwrap(),
612 &"Application/Json; charset=UTF-8"
613 );
614
615 assert_eq!(
616 RpcHandler::<(), Noop>::is_json(request.headers().get("content-type")),
617 true
618 );
619 assert_eq!(
620 RpcHandler::<(), Noop>::is_json(request2.headers().get("content-type")),
621 true
622 );
623 }
624}