1use aes::cipher::Unsigned;
7use hmac::Mac;
8use hmac::digest::generic_array::{ArrayLength, GenericArray};
9use sha2::digest::{FixedOutput, MacError, Output};
10
11#[derive(Clone)]
12pub struct Incremental<M: Mac + Clone> {
13 mac: M,
14 chunk_size: usize,
15 unused_length: usize,
16}
17
18#[derive(Clone)]
19pub struct Validating<M: Mac + Clone> {
20 incremental: Incremental<M>,
21 expected: Vec<Output<M>>,
23}
24
25const MINIMUM_CHUNK_SIZE: usize = 64 * 1024;
26const MAXIMUM_CHUNK_SIZE: usize = 2 * 1024 * 1024;
27const TARGET_TOTAL_DIGEST_SIZE: usize = 8 * 1024;
28
29pub const fn calculate_chunk_size<D>(data_size: usize) -> usize
30where
31 D: FixedOutput,
32 D::OutputSize: ArrayLength<u8>,
33{
34 assert!(
35 0 == TARGET_TOTAL_DIGEST_SIZE % D::OutputSize::USIZE,
36 "Target digest size should be a multiple of digest size"
37 );
38 let target_chunk_count = TARGET_TOTAL_DIGEST_SIZE / D::OutputSize::USIZE;
39 if data_size < target_chunk_count * MINIMUM_CHUNK_SIZE {
40 return MINIMUM_CHUNK_SIZE;
41 }
42 if data_size < target_chunk_count * MAXIMUM_CHUNK_SIZE {
43 return data_size.div_ceil(target_chunk_count);
44 }
45 MAXIMUM_CHUNK_SIZE
46}
47
48impl<M: Mac + Clone> Incremental<M> {
49 pub fn new(mac: M, chunk_size: usize) -> Self {
50 assert!(chunk_size > 0, "chunk size must be positive");
51 Self {
52 mac,
53 chunk_size,
54 unused_length: chunk_size,
55 }
56 }
57
58 pub fn validating<'a, A, I>(self, macs: I) -> Validating<M>
59 where
60 A: Into<&'a GenericArray<u8, M::OutputSize>> + std::ops::Deref<Target: Sized>,
65 I: IntoIterator<Item = A, IntoIter: DoubleEndedIterator>,
66 {
67 let expected = macs
68 .into_iter()
69 .map(|mac| mac.into().to_owned())
70 .rev()
71 .collect();
72 Validating {
73 incremental: self,
74 expected,
75 }
76 }
77
78 pub fn update<'a>(&'a mut self, bytes: &'a [u8]) -> impl Iterator<Item = Output<M>> + 'a {
79 let split_point = std::cmp::min(bytes.len(), self.unused_length);
80 let (to_write, overflow) = bytes.split_at(split_point);
81
82 std::iter::once(to_write)
83 .chain(overflow.chunks(self.chunk_size))
84 .flat_map(move |chunk| self.update_chunk(chunk))
85 }
86
87 pub fn finalize(self) -> Output<M> {
88 self.mac.finalize().into_bytes()
89 }
90
91 fn update_chunk(&mut self, bytes: &[u8]) -> Option<Output<M>> {
92 assert!(bytes.len() <= self.unused_length);
93 self.mac.update(bytes);
94 self.unused_length -= bytes.len();
95 if self.unused_length == 0 {
96 self.unused_length = self.chunk_size;
97 let mac = self.mac.clone();
98 Some(mac.finalize().into_bytes())
99 } else {
100 None
101 }
102 }
103
104 fn pending_bytes_size(&self) -> usize {
105 self.chunk_size - self.unused_length
106 }
107}
108
109impl<M: Mac + Clone> Validating<M> {
110 pub fn update(&mut self, bytes: &[u8]) -> Result<usize, MacError> {
111 let mut result = Ok(0);
112 let macs = self.incremental.update(bytes);
113
114 let mut whole_chunks = 0;
115 for mac in macs {
116 match self.expected.last() {
117 Some(expected) if expected == &mac => {
118 whole_chunks += 1;
119 self.expected.pop();
120 }
121 _ => {
122 result = Err(MacError);
123 }
124 }
125 }
126 let validated_bytes = whole_chunks * self.incremental.chunk_size;
127 result.map(|_| validated_bytes)
128 }
129
130 pub fn finalize(self) -> Result<usize, MacError> {
131 let pending_bytes_size = self.incremental.pending_bytes_size();
132 let mac = self.incremental.finalize();
133 match &self.expected[..] {
134 [expected] if expected == &mac => Ok(pending_bytes_size),
135 _ => Err(MacError),
136 }
137 }
138}
139
140#[cfg(test)]
141mod test {
142 use const_str::hex;
143 use hmac::Hmac;
144 use proptest::prelude::*;
145 use rand::distr::uniform::{UniformSampler as _, UniformUsize};
146 use rand::prelude::{Rng, ThreadRng};
147 use sha2::Sha256;
148 use sha2::digest::OutputSizeUser;
149
150 use super::*;
151 use crate::crypto::hmac_sha256;
152
153 const TEST_HMAC_KEY: &[u8] =
154 &hex!("a83481457efecc69ad1342e21d9c0297f71debbf5c9304b4c1b2e433c1a78f98");
155
156 const TEST_CHUNK_SIZE: usize = 32;
157
158 fn new_incremental(key: &[u8], chunk_size: usize) -> Incremental<Hmac<Sha256>> {
159 let hmac = Hmac::<Sha256>::new_from_slice(key)
160 .expect("Should be able to create a new HMAC instance");
161 Incremental::new(hmac, chunk_size)
162 }
163
164 #[test]
165 #[should_panic]
166 fn chunk_size_zero() {
167 new_incremental(&[], 0);
168 }
169
170 #[test]
171 fn simple_test() {
172 let key = TEST_HMAC_KEY;
173 let input = "this is a simple test input string which is longer than the chunk";
174
175 let bytes = input.as_bytes();
176 let expected = hmac_sha256(key, bytes);
177 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
178 let _ = incremental.update(bytes).collect::<Vec<_>>();
179 let digest = incremental.finalize();
180 let actual: [u8; 32] = digest.into();
181 assert_eq!(actual, expected);
182 }
183
184 #[test]
185 fn final_result_should_be_equal_to_non_incremental_hmac() {
186 let key = TEST_HMAC_KEY;
187 proptest!(|(input in ".{0,100}")| {
188 let bytes = input.as_bytes();
189 let expected = hmac_sha256(key, bytes);
190 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
191 let _ = incremental.update(bytes).collect::<Vec<_>>();
192 let actual: [u8; 32] = incremental.finalize().into();
193 assert_eq!(actual, expected);
194 });
195 }
196
197 #[test]
198 fn incremental_macs_are_valid() {
199 let key = TEST_HMAC_KEY;
200
201 proptest!(|(input in ".{50,100}")| {
202 let bytes = input.as_bytes();
203 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
204
205 let expected: Vec<_> = bytes
208 .chunks(incremental.chunk_size)
209 .scan(Vec::new(), |acc, chunk| {
210 acc.extend(chunk.iter());
211 Some(hmac_sha256(key, acc).to_vec())
212 })
213 .collect();
214
215 let mut actual: Vec<Vec<u8>> = bytes
216 .random_chunks(incremental.chunk_size)
217 .flat_map(|chunk| incremental.update(chunk).collect::<Vec<_>>())
218 .map(|out| out.into())
219 .map(|bs: [u8; 32]| bs.to_vec())
220 .collect();
221 if bytes.len() % incremental.chunk_size != 0 {
224 let last_hmac: [u8; 32] = incremental.finalize().into();
225 actual.push(last_hmac.to_vec());
226 }
227 assert_eq!(actual, expected);
228 });
229 }
230
231 #[test]
232 fn validating_simple_test() {
233 let key = TEST_HMAC_KEY;
234 let input = "this is a simple test input string";
235
236 let bytes = input.as_bytes();
237 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
238 let mut expected_macs: Vec<_> = incremental.update(bytes).collect();
239 expected_macs.push(incremental.finalize());
240
241 let expected_bytes: Vec<[u8; 32]> =
242 expected_macs.into_iter().map(|mac| mac.into()).collect();
243
244 {
245 let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(&expected_bytes);
246 validating
247 .update(bytes)
248 .expect("update: validation should succeed");
249 validating
250 .finalize()
251 .expect("finalize: validation should succeed");
252 }
253
254 {
255 let mut failing_first_update = expected_bytes.clone();
256 failing_first_update
257 .first_mut()
258 .expect("there must be at least one mac")[0] ^= 0xff;
259 let mut validating =
260 new_incremental(key, TEST_CHUNK_SIZE).validating(&failing_first_update);
261 validating.update(bytes).expect_err("MacError");
262 }
263
264 {
265 let mut failing_finalize = expected_bytes.clone();
266 failing_finalize
267 .last_mut()
268 .expect("there must be at least one mac")[0] ^= 0xff;
269 let mut validating =
270 new_incremental(key, TEST_CHUNK_SIZE).validating(&failing_finalize);
271 validating.update(bytes).expect("update should succeed");
272 validating.finalize().expect_err("MacError");
273 }
274
275 {
276 let missing_last_mac = &expected_bytes[0..expected_bytes.len() - 1];
277 let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(missing_last_mac);
278 validating.update(bytes).expect("update should succeed");
279 validating.finalize().expect_err("MacError");
280 }
281
282 {
283 let missing_first_mac: Vec<_> = expected_bytes.clone().into_iter().skip(1).collect();
284 let mut validating =
285 new_incremental(key, TEST_CHUNK_SIZE).validating(&missing_first_mac);
286 validating.update(bytes).expect_err("MacError");
287 }
288 std::hint::black_box(expected_bytes);
290 }
291
292 #[test]
293 fn validating_returns_right_size() {
294 let key = TEST_HMAC_KEY;
295 let input = "this is a simple test input string";
296
297 let bytes = input.as_bytes();
298 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
299 let mut expected_macs: Vec<_> = incremental.update(bytes).collect();
300 expected_macs.push(incremental.finalize());
301
302 let expected_bytes: Vec<[u8; 32]> =
303 expected_macs.into_iter().map(|mac| mac.into()).collect();
304
305 let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(&expected_bytes);
306
307 let input_chunks = bytes.chunks(16).collect::<Vec<_>>();
310 assert_eq!(3, input_chunks.len());
311 let expected_remainder = bytes.len() - TEST_CHUNK_SIZE;
312
313 for (expected_size, input) in std::iter::zip([0, TEST_CHUNK_SIZE, 0], input_chunks) {
314 assert_eq!(
315 expected_size,
316 validating
317 .update(input)
318 .expect("update: validation should succeed")
319 );
320 }
321 assert_eq!(
322 expected_remainder,
323 validating
324 .finalize()
325 .expect("finalize: validation should succeed")
326 );
327 }
328
329 #[test]
330 fn produce_and_validate() {
331 let key = TEST_HMAC_KEY;
332
333 proptest!(|(input in ".{0,100}")| {
334 let bytes = input.as_bytes();
335 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
336 let input_chunks = bytes.random_chunks(incremental.chunk_size*2);
337
338 let mut produced: Vec<[u8; 32]> = input_chunks.clone()
339 .flat_map(|chunk| incremental.update(chunk).collect::<Vec<_>>())
340 .map(|out| out.into())
341 .collect();
342 produced.push(incremental.finalize().into());
343
344 let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(&produced);
345 for chunk in input_chunks.clone() {
346 validating.update(chunk).expect("update: validation should succeed");
347 }
348 validating.finalize().expect("finalize: validation should succeed");
349 });
350 }
351
352 const KIBIBYTES: usize = 1024;
353 const MEBIBYTES: usize = 1024 * KIBIBYTES;
354 const GIBIBYTES: usize = 1024 * MEBIBYTES;
355
356 #[test]
357 fn chunk_sizes_sha256() {
358 for (data_size, expected) in [
359 (0, MINIMUM_CHUNK_SIZE),
360 (KIBIBYTES, MINIMUM_CHUNK_SIZE),
361 (10 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
362 (100 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
363 (MEBIBYTES, MINIMUM_CHUNK_SIZE),
364 (10 * MEBIBYTES, MINIMUM_CHUNK_SIZE),
365 (20 * MEBIBYTES, 80 * KIBIBYTES),
366 (100 * MEBIBYTES, 400 * KIBIBYTES),
367 (200 * MEBIBYTES, 800 * KIBIBYTES),
368 (256 * MEBIBYTES, MEBIBYTES),
369 (512 * MEBIBYTES, 2 * MEBIBYTES),
370 (GIBIBYTES, 2 * MEBIBYTES),
371 (2 * GIBIBYTES, 2 * MEBIBYTES),
372 ] {
373 let actual = calculate_chunk_size::<Sha256>(data_size);
374 assert_eq!(actual, expected);
375 }
376 }
377
378 #[test]
379 fn chunk_sizes_sha512() {
380 for (data_size, expected) in [
381 (0, MINIMUM_CHUNK_SIZE),
382 (KIBIBYTES, MINIMUM_CHUNK_SIZE),
383 (10 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
384 (100 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
385 (MEBIBYTES, MINIMUM_CHUNK_SIZE),
386 (10 * MEBIBYTES, 80 * KIBIBYTES),
387 (20 * MEBIBYTES, 160 * KIBIBYTES),
388 (100 * MEBIBYTES, 800 * KIBIBYTES),
389 (200 * MEBIBYTES, 1600 * KIBIBYTES),
390 (256 * MEBIBYTES, 2 * MEBIBYTES),
391 (512 * MEBIBYTES, 2 * MEBIBYTES),
392 (GIBIBYTES, 2 * MEBIBYTES),
393 ] {
394 let actual = calculate_chunk_size::<sha2::Sha512>(data_size);
395 assert_eq!(actual, expected);
396 }
397 }
398
399 #[test]
400 fn total_digest_size_is_never_too_big() {
401 fn total_digest_size(data_size: usize) -> usize {
402 let chunk_size = calculate_chunk_size::<Sha256>(data_size);
403 let num_chunks = std::cmp::max(1, data_size.div_ceil(chunk_size));
404 num_chunks * <Sha256 as OutputSizeUser>::OutputSize::USIZE
405 }
406 let config = ProptestConfig::with_cases(10_000);
407 proptest!(config, |(data_size in 256..256*MEBIBYTES)| {
408 assert!(total_digest_size(data_size) <= 8*KIBIBYTES)
409 });
410 proptest!(|(data_size_mib in 256_usize..2048)| {
411 assert!(total_digest_size(data_size_mib*MEBIBYTES) <= 32*KIBIBYTES)
412 });
413 }
414
415 #[derive(Clone)]
416 struct RandomChunks<'a, T, R: Rng> {
417 base: &'a [T],
418 distribution: UniformUsize,
419 rng: R,
420 }
421
422 impl<'a, T, R: Rng> Iterator for RandomChunks<'a, T, R> {
423 type Item = &'a [T];
424
425 fn next(&mut self) -> Option<Self::Item> {
426 if self.base.is_empty() {
427 None
428 } else {
429 let candidate = self.distribution.sample(&mut self.rng);
430 let chunk_size = std::cmp::min(candidate, self.base.len());
431 let (before, after) = self.base.split_at(chunk_size);
432 self.base = after;
433 Some(before)
434 }
435 }
436 }
437
438 trait RandomChunksIterator<T> {
439 fn random_chunks(&self, max_size: usize) -> RandomChunks<'_, T, ThreadRng>;
440 }
441
442 impl<T> RandomChunksIterator<T> for [T] {
443 fn random_chunks(&self, max_size: usize) -> RandomChunks<'_, T, ThreadRng> {
444 assert!(max_size > 0, "Maximal chunk size should be positive");
445 RandomChunks {
446 base: self,
447 distribution: UniformUsize::new_inclusive(0, max_size).expect("valid range"),
448 rng: Default::default(),
449 }
450 }
451 }
452}