libsignal_protocol/
sender_keys.rs

1//
2// Copyright 2020-2021 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::collections::VecDeque;
7
8use itertools::Itertools;
9use prost::Message;
10
11use crate::crypto::hmac_sha256;
12use crate::proto::storage as storage_proto;
13use crate::{consts, PrivateKey, PublicKey, SignalProtocolError};
14
15/// A distinct error type to keep from accidentally propagating deserialization errors.
16#[derive(Debug)]
17pub(crate) struct InvalidSessionError(&'static str);
18
19impl std::fmt::Display for InvalidSessionError {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        self.0.fmt(f)
22    }
23}
24
25#[derive(Debug, Clone)]
26pub(crate) struct SenderMessageKey {
27    iteration: u32,
28    iv: Vec<u8>,
29    cipher_key: Vec<u8>,
30    seed: Vec<u8>,
31}
32
33impl SenderMessageKey {
34    pub(crate) fn new(iteration: u32, seed: Vec<u8>) -> Self {
35        let mut derived = [0; 48];
36        hkdf::Hkdf::<sha2::Sha256>::new(None, &seed)
37            .expand(b"WhisperGroup", &mut derived)
38            .expect("valid output length");
39        Self {
40            iteration,
41            seed,
42            iv: derived[0..16].to_vec(),
43            cipher_key: derived[16..48].to_vec(),
44        }
45    }
46
47    pub(crate) fn from_protobuf(
48        smk: storage_proto::sender_key_state_structure::SenderMessageKey,
49    ) -> Self {
50        Self::new(smk.iteration, smk.seed)
51    }
52
53    pub(crate) fn iteration(&self) -> u32 {
54        self.iteration
55    }
56
57    pub(crate) fn iv(&self) -> &[u8] {
58        &self.iv
59    }
60
61    pub(crate) fn cipher_key(&self) -> &[u8] {
62        &self.cipher_key
63    }
64
65    pub(crate) fn as_protobuf(
66        &self,
67    ) -> storage_proto::sender_key_state_structure::SenderMessageKey {
68        storage_proto::sender_key_state_structure::SenderMessageKey {
69            iteration: self.iteration,
70            seed: self.seed.clone(),
71        }
72    }
73}
74
75#[derive(Debug, Clone)]
76pub(crate) struct SenderChainKey {
77    iteration: u32,
78    chain_key: Vec<u8>,
79}
80
81impl SenderChainKey {
82    const MESSAGE_KEY_SEED: u8 = 0x01;
83    const CHAIN_KEY_SEED: u8 = 0x02;
84
85    pub(crate) fn new(iteration: u32, chain_key: Vec<u8>) -> Self {
86        Self {
87            iteration,
88            chain_key,
89        }
90    }
91
92    pub(crate) fn iteration(&self) -> u32 {
93        self.iteration
94    }
95
96    pub(crate) fn seed(&self) -> &[u8] {
97        &self.chain_key
98    }
99
100    pub(crate) fn next(&self) -> SenderChainKey {
101        SenderChainKey::new(
102            self.iteration + 1,
103            self.get_derivative(Self::CHAIN_KEY_SEED),
104        )
105    }
106
107    pub(crate) fn sender_message_key(&self) -> SenderMessageKey {
108        SenderMessageKey::new(self.iteration, self.get_derivative(Self::MESSAGE_KEY_SEED))
109    }
110
111    fn get_derivative(&self, label: u8) -> Vec<u8> {
112        let label = [label];
113        hmac_sha256(&self.chain_key, &label).to_vec()
114    }
115
116    pub(crate) fn as_protobuf(&self) -> storage_proto::sender_key_state_structure::SenderChainKey {
117        storage_proto::sender_key_state_structure::SenderChainKey {
118            iteration: self.iteration,
119            seed: self.chain_key.clone(),
120        }
121    }
122}
123
124#[derive(Debug, Clone)]
125pub(crate) struct SenderKeyState {
126    state: storage_proto::SenderKeyStateStructure,
127}
128
129impl SenderKeyState {
130    pub(crate) fn new(
131        message_version: u8,
132        chain_id: u32,
133        iteration: u32,
134        chain_key: &[u8],
135        signature_key: PublicKey,
136        signature_private_key: Option<PrivateKey>,
137    ) -> SenderKeyState {
138        let state = storage_proto::SenderKeyStateStructure {
139            message_version: message_version as u32,
140            chain_id,
141            sender_chain_key: Some(
142                SenderChainKey::new(iteration, chain_key.to_vec()).as_protobuf(),
143            ),
144            sender_signing_key: Some(
145                storage_proto::sender_key_state_structure::SenderSigningKey {
146                    public: signature_key.serialize().to_vec(),
147                    private: match signature_private_key {
148                        None => vec![],
149                        Some(k) => k.serialize().to_vec(),
150                    },
151                },
152            ),
153            sender_message_keys: vec![],
154        };
155
156        Self { state }
157    }
158
159    pub(crate) fn from_protobuf(state: storage_proto::SenderKeyStateStructure) -> Self {
160        Self { state }
161    }
162
163    pub(crate) fn message_version(&self) -> u32 {
164        match self.state.message_version {
165            0 => 3, // the first SenderKey version
166            v => v,
167        }
168    }
169
170    pub(crate) fn chain_id(&self) -> u32 {
171        self.state.chain_id
172    }
173
174    pub(crate) fn sender_chain_key(&self) -> Option<SenderChainKey> {
175        let sender_chain = self.state.sender_chain_key.as_ref()?;
176        Some(SenderChainKey::new(
177            sender_chain.iteration,
178            sender_chain.seed.clone(),
179        ))
180    }
181
182    pub(crate) fn set_sender_chain_key(&mut self, chain_key: SenderChainKey) {
183        self.state.sender_chain_key = Some(chain_key.as_protobuf());
184    }
185
186    pub(crate) fn signing_key_public(&self) -> Result<PublicKey, InvalidSessionError> {
187        if let Some(ref signing_key) = self.state.sender_signing_key {
188            PublicKey::try_from(&signing_key.public[..])
189                .map_err(|_| InvalidSessionError("invalid public signing key"))
190        } else {
191            Err(InvalidSessionError("missing signing key"))
192        }
193    }
194
195    pub(crate) fn signing_key_private(&self) -> Result<PrivateKey, InvalidSessionError> {
196        if let Some(ref signing_key) = self.state.sender_signing_key {
197            PrivateKey::deserialize(&signing_key.private)
198                .map_err(|_| InvalidSessionError("invalid private signing key"))
199        } else {
200            Err(InvalidSessionError("missing signing key"))
201        }
202    }
203
204    pub(crate) fn as_protobuf(&self) -> storage_proto::SenderKeyStateStructure {
205        self.state.clone()
206    }
207
208    pub(crate) fn add_sender_message_key(&mut self, sender_message_key: &SenderMessageKey) {
209        self.state
210            .sender_message_keys
211            .push(sender_message_key.as_protobuf());
212        while self.state.sender_message_keys.len() > consts::MAX_MESSAGE_KEYS {
213            self.state.sender_message_keys.remove(0);
214        }
215    }
216
217    pub(crate) fn remove_sender_message_key(&mut self, iteration: u32) -> Option<SenderMessageKey> {
218        if let Some(index) = self
219            .state
220            .sender_message_keys
221            .iter()
222            .position(|x| x.iteration == iteration)
223        {
224            let smk = self.state.sender_message_keys.remove(index);
225            Some(SenderMessageKey::from_protobuf(smk))
226        } else {
227            None
228        }
229    }
230}
231
232#[derive(Debug, Clone)]
233pub struct SenderKeyRecord {
234    states: VecDeque<SenderKeyState>,
235}
236
237impl SenderKeyRecord {
238    pub(crate) fn new_empty() -> Self {
239        Self {
240            states: VecDeque::with_capacity(consts::MAX_SENDER_KEY_STATES),
241        }
242    }
243
244    pub fn deserialize(buf: &[u8]) -> Result<SenderKeyRecord, SignalProtocolError> {
245        let skr = storage_proto::SenderKeyRecordStructure::decode(buf)
246            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
247
248        let mut states = VecDeque::with_capacity(skr.sender_key_states.len());
249        for state in skr.sender_key_states {
250            states.push_back(SenderKeyState::from_protobuf(state))
251        }
252        Ok(Self { states })
253    }
254
255    pub(crate) fn sender_key_state(&self) -> Result<&SenderKeyState, InvalidSessionError> {
256        if !self.states.is_empty() {
257            return Ok(&self.states[0]);
258        }
259        Err(InvalidSessionError("empty sender key state"))
260    }
261
262    pub(crate) fn sender_key_state_mut(
263        &mut self,
264    ) -> Result<&mut SenderKeyState, InvalidSessionError> {
265        if !self.states.is_empty() {
266            return Ok(&mut self.states[0]);
267        }
268        Err(InvalidSessionError("empty sender key state"))
269    }
270
271    pub(crate) fn sender_key_state_for_chain_id(
272        &mut self,
273        chain_id: u32,
274    ) -> Option<&mut SenderKeyState> {
275        for i in 0..self.states.len() {
276            if self.states[i].chain_id() == chain_id {
277                return Some(&mut self.states[i]);
278            }
279        }
280        None
281    }
282
283    pub(crate) fn chain_ids_for_logging(&self) -> impl ExactSizeIterator<Item = u32> + '_ {
284        self.states.iter().map(|state| state.chain_id())
285    }
286
287    pub(crate) fn add_sender_key_state(
288        &mut self,
289        message_version: u8,
290        chain_id: u32,
291        iteration: u32,
292        chain_key: &[u8],
293        signature_key: PublicKey,
294        signature_private_key: Option<PrivateKey>,
295    ) {
296        let existing_state = self.remove_state(chain_id, signature_key);
297
298        if self.remove_states_with_chain_id(chain_id) > 0 {
299            log::warn!(
300                "Removed a matching chain_id ({}) found with a different public key",
301                chain_id
302            );
303        }
304
305        let state = match existing_state {
306            None => SenderKeyState::new(
307                message_version,
308                chain_id,
309                iteration,
310                chain_key,
311                signature_key,
312                signature_private_key,
313            ),
314            Some(state) => state,
315        };
316
317        while self.states.len() >= consts::MAX_SENDER_KEY_STATES {
318            self.states.pop_back();
319        }
320
321        self.states.push_front(state);
322    }
323
324    /// Remove the state with the matching `chain_id` and `signature_key`.
325    ///
326    /// Skips any bad protobufs.
327    fn remove_state(&mut self, chain_id: u32, signature_key: PublicKey) -> Option<SenderKeyState> {
328        let (index, _state) = self.states.iter().find_position(|state| {
329            state.chain_id() == chain_id && state.signing_key_public().ok() == Some(signature_key)
330        })?;
331
332        self.states.remove(index)
333    }
334
335    /// Returns the number of removed states.
336    ///
337    /// Skips any bad protobufs.
338    fn remove_states_with_chain_id(&mut self, chain_id: u32) -> usize {
339        let initial_length = self.states.len();
340        self.states.retain(|state| state.chain_id() != chain_id);
341        initial_length - self.states.len()
342    }
343
344    pub(crate) fn as_protobuf(&self) -> storage_proto::SenderKeyRecordStructure {
345        let mut states = Vec::with_capacity(self.states.len());
346        for state in &self.states {
347            states.push(state.as_protobuf());
348        }
349
350        storage_proto::SenderKeyRecordStructure {
351            sender_key_states: states,
352        }
353    }
354
355    pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
356        Ok(self.as_protobuf().encode_to_vec())
357    }
358}
359
360#[cfg(test)]
361mod sender_key_record_add_sender_key_state_tests {
362    use itertools::Itertools;
363    use rand::rngs::OsRng;
364
365    use super::*;
366    use crate::KeyPair;
367
368    fn random_public_key() -> PublicKey {
369        KeyPair::generate(&mut OsRng).public_key
370    }
371
372    fn chain_key(i: u128) -> Vec<u8> {
373        i.to_be_bytes().to_vec()
374    }
375
376    struct TestContext {
377        sender_key_record: SenderKeyRecord,
378    }
379
380    impl TestContext {
381        fn new() -> Self {
382            Self {
383                sender_key_record: SenderKeyRecord::new_empty(),
384            }
385        }
386
387        /// Associates the `record_key` with the `chain_key` via `add_sender_key_state` which is the
388        /// method under test in this module.
389        fn add_sender_key_state_record(&mut self, record_key: (PublicKey, u32), chain_key: &[u8]) {
390            let (public_key, chain_id) = record_key;
391            self.sender_key_record
392                .add_sender_key_state(1, chain_id, 1, chain_key, public_key, None);
393        }
394
395        fn assert_number_of_states(&self, expected: usize) {
396            assert_eq!(expected, self.sender_key_record.states.len());
397        }
398
399        /// Asserts that for the supplied `record_key` the chain key is as expected when looked up
400        /// by both `chain_id` and `public_key` and `chain_id`.
401        fn assert_records_chain_key(
402            &mut self,
403            record_key: (PublicKey, u32),
404            expected_chain_key: &[u8],
405        ) {
406            let (public_key, chain_id) = record_key;
407            let found_chain_key = self
408                .sender_key_record
409                .sender_key_state_for_chain_id(chain_id)
410                .expect("Expect to find chain id")
411                .sender_chain_key()
412                .expect("Expect to find chain key")
413                .chain_key;
414
415            assert_eq!(found_chain_key, expected_chain_key);
416
417            let matching_state = self
418                .sender_key_record
419                .states
420                .iter()
421                .filter(|state| {
422                    state.chain_id() == chain_id
423                        && state.signing_key_public().expect("expect public key") == public_key
424                })
425                .exactly_one()
426                .expect("Expected exactly one record key match");
427
428            assert_eq!(
429                &matching_state
430                    .sender_chain_key()
431                    .expect("Expect to find chain key")
432                    .chain_key,
433                expected_chain_key
434            );
435        }
436
437        fn assert_record_order(&self, order: Vec<(PublicKey, u32)>) {
438            let record_keys = self
439                .sender_key_record
440                .states
441                .iter()
442                .map(|state| {
443                    (
444                        state.signing_key_public().expect("expect public key"),
445                        state.chain_id(),
446                    )
447                })
448                .collect::<Vec<_>>();
449
450            assert_eq!(record_keys, order);
451        }
452    }
453
454    #[test]
455    fn add_single_state() {
456        let mut context = TestContext::new();
457
458        let public_key = random_public_key();
459        let chain_id = 1;
460        let chain_key = chain_key(1);
461        let record_key = (public_key, chain_id);
462
463        context.add_sender_key_state_record(record_key, &chain_key);
464
465        context.assert_number_of_states(1);
466        context.assert_records_chain_key(record_key, &chain_key);
467    }
468
469    #[test]
470    fn add_second_state() {
471        let mut context = TestContext::new();
472
473        let chain_id_1 = 1;
474        let chain_id_2 = 2;
475        let record_key_1 = (random_public_key(), chain_id_1);
476        let record_key_2 = (random_public_key(), chain_id_2);
477        let chain_key_1 = chain_key(1);
478        let chain_key_2 = chain_key(2);
479
480        context.add_sender_key_state_record(record_key_1, &chain_key_1);
481        context.add_sender_key_state_record(record_key_2, &chain_key_2);
482
483        context.assert_number_of_states(2);
484        context.assert_records_chain_key(record_key_1, &chain_key_1);
485        context.assert_records_chain_key(record_key_2, &chain_key_2);
486    }
487
488    #[test]
489    fn when_exceed_maximum_states_then_oldest_is_ejected() {
490        assert_eq!(
491            5,
492            consts::MAX_SENDER_KEY_STATES,
493            "Test written to expect this limit"
494        );
495
496        let mut context = TestContext::new();
497
498        let record_key_1 = (random_public_key(), 1);
499        let record_key_2 = (random_public_key(), 2);
500        let record_key_3 = (random_public_key(), 3);
501        let record_key_4 = (random_public_key(), 4);
502        let record_key_5 = (random_public_key(), 5);
503        let record_key_6 = (random_public_key(), 6);
504
505        context.add_sender_key_state_record(record_key_1, &chain_key(1));
506        context.add_sender_key_state_record(record_key_2, &chain_key(2));
507        context.add_sender_key_state_record(record_key_3, &chain_key(3));
508        context.add_sender_key_state_record(record_key_4, &chain_key(4));
509        context.add_sender_key_state_record(record_key_5, &chain_key(5));
510
511        context.assert_record_order(vec![
512            record_key_5,
513            record_key_4,
514            record_key_3,
515            record_key_2,
516            record_key_1,
517        ]);
518
519        context.add_sender_key_state_record(record_key_6, &chain_key(6));
520
521        context.assert_record_order(vec![
522            record_key_6,
523            record_key_5,
524            record_key_4,
525            record_key_3,
526            record_key_2,
527        ]);
528    }
529
530    #[test]
531    fn when_second_state_with_same_public_key_and_chain_id_added_then_it_keeps_first_data() {
532        let mut context = TestContext::new();
533
534        let chain_id = 1;
535        let record_key = (random_public_key(), chain_id);
536        let chain_key_1 = chain_key(1);
537        let chain_key_2 = chain_key(2);
538
539        context.add_sender_key_state_record(record_key, &chain_key_1);
540        context.add_sender_key_state_record(record_key, &chain_key_2);
541
542        context.assert_number_of_states(1);
543        context.assert_records_chain_key(record_key, &chain_key_1);
544    }
545
546    #[test]
547    fn when_second_state_with_different_public_key_but_same_chain_id_added_then_it_gets_replaced() {
548        let mut context = TestContext::new();
549
550        let chain_id = 1;
551        let record_key_1 = (random_public_key(), chain_id);
552        let record_key_2 = (random_public_key(), chain_id);
553        let chain_key_1 = chain_key(1);
554        let chain_key_2 = chain_key(2);
555
556        context.add_sender_key_state_record(record_key_1, &chain_key_1);
557        context.add_sender_key_state_record(record_key_2, &chain_key_2);
558
559        context.assert_number_of_states(1);
560        context.assert_records_chain_key(record_key_2, &chain_key_2);
561    }
562
563    #[test]
564    fn when_second_state_with_same_public_key_and_chain_id_added_then_it_becomes_the_most_recent() {
565        let mut context = TestContext::new();
566
567        let chain_id_1 = 1;
568        let chain_id_2 = 2;
569        let record_key_1 = (random_public_key(), chain_id_1);
570        let record_key_2 = (random_public_key(), chain_id_2);
571        let chain_key_1 = chain_key(1);
572        let chain_key_2 = chain_key(2);
573        let chain_key_3 = chain_key(3);
574
575        context.add_sender_key_state_record(record_key_1, &chain_key_1);
576        context.add_sender_key_state_record(record_key_2, &chain_key_2);
577
578        context.assert_record_order(vec![record_key_2, record_key_1]);
579
580        context.add_sender_key_state_record(record_key_1, &chain_key_3);
581
582        context.assert_record_order(vec![record_key_1, record_key_2]);
583    }
584}