1use aes::cipher::Unsigned;
7use hmac::Mac;
8use hmac::digest::generic_array::{ArrayLength, GenericArray};
9use sha2::digest::{FixedOutput, MacError, Output};
10
11#[derive(Clone)]
12pub struct Incremental<M: Mac + Clone> {
13    mac: M,
14    chunk_size: usize,
15    unused_length: usize,
16}
17
18#[derive(Clone)]
19pub struct Validating<M: Mac + Clone> {
20    incremental: Incremental<M>,
21    expected: Vec<Output<M>>,
23}
24
25const MINIMUM_CHUNK_SIZE: usize = 64 * 1024;
26const MAXIMUM_CHUNK_SIZE: usize = 2 * 1024 * 1024;
27const TARGET_TOTAL_DIGEST_SIZE: usize = 8 * 1024;
28
29pub const fn calculate_chunk_size<D>(data_size: usize) -> usize
30where
31    D: FixedOutput,
32    D::OutputSize: ArrayLength<u8>,
33{
34    assert!(
35        0 == TARGET_TOTAL_DIGEST_SIZE % D::OutputSize::USIZE,
36        "Target digest size should be a multiple of digest size"
37    );
38    let target_chunk_count = TARGET_TOTAL_DIGEST_SIZE / D::OutputSize::USIZE;
39    if data_size < target_chunk_count * MINIMUM_CHUNK_SIZE {
40        return MINIMUM_CHUNK_SIZE;
41    }
42    if data_size < target_chunk_count * MAXIMUM_CHUNK_SIZE {
43        return data_size.div_ceil(target_chunk_count);
44    }
45    MAXIMUM_CHUNK_SIZE
46}
47
48impl<M: Mac + Clone> Incremental<M> {
49    pub fn new(mac: M, chunk_size: usize) -> Self {
50        assert!(chunk_size > 0, "chunk size must be positive");
51        Self {
52            mac,
53            chunk_size,
54            unused_length: chunk_size,
55        }
56    }
57
58    pub fn validating<A, I>(self, macs: I) -> Validating<M>
59    where
60        A: AsRef<[u8]>,
61        I: IntoIterator<Item = A>,
62        <I as IntoIterator>::IntoIter: DoubleEndedIterator,
63    {
64        let expected = macs
65            .into_iter()
66            .map(|mac| GenericArray::<u8, M::OutputSize>::from_slice(mac.as_ref()).to_owned())
67            .rev()
68            .collect();
69        Validating {
70            incremental: self,
71            expected,
72        }
73    }
74
75    pub fn update<'a>(&'a mut self, bytes: &'a [u8]) -> impl Iterator<Item = Output<M>> + 'a {
76        let split_point = std::cmp::min(bytes.len(), self.unused_length);
77        let (to_write, overflow) = bytes.split_at(split_point);
78
79        std::iter::once(to_write)
80            .chain(overflow.chunks(self.chunk_size))
81            .flat_map(move |chunk| self.update_chunk(chunk))
82    }
83
84    pub fn finalize(self) -> Output<M> {
85        self.mac.finalize().into_bytes()
86    }
87
88    fn update_chunk(&mut self, bytes: &[u8]) -> Option<Output<M>> {
89        assert!(bytes.len() <= self.unused_length);
90        self.mac.update(bytes);
91        self.unused_length -= bytes.len();
92        if self.unused_length == 0 {
93            self.unused_length = self.chunk_size;
94            let mac = self.mac.clone();
95            Some(mac.finalize().into_bytes())
96        } else {
97            None
98        }
99    }
100
101    fn pending_bytes_size(&self) -> usize {
102        self.chunk_size - self.unused_length
103    }
104}
105
106impl<M: Mac + Clone> Validating<M> {
107    pub fn update(&mut self, bytes: &[u8]) -> Result<usize, MacError> {
108        let mut result = Ok(0);
109        let macs = self.incremental.update(bytes);
110
111        let mut whole_chunks = 0;
112        for mac in macs {
113            match self.expected.last() {
114                Some(expected) if expected == &mac => {
115                    whole_chunks += 1;
116                    self.expected.pop();
117                }
118                _ => {
119                    result = Err(MacError);
120                }
121            }
122        }
123        let validated_bytes = whole_chunks * self.incremental.chunk_size;
124        result.map(|_| validated_bytes)
125    }
126
127    pub fn finalize(self) -> Result<usize, MacError> {
128        let pending_bytes_size = self.incremental.pending_bytes_size();
129        let mac = self.incremental.finalize();
130        match &self.expected[..] {
131            [expected] if expected == &mac => Ok(pending_bytes_size),
132            _ => Err(MacError),
133        }
134    }
135}
136
137#[cfg(test)]
138mod test {
139    use const_str::hex;
140    use hmac::Hmac;
141    use proptest::prelude::*;
142    use rand::distr::uniform::{UniformSampler as _, UniformUsize};
143    use rand::prelude::{Rng, ThreadRng};
144    use sha2::Sha256;
145    use sha2::digest::OutputSizeUser;
146
147    use super::*;
148    use crate::crypto::hmac_sha256;
149
150    const TEST_HMAC_KEY: &[u8] =
151        &hex!("a83481457efecc69ad1342e21d9c0297f71debbf5c9304b4c1b2e433c1a78f98");
152
153    const TEST_CHUNK_SIZE: usize = 32;
154
155    fn new_incremental(key: &[u8], chunk_size: usize) -> Incremental<Hmac<Sha256>> {
156        let hmac = Hmac::<Sha256>::new_from_slice(key)
157            .expect("Should be able to create a new HMAC instance");
158        Incremental::new(hmac, chunk_size)
159    }
160
161    #[test]
162    #[should_panic]
163    fn chunk_size_zero() {
164        new_incremental(&[], 0);
165    }
166
167    #[test]
168    fn simple_test() {
169        let key = TEST_HMAC_KEY;
170        let input = "this is a simple test input string which is longer than the chunk";
171
172        let bytes = input.as_bytes();
173        let expected = hmac_sha256(key, bytes);
174        let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
175        let _ = incremental.update(bytes).collect::<Vec<_>>();
176        let digest = incremental.finalize();
177        let actual: [u8; 32] = digest.into();
178        assert_eq!(actual, expected);
179    }
180
181    #[test]
182    fn final_result_should_be_equal_to_non_incremental_hmac() {
183        let key = TEST_HMAC_KEY;
184        proptest!(|(input in ".{0,100}")| {
185            let bytes = input.as_bytes();
186            let expected = hmac_sha256(key, bytes);
187            let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
188            let _ = incremental.update(bytes).collect::<Vec<_>>();
189            let actual: [u8; 32] = incremental.finalize().into();
190            assert_eq!(actual, expected);
191        });
192    }
193
194    #[test]
195    fn incremental_macs_are_valid() {
196        let key = TEST_HMAC_KEY;
197
198        proptest!(|(input in ".{50,100}")| {
199            let bytes = input.as_bytes();
200            let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
201
202            let expected: Vec<_> = bytes
205                .chunks(incremental.chunk_size)
206                .scan(Vec::new(), |acc, chunk| {
207                    acc.extend(chunk.iter());
208                    Some(hmac_sha256(key, acc).to_vec())
209                })
210                .collect();
211
212            let mut actual: Vec<Vec<u8>> = bytes
213                .random_chunks(incremental.chunk_size)
214                .flat_map(|chunk| incremental.update(chunk).collect::<Vec<_>>())
215                .map(|out| out.into())
216                .map(|bs: [u8; 32]| bs.to_vec())
217                .collect();
218            if bytes.len() % incremental.chunk_size != 0 {
221                let last_hmac: [u8; 32] = incremental.finalize().into();
222                actual.push(last_hmac.to_vec());
223            }
224            assert_eq!(actual, expected);
225        });
226    }
227
228    #[test]
229    fn validating_simple_test() {
230        let key = TEST_HMAC_KEY;
231        let input = "this is a simple test input string";
232
233        let bytes = input.as_bytes();
234        let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
235        let mut expected_macs: Vec<_> = incremental.update(bytes).collect();
236        expected_macs.push(incremental.finalize());
237
238        let expected_bytes: Vec<[u8; 32]> =
239            expected_macs.into_iter().map(|mac| mac.into()).collect();
240
241        {
242            let mut validating =
243                new_incremental(key, TEST_CHUNK_SIZE).validating(expected_bytes.clone());
244            validating
245                .update(bytes)
246                .expect("update: validation should succeed");
247            validating
248                .finalize()
249                .expect("finalize: validation should succeed");
250        }
251
252        {
253            let mut failing_first_update = expected_bytes.clone();
254            failing_first_update
255                .first_mut()
256                .expect("there must be at least one mac")[0] ^= 0xff;
257            let mut validating =
258                new_incremental(key, TEST_CHUNK_SIZE).validating(failing_first_update);
259            validating.update(bytes).expect_err("MacError");
260        }
261
262        {
263            let mut failing_finalize = expected_bytes.clone();
264            failing_finalize
265                .last_mut()
266                .expect("there must be at least one mac")[0] ^= 0xff;
267            let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(failing_finalize);
268            validating.update(bytes).expect("update should succeed");
269            validating.finalize().expect_err("MacError");
270        }
271
272        {
273            let missing_last_mac = &expected_bytes[0..expected_bytes.len() - 1];
274            let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(missing_last_mac);
275            validating.update(bytes).expect("update should succeed");
276            validating.finalize().expect_err("MacError");
277        }
278
279        {
280            let missing_first_mac: Vec<_> = expected_bytes.clone().into_iter().skip(1).collect();
281            let mut validating =
282                new_incremental(key, TEST_CHUNK_SIZE).validating(missing_first_mac);
283            validating.update(bytes).expect_err("MacError");
284        }
285        std::hint::black_box(expected_bytes);
287    }
288
289    #[test]
290    fn validating_returns_right_size() {
291        let key = TEST_HMAC_KEY;
292        let input = "this is a simple test input string";
293
294        let bytes = input.as_bytes();
295        let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
296        let mut expected_macs: Vec<_> = incremental.update(bytes).collect();
297        expected_macs.push(incremental.finalize());
298
299        let expected_bytes: Vec<[u8; 32]> =
300            expected_macs.into_iter().map(|mac| mac.into()).collect();
301
302        let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(expected_bytes);
303
304        let input_chunks = bytes.chunks(16).collect::<Vec<_>>();
307        assert_eq!(3, input_chunks.len());
308        let expected_remainder = bytes.len() - TEST_CHUNK_SIZE;
309
310        for (expected_size, input) in std::iter::zip([0, TEST_CHUNK_SIZE, 0], input_chunks) {
311            assert_eq!(
312                expected_size,
313                validating
314                    .update(input)
315                    .expect("update: validation should succeed")
316            );
317        }
318        assert_eq!(
319            expected_remainder,
320            validating
321                .finalize()
322                .expect("finalize: validation should succeed")
323        );
324    }
325
326    #[test]
327    fn produce_and_validate() {
328        let key = TEST_HMAC_KEY;
329
330        proptest!(|(input in ".{0,100}")| {
331            let bytes = input.as_bytes();
332            let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
333            let input_chunks = bytes.random_chunks(incremental.chunk_size*2);
334
335            let mut produced: Vec<[u8; 32]> = input_chunks.clone()
336                .flat_map(|chunk| incremental.update(chunk).collect::<Vec<_>>())
337                .map(|out| out.into())
338                .collect();
339            produced.push(incremental.finalize().into());
340
341            let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(produced);
342            for chunk in input_chunks.clone() {
343                validating.update(chunk).expect("update: validation should succeed");
344            }
345            validating.finalize().expect("finalize: validation should succeed");
346        });
347    }
348
349    const KIBIBYTES: usize = 1024;
350    const MEBIBYTES: usize = 1024 * KIBIBYTES;
351    const GIBIBYTES: usize = 1024 * MEBIBYTES;
352
353    #[test]
354    fn chunk_sizes_sha256() {
355        for (data_size, expected) in [
356            (0, MINIMUM_CHUNK_SIZE),
357            (KIBIBYTES, MINIMUM_CHUNK_SIZE),
358            (10 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
359            (100 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
360            (MEBIBYTES, MINIMUM_CHUNK_SIZE),
361            (10 * MEBIBYTES, MINIMUM_CHUNK_SIZE),
362            (20 * MEBIBYTES, 80 * KIBIBYTES),
363            (100 * MEBIBYTES, 400 * KIBIBYTES),
364            (200 * MEBIBYTES, 800 * KIBIBYTES),
365            (256 * MEBIBYTES, MEBIBYTES),
366            (512 * MEBIBYTES, 2 * MEBIBYTES),
367            (GIBIBYTES, 2 * MEBIBYTES),
368            (2 * GIBIBYTES, 2 * MEBIBYTES),
369        ] {
370            let actual = calculate_chunk_size::<Sha256>(data_size);
371            assert_eq!(actual, expected);
372        }
373    }
374
375    #[test]
376    fn chunk_sizes_sha512() {
377        for (data_size, expected) in [
378            (0, MINIMUM_CHUNK_SIZE),
379            (KIBIBYTES, MINIMUM_CHUNK_SIZE),
380            (10 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
381            (100 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
382            (MEBIBYTES, MINIMUM_CHUNK_SIZE),
383            (10 * MEBIBYTES, 80 * KIBIBYTES),
384            (20 * MEBIBYTES, 160 * KIBIBYTES),
385            (100 * MEBIBYTES, 800 * KIBIBYTES),
386            (200 * MEBIBYTES, 1600 * KIBIBYTES),
387            (256 * MEBIBYTES, 2 * MEBIBYTES),
388            (512 * MEBIBYTES, 2 * MEBIBYTES),
389            (GIBIBYTES, 2 * MEBIBYTES),
390        ] {
391            let actual = calculate_chunk_size::<sha2::Sha512>(data_size);
392            assert_eq!(actual, expected);
393        }
394    }
395
396    #[test]
397    fn total_digest_size_is_never_too_big() {
398        fn total_digest_size(data_size: usize) -> usize {
399            let chunk_size = calculate_chunk_size::<Sha256>(data_size);
400            let num_chunks = std::cmp::max(1, data_size.div_ceil(chunk_size));
401            num_chunks * <Sha256 as OutputSizeUser>::OutputSize::USIZE
402        }
403        let config = ProptestConfig::with_cases(10_000);
404        proptest!(config, |(data_size in 256..256*MEBIBYTES)| {
405            assert!(total_digest_size(data_size) <= 8*KIBIBYTES)
406        });
407        proptest!(|(data_size_mib in 256_usize..2048)| {
408            assert!(total_digest_size(data_size_mib*MEBIBYTES) <= 32*KIBIBYTES)
409        });
410    }
411
412    #[derive(Clone)]
413    struct RandomChunks<'a, T, R: Rng> {
414        base: &'a [T],
415        distribution: UniformUsize,
416        rng: R,
417    }
418
419    impl<'a, T, R: Rng> Iterator for RandomChunks<'a, T, R> {
420        type Item = &'a [T];
421
422        fn next(&mut self) -> Option<Self::Item> {
423            if self.base.is_empty() {
424                None
425            } else {
426                let candidate = self.distribution.sample(&mut self.rng);
427                let chunk_size = std::cmp::min(candidate, self.base.len());
428                let (before, after) = self.base.split_at(chunk_size);
429                self.base = after;
430                Some(before)
431            }
432        }
433    }
434
435    trait RandomChunksIterator<T> {
436        fn random_chunks(&self, max_size: usize) -> RandomChunks<'_, T, ThreadRng>;
437    }
438
439    impl<T> RandomChunksIterator<T> for [T] {
440        fn random_chunks(&self, max_size: usize) -> RandomChunks<'_, T, ThreadRng> {
441            assert!(max_size > 0, "Maximal chunk size should be positive");
442            RandomChunks {
443                base: self,
444                distribution: UniformUsize::new_inclusive(0, max_size).expect("valid range"),
445                rng: Default::default(),
446            }
447        }
448    }
449}