1use crate::client::auth::no_auth::{NoAuthScheme, NO_AUTH_SCHEME_ID};
7use crate::client::defaults::{default_plugins, DefaultPluginParams};
8use crate::client::http::connection_poisoning::ConnectionPoisoningInterceptor;
9use crate::client::identity::no_auth::NoAuthIdentityResolver;
10use crate::client::identity::IdentityCache;
11use crate::client::orchestrator::endpoints::StaticUriEndpointResolver;
12use crate::client::retries::strategy::{NeverRetryStrategy, StandardRetryStrategy};
13use aws_smithy_async::rt::sleep::AsyncSleep;
14use aws_smithy_async::time::TimeSource;
15use aws_smithy_runtime_api::client::auth::static_resolver::StaticAuthSchemeOptionResolver;
16use aws_smithy_runtime_api::client::auth::{
17 AuthSchemeOptionResolverParams, SharedAuthScheme, SharedAuthSchemeOptionResolver,
18};
19use aws_smithy_runtime_api::client::endpoint::{EndpointResolverParams, SharedEndpointResolver};
20use aws_smithy_runtime_api::client::http::HttpClient;
21use aws_smithy_runtime_api::client::identity::SharedIdentityResolver;
22use aws_smithy_runtime_api::client::interceptors::context::{Error, Input, Output};
23use aws_smithy_runtime_api::client::interceptors::Intercept;
24use aws_smithy_runtime_api::client::orchestrator::HttpResponse;
25use aws_smithy_runtime_api::client::orchestrator::{HttpRequest, OrchestratorError};
26use aws_smithy_runtime_api::client::result::SdkError;
27use aws_smithy_runtime_api::client::retries::classifiers::ClassifyRetry;
28use aws_smithy_runtime_api::client::retries::SharedRetryStrategy;
29use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder;
30use aws_smithy_runtime_api::client::runtime_plugin::{
31 RuntimePlugin, RuntimePlugins, SharedRuntimePlugin, StaticRuntimePlugin,
32};
33use aws_smithy_runtime_api::client::ser_de::{
34 DeserializeResponse, SerializeRequest, SharedRequestSerializer, SharedResponseDeserializer,
35};
36use aws_smithy_runtime_api::shared::IntoShared;
37use aws_smithy_runtime_api::{
38 box_error::BoxError, client::stalled_stream_protection::StalledStreamProtectionConfig,
39};
40use aws_smithy_types::config_bag::{ConfigBag, Layer};
41use aws_smithy_types::retry::RetryConfig;
42use aws_smithy_types::timeout::TimeoutConfig;
43use std::borrow::Cow;
44use std::fmt;
45use std::marker::PhantomData;
46
47struct FnSerializer<F, I> {
48 f: F,
49 _phantom: PhantomData<I>,
50}
51impl<F, I> FnSerializer<F, I> {
52 fn new(f: F) -> Self {
53 Self {
54 f,
55 _phantom: Default::default(),
56 }
57 }
58}
59impl<F, I> SerializeRequest for FnSerializer<F, I>
60where
61 F: Fn(I) -> Result<HttpRequest, BoxError> + Send + Sync,
62 I: fmt::Debug + Send + Sync + 'static,
63{
64 fn serialize_input(&self, input: Input, _cfg: &mut ConfigBag) -> Result<HttpRequest, BoxError> {
65 let input: I = input.downcast().expect("correct type");
66 (self.f)(input)
67 }
68}
69impl<F, I> fmt::Debug for FnSerializer<F, I> {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 write!(f, "FnSerializer")
72 }
73}
74
75struct FnDeserializer<F, O, E> {
76 f: F,
77 _phantom: PhantomData<(O, E)>,
78}
79impl<F, O, E> FnDeserializer<F, O, E> {
80 fn new(deserializer: F) -> Self {
81 Self {
82 f: deserializer,
83 _phantom: Default::default(),
84 }
85 }
86}
87impl<F, O, E> DeserializeResponse for FnDeserializer<F, O, E>
88where
89 F: Fn(&HttpResponse) -> Result<O, OrchestratorError<E>> + Send + Sync,
90 O: fmt::Debug + Send + Sync + 'static,
91 E: std::error::Error + fmt::Debug + Send + Sync + 'static,
92{
93 fn deserialize_nonstreaming(
94 &self,
95 response: &HttpResponse,
96 ) -> Result<Output, OrchestratorError<Error>> {
97 (self.f)(response)
98 .map(|output| Output::erase(output))
99 .map_err(|err| err.map_operation_error(Error::erase))
100 }
101}
102impl<F, O, E> fmt::Debug for FnDeserializer<F, O, E> {
103 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
104 write!(f, "FnDeserializer")
105 }
106}
107
108#[derive(Debug)]
110pub struct Operation<I, O, E> {
111 service_name: Cow<'static, str>,
112 operation_name: Cow<'static, str>,
113 runtime_plugins: RuntimePlugins,
114 _phantom: PhantomData<(I, O, E)>,
115}
116
117impl<I, O, E> Clone for Operation<I, O, E> {
119 fn clone(&self) -> Self {
120 Self {
121 service_name: self.service_name.clone(),
122 operation_name: self.operation_name.clone(),
123 runtime_plugins: self.runtime_plugins.clone(),
124 _phantom: self._phantom,
125 }
126 }
127}
128
129impl Operation<(), (), ()> {
130 pub fn builder() -> OperationBuilder {
132 OperationBuilder::new()
133 }
134}
135
136impl<I, O, E> Operation<I, O, E>
137where
138 I: fmt::Debug + Send + Sync + 'static,
139 O: fmt::Debug + Send + Sync + 'static,
140 E: std::error::Error + fmt::Debug + Send + Sync + 'static,
141{
142 pub async fn invoke(&self, input: I) -> Result<O, SdkError<E, HttpResponse>> {
145 let input = Input::erase(input);
146
147 let output = super::invoke(
148 &self.service_name,
149 &self.operation_name,
150 input,
151 &self.runtime_plugins,
152 )
153 .await
154 .map_err(|err| err.map_service_error(|e| e.downcast().expect("correct type")))?;
155
156 Ok(output.downcast().expect("correct type"))
157 }
158}
159
160#[derive(Debug)]
162pub struct OperationBuilder<I = (), O = (), E = ()> {
163 service_name: Option<Cow<'static, str>>,
164 operation_name: Option<Cow<'static, str>>,
165 config: Layer,
166 runtime_components: RuntimeComponentsBuilder,
167 runtime_plugins: Vec<SharedRuntimePlugin>,
168 _phantom: PhantomData<(I, O, E)>,
169}
170
171impl Default for OperationBuilder<(), (), ()> {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177impl OperationBuilder<(), (), ()> {
178 pub fn new() -> Self {
180 Self {
181 service_name: None,
182 operation_name: None,
183 config: Layer::new("operation"),
184 runtime_components: RuntimeComponentsBuilder::new("operation"),
185 runtime_plugins: Vec::new(),
186 _phantom: Default::default(),
187 }
188 }
189}
190
191impl<I, O, E> OperationBuilder<I, O, E> {
192 pub fn service_name(mut self, service_name: impl Into<Cow<'static, str>>) -> Self {
194 self.service_name = Some(service_name.into());
195 self
196 }
197
198 pub fn operation_name(mut self, operation_name: impl Into<Cow<'static, str>>) -> Self {
200 self.operation_name = Some(operation_name.into());
201 self
202 }
203
204 pub fn http_client(mut self, connector: impl HttpClient + 'static) -> Self {
206 self.runtime_components.set_http_client(Some(connector));
207 self
208 }
209
210 pub fn endpoint_url(mut self, url: &str) -> Self {
212 self.config.store_put(EndpointResolverParams::new(()));
213 self.runtime_components
214 .set_endpoint_resolver(Some(SharedEndpointResolver::new(
215 StaticUriEndpointResolver::uri(url),
216 )));
217 self
218 }
219
220 pub fn retry_classifier(mut self, retry_classifier: impl ClassifyRetry + 'static) -> Self {
222 self.runtime_components
223 .push_retry_classifier(retry_classifier);
224 self
225 }
226
227 pub fn no_retry(mut self) -> Self {
229 self.runtime_components
230 .set_retry_strategy(Some(SharedRetryStrategy::new(NeverRetryStrategy::new())));
231 self
232 }
233
234 pub fn standard_retry(mut self, retry_config: &RetryConfig) -> Self {
236 self.config.store_put(retry_config.clone());
237 self.runtime_components
238 .set_retry_strategy(Some(SharedRetryStrategy::new(StandardRetryStrategy::new())));
239 self
240 }
241
242 pub fn timeout_config(mut self, timeout_config: TimeoutConfig) -> Self {
244 self.config.store_put(timeout_config);
245 self
246 }
247
248 pub fn no_auth(mut self) -> Self {
250 self.config
251 .store_put(AuthSchemeOptionResolverParams::new(()));
252 self.runtime_components
253 .set_auth_scheme_option_resolver(Some(SharedAuthSchemeOptionResolver::new(
254 StaticAuthSchemeOptionResolver::new(vec![NO_AUTH_SCHEME_ID]),
255 )));
256 self.runtime_components
257 .push_auth_scheme(SharedAuthScheme::new(NoAuthScheme::default()));
258 self.runtime_components
259 .set_identity_cache(Some(IdentityCache::no_cache()));
260 self.runtime_components.set_identity_resolver(
261 NO_AUTH_SCHEME_ID,
262 SharedIdentityResolver::new(NoAuthIdentityResolver::new()),
263 );
264 self
265 }
266
267 pub fn sleep_impl(mut self, async_sleep: impl AsyncSleep + 'static) -> Self {
269 self.runtime_components
270 .set_sleep_impl(Some(async_sleep.into_shared()));
271 self
272 }
273
274 pub fn time_source(mut self, time_source: impl TimeSource + 'static) -> Self {
276 self.runtime_components
277 .set_time_source(Some(time_source.into_shared()));
278 self
279 }
280
281 pub fn interceptor(mut self, interceptor: impl Intercept + 'static) -> Self {
283 self.runtime_components.push_interceptor(interceptor);
284 self
285 }
286
287 pub fn with_connection_poisoning(self) -> Self {
289 self.interceptor(ConnectionPoisoningInterceptor::new())
290 }
291
292 pub fn runtime_plugin(mut self, runtime_plugin: impl RuntimePlugin + 'static) -> Self {
294 self.runtime_plugins.push(runtime_plugin.into_shared());
295 self
296 }
297
298 pub fn stalled_stream_protection(
300 mut self,
301 stalled_stream_protection: StalledStreamProtectionConfig,
302 ) -> Self {
303 self.config.store_put(stalled_stream_protection);
304 self
305 }
306
307 pub fn serializer<I2>(
309 mut self,
310 serializer: impl Fn(I2) -> Result<HttpRequest, BoxError> + Send + Sync + 'static,
311 ) -> OperationBuilder<I2, O, E>
312 where
313 I2: fmt::Debug + Send + Sync + 'static,
314 {
315 self.config
316 .store_put(SharedRequestSerializer::new(FnSerializer::new(serializer)));
317 OperationBuilder {
318 service_name: self.service_name,
319 operation_name: self.operation_name,
320 config: self.config,
321 runtime_components: self.runtime_components,
322 runtime_plugins: self.runtime_plugins,
323 _phantom: Default::default(),
324 }
325 }
326
327 pub fn deserializer<O2, E2>(
329 mut self,
330 deserializer: impl Fn(&HttpResponse) -> Result<O2, OrchestratorError<E2>>
331 + Send
332 + Sync
333 + 'static,
334 ) -> OperationBuilder<I, O2, E2>
335 where
336 O2: fmt::Debug + Send + Sync + 'static,
337 E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
338 {
339 self.config
340 .store_put(SharedResponseDeserializer::new(FnDeserializer::new(
341 deserializer,
342 )));
343 OperationBuilder {
344 service_name: self.service_name,
345 operation_name: self.operation_name,
346 config: self.config,
347 runtime_components: self.runtime_components,
348 runtime_plugins: self.runtime_plugins,
349 _phantom: Default::default(),
350 }
351 }
352
353 #[allow(clippy::implied_bounds_in_impls)] pub fn deserializer_impl<O2, E2>(
356 mut self,
357 deserializer: impl DeserializeResponse + Send + Sync + 'static,
358 ) -> OperationBuilder<I, O2, E2>
359 where
360 O2: fmt::Debug + Send + Sync + 'static,
361 E2: std::error::Error + fmt::Debug + Send + Sync + 'static,
362 {
363 let deserializer: SharedResponseDeserializer = deserializer.into_shared();
364 self.config.store_put(deserializer);
365
366 OperationBuilder {
367 service_name: self.service_name,
368 operation_name: self.operation_name,
369 config: self.config,
370 runtime_components: self.runtime_components,
371 runtime_plugins: self.runtime_plugins,
372 _phantom: Default::default(),
373 }
374 }
375
376 pub fn build(self) -> Operation<I, O, E> {
378 let service_name = self.service_name.expect("service_name required");
379 let operation_name = self.operation_name.expect("operation_name required");
380
381 let mut runtime_plugins = RuntimePlugins::new()
382 .with_client_plugins(default_plugins(
383 DefaultPluginParams::new().with_retry_partition_name(service_name.clone()),
384 ))
385 .with_client_plugin(
386 StaticRuntimePlugin::new()
387 .with_config(self.config.freeze())
388 .with_runtime_components(self.runtime_components),
389 );
390 for runtime_plugin in self.runtime_plugins {
391 runtime_plugins = runtime_plugins.with_client_plugin(runtime_plugin);
392 }
393
394 #[cfg(debug_assertions)]
395 {
396 let mut config = ConfigBag::base();
397 let components = runtime_plugins
398 .apply_client_configuration(&mut config)
399 .expect("the runtime plugins should succeed");
400
401 assert!(
402 components.http_client().is_some(),
403 "a http_client is required. Enable the `default-https-client` crate feature or configure an HTTP client to fix this."
404 );
405 assert!(
406 components.endpoint_resolver().is_some(),
407 "a endpoint_resolver is required"
408 );
409 assert!(
410 components.retry_strategy().is_some(),
411 "a retry_strategy is required"
412 );
413 assert!(
414 config.load::<SharedRequestSerializer>().is_some(),
415 "a serializer is required"
416 );
417 assert!(
418 config.load::<SharedResponseDeserializer>().is_some(),
419 "a deserializer is required"
420 );
421 assert!(
422 config.load::<EndpointResolverParams>().is_some(),
423 "endpoint resolver params are required"
424 );
425 assert!(
426 config.load::<TimeoutConfig>().is_some(),
427 "timeout config is required"
428 );
429 }
430
431 Operation {
432 service_name,
433 operation_name,
434 runtime_plugins,
435 _phantom: Default::default(),
436 }
437 }
438}
439
440#[cfg(all(test, any(feature = "test-util", feature = "legacy-test-util")))]
441mod tests {
442 use super::*;
443 use crate::client::retries::classifiers::HttpStatusCodeClassifier;
444 use aws_smithy_async::rt::sleep::{SharedAsyncSleep, TokioSleep};
445 use aws_smithy_http_client::test_util::{capture_request, ReplayEvent, StaticReplayClient};
446 use aws_smithy_runtime_api::client::result::ConnectorError;
447 use aws_smithy_types::body::SdkBody;
448 use std::convert::Infallible;
449
450 #[tokio::test]
451 async fn operation() {
452 let (connector, request_rx) = capture_request(Some(
453 http_1x::Response::builder()
454 .status(418)
455 .body(SdkBody::from(&b"I'm a teapot!"[..]))
456 .unwrap(),
457 ));
458 let operation = Operation::builder()
459 .service_name("test")
460 .operation_name("test")
461 .http_client(connector)
462 .endpoint_url("http://localhost:1234")
463 .no_auth()
464 .no_retry()
465 .timeout_config(TimeoutConfig::disabled())
466 .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
467 .deserializer::<_, Infallible>(|response| {
468 assert_eq!(418, u16::from(response.status()));
469 Ok(std::str::from_utf8(response.body().bytes().unwrap())
470 .unwrap()
471 .to_string())
472 })
473 .build();
474
475 let output = operation
476 .invoke("what are you?".to_string())
477 .await
478 .expect("success");
479 assert_eq!("I'm a teapot!", output);
480
481 let request = request_rx.expect_request();
482 assert_eq!("http://localhost:1234/", request.uri());
483 assert_eq!(b"what are you?", request.body().bytes().unwrap());
484 }
485
486 #[tokio::test]
487 async fn operation_retries() {
488 let connector = StaticReplayClient::new(vec![
489 ReplayEvent::new(
490 http_1x::Request::builder()
491 .uri("http://localhost:1234/")
492 .body(SdkBody::from(&b"what are you?"[..]))
493 .unwrap(),
494 http_1x::Response::builder()
495 .status(503)
496 .body(SdkBody::from(&b""[..]))
497 .unwrap(),
498 ),
499 ReplayEvent::new(
500 http_1x::Request::builder()
501 .uri("http://localhost:1234/")
502 .body(SdkBody::from(&b"what are you?"[..]))
503 .unwrap(),
504 http_1x::Response::builder()
505 .status(418)
506 .body(SdkBody::from(&b"I'm a teapot!"[..]))
507 .unwrap(),
508 ),
509 ]);
510 let operation = Operation::builder()
511 .service_name("test")
512 .operation_name("test")
513 .http_client(connector.clone())
514 .endpoint_url("http://localhost:1234")
515 .no_auth()
516 .standard_retry(&RetryConfig::standard())
517 .retry_classifier(HttpStatusCodeClassifier::default())
518 .timeout_config(TimeoutConfig::disabled())
519 .sleep_impl(SharedAsyncSleep::new(TokioSleep::new()))
520 .serializer(|input: String| Ok(HttpRequest::new(SdkBody::from(input.as_bytes()))))
521 .deserializer::<_, Infallible>(|response| {
522 if u16::from(response.status()) == 503 {
523 Err(OrchestratorError::connector(ConnectorError::io(
524 "test".into(),
525 )))
526 } else {
527 assert_eq!(418, u16::from(response.status()));
528 Ok(std::str::from_utf8(response.body().bytes().unwrap())
529 .unwrap()
530 .to_string())
531 }
532 })
533 .build();
534
535 let output = operation
536 .invoke("what are you?".to_string())
537 .await
538 .expect("success");
539 assert_eq!("I'm a teapot!", output);
540
541 connector.assert_requests_match(&[]);
542 }
543}