igd_next/aio/
tokio.rs

1//! Tokio abstraction for the aio [`Gateway`].
2
3use async_trait::async_trait;
4use bytes::Bytes;
5use futures::prelude::*;
6use http_body_util::{BodyExt, Empty};
7use hyper::header::{CONTENT_LENGTH, CONTENT_TYPE};
8use hyper::Request;
9use hyper_util::client::legacy::Client;
10use std::collections::HashMap;
11use std::net::SocketAddr;
12
13use tokio::{net::UdpSocket, time::timeout};
14
15use super::{Provider, HEADER_NAME, MAX_RESPONSE_SIZE};
16use crate::common::options::{DEFAULT_TIMEOUT, RESPONSE_TIMEOUT};
17use crate::common::{messages, parsing, SearchOptions};
18use crate::errors::SearchError;
19use crate::{aio::Gateway, RequestError};
20use log::debug;
21
22/// Tokio provider for the [`Gateway`].
23#[derive(Debug, Clone)]
24pub struct Tokio;
25
26#[async_trait]
27impl Provider for Tokio {
28    async fn send_async(url: &str, action: &str, body: &str) -> Result<String, RequestError> {
29        let client = Client::builder(hyper_util::rt::TokioExecutor::new()).build_http();
30
31        let body = body.to_string();
32
33        let req = Request::builder()
34            .uri(url)
35            .method("POST")
36            .header(HEADER_NAME, action)
37            .header(CONTENT_TYPE, "text/xml")
38            .header(CONTENT_LENGTH, body.len() as u64)
39            .body(body)?;
40
41        let resp = client.request(req).await?;
42        let body = resp.into_body().collect().await?.to_bytes();
43        let string = String::from_utf8(body.to_vec())?;
44        Ok(string)
45    }
46}
47
48/// Search for a gateway with the provided options.
49pub async fn search_gateway(options: SearchOptions) -> Result<Gateway<Tokio>, SearchError> {
50    let search_timeout = options.timeout.unwrap_or(DEFAULT_TIMEOUT);
51    match timeout(search_timeout, search_gateway_inner(options)).await {
52        Ok(Ok(gateway)) => Ok(gateway),
53        Ok(Err(err)) => Err(err),
54        Err(_err) => {
55            // Timeout
56            Err(SearchError::NoResponseWithinTimeout)
57        }
58    }
59}
60
61async fn search_gateway_inner(options: SearchOptions) -> Result<Gateway<Tokio>, SearchError> {
62    // Create socket for future calls
63    let mut socket = UdpSocket::bind(&options.bind_addr).await?;
64
65    send_search_request(&mut socket, options.broadcast_address).await?;
66    let response_timeout = options.single_search_timeout.unwrap_or(RESPONSE_TIMEOUT);
67
68    loop {
69        let search_response = receive_search_response(&mut socket);
70
71        // Receive search response
72        let (response_body, from) = match timeout(response_timeout, search_response).await {
73            Ok(Ok(v)) => v,
74            Ok(Err(err)) => {
75                debug!("error while receiving broadcast response: {err}");
76                continue;
77            }
78            Err(_) => {
79                debug!("timeout while receiving broadcast response");
80                continue;
81            }
82        };
83
84        let (addr, root_url) = match handle_broadcast_resp(&from, &response_body) {
85            Ok(v) => v,
86            Err(e) => {
87                debug!("error handling broadcast response: {}", e);
88                continue;
89            }
90        };
91
92        let (control_schema_url, control_url) = match get_control_urls(&addr, &root_url).await {
93            Ok(v) => v,
94            Err(e) => {
95                debug!("error getting control URLs: {}", e);
96                continue;
97            }
98        };
99
100        let control_schema = match get_control_schemas(&addr, &control_schema_url).await {
101            Ok(v) => v,
102            Err(e) => {
103                debug!("error getting control schemas: {}", e);
104                continue;
105            }
106        };
107
108        return Ok(Gateway {
109            addr,
110            root_url,
111            control_url,
112            control_schema_url,
113            control_schema,
114            provider: Tokio,
115        });
116    }
117}
118
119// Create a new search.
120async fn send_search_request(socket: &mut UdpSocket, addr: SocketAddr) -> Result<(), SearchError> {
121    debug!(
122        "sending broadcast request to: {} on interface: {:?}",
123        addr,
124        socket.local_addr()
125    );
126    socket
127        .send_to(messages::SEARCH_REQUEST.as_bytes(), &addr)
128        .map_ok(|_| ())
129        .map_err(SearchError::from)
130        .await
131}
132
133async fn receive_search_response(socket: &mut UdpSocket) -> Result<(Vec<u8>, SocketAddr), SearchError> {
134    let mut buff = [0u8; MAX_RESPONSE_SIZE];
135    let (n, from) = socket.recv_from(&mut buff).map_err(SearchError::from).await?;
136    debug!("received broadcast response from: {}", from);
137    Ok((buff[..n].to_vec(), from))
138}
139
140// Handle a UDP response message.
141fn handle_broadcast_resp(from: &SocketAddr, data: &[u8]) -> Result<(SocketAddr, String), SearchError> {
142    debug!("handling broadcast response from: {}", from);
143
144    // Convert response to text.
145    let text = std::str::from_utf8(data).map_err(SearchError::from)?;
146
147    // Parse socket address and path.
148    let (addr, root_url) = parsing::parse_search_result(text)?;
149
150    Ok((addr, root_url))
151}
152
153async fn get_control_urls(addr: &SocketAddr, path: &str) -> Result<(String, String), SearchError> {
154    let uri = match format!("http://{addr}{path}").parse() {
155        Ok(uri) => uri,
156        Err(err) => return Err(SearchError::from(err)),
157    };
158
159    debug!("requesting control url from: {uri}");
160    let client: Client<_, Empty<Bytes>> = Client::builder(hyper_util::rt::TokioExecutor::new()).build_http();
161
162    let resp = client.get(uri).await?.into_body().collect().await?.to_bytes();
163
164    debug!("handling control response from: {addr}");
165    let c = std::io::Cursor::new(&resp);
166    parsing::parse_control_urls(c)
167}
168
169async fn get_control_schemas(
170    addr: &SocketAddr,
171    control_schema_url: &str,
172) -> Result<HashMap<String, Vec<String>>, SearchError> {
173    let uri = match format!("http://{addr}{control_schema_url}").parse() {
174        Ok(uri) => uri,
175        Err(err) => return Err(SearchError::from(err)),
176    };
177
178    debug!("requesting control schema from: {uri}");
179    let client: Client<_, Empty<Bytes>> = Client::builder(hyper_util::rt::TokioExecutor::new()).build_http();
180
181    let resp = client.get(uri).await?.into_body().collect().await?.to_bytes();
182
183    debug!("handling schema response from: {addr}");
184    let c = std::io::Cursor::new(&resp);
185    parsing::parse_schemas(c)
186}