1use std::collections::VecDeque;
7
8use itertools::Itertools;
9use prost::Message;
10
11use crate::crypto::hmac_sha256;
12use crate::proto::storage as storage_proto;
13use crate::{consts, PrivateKey, PublicKey, SignalProtocolError};
14
15#[derive(Debug)]
17pub(crate) struct InvalidSessionError(&'static str);
18
19impl std::fmt::Display for InvalidSessionError {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 self.0.fmt(f)
22 }
23}
24
25#[derive(Debug, Clone)]
26pub(crate) struct SenderMessageKey {
27 iteration: u32,
28 iv: Vec<u8>,
29 cipher_key: Vec<u8>,
30 seed: Vec<u8>,
31}
32
33impl SenderMessageKey {
34 pub(crate) fn new(iteration: u32, seed: Vec<u8>) -> Self {
35 let mut derived = [0; 48];
36 hkdf::Hkdf::<sha2::Sha256>::new(None, &seed)
37 .expand(b"WhisperGroup", &mut derived)
38 .expect("valid output length");
39 Self {
40 iteration,
41 seed,
42 iv: derived[0..16].to_vec(),
43 cipher_key: derived[16..48].to_vec(),
44 }
45 }
46
47 pub(crate) fn from_protobuf(
48 smk: storage_proto::sender_key_state_structure::SenderMessageKey,
49 ) -> Self {
50 Self::new(smk.iteration, smk.seed)
51 }
52
53 pub(crate) fn iteration(&self) -> u32 {
54 self.iteration
55 }
56
57 pub(crate) fn iv(&self) -> &[u8] {
58 &self.iv
59 }
60
61 pub(crate) fn cipher_key(&self) -> &[u8] {
62 &self.cipher_key
63 }
64
65 pub(crate) fn as_protobuf(
66 &self,
67 ) -> storage_proto::sender_key_state_structure::SenderMessageKey {
68 storage_proto::sender_key_state_structure::SenderMessageKey {
69 iteration: self.iteration,
70 seed: self.seed.clone(),
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
76pub(crate) struct SenderChainKey {
77 iteration: u32,
78 chain_key: Vec<u8>,
79}
80
81impl SenderChainKey {
82 const MESSAGE_KEY_SEED: u8 = 0x01;
83 const CHAIN_KEY_SEED: u8 = 0x02;
84
85 pub(crate) fn new(iteration: u32, chain_key: Vec<u8>) -> Self {
86 Self {
87 iteration,
88 chain_key,
89 }
90 }
91
92 pub(crate) fn iteration(&self) -> u32 {
93 self.iteration
94 }
95
96 pub(crate) fn seed(&self) -> &[u8] {
97 &self.chain_key
98 }
99
100 pub(crate) fn next(&self) -> SenderChainKey {
101 SenderChainKey::new(
102 self.iteration + 1,
103 self.get_derivative(Self::CHAIN_KEY_SEED),
104 )
105 }
106
107 pub(crate) fn sender_message_key(&self) -> SenderMessageKey {
108 SenderMessageKey::new(self.iteration, self.get_derivative(Self::MESSAGE_KEY_SEED))
109 }
110
111 fn get_derivative(&self, label: u8) -> Vec<u8> {
112 let label = [label];
113 hmac_sha256(&self.chain_key, &label).to_vec()
114 }
115
116 pub(crate) fn as_protobuf(&self) -> storage_proto::sender_key_state_structure::SenderChainKey {
117 storage_proto::sender_key_state_structure::SenderChainKey {
118 iteration: self.iteration,
119 seed: self.chain_key.clone(),
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
125pub(crate) struct SenderKeyState {
126 state: storage_proto::SenderKeyStateStructure,
127}
128
129impl SenderKeyState {
130 pub(crate) fn new(
131 message_version: u8,
132 chain_id: u32,
133 iteration: u32,
134 chain_key: &[u8],
135 signature_key: PublicKey,
136 signature_private_key: Option<PrivateKey>,
137 ) -> SenderKeyState {
138 let state = storage_proto::SenderKeyStateStructure {
139 message_version: message_version as u32,
140 chain_id,
141 sender_chain_key: Some(
142 SenderChainKey::new(iteration, chain_key.to_vec()).as_protobuf(),
143 ),
144 sender_signing_key: Some(
145 storage_proto::sender_key_state_structure::SenderSigningKey {
146 public: signature_key.serialize().to_vec(),
147 private: match signature_private_key {
148 None => vec![],
149 Some(k) => k.serialize().to_vec(),
150 },
151 },
152 ),
153 sender_message_keys: vec![],
154 };
155
156 Self { state }
157 }
158
159 pub(crate) fn from_protobuf(state: storage_proto::SenderKeyStateStructure) -> Self {
160 Self { state }
161 }
162
163 pub(crate) fn message_version(&self) -> u32 {
164 match self.state.message_version {
165 0 => 3, v => v,
167 }
168 }
169
170 pub(crate) fn chain_id(&self) -> u32 {
171 self.state.chain_id
172 }
173
174 pub(crate) fn sender_chain_key(&self) -> Option<SenderChainKey> {
175 let sender_chain = self.state.sender_chain_key.as_ref()?;
176 Some(SenderChainKey::new(
177 sender_chain.iteration,
178 sender_chain.seed.clone(),
179 ))
180 }
181
182 pub(crate) fn set_sender_chain_key(&mut self, chain_key: SenderChainKey) {
183 self.state.sender_chain_key = Some(chain_key.as_protobuf());
184 }
185
186 pub(crate) fn signing_key_public(&self) -> Result<PublicKey, InvalidSessionError> {
187 if let Some(ref signing_key) = self.state.sender_signing_key {
188 PublicKey::try_from(&signing_key.public[..])
189 .map_err(|_| InvalidSessionError("invalid public signing key"))
190 } else {
191 Err(InvalidSessionError("missing signing key"))
192 }
193 }
194
195 pub(crate) fn signing_key_private(&self) -> Result<PrivateKey, InvalidSessionError> {
196 if let Some(ref signing_key) = self.state.sender_signing_key {
197 PrivateKey::deserialize(&signing_key.private)
198 .map_err(|_| InvalidSessionError("invalid private signing key"))
199 } else {
200 Err(InvalidSessionError("missing signing key"))
201 }
202 }
203
204 pub(crate) fn as_protobuf(&self) -> storage_proto::SenderKeyStateStructure {
205 self.state.clone()
206 }
207
208 pub(crate) fn add_sender_message_key(&mut self, sender_message_key: &SenderMessageKey) {
209 self.state
210 .sender_message_keys
211 .push(sender_message_key.as_protobuf());
212 while self.state.sender_message_keys.len() > consts::MAX_MESSAGE_KEYS {
213 self.state.sender_message_keys.remove(0);
214 }
215 }
216
217 pub(crate) fn remove_sender_message_key(&mut self, iteration: u32) -> Option<SenderMessageKey> {
218 if let Some(index) = self
219 .state
220 .sender_message_keys
221 .iter()
222 .position(|x| x.iteration == iteration)
223 {
224 let smk = self.state.sender_message_keys.remove(index);
225 Some(SenderMessageKey::from_protobuf(smk))
226 } else {
227 None
228 }
229 }
230}
231
232#[derive(Debug, Clone)]
233pub struct SenderKeyRecord {
234 states: VecDeque<SenderKeyState>,
235}
236
237impl SenderKeyRecord {
238 pub(crate) fn new_empty() -> Self {
239 Self {
240 states: VecDeque::with_capacity(consts::MAX_SENDER_KEY_STATES),
241 }
242 }
243
244 pub fn deserialize(buf: &[u8]) -> Result<SenderKeyRecord, SignalProtocolError> {
245 let skr = storage_proto::SenderKeyRecordStructure::decode(buf)
246 .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
247
248 let mut states = VecDeque::with_capacity(skr.sender_key_states.len());
249 for state in skr.sender_key_states {
250 states.push_back(SenderKeyState::from_protobuf(state))
251 }
252 Ok(Self { states })
253 }
254
255 pub(crate) fn sender_key_state(&self) -> Result<&SenderKeyState, InvalidSessionError> {
256 if !self.states.is_empty() {
257 return Ok(&self.states[0]);
258 }
259 Err(InvalidSessionError("empty sender key state"))
260 }
261
262 pub(crate) fn sender_key_state_mut(
263 &mut self,
264 ) -> Result<&mut SenderKeyState, InvalidSessionError> {
265 if !self.states.is_empty() {
266 return Ok(&mut self.states[0]);
267 }
268 Err(InvalidSessionError("empty sender key state"))
269 }
270
271 pub(crate) fn sender_key_state_for_chain_id(
272 &mut self,
273 chain_id: u32,
274 ) -> Option<&mut SenderKeyState> {
275 for i in 0..self.states.len() {
276 if self.states[i].chain_id() == chain_id {
277 return Some(&mut self.states[i]);
278 }
279 }
280 None
281 }
282
283 pub(crate) fn chain_ids_for_logging(&self) -> impl ExactSizeIterator<Item = u32> + '_ {
284 self.states.iter().map(|state| state.chain_id())
285 }
286
287 pub(crate) fn add_sender_key_state(
288 &mut self,
289 message_version: u8,
290 chain_id: u32,
291 iteration: u32,
292 chain_key: &[u8],
293 signature_key: PublicKey,
294 signature_private_key: Option<PrivateKey>,
295 ) {
296 let existing_state = self.remove_state(chain_id, signature_key);
297
298 if self.remove_states_with_chain_id(chain_id) > 0 {
299 log::warn!(
300 "Removed a matching chain_id ({}) found with a different public key",
301 chain_id
302 );
303 }
304
305 let state = match existing_state {
306 None => SenderKeyState::new(
307 message_version,
308 chain_id,
309 iteration,
310 chain_key,
311 signature_key,
312 signature_private_key,
313 ),
314 Some(state) => state,
315 };
316
317 while self.states.len() >= consts::MAX_SENDER_KEY_STATES {
318 self.states.pop_back();
319 }
320
321 self.states.push_front(state);
322 }
323
324 fn remove_state(&mut self, chain_id: u32, signature_key: PublicKey) -> Option<SenderKeyState> {
328 let (index, _state) = self.states.iter().find_position(|state| {
329 state.chain_id() == chain_id && state.signing_key_public().ok() == Some(signature_key)
330 })?;
331
332 self.states.remove(index)
333 }
334
335 fn remove_states_with_chain_id(&mut self, chain_id: u32) -> usize {
339 let initial_length = self.states.len();
340 self.states.retain(|state| state.chain_id() != chain_id);
341 initial_length - self.states.len()
342 }
343
344 pub(crate) fn as_protobuf(&self) -> storage_proto::SenderKeyRecordStructure {
345 let mut states = Vec::with_capacity(self.states.len());
346 for state in &self.states {
347 states.push(state.as_protobuf());
348 }
349
350 storage_proto::SenderKeyRecordStructure {
351 sender_key_states: states,
352 }
353 }
354
355 pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
356 Ok(self.as_protobuf().encode_to_vec())
357 }
358}
359
360#[cfg(test)]
361mod sender_key_record_add_sender_key_state_tests {
362 use itertools::Itertools;
363 use rand::rngs::OsRng;
364
365 use super::*;
366 use crate::KeyPair;
367
368 fn random_public_key() -> PublicKey {
369 KeyPair::generate(&mut OsRng).public_key
370 }
371
372 fn chain_key(i: u128) -> Vec<u8> {
373 i.to_be_bytes().to_vec()
374 }
375
376 struct TestContext {
377 sender_key_record: SenderKeyRecord,
378 }
379
380 impl TestContext {
381 fn new() -> Self {
382 Self {
383 sender_key_record: SenderKeyRecord::new_empty(),
384 }
385 }
386
387 fn add_sender_key_state_record(&mut self, record_key: (PublicKey, u32), chain_key: &[u8]) {
390 let (public_key, chain_id) = record_key;
391 self.sender_key_record
392 .add_sender_key_state(1, chain_id, 1, chain_key, public_key, None);
393 }
394
395 fn assert_number_of_states(&self, expected: usize) {
396 assert_eq!(expected, self.sender_key_record.states.len());
397 }
398
399 fn assert_records_chain_key(
402 &mut self,
403 record_key: (PublicKey, u32),
404 expected_chain_key: &[u8],
405 ) {
406 let (public_key, chain_id) = record_key;
407 let found_chain_key = self
408 .sender_key_record
409 .sender_key_state_for_chain_id(chain_id)
410 .expect("Expect to find chain id")
411 .sender_chain_key()
412 .expect("Expect to find chain key")
413 .chain_key;
414
415 assert_eq!(found_chain_key, expected_chain_key);
416
417 let matching_state = self
418 .sender_key_record
419 .states
420 .iter()
421 .filter(|state| {
422 state.chain_id() == chain_id
423 && state.signing_key_public().expect("expect public key") == public_key
424 })
425 .exactly_one()
426 .expect("Expected exactly one record key match");
427
428 assert_eq!(
429 &matching_state
430 .sender_chain_key()
431 .expect("Expect to find chain key")
432 .chain_key,
433 expected_chain_key
434 );
435 }
436
437 fn assert_record_order(&self, order: Vec<(PublicKey, u32)>) {
438 let record_keys = self
439 .sender_key_record
440 .states
441 .iter()
442 .map(|state| {
443 (
444 state.signing_key_public().expect("expect public key"),
445 state.chain_id(),
446 )
447 })
448 .collect::<Vec<_>>();
449
450 assert_eq!(record_keys, order);
451 }
452 }
453
454 #[test]
455 fn add_single_state() {
456 let mut context = TestContext::new();
457
458 let public_key = random_public_key();
459 let chain_id = 1;
460 let chain_key = chain_key(1);
461 let record_key = (public_key, chain_id);
462
463 context.add_sender_key_state_record(record_key, &chain_key);
464
465 context.assert_number_of_states(1);
466 context.assert_records_chain_key(record_key, &chain_key);
467 }
468
469 #[test]
470 fn add_second_state() {
471 let mut context = TestContext::new();
472
473 let chain_id_1 = 1;
474 let chain_id_2 = 2;
475 let record_key_1 = (random_public_key(), chain_id_1);
476 let record_key_2 = (random_public_key(), chain_id_2);
477 let chain_key_1 = chain_key(1);
478 let chain_key_2 = chain_key(2);
479
480 context.add_sender_key_state_record(record_key_1, &chain_key_1);
481 context.add_sender_key_state_record(record_key_2, &chain_key_2);
482
483 context.assert_number_of_states(2);
484 context.assert_records_chain_key(record_key_1, &chain_key_1);
485 context.assert_records_chain_key(record_key_2, &chain_key_2);
486 }
487
488 #[test]
489 fn when_exceed_maximum_states_then_oldest_is_ejected() {
490 assert_eq!(
491 5,
492 consts::MAX_SENDER_KEY_STATES,
493 "Test written to expect this limit"
494 );
495
496 let mut context = TestContext::new();
497
498 let record_key_1 = (random_public_key(), 1);
499 let record_key_2 = (random_public_key(), 2);
500 let record_key_3 = (random_public_key(), 3);
501 let record_key_4 = (random_public_key(), 4);
502 let record_key_5 = (random_public_key(), 5);
503 let record_key_6 = (random_public_key(), 6);
504
505 context.add_sender_key_state_record(record_key_1, &chain_key(1));
506 context.add_sender_key_state_record(record_key_2, &chain_key(2));
507 context.add_sender_key_state_record(record_key_3, &chain_key(3));
508 context.add_sender_key_state_record(record_key_4, &chain_key(4));
509 context.add_sender_key_state_record(record_key_5, &chain_key(5));
510
511 context.assert_record_order(vec![
512 record_key_5,
513 record_key_4,
514 record_key_3,
515 record_key_2,
516 record_key_1,
517 ]);
518
519 context.add_sender_key_state_record(record_key_6, &chain_key(6));
520
521 context.assert_record_order(vec![
522 record_key_6,
523 record_key_5,
524 record_key_4,
525 record_key_3,
526 record_key_2,
527 ]);
528 }
529
530 #[test]
531 fn when_second_state_with_same_public_key_and_chain_id_added_then_it_keeps_first_data() {
532 let mut context = TestContext::new();
533
534 let chain_id = 1;
535 let record_key = (random_public_key(), chain_id);
536 let chain_key_1 = chain_key(1);
537 let chain_key_2 = chain_key(2);
538
539 context.add_sender_key_state_record(record_key, &chain_key_1);
540 context.add_sender_key_state_record(record_key, &chain_key_2);
541
542 context.assert_number_of_states(1);
543 context.assert_records_chain_key(record_key, &chain_key_1);
544 }
545
546 #[test]
547 fn when_second_state_with_different_public_key_but_same_chain_id_added_then_it_gets_replaced() {
548 let mut context = TestContext::new();
549
550 let chain_id = 1;
551 let record_key_1 = (random_public_key(), chain_id);
552 let record_key_2 = (random_public_key(), chain_id);
553 let chain_key_1 = chain_key(1);
554 let chain_key_2 = chain_key(2);
555
556 context.add_sender_key_state_record(record_key_1, &chain_key_1);
557 context.add_sender_key_state_record(record_key_2, &chain_key_2);
558
559 context.assert_number_of_states(1);
560 context.assert_records_chain_key(record_key_2, &chain_key_2);
561 }
562
563 #[test]
564 fn when_second_state_with_same_public_key_and_chain_id_added_then_it_becomes_the_most_recent() {
565 let mut context = TestContext::new();
566
567 let chain_id_1 = 1;
568 let chain_id_2 = 2;
569 let record_key_1 = (random_public_key(), chain_id_1);
570 let record_key_2 = (random_public_key(), chain_id_2);
571 let chain_key_1 = chain_key(1);
572 let chain_key_2 = chain_key(2);
573 let chain_key_3 = chain_key(3);
574
575 context.add_sender_key_state_record(record_key_1, &chain_key_1);
576 context.add_sender_key_state_record(record_key_2, &chain_key_2);
577
578 context.assert_record_order(vec![record_key_2, record_key_1]);
579
580 context.add_sender_key_state_record(record_key_1, &chain_key_3);
581
582 context.assert_record_order(vec![record_key_1, record_key_2]);
583 }
584}