libsignal_protocol/
incremental_mac.rs

1//
2// Copyright 2023 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use aes::cipher::Unsigned;
7use hmac::Mac;
8use hmac::digest::generic_array::{ArrayLength, GenericArray};
9use sha2::digest::{FixedOutput, MacError, Output};
10
11#[derive(Clone)]
12pub struct Incremental<M: Mac + Clone> {
13    mac: M,
14    chunk_size: usize,
15    unused_length: usize,
16}
17
18#[derive(Clone)]
19pub struct Validating<M: Mac + Clone> {
20    incremental: Incremental<M>,
21    // Expected MACs in reversed order, to efficiently pop them from off the end
22    expected: Vec<Output<M>>,
23}
24
25const MINIMUM_CHUNK_SIZE: usize = 64 * 1024;
26const MAXIMUM_CHUNK_SIZE: usize = 2 * 1024 * 1024;
27const TARGET_TOTAL_DIGEST_SIZE: usize = 8 * 1024;
28
29pub const fn calculate_chunk_size<D>(data_size: usize) -> usize
30where
31    D: FixedOutput,
32    D::OutputSize: ArrayLength<u8>,
33{
34    assert!(
35        0 == TARGET_TOTAL_DIGEST_SIZE % D::OutputSize::USIZE,
36        "Target digest size should be a multiple of digest size"
37    );
38    let target_chunk_count = TARGET_TOTAL_DIGEST_SIZE / D::OutputSize::USIZE;
39    if data_size < target_chunk_count * MINIMUM_CHUNK_SIZE {
40        return MINIMUM_CHUNK_SIZE;
41    }
42    if data_size < target_chunk_count * MAXIMUM_CHUNK_SIZE {
43        return data_size.div_ceil(target_chunk_count);
44    }
45    MAXIMUM_CHUNK_SIZE
46}
47
48impl<M: Mac + Clone> Incremental<M> {
49    pub fn new(mac: M, chunk_size: usize) -> Self {
50        assert!(chunk_size > 0, "chunk size must be positive");
51        Self {
52            mac,
53            chunk_size,
54            unused_length: chunk_size,
55        }
56    }
57
58    pub fn validating<'a, A, I>(self, macs: I) -> Validating<M>
59    where
60        // This is a clunky way to spell an iterator over `&'a [u8; M::OutputSize]`, which requires
61        // feature(generic_const_exprs). The Sized requirement is because the older version of
62        // generic-array we're using (via the digest crate) provides From<&[u8]>; it was removed in
63        // generic-array 1.0.
64        A: Into<&'a GenericArray<u8, M::OutputSize>> + std::ops::Deref<Target: Sized>,
65        I: IntoIterator<Item = A, IntoIter: DoubleEndedIterator>,
66    {
67        let expected = macs
68            .into_iter()
69            .map(|mac| mac.into().to_owned())
70            .rev()
71            .collect();
72        Validating {
73            incremental: self,
74            expected,
75        }
76    }
77
78    pub fn update<'a>(&'a mut self, bytes: &'a [u8]) -> impl Iterator<Item = Output<M>> + 'a {
79        let split_point = std::cmp::min(bytes.len(), self.unused_length);
80        let (to_write, overflow) = bytes.split_at(split_point);
81
82        std::iter::once(to_write)
83            .chain(overflow.chunks(self.chunk_size))
84            .flat_map(move |chunk| self.update_chunk(chunk))
85    }
86
87    pub fn finalize(self) -> Output<M> {
88        self.mac.finalize().into_bytes()
89    }
90
91    fn update_chunk(&mut self, bytes: &[u8]) -> Option<Output<M>> {
92        assert!(bytes.len() <= self.unused_length);
93        self.mac.update(bytes);
94        self.unused_length -= bytes.len();
95        if self.unused_length == 0 {
96            self.unused_length = self.chunk_size;
97            let mac = self.mac.clone();
98            Some(mac.finalize().into_bytes())
99        } else {
100            None
101        }
102    }
103
104    fn pending_bytes_size(&self) -> usize {
105        self.chunk_size - self.unused_length
106    }
107}
108
109impl<M: Mac + Clone> Validating<M> {
110    pub fn update(&mut self, bytes: &[u8]) -> Result<usize, MacError> {
111        let mut result = Ok(0);
112        let macs = self.incremental.update(bytes);
113
114        let mut whole_chunks = 0;
115        for mac in macs {
116            match self.expected.last() {
117                Some(expected) if expected == &mac => {
118                    whole_chunks += 1;
119                    self.expected.pop();
120                }
121                _ => {
122                    result = Err(MacError);
123                }
124            }
125        }
126        let validated_bytes = whole_chunks * self.incremental.chunk_size;
127        result.map(|_| validated_bytes)
128    }
129
130    pub fn finalize(self) -> Result<usize, MacError> {
131        let pending_bytes_size = self.incremental.pending_bytes_size();
132        let mac = self.incremental.finalize();
133        match &self.expected[..] {
134            [expected] if expected == &mac => Ok(pending_bytes_size),
135            _ => Err(MacError),
136        }
137    }
138}
139
140#[cfg(test)]
141mod test {
142    use const_str::hex;
143    use hmac::Hmac;
144    use proptest::prelude::*;
145    use rand::distr::uniform::{UniformSampler as _, UniformUsize};
146    use rand::prelude::{Rng, ThreadRng};
147    use sha2::Sha256;
148    use sha2::digest::OutputSizeUser;
149
150    use super::*;
151    use crate::crypto::hmac_sha256;
152
153    const TEST_HMAC_KEY: &[u8] =
154        &hex!("a83481457efecc69ad1342e21d9c0297f71debbf5c9304b4c1b2e433c1a78f98");
155
156    const TEST_CHUNK_SIZE: usize = 32;
157
158    fn new_incremental(key: &[u8], chunk_size: usize) -> Incremental<Hmac<Sha256>> {
159        let hmac = Hmac::<Sha256>::new_from_slice(key)
160            .expect("Should be able to create a new HMAC instance");
161        Incremental::new(hmac, chunk_size)
162    }
163
164    #[test]
165    #[should_panic]
166    fn chunk_size_zero() {
167        new_incremental(&[], 0);
168    }
169
170    #[test]
171    fn simple_test() {
172        let key = TEST_HMAC_KEY;
173        let input = "this is a simple test input string which is longer than the chunk";
174
175        let bytes = input.as_bytes();
176        let expected = hmac_sha256(key, bytes);
177        let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
178        let _ = incremental.update(bytes).collect::<Vec<_>>();
179        let digest = incremental.finalize();
180        let actual: [u8; 32] = digest.into();
181        assert_eq!(actual, expected);
182    }
183
184    #[test]
185    fn final_result_should_be_equal_to_non_incremental_hmac() {
186        let key = TEST_HMAC_KEY;
187        proptest!(|(input in ".{0,100}")| {
188            let bytes = input.as_bytes();
189            let expected = hmac_sha256(key, bytes);
190            let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
191            let _ = incremental.update(bytes).collect::<Vec<_>>();
192            let actual: [u8; 32] = incremental.finalize().into();
193            assert_eq!(actual, expected);
194        });
195    }
196
197    #[test]
198    fn incremental_macs_are_valid() {
199        let key = TEST_HMAC_KEY;
200
201        proptest!(|(input in ".{50,100}")| {
202            let bytes = input.as_bytes();
203            let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
204
205            // Manually breaking the input in buffer-sized chunks and calculating the HMACs on the
206            // ever-increasing input prefix.
207            let expected: Vec<_> = bytes
208                .chunks(incremental.chunk_size)
209                .scan(Vec::new(), |acc, chunk| {
210                    acc.extend(chunk.iter());
211                    Some(hmac_sha256(key, acc).to_vec())
212                })
213                .collect();
214
215            let mut actual: Vec<Vec<u8>> = bytes
216                .random_chunks(incremental.chunk_size)
217                .flat_map(|chunk| incremental.update(chunk).collect::<Vec<_>>())
218                .map(|out| out.into())
219                .map(|bs: [u8; 32]| bs.to_vec())
220                .collect();
221            // If the input is not an exact multiple of the chunk_size, there are some leftovers in
222            // the incremental that need to be accounted for.
223            if bytes.len() % incremental.chunk_size != 0 {
224                let last_hmac: [u8; 32] = incremental.finalize().into();
225                actual.push(last_hmac.to_vec());
226            }
227            assert_eq!(actual, expected);
228        });
229    }
230
231    #[test]
232    fn validating_simple_test() {
233        let key = TEST_HMAC_KEY;
234        let input = "this is a simple test input string";
235
236        let bytes = input.as_bytes();
237        let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
238        let mut expected_macs: Vec<_> = incremental.update(bytes).collect();
239        expected_macs.push(incremental.finalize());
240
241        let expected_bytes: Vec<[u8; 32]> =
242            expected_macs.into_iter().map(|mac| mac.into()).collect();
243
244        {
245            let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(&expected_bytes);
246            validating
247                .update(bytes)
248                .expect("update: validation should succeed");
249            validating
250                .finalize()
251                .expect("finalize: validation should succeed");
252        }
253
254        {
255            let mut failing_first_update = expected_bytes.clone();
256            failing_first_update
257                .first_mut()
258                .expect("there must be at least one mac")[0] ^= 0xff;
259            let mut validating =
260                new_incremental(key, TEST_CHUNK_SIZE).validating(&failing_first_update);
261            validating.update(bytes).expect_err("MacError");
262        }
263
264        {
265            let mut failing_finalize = expected_bytes.clone();
266            failing_finalize
267                .last_mut()
268                .expect("there must be at least one mac")[0] ^= 0xff;
269            let mut validating =
270                new_incremental(key, TEST_CHUNK_SIZE).validating(&failing_finalize);
271            validating.update(bytes).expect("update should succeed");
272            validating.finalize().expect_err("MacError");
273        }
274
275        {
276            let missing_last_mac = &expected_bytes[0..expected_bytes.len() - 1];
277            let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(missing_last_mac);
278            validating.update(bytes).expect("update should succeed");
279            validating.finalize().expect_err("MacError");
280        }
281
282        {
283            let missing_first_mac: Vec<_> = expected_bytes.clone().into_iter().skip(1).collect();
284            let mut validating =
285                new_incremental(key, TEST_CHUNK_SIZE).validating(&missing_first_mac);
286            validating.update(bytes).expect_err("MacError");
287        }
288        // To make clippy happy and allow extending the test in the future
289        std::hint::black_box(expected_bytes);
290    }
291
292    #[test]
293    fn validating_returns_right_size() {
294        let key = TEST_HMAC_KEY;
295        let input = "this is a simple test input string";
296
297        let bytes = input.as_bytes();
298        let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
299        let mut expected_macs: Vec<_> = incremental.update(bytes).collect();
300        expected_macs.push(incremental.finalize());
301
302        let expected_bytes: Vec<[u8; 32]> =
303            expected_macs.into_iter().map(|mac| mac.into()).collect();
304
305        let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(&expected_bytes);
306
307        // Splitting input into chunks of 16 will give us one full incremental chunk + 3 bytes
308        // authenticated by call to finalize.
309        let input_chunks = bytes.chunks(16).collect::<Vec<_>>();
310        assert_eq!(3, input_chunks.len());
311        let expected_remainder = bytes.len() - TEST_CHUNK_SIZE;
312
313        for (expected_size, input) in std::iter::zip([0, TEST_CHUNK_SIZE, 0], input_chunks) {
314            assert_eq!(
315                expected_size,
316                validating
317                    .update(input)
318                    .expect("update: validation should succeed")
319            );
320        }
321        assert_eq!(
322            expected_remainder,
323            validating
324                .finalize()
325                .expect("finalize: validation should succeed")
326        );
327    }
328
329    #[test]
330    fn produce_and_validate() {
331        let key = TEST_HMAC_KEY;
332
333        proptest!(|(input in ".{0,100}")| {
334            let bytes = input.as_bytes();
335            let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
336            let input_chunks = bytes.random_chunks(incremental.chunk_size*2);
337
338            let mut produced: Vec<[u8; 32]> = input_chunks.clone()
339                .flat_map(|chunk| incremental.update(chunk).collect::<Vec<_>>())
340                .map(|out| out.into())
341                .collect();
342            produced.push(incremental.finalize().into());
343
344            let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(&produced);
345            for chunk in input_chunks.clone() {
346                validating.update(chunk).expect("update: validation should succeed");
347            }
348            validating.finalize().expect("finalize: validation should succeed");
349        });
350    }
351
352    const KIBIBYTES: usize = 1024;
353    const MEBIBYTES: usize = 1024 * KIBIBYTES;
354    const GIBIBYTES: usize = 1024 * MEBIBYTES;
355
356    #[test]
357    fn chunk_sizes_sha256() {
358        for (data_size, expected) in [
359            (0, MINIMUM_CHUNK_SIZE),
360            (KIBIBYTES, MINIMUM_CHUNK_SIZE),
361            (10 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
362            (100 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
363            (MEBIBYTES, MINIMUM_CHUNK_SIZE),
364            (10 * MEBIBYTES, MINIMUM_CHUNK_SIZE),
365            (20 * MEBIBYTES, 80 * KIBIBYTES),
366            (100 * MEBIBYTES, 400 * KIBIBYTES),
367            (200 * MEBIBYTES, 800 * KIBIBYTES),
368            (256 * MEBIBYTES, MEBIBYTES),
369            (512 * MEBIBYTES, 2 * MEBIBYTES),
370            (GIBIBYTES, 2 * MEBIBYTES),
371            (2 * GIBIBYTES, 2 * MEBIBYTES),
372        ] {
373            let actual = calculate_chunk_size::<Sha256>(data_size);
374            assert_eq!(actual, expected);
375        }
376    }
377
378    #[test]
379    fn chunk_sizes_sha512() {
380        for (data_size, expected) in [
381            (0, MINIMUM_CHUNK_SIZE),
382            (KIBIBYTES, MINIMUM_CHUNK_SIZE),
383            (10 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
384            (100 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
385            (MEBIBYTES, MINIMUM_CHUNK_SIZE),
386            (10 * MEBIBYTES, 80 * KIBIBYTES),
387            (20 * MEBIBYTES, 160 * KIBIBYTES),
388            (100 * MEBIBYTES, 800 * KIBIBYTES),
389            (200 * MEBIBYTES, 1600 * KIBIBYTES),
390            (256 * MEBIBYTES, 2 * MEBIBYTES),
391            (512 * MEBIBYTES, 2 * MEBIBYTES),
392            (GIBIBYTES, 2 * MEBIBYTES),
393        ] {
394            let actual = calculate_chunk_size::<sha2::Sha512>(data_size);
395            assert_eq!(actual, expected);
396        }
397    }
398
399    #[test]
400    fn total_digest_size_is_never_too_big() {
401        fn total_digest_size(data_size: usize) -> usize {
402            let chunk_size = calculate_chunk_size::<Sha256>(data_size);
403            let num_chunks = std::cmp::max(1, data_size.div_ceil(chunk_size));
404            num_chunks * <Sha256 as OutputSizeUser>::OutputSize::USIZE
405        }
406        let config = ProptestConfig::with_cases(10_000);
407        proptest!(config, |(data_size in 256..256*MEBIBYTES)| {
408            assert!(total_digest_size(data_size) <= 8*KIBIBYTES)
409        });
410        proptest!(|(data_size_mib in 256_usize..2048)| {
411            assert!(total_digest_size(data_size_mib*MEBIBYTES) <= 32*KIBIBYTES)
412        });
413    }
414
415    #[derive(Clone)]
416    struct RandomChunks<'a, T, R: Rng> {
417        base: &'a [T],
418        distribution: UniformUsize,
419        rng: R,
420    }
421
422    impl<'a, T, R: Rng> Iterator for RandomChunks<'a, T, R> {
423        type Item = &'a [T];
424
425        fn next(&mut self) -> Option<Self::Item> {
426            if self.base.is_empty() {
427                None
428            } else {
429                let candidate = self.distribution.sample(&mut self.rng);
430                let chunk_size = std::cmp::min(candidate, self.base.len());
431                let (before, after) = self.base.split_at(chunk_size);
432                self.base = after;
433                Some(before)
434            }
435        }
436    }
437
438    trait RandomChunksIterator<T> {
439        fn random_chunks(&self, max_size: usize) -> RandomChunks<'_, T, ThreadRng>;
440    }
441
442    impl<T> RandomChunksIterator<T> for [T] {
443        fn random_chunks(&self, max_size: usize) -> RandomChunks<'_, T, ThreadRng> {
444            assert!(max_size > 0, "Maximal chunk size should be positive");
445            RandomChunks {
446                base: self,
447                distribution: UniformUsize::new_inclusive(0, max_size).expect("valid range"),
448                rng: Default::default(),
449            }
450        }
451    }
452}