aws_smithy_runtime/client/orchestrator/
operation.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use crate::client::auth::no_auth::{NoAuthScheme, NO_AUTH_SCHEME_ID};
7use crate::client::defaults::{default_plugins, DefaultPluginParams};
8use crate::client::http::connection_poisoning::ConnectionPoisoningInterceptor;
9use crate::client::identity::no_auth::NoAuthIdentityResolver;
10use crate::client::identity::IdentityCache;
11use crate::client::orchestrator::endpoints::StaticUriEndpointResolver;
12use crate::client::retries::strategy::{NeverRetryStrategy, StandardRetryStrategy};
13use aws_smithy_async::rt::sleep::AsyncSleep;
14use aws_smithy_async::time::TimeSource;
15use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
16use aws_smithy_runtime_api::client::auth::{
17    AuthSchemeOptionResolverParams, SharedAuthScheme, SharedAuthSchemeOptionResolver,
18};
19use aws_smithy_runtime_api::client::endpoint::{EndpointResolverParams, SharedEndpointResolver};
20use aws_smithy_runtime_api::client::http::HttpClient;
21use aws_smithy_runtime_api::client::identity::SharedIdentityResolver;
22use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
23use aws_smithy_runtime_api::client::interceptors::Intercept;
24use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
25use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, OrchestratorError};
26use aws_smithy_runtime_api::client::result::SdkError;
27use aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry;
28use aws_smithy_runtime_api::client::retries::SharedRetryStrategy;
29use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
30use aws_smithy_runtime_api::client::runtime_plugin::{
31    RuntimePlugin, RuntimePlugins, SharedRuntimePlugin, StaticRuntimePlugin,
32};
33use aws_smithy_runtime_api::client::ser_de::{
34    DeserializeResponse, SerializeRequest, SharedRequestSerializer, SharedResponseDeserializer,
35};
36use aws_smithy_runtime_api::shared::IntoShared;
37use aws_smithy_runtime_api::{
38    box_error::BoxError, client::stalled_stream_protection::StalledStreamProtectionConfig,
39};
40use aws_smithy_types::config_bag::{ConfigBag, Layer};
41use aws_smithy_types::retry::RetryConfig;
42use aws_smithy_types::timeout::TimeoutConfig;
43use std::borrow::Cow;
44use std::fmt;
45use std::marker::PhantomData;
46
47struct FnSerializer<F, I> {
48    f: F,
49    _phantom: PhantomData<I>,
50}
51impl<F, I> FnSerializer<F, I> {
52    fn new(f: F) -> Self {
53        Self {
54            f,
55            _phantom: Default::default(),
56        }
57    }
58}
59impl<F, I> SerializeRequest for FnSerializer<F, I>
60where
61    F: Fn(I) -> Result<HttpRequest, BoxError> + Send + Sync,
62    I: fmt::Debug + Send + Sync + 'static,
63{
64    fn serialize_input(&self, input: Input, _cfg: &mut ConfigBag) -> Result<HttpRequest, BoxError> {
65        let input: I = input.downcast().expect("correct type");
66        (self.f)(input)
67    }
68}
69impl<F, I> fmt::Debug for FnSerializer<F, I> {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        write!(f, "FnSerializer")
72    }
73}
74
75struct FnDeserializer<F, O, E> {
76    f: F,
77    _phantom: PhantomData<(O, E)>,
78}
79impl<F, O, E> FnDeserializer<F, O, E> {
80    fn new(deserializer: F) -> Self {
81        Self {
82            f: deserializer,
83            _phantom: Default::default(),
84        }
85    }
86}
87impl<F, O, E> DeserializeResponse for FnDeserializer<F, O, E>
88where
89    F: Fn(&HttpResponse) -> Result<O, OrchestratorError<E>> + Send + Sync,
90    O: fmt::Debug + Send + Sync + 'static,
91    E: std::error::Error + fmt::Debug + Send + Sync + 'static,
92{
93    fn deserialize_nonstreaming(
94        &self,
95        response: &HttpResponse,
96    ) -> Result<Output, OrchestratorError<Error>> {
97        (self.f)(response)
98            .map(|output| Output::erase(output))
99            .map_err(|err| err.map_operation_error(Error::erase))
100    }
101}
102impl<F, O, E> fmt::Debug for FnDeserializer<F, O, E> {
103    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104        write!(f, "FnDeserializer")
105    }
106}
107
108/// Orchestrates execution of a HTTP request without any modeled input or output.
109#[derive(Debug)]
110pub struct Operation<I, O, E> {
111    service_name: Cow<'static, str>,
112    operation_name: Cow<'static, str>,
113    runtime_plugins: RuntimePlugins,
114    _phantom: PhantomData<(I, O, E)>,
115}
116
117// Manual Clone implementation needed to get rid of Clone bounds on I, O, and E
118impl<I, O, E> Clone for Operation<I, O, E> {
119    fn clone(&self) -> Self {
120        Self {
121            service_name: self.service_name.clone(),
122            operation_name: self.operation_name.clone(),
123            runtime_plugins: self.runtime_plugins.clone(),
124            _phantom: self._phantom,
125        }
126    }
127}
128
129impl Operation<(), (), ()> {
130    /// Returns a new `OperationBuilder` for the `Operation`.
131    pub fn builder() -> OperationBuilder {
132        OperationBuilder::new()
133    }
134}
135
136impl<I, O, E> Operation<I, O, E>
137where
138    I: fmt::Debug + Send + Sync + 'static,
139    O: fmt::Debug + Send + Sync + 'static,
140    E: std::error::Error + fmt::Debug + Send + Sync + 'static,
141{
142    /// Invokes this `Operation` with the given `input` and returns either an output for success
143    /// or an [`SdkError`] for failure
144    pub async fn invoke(&self, input: I) -> Result<O, SdkError<E, HttpResponse>> {
145        let input = Input::erase(input);
146
147        let output = super::invoke(
148            &self.service_name,
149            &self.operation_name,
150            input,
151            &self.runtime_plugins,
152        )
153        .await
154        .map_err(|err| err.map_service_error(|e| e.downcast().expect("correct type")))?;
155
156        Ok(output.downcast().expect("correct type"))
157    }
158}
159
160/// Builder for [`Operation`].
161#[derive(Debug)]
162pub struct OperationBuilder<I = (), O = (), E = ()> {
163    service_name: Option<Cow<'static, str>>,
164    operation_name: Option<Cow<'static, str>>,
165    config: Layer,
166    runtime_components: RuntimeComponentsBuilder,
167    runtime_plugins: Vec<SharedRuntimePlugin>,
168    _phantom: PhantomData<(I, O, E)>,
169}
170
171impl Default for OperationBuilder<(), (), ()> {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl OperationBuilder<(), (), ()> {
178    /// Creates a new [`OperationBuilder`].
179    pub fn new() -> Self {
180        Self {
181            service_name: None,
182            operation_name: None,
183            config: Layer::new("operation"),
184            runtime_components: RuntimeComponentsBuilder::new("operation"),
185            runtime_plugins: Vec::new(),
186            _phantom: Default::default(),
187        }
188    }
189}
190
191impl<I, O, E> OperationBuilder<I, O, E> {
192    /// Configures the service name for the builder.
193    pub fn service_name(mut self, service_name: impl Into<Cow<'static, str>>) -> Self {
194        self.service_name = Some(service_name.into());
195        self
196    }
197
198    /// Configures the operation name for the builder.
199    pub fn operation_name(mut self, operation_name: impl Into<Cow<'static, str>>) -> Self {
200        self.operation_name = Some(operation_name.into());
201        self
202    }
203
204    /// Configures the http client for the builder.
205    pub fn http_client(mut self, connector: impl HttpClient + 'static) -> Self {
206        self.runtime_components.set_http_client(Some(connector));
207        self
208    }
209
210    /// Configures the endpoint URL for the builder.
211    pub fn endpoint_url(mut self, url: &str) -> Self {
212        self.config.store_put(EndpointResolverParams::new(()));
213        self.runtime_components
214            .set_endpoint_resolver(Some(SharedEndpointResolver::new(
215                StaticUriEndpointResolver::uri(url),
216            )));
217        self
218    }
219
220    /// Configures the retry classifier for the builder.
221    pub fn retry_classifier(mut self, retry_classifier: impl ClassifyRetry + 'static) -> Self {
222        self.runtime_components
223            .push_retry_classifier(retry_classifier);
224        self
225    }
226
227    /// Disables the retry for the operation.
228    pub fn no_retry(mut self) -> Self {
229        self.runtime_components
230            .set_retry_strategy(Some(SharedRetryStrategy::new(NeverRetryStrategy::new())));
231        self
232    }
233
234    /// Configures the standard retry for the builder.
235    pub fn standard_retry(mut self, retry_config: &RetryConfig) -> Self {
236        self.config.store_put(retry_config.clone());
237        self.runtime_components
238            .set_retry_strategy(Some(SharedRetryStrategy::new(StandardRetryStrategy::new())));
239        self
240    }
241
242    /// Configures the timeout configuration for the builder.
243    pub fn timeout_config(mut self, timeout_config: TimeoutConfig) -> Self {
244        self.config.store_put(timeout_config);
245        self
246    }
247
248    /// Disables auth for the operation.
249    pub fn no_auth(mut self) -> Self {
250        self.config
251            .store_put(AuthSchemeOptionResolverParams::new(()));
252        self.runtime_components
253            .set_auth_scheme_option_resolver(Some(SharedAuthSchemeOptionResolver::new(
254                StaticAuthSchemeOptionResolver::new(vec![NO_AUTH_SCHEME_ID]),
255            )));
256        self.runtime_components
257            .push_auth_scheme(SharedAuthScheme::new(NoAuthScheme::default()));
258        self.runtime_components
259            .set_identity_cache(Some(IdentityCache::no_cache()));
260        self.runtime_components.set_identity_resolver(
261            NO_AUTH_SCHEME_ID,
262            SharedIdentityResolver::new(NoAuthIdentityResolver::new()),
263        );
264        self
265    }
266
267    /// Configures the sleep for the builder.
268    pub fn sleep_impl(mut self, async_sleep: impl AsyncSleep + 'static) -> Self {
269        self.runtime_components
270            .set_sleep_impl(Some(async_sleep.into_shared()));
271        self
272    }
273
274    /// Configures the time source for the builder.
275    pub fn time_source(mut self, time_source: impl TimeSource + 'static) -> Self {
276        self.runtime_components
277            .set_time_source(Some(time_source.into_shared()));
278        self
279    }
280
281    /// Configures the interceptor for the builder.
282    pub fn interceptor(mut self, interceptor: impl Intercept + 'static) -> Self {
283        self.runtime_components.push_interceptor(interceptor);
284        self
285    }
286
287    /// Registers the [`ConnectionPoisoningInterceptor`].
288    pub fn with_connection_poisoning(self) -> Self {
289        self.interceptor(ConnectionPoisoningInterceptor::new())
290    }
291
292    /// Configures the runtime plugin for the builder.
293    pub fn runtime_plugin(mut self, runtime_plugin: impl RuntimePlugin + 'static) -> Self {
294        self.runtime_plugins.push(runtime_plugin.into_shared());
295        self
296    }
297
298    /// Configures stalled stream protection with the given config.
299    pub fn stalled_stream_protection(
300        mut self,
301        stalled_stream_protection: StalledStreamProtectionConfig,
302    ) -> Self {
303        self.config.store_put(stalled_stream_protection);
304        self
305    }
306
307    /// Configures the serializer for the builder.
308    pub fn serializer<I2>(
309        mut self,
310        serializer: impl Fn(I2) -> Result<HttpRequest, BoxError> + Send + Sync + 'static,
311    ) -> OperationBuilder<I2, O, E>
312    where
313        I2: fmt::Debug + Send + Sync + 'static,
314    {
315        self.config
316            .store_put(SharedRequestSerializer::new(FnSerializer::new(serializer)));
317        OperationBuilder {
318            service_name: self.service_name,
319            operation_name: self.operation_name,
320            config: self.config,
321            runtime_components: self.runtime_components,
322            runtime_plugins: self.runtime_plugins,
323            _phantom: Default::default(),
324        }
325    }
326
327    /// Configures the deserializer for the builder.
328    pub fn deserializer<O2, E2>(
329        mut self,
330        deserializer: impl Fn(&HttpResponse) -> Result<O2, OrchestratorError<E2>>
331            + Send
332            + Sync
333            + 'static,
334    ) -> OperationBuilder<I, O2, E2>
335    where
336        O2: fmt::Debug + Send + Sync + 'static,
337        E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
338    {
339        self.config
340            .store_put(SharedResponseDeserializer::new(FnDeserializer::new(
341                deserializer,
342            )));
343        OperationBuilder {
344            service_name: self.service_name,
345            operation_name: self.operation_name,
346            config: self.config,
347            runtime_components: self.runtime_components,
348            runtime_plugins: self.runtime_plugins,
349            _phantom: Default::default(),
350        }
351    }
352
353    /// Configures the a deserializer implementation for the builder.
354    #[allow(clippy::implied_bounds_in_impls)] // for `Send` and `Sync`
355    pub fn deserializer_impl<O2, E2>(
356        mut self,
357        deserializer: impl DeserializeResponse + Send + Sync + 'static,
358    ) -> OperationBuilder<I, O2, E2>
359    where
360        O2: fmt::Debug + Send + Sync + 'static,
361        E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
362    {
363        let deserializer: SharedResponseDeserializer = deserializer.into_shared();
364        self.config.store_put(deserializer);
365
366        OperationBuilder {
367            service_name: self.service_name,
368            operation_name: self.operation_name,
369            config: self.config,
370            runtime_components: self.runtime_components,
371            runtime_plugins: self.runtime_plugins,
372            _phantom: Default::default(),
373        }
374    }
375
376    /// Creates an `Operation` from the builder.
377    pub fn build(self) -> Operation<I, O, E> {
378        let service_name = self.service_name.expect("service_name required");
379        let operation_name = self.operation_name.expect("operation_name required");
380
381        let mut runtime_plugins = RuntimePlugins::new()
382            .with_client_plugins(default_plugins(
383                DefaultPluginParams::new().with_retry_partition_name(service_name.clone()),
384            ))
385            .with_client_plugin(
386                StaticRuntimePlugin::new()
387                    .with_config(self.config.freeze())
388                    .with_runtime_components(self.runtime_components),
389            );
390        for runtime_plugin in self.runtime_plugins {
391            runtime_plugins = runtime_plugins.with_client_plugin(runtime_plugin);
392        }
393
394        #[cfg(debug_assertions)]
395        {
396            let mut config = ConfigBag::base();
397            let components = runtime_plugins
398                .apply_client_configuration(&mut config)
399                .expect("the runtime plugins should succeed");
400
401            assert!(
402                components.http_client().is_some(),
403                "a http_client is required. Enable the `default-https-client` crate feature or configure an HTTP client to fix this."
404            );
405            assert!(
406                components.endpoint_resolver().is_some(),
407                "a endpoint_resolver is required"
408            );
409            assert!(
410                components.retry_strategy().is_some(),
411                "a retry_strategy is required"
412            );
413            assert!(
414                config.load::<SharedRequestSerializer>().is_some(),
415                "a serializer is required"
416            );
417            assert!(
418                config.load::<SharedResponseDeserializer>().is_some(),
419                "a deserializer is required"
420            );
421            assert!(
422                config.load::<EndpointResolverParams>().is_some(),
423                "endpoint resolver params are required"
424            );
425            assert!(
426                config.load::<TimeoutConfig>().is_some(),
427                "timeout config is required"
428            );
429        }
430
431        Operation {
432            service_name,
433            operation_name,
434            runtime_plugins,
435            _phantom: Default::default(),
436        }
437    }
438}
439
440#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
441mod tests {
442    use super::*;
443    use crate::client::retries::classifiers::HttpStatusCodeClassifier;
444    use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
445    use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
446    use aws_smithy_runtime_api::client::result::ConnectorError;
447    use aws_smithy_types::body::SdkBody;
448    use std::convert::Infallible;
449
450    #[tokio::test]
451    async fn operation() {
452        let (connector, request_rx) = capture_request(Some(
453            http_1x::Response::builder()
454                .status(418)
455                .body(SdkBody::from(&b"I'm a teapot!"[..]))
456                .unwrap(),
457        ));
458        let operation = Operation::builder()
459            .service_name("test")
460            .operation_name("test")
461            .http_client(connector)
462            .endpoint_url("http://localhost:1234")
463            .no_auth()
464            .no_retry()
465            .timeout_config(TimeoutConfig::disabled())
466            .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
467            .deserializer::<_, Infallible>(|response| {
468                assert_eq!(418, u16::from(response.status()));
469                Ok(std::str::from_utf8(response.body().bytes().unwrap())
470                    .unwrap()
471                    .to_string())
472            })
473            .build();
474
475        let output = operation
476            .invoke("what are you?".to_string())
477            .await
478            .expect("success");
479        assert_eq!("I'm a teapot!", output);
480
481        let request = request_rx.expect_request();
482        assert_eq!("http://localhost:1234/", request.uri());
483        assert_eq!(b"what are you?", request.body().bytes().unwrap());
484    }
485
486    #[tokio::test]
487    async fn operation_retries() {
488        let connector = StaticReplayClient::new(vec![
489            ReplayEvent::new(
490                http_1x::Request::builder()
491                    .uri("http://localhost:1234/")
492                    .body(SdkBody::from(&b"what are you?"[..]))
493                    .unwrap(),
494                http_1x::Response::builder()
495                    .status(503)
496                    .body(SdkBody::from(&b""[..]))
497                    .unwrap(),
498            ),
499            ReplayEvent::new(
500                http_1x::Request::builder()
501                    .uri("http://localhost:1234/")
502                    .body(SdkBody::from(&b"what are you?"[..]))
503                    .unwrap(),
504                http_1x::Response::builder()
505                    .status(418)
506                    .body(SdkBody::from(&b"I'm a teapot!"[..]))
507                    .unwrap(),
508            ),
509        ]);
510        let operation = Operation::builder()
511            .service_name("test")
512            .operation_name("test")
513            .http_client(connector.clone())
514            .endpoint_url("http://localhost:1234")
515            .no_auth()
516            .standard_retry(&RetryConfig::standard())
517            .retry_classifier(HttpStatusCodeClassifier::default())
518            .timeout_config(TimeoutConfig::disabled())
519            .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
520            .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
521            .deserializer::<_, Infallible>(|response| {
522                if u16::from(response.status()) == 503 {
523                    Err(OrchestratorError::connector(ConnectorError::io(
524                        "test".into(),
525                    )))
526                } else {
527                    assert_eq!(418, u16::from(response.status()));
528                    Ok(std::str::from_utf8(response.body().bytes().unwrap())
529                        .unwrap()
530                        .to_string())
531                }
532            })
533            .build();
534
535        let output = operation
536            .invoke("what are you?".to_string())
537            .await
538            .expect("success");
539        assert_eq!("I'm a teapot!", output);
540
541        connector.assert_requests_match(&[]);
542    }
543}