1use std::time::SystemTime;
7
8use rand::{CryptoRng, Rng};
9
10use crate::consts::{MAX_FORWARD_JUMPS, MAX_UNACKNOWLEDGED_SESSION_AGE};
11use crate::ratchet::{ChainKey, MessageKeyGenerator};
12use crate::state::{InvalidSessionError, SessionState};
13use crate::{
14 CiphertextMessage, CiphertextMessageType, Direction, IdentityKeyStore, KeyPair, KyberPayload,
15 KyberPreKeyStore, PreKeySignalMessage, PreKeyStore, ProtocolAddress, PublicKey, Result,
16 SessionRecord, SessionStore, SignalMessage, SignalProtocolError, SignedPreKeyStore, session,
17};
18
19pub async fn message_encrypt<R: Rng + CryptoRng>(
20 ptext: &[u8],
21 remote_address: &ProtocolAddress,
22 session_store: &mut dyn SessionStore,
23 identity_store: &mut dyn IdentityKeyStore,
24 now: SystemTime,
25 csprng: &mut R,
26) -> Result<CiphertextMessage> {
27 let mut session_record = session_store
28 .load_session(remote_address)
29 .await?
30 .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
31 let session_state = session_record
32 .session_state_mut()
33 .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
34
35 let chain_key = session_state.get_sender_chain_key()?;
36
37 let (pqr_msg, pqr_key) = session_state.pq_ratchet_send(csprng).map_err(|e| {
38 SignalProtocolError::InvalidState(
40 "message_encrypt",
41 format!("post-quantum ratchet send error: {e}"),
42 )
43 })?;
44 let message_keys = chain_key.message_keys().generate_keys(pqr_key);
45
46 let sender_ephemeral = session_state.sender_ratchet_key()?;
47 let previous_counter = session_state.previous_counter();
48 let session_version = session_state
49 .session_version()?
50 .try_into()
51 .map_err(|_| SignalProtocolError::InvalidSessionStructure("version does not fit in u8"))?;
52
53 let local_identity_key = session_state.local_identity_key()?;
54 let their_identity_key = session_state.remote_identity_key()?.ok_or_else(|| {
55 SignalProtocolError::InvalidState(
56 "message_encrypt",
57 format!("no remote identity key for {remote_address}"),
58 )
59 })?;
60
61 let ctext =
62 signal_crypto::aes_256_cbc_encrypt(ptext, message_keys.cipher_key(), message_keys.iv())
63 .map_err(|_| {
64 log::error!("session state corrupt for {remote_address}");
65 SignalProtocolError::InvalidSessionStructure("invalid sender chain message keys")
66 })?;
67
68 let message = if let Some(items) = session_state.unacknowledged_pre_key_message_items()? {
69 let timestamp_as_unix_time = items
70 .timestamp()
71 .duration_since(SystemTime::UNIX_EPOCH)
72 .unwrap_or_default()
73 .as_secs();
74 if items.timestamp() + MAX_UNACKNOWLEDGED_SESSION_AGE < now {
75 log::warn!(
76 "stale unacknowledged session for {remote_address} (created at {timestamp_as_unix_time})"
77 );
78 return Err(SignalProtocolError::SessionNotFound(remote_address.clone()));
79 }
80
81 let local_registration_id = session_state.local_registration_id();
82
83 log::info!(
84 "Building PreKeyWhisperMessage for: {} with preKeyId: {} (session created at {})",
85 remote_address,
86 items
87 .pre_key_id()
88 .map_or_else(|| "<none>".to_string(), |id| id.to_string()),
89 timestamp_as_unix_time,
90 );
91
92 let message = SignalMessage::new(
93 session_version,
94 message_keys.mac_key(),
95 sender_ephemeral,
96 chain_key.index(),
97 previous_counter,
98 &ctext,
99 &local_identity_key,
100 &their_identity_key,
101 &pqr_msg,
102 )?;
103
104 let kyber_payload = items
105 .kyber_pre_key_id()
106 .zip(items.kyber_ciphertext())
107 .map(|(id, ciphertext)| KyberPayload::new(id, ciphertext.into()));
108
109 CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new(
110 session_version,
111 local_registration_id,
112 items.pre_key_id(),
113 items.signed_pre_key_id(),
114 kyber_payload,
115 *items.base_key(),
116 local_identity_key,
117 message,
118 )?)
119 } else {
120 CiphertextMessage::SignalMessage(SignalMessage::new(
121 session_version,
122 message_keys.mac_key(),
123 sender_ephemeral,
124 chain_key.index(),
125 previous_counter,
126 &ctext,
127 &local_identity_key,
128 &their_identity_key,
129 &pqr_msg,
130 )?)
131 };
132
133 session_state.set_sender_chain_key(&chain_key.next_chain_key());
134
135 if !identity_store
137 .is_trusted_identity(remote_address, &their_identity_key, Direction::Sending)
138 .await?
139 {
140 log::warn!(
141 "Identity key {} is not trusted for remote address {}",
142 hex::encode(their_identity_key.public_key().public_key_bytes()),
143 remote_address,
144 );
145 return Err(SignalProtocolError::UntrustedIdentity(
146 remote_address.clone(),
147 ));
148 }
149
150 identity_store
152 .save_identity(remote_address, &their_identity_key)
153 .await?;
154
155 session_store
156 .store_session(remote_address, &session_record)
157 .await?;
158 Ok(message)
159}
160
161#[allow(clippy::too_many_arguments)]
162pub async fn message_decrypt<R: Rng + CryptoRng>(
163 ciphertext: &CiphertextMessage,
164 remote_address: &ProtocolAddress,
165 session_store: &mut dyn SessionStore,
166 identity_store: &mut dyn IdentityKeyStore,
167 pre_key_store: &mut dyn PreKeyStore,
168 signed_pre_key_store: &dyn SignedPreKeyStore,
169 kyber_pre_key_store: &mut dyn KyberPreKeyStore,
170 csprng: &mut R,
171) -> Result<Vec<u8>> {
172 match ciphertext {
173 CiphertextMessage::SignalMessage(m) => {
174 message_decrypt_signal(m, remote_address, session_store, identity_store, csprng).await
175 }
176 CiphertextMessage::PreKeySignalMessage(m) => {
177 message_decrypt_prekey(
178 m,
179 remote_address,
180 session_store,
181 identity_store,
182 pre_key_store,
183 signed_pre_key_store,
184 kyber_pre_key_store,
185 csprng,
186 )
187 .await
188 }
189 _ => Err(SignalProtocolError::InvalidArgument(format!(
190 "message_decrypt cannot be used to decrypt {:?} messages",
191 ciphertext.message_type()
192 ))),
193 }
194}
195
196#[allow(clippy::too_many_arguments)]
197pub async fn message_decrypt_prekey<R: Rng + CryptoRng>(
198 ciphertext: &PreKeySignalMessage,
199 remote_address: &ProtocolAddress,
200 session_store: &mut dyn SessionStore,
201 identity_store: &mut dyn IdentityKeyStore,
202 pre_key_store: &mut dyn PreKeyStore,
203 signed_pre_key_store: &dyn SignedPreKeyStore,
204 kyber_pre_key_store: &mut dyn KyberPreKeyStore,
205 csprng: &mut R,
206) -> Result<Vec<u8>> {
207 let mut session_record = session_store
208 .load_session(remote_address)
209 .await?
210 .unwrap_or_else(SessionRecord::new_fresh);
211
212 let process_prekey_result = session::process_prekey(
214 ciphertext,
215 remote_address,
216 &mut session_record,
217 identity_store,
218 pre_key_store,
219 signed_pre_key_store,
220 kyber_pre_key_store,
221 )
222 .await;
223
224 let (pre_key_used, identity_to_save) = match process_prekey_result {
225 Ok(result) => result,
226 Err(e) => {
227 let errs = [e];
228 log::error!(
229 "{}",
230 create_decryption_failure_log(
231 remote_address,
232 &errs,
233 &session_record,
234 ciphertext.message()
235 )?
236 );
237 let [e] = errs;
238 return Err(e);
239 }
240 };
241
242 let ptext = decrypt_message_with_record(
243 remote_address,
244 &mut session_record,
245 ciphertext.message(),
246 CiphertextMessageType::PreKey,
247 csprng,
248 )?;
249
250 identity_store
251 .save_identity(
252 identity_to_save.remote_address,
253 identity_to_save.their_identity_key,
254 )
255 .await?;
256
257 if let Some(pre_key_used) = pre_key_used {
258 if let Some(kyber_pre_key_id) = pre_key_used.kyber_pre_key_id {
259 kyber_pre_key_store
260 .mark_kyber_pre_key_used(
261 kyber_pre_key_id,
262 pre_key_used.signed_ec_pre_key_id,
263 ciphertext.base_key(),
264 )
265 .await?;
266 }
267
268 if let Some(pre_key_id) = pre_key_used.one_time_ec_pre_key_id {
269 pre_key_store.remove_pre_key(pre_key_id).await?;
270 }
271 }
272
273 session_store
274 .store_session(remote_address, &session_record)
275 .await?;
276
277 Ok(ptext)
278}
279
280pub async fn message_decrypt_signal<R: Rng + CryptoRng>(
281 ciphertext: &SignalMessage,
282 remote_address: &ProtocolAddress,
283 session_store: &mut dyn SessionStore,
284 identity_store: &mut dyn IdentityKeyStore,
285 csprng: &mut R,
286) -> Result<Vec<u8>> {
287 let mut session_record = session_store
288 .load_session(remote_address)
289 .await?
290 .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
291
292 let ptext = decrypt_message_with_record(
293 remote_address,
294 &mut session_record,
295 ciphertext,
296 CiphertextMessageType::Whisper,
297 csprng,
298 )?;
299
300 let their_identity_key = session_record
302 .session_state()
303 .expect("successfully decrypted; must have a current state")
304 .remote_identity_key()
305 .expect("successfully decrypted; must have a remote identity key")
306 .expect("successfully decrypted; must have a remote identity key");
307
308 if !identity_store
309 .is_trusted_identity(remote_address, &their_identity_key, Direction::Receiving)
310 .await?
311 {
312 log::warn!(
313 "Identity key {} is not trusted for remote address {}",
314 hex::encode(their_identity_key.public_key().public_key_bytes()),
315 remote_address,
316 );
317 return Err(SignalProtocolError::UntrustedIdentity(
318 remote_address.clone(),
319 ));
320 }
321
322 identity_store
323 .save_identity(remote_address, &their_identity_key)
324 .await?;
325
326 session_store
327 .store_session(remote_address, &session_record)
328 .await?;
329
330 Ok(ptext)
331}
332
333fn create_decryption_failure_log(
334 remote_address: &ProtocolAddress,
335 mut errs: &[SignalProtocolError],
336 record: &SessionRecord,
337 ciphertext: &SignalMessage,
338) -> Result<String> {
339 fn append_session_summary(
340 lines: &mut Vec<String>,
341 idx: usize,
342 state: std::result::Result<&SessionState, InvalidSessionError>,
343 err: Option<&SignalProtocolError>,
344 ) {
345 let chains = state.map(|state| state.all_receiver_chain_logging_info());
346 match (err, &chains) {
347 (Some(err), Ok(chains)) => {
348 lines.push(format!(
349 "Candidate session {} failed with '{}', had {} receiver chains",
350 idx,
351 err,
352 chains.len()
353 ));
354 }
355 (Some(err), Err(state_err)) => {
356 lines.push(format!(
357 "Candidate session {idx} failed with '{err}'; cannot get receiver chain info ({state_err})",
358 ));
359 }
360 (None, Ok(chains)) => {
361 lines.push(format!(
362 "Candidate session {} had {} receiver chains",
363 idx,
364 chains.len()
365 ));
366 }
367 (None, Err(state_err)) => {
368 lines.push(format!(
369 "Candidate session {idx}: cannot get receiver chain info ({state_err})",
370 ));
371 }
372 }
373
374 if let Ok(chains) = chains {
375 for chain in chains {
376 let chain_idx = match chain.1 {
377 Some(i) => i.to_string(),
378 None => "missing in protobuf".to_string(),
379 };
380
381 lines.push(format!(
382 "Receiver chain with sender ratchet public key {} chain key index {}",
383 hex::encode(chain.0),
384 chain_idx
385 ));
386 }
387 }
388 }
389
390 let mut lines = vec![];
391
392 lines.push(format!(
393 "Message from {} failed to decrypt; sender ratchet public key {} message counter {}",
394 remote_address,
395 hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()),
396 ciphertext.counter()
397 ));
398
399 if let Some(current_session) = record.session_state() {
400 let err = errs.first();
401 if err.is_some() {
402 errs = &errs[1..];
403 }
404 append_session_summary(&mut lines, 0, Ok(current_session), err);
405 } else {
406 lines.push("No current session".to_string());
407 }
408
409 for (idx, (state, err)) in record
410 .previous_session_states()
411 .zip(errs.iter().map(Some).chain(std::iter::repeat(None)))
412 .enumerate()
413 {
414 let state = match state {
415 Ok(ref state) => Ok(state),
416 Err(err) => Err(err),
417 };
418 append_session_summary(&mut lines, idx + 1, state, err);
419 }
420
421 Ok(lines.join("\n"))
422}
423
424fn decrypt_message_with_record<R: Rng + CryptoRng>(
425 remote_address: &ProtocolAddress,
426 record: &mut SessionRecord,
427 ciphertext: &SignalMessage,
428 original_message_type: CiphertextMessageType,
429 csprng: &mut R,
430) -> Result<Vec<u8>> {
431 debug_assert!(matches!(
432 original_message_type,
433 CiphertextMessageType::Whisper | CiphertextMessageType::PreKey
434 ));
435 let log_decryption_failure = |state: &SessionState, error: &SignalProtocolError| {
436 log::warn!(
438 "Failed to decrypt {:?} message with ratchet key: {} and counter: {}. \
439 Session loaded for {}. Local session has base key: {} and counter: {}. {}",
440 original_message_type,
441 hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()),
442 ciphertext.counter(),
443 remote_address,
444 state
445 .sender_ratchet_key_for_logging()
446 .unwrap_or_else(|e| format!("<error: {e}>")),
447 state.previous_counter(),
448 error
449 );
450 };
451
452 let mut errs = vec![];
453
454 if let Some(current_state) = record.session_state() {
455 let mut current_state = current_state.clone();
456 let result = decrypt_message_with_state(
457 CurrentOrPrevious::Current,
458 &mut current_state,
459 ciphertext,
460 original_message_type,
461 remote_address,
462 csprng,
463 );
464
465 match result {
466 Ok(ptext) => {
467 log::info!(
468 "decrypted {:?} message from {} with current session state (base key {})",
469 original_message_type,
470 remote_address,
471 current_state
472 .sender_ratchet_key_for_logging()
473 .expect("successful decrypt always has a valid base key"),
474 );
475 record.set_session_state(current_state); return Ok(ptext);
477 }
478 Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
479 return result;
480 }
481 Err(e) => {
482 log_decryption_failure(¤t_state, &e);
483 errs.push(e);
484 match original_message_type {
485 CiphertextMessageType::PreKey => {
486 log::error!(
489 "{}",
490 create_decryption_failure_log(
491 remote_address,
492 &errs,
493 record,
494 ciphertext
495 )?
496 );
497 return Err(SignalProtocolError::InvalidMessage(
500 original_message_type,
501 "decryption failed",
502 ));
503 }
504 CiphertextMessageType::Whisper => {}
505 CiphertextMessageType::SenderKey | CiphertextMessageType::Plaintext => {
506 unreachable!("should not be using Double Ratchet for these")
507 }
508 }
509 }
510 }
511 }
512
513 let mut updated_session = None;
515
516 for (idx, previous) in record.previous_session_states().enumerate() {
517 let mut previous = previous?;
518
519 let result = decrypt_message_with_state(
520 CurrentOrPrevious::Previous,
521 &mut previous,
522 ciphertext,
523 original_message_type,
524 remote_address,
525 csprng,
526 );
527
528 match result {
529 Ok(ptext) => {
530 log::info!(
531 "decrypted {:?} message from {} with PREVIOUS session state (base key {})",
532 original_message_type,
533 remote_address,
534 previous
535 .sender_ratchet_key_for_logging()
536 .expect("successful decrypt always has a valid base key"),
537 );
538 updated_session = Some((ptext, idx, previous));
539 break;
540 }
541 Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
542 return result;
543 }
544 Err(e) => {
545 log_decryption_failure(&previous, &e);
546 errs.push(e);
547 }
548 }
549 }
550
551 if let Some((ptext, idx, updated_session)) = updated_session {
552 record.promote_old_session(idx, updated_session);
553 Ok(ptext)
554 } else {
555 let previous_state_count = || record.previous_session_states().len();
556
557 if let Some(current_state) = record.session_state() {
558 log::error!(
559 "No valid session for recipient: {}, current session base key {}, number of previous states: {}",
560 remote_address,
561 current_state
562 .sender_ratchet_key_for_logging()
563 .unwrap_or_else(|e| format!("<error: {e}>")),
564 previous_state_count(),
565 );
566 } else {
567 log::error!(
568 "No valid session for recipient: {}, (no current session state), number of previous states: {}",
569 remote_address,
570 previous_state_count(),
571 );
572 }
573 log::error!(
574 "{}",
575 create_decryption_failure_log(remote_address, &errs, record, ciphertext)?
576 );
577 Err(SignalProtocolError::InvalidMessage(
578 original_message_type,
579 "decryption failed",
580 ))
581 }
582}
583
584#[derive(Clone, Copy)]
585enum CurrentOrPrevious {
586 Current,
587 Previous,
588}
589
590impl std::fmt::Display for CurrentOrPrevious {
591 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
592 match self {
593 Self::Current => write!(f, "current"),
594 Self::Previous => write!(f, "previous"),
595 }
596 }
597}
598
599fn decrypt_message_with_state<R: Rng + CryptoRng>(
600 current_or_previous: CurrentOrPrevious,
601 state: &mut SessionState,
602 ciphertext: &SignalMessage,
603 original_message_type: CiphertextMessageType,
604 remote_address: &ProtocolAddress,
605 csprng: &mut R,
606) -> Result<Vec<u8>> {
607 let _ = state.root_key().map_err(|_| {
609 SignalProtocolError::InvalidMessage(
610 original_message_type,
611 "No session available to decrypt",
612 )
613 })?;
614
615 let ciphertext_version = ciphertext.message_version() as u32;
616 if ciphertext_version != state.session_version()? {
617 return Err(SignalProtocolError::UnrecognizedMessageVersion(
618 ciphertext_version,
619 ));
620 }
621
622 let their_ephemeral = ciphertext.sender_ratchet_key();
623 let counter = ciphertext.counter();
624 let chain_key = get_or_create_chain_key(state, their_ephemeral, remote_address, csprng)?;
625 let message_key_gen = get_or_create_message_key(
626 state,
627 their_ephemeral,
628 remote_address,
629 original_message_type,
630 &chain_key,
631 counter,
632 )?;
633 let pqr_key = state
634 .pq_ratchet_recv(ciphertext.pq_ratchet())
635 .map_err(|e| match e {
636 spqr::Error::StateDecode => SignalProtocolError::InvalidState(
637 "decrypt_message_with_state",
638 format!("post-quantum ratchet error: {e}"),
639 ),
640 _ => {
641 log::info!("post-quantum ratchet error in decrypt_message_with_state: {e}");
642 SignalProtocolError::InvalidMessage(
643 original_message_type,
644 "post-quantum ratchet error",
645 )
646 }
647 })?;
648 let message_keys = message_key_gen.generate_keys(pqr_key);
649
650 let their_identity_key =
651 state
652 .remote_identity_key()?
653 .ok_or(SignalProtocolError::InvalidSessionStructure(
654 "cannot decrypt without remote identity key",
655 ))?;
656
657 let mac_valid = ciphertext.verify_mac(
658 &their_identity_key,
659 &state.local_identity_key()?,
660 message_keys.mac_key(),
661 )?;
662
663 if !mac_valid {
664 return Err(SignalProtocolError::InvalidMessage(
665 original_message_type,
666 "MAC verification failed",
667 ));
668 }
669
670 let ptext = match signal_crypto::aes_256_cbc_decrypt(
671 ciphertext.body(),
672 message_keys.cipher_key(),
673 message_keys.iv(),
674 ) {
675 Ok(ptext) => ptext,
676 Err(signal_crypto::DecryptionError::BadKeyOrIv) => {
677 log::warn!("{current_or_previous} session state corrupt for {remote_address}",);
678 return Err(SignalProtocolError::InvalidSessionStructure(
679 "invalid receiver chain message keys",
680 ));
681 }
682 Err(signal_crypto::DecryptionError::BadCiphertext(msg)) => {
683 log::warn!("failed to decrypt 1:1 message: {msg}");
684 return Err(SignalProtocolError::InvalidMessage(
685 original_message_type,
686 "failed to decrypt",
687 ));
688 }
689 };
690
691 state.clear_unacknowledged_pre_key_message();
692
693 Ok(ptext)
694}
695
696fn get_or_create_chain_key<R: Rng + CryptoRng>(
697 state: &mut SessionState,
698 their_ephemeral: &PublicKey,
699 remote_address: &ProtocolAddress,
700 csprng: &mut R,
701) -> Result<ChainKey> {
702 if let Some(chain) = state.get_receiver_chain_key(their_ephemeral)? {
703 log::debug!("{remote_address} has existing receiver chain.");
704 return Ok(chain);
705 }
706
707 log::info!("{remote_address} creating new chains.");
708
709 let root_key = state.root_key()?;
710 let our_ephemeral = state.sender_ratchet_private_key()?;
711 let receiver_chain = root_key.create_chain(their_ephemeral, &our_ephemeral)?;
712 let our_new_ephemeral = KeyPair::generate(csprng);
713 let sender_chain = receiver_chain
714 .0
715 .create_chain(their_ephemeral, &our_new_ephemeral.private_key)?;
716
717 state.set_root_key(&sender_chain.0);
718 state.add_receiver_chain(their_ephemeral, &receiver_chain.1);
719
720 let current_index = state.get_sender_chain_key()?.index();
721 let previous_index = if current_index > 0 {
722 current_index - 1
723 } else {
724 0
725 };
726 state.set_previous_counter(previous_index);
727 state.set_sender_chain(&our_new_ephemeral, &sender_chain.1);
728
729 Ok(receiver_chain.1)
730}
731
732fn get_or_create_message_key(
733 state: &mut SessionState,
734 their_ephemeral: &PublicKey,
735 remote_address: &ProtocolAddress,
736 original_message_type: CiphertextMessageType,
737 chain_key: &ChainKey,
738 counter: u32,
739) -> Result<MessageKeyGenerator> {
740 let chain_index = chain_key.index();
741
742 if chain_index > counter {
743 return match state.get_message_keys(their_ephemeral, counter)? {
744 Some(keys) => Ok(keys),
745 None => {
746 log::info!("{remote_address} Duplicate message for counter: {counter}");
747 Err(SignalProtocolError::DuplicatedMessage(chain_index, counter))
748 }
749 };
750 }
751
752 assert!(chain_index <= counter);
753
754 let jump = (counter - chain_index) as usize;
755
756 if jump > MAX_FORWARD_JUMPS {
757 if state.session_with_self()? {
758 log::info!(
759 "{remote_address} Jumping ahead {jump} messages (index: {chain_index}, counter: {counter})"
760 );
761 } else {
762 log::error!(
763 "{remote_address} Exceeded future message limit: {MAX_FORWARD_JUMPS}, index: {chain_index}, counter: {counter})"
764 );
765 return Err(SignalProtocolError::InvalidMessage(
766 original_message_type,
767 "message from too far into the future",
768 ));
769 }
770 }
771
772 let mut chain_key = chain_key.clone();
773
774 while chain_key.index() < counter {
775 let message_keys = chain_key.message_keys();
776 state.set_message_keys(their_ephemeral, message_keys)?;
777 chain_key = chain_key.next_chain_key();
778 }
779
780 state.set_receiver_chain_key(their_ephemeral, &chain_key.next_chain_key())?;
781 Ok(chain_key.message_keys())
782}