1use aes::cipher::Unsigned;
7use hmac::digest::generic_array::{ArrayLength, GenericArray};
8use hmac::Mac;
9use sha2::digest::{FixedOutput, MacError, Output};
10
11#[derive(Clone)]
12pub struct Incremental<M: Mac + Clone> {
13 mac: M,
14 chunk_size: usize,
15 unused_length: usize,
16}
17
18#[derive(Clone)]
19pub struct Validating<M: Mac + Clone> {
20 incremental: Incremental<M>,
21 expected: Vec<Output<M>>,
23}
24
25const MINIMUM_CHUNK_SIZE: usize = 64 * 1024;
26const MAXIMUM_CHUNK_SIZE: usize = 2 * 1024 * 1024;
27const TARGET_TOTAL_DIGEST_SIZE: usize = 8 * 1024;
28
29pub const fn calculate_chunk_size<D>(data_size: usize) -> usize
30where
31 D: FixedOutput,
32 D::OutputSize: ArrayLength<u8>,
33{
34 assert!(
35 0 == TARGET_TOTAL_DIGEST_SIZE % D::OutputSize::USIZE,
36 "Target digest size should be a multiple of digest size"
37 );
38 let target_chunk_count = TARGET_TOTAL_DIGEST_SIZE / D::OutputSize::USIZE;
39 if data_size < target_chunk_count * MINIMUM_CHUNK_SIZE {
40 return MINIMUM_CHUNK_SIZE;
41 }
42 if data_size < target_chunk_count * MAXIMUM_CHUNK_SIZE {
43 return (data_size + target_chunk_count - 1) / target_chunk_count;
44 }
45 MAXIMUM_CHUNK_SIZE
46}
47
48impl<M: Mac + Clone> Incremental<M> {
49 pub fn new(mac: M, chunk_size: usize) -> Self {
50 assert!(chunk_size > 0, "chunk size must be positive");
51 Self {
52 mac,
53 chunk_size,
54 unused_length: chunk_size,
55 }
56 }
57
58 pub fn validating<A, I>(self, macs: I) -> Validating<M>
59 where
60 A: AsRef<[u8]>,
61 I: IntoIterator<Item = A>,
62 <I as IntoIterator>::IntoIter: DoubleEndedIterator,
63 {
64 let expected = macs
65 .into_iter()
66 .map(|mac| GenericArray::<u8, M::OutputSize>::from_slice(mac.as_ref()).to_owned())
67 .rev()
68 .collect();
69 Validating {
70 incremental: self,
71 expected,
72 }
73 }
74
75 pub fn update<'a>(&'a mut self, bytes: &'a [u8]) -> impl Iterator<Item = Output<M>> + 'a {
76 let split_point = std::cmp::min(bytes.len(), self.unused_length);
77 let (to_write, overflow) = bytes.split_at(split_point);
78
79 std::iter::once(to_write)
80 .chain(overflow.chunks(self.chunk_size))
81 .flat_map(move |chunk| self.update_chunk(chunk))
82 }
83
84 pub fn finalize(self) -> Output<M> {
85 self.mac.finalize().into_bytes()
86 }
87
88 fn update_chunk(&mut self, bytes: &[u8]) -> Option<Output<M>> {
89 assert!(bytes.len() <= self.unused_length);
90 self.mac.update(bytes);
91 self.unused_length -= bytes.len();
92 if self.unused_length == 0 {
93 self.unused_length = self.chunk_size;
94 let mac = self.mac.clone();
95 Some(mac.finalize().into_bytes())
96 } else {
97 None
98 }
99 }
100
101 fn pending_bytes_size(&self) -> usize {
102 self.chunk_size - self.unused_length
103 }
104}
105
106impl<M: Mac + Clone> Validating<M> {
107 pub fn update(&mut self, bytes: &[u8]) -> Result<usize, MacError> {
108 let mut result = Ok(0);
109 let macs = self.incremental.update(bytes);
110
111 let mut whole_chunks = 0;
112 for mac in macs {
113 match self.expected.last() {
114 Some(expected) if expected == &mac => {
115 whole_chunks += 1;
116 self.expected.pop();
117 }
118 _ => {
119 result = Err(MacError);
120 }
121 }
122 }
123 let validated_bytes = whole_chunks * self.incremental.chunk_size;
124 result.map(|_| validated_bytes)
125 }
126
127 pub fn finalize(self) -> Result<usize, MacError> {
128 let pending_bytes_size = self.incremental.pending_bytes_size();
129 let mac = self.incremental.finalize();
130 match &self.expected[..] {
131 [expected] if expected == &mac => Ok(pending_bytes_size),
132 _ => Err(MacError),
133 }
134 }
135}
136
137#[cfg(test)]
138mod test {
139 use hex_literal::hex;
140 use hmac::Hmac;
141 use proptest::prelude::*;
142 use rand::distributions::Uniform;
143 use rand::prelude::{Rng, ThreadRng};
144 use sha2::digest::OutputSizeUser;
145 use sha2::Sha256;
146
147 use super::*;
148 use crate::crypto::hmac_sha256;
149
150 const TEST_HMAC_KEY: &[u8] =
151 &hex!("a83481457efecc69ad1342e21d9c0297f71debbf5c9304b4c1b2e433c1a78f98");
152
153 const TEST_CHUNK_SIZE: usize = 32;
154
155 fn new_incremental(key: &[u8], chunk_size: usize) -> Incremental<Hmac<Sha256>> {
156 let hmac = Hmac::<Sha256>::new_from_slice(key)
157 .expect("Should be able to create a new HMAC instance");
158 Incremental::new(hmac, chunk_size)
159 }
160
161 #[test]
162 #[should_panic]
163 fn chunk_size_zero() {
164 new_incremental(&[], 0);
165 }
166
167 #[test]
168 fn simple_test() {
169 let key = TEST_HMAC_KEY;
170 let input = "this is a simple test input string which is longer than the chunk";
171
172 let bytes = input.as_bytes();
173 let expected = hmac_sha256(key, bytes);
174 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
175 let _ = incremental.update(bytes).collect::<Vec<_>>();
176 let digest = incremental.finalize();
177 let actual: [u8; 32] = digest.into();
178 assert_eq!(actual, expected);
179 }
180
181 #[test]
182 fn final_result_should_be_equal_to_non_incremental_hmac() {
183 let key = TEST_HMAC_KEY;
184 proptest!(|(input in ".{0,100}")| {
185 let bytes = input.as_bytes();
186 let expected = hmac_sha256(key, bytes);
187 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
188 let _ = incremental.update(bytes).collect::<Vec<_>>();
189 let actual: [u8; 32] = incremental.finalize().into();
190 assert_eq!(actual, expected);
191 });
192 }
193
194 #[test]
195 fn incremental_macs_are_valid() {
196 let key = TEST_HMAC_KEY;
197
198 proptest!(|(input in ".{50,100}")| {
199 let bytes = input.as_bytes();
200 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
201
202 let expected: Vec<_> = bytes
205 .chunks(incremental.chunk_size)
206 .scan(Vec::new(), |acc, chunk| {
207 acc.extend(chunk.iter());
208 Some(hmac_sha256(key, acc).to_vec())
209 })
210 .collect();
211
212 let mut actual: Vec<Vec<u8>> = bytes
213 .random_chunks(incremental.chunk_size)
214 .flat_map(|chunk| incremental.update(chunk).collect::<Vec<_>>())
215 .map(|out| out.into())
216 .map(|bs: [u8; 32]| bs.to_vec())
217 .collect();
218 if bytes.len() % incremental.chunk_size != 0 {
221 let last_hmac: [u8; 32] = incremental.finalize().into();
222 actual.push(last_hmac.to_vec());
223 }
224 assert_eq!(actual, expected);
225 });
226 }
227
228 #[test]
229 fn validating_simple_test() {
230 let key = TEST_HMAC_KEY;
231 let input = "this is a simple test input string";
232
233 let bytes = input.as_bytes();
234 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
235 let mut expected_macs: Vec<_> = incremental.update(bytes).collect();
236 expected_macs.push(incremental.finalize());
237
238 let expected_bytes: Vec<[u8; 32]> =
239 expected_macs.into_iter().map(|mac| mac.into()).collect();
240
241 {
242 let mut validating =
243 new_incremental(key, TEST_CHUNK_SIZE).validating(expected_bytes.clone());
244 validating
245 .update(bytes)
246 .expect("update: validation should succeed");
247 validating
248 .finalize()
249 .expect("finalize: validation should succeed");
250 }
251
252 {
253 let mut failing_first_update = expected_bytes.clone();
254 failing_first_update
255 .first_mut()
256 .expect("there must be at least one mac")[0] ^= 0xff;
257 let mut validating =
258 new_incremental(key, TEST_CHUNK_SIZE).validating(failing_first_update);
259 validating.update(bytes).expect_err("MacError");
260 }
261
262 {
263 let mut failing_finalize = expected_bytes.clone();
264 failing_finalize
265 .last_mut()
266 .expect("there must be at least one mac")[0] ^= 0xff;
267 let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(failing_finalize);
268 validating.update(bytes).expect("update should succeed");
269 validating.finalize().expect_err("MacError");
270 }
271
272 {
273 let missing_last_mac = &expected_bytes[0..expected_bytes.len() - 1];
274 let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(missing_last_mac);
275 validating.update(bytes).expect("update should succeed");
276 validating.finalize().expect_err("MacError");
277 }
278
279 {
280 let missing_first_mac: Vec<_> = expected_bytes.clone().into_iter().skip(1).collect();
281 let mut validating =
282 new_incremental(key, TEST_CHUNK_SIZE).validating(missing_first_mac);
283 validating.update(bytes).expect_err("MacError");
284 }
285 std::hint::black_box(expected_bytes);
287 }
288
289 #[test]
290 fn validating_returns_right_size() {
291 let key = TEST_HMAC_KEY;
292 let input = "this is a simple test input string";
293
294 let bytes = input.as_bytes();
295 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
296 let mut expected_macs: Vec<_> = incremental.update(bytes).collect();
297 expected_macs.push(incremental.finalize());
298
299 let expected_bytes: Vec<[u8; 32]> =
300 expected_macs.into_iter().map(|mac| mac.into()).collect();
301
302 let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(expected_bytes);
303
304 let input_chunks = bytes.chunks(16).collect::<Vec<_>>();
307 assert_eq!(3, input_chunks.len());
308 let expected_remainder = bytes.len() - TEST_CHUNK_SIZE;
309
310 for (expected_size, input) in std::iter::zip([0, TEST_CHUNK_SIZE, 0], input_chunks) {
311 assert_eq!(
312 expected_size,
313 validating
314 .update(input)
315 .expect("update: validation should succeed")
316 );
317 }
318 assert_eq!(
319 expected_remainder,
320 validating
321 .finalize()
322 .expect("finalize: validation should succeed")
323 );
324 }
325
326 #[test]
327 fn produce_and_validate() {
328 let key = TEST_HMAC_KEY;
329
330 proptest!(|(input in ".{0,100}")| {
331 let bytes = input.as_bytes();
332 let mut incremental = new_incremental(key, TEST_CHUNK_SIZE);
333 let input_chunks = bytes.random_chunks(incremental.chunk_size*2);
334
335 let mut produced: Vec<[u8; 32]> = input_chunks.clone()
336 .flat_map(|chunk| incremental.update(chunk).collect::<Vec<_>>())
337 .map(|out| out.into())
338 .collect();
339 produced.push(incremental.finalize().into());
340
341 let mut validating = new_incremental(key, TEST_CHUNK_SIZE).validating(produced);
342 for chunk in input_chunks.clone() {
343 validating.update(chunk).expect("update: validation should succeed");
344 }
345 validating.finalize().expect("finalize: validation should succeed");
346 });
347 }
348
349 const KIBIBYTES: usize = 1024;
350 const MEBIBYTES: usize = 1024 * KIBIBYTES;
351 const GIBIBYTES: usize = 1024 * MEBIBYTES;
352
353 #[test]
354 fn chunk_sizes_sha256() {
355 for (data_size, expected) in [
356 (0, MINIMUM_CHUNK_SIZE),
357 (KIBIBYTES, MINIMUM_CHUNK_SIZE),
358 (10 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
359 (100 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
360 (MEBIBYTES, MINIMUM_CHUNK_SIZE),
361 (10 * MEBIBYTES, MINIMUM_CHUNK_SIZE),
362 (20 * MEBIBYTES, 80 * KIBIBYTES),
363 (100 * MEBIBYTES, 400 * KIBIBYTES),
364 (200 * MEBIBYTES, 800 * KIBIBYTES),
365 (256 * MEBIBYTES, MEBIBYTES),
366 (512 * MEBIBYTES, 2 * MEBIBYTES),
367 (GIBIBYTES, 2 * MEBIBYTES),
368 (2 * GIBIBYTES, 2 * MEBIBYTES),
369 ] {
370 let actual = calculate_chunk_size::<Sha256>(data_size);
371 assert_eq!(actual, expected);
372 }
373 }
374
375 #[test]
376 fn chunk_sizes_sha512() {
377 for (data_size, expected) in [
378 (0, MINIMUM_CHUNK_SIZE),
379 (KIBIBYTES, MINIMUM_CHUNK_SIZE),
380 (10 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
381 (100 * KIBIBYTES, MINIMUM_CHUNK_SIZE),
382 (MEBIBYTES, MINIMUM_CHUNK_SIZE),
383 (10 * MEBIBYTES, 80 * KIBIBYTES),
384 (20 * MEBIBYTES, 160 * KIBIBYTES),
385 (100 * MEBIBYTES, 800 * KIBIBYTES),
386 (200 * MEBIBYTES, 1600 * KIBIBYTES),
387 (256 * MEBIBYTES, 2 * MEBIBYTES),
388 (512 * MEBIBYTES, 2 * MEBIBYTES),
389 (GIBIBYTES, 2 * MEBIBYTES),
390 ] {
391 let actual = calculate_chunk_size::<sha2::Sha512>(data_size);
392 assert_eq!(actual, expected);
393 }
394 }
395
396 #[test]
397 fn total_digest_size_is_never_too_big() {
398 fn total_digest_size(data_size: usize) -> usize {
399 let chunk_size = calculate_chunk_size::<Sha256>(data_size);
400 let num_chunks = std::cmp::max(1, (data_size + chunk_size - 1) / chunk_size);
401 num_chunks * <Sha256 as OutputSizeUser>::OutputSize::USIZE
402 }
403 let config = ProptestConfig::with_cases(10_000);
404 proptest!(config, |(data_size in 256..256*MEBIBYTES)| {
405 assert!(total_digest_size(data_size) <= 8*KIBIBYTES)
406 });
407 proptest!(|(data_size_mib in 256_usize..2048)| {
408 assert!(total_digest_size(data_size_mib*MEBIBYTES) <= 32*KIBIBYTES)
409 });
410 }
411
412 #[derive(Clone)]
413 struct RandomChunks<'a, T, R: Rng> {
414 base: &'a [T],
415 distribution: Uniform<usize>,
416 rng: R,
417 }
418
419 impl<'a, T, R: Rng> Iterator for RandomChunks<'a, T, R> {
420 type Item = &'a [T];
421
422 fn next(&mut self) -> Option<Self::Item> {
423 if self.base.is_empty() {
424 None
425 } else {
426 let candidate = self.rng.sample(self.distribution);
427 let chunk_size = std::cmp::min(candidate, self.base.len());
428 let (before, after) = self.base.split_at(chunk_size);
429 self.base = after;
430 Some(before)
431 }
432 }
433 }
434
435 trait RandomChunksIterator<T> {
436 fn random_chunks(&self, max_size: usize) -> RandomChunks<T, ThreadRng>;
437 }
438
439 impl<T> RandomChunksIterator<T> for [T] {
440 fn random_chunks(&self, max_size: usize) -> RandomChunks<T, ThreadRng> {
441 assert!(max_size > 0, "Maximal chunk size should be positive");
442 RandomChunks {
443 base: self,
444 distribution: Uniform::new(0, max_size + 1),
445 rng: Default::default(),
446 }
447 }
448 }
449}