usage/spec/
flag.rs

1use itertools::Itertools;
2use kdl::{KdlDocument, KdlEntry, KdlNode};
3use serde::Serialize;
4use std::fmt::Display;
5use std::hash::Hash;
6use std::str::FromStr;
7
8use crate::error::UsageErr::InvalidFlag;
9use crate::error::{Result, UsageErr};
10use crate::spec::context::ParsingContext;
11use crate::spec::helpers::NodeHelper;
12use crate::spec::is_false;
13use crate::{string, SpecArg, SpecChoices};
14
15#[derive(Debug, Default, Clone, Serialize)]
16pub struct SpecFlag {
17    pub name: String,
18    pub usage: String,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    pub help: Option<String>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    pub help_long: Option<String>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    pub help_md: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    pub help_first_line: Option<String>,
27    pub short: Vec<char>,
28    pub long: Vec<String>,
29    #[serde(skip_serializing_if = "is_false")]
30    pub required: bool,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    pub deprecated: Option<String>,
33    #[serde(skip_serializing_if = "is_false")]
34    pub var: bool,
35    pub hide: bool,
36    pub global: bool,
37    #[serde(skip_serializing_if = "is_false")]
38    pub count: bool,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub arg: Option<SpecArg>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub default: Option<String>,
43    #[serde(skip_serializing_if = "Option::is_none")]
44    pub negate: Option<String>,
45}
46
47impl SpecFlag {
48    pub(crate) fn parse(ctx: &ParsingContext, node: &NodeHelper) -> Result<Self> {
49        let mut flag: Self = node.arg(0)?.ensure_string()?.parse()?;
50        for (k, v) in node.props() {
51            match k {
52                "help" => flag.help = Some(v.ensure_string()?),
53                "long_help" => flag.help_long = Some(v.ensure_string()?),
54                "help_long" => flag.help_long = Some(v.ensure_string()?),
55                "help_md" => flag.help_md = Some(v.ensure_string()?),
56                "required" => flag.required = v.ensure_bool()?,
57                "var" => flag.var = v.ensure_bool()?,
58                "hide" => flag.hide = v.ensure_bool()?,
59                "deprecated" => {
60                    flag.deprecated = match v.value.as_bool() {
61                        Some(true) => Some("deprecated".into()),
62                        Some(false) => None,
63                        None => Some(v.ensure_string()?),
64                    }
65                }
66                "global" => flag.global = v.ensure_bool()?,
67                "count" => flag.count = v.ensure_bool()?,
68                "default" => flag.default = v.ensure_string().map(Some)?,
69                "negate" => flag.negate = v.ensure_string().map(Some)?,
70                k => bail_parse!(ctx, v.entry.span(), "unsupported flag key {k}"),
71            }
72        }
73        if flag.default.is_some() {
74            flag.required = false;
75        }
76        for child in node.children() {
77            match child.name() {
78                "arg" => flag.arg = Some(SpecArg::parse(ctx, &child)?),
79                "help" => flag.help = Some(child.arg(0)?.ensure_string()?),
80                "long_help" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
81                "help_long" => flag.help_long = Some(child.arg(0)?.ensure_string()?),
82                "help_md" => flag.help_md = Some(child.arg(0)?.ensure_string()?),
83                "required" => flag.required = child.arg(0)?.ensure_bool()?,
84                "var" => flag.var = child.arg(0)?.ensure_bool()?,
85                "hide" => flag.hide = child.arg(0)?.ensure_bool()?,
86                "deprecated" => {
87                    flag.deprecated = match child.arg(0)?.ensure_bool() {
88                        Ok(true) => Some("deprecated".into()),
89                        Ok(false) => None,
90                        _ => Some(child.arg(0)?.ensure_string()?),
91                    }
92                }
93                "global" => flag.global = child.arg(0)?.ensure_bool()?,
94                "count" => flag.count = child.arg(0)?.ensure_bool()?,
95                "default" => flag.default = child.arg(0)?.ensure_string().map(Some)?,
96                "choices" => {
97                    if let Some(arg) = &mut flag.arg {
98                        arg.choices = Some(SpecChoices::parse(ctx, &child)?);
99                    } else {
100                        bail_parse!(
101                            ctx,
102                            child.node.name().span(),
103                            "flag must have value to have choices"
104                        )
105                    }
106                }
107                k => bail_parse!(ctx, child.node.name().span(), "unsupported flag child {k}"),
108            }
109        }
110        flag.usage = flag.usage();
111        flag.help_first_line = flag.help.as_ref().map(|s| string::first_line(s));
112        Ok(flag)
113    }
114    pub fn usage(&self) -> String {
115        let mut parts = vec![];
116        let name = get_name_from_short_and_long(&self.short, &self.long).unwrap_or_default();
117        if name != self.name {
118            parts.push(format!("{}:", self.name));
119        }
120        if let Some(short) = self.short.first() {
121            parts.push(format!("-{}", short));
122        }
123        if let Some(long) = self.long.first() {
124            parts.push(format!("--{}", long));
125        }
126        let mut out = parts.join(" ");
127        if self.var {
128            out = format!("{}...", out);
129        }
130        if let Some(arg) = &self.arg {
131            out = format!("{} {}", out, arg.usage());
132        }
133        out
134    }
135}
136
137impl From<&SpecFlag> for KdlNode {
138    fn from(flag: &SpecFlag) -> KdlNode {
139        let mut node = KdlNode::new("flag");
140        let name = flag
141            .short
142            .iter()
143            .map(|c| format!("-{c}"))
144            .chain(flag.long.iter().map(|s| format!("--{s}")))
145            .collect_vec()
146            .join(" ");
147        node.push(KdlEntry::new(name));
148        if let Some(desc) = &flag.help {
149            node.push(KdlEntry::new_prop("help", desc.clone()));
150        }
151        if let Some(desc) = &flag.help_long {
152            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
153            let mut node = KdlNode::new("long_help");
154            node.entries_mut().push(KdlEntry::new(desc.clone()));
155            children.nodes_mut().push(node);
156        }
157        if let Some(desc) = &flag.help_md {
158            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
159            let mut node = KdlNode::new("help_md");
160            node.entries_mut().push(KdlEntry::new(desc.clone()));
161            children.nodes_mut().push(node);
162        }
163        if flag.required {
164            node.push(KdlEntry::new_prop("required", true));
165        }
166        if flag.var {
167            node.push(KdlEntry::new_prop("var", true));
168        }
169        if flag.hide {
170            node.push(KdlEntry::new_prop("hide", true));
171        }
172        if flag.global {
173            node.push(KdlEntry::new_prop("global", true));
174        }
175        if flag.count {
176            node.push(KdlEntry::new_prop("count", true));
177        }
178        if let Some(negate) = &flag.negate {
179            node.push(KdlEntry::new_prop("negate", negate.clone()));
180        }
181        if let Some(deprecated) = &flag.deprecated {
182            node.push(KdlEntry::new_prop("deprecated", deprecated.clone()));
183        }
184        if let Some(arg) = &flag.arg {
185            let children = node.children_mut().get_or_insert_with(KdlDocument::new);
186            children.nodes_mut().push(arg.into());
187        }
188        node
189    }
190}
191
192impl FromStr for SpecFlag {
193    type Err = UsageErr;
194    fn from_str(input: &str) -> Result<Self> {
195        let mut flag = Self::default();
196        let input = input.replace("...", " ... ");
197        for part in input.split_whitespace() {
198            if let Some(name) = part.strip_suffix(':') {
199                flag.name = name.to_string();
200            } else if let Some(long) = part.strip_prefix("--") {
201                flag.long.push(long.to_string());
202            } else if let Some(short) = part.strip_prefix('-') {
203                if short.len() != 1 {
204                    return Err(InvalidFlag(
205                        short.to_string(),
206                        (0, input.len()).into(),
207                        input.to_string(),
208                    ));
209                }
210                flag.short.push(short.chars().next().unwrap());
211            } else if part == "..." {
212                if let Some(arg) = &mut flag.arg {
213                    arg.var = true;
214                } else {
215                    flag.var = true;
216                }
217            } else if part.starts_with('<') && part.ends_with('>')
218                || part.starts_with('[') && part.ends_with(']')
219            {
220                flag.arg = Some(part.to_string().parse()?);
221            } else {
222                return Err(InvalidFlag(
223                    part.to_string(),
224                    (0, input.len()).into(),
225                    input.to_string(),
226                ));
227            }
228        }
229        if flag.name.is_empty() {
230            flag.name = get_name_from_short_and_long(&flag.short, &flag.long).unwrap_or_default();
231        }
232        flag.usage = flag.usage();
233        Ok(flag)
234    }
235}
236
237#[cfg(feature = "clap")]
238impl From<&clap::Arg> for SpecFlag {
239    fn from(c: &clap::Arg) -> Self {
240        let required = c.is_required_set();
241        let help = c.get_help().map(|s| s.to_string());
242        let help_long = c.get_long_help().map(|s| s.to_string());
243        let help_first_line = help.as_ref().map(|s| string::first_line(s));
244        let hide = c.is_hide_set();
245        let var = matches!(
246            c.get_action(),
247            clap::ArgAction::Count | clap::ArgAction::Append
248        );
249        let default = c
250            .get_default_values()
251            .first()
252            .map(|s| s.to_string_lossy().to_string());
253        let short = c.get_short_and_visible_aliases().unwrap_or_default();
254        let long = c
255            .get_long_and_visible_aliases()
256            .unwrap_or_default()
257            .into_iter()
258            .map(|s| s.to_string())
259            .collect::<Vec<_>>();
260        let name = get_name_from_short_and_long(&short, &long).unwrap_or_default();
261        let arg = if let clap::ArgAction::Set | clap::ArgAction::Append = c.get_action() {
262            let mut arg = SpecArg::from(
263                c.get_value_names()
264                    .map(|s| s.iter().map(|s| s.to_string()).join(" "))
265                    .unwrap_or(name.clone())
266                    .as_str(),
267            );
268
269            let choices = c
270                .get_possible_values()
271                .iter()
272                .flat_map(|v| v.get_name_and_aliases().map(|s| s.to_string()))
273                .collect::<Vec<_>>();
274            if !choices.is_empty() {
275                arg.choices = Some(SpecChoices { choices });
276            }
277
278            Some(arg)
279        } else {
280            None
281        };
282        Self {
283            name,
284            usage: "".into(),
285            short,
286            long,
287            required,
288            help,
289            help_long,
290            help_md: None,
291            help_first_line,
292            var,
293            hide,
294            global: c.is_global_set(),
295            arg,
296            count: matches!(c.get_action(), clap::ArgAction::Count),
297            default,
298            deprecated: None,
299            negate: None,
300        }
301    }
302}
303
304// #[cfg(feature = "clap")]
305// impl From<&SpecFlag> for clap::Arg {
306//     fn from(flag: &SpecFlag) -> Self {
307//         let mut a = clap::Arg::new(&flag.name);
308//         if let Some(desc) = &flag.help {
309//             a = a.help(desc);
310//         }
311//         if flag.required {
312//             a = a.required(true);
313//         }
314//         if let Some(arg) = &flag.arg {
315//             a = a.value_name(&arg.name);
316//             if arg.var {
317//                 a = a.action(clap::ArgAction::Append)
318//             } else {
319//                 a = a.action(clap::ArgAction::Set)
320//             }
321//         } else {
322//             a = a.action(clap::ArgAction::SetTrue)
323//         }
324//         // let mut a = clap::Arg::new(&flag.name)
325//         //     .required(flag.required)
326//         //     .action(clap::ArgAction::SetTrue);
327//         if let Some(short) = flag.short.first() {
328//             a = a.short(*short);
329//         }
330//         if let Some(long) = flag.long.first() {
331//             a = a.long(long);
332//         }
333//         for short in flag.short.iter().skip(1) {
334//             a = a.visible_short_alias(*short);
335//         }
336//         for long in flag.long.iter().skip(1) {
337//             a = a.visible_alias(long);
338//         }
339//         // cmd = cmd.arg(a);
340//         // if flag.multiple {
341//         //     a = a.multiple(true);
342//         // }
343//         // if flag.hide {
344//         //     a = a.hide_possible_values(true);
345//         // }
346//         a
347//     }
348// }
349
350impl Display for SpecFlag {
351    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
352        write!(f, "{}", self.usage())
353    }
354}
355impl PartialEq for SpecFlag {
356    fn eq(&self, other: &Self) -> bool {
357        self.name == other.name
358    }
359}
360impl Eq for SpecFlag {}
361impl Hash for SpecFlag {
362    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
363        self.name.hash(state);
364    }
365}
366
367fn get_name_from_short_and_long(short: &[char], long: &[String]) -> Option<String> {
368    long.first()
369        .map(|s| s.to_string())
370        .or_else(|| short.first().map(|c| c.to_string()))
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use insta::assert_snapshot;
377
378    #[test]
379    fn from_str() {
380        assert_snapshot!("-f".parse::<SpecFlag>().unwrap(), @"-f");
381        assert_snapshot!("--flag".parse::<SpecFlag>().unwrap(), @"--flag");
382        assert_snapshot!("-f --flag".parse::<SpecFlag>().unwrap(), @"-f --flag");
383        assert_snapshot!("-f --flag...".parse::<SpecFlag>().unwrap(), @"-f --flag...");
384        assert_snapshot!("-f --flag ...".parse::<SpecFlag>().unwrap(), @"-f --flag...");
385        assert_snapshot!("--flag <arg>".parse::<SpecFlag>().unwrap(), @"--flag <arg>");
386        assert_snapshot!("-f --flag <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>");
387        assert_snapshot!("-f --flag... <arg>".parse::<SpecFlag>().unwrap(), @"-f --flag... <arg>");
388        assert_snapshot!("-f --flag <arg>...".parse::<SpecFlag>().unwrap(), @"-f --flag <arg>...");
389        assert_snapshot!("myflag: -f".parse::<SpecFlag>().unwrap(), @"myflag: -f");
390        assert_snapshot!("myflag: -f --flag <arg>".parse::<SpecFlag>().unwrap(), @"myflag: -f --flag <arg>");
391    }
392}