1use std::collections::VecDeque;
7
8use itertools::Itertools;
9use prost::Message;
10
11use crate::crypto::hmac_sha256;
12use crate::proto::storage as storage_proto;
13use crate::{consts, PrivateKey, PublicKey, SignalProtocolError};
14
15#[derive(Debug)]
17pub(crate) struct InvalidSessionError(&'static str);
18
19impl std::fmt::Display for InvalidSessionError {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 self.0.fmt(f)
22 }
23}
24
25#[derive(Debug, Clone)]
26pub(crate) struct SenderMessageKey {
27 iteration: u32,
28 iv: Vec<u8>,
29 cipher_key: Vec<u8>,
30 seed: Vec<u8>,
31}
32
33impl SenderMessageKey {
34 pub(crate) fn new(iteration: u32, seed: Vec<u8>) -> Self {
35 let mut derived = [0; 48];
36 hkdf::Hkdf::<sha2::Sha256>::new(None, &seed)
37 .expand(b"WhisperGroup", &mut derived)
38 .expect("valid output length");
39 Self {
40 iteration,
41 seed,
42 iv: derived[0..16].to_vec(),
43 cipher_key: derived[16..48].to_vec(),
44 }
45 }
46
47 pub(crate) fn from_protobuf(
48 smk: storage_proto::sender_key_state_structure::SenderMessageKey,
49 ) -> Self {
50 Self::new(smk.iteration, smk.seed)
51 }
52
53 pub(crate) fn iteration(&self) -> u32 {
54 self.iteration
55 }
56
57 pub(crate) fn iv(&self) -> &[u8] {
58 &self.iv
59 }
60
61 pub(crate) fn cipher_key(&self) -> &[u8] {
62 &self.cipher_key
63 }
64
65 pub(crate) fn as_protobuf(
66 &self,
67 ) -> storage_proto::sender_key_state_structure::SenderMessageKey {
68 storage_proto::sender_key_state_structure::SenderMessageKey {
69 iteration: self.iteration,
70 seed: self.seed.clone(),
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
76pub(crate) struct SenderChainKey {
77 iteration: u32,
78 chain_key: Vec<u8>,
79}
80
81impl SenderChainKey {
82 const MESSAGE_KEY_SEED: u8 = 0x01;
83 const CHAIN_KEY_SEED: u8 = 0x02;
84
85 pub(crate) fn new(iteration: u32, chain_key: Vec<u8>) -> Self {
86 Self {
87 iteration,
88 chain_key,
89 }
90 }
91
92 pub(crate) fn iteration(&self) -> u32 {
93 self.iteration
94 }
95
96 pub(crate) fn seed(&self) -> &[u8] {
97 &self.chain_key
98 }
99
100 pub(crate) fn next(&self) -> Result<SenderChainKey, SignalProtocolError> {
101 let new_iteration = self.iteration.checked_add(1).ok_or_else(|| {
102 SignalProtocolError::InvalidState(
103 "sender_chain_key_next",
104 "Sender chain is too long".into(),
105 )
106 })?;
107
108 Ok(SenderChainKey::new(
109 new_iteration,
110 self.get_derivative(Self::CHAIN_KEY_SEED),
111 ))
112 }
113
114 pub(crate) fn sender_message_key(&self) -> SenderMessageKey {
115 SenderMessageKey::new(self.iteration, self.get_derivative(Self::MESSAGE_KEY_SEED))
116 }
117
118 fn get_derivative(&self, label: u8) -> Vec<u8> {
119 let label = [label];
120 hmac_sha256(&self.chain_key, &label).to_vec()
121 }
122
123 pub(crate) fn as_protobuf(&self) -> storage_proto::sender_key_state_structure::SenderChainKey {
124 storage_proto::sender_key_state_structure::SenderChainKey {
125 iteration: self.iteration,
126 seed: self.chain_key.clone(),
127 }
128 }
129}
130
131#[derive(Debug, Clone)]
132pub(crate) struct SenderKeyState {
133 state: storage_proto::SenderKeyStateStructure,
134}
135
136impl SenderKeyState {
137 pub(crate) fn new(
138 message_version: u8,
139 chain_id: u32,
140 iteration: u32,
141 chain_key: &[u8],
142 signature_key: PublicKey,
143 signature_private_key: Option<PrivateKey>,
144 ) -> SenderKeyState {
145 let state = storage_proto::SenderKeyStateStructure {
146 message_version: message_version as u32,
147 chain_id,
148 sender_chain_key: Some(
149 SenderChainKey::new(iteration, chain_key.to_vec()).as_protobuf(),
150 ),
151 sender_signing_key: Some(
152 storage_proto::sender_key_state_structure::SenderSigningKey {
153 public: signature_key.serialize().to_vec(),
154 private: match signature_private_key {
155 None => vec![],
156 Some(k) => k.serialize().to_vec(),
157 },
158 },
159 ),
160 sender_message_keys: vec![],
161 };
162
163 Self { state }
164 }
165
166 pub(crate) fn from_protobuf(state: storage_proto::SenderKeyStateStructure) -> Self {
167 Self { state }
168 }
169
170 pub(crate) fn message_version(&self) -> u32 {
171 match self.state.message_version {
172 0 => 3, v => v,
174 }
175 }
176
177 pub(crate) fn chain_id(&self) -> u32 {
178 self.state.chain_id
179 }
180
181 pub(crate) fn sender_chain_key(&self) -> Option<SenderChainKey> {
182 let sender_chain = self.state.sender_chain_key.as_ref()?;
183 Some(SenderChainKey::new(
184 sender_chain.iteration,
185 sender_chain.seed.clone(),
186 ))
187 }
188
189 pub(crate) fn set_sender_chain_key(&mut self, chain_key: SenderChainKey) {
190 self.state.sender_chain_key = Some(chain_key.as_protobuf());
191 }
192
193 pub(crate) fn signing_key_public(&self) -> Result<PublicKey, InvalidSessionError> {
194 if let Some(ref signing_key) = self.state.sender_signing_key {
195 PublicKey::try_from(&signing_key.public[..])
196 .map_err(|_| InvalidSessionError("invalid public signing key"))
197 } else {
198 Err(InvalidSessionError("missing signing key"))
199 }
200 }
201
202 pub(crate) fn signing_key_private(&self) -> Result<PrivateKey, InvalidSessionError> {
203 if let Some(ref signing_key) = self.state.sender_signing_key {
204 PrivateKey::deserialize(&signing_key.private)
205 .map_err(|_| InvalidSessionError("invalid private signing key"))
206 } else {
207 Err(InvalidSessionError("missing signing key"))
208 }
209 }
210
211 pub(crate) fn as_protobuf(&self) -> storage_proto::SenderKeyStateStructure {
212 self.state.clone()
213 }
214
215 pub(crate) fn add_sender_message_key(&mut self, sender_message_key: &SenderMessageKey) {
216 self.state
217 .sender_message_keys
218 .push(sender_message_key.as_protobuf());
219 while self.state.sender_message_keys.len() > consts::MAX_MESSAGE_KEYS {
220 self.state.sender_message_keys.remove(0);
221 }
222 }
223
224 pub(crate) fn remove_sender_message_key(&mut self, iteration: u32) -> Option<SenderMessageKey> {
225 if let Some(index) = self
226 .state
227 .sender_message_keys
228 .iter()
229 .position(|x| x.iteration == iteration)
230 {
231 let smk = self.state.sender_message_keys.remove(index);
232 Some(SenderMessageKey::from_protobuf(smk))
233 } else {
234 None
235 }
236 }
237}
238
239#[derive(Debug, Clone)]
240pub struct SenderKeyRecord {
241 states: VecDeque<SenderKeyState>,
242}
243
244impl SenderKeyRecord {
245 pub(crate) fn new_empty() -> Self {
246 Self {
247 states: VecDeque::with_capacity(consts::MAX_SENDER_KEY_STATES),
248 }
249 }
250
251 pub fn deserialize(buf: &[u8]) -> Result<SenderKeyRecord, SignalProtocolError> {
252 let skr = storage_proto::SenderKeyRecordStructure::decode(buf)
253 .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
254
255 let mut states = VecDeque::with_capacity(skr.sender_key_states.len());
256 for state in skr.sender_key_states {
257 states.push_back(SenderKeyState::from_protobuf(state))
258 }
259 Ok(Self { states })
260 }
261
262 pub(crate) fn sender_key_state(&self) -> Result<&SenderKeyState, InvalidSessionError> {
263 if !self.states.is_empty() {
264 return Ok(&self.states[0]);
265 }
266 Err(InvalidSessionError("empty sender key state"))
267 }
268
269 pub(crate) fn sender_key_state_mut(
270 &mut self,
271 ) -> Result<&mut SenderKeyState, InvalidSessionError> {
272 if !self.states.is_empty() {
273 return Ok(&mut self.states[0]);
274 }
275 Err(InvalidSessionError("empty sender key state"))
276 }
277
278 pub(crate) fn sender_key_state_for_chain_id(
279 &mut self,
280 chain_id: u32,
281 ) -> Option<&mut SenderKeyState> {
282 for i in 0..self.states.len() {
283 if self.states[i].chain_id() == chain_id {
284 return Some(&mut self.states[i]);
285 }
286 }
287 None
288 }
289
290 pub(crate) fn chain_ids_for_logging(&self) -> impl ExactSizeIterator<Item = u32> + '_ {
291 self.states.iter().map(|state| state.chain_id())
292 }
293
294 pub(crate) fn add_sender_key_state(
295 &mut self,
296 message_version: u8,
297 chain_id: u32,
298 iteration: u32,
299 chain_key: &[u8],
300 signature_key: PublicKey,
301 signature_private_key: Option<PrivateKey>,
302 ) {
303 let existing_state = self.remove_state(chain_id, signature_key);
304
305 if self.remove_states_with_chain_id(chain_id) > 0 {
306 log::warn!(
307 "Removed a matching chain_id ({chain_id}) found with a different public key"
308 );
309 }
310
311 let state = match existing_state {
312 None => SenderKeyState::new(
313 message_version,
314 chain_id,
315 iteration,
316 chain_key,
317 signature_key,
318 signature_private_key,
319 ),
320 Some(state) => state,
321 };
322
323 while self.states.len() >= consts::MAX_SENDER_KEY_STATES {
324 self.states.pop_back();
325 }
326
327 self.states.push_front(state);
328 }
329
330 fn remove_state(&mut self, chain_id: u32, signature_key: PublicKey) -> Option<SenderKeyState> {
334 let (index, _state) = self.states.iter().find_position(|state| {
335 state.chain_id() == chain_id && state.signing_key_public().ok() == Some(signature_key)
336 })?;
337
338 self.states.remove(index)
339 }
340
341 fn remove_states_with_chain_id(&mut self, chain_id: u32) -> usize {
345 let initial_length = self.states.len();
346 self.states.retain(|state| state.chain_id() != chain_id);
347 initial_length - self.states.len()
348 }
349
350 pub(crate) fn as_protobuf(&self) -> storage_proto::SenderKeyRecordStructure {
351 let mut states = Vec::with_capacity(self.states.len());
352 for state in &self.states {
353 states.push(state.as_protobuf());
354 }
355
356 storage_proto::SenderKeyRecordStructure {
357 sender_key_states: states,
358 }
359 }
360
361 pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
362 Ok(self.as_protobuf().encode_to_vec())
363 }
364}
365
366#[cfg(test)]
367mod sender_key_record_add_sender_key_state_tests {
368 use itertools::Itertools;
369 use rand::rngs::OsRng;
370 use rand::TryRngCore as _;
371
372 use super::*;
373 use crate::KeyPair;
374
375 fn random_public_key() -> PublicKey {
376 KeyPair::generate(&mut OsRng.unwrap_err()).public_key
377 }
378
379 fn chain_key(i: u128) -> Vec<u8> {
380 i.to_be_bytes().to_vec()
381 }
382
383 struct TestContext {
384 sender_key_record: SenderKeyRecord,
385 }
386
387 impl TestContext {
388 fn new() -> Self {
389 Self {
390 sender_key_record: SenderKeyRecord::new_empty(),
391 }
392 }
393
394 fn add_sender_key_state_record(&mut self, record_key: (PublicKey, u32), chain_key: &[u8]) {
397 let (public_key, chain_id) = record_key;
398 self.sender_key_record
399 .add_sender_key_state(1, chain_id, 1, chain_key, public_key, None);
400 }
401
402 fn assert_number_of_states(&self, expected: usize) {
403 assert_eq!(expected, self.sender_key_record.states.len());
404 }
405
406 fn assert_records_chain_key(
409 &mut self,
410 record_key: (PublicKey, u32),
411 expected_chain_key: &[u8],
412 ) {
413 let (public_key, chain_id) = record_key;
414 let found_chain_key = self
415 .sender_key_record
416 .sender_key_state_for_chain_id(chain_id)
417 .expect("Expect to find chain id")
418 .sender_chain_key()
419 .expect("Expect to find chain key")
420 .chain_key;
421
422 assert_eq!(found_chain_key, expected_chain_key);
423
424 let matching_state = self
425 .sender_key_record
426 .states
427 .iter()
428 .filter(|state| {
429 state.chain_id() == chain_id
430 && state.signing_key_public().expect("expect public key") == public_key
431 })
432 .exactly_one()
433 .expect("Expected exactly one record key match");
434
435 assert_eq!(
436 &matching_state
437 .sender_chain_key()
438 .expect("Expect to find chain key")
439 .chain_key,
440 expected_chain_key
441 );
442 }
443
444 fn assert_record_order(&self, order: Vec<(PublicKey, u32)>) {
445 let record_keys = self
446 .sender_key_record
447 .states
448 .iter()
449 .map(|state| {
450 (
451 state.signing_key_public().expect("expect public key"),
452 state.chain_id(),
453 )
454 })
455 .collect::<Vec<_>>();
456
457 assert_eq!(record_keys, order);
458 }
459 }
460
461 #[test]
462 fn add_single_state() {
463 let mut context = TestContext::new();
464
465 let public_key = random_public_key();
466 let chain_id = 1;
467 let chain_key = chain_key(1);
468 let record_key = (public_key, chain_id);
469
470 context.add_sender_key_state_record(record_key, &chain_key);
471
472 context.assert_number_of_states(1);
473 context.assert_records_chain_key(record_key, &chain_key);
474 }
475
476 #[test]
477 fn add_second_state() {
478 let mut context = TestContext::new();
479
480 let chain_id_1 = 1;
481 let chain_id_2 = 2;
482 let record_key_1 = (random_public_key(), chain_id_1);
483 let record_key_2 = (random_public_key(), chain_id_2);
484 let chain_key_1 = chain_key(1);
485 let chain_key_2 = chain_key(2);
486
487 context.add_sender_key_state_record(record_key_1, &chain_key_1);
488 context.add_sender_key_state_record(record_key_2, &chain_key_2);
489
490 context.assert_number_of_states(2);
491 context.assert_records_chain_key(record_key_1, &chain_key_1);
492 context.assert_records_chain_key(record_key_2, &chain_key_2);
493 }
494
495 #[test]
496 fn when_exceed_maximum_states_then_oldest_is_ejected() {
497 assert_eq!(
498 5,
499 consts::MAX_SENDER_KEY_STATES,
500 "Test written to expect this limit"
501 );
502
503 let mut context = TestContext::new();
504
505 let record_key_1 = (random_public_key(), 1);
506 let record_key_2 = (random_public_key(), 2);
507 let record_key_3 = (random_public_key(), 3);
508 let record_key_4 = (random_public_key(), 4);
509 let record_key_5 = (random_public_key(), 5);
510 let record_key_6 = (random_public_key(), 6);
511
512 context.add_sender_key_state_record(record_key_1, &chain_key(1));
513 context.add_sender_key_state_record(record_key_2, &chain_key(2));
514 context.add_sender_key_state_record(record_key_3, &chain_key(3));
515 context.add_sender_key_state_record(record_key_4, &chain_key(4));
516 context.add_sender_key_state_record(record_key_5, &chain_key(5));
517
518 context.assert_record_order(vec![
519 record_key_5,
520 record_key_4,
521 record_key_3,
522 record_key_2,
523 record_key_1,
524 ]);
525
526 context.add_sender_key_state_record(record_key_6, &chain_key(6));
527
528 context.assert_record_order(vec![
529 record_key_6,
530 record_key_5,
531 record_key_4,
532 record_key_3,
533 record_key_2,
534 ]);
535 }
536
537 #[test]
538 fn when_second_state_with_same_public_key_and_chain_id_added_then_it_keeps_first_data() {
539 let mut context = TestContext::new();
540
541 let chain_id = 1;
542 let record_key = (random_public_key(), chain_id);
543 let chain_key_1 = chain_key(1);
544 let chain_key_2 = chain_key(2);
545
546 context.add_sender_key_state_record(record_key, &chain_key_1);
547 context.add_sender_key_state_record(record_key, &chain_key_2);
548
549 context.assert_number_of_states(1);
550 context.assert_records_chain_key(record_key, &chain_key_1);
551 }
552
553 #[test]
554 fn when_second_state_with_different_public_key_but_same_chain_id_added_then_it_gets_replaced() {
555 let mut context = TestContext::new();
556
557 let chain_id = 1;
558 let record_key_1 = (random_public_key(), chain_id);
559 let record_key_2 = (random_public_key(), chain_id);
560 let chain_key_1 = chain_key(1);
561 let chain_key_2 = chain_key(2);
562
563 context.add_sender_key_state_record(record_key_1, &chain_key_1);
564 context.add_sender_key_state_record(record_key_2, &chain_key_2);
565
566 context.assert_number_of_states(1);
567 context.assert_records_chain_key(record_key_2, &chain_key_2);
568 }
569
570 #[test]
571 fn when_second_state_with_same_public_key_and_chain_id_added_then_it_becomes_the_most_recent() {
572 let mut context = TestContext::new();
573
574 let chain_id_1 = 1;
575 let chain_id_2 = 2;
576 let record_key_1 = (random_public_key(), chain_id_1);
577 let record_key_2 = (random_public_key(), chain_id_2);
578 let chain_key_1 = chain_key(1);
579 let chain_key_2 = chain_key(2);
580 let chain_key_3 = chain_key(3);
581
582 context.add_sender_key_state_record(record_key_1, &chain_key_1);
583 context.add_sender_key_state_record(record_key_2, &chain_key_2);
584
585 context.assert_record_order(vec![record_key_2, record_key_1]);
586
587 context.add_sender_key_state_record(record_key_1, &chain_key_3);
588
589 context.assert_record_order(vec![record_key_1, record_key_2]);
590 }
591}
592
593#[cfg(test)]
594mod sender_chain_key_iteration_tests {
595 use std::collections::HashSet;
596
597 use assert_matches::assert_matches;
598
599 use super::SenderChainKey;
600 use crate::SignalProtocolError;
601
602 const INITIAL_ITERATION: u32 = 0;
603 const INITIAL_SEED_KEY: [u8; 4] = [1, 2, 3, 4];
604
605 #[test]
606 fn iteration() {
607 let mut sender_chain_key =
608 SenderChainKey::new(INITIAL_ITERATION, INITIAL_SEED_KEY.to_vec());
609
610 let mut seen_seeds = HashSet::new();
611 seen_seeds.insert(sender_chain_key.seed().to_vec());
612
613 for i in 1..10 {
614 let next_chain_key = sender_chain_key
615 .next()
616 .expect("Expect chain key to not overflow after only a few iterations");
617 let next_seed = next_chain_key.seed().to_vec();
618
619 assert!(
620 seen_seeds.insert(next_seed),
621 "Seed has already been seen before for iteration {i}"
622 );
623 assert_eq!(next_chain_key.iteration(), INITIAL_ITERATION + i);
624
625 sender_chain_key = next_chain_key;
626 }
627 }
628
629 #[test]
630 fn when_sender_chain_key_iteration_overflows() {
631 let sender_chain_key: SenderChainKey =
632 SenderChainKey::new(u32::MAX, INITIAL_SEED_KEY.to_vec());
633 assert_matches!(
634 sender_chain_key.next(),
635 Err(SignalProtocolError::InvalidState { .. })
636 );
637 }
638}