1use std::time::SystemTime;
7
8use rand::{CryptoRng, Rng};
9
10use crate::consts::{MAX_FORWARD_JUMPS, MAX_UNACKNOWLEDGED_SESSION_AGE};
11use crate::ratchet::{ChainKey, MessageKeyGenerator, UsePQRatchet};
12use crate::state::{InvalidSessionError, SessionState};
13use crate::{
14 session, CiphertextMessage, CiphertextMessageType, Direction, IdentityKeyStore, KeyPair,
15 KyberPayload, KyberPreKeyStore, PreKeySignalMessage, PreKeyStore, ProtocolAddress, PublicKey,
16 Result, SessionRecord, SessionStore, SignalMessage, SignalProtocolError, SignedPreKeyStore,
17};
18
19pub async fn message_encrypt<R: Rng + CryptoRng>(
20 ptext: &[u8],
21 remote_address: &ProtocolAddress,
22 session_store: &mut dyn SessionStore,
23 identity_store: &mut dyn IdentityKeyStore,
24 now: SystemTime,
25 csprng: &mut R,
26) -> Result<CiphertextMessage> {
27 let mut session_record = session_store
28 .load_session(remote_address)
29 .await?
30 .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
31 let session_state = session_record
32 .session_state_mut()
33 .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
34
35 let chain_key = session_state.get_sender_chain_key()?;
36
37 let (pqr_msg, pqr_key) = session_state.pq_ratchet_send(csprng).map_err(|e| {
38 SignalProtocolError::InvalidState(
40 "message_encrypt",
41 format!("post-quantum ratchet send error: {e}"),
42 )
43 })?;
44 let message_keys = chain_key.message_keys().generate_keys(pqr_key);
45
46 let sender_ephemeral = session_state.sender_ratchet_key()?;
47 let previous_counter = session_state.previous_counter();
48 let session_version = session_state
49 .session_version()?
50 .try_into()
51 .map_err(|_| SignalProtocolError::InvalidSessionStructure("version does not fit in u8"))?;
52
53 let local_identity_key = session_state.local_identity_key()?;
54 let their_identity_key = session_state.remote_identity_key()?.ok_or_else(|| {
55 SignalProtocolError::InvalidState(
56 "message_encrypt",
57 format!("no remote identity key for {remote_address}"),
58 )
59 })?;
60
61 let ctext =
62 signal_crypto::aes_256_cbc_encrypt(ptext, message_keys.cipher_key(), message_keys.iv())
63 .map_err(|_| {
64 log::error!("session state corrupt for {remote_address}");
65 SignalProtocolError::InvalidSessionStructure("invalid sender chain message keys")
66 })?;
67
68 let message = if let Some(items) = session_state.unacknowledged_pre_key_message_items()? {
69 let timestamp_as_unix_time = items
70 .timestamp()
71 .duration_since(SystemTime::UNIX_EPOCH)
72 .unwrap_or_default()
73 .as_secs();
74 if items.timestamp() + MAX_UNACKNOWLEDGED_SESSION_AGE < now {
75 log::warn!(
76 "stale unacknowledged session for {remote_address} (created at {timestamp_as_unix_time})"
77 );
78 return Err(SignalProtocolError::SessionNotFound(remote_address.clone()));
79 }
80
81 let local_registration_id = session_state.local_registration_id();
82
83 log::info!(
84 "Building PreKeyWhisperMessage for: {} with preKeyId: {} (session created at {})",
85 remote_address,
86 items
87 .pre_key_id()
88 .map_or_else(|| "<none>".to_string(), |id| id.to_string()),
89 timestamp_as_unix_time,
90 );
91
92 let message = SignalMessage::new(
93 session_version,
94 message_keys.mac_key(),
95 sender_ephemeral,
96 chain_key.index(),
97 previous_counter,
98 &ctext,
99 &local_identity_key,
100 &their_identity_key,
101 &pqr_msg,
102 )?;
103
104 let kyber_payload = items
105 .kyber_pre_key_id()
106 .zip(items.kyber_ciphertext())
107 .map(|(id, ciphertext)| KyberPayload::new(id, ciphertext.into()));
108
109 CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new(
110 session_version,
111 local_registration_id,
112 items.pre_key_id(),
113 items.signed_pre_key_id(),
114 kyber_payload,
115 *items.base_key(),
116 local_identity_key,
117 message,
118 )?)
119 } else {
120 CiphertextMessage::SignalMessage(SignalMessage::new(
121 session_version,
122 message_keys.mac_key(),
123 sender_ephemeral,
124 chain_key.index(),
125 previous_counter,
126 &ctext,
127 &local_identity_key,
128 &their_identity_key,
129 &pqr_msg,
130 )?)
131 };
132
133 session_state.set_sender_chain_key(&chain_key.next_chain_key());
134
135 if !identity_store
137 .is_trusted_identity(remote_address, &their_identity_key, Direction::Sending)
138 .await?
139 {
140 log::warn!(
141 "Identity key {} is not trusted for remote address {}",
142 hex::encode(their_identity_key.public_key().public_key_bytes()),
143 remote_address,
144 );
145 return Err(SignalProtocolError::UntrustedIdentity(
146 remote_address.clone(),
147 ));
148 }
149
150 identity_store
152 .save_identity(remote_address, &their_identity_key)
153 .await?;
154
155 session_store
156 .store_session(remote_address, &session_record)
157 .await?;
158 Ok(message)
159}
160
161#[allow(clippy::too_many_arguments)]
162pub async fn message_decrypt<R: Rng + CryptoRng>(
163 ciphertext: &CiphertextMessage,
164 remote_address: &ProtocolAddress,
165 session_store: &mut dyn SessionStore,
166 identity_store: &mut dyn IdentityKeyStore,
167 pre_key_store: &mut dyn PreKeyStore,
168 signed_pre_key_store: &dyn SignedPreKeyStore,
169 kyber_pre_key_store: &mut dyn KyberPreKeyStore,
170 csprng: &mut R,
171 use_pq_ratchet: UsePQRatchet,
172) -> Result<Vec<u8>> {
173 match ciphertext {
174 CiphertextMessage::SignalMessage(m) => {
175 message_decrypt_signal(m, remote_address, session_store, identity_store, csprng).await
176 }
177 CiphertextMessage::PreKeySignalMessage(m) => {
178 message_decrypt_prekey(
179 m,
180 remote_address,
181 session_store,
182 identity_store,
183 pre_key_store,
184 signed_pre_key_store,
185 kyber_pre_key_store,
186 csprng,
187 use_pq_ratchet,
188 )
189 .await
190 }
191 _ => Err(SignalProtocolError::InvalidArgument(format!(
192 "message_decrypt cannot be used to decrypt {:?} messages",
193 ciphertext.message_type()
194 ))),
195 }
196}
197
198#[allow(clippy::too_many_arguments)]
199pub async fn message_decrypt_prekey<R: Rng + CryptoRng>(
200 ciphertext: &PreKeySignalMessage,
201 remote_address: &ProtocolAddress,
202 session_store: &mut dyn SessionStore,
203 identity_store: &mut dyn IdentityKeyStore,
204 pre_key_store: &mut dyn PreKeyStore,
205 signed_pre_key_store: &dyn SignedPreKeyStore,
206 kyber_pre_key_store: &mut dyn KyberPreKeyStore,
207 csprng: &mut R,
208 use_pq_ratchet: UsePQRatchet,
209) -> Result<Vec<u8>> {
210 let mut session_record = session_store
211 .load_session(remote_address)
212 .await?
213 .unwrap_or_else(SessionRecord::new_fresh);
214
215 let process_prekey_result = session::process_prekey(
217 ciphertext,
218 remote_address,
219 &mut session_record,
220 identity_store,
221 pre_key_store,
222 signed_pre_key_store,
223 kyber_pre_key_store,
224 use_pq_ratchet,
225 )
226 .await;
227
228 let (pre_key_used, identity_to_save) = match process_prekey_result {
229 Ok(result) => result,
230 Err(e) => {
231 let errs = [e];
232 log::error!(
233 "{}",
234 create_decryption_failure_log(
235 remote_address,
236 &errs,
237 &session_record,
238 ciphertext.message()
239 )?
240 );
241 let [e] = errs;
242 return Err(e);
243 }
244 };
245
246 let ptext = decrypt_message_with_record(
247 remote_address,
248 &mut session_record,
249 ciphertext.message(),
250 CiphertextMessageType::PreKey,
251 csprng,
252 )?;
253
254 identity_store
255 .save_identity(
256 identity_to_save.remote_address,
257 identity_to_save.their_identity_key,
258 )
259 .await?;
260
261 session_store
262 .store_session(remote_address, &session_record)
263 .await?;
264
265 if let Some(pre_key_id) = pre_key_used.pre_key_id {
266 pre_key_store.remove_pre_key(pre_key_id).await?;
267 }
268
269 if let Some(kyber_pre_key_id) = pre_key_used.kyber_pre_key_id {
270 kyber_pre_key_store
271 .mark_kyber_pre_key_used(kyber_pre_key_id)
272 .await?;
273 }
274
275 Ok(ptext)
276}
277
278pub async fn message_decrypt_signal<R: Rng + CryptoRng>(
279 ciphertext: &SignalMessage,
280 remote_address: &ProtocolAddress,
281 session_store: &mut dyn SessionStore,
282 identity_store: &mut dyn IdentityKeyStore,
283 csprng: &mut R,
284) -> Result<Vec<u8>> {
285 let mut session_record = session_store
286 .load_session(remote_address)
287 .await?
288 .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
289
290 let ptext = decrypt_message_with_record(
291 remote_address,
292 &mut session_record,
293 ciphertext,
294 CiphertextMessageType::Whisper,
295 csprng,
296 )?;
297
298 let their_identity_key = session_record
300 .session_state()
301 .expect("successfully decrypted; must have a current state")
302 .remote_identity_key()
303 .expect("successfully decrypted; must have a remote identity key")
304 .expect("successfully decrypted; must have a remote identity key");
305
306 if !identity_store
307 .is_trusted_identity(remote_address, &their_identity_key, Direction::Receiving)
308 .await?
309 {
310 log::warn!(
311 "Identity key {} is not trusted for remote address {}",
312 hex::encode(their_identity_key.public_key().public_key_bytes()),
313 remote_address,
314 );
315 return Err(SignalProtocolError::UntrustedIdentity(
316 remote_address.clone(),
317 ));
318 }
319
320 identity_store
321 .save_identity(remote_address, &their_identity_key)
322 .await?;
323
324 session_store
325 .store_session(remote_address, &session_record)
326 .await?;
327
328 Ok(ptext)
329}
330
331fn create_decryption_failure_log(
332 remote_address: &ProtocolAddress,
333 mut errs: &[SignalProtocolError],
334 record: &SessionRecord,
335 ciphertext: &SignalMessage,
336) -> Result<String> {
337 fn append_session_summary(
338 lines: &mut Vec<String>,
339 idx: usize,
340 state: std::result::Result<&SessionState, InvalidSessionError>,
341 err: Option<&SignalProtocolError>,
342 ) {
343 let chains = state.map(|state| state.all_receiver_chain_logging_info());
344 match (err, &chains) {
345 (Some(err), Ok(chains)) => {
346 lines.push(format!(
347 "Candidate session {} failed with '{}', had {} receiver chains",
348 idx,
349 err,
350 chains.len()
351 ));
352 }
353 (Some(err), Err(state_err)) => {
354 lines.push(format!(
355 "Candidate session {idx} failed with '{err}'; cannot get receiver chain info ({state_err})",
356 ));
357 }
358 (None, Ok(chains)) => {
359 lines.push(format!(
360 "Candidate session {} had {} receiver chains",
361 idx,
362 chains.len()
363 ));
364 }
365 (None, Err(state_err)) => {
366 lines.push(format!(
367 "Candidate session {idx}: cannot get receiver chain info ({state_err})",
368 ));
369 }
370 }
371
372 if let Ok(chains) = chains {
373 for chain in chains {
374 let chain_idx = match chain.1 {
375 Some(i) => i.to_string(),
376 None => "missing in protobuf".to_string(),
377 };
378
379 lines.push(format!(
380 "Receiver chain with sender ratchet public key {} chain key index {}",
381 hex::encode(chain.0),
382 chain_idx
383 ));
384 }
385 }
386 }
387
388 let mut lines = vec![];
389
390 lines.push(format!(
391 "Message from {} failed to decrypt; sender ratchet public key {} message counter {}",
392 remote_address,
393 hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()),
394 ciphertext.counter()
395 ));
396
397 if let Some(current_session) = record.session_state() {
398 let err = errs.first();
399 if err.is_some() {
400 errs = &errs[1..];
401 }
402 append_session_summary(&mut lines, 0, Ok(current_session), err);
403 } else {
404 lines.push("No current session".to_string());
405 }
406
407 for (idx, (state, err)) in record
408 .previous_session_states()
409 .zip(errs.iter().map(Some).chain(std::iter::repeat(None)))
410 .enumerate()
411 {
412 let state = match state {
413 Ok(ref state) => Ok(state),
414 Err(err) => Err(err),
415 };
416 append_session_summary(&mut lines, idx + 1, state, err);
417 }
418
419 Ok(lines.join("\n"))
420}
421
422fn decrypt_message_with_record<R: Rng + CryptoRng>(
423 remote_address: &ProtocolAddress,
424 record: &mut SessionRecord,
425 ciphertext: &SignalMessage,
426 original_message_type: CiphertextMessageType,
427 csprng: &mut R,
428) -> Result<Vec<u8>> {
429 debug_assert!(matches!(
430 original_message_type,
431 CiphertextMessageType::Whisper | CiphertextMessageType::PreKey
432 ));
433 let log_decryption_failure = |state: &SessionState, error: &SignalProtocolError| {
434 log::warn!(
436 "Failed to decrypt {:?} message with ratchet key: {} and counter: {}. \
437 Session loaded for {}. Local session has base key: {} and counter: {}. {}",
438 original_message_type,
439 hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()),
440 ciphertext.counter(),
441 remote_address,
442 state
443 .sender_ratchet_key_for_logging()
444 .unwrap_or_else(|e| format!("<error: {e}>")),
445 state.previous_counter(),
446 error
447 );
448 };
449
450 let mut errs = vec![];
451
452 if let Some(current_state) = record.session_state() {
453 let mut current_state = current_state.clone();
454 let result = decrypt_message_with_state(
455 CurrentOrPrevious::Current,
456 &mut current_state,
457 ciphertext,
458 original_message_type,
459 remote_address,
460 csprng,
461 );
462
463 match result {
464 Ok(ptext) => {
465 log::info!(
466 "decrypted {:?} message from {} with current session state (base key {})",
467 original_message_type,
468 remote_address,
469 current_state
470 .sender_ratchet_key_for_logging()
471 .expect("successful decrypt always has a valid base key"),
472 );
473 record.set_session_state(current_state); return Ok(ptext);
475 }
476 Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
477 return result;
478 }
479 Err(e) => {
480 log_decryption_failure(¤t_state, &e);
481 errs.push(e);
482 match original_message_type {
483 CiphertextMessageType::PreKey => {
484 log::error!(
487 "{}",
488 create_decryption_failure_log(
489 remote_address,
490 &errs,
491 record,
492 ciphertext
493 )?
494 );
495 return Err(SignalProtocolError::InvalidMessage(
498 original_message_type,
499 "decryption failed",
500 ));
501 }
502 CiphertextMessageType::Whisper => {}
503 CiphertextMessageType::SenderKey | CiphertextMessageType::Plaintext => {
504 unreachable!("should not be using Double Ratchet for these")
505 }
506 }
507 }
508 }
509 }
510
511 let mut updated_session = None;
513
514 for (idx, previous) in record.previous_session_states().enumerate() {
515 let mut previous = previous?;
516
517 let result = decrypt_message_with_state(
518 CurrentOrPrevious::Previous,
519 &mut previous,
520 ciphertext,
521 original_message_type,
522 remote_address,
523 csprng,
524 );
525
526 match result {
527 Ok(ptext) => {
528 log::info!(
529 "decrypted {:?} message from {} with PREVIOUS session state (base key {})",
530 original_message_type,
531 remote_address,
532 previous
533 .sender_ratchet_key_for_logging()
534 .expect("successful decrypt always has a valid base key"),
535 );
536 updated_session = Some((ptext, idx, previous));
537 break;
538 }
539 Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
540 return result;
541 }
542 Err(e) => {
543 log_decryption_failure(&previous, &e);
544 errs.push(e);
545 }
546 }
547 }
548
549 if let Some((ptext, idx, updated_session)) = updated_session {
550 record.promote_old_session(idx, updated_session);
551 Ok(ptext)
552 } else {
553 let previous_state_count = || record.previous_session_states().len();
554
555 if let Some(current_state) = record.session_state() {
556 log::error!(
557 "No valid session for recipient: {}, current session base key {}, number of previous states: {}",
558 remote_address,
559 current_state.sender_ratchet_key_for_logging()
560 .unwrap_or_else(|e| format!("<error: {e}>")),
561 previous_state_count(),
562 );
563 } else {
564 log::error!(
565 "No valid session for recipient: {}, (no current session state), number of previous states: {}",
566 remote_address,
567 previous_state_count(),
568 );
569 }
570 log::error!(
571 "{}",
572 create_decryption_failure_log(remote_address, &errs, record, ciphertext)?
573 );
574 Err(SignalProtocolError::InvalidMessage(
575 original_message_type,
576 "decryption failed",
577 ))
578 }
579}
580
581#[derive(Clone, Copy)]
582enum CurrentOrPrevious {
583 Current,
584 Previous,
585}
586
587impl std::fmt::Display for CurrentOrPrevious {
588 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
589 match self {
590 Self::Current => write!(f, "current"),
591 Self::Previous => write!(f, "previous"),
592 }
593 }
594}
595
596fn decrypt_message_with_state<R: Rng + CryptoRng>(
597 current_or_previous: CurrentOrPrevious,
598 state: &mut SessionState,
599 ciphertext: &SignalMessage,
600 original_message_type: CiphertextMessageType,
601 remote_address: &ProtocolAddress,
602 csprng: &mut R,
603) -> Result<Vec<u8>> {
604 let _ = state.root_key().map_err(|_| {
606 SignalProtocolError::InvalidMessage(
607 original_message_type,
608 "No session available to decrypt",
609 )
610 })?;
611
612 let ciphertext_version = ciphertext.message_version() as u32;
613 if ciphertext_version != state.session_version()? {
614 return Err(SignalProtocolError::UnrecognizedMessageVersion(
615 ciphertext_version,
616 ));
617 }
618
619 let their_ephemeral = ciphertext.sender_ratchet_key();
620 let counter = ciphertext.counter();
621 let chain_key = get_or_create_chain_key(state, their_ephemeral, remote_address, csprng)?;
622 let message_key_gen = get_or_create_message_key(
623 state,
624 their_ephemeral,
625 remote_address,
626 original_message_type,
627 &chain_key,
628 counter,
629 )?;
630 let pqr_key = state
631 .pq_ratchet_recv(ciphertext.pq_ratchet())
632 .map_err(|e| match e {
633 spqr::Error::StateDecode => SignalProtocolError::InvalidState(
634 "decrypt_message_with_state",
635 format!("post-quantum ratchet error: {e}"),
636 ),
637 _ => {
638 log::info!("post-quantum ratchet error in decrypt_message_with_state: {e}");
639 SignalProtocolError::InvalidMessage(
640 original_message_type,
641 "post-quantum ratchet error",
642 )
643 }
644 })?;
645 let message_keys = message_key_gen.generate_keys(pqr_key);
646
647 let their_identity_key =
648 state
649 .remote_identity_key()?
650 .ok_or(SignalProtocolError::InvalidSessionStructure(
651 "cannot decrypt without remote identity key",
652 ))?;
653
654 let mac_valid = ciphertext.verify_mac(
655 &their_identity_key,
656 &state.local_identity_key()?,
657 message_keys.mac_key(),
658 )?;
659
660 if !mac_valid {
661 return Err(SignalProtocolError::InvalidMessage(
662 original_message_type,
663 "MAC verification failed",
664 ));
665 }
666
667 let ptext = match signal_crypto::aes_256_cbc_decrypt(
668 ciphertext.body(),
669 message_keys.cipher_key(),
670 message_keys.iv(),
671 ) {
672 Ok(ptext) => ptext,
673 Err(signal_crypto::DecryptionError::BadKeyOrIv) => {
674 log::warn!("{current_or_previous} session state corrupt for {remote_address}",);
675 return Err(SignalProtocolError::InvalidSessionStructure(
676 "invalid receiver chain message keys",
677 ));
678 }
679 Err(signal_crypto::DecryptionError::BadCiphertext(msg)) => {
680 log::warn!("failed to decrypt 1:1 message: {msg}");
681 return Err(SignalProtocolError::InvalidMessage(
682 original_message_type,
683 "failed to decrypt",
684 ));
685 }
686 };
687
688 state.clear_unacknowledged_pre_key_message();
689
690 Ok(ptext)
691}
692
693fn get_or_create_chain_key<R: Rng + CryptoRng>(
694 state: &mut SessionState,
695 their_ephemeral: &PublicKey,
696 remote_address: &ProtocolAddress,
697 csprng: &mut R,
698) -> Result<ChainKey> {
699 if let Some(chain) = state.get_receiver_chain_key(their_ephemeral)? {
700 log::debug!("{remote_address} has existing receiver chain.");
701 return Ok(chain);
702 }
703
704 log::info!("{remote_address} creating new chains.");
705
706 let root_key = state.root_key()?;
707 let our_ephemeral = state.sender_ratchet_private_key()?;
708 let receiver_chain = root_key.create_chain(their_ephemeral, &our_ephemeral)?;
709 let our_new_ephemeral = KeyPair::generate(csprng);
710 let sender_chain = receiver_chain
711 .0
712 .create_chain(their_ephemeral, &our_new_ephemeral.private_key)?;
713
714 state.set_root_key(&sender_chain.0);
715 state.add_receiver_chain(their_ephemeral, &receiver_chain.1);
716
717 let current_index = state.get_sender_chain_key()?.index();
718 let previous_index = if current_index > 0 {
719 current_index - 1
720 } else {
721 0
722 };
723 state.set_previous_counter(previous_index);
724 state.set_sender_chain(&our_new_ephemeral, &sender_chain.1);
725
726 Ok(receiver_chain.1)
727}
728
729fn get_or_create_message_key(
730 state: &mut SessionState,
731 their_ephemeral: &PublicKey,
732 remote_address: &ProtocolAddress,
733 original_message_type: CiphertextMessageType,
734 chain_key: &ChainKey,
735 counter: u32,
736) -> Result<MessageKeyGenerator> {
737 let chain_index = chain_key.index();
738
739 if chain_index > counter {
740 return match state.get_message_keys(their_ephemeral, counter)? {
741 Some(keys) => Ok(keys),
742 None => {
743 log::info!("{remote_address} Duplicate message for counter: {counter}");
744 Err(SignalProtocolError::DuplicatedMessage(chain_index, counter))
745 }
746 };
747 }
748
749 assert!(chain_index <= counter);
750
751 let jump = (counter - chain_index) as usize;
752
753 if jump > MAX_FORWARD_JUMPS {
754 if state.session_with_self()? {
755 log::info!(
756 "{remote_address} Jumping ahead {jump} messages (index: {chain_index}, counter: {counter})"
757 );
758 } else {
759 log::error!(
760 "{remote_address} Exceeded future message limit: {MAX_FORWARD_JUMPS}, index: {chain_index}, counter: {counter})"
761 );
762 return Err(SignalProtocolError::InvalidMessage(
763 original_message_type,
764 "message from too far into the future",
765 ));
766 }
767 }
768
769 let mut chain_key = chain_key.clone();
770
771 while chain_key.index() < counter {
772 let message_keys = chain_key.message_keys();
773 state.set_message_keys(their_ephemeral, message_keys)?;
774 chain_key = chain_key.next_chain_key();
775 }
776
777 state.set_receiver_chain_key(their_ephemeral, &chain_key.next_chain_key())?;
778 Ok(chain_key.message_keys())
779}