parse_book_source/http_client/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
use crate::{HttpConfig, Result};
use anyhow::anyhow;
use rate_limiter::TokenBucket;
use reqwest::{
    header::{HeaderMap, HeaderName, HeaderValue},
    Body, Client, ClientBuilder,
};
use std::time::Duration;
pub mod rate_limiter;

#[derive(Debug, Clone)]
pub struct HttpClient {
    pub client: Client,
    pub base_url: String,
    pub rate_limiter: Option<TokenBucket>,
}

impl HttpClient {
    pub fn new(base_url: &str, config: &HttpConfig) -> Result<Self> {
        let mut client = ClientBuilder::new().cookie_store(true);

        if let Some(header) = &config.header {
            let mut headers = HeaderMap::new();

            for (k, v) in header {
                headers.insert(
                    HeaderName::try_from(k)
                        .map_err(|e| anyhow!("header name is not valid: {}", e))?,
                    HeaderValue::from_str(v)
                        .map_err(|e| anyhow!("header value is not valid: {}", e))?,
                );
            }
            client = client.default_headers(headers);
        }

        if let Some(timeout) = config.timeout {
            client = client.timeout(Duration::from_millis(timeout));
        }

        Ok(Self {
            client: client.build()?,
            base_url: base_url.to_string(),
            rate_limiter: config.rate_limit.as_ref().map(|rate_limit| {
                TokenBucket::new(
                    rate_limit.max_count as usize,
                    Duration::from_secs_f64(rate_limit.fill_duration),
                )
            }),
        })
    }

    fn url_with_base(&self, url: &str) -> String {
        if url.starts_with("http") {
            url.to_string()
        } else if url.starts_with('/') {
            format!("{}{}", self.base_url, url)
        } else {
            format!("{}/{}", self.base_url, url)
        }
    }

    pub async fn get(&self, url: &str) -> Result<reqwest::Response> {
        if let Some(rate_limiter) = &self.rate_limiter {
            rate_limiter.acquire().await;
        }

        let url = self.url_with_base(url);

        Ok(self.client.get(url).send().await?)
    }

    pub async fn post<T: Into<Body>>(&self, url: &str, body: T) -> Result<reqwest::Response> {
        if let Some(rate_limiter) = &self.rate_limiter {
            rate_limiter.acquire().await;
        }

        let url = self.url_with_base(url);

        Ok(self.client.post(url).body(body).send().await?)
    }
}