1mod aggregate_type;
20pub(crate) mod entity;
21mod options;
22
23pub use aggregate_type::{AggregateType, AggregateTypeList};
24pub use options::{FinalizeModify, ParallelOption};
25
26use crate::enrich::CodeEnrichment;
27use crate::enrich::ToEntityGraphTokens;
28use crate::enrich::ToRustCodeTokens;
29use convert_case::{Case, Casing};
30use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
31use quote::quote;
32use syn::parse::{Parse, ParseStream};
33use syn::punctuated::Punctuated;
34use syn::spanned::Spanned;
35use syn::{
36 parse_quote, Expr, ImplItemConst, ImplItemFn, ImplItemType, ItemFn, ItemImpl, Path, Type,
37};
38
39use crate::ToSqlConfig;
40
41use super::UsedType;
42
43const ARG_NAMES: [&str; 32] = [
45 "arg_one",
46 "arg_two",
47 "arg_three",
48 "arg_four",
49 "arg_five",
50 "arg_six",
51 "arg_seven",
52 "arg_eight",
53 "arg_nine",
54 "arg_ten",
55 "arg_eleven",
56 "arg_twelve",
57 "arg_thirteen",
58 "arg_fourteen",
59 "arg_fifteen",
60 "arg_sixteen",
61 "arg_seventeen",
62 "arg_eighteen",
63 "arg_nineteen",
64 "arg_twenty",
65 "arg_twenty_one",
66 "arg_twenty_two",
67 "arg_twenty_three",
68 "arg_twenty_four",
69 "arg_twenty_five",
70 "arg_twenty_six",
71 "arg_twenty_seven",
72 "arg_twenty_eight",
73 "arg_twenty_nine",
74 "arg_thirty",
75 "arg_thirty_one",
76 "arg_thirty_two",
77];
78
79#[derive(Debug, Clone)]
82pub struct PgAggregate {
83 item_impl: ItemImpl,
84 name: Expr,
85 target_ident: Ident,
86 pg_externs: Vec<ItemFn>,
87 type_args: AggregateTypeList,
89 type_ordered_set_args: Option<AggregateTypeList>,
90 type_moving_state: Option<UsedType>,
91 type_stype: AggregateType,
92 const_ordered_set: bool,
93 const_parallel: Option<syn::Expr>,
94 const_finalize_modify: Option<syn::Expr>,
95 const_moving_finalize_modify: Option<syn::Expr>,
96 const_initial_condition: Option<String>,
97 const_sort_operator: Option<String>,
98 const_moving_intial_condition: Option<String>,
99 fn_state: Ident,
100 fn_finalize: Option<Ident>,
101 fn_combine: Option<Ident>,
102 fn_serial: Option<Ident>,
103 fn_deserial: Option<Ident>,
104 fn_moving_state: Option<Ident>,
105 fn_moving_state_inverse: Option<Ident>,
106 fn_moving_finalize: Option<Ident>,
107 hypothetical: bool,
108 to_sql_config: ToSqlConfig,
109}
110
111impl PgAggregate {
112 pub fn new(mut item_impl: ItemImpl) -> Result<CodeEnrichment<Self>, syn::Error> {
113 let to_sql_config =
114 ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
115 let target_path = get_target_path(&item_impl)?;
116 let target_ident = get_target_ident(&target_path)?;
117
118 let snake_case_target_ident =
119 Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
120 crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
121
122 let mut pg_externs = Vec::default();
123 let item_impl_snapshot = item_impl.clone();
126
127 if let Some((_, ref path, _)) = item_impl.trait_ {
128 if let Some(last) = path.segments.last() {
130 if last.ident != "Aggregate" {
131 return Err(syn::Error::new(
132 last.ident.span(),
133 "`#[pg_aggregate]` only works with the `Aggregate` trait.",
134 ));
135 }
136 }
137 }
138
139 let name = match get_impl_const_by_name(&item_impl_snapshot, "NAME") {
140 Some(item_const) => match &item_const.expr {
141 syn::Expr::Lit(ref expr) => {
142 if let syn::Lit::Str(_) = &expr.lit {
143 item_const.expr.clone()
144 } else {
145 return Err(syn::Error::new(
146 expr.span(),
147 "`NAME` must be a `&'static str` for Aggregate implementations.",
148 ));
149 }
150 }
151 e => {
152 return Err(syn::Error::new(
153 e.span(),
154 "`NAME` must be a `&'static str` for Aggregate implementations.",
155 ));
156 }
157 },
158 None => {
159 item_impl.items.push(parse_quote! {
160 const NAME: &'static str = stringify!(Self);
161 });
162 parse_quote! {
163 stringify!(#target_ident)
164 }
165 }
166 };
167
168 let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
170 let _type_state_value = type_state.map(|v| v.ty.clone());
171
172 let type_state_without_self = if let Some(inner) = type_state {
173 let mut remapped = inner.ty.clone();
174 remap_self_to_target(&mut remapped, &target_ident);
175 remapped
176 } else {
177 item_impl.items.push(parse_quote! {
178 type State = Self;
179 });
180 let mut remapped = parse_quote!(Self);
181 remap_self_to_target(&mut remapped, &target_ident);
182 remapped
183 };
184 let type_stype = AggregateType {
185 used_ty: UsedType::new(type_state_without_self.clone())?,
186 name: Some("state".into()),
187 };
188
189 let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
191 let type_moving_state;
192 let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
193 type_moving_state = impl_type_moving_state.ty.clone();
194 Some(UsedType::new(type_moving_state.clone())?)
195 } else {
196 item_impl.items.push(parse_quote! {
197 type MovingState = ();
198 });
199 type_moving_state = parse_quote! { () };
200 None
201 };
202
203 let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
205 let type_ordered_set_args_value =
206 type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
207 if type_ordered_set_args.is_none() {
208 item_impl.items.push(parse_quote! {
209 type OrderedSetArgs = ();
210 })
211 }
212 let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
213 type_ordered_set_args_value
214 {
215 let direct_args = order_by_direct_args
216 .found
217 .iter()
218 .map(|x| {
219 (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
220 })
221 .collect::<Vec<_>>();
222 let direct_arg_names = ARG_NAMES[0..direct_args.len()]
223 .iter()
224 .zip(direct_args.iter())
225 .map(|(default_name, (custom_name, _ty, _orig))| {
226 Ident::new(
227 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
228 Span::mixed_site(),
229 )
230 })
231 .collect::<Vec<_>>();
232 let direct_args_with_names = direct_args
233 .iter()
234 .zip(direct_arg_names.iter())
235 .map(|(arg, name)| {
236 let arg_ty = &arg.2; parse_quote! {
238 #name: #arg_ty
239 }
240 })
241 .collect::<Vec<syn::FnArg>>();
242 (direct_args_with_names, direct_arg_names)
243 } else {
244 (Vec::default(), Vec::default())
245 };
246
247 let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
249 syn::Error::new(
250 item_impl_snapshot.span(),
251 "`#[pg_aggregate]` requires the `Args` type defined.",
252 )
253 })?;
254 let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
255 let args = type_args_value
256 .found
257 .iter()
258 .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
259 .collect::<Vec<_>>();
260 let arg_names = ARG_NAMES[0..args.len()]
261 .iter()
262 .zip(args.iter())
263 .map(|(default_name, (custom_name, ty))| {
264 Ident::new(
265 &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
266 ty.span(),
267 )
268 })
269 .collect::<Vec<_>>();
270 let args_with_names = args
271 .iter()
272 .zip(arg_names.iter())
273 .map(|(arg, name)| {
274 let arg_ty = &arg.1;
275 quote! {
276 #name: #arg_ty
277 }
278 })
279 .collect::<Vec<_>>();
280
281 let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
283 let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
284 type_finalize.ty.clone()
285 } else {
286 item_impl.items.push(parse_quote! {
287 type Finalize = ();
288 });
289 parse_quote! { () }
290 };
291
292 let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
293
294 let fn_state_name = if let Some(found) = fn_state {
295 let fn_name =
296 Ident::new(&format!("{}_state", snake_case_target_ident), found.sig.ident.span());
297 let pg_extern_attr = pg_extern_attr(found);
298
299 pg_externs.push(parse_quote! {
300 #[allow(non_snake_case, clippy::too_many_arguments)]
301 #pg_extern_attr
302 fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
303 unsafe {
304 <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
305 fcinfo,
306 move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::state(this, (#(#arg_names),*), fcinfo)
307 )
308 }
309 }
310 });
311 fn_name
312 } else {
313 return Err(syn::Error::new(
314 item_impl.span(),
315 "Aggregate implementation must include state function.",
316 ));
317 };
318
319 let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
320 let fn_combine_name = if let Some(found) = fn_combine {
321 let fn_name =
322 Ident::new(&format!("{}_combine", snake_case_target_ident), found.sig.ident.span());
323 let pg_extern_attr = pg_extern_attr(found);
324 pg_externs.push(parse_quote! {
325 #[allow(non_snake_case, clippy::too_many_arguments)]
326 #pg_extern_attr
327 fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
328 unsafe {
329 <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
330 fcinfo,
331 move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::combine(this, v, fcinfo)
332 )
333 }
334 }
335 });
336 Some(fn_name)
337 } else {
338 item_impl.items.push(parse_quote! {
339 fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
340 unimplemented!("Call to combine on an aggregate which does not support it.")
341 }
342 });
343 None
344 };
345
346 let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
347 let fn_finalize_name = if let Some(found) = fn_finalize {
348 let fn_name = Ident::new(
349 &format!("{}_finalize", snake_case_target_ident),
350 found.sig.ident.span(),
351 );
352 let pg_extern_attr = pg_extern_attr(found);
353
354 if !direct_args_with_names.is_empty() {
355 pg_externs.push(parse_quote! {
356 #[allow(non_snake_case, clippy::too_many_arguments)]
357 #pg_extern_attr
358 fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
359 unsafe {
360 <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
361 fcinfo,
362 move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::finalize(this, (#(#direct_arg_names),*), fcinfo)
363 )
364 }
365 }
366 });
367 } else {
368 pg_externs.push(parse_quote! {
369 #[allow(non_snake_case, clippy::too_many_arguments)]
370 #pg_extern_attr
371 fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
372 unsafe {
373 <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
374 fcinfo,
375 move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::finalize(this, (), fcinfo)
376 )
377 }
378 }
379 });
380 };
381 Some(fn_name)
382 } else {
383 item_impl.items.push(parse_quote! {
384 fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
385 unimplemented!("Call to finalize on an aggregate which does not support it.")
386 }
387 });
388 None
389 };
390
391 let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
392 let fn_serial_name = if let Some(found) = fn_serial {
393 let fn_name =
394 Ident::new(&format!("{}_serial", snake_case_target_ident), found.sig.ident.span());
395 let pg_extern_attr = pg_extern_attr(found);
396 pg_externs.push(parse_quote! {
397 #[allow(non_snake_case, clippy::too_many_arguments)]
398 #pg_extern_attr
399 fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
400 unsafe {
401 <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
402 fcinfo,
403 move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::serial(this, fcinfo)
404 )
405 }
406 }
407 });
408 Some(fn_name)
409 } else {
410 item_impl.items.push(parse_quote! {
411 fn serial(current: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
412 unimplemented!("Call to serial on an aggregate which does not support it.")
413 }
414 });
415 None
416 };
417
418 let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
419 let fn_deserial_name = if let Some(found) = fn_deserial {
420 let fn_name = Ident::new(
421 &format!("{}_deserial", snake_case_target_ident),
422 found.sig.ident.span(),
423 );
424 let pg_extern_attr = pg_extern_attr(found);
425 pg_externs.push(parse_quote! {
426 #[allow(non_snake_case, clippy::too_many_arguments)]
427 #pg_extern_attr
428 fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
429 unsafe {
430 <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
431 fcinfo,
432 move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::deserial(this, buf, internal, fcinfo)
433 )
434 }
435 }
436 });
437 Some(fn_name)
438 } else {
439 item_impl.items.push(parse_quote! {
440 fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
441 unimplemented!("Call to deserial on an aggregate which does not support it.")
442 }
443 });
444 None
445 };
446
447 let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
448 let fn_moving_state_name = if let Some(found) = fn_moving_state {
449 let fn_name = Ident::new(
450 &format!("{}_moving_state", snake_case_target_ident),
451 found.sig.ident.span(),
452 );
453 let pg_extern_attr = pg_extern_attr(found);
454
455 pg_externs.push(parse_quote! {
456 #[allow(non_snake_case, clippy::too_many_arguments)]
457 #pg_extern_attr
458 fn #fn_name(
459 mstate: #type_moving_state,
460 #(#args_with_names),*,
461 fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
462 ) -> #type_moving_state {
463 unsafe {
464 <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
465 fcinfo,
466 move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::moving_state(mstate, (#(#arg_names),*), fcinfo)
467 )
468 }
469 }
470 });
471 Some(fn_name)
472 } else {
473 item_impl.items.push(parse_quote! {
474 fn moving_state(
475 _mstate: <#target_path as ::pgrx::aggregate::Aggregate>::MovingState,
476 _v: Self::Args,
477 _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
478 ) -> <#target_path as ::pgrx::aggregate::Aggregate>::MovingState {
479 unimplemented!("Call to moving_state on an aggregate which does not support it.")
480 }
481 });
482 None
483 };
484
485 let fn_moving_state_inverse =
486 get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
487 let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
488 let fn_name = Ident::new(
489 &format!("{}_moving_state_inverse", snake_case_target_ident),
490 found.sig.ident.span(),
491 );
492 let pg_extern_attr = pg_extern_attr(found);
493 pg_externs.push(parse_quote! {
494 #[allow(non_snake_case, clippy::too_many_arguments)]
495 #pg_extern_attr
496 fn #fn_name(
497 mstate: #type_moving_state,
498 #(#args_with_names),*,
499 fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
500 ) -> #type_moving_state {
501 unsafe {
502 <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
503 fcinfo,
504 move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
505 )
506 }
507 }
508 });
509 Some(fn_name)
510 } else {
511 item_impl.items.push(parse_quote! {
512 fn moving_state_inverse(
513 _mstate: #type_moving_state,
514 _v: Self::Args,
515 _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
516 ) -> #type_moving_state {
517 unimplemented!("Call to moving_state on an aggregate which does not support it.")
518 }
519 });
520 None
521 };
522
523 let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
524 let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
525 let fn_name = Ident::new(
526 &format!("{}_moving_finalize", snake_case_target_ident),
527 found.sig.ident.span(),
528 );
529 let pg_extern_attr = pg_extern_attr(found);
530 let maybe_comma: Option<syn::Token![,]> =
531 if !direct_args_with_names.is_empty() { Some(parse_quote! {,}) } else { None };
532
533 pg_externs.push(parse_quote! {
534 #[allow(non_snake_case, clippy::too_many_arguments)]
535 #pg_extern_attr
536 fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
537 unsafe {
538 <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
539 fcinfo,
540 move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
541 )
542 }
543 }
544 });
545 Some(fn_name)
546 } else {
547 item_impl.items.push(parse_quote! {
548 fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Self::Finalize {
549 unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
550 }
551 });
552 None
553 };
554
555 Ok(CodeEnrichment(Self {
556 item_impl,
557 target_ident,
558 pg_externs,
559 name,
560 type_args: type_args_value,
561 type_ordered_set_args: type_ordered_set_args_value,
562 type_moving_state: type_moving_state_value,
563 type_stype,
564 const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
565 .map(|x| x.expr.clone()),
566 const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
567 .map(|x| x.expr.clone()),
568 const_moving_finalize_modify: get_impl_const_by_name(
569 &item_impl_snapshot,
570 "MOVING_FINALIZE_MODIFY",
571 )
572 .map(|x| x.expr.clone()),
573 const_initial_condition: get_impl_const_by_name(
574 &item_impl_snapshot,
575 "INITIAL_CONDITION",
576 )
577 .and_then(|e| get_const_litstr(e).transpose())
578 .transpose()?,
579 const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
580 .and_then(get_const_litbool)
581 .unwrap_or(false),
582 const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
583 .and_then(|e| get_const_litstr(e).transpose())
584 .transpose()?,
585 const_moving_intial_condition: get_impl_const_by_name(
586 &item_impl_snapshot,
587 "MOVING_INITIAL_CONDITION",
588 )
589 .and_then(|e| get_const_litstr(e).transpose())
590 .transpose()?,
591 fn_state: fn_state_name,
592 fn_finalize: fn_finalize_name,
593 fn_combine: fn_combine_name,
594 fn_serial: fn_serial_name,
595 fn_deserial: fn_deserial_name,
596 fn_moving_state: fn_moving_state_name,
597 fn_moving_state_inverse: fn_moving_state_inverse_name,
598 fn_moving_finalize: fn_moving_finalize_name,
599 hypothetical: if let Some(value) =
600 get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
601 {
602 match &value.expr {
603 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
604 syn::Lit::Bool(lit) => lit.value,
605 _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
606 },
607 _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
608 }
609 } else {
610 false
611 },
612 to_sql_config,
613 }))
614 }
615}
616
617impl ToEntityGraphTokens for PgAggregate {
618 fn to_entity_graph_tokens(&self) -> TokenStream2 {
619 let target_ident = &self.target_ident;
620 let snake_case_target_ident =
621 Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
622 let sql_graph_entity_fn_name = syn::Ident::new(
623 &format!("__pgrx_internals_aggregate_{}", snake_case_target_ident),
624 target_ident.span(),
625 );
626
627 let name = &self.name;
628 let type_args_iter = &self.type_args.entity_tokens();
629 let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens());
630
631 let type_moving_state_entity_tokens =
632 self.type_moving_state.clone().map(|v| v.entity_tokens());
633 let type_moving_state_entity_tokens_iter = type_moving_state_entity_tokens.iter();
634 let type_stype = self.type_stype.entity_tokens();
635 let const_ordered_set = self.const_ordered_set;
636 let const_parallel_iter = self.const_parallel.iter();
637 let const_finalize_modify_iter = self.const_finalize_modify.iter();
638 let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter();
639 let const_initial_condition_iter = self.const_initial_condition.iter();
640 let const_sort_operator_iter = self.const_sort_operator.iter();
641 let const_moving_intial_condition_iter = self.const_moving_intial_condition.iter();
642 let hypothetical = self.hypothetical;
643 let fn_state = &self.fn_state;
644 let fn_finalize_iter = self.fn_finalize.iter();
645 let fn_combine_iter = self.fn_combine.iter();
646 let fn_serial_iter = self.fn_serial.iter();
647 let fn_deserial_iter = self.fn_deserial.iter();
648 let fn_moving_state_iter = self.fn_moving_state.iter();
649 let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
650 let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
651 let to_sql_config = &self.to_sql_config;
652
653 quote! {
654 #[no_mangle]
655 #[doc(hidden)]
656 #[allow(unknown_lints, clippy::no_mangle_with_rust_abi)]
657 pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity {
658 let submission = ::pgrx::pgrx_sql_entity_graph::PgAggregateEntity {
659 full_path: ::core::any::type_name::<#target_ident>(),
660 module_path: module_path!(),
661 file: file!(),
662 line: line!(),
663 name: #name,
664 ordered_set: #const_ordered_set,
665 ty_id: ::core::any::TypeId::of::<#target_ident>(),
666 args: #type_args_iter,
667 direct_args: None #( .unwrap_or(Some(#type_order_by_args_iter)) )*,
668 stype: #type_stype,
669 sfunc: stringify!(#fn_state),
670 combinefunc: None #( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*,
671 finalfunc: None #( .unwrap_or(Some(stringify!(#fn_finalize_iter))) )*,
672 finalfunc_modify: None #( .unwrap_or(#const_finalize_modify_iter) )*,
673 initcond: None #( .unwrap_or(Some(#const_initial_condition_iter)) )*,
674 serialfunc: None #( .unwrap_or(Some(stringify!(#fn_serial_iter))) )*,
675 deserialfunc: None #( .unwrap_or(Some(stringify!(#fn_deserial_iter))) )*,
676 msfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_iter))) )*,
677 minvfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_inverse_iter))) )*,
678 mstype: None #( .unwrap_or(Some(#type_moving_state_entity_tokens_iter)) )*,
679 mfinalfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_finalize_iter))) )*,
680 mfinalfunc_modify: None #( .unwrap_or(#const_moving_finalize_modify_iter) )*,
681 minitcond: None #( .unwrap_or(Some(#const_moving_intial_condition_iter)) )*,
682 sortop: None #( .unwrap_or(Some(#const_sort_operator_iter)) )*,
683 parallel: None #( .unwrap_or(#const_parallel_iter) )*,
684 hypothetical: #hypothetical,
685 to_sql_config: #to_sql_config,
686 };
687 ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity::Aggregate(submission)
688 }
689 }
690 }
691}
692
693impl ToRustCodeTokens for PgAggregate {
694 fn to_rust_code_tokens(&self) -> TokenStream2 {
695 let impl_item = &self.item_impl;
696 let pg_externs = self.pg_externs.iter();
697 quote! {
698 #impl_item
699 #(#pg_externs)*
700 }
701 }
702}
703
704impl Parse for CodeEnrichment<PgAggregate> {
705 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
706 PgAggregate::new(input.parse()?)
707 }
708}
709
710fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
711 let last = path.segments.last().ok_or_else(|| {
712 syn::Error::new(
713 path.span(),
714 "`#[pg_aggregate]` only works with types whose path have a final segment.",
715 )
716 })?;
717 Ok(last.ident.clone())
718}
719
720fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
721 let target_ident = match &*item_impl.self_ty {
722 syn::Type::Path(ref type_path) => {
723 let last_segment = type_path.path.segments.last().ok_or_else(|| {
724 syn::Error::new(
725 type_path.span(),
726 "`#[pg_aggregate]` only works with types whose path have a final segment.",
727 )
728 })?;
729 if last_segment.ident == "PgVarlena" {
730 match &last_segment.arguments {
731 syn::PathArguments::AngleBracketed(angled) => {
732 let first = angled.args.first().ok_or_else(|| syn::Error::new(
733 type_path.span(),
734 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
735 ))?;
736 match &first {
737 syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
738 _ => return Err(syn::Error::new(
739 type_path.span(),
740 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
741 )),
742 }
743 },
744 _ => return Err(syn::Error::new(
745 type_path.span(),
746 "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
747 )),
748 }
749 } else {
750 type_path.path.clone()
751 }
752 }
753 something_else => {
754 return Err(syn::Error::new(
755 something_else.span(),
756 "`#[pg_aggregate]` only works with types.",
757 ))
758 }
759 };
760 Ok(target_ident)
761}
762
763fn pg_extern_attr(item: &ImplItemFn) -> syn::Attribute {
764 let mut found = None;
765 for attr in item.attrs.iter() {
766 match attr.path().segments.last() {
767 Some(segment) if segment.ident == "pgrx" => {
768 found = Some(attr);
769 break;
770 }
771 _ => (),
772 };
773 }
774
775 let attrs = if let Some(attr) = found {
776 let parser = Punctuated::<super::pg_extern::Attribute, syn::Token![,]>::parse_terminated;
777 let attrs = attr.parse_args_with(parser);
778 attrs.ok()
779 } else {
780 None
781 };
782
783 match attrs {
784 Some(args) => parse_quote! {
785 #[::pgrx::pg_extern(#args)]
786 },
787 None => parse_quote! {
788 #[::pgrx::pg_extern]
789 },
790 }
791}
792
793fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
794 let mut needle = None;
795 for impl_item_type in item_impl.items.iter().filter_map(|impl_item| match impl_item {
796 syn::ImplItem::Type(iitype) => Some(iitype),
797 _ => None,
798 }) {
799 let ident_string = impl_item_type.ident.to_string();
800 if ident_string == name {
801 needle = Some(impl_item_type);
802 }
803 }
804 needle
805}
806
807fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemFn> {
808 let mut needle = None;
809 for impl_item_fn in item_impl.items.iter().filter_map(|impl_item| match impl_item {
810 syn::ImplItem::Fn(iifn) => Some(iifn),
811 _ => None,
812 }) {
813 let ident_string = impl_item_fn.sig.ident.to_string();
814 if ident_string == name {
815 needle = Some(impl_item_fn);
816 }
817 }
818 needle
819}
820
821fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
822 let mut needle = None;
823 for impl_item_const in item_impl.items.iter().filter_map(|impl_item| match impl_item {
824 syn::ImplItem::Const(iiconst) => Some(iiconst),
825 _ => None,
826 }) {
827 let ident_string = impl_item_const.ident.to_string();
828 if ident_string == name {
829 needle = Some(impl_item_const);
830 }
831 }
832 needle
833}
834
835fn get_const_litbool(item: &ImplItemConst) -> Option<bool> {
836 match &item.expr {
837 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
838 syn::Lit::Bool(lit) => Some(lit.value()),
839 _ => None,
840 },
841 _ => None,
842 }
843}
844
845fn get_const_litstr(item: &ImplItemConst) -> syn::Result<Option<String>> {
846 match &item.expr {
847 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
848 syn::Lit::Str(lit) => Ok(Some(lit.value())),
849 _ => Ok(None),
850 },
851 syn::Expr::Call(expr_call) => match &*expr_call.func {
852 syn::Expr::Path(expr_path) => {
853 let Some(last) = expr_path.path.segments.last() else {
854 return Ok(None);
855 };
856 if last.ident == "Some" {
857 match expr_call.args.first() {
858 Some(syn::Expr::Lit(expr_lit)) => match &expr_lit.lit {
859 syn::Lit::Str(lit) => Ok(Some(lit.value())),
860 _ => Ok(None),
861 },
862 _ => Ok(None),
863 }
864 } else {
865 Ok(None)
866 }
867 }
868 _ => Ok(None),
869 },
870 ex => Err(syn::Error::new(ex.span(), "")),
871 }
872}
873
874fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
875 if let Type::Path(ref mut ty_path) = ty {
876 for segment in ty_path.path.segments.iter_mut() {
877 if segment.ident == "Self" {
878 segment.ident = target.clone()
879 }
880 use syn::{GenericArgument, PathArguments};
881 match segment.arguments {
882 PathArguments::AngleBracketed(ref mut angle_args) => {
883 for arg in angle_args.args.iter_mut() {
884 if let GenericArgument::Type(inner_ty) = arg {
885 remap_self_to_target(inner_ty, target)
886 }
887 }
888 }
889 PathArguments::Parenthesized(_) => (),
890 PathArguments::None => (),
891 }
892 }
893 }
894}
895
896fn get_pgrx_attr_macro(attr_name: impl AsRef<str>, ty: &syn::Type) -> Option<TokenStream2> {
897 match &ty {
898 syn::Type::Macro(ty_macro) => {
899 let mut found_pgrx = false;
900 let mut found_attr = false;
901 for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
903 match segment.ident.to_string().as_str() {
904 "pgrx" if idx == 0 => found_pgrx = true,
905 attr if attr == attr_name.as_ref() => found_attr = true,
906 _ => (),
907 }
908 }
909 if (found_pgrx || ty_macro.mac.path.segments.len() == 1) && found_attr {
910 Some(ty_macro.mac.tokens.clone())
911 } else {
912 None
913 }
914 }
915 _ => None,
916 }
917}
918
919#[cfg(test)]
920mod tests {
921 use super::PgAggregate;
922 use eyre::Result;
923 use quote::ToTokens;
924 use syn::{parse_quote, ItemImpl};
925
926 #[test]
927 fn agg_required_only() -> Result<()> {
928 let tokens: ItemImpl = parse_quote! {
929 #[pg_aggregate]
930 impl Aggregate for DemoAgg {
931 type State = PgVarlena<Self>;
932 type Args = i32;
933 const NAME: &'static str = "DEMO";
934
935 fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
936 todo!()
937 }
938 }
939 };
940 let agg = PgAggregate::new(tokens);
942 assert!(agg.is_ok());
943 let agg = agg.unwrap();
945 assert_eq!(agg.0.pg_externs.len(), 1);
946 let extern_fn = &agg.0.pg_externs[0];
948 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
949 let _ = agg.to_token_stream();
951 Ok(())
952 }
953
954 #[test]
955 fn agg_all_options() -> Result<()> {
956 let tokens: ItemImpl = parse_quote! {
957 #[pg_aggregate]
958 impl Aggregate for DemoAgg {
959 type State = PgVarlena<Self>;
960 type Args = i32;
961 type OrderBy = i32;
962 type MovingState = i32;
963
964 const NAME: &'static str = "DEMO";
965
966 const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
967 const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
968 const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
969 const SORT_OPERATOR: Option<&'static str> = Some("sortop");
970 const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
971 const HYPOTHETICAL: bool = true;
972
973 fn state(current: Self::State, v: Self::Args) -> Self::State {
974 todo!()
975 }
976
977 fn finalize(current: Self::State) -> Self::Finalize {
978 todo!()
979 }
980
981 fn combine(current: Self::State, _other: Self::State) -> Self::State {
982 todo!()
983 }
984
985 fn serial(current: Self::State) -> Vec<u8> {
986 todo!()
987 }
988
989 fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
990 todo!()
991 }
992
993 fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
994 todo!()
995 }
996
997 fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
998 todo!()
999 }
1000
1001 fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
1002 todo!()
1003 }
1004 }
1005 };
1006 let agg = PgAggregate::new(tokens);
1008 assert!(agg.is_ok());
1009 let agg = agg.unwrap();
1011 assert_eq!(agg.0.pg_externs.len(), 8);
1012 let extern_fn = &agg.0.pg_externs[0];
1014 assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
1015 let _ = agg.to_token_stream();
1017 Ok(())
1018 }
1019
1020 #[test]
1021 fn agg_missing_required() -> Result<()> {
1022 let tokens: ItemImpl = parse_quote! {
1024 #[pg_aggregate]
1025 impl Aggregate for IntegerAvgState {
1026 }
1027 };
1028 let agg = PgAggregate::new(tokens);
1029 assert!(agg.is_err());
1030 Ok(())
1031 }
1032}