pgrx_sql_entity_graph/aggregate/
entity.rs

1//LICENSE Portions Copyright 2019-2021 ZomboDB, LLC.
2//LICENSE
3//LICENSE Portions Copyright 2021-2023 Technology Concepts & Design, Inc.
4//LICENSE
5//LICENSE Portions Copyright 2023-2023 PgCentral Foundation, Inc. <contact@pgcentral.org>
6//LICENSE
7//LICENSE All rights reserved.
8//LICENSE
9//LICENSE Use of this source code is governed by the MIT license that can be found in the LICENSE file.
10/*!
11
12`#[pg_aggregate]` related entities for Rust to SQL translation
13
14> Like all of the [`sql_entity_graph`][crate] APIs, this is considered **internal**
15> to the `pgrx` framework and very subject to change between versions. While you may use this, please do it with caution.
16
17
18*/
19use crate::aggregate::options::{FinalizeModify, ParallelOption};
20use crate::fmt;
21use crate::metadata::SqlMapping;
22use crate::pgrx_sql::PgrxSql;
23use crate::to_sql::entity::ToSqlConfigEntity;
24use crate::to_sql::ToSql;
25use crate::type_keyed;
26use crate::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
27use core::any::TypeId;
28use eyre::{eyre, WrapErr};
29
30#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
31pub struct AggregateTypeEntity {
32    pub used_ty: UsedTypeEntity,
33    pub name: Option<&'static str>,
34}
35
36#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
37pub struct PgAggregateEntity {
38    pub full_path: &'static str,
39    pub module_path: &'static str,
40    pub file: &'static str,
41    pub line: u32,
42    pub ty_id: TypeId,
43
44    pub name: &'static str,
45
46    /// If the aggregate is an ordered set aggregate.
47    ///
48    /// See [the PostgreSQL ordered set docs](https://www.postgresql.org/docs/current/xaggr.html#XAGGR-ORDERED-SET-AGGREGATES).
49    pub ordered_set: bool,
50
51    /// The `arg_data_type` list.
52    ///
53    /// Corresponds to `Args` in `pgrx::aggregate::Aggregate`.
54    pub args: Vec<AggregateTypeEntity>,
55
56    /// The direct argument list, appearing before `ORDER BY` in ordered set aggregates.
57    ///
58    /// Corresponds to `OrderBy` in `pgrx::aggregate::Aggregate`.
59    pub direct_args: Option<Vec<AggregateTypeEntity>>,
60
61    /// The `STYPE` and `name` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
62    ///
63    /// The implementor of an `pgrx::aggregate::Aggregate`.
64    pub stype: AggregateTypeEntity,
65
66    /// The `SFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
67    ///
68    /// Corresponds to `state` in `pgrx::aggregate::Aggregate`.
69    pub sfunc: &'static str,
70
71    /// The `FINALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
72    ///
73    /// Corresponds to `finalize` in `pgrx::aggregate::Aggregate`.
74    pub finalfunc: Option<&'static str>,
75
76    /// The `FINALFUNC_MODIFY` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
77    ///
78    /// Corresponds to `FINALIZE_MODIFY` in `pgrx::aggregate::Aggregate`.
79    pub finalfunc_modify: Option<FinalizeModify>,
80
81    /// The `COMBINEFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
82    ///
83    /// Corresponds to `combine` in `pgrx::aggregate::Aggregate`.
84    pub combinefunc: Option<&'static str>,
85
86    /// The `SERIALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
87    ///
88    /// Corresponds to `serial` in `pgrx::aggregate::Aggregate`.
89    pub serialfunc: Option<&'static str>,
90
91    /// The `DESERIALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
92    ///
93    /// Corresponds to `deserial` in `pgrx::aggregate::Aggregate`.
94    pub deserialfunc: Option<&'static str>,
95
96    /// The `INITCOND` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
97    ///
98    /// Corresponds to `INITIAL_CONDITION` in `pgrx::aggregate::Aggregate`.
99    pub initcond: Option<&'static str>,
100
101    /// The `MSFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
102    ///
103    /// Corresponds to `moving_state` in `pgrx::aggregate::Aggregate`.
104    pub msfunc: Option<&'static str>,
105
106    /// The `MINVFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
107    ///
108    /// Corresponds to `moving_state_inverse` in `pgrx::aggregate::Aggregate`.
109    pub minvfunc: Option<&'static str>,
110
111    /// The `MSTYPE` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
112    ///
113    /// Corresponds to `MovingState` in `pgrx::aggregate::Aggregate`.
114    pub mstype: Option<UsedTypeEntity>,
115
116    // The `MSSPACE` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
117    //
118    // TODO: Currently unused.
119    // pub msspace: &'static str,
120    /// The `MFINALFUNC` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
121    ///
122    /// Corresponds to `moving_state_finalize` in `pgrx::aggregate::Aggregate`.
123    pub mfinalfunc: Option<&'static str>,
124
125    /// The `MFINALFUNC_MODIFY` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
126    ///
127    /// Corresponds to `MOVING_FINALIZE_MODIFY` in `pgrx::aggregate::Aggregate`.
128    pub mfinalfunc_modify: Option<FinalizeModify>,
129
130    /// The `MINITCOND` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
131    ///
132    /// Corresponds to `MOVING_INITIAL_CONDITION` in `pgrx::aggregate::Aggregate`.
133    pub minitcond: Option<&'static str>,
134
135    /// The `SORTOP` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
136    ///
137    /// Corresponds to `SORT_OPERATOR` in `pgrx::aggregate::Aggregate`.
138    pub sortop: Option<&'static str>,
139
140    /// The `PARALLEL` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
141    ///
142    /// Corresponds to `PARALLEL` in `pgrx::aggregate::Aggregate`.
143    pub parallel: Option<ParallelOption>,
144
145    /// The `HYPOTHETICAL` parameter for [`CREATE AGGREGATE`](https://www.postgresql.org/docs/current/sql-createaggregate.html)
146    ///
147    /// Corresponds to `hypothetical` in `pgrx::aggregate::Aggregate`.
148    pub hypothetical: bool,
149    pub to_sql_config: ToSqlConfigEntity,
150}
151
152impl From<PgAggregateEntity> for SqlGraphEntity {
153    fn from(val: PgAggregateEntity) -> Self {
154        SqlGraphEntity::Aggregate(val)
155    }
156}
157
158impl SqlGraphIdentifier for PgAggregateEntity {
159    fn dot_identifier(&self) -> String {
160        format!("aggregate {}", self.full_path)
161    }
162    fn rust_identifier(&self) -> String {
163        self.full_path.to_string()
164    }
165    fn file(&self) -> Option<&'static str> {
166        Some(self.file)
167    }
168    fn line(&self) -> Option<u32> {
169        Some(self.line)
170    }
171}
172
173impl ToSql for PgAggregateEntity {
174    fn to_sql(&self, context: &PgrxSql) -> eyre::Result<String> {
175        let self_index = context.aggregates[self];
176        let mut optional_attributes = Vec::new();
177        let schema = context.schema_prefix_for(&self_index);
178
179        if let Some(value) = self.finalfunc {
180            optional_attributes.push((
181                format!("\tFINALFUNC = {schema}\"{value}\""),
182                format!("/* {}::final */", self.full_path),
183            ));
184        }
185        if let Some(value) = self.finalfunc_modify {
186            optional_attributes.push((
187                format!("\tFINALFUNC_MODIFY = {}", value.to_sql(context)?),
188                format!("/* {}::FINALIZE_MODIFY */", self.full_path),
189            ));
190        }
191        if let Some(value) = self.combinefunc {
192            optional_attributes.push((
193                format!("\tCOMBINEFUNC = {schema}\"{value}\""),
194                format!("/* {}::combine */", self.full_path),
195            ));
196        }
197        if let Some(value) = self.serialfunc {
198            optional_attributes.push((
199                format!("\tSERIALFUNC = {schema}\"{value}\""),
200                format!("/* {}::serial */", self.full_path),
201            ));
202        }
203        if let Some(value) = self.deserialfunc {
204            optional_attributes.push((
205                format!("\tDESERIALFUNC ={schema} \"{value}\""),
206                format!("/* {}::deserial */", self.full_path),
207            ));
208        }
209        if let Some(value) = self.initcond {
210            optional_attributes.push((
211                format!("\tINITCOND = '{value}'"),
212                format!("/* {}::INITIAL_CONDITION */", self.full_path),
213            ));
214        }
215        if let Some(value) = self.msfunc {
216            optional_attributes.push((
217                format!("\tMSFUNC = {schema}\"{value}\""),
218                format!("/* {}::moving_state */", self.full_path),
219            ));
220        }
221        if let Some(value) = self.minvfunc {
222            optional_attributes.push((
223                format!("\tMINVFUNC = {schema}\"{value}\""),
224                format!("/* {}::moving_state_inverse */", self.full_path),
225            ));
226        }
227        if let Some(value) = self.mfinalfunc {
228            optional_attributes.push((
229                format!("\tMFINALFUNC = {schema}\"{value}\""),
230                format!("/* {}::moving_state_finalize */", self.full_path),
231            ));
232        }
233        if let Some(value) = self.mfinalfunc_modify {
234            optional_attributes.push((
235                format!("\tMFINALFUNC_MODIFY = {}", value.to_sql(context)?),
236                format!("/* {}::MOVING_FINALIZE_MODIFY */", self.full_path),
237            ));
238        }
239        if let Some(value) = self.minitcond {
240            optional_attributes.push((
241                format!("\tMINITCOND = '{value}'"),
242                format!("/* {}::MOVING_INITIAL_CONDITION */", self.full_path),
243            ));
244        }
245        if let Some(value) = self.sortop {
246            optional_attributes.push((
247                format!("\tSORTOP = \"{value}\""),
248                format!("/* {}::SORT_OPERATOR */", self.full_path),
249            ));
250        }
251        if let Some(value) = self.parallel {
252            optional_attributes.push((
253                format!("\tPARALLEL = {}", value.to_sql(context)?),
254                format!("/* {}::PARALLEL */", self.full_path),
255            ));
256        }
257        if self.hypothetical {
258            optional_attributes.push((
259                String::from("\tHYPOTHETICAL"),
260                format!("/* {}::hypothetical */", self.full_path),
261            ))
262        }
263
264        let map_ty = |used_ty: &UsedTypeEntity| -> eyre::Result<String> {
265            match used_ty.metadata.argument_sql {
266                Ok(SqlMapping::As(ref argument_sql)) => Ok(argument_sql.to_string()),
267                Ok(SqlMapping::Composite { array_brackets }) => used_ty
268                    .composite_type
269                    .map(|v| fmt::with_array_brackets(v.into(), array_brackets))
270                    .ok_or_else(|| {
271                        eyre!("Macro expansion time suggested a composite_type!() in return")
272                    }),
273                Ok(SqlMapping::Skip) => {
274                    Err(eyre!("Cannot use skipped SQL translatable type as aggregate const type"))
275                }
276                Err(err) => Err(err).wrap_err("While mapping argument"),
277            }
278        };
279
280        let stype_sql = map_ty(&self.stype.used_ty).wrap_err("Mapping state type")?;
281        let stype_schema = context
282            .types
283            .iter()
284            .map(type_keyed)
285            .chain(context.enums.iter().map(type_keyed))
286            .find(|(ty, _)| ty.id_matches(&self.stype.used_ty.ty_id))
287            .map(|(_, ty_index)| context.schema_prefix_for(ty_index))
288            .unwrap_or_default();
289
290        if let Some(value) = &self.mstype {
291            let mstype_sql = map_ty(value).wrap_err("Mapping moving state type")?;
292            optional_attributes.push((
293                format!("\tMSTYPE = {mstype_sql}"),
294                format!("/* {}::MovingState = {} */", self.full_path, value.full_path),
295            ));
296        }
297
298        let mut optional_attributes_string = String::new();
299        for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() {
300            let optional_attribute_string = format!(
301                "{optional_attribute}{maybe_comma} {comment}{maybe_newline}",
302                optional_attribute = optional_attribute,
303                maybe_comma = if index == optional_attributes.len() - 1 { "" } else { "," },
304                comment = comment,
305                maybe_newline = if index == optional_attributes.len() - 1 { "" } else { "\n" }
306            );
307            optional_attributes_string += &optional_attribute_string;
308        }
309
310        let args = {
311            let mut args = Vec::new();
312            for (idx, arg) in self.args.iter().enumerate() {
313                let graph_index = context
314                    .graph
315                    .neighbors_undirected(self_index)
316                    .find(|neighbor| context.graph[*neighbor].type_matches(&arg.used_ty))
317                    .ok_or_else(|| {
318                        eyre!("Could not find arg type in graph. Got: {:?}", arg.used_ty)
319                    })?;
320                let needs_comma = idx < (self.args.len() - 1);
321                let buf = format!("\
322                       \t{name}{variadic}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
323                   ",
324                       schema_prefix = context.schema_prefix_for(&graph_index),
325                       // First try to match on [`TypeId`] since it's most reliable.
326                       sql_type = match arg.used_ty.metadata.argument_sql {
327                            Ok(SqlMapping::As(ref argument_sql)) => {
328                                argument_sql.to_string()
329                            }
330                            Ok(SqlMapping::Composite {
331                                array_brackets,
332                            }) => {
333                                arg.used_ty
334                                    .composite_type
335                                    .map(|v| {
336                                        fmt::with_array_brackets(v.into(), array_brackets)
337                                    })
338                                    .ok_or_else(|| {
339                                        eyre!(
340                                        "Macro expansion time suggested a composite_type!() in return"
341                                    )
342                                    })?
343                            }
344                            Ok(SqlMapping::Skip) => return Err(eyre!("Got a skipped SQL translatable type in aggregate args, this is not permitted")),
345                            Err(err) => return Err(err).wrap_err("While mapping argument")
346                        },
347                       variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
348                       maybe_comma = if needs_comma { ", " } else { " " },
349                       full_path = arg.used_ty.full_path,
350                       name = if let Some(name) = arg.name {
351                           format!(r#""{name}" "#)
352                       } else { "".to_string() },
353                );
354                args.push(buf);
355            }
356            "\n".to_string() + &args.join("\n") + "\n"
357        };
358        let direct_args = if let Some(direct_args) = &self.direct_args {
359            let mut args = Vec::new();
360            for (idx, arg) in direct_args.iter().enumerate() {
361                let graph_index = context
362                    .graph
363                    .neighbors_undirected(self_index)
364                    .find(|neighbor| context.graph[*neighbor].type_matches(&arg.used_ty))
365                    .ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
366                let needs_comma = idx < (direct_args.len() - 1);
367                let buf = format!(
368                    "\
369                    \t{maybe_name}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
370                   ",
371                    schema_prefix = context.schema_prefix_for(&graph_index),
372                    // First try to match on [`TypeId`] since it's most reliable.
373                    sql_type = map_ty(&arg.used_ty).wrap_err("Mapping direct arg type")?,
374                    maybe_name = if let Some(name) = arg.name {
375                        "\"".to_string() + name + "\" "
376                    } else {
377                        "".to_string()
378                    },
379                    maybe_comma = if needs_comma { ", " } else { " " },
380                    full_path = arg.used_ty.full_path,
381                );
382                args.push(buf);
383            }
384            "\n".to_string() + &args.join("\n") + "\n"
385        } else {
386            String::default()
387        };
388
389        let PgAggregateEntity { name, full_path, file, line, sfunc, .. } = self;
390
391        let sql = format!(
392            "\n\
393                -- {file}:{line}\n\
394                -- {full_path}\n\
395                CREATE AGGREGATE {schema}{name} ({direct_args}{maybe_order_by}{args})\n\
396                (\n\
397                    \tSFUNC = {schema}\"{sfunc}\", /* {full_path}::state */\n\
398                    \tSTYPE = {stype_schema}{stype_sql}{maybe_comma_after_stype} /* {stype_full_path} */\
399                    {optional_attributes}\
400                );\
401            ",
402            stype_full_path = self.stype.used_ty.full_path,
403            maybe_comma_after_stype = if optional_attributes.is_empty() { "" } else { "," },
404            maybe_order_by = if self.ordered_set { "\tORDER BY" } else { "" },
405            optional_attributes = String::from("\n")
406                + &optional_attributes_string
407                + if optional_attributes.is_empty() { "" } else { "\n" },
408        );
409        Ok(sql)
410    }
411}