pgrx_sql_entity_graph/aggregate/
mod.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_aggregate]` related macro expansion for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19mod aggregate_type;
20pub(crate) mod entity;
21mod options;
22
23pub use aggregate_type::{AggregateType, AggregateTypeList};
24pub use options::{FinalizeModify, ParallelOption};
25
26use crate::enrich::CodeEnrichment;
27use crate::enrich::ToEntityGraphTokens;
28use crate::enrich::ToRustCodeTokens;
29use convert_case::{Case, Casing};
30use proc_macro2::{Ident, Span, TokenStream as TokenStream2};
31use quote::quote;
32use syn::parse::{Parse, ParseStream};
33use syn::punctuated::Punctuated;
34use syn::spanned::Spanned;
35use syn::{
36    parse_quote, Expr, ImplItemConst, ImplItemFn, ImplItemType, ItemFn, ItemImpl, Path, Type,
37};
38
39use crate::ToSqlConfig;
40
41use super::UsedType;
42
43// We support only 32 tuples...
44const ARG_NAMES: [&str; 32] = [
45    "arg_one",
46    "arg_two",
47    "arg_three",
48    "arg_four",
49    "arg_five",
50    "arg_six",
51    "arg_seven",
52    "arg_eight",
53    "arg_nine",
54    "arg_ten",
55    "arg_eleven",
56    "arg_twelve",
57    "arg_thirteen",
58    "arg_fourteen",
59    "arg_fifteen",
60    "arg_sixteen",
61    "arg_seventeen",
62    "arg_eighteen",
63    "arg_nineteen",
64    "arg_twenty",
65    "arg_twenty_one",
66    "arg_twenty_two",
67    "arg_twenty_three",
68    "arg_twenty_four",
69    "arg_twenty_five",
70    "arg_twenty_six",
71    "arg_twenty_seven",
72    "arg_twenty_eight",
73    "arg_twenty_nine",
74    "arg_thirty",
75    "arg_thirty_one",
76    "arg_thirty_two",
77];
78
79/** A parsed `#[pg_aggregate]` item.
80*/
81#[derive(Debug, Clone)]
82pub struct PgAggregate {
83    item_impl: ItemImpl,
84    name: Expr,
85    target_ident: Ident,
86    pg_externs: Vec<ItemFn>,
87    // Note these should not be considered *writable*, they're snapshots from construction.
88    type_args: AggregateTypeList,
89    type_ordered_set_args: Option<AggregateTypeList>,
90    type_moving_state: Option<UsedType>,
91    type_stype: AggregateType,
92    const_ordered_set: bool,
93    const_parallel: Option<syn::Expr>,
94    const_finalize_modify: Option<syn::Expr>,
95    const_moving_finalize_modify: Option<syn::Expr>,
96    const_initial_condition: Option<String>,
97    const_sort_operator: Option<String>,
98    const_moving_intial_condition: Option<String>,
99    fn_state: Ident,
100    fn_finalize: Option<Ident>,
101    fn_combine: Option<Ident>,
102    fn_serial: Option<Ident>,
103    fn_deserial: Option<Ident>,
104    fn_moving_state: Option<Ident>,
105    fn_moving_state_inverse: Option<Ident>,
106    fn_moving_finalize: Option<Ident>,
107    hypothetical: bool,
108    to_sql_config: ToSqlConfig,
109}
110
111impl PgAggregate {
112    pub fn new(mut item_impl: ItemImpl) -> Result<CodeEnrichment<Self>, syn::Error> {
113        let to_sql_config =
114            ToSqlConfig::from_attributes(item_impl.attrs.as_slice())?.unwrap_or_default();
115        let target_path = get_target_path(&item_impl)?;
116        let target_ident = get_target_ident(&target_path)?;
117
118        let snake_case_target_ident =
119            Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
120        crate::ident_is_acceptable_to_postgres(&snake_case_target_ident)?;
121
122        let mut pg_externs = Vec::default();
123        // We want to avoid having multiple borrows, so we take a snapshot to scan from,
124        // and mutate the actual one.
125        let item_impl_snapshot = item_impl.clone();
126
127        if let Some((_, ref path, _)) = item_impl.trait_ {
128            // TODO: Consider checking the path if there is more than one segment to make sure it's pgrx.
129            if let Some(last) = path.segments.last() {
130                if last.ident != "Aggregate" {
131                    return Err(syn::Error::new(
132                        last.ident.span(),
133                        "`#[pg_aggregate]` only works with the `Aggregate` trait.",
134                    ));
135                }
136            }
137        }
138
139        let name = match get_impl_const_by_name(&item_impl_snapshot, "NAME") {
140            Some(item_const) => match &item_const.expr {
141                syn::Expr::Lit(ref expr) => {
142                    if let syn::Lit::Str(_) = &expr.lit {
143                        item_const.expr.clone()
144                    } else {
145                        return Err(syn::Error::new(
146                            expr.span(),
147                            "`NAME` must be a `&'static str` for Aggregate implementations.",
148                        ));
149                    }
150                }
151                e => {
152                    return Err(syn::Error::new(
153                        e.span(),
154                        "`NAME` must be a `&'static str` for Aggregate implementations.",
155                    ));
156                }
157            },
158            None => {
159                item_impl.items.push(parse_quote! {
160                    const NAME: &'static str = stringify!(Self);
161                });
162                parse_quote! {
163                    stringify!(#target_ident)
164                }
165            }
166        };
167
168        // `State` is an optional value, we default to `Self`.
169        let type_state = get_impl_type_by_name(&item_impl_snapshot, "State");
170        let _type_state_value = type_state.map(|v| v.ty.clone());
171
172        let type_state_without_self = if let Some(inner) = type_state {
173            let mut remapped = inner.ty.clone();
174            remap_self_to_target(&mut remapped, &target_ident);
175            remapped
176        } else {
177            item_impl.items.push(parse_quote! {
178                type State = Self;
179            });
180            let mut remapped = parse_quote!(Self);
181            remap_self_to_target(&mut remapped, &target_ident);
182            remapped
183        };
184        let type_stype = AggregateType {
185            used_ty: UsedType::new(type_state_without_self.clone())?,
186            name: Some("state".into()),
187        };
188
189        // `MovingState` is an optional value, we default to nothing.
190        let impl_type_moving_state = get_impl_type_by_name(&item_impl_snapshot, "MovingState");
191        let type_moving_state;
192        let type_moving_state_value = if let Some(impl_type_moving_state) = impl_type_moving_state {
193            type_moving_state = impl_type_moving_state.ty.clone();
194            Some(UsedType::new(type_moving_state.clone())?)
195        } else {
196            item_impl.items.push(parse_quote! {
197                type MovingState = ();
198            });
199            type_moving_state = parse_quote! { () };
200            None
201        };
202
203        // `OrderBy` is an optional value, we default to nothing.
204        let type_ordered_set_args = get_impl_type_by_name(&item_impl_snapshot, "OrderedSetArgs");
205        let type_ordered_set_args_value =
206            type_ordered_set_args.map(|v| AggregateTypeList::new(v.ty.clone())).transpose()?;
207        if type_ordered_set_args.is_none() {
208            item_impl.items.push(parse_quote! {
209                type OrderedSetArgs = ();
210            })
211        }
212        let (direct_args_with_names, direct_arg_names) = if let Some(ref order_by_direct_args) =
213            type_ordered_set_args_value
214        {
215            let direct_args = order_by_direct_args
216                .found
217                .iter()
218                .map(|x| {
219                    (x.name.clone(), x.used_ty.resolved_ty.clone(), x.used_ty.original_ty.clone())
220                })
221                .collect::<Vec<_>>();
222            let direct_arg_names = ARG_NAMES[0..direct_args.len()]
223                .iter()
224                .zip(direct_args.iter())
225                .map(|(default_name, (custom_name, _ty, _orig))| {
226                    Ident::new(
227                        &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
228                        Span::mixed_site(),
229                    )
230                })
231                .collect::<Vec<_>>();
232            let direct_args_with_names = direct_args
233                .iter()
234                .zip(direct_arg_names.iter())
235                .map(|(arg, name)| {
236                    let arg_ty = &arg.2; // original_type
237                    parse_quote! {
238                        #name: #arg_ty
239                    }
240                })
241                .collect::<Vec<syn::FnArg>>();
242            (direct_args_with_names, direct_arg_names)
243        } else {
244            (Vec::default(), Vec::default())
245        };
246
247        // `Args` is an optional value, we default to nothing.
248        let type_args = get_impl_type_by_name(&item_impl_snapshot, "Args").ok_or_else(|| {
249            syn::Error::new(
250                item_impl_snapshot.span(),
251                "`#[pg_aggregate]` requires the `Args` type defined.",
252            )
253        })?;
254        let type_args_value = AggregateTypeList::new(type_args.ty.clone())?;
255        let args = type_args_value
256            .found
257            .iter()
258            .map(|x| (x.name.clone(), x.used_ty.original_ty.clone()))
259            .collect::<Vec<_>>();
260        let arg_names = ARG_NAMES[0..args.len()]
261            .iter()
262            .zip(args.iter())
263            .map(|(default_name, (custom_name, ty))| {
264                Ident::new(
265                    &custom_name.clone().unwrap_or_else(|| default_name.to_string()),
266                    ty.span(),
267                )
268            })
269            .collect::<Vec<_>>();
270        let args_with_names = args
271            .iter()
272            .zip(arg_names.iter())
273            .map(|(arg, name)| {
274                let arg_ty = &arg.1;
275                quote! {
276                    #name: #arg_ty
277                }
278            })
279            .collect::<Vec<_>>();
280
281        // `Finalize` is an optional value, we default to nothing.
282        let impl_type_finalize = get_impl_type_by_name(&item_impl_snapshot, "Finalize");
283        let type_finalize: syn::Type = if let Some(type_finalize) = impl_type_finalize {
284            type_finalize.ty.clone()
285        } else {
286            item_impl.items.push(parse_quote! {
287                type Finalize = ();
288            });
289            parse_quote! { () }
290        };
291
292        let fn_state = get_impl_func_by_name(&item_impl_snapshot, "state");
293
294        let fn_state_name = if let Some(found) = fn_state {
295            let fn_name =
296                Ident::new(&format!("{}_state", snake_case_target_ident), found.sig.ident.span());
297            let pg_extern_attr = pg_extern_attr(found);
298
299            pg_externs.push(parse_quote! {
300                #[allow(non_snake_case, clippy::too_many_arguments)]
301                #pg_extern_attr
302                fn #fn_name(this: #type_state_without_self, #(#args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
303                    unsafe {
304                        <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
305                            fcinfo,
306                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::state(this, (#(#arg_names),*), fcinfo)
307                        )
308                    }
309                }
310            });
311            fn_name
312        } else {
313            return Err(syn::Error::new(
314                item_impl.span(),
315                "Aggregate implementation must include state function.",
316            ));
317        };
318
319        let fn_combine = get_impl_func_by_name(&item_impl_snapshot, "combine");
320        let fn_combine_name = if let Some(found) = fn_combine {
321            let fn_name =
322                Ident::new(&format!("{}_combine", snake_case_target_ident), found.sig.ident.span());
323            let pg_extern_attr = pg_extern_attr(found);
324            pg_externs.push(parse_quote! {
325                #[allow(non_snake_case, clippy::too_many_arguments)]
326                #pg_extern_attr
327                fn #fn_name(this: #type_state_without_self, v: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
328                    unsafe {
329                        <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
330                            fcinfo,
331                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::combine(this, v, fcinfo)
332                        )
333                    }
334                }
335            });
336            Some(fn_name)
337        } else {
338            item_impl.items.push(parse_quote! {
339                fn combine(current: #type_state_without_self, _other: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_state_without_self {
340                    unimplemented!("Call to combine on an aggregate which does not support it.")
341                }
342            });
343            None
344        };
345
346        let fn_finalize = get_impl_func_by_name(&item_impl_snapshot, "finalize");
347        let fn_finalize_name = if let Some(found) = fn_finalize {
348            let fn_name = Ident::new(
349                &format!("{}_finalize", snake_case_target_ident),
350                found.sig.ident.span(),
351            );
352            let pg_extern_attr = pg_extern_attr(found);
353
354            if !direct_args_with_names.is_empty() {
355                pg_externs.push(parse_quote! {
356                    #[allow(non_snake_case, clippy::too_many_arguments)]
357                    #pg_extern_attr
358                    fn #fn_name(this: #type_state_without_self, #(#direct_args_with_names),*, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
359                        unsafe {
360                            <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
361                                fcinfo,
362                                move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::finalize(this, (#(#direct_arg_names),*), fcinfo)
363                            )
364                        }
365                    }
366                });
367            } else {
368                pg_externs.push(parse_quote! {
369                    #[allow(non_snake_case, clippy::too_many_arguments)]
370                    #pg_extern_attr
371                    fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
372                        unsafe {
373                            <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
374                                fcinfo,
375                                move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::finalize(this, (), fcinfo)
376                            )
377                        }
378                    }
379                });
380            };
381            Some(fn_name)
382        } else {
383            item_impl.items.push(parse_quote! {
384                fn finalize(current: Self::State, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
385                    unimplemented!("Call to finalize on an aggregate which does not support it.")
386                }
387            });
388            None
389        };
390
391        let fn_serial = get_impl_func_by_name(&item_impl_snapshot, "serial");
392        let fn_serial_name = if let Some(found) = fn_serial {
393            let fn_name =
394                Ident::new(&format!("{}_serial", snake_case_target_ident), found.sig.ident.span());
395            let pg_extern_attr = pg_extern_attr(found);
396            pg_externs.push(parse_quote! {
397                #[allow(non_snake_case, clippy::too_many_arguments)]
398                #pg_extern_attr
399                fn #fn_name(this: #type_state_without_self, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
400                    unsafe {
401                        <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
402                            fcinfo,
403                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::serial(this, fcinfo)
404                        )
405                    }
406                }
407            });
408            Some(fn_name)
409        } else {
410            item_impl.items.push(parse_quote! {
411                fn serial(current: #type_state_without_self, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Vec<u8> {
412                    unimplemented!("Call to serial on an aggregate which does not support it.")
413                }
414            });
415            None
416        };
417
418        let fn_deserial = get_impl_func_by_name(&item_impl_snapshot, "deserial");
419        let fn_deserial_name = if let Some(found) = fn_deserial {
420            let fn_name = Ident::new(
421                &format!("{}_deserial", snake_case_target_ident),
422                found.sig.ident.span(),
423            );
424            let pg_extern_attr = pg_extern_attr(found);
425            pg_externs.push(parse_quote! {
426                #[allow(non_snake_case, clippy::too_many_arguments)]
427                #pg_extern_attr
428                fn #fn_name(this: #type_state_without_self, buf: Vec<u8>, internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
429                    unsafe {
430                        <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
431                            fcinfo,
432                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::deserial(this, buf, internal, fcinfo)
433                        )
434                    }
435                }
436            });
437            Some(fn_name)
438        } else {
439            item_impl.items.push(parse_quote! {
440                fn deserial(current: #type_state_without_self, _buf: Vec<u8>, _internal: ::pgrx::pgbox::PgBox<#type_state_without_self>, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> ::pgrx::pgbox::PgBox<#type_state_without_self> {
441                    unimplemented!("Call to deserial on an aggregate which does not support it.")
442                }
443            });
444            None
445        };
446
447        let fn_moving_state = get_impl_func_by_name(&item_impl_snapshot, "moving_state");
448        let fn_moving_state_name = if let Some(found) = fn_moving_state {
449            let fn_name = Ident::new(
450                &format!("{}_moving_state", snake_case_target_ident),
451                found.sig.ident.span(),
452            );
453            let pg_extern_attr = pg_extern_attr(found);
454
455            pg_externs.push(parse_quote! {
456                #[allow(non_snake_case, clippy::too_many_arguments)]
457                #pg_extern_attr
458                fn #fn_name(
459                    mstate: #type_moving_state,
460                    #(#args_with_names),*,
461                    fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
462                ) -> #type_moving_state {
463                    unsafe {
464                        <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
465                            fcinfo,
466                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::moving_state(mstate, (#(#arg_names),*), fcinfo)
467                        )
468                    }
469                }
470            });
471            Some(fn_name)
472        } else {
473            item_impl.items.push(parse_quote! {
474                fn moving_state(
475                    _mstate: <#target_path as ::pgrx::aggregate::Aggregate>::MovingState,
476                    _v: Self::Args,
477                    _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
478                ) -> <#target_path as ::pgrx::aggregate::Aggregate>::MovingState {
479                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
480                }
481            });
482            None
483        };
484
485        let fn_moving_state_inverse =
486            get_impl_func_by_name(&item_impl_snapshot, "moving_state_inverse");
487        let fn_moving_state_inverse_name = if let Some(found) = fn_moving_state_inverse {
488            let fn_name = Ident::new(
489                &format!("{}_moving_state_inverse", snake_case_target_ident),
490                found.sig.ident.span(),
491            );
492            let pg_extern_attr = pg_extern_attr(found);
493            pg_externs.push(parse_quote! {
494                #[allow(non_snake_case, clippy::too_many_arguments)]
495                #pg_extern_attr
496                fn #fn_name(
497                    mstate: #type_moving_state,
498                    #(#args_with_names),*,
499                    fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
500                ) -> #type_moving_state {
501                    unsafe {
502                        <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
503                            fcinfo,
504                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::moving_state_inverse(mstate, (#(#arg_names),*), fcinfo)
505                        )
506                    }
507                }
508            });
509            Some(fn_name)
510        } else {
511            item_impl.items.push(parse_quote! {
512                fn moving_state_inverse(
513                    _mstate: #type_moving_state,
514                    _v: Self::Args,
515                    _fcinfo: ::pgrx::pg_sys::FunctionCallInfo,
516                ) -> #type_moving_state {
517                    unimplemented!("Call to moving_state on an aggregate which does not support it.")
518                }
519            });
520            None
521        };
522
523        let fn_moving_finalize = get_impl_func_by_name(&item_impl_snapshot, "moving_finalize");
524        let fn_moving_finalize_name = if let Some(found) = fn_moving_finalize {
525            let fn_name = Ident::new(
526                &format!("{}_moving_finalize", snake_case_target_ident),
527                found.sig.ident.span(),
528            );
529            let pg_extern_attr = pg_extern_attr(found);
530            let maybe_comma: Option<syn::Token![,]> =
531                if !direct_args_with_names.is_empty() { Some(parse_quote! {,}) } else { None };
532
533            pg_externs.push(parse_quote! {
534                #[allow(non_snake_case, clippy::too_many_arguments)]
535                #pg_extern_attr
536                fn #fn_name(mstate: #type_moving_state, #(#direct_args_with_names),* #maybe_comma fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> #type_finalize {
537                    unsafe {
538                        <#target_path as ::pgrx::aggregate::Aggregate>::in_memory_context(
539                            fcinfo,
540                            move |_context| <#target_path as ::pgrx::aggregate::Aggregate>::moving_finalize(mstate, (#(#direct_arg_names),*), fcinfo)
541                        )
542                    }
543                }
544            });
545            Some(fn_name)
546        } else {
547            item_impl.items.push(parse_quote! {
548                fn moving_finalize(_mstate: Self::MovingState, direct_args: Self::OrderedSetArgs, _fcinfo: ::pgrx::pg_sys::FunctionCallInfo) -> Self::Finalize {
549                    unimplemented!("Call to moving_finalize on an aggregate which does not support it.")
550                }
551            });
552            None
553        };
554
555        Ok(CodeEnrichment(Self {
556            item_impl,
557            target_ident,
558            pg_externs,
559            name,
560            type_args: type_args_value,
561            type_ordered_set_args: type_ordered_set_args_value,
562            type_moving_state: type_moving_state_value,
563            type_stype,
564            const_parallel: get_impl_const_by_name(&item_impl_snapshot, "PARALLEL")
565                .map(|x| x.expr.clone()),
566            const_finalize_modify: get_impl_const_by_name(&item_impl_snapshot, "FINALIZE_MODIFY")
567                .map(|x| x.expr.clone()),
568            const_moving_finalize_modify: get_impl_const_by_name(
569                &item_impl_snapshot,
570                "MOVING_FINALIZE_MODIFY",
571            )
572            .map(|x| x.expr.clone()),
573            const_initial_condition: get_impl_const_by_name(
574                &item_impl_snapshot,
575                "INITIAL_CONDITION",
576            )
577            .and_then(|e| get_const_litstr(e).transpose())
578            .transpose()?,
579            const_ordered_set: get_impl_const_by_name(&item_impl_snapshot, "ORDERED_SET")
580                .and_then(get_const_litbool)
581                .unwrap_or(false),
582            const_sort_operator: get_impl_const_by_name(&item_impl_snapshot, "SORT_OPERATOR")
583                .and_then(|e| get_const_litstr(e).transpose())
584                .transpose()?,
585            const_moving_intial_condition: get_impl_const_by_name(
586                &item_impl_snapshot,
587                "MOVING_INITIAL_CONDITION",
588            )
589            .and_then(|e| get_const_litstr(e).transpose())
590            .transpose()?,
591            fn_state: fn_state_name,
592            fn_finalize: fn_finalize_name,
593            fn_combine: fn_combine_name,
594            fn_serial: fn_serial_name,
595            fn_deserial: fn_deserial_name,
596            fn_moving_state: fn_moving_state_name,
597            fn_moving_state_inverse: fn_moving_state_inverse_name,
598            fn_moving_finalize: fn_moving_finalize_name,
599            hypothetical: if let Some(value) =
600                get_impl_const_by_name(&item_impl_snapshot, "HYPOTHETICAL")
601            {
602                match &value.expr {
603                    syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
604                        syn::Lit::Bool(lit) => lit.value,
605                        _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
606                    },
607                    _ => return Err(syn::Error::new(value.span(), "`#[pg_aggregate]` required the `HYPOTHETICAL` value to be a literal boolean.")),
608                }
609            } else {
610                false
611            },
612            to_sql_config,
613        }))
614    }
615}
616
617impl ToEntityGraphTokens for PgAggregate {
618    fn to_entity_graph_tokens(&self) -> TokenStream2 {
619        let target_ident = &self.target_ident;
620        let snake_case_target_ident =
621            Ident::new(&target_ident.to_string().to_case(Case::Snake), target_ident.span());
622        let sql_graph_entity_fn_name = syn::Ident::new(
623            &format!("__pgrx_internals_aggregate_{}", snake_case_target_ident),
624            target_ident.span(),
625        );
626
627        let name = &self.name;
628        let type_args_iter = &self.type_args.entity_tokens();
629        let type_order_by_args_iter = self.type_ordered_set_args.iter().map(|x| x.entity_tokens());
630
631        let type_moving_state_entity_tokens =
632            self.type_moving_state.clone().map(|v| v.entity_tokens());
633        let type_moving_state_entity_tokens_iter = type_moving_state_entity_tokens.iter();
634        let type_stype = self.type_stype.entity_tokens();
635        let const_ordered_set = self.const_ordered_set;
636        let const_parallel_iter = self.const_parallel.iter();
637        let const_finalize_modify_iter = self.const_finalize_modify.iter();
638        let const_moving_finalize_modify_iter = self.const_moving_finalize_modify.iter();
639        let const_initial_condition_iter = self.const_initial_condition.iter();
640        let const_sort_operator_iter = self.const_sort_operator.iter();
641        let const_moving_intial_condition_iter = self.const_moving_intial_condition.iter();
642        let hypothetical = self.hypothetical;
643        let fn_state = &self.fn_state;
644        let fn_finalize_iter = self.fn_finalize.iter();
645        let fn_combine_iter = self.fn_combine.iter();
646        let fn_serial_iter = self.fn_serial.iter();
647        let fn_deserial_iter = self.fn_deserial.iter();
648        let fn_moving_state_iter = self.fn_moving_state.iter();
649        let fn_moving_state_inverse_iter = self.fn_moving_state_inverse.iter();
650        let fn_moving_finalize_iter = self.fn_moving_finalize.iter();
651        let to_sql_config = &self.to_sql_config;
652
653        quote! {
654            #[no_mangle]
655            #[doc(hidden)]
656            #[allow(unknown_lints, clippy::no_mangle_with_rust_abi)]
657            pub extern "Rust" fn #sql_graph_entity_fn_name() -> ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity {
658                let submission = ::pgrx::pgrx_sql_entity_graph::PgAggregateEntity {
659                    full_path: ::core::any::type_name::<#target_ident>(),
660                    module_path: module_path!(),
661                    file: file!(),
662                    line: line!(),
663                    name: #name,
664                    ordered_set: #const_ordered_set,
665                    ty_id: ::core::any::TypeId::of::<#target_ident>(),
666                    args: #type_args_iter,
667                    direct_args: None #( .unwrap_or(Some(#type_order_by_args_iter)) )*,
668                    stype: #type_stype,
669                    sfunc: stringify!(#fn_state),
670                    combinefunc: None #( .unwrap_or(Some(stringify!(#fn_combine_iter))) )*,
671                    finalfunc: None #( .unwrap_or(Some(stringify!(#fn_finalize_iter))) )*,
672                    finalfunc_modify: None #( .unwrap_or(#const_finalize_modify_iter) )*,
673                    initcond: None #( .unwrap_or(Some(#const_initial_condition_iter)) )*,
674                    serialfunc: None #( .unwrap_or(Some(stringify!(#fn_serial_iter))) )*,
675                    deserialfunc: None #( .unwrap_or(Some(stringify!(#fn_deserial_iter))) )*,
676                    msfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_iter))) )*,
677                    minvfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_state_inverse_iter))) )*,
678                    mstype: None #( .unwrap_or(Some(#type_moving_state_entity_tokens_iter)) )*,
679                    mfinalfunc: None #( .unwrap_or(Some(stringify!(#fn_moving_finalize_iter))) )*,
680                    mfinalfunc_modify: None #( .unwrap_or(#const_moving_finalize_modify_iter) )*,
681                    minitcond: None #( .unwrap_or(Some(#const_moving_intial_condition_iter)) )*,
682                    sortop: None #( .unwrap_or(Some(#const_sort_operator_iter)) )*,
683                    parallel: None #( .unwrap_or(#const_parallel_iter) )*,
684                    hypothetical: #hypothetical,
685                    to_sql_config: #to_sql_config,
686                };
687                ::pgrx::pgrx_sql_entity_graph::SqlGraphEntity::Aggregate(submission)
688            }
689        }
690    }
691}
692
693impl ToRustCodeTokens for PgAggregate {
694    fn to_rust_code_tokens(&self) -> TokenStream2 {
695        let impl_item = &self.item_impl;
696        let pg_externs = self.pg_externs.iter();
697        quote! {
698            #impl_item
699            #(#pg_externs)*
700        }
701    }
702}
703
704impl Parse for CodeEnrichment<PgAggregate> {
705    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
706        PgAggregate::new(input.parse()?)
707    }
708}
709
710fn get_target_ident(path: &Path) -> Result<Ident, syn::Error> {
711    let last = path.segments.last().ok_or_else(|| {
712        syn::Error::new(
713            path.span(),
714            "`#[pg_aggregate]` only works with types whose path have a final segment.",
715        )
716    })?;
717    Ok(last.ident.clone())
718}
719
720fn get_target_path(item_impl: &ItemImpl) -> Result<Path, syn::Error> {
721    let target_ident = match &*item_impl.self_ty {
722        syn::Type::Path(ref type_path) => {
723            let last_segment = type_path.path.segments.last().ok_or_else(|| {
724                syn::Error::new(
725                    type_path.span(),
726                    "`#[pg_aggregate]` only works with types whose path have a final segment.",
727                )
728            })?;
729            if last_segment.ident == "PgVarlena" {
730                match &last_segment.arguments {
731                    syn::PathArguments::AngleBracketed(angled) => {
732                        let first = angled.args.first().ok_or_else(|| syn::Error::new(
733                            type_path.span(),
734                            "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
735                        ))?;
736                        match &first {
737                            syn::GenericArgument::Type(Type::Path(ty_path)) => ty_path.path.clone(),
738                            _ => return Err(syn::Error::new(
739                                type_path.span(),
740                                "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type path contained.",
741                            )),
742                        }
743                    },
744                    _ => return Err(syn::Error::new(
745                        type_path.span(),
746                        "`#[pg_aggregate]` only works with `PgVarlena` declarations if they have a type contained.",
747                    )),
748                }
749            } else {
750                type_path.path.clone()
751            }
752        }
753        something_else => {
754            return Err(syn::Error::new(
755                something_else.span(),
756                "`#[pg_aggregate]` only works with types.",
757            ))
758        }
759    };
760    Ok(target_ident)
761}
762
763fn pg_extern_attr(item: &ImplItemFn) -> syn::Attribute {
764    let mut found = None;
765    for attr in item.attrs.iter() {
766        match attr.path().segments.last() {
767            Some(segment) if segment.ident == "pgrx" => {
768                found = Some(attr);
769                break;
770            }
771            _ => (),
772        };
773    }
774
775    let attrs = if let Some(attr) = found {
776        let parser = Punctuated::<super::pg_extern::Attribute, syn::Token![,]>::parse_terminated;
777        let attrs = attr.parse_args_with(parser);
778        attrs.ok()
779    } else {
780        None
781    };
782
783    match attrs {
784        Some(args) => parse_quote! {
785            #[::pgrx::pg_extern(#args)]
786        },
787        None => parse_quote! {
788            #[::pgrx::pg_extern]
789        },
790    }
791}
792
793fn get_impl_type_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemType> {
794    let mut needle = None;
795    for impl_item_type in item_impl.items.iter().filter_map(|impl_item| match impl_item {
796        syn::ImplItem::Type(iitype) => Some(iitype),
797        _ => None,
798    }) {
799        let ident_string = impl_item_type.ident.to_string();
800        if ident_string == name {
801            needle = Some(impl_item_type);
802        }
803    }
804    needle
805}
806
807fn get_impl_func_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemFn> {
808    let mut needle = None;
809    for impl_item_fn in item_impl.items.iter().filter_map(|impl_item| match impl_item {
810        syn::ImplItem::Fn(iifn) => Some(iifn),
811        _ => None,
812    }) {
813        let ident_string = impl_item_fn.sig.ident.to_string();
814        if ident_string == name {
815            needle = Some(impl_item_fn);
816        }
817    }
818    needle
819}
820
821fn get_impl_const_by_name<'a>(item_impl: &'a ItemImpl, name: &str) -> Option<&'a ImplItemConst> {
822    let mut needle = None;
823    for impl_item_const in item_impl.items.iter().filter_map(|impl_item| match impl_item {
824        syn::ImplItem::Const(iiconst) => Some(iiconst),
825        _ => None,
826    }) {
827        let ident_string = impl_item_const.ident.to_string();
828        if ident_string == name {
829            needle = Some(impl_item_const);
830        }
831    }
832    needle
833}
834
835fn get_const_litbool(item: &ImplItemConst) -> Option<bool> {
836    match &item.expr {
837        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
838            syn::Lit::Bool(lit) => Some(lit.value()),
839            _ => None,
840        },
841        _ => None,
842    }
843}
844
845fn get_const_litstr(item: &ImplItemConst) -> syn::Result<Option<String>> {
846    match &item.expr {
847        syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
848            syn::Lit::Str(lit) => Ok(Some(lit.value())),
849            _ => Ok(None),
850        },
851        syn::Expr::Call(expr_call) => match &*expr_call.func {
852            syn::Expr::Path(expr_path) => {
853                let Some(last) = expr_path.path.segments.last() else {
854                    return Ok(None);
855                };
856                if last.ident == "Some" {
857                    match expr_call.args.first() {
858                        Some(syn::Expr::Lit(expr_lit)) => match &expr_lit.lit {
859                            syn::Lit::Str(lit) => Ok(Some(lit.value())),
860                            _ => Ok(None),
861                        },
862                        _ => Ok(None),
863                    }
864                } else {
865                    Ok(None)
866                }
867            }
868            _ => Ok(None),
869        },
870        ex => Err(syn::Error::new(ex.span(), "")),
871    }
872}
873
874fn remap_self_to_target(ty: &mut syn::Type, target: &syn::Ident) {
875    if let Type::Path(ref mut ty_path) = ty {
876        for segment in ty_path.path.segments.iter_mut() {
877            if segment.ident == "Self" {
878                segment.ident = target.clone()
879            }
880            use syn::{GenericArgument, PathArguments};
881            match segment.arguments {
882                PathArguments::AngleBracketed(ref mut angle_args) => {
883                    for arg in angle_args.args.iter_mut() {
884                        if let GenericArgument::Type(inner_ty) = arg {
885                            remap_self_to_target(inner_ty, target)
886                        }
887                    }
888                }
889                PathArguments::Parenthesized(_) => (),
890                PathArguments::None => (),
891            }
892        }
893    }
894}
895
896fn get_pgrx_attr_macro(attr_name: impl AsRef<str>, ty: &syn::Type) -> Option<TokenStream2> {
897    match &ty {
898        syn::Type::Macro(ty_macro) => {
899            let mut found_pgrx = false;
900            let mut found_attr = false;
901            // We don't actually have type resolution here, this is a "Best guess".
902            for (idx, segment) in ty_macro.mac.path.segments.iter().enumerate() {
903                match segment.ident.to_string().as_str() {
904                    "pgrx" if idx == 0 => found_pgrx = true,
905                    attr if attr == attr_name.as_ref() => found_attr = true,
906                    _ => (),
907                }
908            }
909            if (found_pgrx || ty_macro.mac.path.segments.len() == 1) && found_attr {
910                Some(ty_macro.mac.tokens.clone())
911            } else {
912                None
913            }
914        }
915        _ => None,
916    }
917}
918
919#[cfg(test)]
920mod tests {
921    use super::PgAggregate;
922    use eyre::Result;
923    use quote::ToTokens;
924    use syn::{parse_quote, ItemImpl};
925
926    #[test]
927    fn agg_required_only() -> Result<()> {
928        let tokens: ItemImpl = parse_quote! {
929            #[pg_aggregate]
930            impl Aggregate for DemoAgg {
931                type State = PgVarlena<Self>;
932                type Args = i32;
933                const NAME: &'static str = "DEMO";
934
935                fn state(mut current: Self::State, arg: Self::Args) -> Self::State {
936                    todo!()
937                }
938            }
939        };
940        // It should not error, as it's valid.
941        let agg = PgAggregate::new(tokens);
942        assert!(agg.is_ok());
943        // It should create 1 extern, the state.
944        let agg = agg.unwrap();
945        assert_eq!(agg.0.pg_externs.len(), 1);
946        // That extern should be named specifically:
947        let extern_fn = &agg.0.pg_externs[0];
948        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
949        // It should be possible to generate entity tokens.
950        let _ = agg.to_token_stream();
951        Ok(())
952    }
953
954    #[test]
955    fn agg_all_options() -> Result<()> {
956        let tokens: ItemImpl = parse_quote! {
957            #[pg_aggregate]
958            impl Aggregate for DemoAgg {
959                type State = PgVarlena<Self>;
960                type Args = i32;
961                type OrderBy = i32;
962                type MovingState = i32;
963
964                const NAME: &'static str = "DEMO";
965
966                const PARALLEL: Option<ParallelOption> = Some(ParallelOption::Safe);
967                const FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
968                const MOVING_FINALIZE_MODIFY: Option<FinalizeModify> = Some(FinalizeModify::ReadWrite);
969                const SORT_OPERATOR: Option<&'static str> = Some("sortop");
970                const MOVING_INITIAL_CONDITION: Option<&'static str> = Some("1,1");
971                const HYPOTHETICAL: bool = true;
972
973                fn state(current: Self::State, v: Self::Args) -> Self::State {
974                    todo!()
975                }
976
977                fn finalize(current: Self::State) -> Self::Finalize {
978                    todo!()
979                }
980
981                fn combine(current: Self::State, _other: Self::State) -> Self::State {
982                    todo!()
983                }
984
985                fn serial(current: Self::State) -> Vec<u8> {
986                    todo!()
987                }
988
989                fn deserial(current: Self::State, _buf: Vec<u8>, _internal: PgBox<Self>) -> PgBox<Self> {
990                    todo!()
991                }
992
993                fn moving_state(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
994                    todo!()
995                }
996
997                fn moving_state_inverse(_mstate: Self::MovingState, _v: Self::Args) -> Self::MovingState {
998                    todo!()
999                }
1000
1001                fn moving_finalize(_mstate: Self::MovingState) -> Self::Finalize {
1002                    todo!()
1003                }
1004            }
1005        };
1006        // It should not error, as it's valid.
1007        let agg = PgAggregate::new(tokens);
1008        assert!(agg.is_ok());
1009        // It should create 8 externs!
1010        let agg = agg.unwrap();
1011        assert_eq!(agg.0.pg_externs.len(), 8);
1012        // That extern should be named specifically:
1013        let extern_fn = &agg.0.pg_externs[0];
1014        assert_eq!(extern_fn.sig.ident.to_string(), "demo_agg_state");
1015        // It should be possible to generate entity tokens.
1016        let _ = agg.to_token_stream();
1017        Ok(())
1018    }
1019
1020    #[test]
1021    fn agg_missing_required() -> Result<()> {
1022        // This is not valid as it is missing required types/consts.
1023        let tokens: ItemImpl = parse_quote! {
1024            #[pg_aggregate]
1025            impl Aggregate for IntegerAvgState {
1026            }
1027        };
1028        let agg = PgAggregate::new(tokens);
1029        assert!(agg.is_err());
1030        Ok(())
1031    }
1032}