1use std::io::Write;
2
3use crate::{entry, extension, write::util::CountBytes, State, Version};
4
5#[derive(Default, Debug, Copy, Clone)]
7pub enum Extensions {
8 #[default]
10 All,
11 Given {
18 tree_cache: bool,
20 end_of_index_entry: bool,
22 },
23 None,
25}
26
27impl Extensions {
28 pub fn should_write(&self, signature: extension::Signature) -> Option<extension::Signature> {
30 match self {
31 Extensions::None => None,
32 Extensions::All => Some(signature),
33 Extensions::Given {
34 tree_cache,
35 end_of_index_entry,
36 } => match signature {
37 extension::tree::SIGNATURE => tree_cache,
38 extension::end_of_index_entry::SIGNATURE => end_of_index_entry,
39 _ => &false,
40 }
41 .then(|| signature),
42 }
43 }
44}
45
46#[derive(Debug, Default, Clone, Copy)]
50pub struct Options {
51 pub extensions: Extensions,
53 pub skip_hash: bool,
59}
60
61impl State {
62 pub fn write_to(
64 &self,
65 out: impl std::io::Write,
66 Options {
67 extensions,
68 skip_hash: _,
69 }: Options,
70 ) -> std::io::Result<Version> {
71 let _span = gix_features::trace::detail!("gix_index::State::write()");
72 let version = self.detect_required_version();
73
74 let mut write = CountBytes::new(out);
75 let num_entries: u32 = self
76 .entries()
77 .len()
78 .try_into()
79 .expect("definitely not 4billion entries");
80 let removed_entries: u32 = self
81 .entries()
82 .iter()
83 .filter(|e| e.flags.contains(entry::Flags::REMOVE))
84 .count()
85 .try_into()
86 .expect("definitely not too many entries");
87
88 let offset_to_entries = header(&mut write, version, num_entries - removed_entries)?;
89 let offset_to_extensions = entries(&mut write, self, offset_to_entries)?;
90 let (extension_toc, out) = self.write_extensions(write, offset_to_extensions, extensions)?;
91
92 if num_entries > 0
93 && extensions
94 .should_write(extension::end_of_index_entry::SIGNATURE)
95 .is_some()
96 && !extension_toc.is_empty()
97 {
98 extension::end_of_index_entry::write_to(out, self.object_hash, offset_to_extensions, extension_toc)?;
99 }
100
101 Ok(version)
102 }
103
104 fn write_extensions<T>(
105 &self,
106 mut write: CountBytes<T>,
107 offset_to_extensions: u32,
108 extensions: Extensions,
109 ) -> std::io::Result<(Vec<(extension::Signature, u32)>, T)>
110 where
111 T: std::io::Write,
112 {
113 type WriteExtFn<'a> = &'a dyn Fn(&mut dyn std::io::Write) -> Option<std::io::Result<extension::Signature>>;
114 let extensions: &[WriteExtFn<'_>] = &[
115 &|write| {
116 extensions
117 .should_write(extension::tree::SIGNATURE)
118 .and_then(|signature| self.tree().map(|tree| tree.write_to(write).map(|_| signature)))
119 },
120 &|write| {
121 self.is_sparse()
122 .then(|| extension::sparse::write_to(write).map(|_| extension::sparse::SIGNATURE))
123 },
124 ];
125
126 let mut offset_to_previous_ext = offset_to_extensions;
127 let mut out = Vec::with_capacity(5);
128 for write_ext in extensions {
129 if let Some(signature) = write_ext(&mut write).transpose()? {
130 let offset_past_ext = write.count;
131 let ext_size = offset_past_ext - offset_to_previous_ext - (extension::MIN_SIZE as u32);
132 offset_to_previous_ext = offset_past_ext;
133 out.push((signature, ext_size));
134 }
135 }
136 Ok((out, write.inner))
137 }
138}
139
140impl State {
141 fn detect_required_version(&self) -> Version {
142 self.entries
143 .iter()
144 .find_map(|e| e.flags.contains(entry::Flags::EXTENDED).then_some(Version::V3))
145 .unwrap_or(Version::V2)
146 }
147}
148
149fn header<T: std::io::Write>(
150 out: &mut CountBytes<T>,
151 version: Version,
152 num_entries: u32,
153) -> Result<u32, std::io::Error> {
154 let version = match version {
155 Version::V2 => 2_u32.to_be_bytes(),
156 Version::V3 => 3_u32.to_be_bytes(),
157 Version::V4 => 4_u32.to_be_bytes(),
158 };
159
160 out.write_all(crate::decode::header::SIGNATURE)?;
161 out.write_all(&version)?;
162 out.write_all(&num_entries.to_be_bytes())?;
163
164 Ok(out.count)
165}
166
167fn entries<T: std::io::Write>(out: &mut CountBytes<T>, state: &State, header_size: u32) -> Result<u32, std::io::Error> {
168 for entry in state.entries() {
169 if entry.flags.contains(entry::Flags::REMOVE) {
170 continue;
171 }
172 entry.write_to(&mut *out, state)?;
173 match (out.count - header_size) % 8 {
174 0 => {}
175 n => {
176 let eight_null_bytes = [0u8; 8];
177 out.write_all(&eight_null_bytes[n as usize..])?;
178 }
179 };
180 }
181
182 Ok(out.count)
183}
184
185mod util {
186 pub struct CountBytes<T> {
187 pub count: u32,
188 pub inner: T,
189 }
190
191 impl<T> CountBytes<T>
192 where
193 T: std::io::Write,
194 {
195 pub fn new(inner: T) -> Self {
196 CountBytes { inner, count: 0 }
197 }
198 }
199
200 impl<T> std::io::Write for CountBytes<T>
201 where
202 T: std::io::Write,
203 {
204 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
205 let written = self.inner.write(buf)?;
206 self.count = self
207 .count
208 .checked_add(u32::try_from(written).expect("we don't write 4GB buffers"))
209 .ok_or_else(|| {
210 std::io::Error::new(
211 std::io::ErrorKind::Other,
212 "Cannot write indices larger than 4 gigabytes",
213 )
214 })?;
215 Ok(written)
216 }
217
218 fn flush(&mut self) -> std::io::Result<()> {
219 self.inner.flush()
220 }
221 }
222}