libsignal_protocol/
sender_keys.rs

1//
2// Copyright 2020-2021 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::collections::VecDeque;
7
8use itertools::Itertools;
9use prost::Message;
10
11use crate::crypto::hmac_sha256;
12use crate::proto::storage as storage_proto;
13use crate::{consts, PrivateKey, PublicKey, SignalProtocolError};
14
15/// A distinct error type to keep from accidentally propagating deserialization errors.
16#[derive(Debug)]
17pub(crate) struct InvalidSessionError(&'static str);
18
19impl std::fmt::Display for InvalidSessionError {
20    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21        self.0.fmt(f)
22    }
23}
24
25#[derive(Debug, Clone)]
26pub(crate) struct SenderMessageKey {
27    iteration: u32,
28    iv: Vec<u8>,
29    cipher_key: Vec<u8>,
30    seed: Vec<u8>,
31}
32
33impl SenderMessageKey {
34    pub(crate) fn new(iteration: u32, seed: Vec<u8>) -> Self {
35        let mut derived = [0; 48];
36        hkdf::Hkdf::<sha2::Sha256>::new(None, &seed)
37            .expand(b"WhisperGroup", &mut derived)
38            .expect("valid output length");
39        Self {
40            iteration,
41            seed,
42            iv: derived[0..16].to_vec(),
43            cipher_key: derived[16..48].to_vec(),
44        }
45    }
46
47    pub(crate) fn from_protobuf(
48        smk: storage_proto::sender_key_state_structure::SenderMessageKey,
49    ) -> Self {
50        Self::new(smk.iteration, smk.seed)
51    }
52
53    pub(crate) fn iteration(&self) -> u32 {
54        self.iteration
55    }
56
57    pub(crate) fn iv(&self) -> &[u8] {
58        &self.iv
59    }
60
61    pub(crate) fn cipher_key(&self) -> &[u8] {
62        &self.cipher_key
63    }
64
65    pub(crate) fn as_protobuf(
66        &self,
67    ) -> storage_proto::sender_key_state_structure::SenderMessageKey {
68        storage_proto::sender_key_state_structure::SenderMessageKey {
69            iteration: self.iteration,
70            seed: self.seed.clone(),
71        }
72    }
73}
74
75#[derive(Debug, Clone)]
76pub(crate) struct SenderChainKey {
77    iteration: u32,
78    chain_key: Vec<u8>,
79}
80
81impl SenderChainKey {
82    const MESSAGE_KEY_SEED: u8 = 0x01;
83    const CHAIN_KEY_SEED: u8 = 0x02;
84
85    pub(crate) fn new(iteration: u32, chain_key: Vec<u8>) -> Self {
86        Self {
87            iteration,
88            chain_key,
89        }
90    }
91
92    pub(crate) fn iteration(&self) -> u32 {
93        self.iteration
94    }
95
96    pub(crate) fn seed(&self) -> &[u8] {
97        &self.chain_key
98    }
99
100    pub(crate) fn next(&self) -> Result<SenderChainKey, SignalProtocolError> {
101        let new_iteration = self.iteration.checked_add(1).ok_or_else(|| {
102            SignalProtocolError::InvalidState(
103                "sender_chain_key_next",
104                "Sender chain is too long".into(),
105            )
106        })?;
107
108        Ok(SenderChainKey::new(
109            new_iteration,
110            self.get_derivative(Self::CHAIN_KEY_SEED),
111        ))
112    }
113
114    pub(crate) fn sender_message_key(&self) -> SenderMessageKey {
115        SenderMessageKey::new(self.iteration, self.get_derivative(Self::MESSAGE_KEY_SEED))
116    }
117
118    fn get_derivative(&self, label: u8) -> Vec<u8> {
119        let label = [label];
120        hmac_sha256(&self.chain_key, &label).to_vec()
121    }
122
123    pub(crate) fn as_protobuf(&self) -> storage_proto::sender_key_state_structure::SenderChainKey {
124        storage_proto::sender_key_state_structure::SenderChainKey {
125            iteration: self.iteration,
126            seed: self.chain_key.clone(),
127        }
128    }
129}
130
131#[derive(Debug, Clone)]
132pub(crate) struct SenderKeyState {
133    state: storage_proto::SenderKeyStateStructure,
134}
135
136impl SenderKeyState {
137    pub(crate) fn new(
138        message_version: u8,
139        chain_id: u32,
140        iteration: u32,
141        chain_key: &[u8],
142        signature_key: PublicKey,
143        signature_private_key: Option<PrivateKey>,
144    ) -> SenderKeyState {
145        let state = storage_proto::SenderKeyStateStructure {
146            message_version: message_version as u32,
147            chain_id,
148            sender_chain_key: Some(
149                SenderChainKey::new(iteration, chain_key.to_vec()).as_protobuf(),
150            ),
151            sender_signing_key: Some(
152                storage_proto::sender_key_state_structure::SenderSigningKey {
153                    public: signature_key.serialize().to_vec(),
154                    private: match signature_private_key {
155                        None => vec![],
156                        Some(k) => k.serialize().to_vec(),
157                    },
158                },
159            ),
160            sender_message_keys: vec![],
161        };
162
163        Self { state }
164    }
165
166    pub(crate) fn from_protobuf(state: storage_proto::SenderKeyStateStructure) -> Self {
167        Self { state }
168    }
169
170    pub(crate) fn message_version(&self) -> u32 {
171        match self.state.message_version {
172            0 => 3, // the first SenderKey version
173            v => v,
174        }
175    }
176
177    pub(crate) fn chain_id(&self) -> u32 {
178        self.state.chain_id
179    }
180
181    pub(crate) fn sender_chain_key(&self) -> Option<SenderChainKey> {
182        let sender_chain = self.state.sender_chain_key.as_ref()?;
183        Some(SenderChainKey::new(
184            sender_chain.iteration,
185            sender_chain.seed.clone(),
186        ))
187    }
188
189    pub(crate) fn set_sender_chain_key(&mut self, chain_key: SenderChainKey) {
190        self.state.sender_chain_key = Some(chain_key.as_protobuf());
191    }
192
193    pub(crate) fn signing_key_public(&self) -> Result<PublicKey, InvalidSessionError> {
194        if let Some(ref signing_key) = self.state.sender_signing_key {
195            PublicKey::try_from(&signing_key.public[..])
196                .map_err(|_| InvalidSessionError("invalid public signing key"))
197        } else {
198            Err(InvalidSessionError("missing signing key"))
199        }
200    }
201
202    pub(crate) fn signing_key_private(&self) -> Result<PrivateKey, InvalidSessionError> {
203        if let Some(ref signing_key) = self.state.sender_signing_key {
204            PrivateKey::deserialize(&signing_key.private)
205                .map_err(|_| InvalidSessionError("invalid private signing key"))
206        } else {
207            Err(InvalidSessionError("missing signing key"))
208        }
209    }
210
211    pub(crate) fn as_protobuf(&self) -> storage_proto::SenderKeyStateStructure {
212        self.state.clone()
213    }
214
215    pub(crate) fn add_sender_message_key(&mut self, sender_message_key: &SenderMessageKey) {
216        self.state
217            .sender_message_keys
218            .push(sender_message_key.as_protobuf());
219        while self.state.sender_message_keys.len() > consts::MAX_MESSAGE_KEYS {
220            self.state.sender_message_keys.remove(0);
221        }
222    }
223
224    pub(crate) fn remove_sender_message_key(&mut self, iteration: u32) -> Option<SenderMessageKey> {
225        if let Some(index) = self
226            .state
227            .sender_message_keys
228            .iter()
229            .position(|x| x.iteration == iteration)
230        {
231            let smk = self.state.sender_message_keys.remove(index);
232            Some(SenderMessageKey::from_protobuf(smk))
233        } else {
234            None
235        }
236    }
237}
238
239#[derive(Debug, Clone)]
240pub struct SenderKeyRecord {
241    states: VecDeque<SenderKeyState>,
242}
243
244impl SenderKeyRecord {
245    pub(crate) fn new_empty() -> Self {
246        Self {
247            states: VecDeque::with_capacity(consts::MAX_SENDER_KEY_STATES),
248        }
249    }
250
251    pub fn deserialize(buf: &[u8]) -> Result<SenderKeyRecord, SignalProtocolError> {
252        let skr = storage_proto::SenderKeyRecordStructure::decode(buf)
253            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
254
255        let mut states = VecDeque::with_capacity(skr.sender_key_states.len());
256        for state in skr.sender_key_states {
257            states.push_back(SenderKeyState::from_protobuf(state))
258        }
259        Ok(Self { states })
260    }
261
262    pub(crate) fn sender_key_state(&self) -> Result<&SenderKeyState, InvalidSessionError> {
263        if !self.states.is_empty() {
264            return Ok(&self.states[0]);
265        }
266        Err(InvalidSessionError("empty sender key state"))
267    }
268
269    pub(crate) fn sender_key_state_mut(
270        &mut self,
271    ) -> Result<&mut SenderKeyState, InvalidSessionError> {
272        if !self.states.is_empty() {
273            return Ok(&mut self.states[0]);
274        }
275        Err(InvalidSessionError("empty sender key state"))
276    }
277
278    pub(crate) fn sender_key_state_for_chain_id(
279        &mut self,
280        chain_id: u32,
281    ) -> Option<&mut SenderKeyState> {
282        for i in 0..self.states.len() {
283            if self.states[i].chain_id() == chain_id {
284                return Some(&mut self.states[i]);
285            }
286        }
287        None
288    }
289
290    pub(crate) fn chain_ids_for_logging(&self) -> impl ExactSizeIterator<Item = u32> + '_ {
291        self.states.iter().map(|state| state.chain_id())
292    }
293
294    pub(crate) fn add_sender_key_state(
295        &mut self,
296        message_version: u8,
297        chain_id: u32,
298        iteration: u32,
299        chain_key: &[u8],
300        signature_key: PublicKey,
301        signature_private_key: Option<PrivateKey>,
302    ) {
303        let existing_state = self.remove_state(chain_id, signature_key);
304
305        if self.remove_states_with_chain_id(chain_id) > 0 {
306            log::warn!(
307                "Removed a matching chain_id ({chain_id}) found with a different public key"
308            );
309        }
310
311        let state = match existing_state {
312            None => SenderKeyState::new(
313                message_version,
314                chain_id,
315                iteration,
316                chain_key,
317                signature_key,
318                signature_private_key,
319            ),
320            Some(state) => state,
321        };
322
323        while self.states.len() >= consts::MAX_SENDER_KEY_STATES {
324            self.states.pop_back();
325        }
326
327        self.states.push_front(state);
328    }
329
330    /// Remove the state with the matching `chain_id` and `signature_key`.
331    ///
332    /// Skips any bad protobufs.
333    fn remove_state(&mut self, chain_id: u32, signature_key: PublicKey) -> Option<SenderKeyState> {
334        let (index, _state) = self.states.iter().find_position(|state| {
335            state.chain_id() == chain_id && state.signing_key_public().ok() == Some(signature_key)
336        })?;
337
338        self.states.remove(index)
339    }
340
341    /// Returns the number of removed states.
342    ///
343    /// Skips any bad protobufs.
344    fn remove_states_with_chain_id(&mut self, chain_id: u32) -> usize {
345        let initial_length = self.states.len();
346        self.states.retain(|state| state.chain_id() != chain_id);
347        initial_length - self.states.len()
348    }
349
350    pub(crate) fn as_protobuf(&self) -> storage_proto::SenderKeyRecordStructure {
351        let mut states = Vec::with_capacity(self.states.len());
352        for state in &self.states {
353            states.push(state.as_protobuf());
354        }
355
356        storage_proto::SenderKeyRecordStructure {
357            sender_key_states: states,
358        }
359    }
360
361    pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
362        Ok(self.as_protobuf().encode_to_vec())
363    }
364}
365
366#[cfg(test)]
367mod sender_key_record_add_sender_key_state_tests {
368    use itertools::Itertools;
369    use rand::rngs::OsRng;
370    use rand::TryRngCore as _;
371
372    use super::*;
373    use crate::KeyPair;
374
375    fn random_public_key() -> PublicKey {
376        KeyPair::generate(&mut OsRng.unwrap_err()).public_key
377    }
378
379    fn chain_key(i: u128) -> Vec<u8> {
380        i.to_be_bytes().to_vec()
381    }
382
383    struct TestContext {
384        sender_key_record: SenderKeyRecord,
385    }
386
387    impl TestContext {
388        fn new() -> Self {
389            Self {
390                sender_key_record: SenderKeyRecord::new_empty(),
391            }
392        }
393
394        /// Associates the `record_key` with the `chain_key` via `add_sender_key_state` which is the
395        /// method under test in this module.
396        fn add_sender_key_state_record(&mut self, record_key: (PublicKey, u32), chain_key: &[u8]) {
397            let (public_key, chain_id) = record_key;
398            self.sender_key_record
399                .add_sender_key_state(1, chain_id, 1, chain_key, public_key, None);
400        }
401
402        fn assert_number_of_states(&self, expected: usize) {
403            assert_eq!(expected, self.sender_key_record.states.len());
404        }
405
406        /// Asserts that for the supplied `record_key` the chain key is as expected when looked up
407        /// by both `chain_id` and `public_key` and `chain_id`.
408        fn assert_records_chain_key(
409            &mut self,
410            record_key: (PublicKey, u32),
411            expected_chain_key: &[u8],
412        ) {
413            let (public_key, chain_id) = record_key;
414            let found_chain_key = self
415                .sender_key_record
416                .sender_key_state_for_chain_id(chain_id)
417                .expect("Expect to find chain id")
418                .sender_chain_key()
419                .expect("Expect to find chain key")
420                .chain_key;
421
422            assert_eq!(found_chain_key, expected_chain_key);
423
424            let matching_state = self
425                .sender_key_record
426                .states
427                .iter()
428                .filter(|state| {
429                    state.chain_id() == chain_id
430                        && state.signing_key_public().expect("expect public key") == public_key
431                })
432                .exactly_one()
433                .expect("Expected exactly one record key match");
434
435            assert_eq!(
436                &matching_state
437                    .sender_chain_key()
438                    .expect("Expect to find chain key")
439                    .chain_key,
440                expected_chain_key
441            );
442        }
443
444        fn assert_record_order(&self, order: Vec<(PublicKey, u32)>) {
445            let record_keys = self
446                .sender_key_record
447                .states
448                .iter()
449                .map(|state| {
450                    (
451                        state.signing_key_public().expect("expect public key"),
452                        state.chain_id(),
453                    )
454                })
455                .collect::<Vec<_>>();
456
457            assert_eq!(record_keys, order);
458        }
459    }
460
461    #[test]
462    fn add_single_state() {
463        let mut context = TestContext::new();
464
465        let public_key = random_public_key();
466        let chain_id = 1;
467        let chain_key = chain_key(1);
468        let record_key = (public_key, chain_id);
469
470        context.add_sender_key_state_record(record_key, &chain_key);
471
472        context.assert_number_of_states(1);
473        context.assert_records_chain_key(record_key, &chain_key);
474    }
475
476    #[test]
477    fn add_second_state() {
478        let mut context = TestContext::new();
479
480        let chain_id_1 = 1;
481        let chain_id_2 = 2;
482        let record_key_1 = (random_public_key(), chain_id_1);
483        let record_key_2 = (random_public_key(), chain_id_2);
484        let chain_key_1 = chain_key(1);
485        let chain_key_2 = chain_key(2);
486
487        context.add_sender_key_state_record(record_key_1, &chain_key_1);
488        context.add_sender_key_state_record(record_key_2, &chain_key_2);
489
490        context.assert_number_of_states(2);
491        context.assert_records_chain_key(record_key_1, &chain_key_1);
492        context.assert_records_chain_key(record_key_2, &chain_key_2);
493    }
494
495    #[test]
496    fn when_exceed_maximum_states_then_oldest_is_ejected() {
497        assert_eq!(
498            5,
499            consts::MAX_SENDER_KEY_STATES,
500            "Test written to expect this limit"
501        );
502
503        let mut context = TestContext::new();
504
505        let record_key_1 = (random_public_key(), 1);
506        let record_key_2 = (random_public_key(), 2);
507        let record_key_3 = (random_public_key(), 3);
508        let record_key_4 = (random_public_key(), 4);
509        let record_key_5 = (random_public_key(), 5);
510        let record_key_6 = (random_public_key(), 6);
511
512        context.add_sender_key_state_record(record_key_1, &chain_key(1));
513        context.add_sender_key_state_record(record_key_2, &chain_key(2));
514        context.add_sender_key_state_record(record_key_3, &chain_key(3));
515        context.add_sender_key_state_record(record_key_4, &chain_key(4));
516        context.add_sender_key_state_record(record_key_5, &chain_key(5));
517
518        context.assert_record_order(vec![
519            record_key_5,
520            record_key_4,
521            record_key_3,
522            record_key_2,
523            record_key_1,
524        ]);
525
526        context.add_sender_key_state_record(record_key_6, &chain_key(6));
527
528        context.assert_record_order(vec![
529            record_key_6,
530            record_key_5,
531            record_key_4,
532            record_key_3,
533            record_key_2,
534        ]);
535    }
536
537    #[test]
538    fn when_second_state_with_same_public_key_and_chain_id_added_then_it_keeps_first_data() {
539        let mut context = TestContext::new();
540
541        let chain_id = 1;
542        let record_key = (random_public_key(), chain_id);
543        let chain_key_1 = chain_key(1);
544        let chain_key_2 = chain_key(2);
545
546        context.add_sender_key_state_record(record_key, &chain_key_1);
547        context.add_sender_key_state_record(record_key, &chain_key_2);
548
549        context.assert_number_of_states(1);
550        context.assert_records_chain_key(record_key, &chain_key_1);
551    }
552
553    #[test]
554    fn when_second_state_with_different_public_key_but_same_chain_id_added_then_it_gets_replaced() {
555        let mut context = TestContext::new();
556
557        let chain_id = 1;
558        let record_key_1 = (random_public_key(), chain_id);
559        let record_key_2 = (random_public_key(), chain_id);
560        let chain_key_1 = chain_key(1);
561        let chain_key_2 = chain_key(2);
562
563        context.add_sender_key_state_record(record_key_1, &chain_key_1);
564        context.add_sender_key_state_record(record_key_2, &chain_key_2);
565
566        context.assert_number_of_states(1);
567        context.assert_records_chain_key(record_key_2, &chain_key_2);
568    }
569
570    #[test]
571    fn when_second_state_with_same_public_key_and_chain_id_added_then_it_becomes_the_most_recent() {
572        let mut context = TestContext::new();
573
574        let chain_id_1 = 1;
575        let chain_id_2 = 2;
576        let record_key_1 = (random_public_key(), chain_id_1);
577        let record_key_2 = (random_public_key(), chain_id_2);
578        let chain_key_1 = chain_key(1);
579        let chain_key_2 = chain_key(2);
580        let chain_key_3 = chain_key(3);
581
582        context.add_sender_key_state_record(record_key_1, &chain_key_1);
583        context.add_sender_key_state_record(record_key_2, &chain_key_2);
584
585        context.assert_record_order(vec![record_key_2, record_key_1]);
586
587        context.add_sender_key_state_record(record_key_1, &chain_key_3);
588
589        context.assert_record_order(vec![record_key_1, record_key_2]);
590    }
591}
592
593#[cfg(test)]
594mod sender_chain_key_iteration_tests {
595    use std::collections::HashSet;
596
597    use assert_matches::assert_matches;
598
599    use super::SenderChainKey;
600    use crate::SignalProtocolError;
601
602    const INITIAL_ITERATION: u32 = 0;
603    const INITIAL_SEED_KEY: [u8; 4] = [1, 2, 3, 4];
604
605    #[test]
606    fn iteration() {
607        let mut sender_chain_key =
608            SenderChainKey::new(INITIAL_ITERATION, INITIAL_SEED_KEY.to_vec());
609
610        let mut seen_seeds = HashSet::new();
611        seen_seeds.insert(sender_chain_key.seed().to_vec());
612
613        for i in 1..10 {
614            let next_chain_key = sender_chain_key
615                .next()
616                .expect("Expect chain key to not overflow after only a few iterations");
617            let next_seed = next_chain_key.seed().to_vec();
618
619            assert!(
620                seen_seeds.insert(next_seed),
621                "Seed has already been seen before for iteration {i}"
622            );
623            assert_eq!(next_chain_key.iteration(), INITIAL_ITERATION + i);
624
625            sender_chain_key = next_chain_key;
626        }
627    }
628
629    #[test]
630    fn when_sender_chain_key_iteration_overflows() {
631        let sender_chain_key: SenderChainKey =
632            SenderChainKey::new(u32::MAX, INITIAL_SEED_KEY.to_vec());
633        assert_matches!(
634            sender_chain_key.next(),
635            Err(SignalProtocolError::InvalidState { .. })
636        );
637    }
638}