1use itertools::Itertools;
2use kdl::{KdlDocument, KdlEntry, KdlNode};
3use serde::Serialize;
4use std::fmt::Display;
5use std::hash::Hash;
6use std::str::FromStr;
7
8use crate::error::UsageErr::InvalidFlag;
9use crate::error::{Result, UsageErr};
10use crate::spec::context::ParsingContext;
11use crate::spec::helpers::NodeHelper;
12use crate::spec::is_false;
13use crate::{string, SpecArg, SpecChoices};
14
15#[derive(Debug, Default, Clone, Serialize)]
16pub struct SpecFlag {
17 pub name: String,
18 pub usage: String,
19 #[serde(skip_serializing_if = "Option::is_none")]
20 pub help: Option<String>,
21 #[serde(skip_serializing_if = "Option::is_none")]
22 pub help_long: Option<String>,
23 #[serde(skip_serializing_if = "Option::is_none")]
24 pub help_md: Option<String>,
25 #[serde(skip_serializing_if = "Option::is_none")]
26 pub help_first_line: Option<String>,
27 pub short: Vec<char>,
28 pub long: Vec<String>,
29 #[serde(skip_serializing_if = "is_false")]
30 pub required: bool,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 pub deprecated: Option<String>,
33 #[serde(skip_serializing_if = "is_false")]
34 pub var: bool,
35 pub hide: bool,
36 pub global: bool,
37 #[serde(skip_serializing_if = "is_false")]
38 pub count: bool,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 pub arg: Option<SpecArg>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub default: Option<String>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub negate: Option<String>,
45}
46
47impl SpecFlag {
48 pub(crate) fn parse(ctx: &ParsingContext, node: &NodeHelper) -> Result<Self> {
49 let mut flag: Self = node.arg(0)?.ensure_string()?.parse()?;
50 for (k, v) in node.props() {
51 match k {
52 "help" => flag.help = Some(v.ensure_string()?),
53 "long_help" => flag.help_long = Some(v.ensure_string()?),
54 "help_long" => flag.help_long = Some(v.ensure_string()?),
55 "help_md" => flag.help_md = Some(v.ensure_string()?),
56 "required" => flag.required = v.ensure_bool()?,
57 "var" => flag.var = v.ensure_bool()?,
58 "hide" => flag.hide = v.ensure_bool()?,
59 "deprecated" => {
60 flag.deprecated = match v.value.as_bool() {
61 Some(true) => Some("deprecated".into()),
62 Some(false) => None,
63 None => Some(v.ensure_string()?),
64 }
65 }
66 "global" => flag.global = v.ensure_bool()?,
67 "count" => flag.count = v.ensure_bool()?,
68 "default" => flag.default = v.ensure_string().map(Some)?,
69 "negate" => flag.negate = v.ensure_string().map(Some)?,
70 k => bail_parse!(ctx, v.entry.span(), "unsupported flag key {k}"),
71 }
72 }
73 if flag.default.is_some() {
74 flag.required = false;
75 }
76 for child in node.children() {
77 match child.name() {
78 "arg" => flag.arg = Some(SpecArg::parse(ctx, &child)?),
79 "help" => flag.help = Some(child.arg(0)?.ensure_string()?),
80 "long_help" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
81 "help_long" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
82 "help_md" => flag.help_md = Some(child.arg(0)?.ensure_string()?),
83 "required" => flag.required = child.arg(0)?.ensure_bool()?,
84 "var" => flag.var = child.arg(0)?.ensure_bool()?,
85 "hide" => flag.hide = child.arg(0)?.ensure_bool()?,
86 "deprecated" => {
87 flag.deprecated = match child.arg(0)?.ensure_bool() {
88 Ok(true) => Some("deprecated".into()),
89 Ok(false) => None,
90 _ => Some(child.arg(0)?.ensure_string()?),
91 }
92 }
93 "global" => flag.global = child.arg(0)?.ensure_bool()?,
94 "count" => flag.count = child.arg(0)?.ensure_bool()?,
95 "default" => flag.default = child.arg(0)?.ensure_string().map(Some)?,
96 "choices" => {
97 if let Some(arg) = &mut flag.arg {
98 arg.choices = Some(SpecChoices::parse(ctx, &child)?);
99 } else {
100 bail_parse!(
101 ctx,
102 child.node.name().span(),
103 "flag must have value to have choices"
104 )
105 }
106 }
107 k => bail_parse!(ctx, child.node.name().span(), "unsupported flag child {k}"),
108 }
109 }
110 flag.usage = flag.usage();
111 flag.help_first_line = flag.help.as_ref().map(|s| string::first_line(s));
112 Ok(flag)
113 }
114 pub fn usage(&self) -> String {
115 let mut parts = vec![];
116 let name = get_name_from_short_and_long(&self.short, &self.long).unwrap_or_default();
117 if name != self.name {
118 parts.push(format!("{}:", self.name));
119 }
120 if let Some(short) = self.short.first() {
121 parts.push(format!("-{}", short));
122 }
123 if let Some(long) = self.long.first() {
124 parts.push(format!("--{}", long));
125 }
126 let mut out = parts.join(" ");
127 if self.var {
128 out = format!("{}...", out);
129 }
130 if let Some(arg) = &self.arg {
131 out = format!("{} {}", out, arg.usage());
132 }
133 out
134 }
135}
136
137impl From<&SpecFlag> for KdlNode {
138 fn from(flag: &SpecFlag) -> KdlNode {
139 let mut node = KdlNode::new("flag");
140 let name = flag
141 .short
142 .iter()
143 .map(|c| format!("-{c}"))
144 .chain(flag.long.iter().map(|s| format!("--{s}")))
145 .collect_vec()
146 .join(" ");
147 node.push(KdlEntry::new(name));
148 if let Some(desc) = &flag.help {
149 node.push(KdlEntry::new_prop("help", desc.clone()));
150 }
151 if let Some(desc) = &flag.help_long {
152 let children = node.children_mut().get_or_insert_with(KdlDocument::new);
153 let mut node = KdlNode::new("long_help");
154 node.entries_mut().push(KdlEntry::new(desc.clone()));
155 children.nodes_mut().push(node);
156 }
157 if let Some(desc) = &flag.help_md {
158 let children = node.children_mut().get_or_insert_with(KdlDocument::new);
159 let mut node = KdlNode::new("help_md");
160 node.entries_mut().push(KdlEntry::new(desc.clone()));
161 children.nodes_mut().push(node);
162 }
163 if flag.required {
164 node.push(KdlEntry::new_prop("required", true));
165 }
166 if flag.var {
167 node.push(KdlEntry::new_prop("var", true));
168 }
169 if flag.hide {
170 node.push(KdlEntry::new_prop("hide", true));
171 }
172 if flag.global {
173 node.push(KdlEntry::new_prop("global", true));
174 }
175 if flag.count {
176 node.push(KdlEntry::new_prop("count", true));
177 }
178 if let Some(negate) = &flag.negate {
179 node.push(KdlEntry::new_prop("negate", negate.clone()));
180 }
181 if let Some(deprecated) = &flag.deprecated {
182 node.push(KdlEntry::new_prop("deprecated", deprecated.clone()));
183 }
184 if let Some(arg) = &flag.arg {
185 let children = node.children_mut().get_or_insert_with(KdlDocument::new);
186 children.nodes_mut().push(arg.into());
187 }
188 node
189 }
190}
191
192impl FromStr for SpecFlag {
193 type Err = UsageErr;
194 fn from_str(input: &str) -> Result<Self> {
195 let mut flag = Self::default();
196 let input = input.replace("...", " ... ");
197 for part in input.split_whitespace() {
198 if let Some(name) = part.strip_suffix(':') {
199 flag.name = name.to_string();
200 } else if let Some(long) = part.strip_prefix("--") {
201 flag.long.push(long.to_string());
202 } else if let Some(short) = part.strip_prefix('-') {
203 if short.len() != 1 {
204 return Err(InvalidFlag(
205 short.to_string(),
206 (0, input.len()).into(),
207 input.to_string(),
208 ));
209 }
210 flag.short.push(short.chars().next().unwrap());
211 } else if part == "..." {
212 if let Some(arg) = &mut flag.arg {
213 arg.var = true;
214 } else {
215 flag.var = true;
216 }
217 } else if part.starts_with('<') && part.ends_with('>')
218 || part.starts_with('[') && part.ends_with(']')
219 {
220 flag.arg = Some(part.to_string().parse()?);
221 } else {
222 return Err(InvalidFlag(
223 part.to_string(),
224 (0, input.len()).into(),
225 input.to_string(),
226 ));
227 }
228 }
229 if flag.name.is_empty() {
230 flag.name = get_name_from_short_and_long(&flag.short, &flag.long).unwrap_or_default();
231 }
232 flag.usage = flag.usage();
233 Ok(flag)
234 }
235}
236
237#[cfg(feature = "clap")]
238impl From<&clap::Arg> for SpecFlag {
239 fn from(c: &clap::Arg) -> Self {
240 let required = c.is_required_set();
241 let help = c.get_help().map(|s| s.to_string());
242 let help_long = c.get_long_help().map(|s| s.to_string());
243 let help_first_line = help.as_ref().map(|s| string::first_line(s));
244 let hide = c.is_hide_set();
245 let var = matches!(
246 c.get_action(),
247 clap::ArgAction::Count | clap::ArgAction::Append
248 );
249 let default = c
250 .get_default_values()
251 .first()
252 .map(|s| s.to_string_lossy().to_string());
253 let short = c.get_short_and_visible_aliases().unwrap_or_default();
254 let long = c
255 .get_long_and_visible_aliases()
256 .unwrap_or_default()
257 .into_iter()
258 .map(|s| s.to_string())
259 .collect::<Vec<_>>();
260 let name = get_name_from_short_and_long(&short, &long).unwrap_or_default();
261 let arg = if let clap::ArgAction::Set | clap::ArgAction::Append = c.get_action() {
262 let mut arg = SpecArg::from(
263 c.get_value_names()
264 .map(|s| s.iter().map(|s| s.to_string()).join(" "))
265 .unwrap_or(name.clone())
266 .as_str(),
267 );
268
269 let choices = c
270 .get_possible_values()
271 .iter()
272 .flat_map(|v| v.get_name_and_aliases().map(|s| s.to_string()))
273 .collect::<Vec<_>>();
274 if !choices.is_empty() {
275 arg.choices = Some(SpecChoices { choices });
276 }
277
278 Some(arg)
279 } else {
280 None
281 };
282 Self {
283 name,
284 usage: "".into(),
285 short,
286 long,
287 required,
288 help,
289 help_long,
290 help_md: None,
291 help_first_line,
292 var,
293 hide,
294 global: c.is_global_set(),
295 arg,
296 count: matches!(c.get_action(), clap::ArgAction::Count),
297 default,
298 deprecated: None,
299 negate: None,
300 }
301 }
302}
303
304impl Display for SpecFlag {
351 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352 write!(f, "{}", self.usage())
353 }
354}
355impl PartialEq for SpecFlag {
356 fn eq(&self, other: &Self) -> bool {
357 self.name == other.name
358 }
359}
360impl Eq for SpecFlag {}
361impl Hash for SpecFlag {
362 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
363 self.name.hash(state);
364 }
365}
366
367fn get_name_from_short_and_long(short: &[char], long: &[String]) -> Option<String> {
368 long.first()
369 .map(|s| s.to_string())
370 .or_else(|| short.first().map(|c| c.to_string()))
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376 use insta::assert_snapshot;
377
378 #[test]
379 fn from_str() {
380 assert_snapshot!("-f".parse::<SpecFlag>().unwrap(), @"-f");
381 assert_snapshot!("--flag".parse::<SpecFlag>().unwrap(), @"--flag");
382 assert_snapshot!("-f --flag".parse::<SpecFlag>().unwrap(), @"-f --flag");
383 assert_snapshot!("-f --flag...".parse::<SpecFlag>().unwrap(), @"-f --flag...");
384 assert_snapshot!("-f --flag ...".parse::<SpecFlag>().unwrap(), @"-f --flag...");
385 assert_snapshot!("--flag <arg>".parse::<SpecFlag>().unwrap(), @"--flag <arg>");
386 assert_snapshot!("-f --flag <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>");
387 assert_snapshot!("-f --flag... <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag... <arg>");
388 assert_snapshot!("-f --flag <arg>...".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>...");
389 assert_snapshot!("myflag: -f".parse::<SpecFlag>().unwrap(), @"myflag: -f");
390 assert_snapshot!("myflag: -f --flag <arg>".parse::<SpecFlag>().unwrap(), @"myflag: -f --flag <arg>");
391 }
392}