pgrx_sql_entity_graph/aggregate/
entity.rs1use crate::aggregate::options::{FinalizeModify, ParallelOption};
20use crate::fmt;
21use crate::metadata::SqlMapping;
22use crate::pgrx_sql::PgrxSql;
23use crate::to_sql::entity::ToSqlConfigEntity;
24use crate::to_sql::ToSql;
25use crate::type_keyed;
26use crate::{SqlGraphEntity, SqlGraphIdentifier, UsedTypeEntity};
27use core::any::TypeId;
28use eyre::{eyre, WrapErr};
29
30#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
31pub struct AggregateTypeEntity {
32 pub used_ty: UsedTypeEntity,
33 pub name: Option<&'static str>,
34}
35
36#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
37pub struct PgAggregateEntity {
38 pub full_path: &'static str,
39 pub module_path: &'static str,
40 pub file: &'static str,
41 pub line: u32,
42 pub ty_id: TypeId,
43
44 pub name: &'static str,
45
46 pub ordered_set: bool,
50
51 pub args: Vec<AggregateTypeEntity>,
55
56 pub direct_args: Option<Vec<AggregateTypeEntity>>,
60
61 pub stype: AggregateTypeEntity,
65
66 pub sfunc: &'static str,
70
71 pub finalfunc: Option<&'static str>,
75
76 pub finalfunc_modify: Option<FinalizeModify>,
80
81 pub combinefunc: Option<&'static str>,
85
86 pub serialfunc: Option<&'static str>,
90
91 pub deserialfunc: Option<&'static str>,
95
96 pub initcond: Option<&'static str>,
100
101 pub msfunc: Option<&'static str>,
105
106 pub minvfunc: Option<&'static str>,
110
111 pub mstype: Option<UsedTypeEntity>,
115
116 pub mfinalfunc: Option<&'static str>,
124
125 pub mfinalfunc_modify: Option<FinalizeModify>,
129
130 pub minitcond: Option<&'static str>,
134
135 pub sortop: Option<&'static str>,
139
140 pub parallel: Option<ParallelOption>,
144
145 pub hypothetical: bool,
149 pub to_sql_config: ToSqlConfigEntity,
150}
151
152impl From<PgAggregateEntity> for SqlGraphEntity {
153 fn from(val: PgAggregateEntity) -> Self {
154 SqlGraphEntity::Aggregate(val)
155 }
156}
157
158impl SqlGraphIdentifier for PgAggregateEntity {
159 fn dot_identifier(&self) -> String {
160 format!("aggregate {}", self.full_path)
161 }
162 fn rust_identifier(&self) -> String {
163 self.full_path.to_string()
164 }
165 fn file(&self) -> Option<&'static str> {
166 Some(self.file)
167 }
168 fn line(&self) -> Option<u32> {
169 Some(self.line)
170 }
171}
172
173impl ToSql for PgAggregateEntity {
174 fn to_sql(&self, context: &PgrxSql) -> eyre::Result<String> {
175 let self_index = context.aggregates[self];
176 let mut optional_attributes = Vec::new();
177 let schema = context.schema_prefix_for(&self_index);
178
179 if let Some(value) = self.finalfunc {
180 optional_attributes.push((
181 format!("\tFINALFUNC = {schema}\"{value}\""),
182 format!("/* {}::final */", self.full_path),
183 ));
184 }
185 if let Some(value) = self.finalfunc_modify {
186 optional_attributes.push((
187 format!("\tFINALFUNC_MODIFY = {}", value.to_sql(context)?),
188 format!("/* {}::FINALIZE_MODIFY */", self.full_path),
189 ));
190 }
191 if let Some(value) = self.combinefunc {
192 optional_attributes.push((
193 format!("\tCOMBINEFUNC = {schema}\"{value}\""),
194 format!("/* {}::combine */", self.full_path),
195 ));
196 }
197 if let Some(value) = self.serialfunc {
198 optional_attributes.push((
199 format!("\tSERIALFUNC = {schema}\"{value}\""),
200 format!("/* {}::serial */", self.full_path),
201 ));
202 }
203 if let Some(value) = self.deserialfunc {
204 optional_attributes.push((
205 format!("\tDESERIALFUNC ={schema} \"{value}\""),
206 format!("/* {}::deserial */", self.full_path),
207 ));
208 }
209 if let Some(value) = self.initcond {
210 optional_attributes.push((
211 format!("\tINITCOND = '{value}'"),
212 format!("/* {}::INITIAL_CONDITION */", self.full_path),
213 ));
214 }
215 if let Some(value) = self.msfunc {
216 optional_attributes.push((
217 format!("\tMSFUNC = {schema}\"{value}\""),
218 format!("/* {}::moving_state */", self.full_path),
219 ));
220 }
221 if let Some(value) = self.minvfunc {
222 optional_attributes.push((
223 format!("\tMINVFUNC = {schema}\"{value}\""),
224 format!("/* {}::moving_state_inverse */", self.full_path),
225 ));
226 }
227 if let Some(value) = self.mfinalfunc {
228 optional_attributes.push((
229 format!("\tMFINALFUNC = {schema}\"{value}\""),
230 format!("/* {}::moving_state_finalize */", self.full_path),
231 ));
232 }
233 if let Some(value) = self.mfinalfunc_modify {
234 optional_attributes.push((
235 format!("\tMFINALFUNC_MODIFY = {}", value.to_sql(context)?),
236 format!("/* {}::MOVING_FINALIZE_MODIFY */", self.full_path),
237 ));
238 }
239 if let Some(value) = self.minitcond {
240 optional_attributes.push((
241 format!("\tMINITCOND = '{value}'"),
242 format!("/* {}::MOVING_INITIAL_CONDITION */", self.full_path),
243 ));
244 }
245 if let Some(value) = self.sortop {
246 optional_attributes.push((
247 format!("\tSORTOP = \"{value}\""),
248 format!("/* {}::SORT_OPERATOR */", self.full_path),
249 ));
250 }
251 if let Some(value) = self.parallel {
252 optional_attributes.push((
253 format!("\tPARALLEL = {}", value.to_sql(context)?),
254 format!("/* {}::PARALLEL */", self.full_path),
255 ));
256 }
257 if self.hypothetical {
258 optional_attributes.push((
259 String::from("\tHYPOTHETICAL"),
260 format!("/* {}::hypothetical */", self.full_path),
261 ))
262 }
263
264 let map_ty = |used_ty: &UsedTypeEntity| -> eyre::Result<String> {
265 match used_ty.metadata.argument_sql {
266 Ok(SqlMapping::As(ref argument_sql)) => Ok(argument_sql.to_string()),
267 Ok(SqlMapping::Composite { array_brackets }) => used_ty
268 .composite_type
269 .map(|v| fmt::with_array_brackets(v.into(), array_brackets))
270 .ok_or_else(|| {
271 eyre!("Macro expansion time suggested a composite_type!() in return")
272 }),
273 Ok(SqlMapping::Skip) => {
274 Err(eyre!("Cannot use skipped SQL translatable type as aggregate const type"))
275 }
276 Err(err) => Err(err).wrap_err("While mapping argument"),
277 }
278 };
279
280 let stype_sql = map_ty(&self.stype.used_ty).wrap_err("Mapping state type")?;
281 let stype_schema = context
282 .types
283 .iter()
284 .map(type_keyed)
285 .chain(context.enums.iter().map(type_keyed))
286 .find(|(ty, _)| ty.id_matches(&self.stype.used_ty.ty_id))
287 .map(|(_, ty_index)| context.schema_prefix_for(ty_index))
288 .unwrap_or_default();
289
290 if let Some(value) = &self.mstype {
291 let mstype_sql = map_ty(value).wrap_err("Mapping moving state type")?;
292 optional_attributes.push((
293 format!("\tMSTYPE = {mstype_sql}"),
294 format!("/* {}::MovingState = {} */", self.full_path, value.full_path),
295 ));
296 }
297
298 let mut optional_attributes_string = String::new();
299 for (index, (optional_attribute, comment)) in optional_attributes.iter().enumerate() {
300 let optional_attribute_string = format!(
301 "{optional_attribute}{maybe_comma} {comment}{maybe_newline}",
302 optional_attribute = optional_attribute,
303 maybe_comma = if index == optional_attributes.len() - 1 { "" } else { "," },
304 comment = comment,
305 maybe_newline = if index == optional_attributes.len() - 1 { "" } else { "\n" }
306 );
307 optional_attributes_string += &optional_attribute_string;
308 }
309
310 let args = {
311 let mut args = Vec::new();
312 for (idx, arg) in self.args.iter().enumerate() {
313 let graph_index = context
314 .graph
315 .neighbors_undirected(self_index)
316 .find(|neighbor| context.graph[*neighbor].type_matches(&arg.used_ty))
317 .ok_or_else(|| {
318 eyre!("Could not find arg type in graph. Got: {:?}", arg.used_ty)
319 })?;
320 let needs_comma = idx < (self.args.len() - 1);
321 let buf = format!("\
322 \t{name}{variadic}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
323 ",
324 schema_prefix = context.schema_prefix_for(&graph_index),
325 sql_type = match arg.used_ty.metadata.argument_sql {
327 Ok(SqlMapping::As(ref argument_sql)) => {
328 argument_sql.to_string()
329 }
330 Ok(SqlMapping::Composite {
331 array_brackets,
332 }) => {
333 arg.used_ty
334 .composite_type
335 .map(|v| {
336 fmt::with_array_brackets(v.into(), array_brackets)
337 })
338 .ok_or_else(|| {
339 eyre!(
340 "Macro expansion time suggested a composite_type!() in return"
341 )
342 })?
343 }
344 Ok(SqlMapping::Skip) => return Err(eyre!("Got a skipped SQL translatable type in aggregate args, this is not permitted")),
345 Err(err) => return Err(err).wrap_err("While mapping argument")
346 },
347 variadic = if arg.used_ty.variadic { "VARIADIC " } else { "" },
348 maybe_comma = if needs_comma { ", " } else { " " },
349 full_path = arg.used_ty.full_path,
350 name = if let Some(name) = arg.name {
351 format!(r#""{name}" "#)
352 } else { "".to_string() },
353 );
354 args.push(buf);
355 }
356 "\n".to_string() + &args.join("\n") + "\n"
357 };
358 let direct_args = if let Some(direct_args) = &self.direct_args {
359 let mut args = Vec::new();
360 for (idx, arg) in direct_args.iter().enumerate() {
361 let graph_index = context
362 .graph
363 .neighbors_undirected(self_index)
364 .find(|neighbor| context.graph[*neighbor].type_matches(&arg.used_ty))
365 .ok_or_else(|| eyre!("Could not find arg type in graph. Got: {:?}", arg))?;
366 let needs_comma = idx < (direct_args.len() - 1);
367 let buf = format!(
368 "\
369 \t{maybe_name}{schema_prefix}{sql_type}{maybe_comma}/* {full_path} */\
370 ",
371 schema_prefix = context.schema_prefix_for(&graph_index),
372 sql_type = map_ty(&arg.used_ty).wrap_err("Mapping direct arg type")?,
374 maybe_name = if let Some(name) = arg.name {
375 "\"".to_string() + name + "\" "
376 } else {
377 "".to_string()
378 },
379 maybe_comma = if needs_comma { ", " } else { " " },
380 full_path = arg.used_ty.full_path,
381 );
382 args.push(buf);
383 }
384 "\n".to_string() + &args.join("\n") + "\n"
385 } else {
386 String::default()
387 };
388
389 let PgAggregateEntity { name, full_path, file, line, sfunc, .. } = self;
390
391 let sql = format!(
392 "\n\
393 -- {file}:{line}\n\
394 -- {full_path}\n\
395 CREATE AGGREGATE {schema}{name} ({direct_args}{maybe_order_by}{args})\n\
396 (\n\
397 \tSFUNC = {schema}\"{sfunc}\", /* {full_path}::state */\n\
398 \tSTYPE = {stype_schema}{stype_sql}{maybe_comma_after_stype} /* {stype_full_path} */\
399 {optional_attributes}\
400 );\
401 ",
402 stype_full_path = self.stype.used_ty.full_path,
403 maybe_comma_after_stype = if optional_attributes.is_empty() { "" } else { "," },
404 maybe_order_by = if self.ordered_set { "\tORDER BY" } else { "" },
405 optional_attributes = String::from("\n")
406 + &optional_attributes_string
407 + if optional_attributes.is_empty() { "" } else { "\n" },
408 );
409 Ok(sql)
410 }
411}