1use std::time::SystemTime;
7
8use rand::{CryptoRng, Rng};
9
10use crate::consts::{MAX_FORWARD_JUMPS, MAX_UNACKNOWLEDGED_SESSION_AGE};
11use crate::ratchet::{ChainKey, MessageKeys};
12use crate::state::{InvalidSessionError, SessionState};
13use crate::{
14 session, CiphertextMessage, CiphertextMessageType, Direction, IdentityKeyStore, KeyPair,
15 KyberPayload, KyberPreKeyStore, PreKeySignalMessage, PreKeyStore, ProtocolAddress, PublicKey,
16 Result, SessionRecord, SessionStore, SignalMessage, SignalProtocolError, SignedPreKeyStore,
17};
18
19pub async fn message_encrypt(
20 ptext: &[u8],
21 remote_address: &ProtocolAddress,
22 session_store: &mut dyn SessionStore,
23 identity_store: &mut dyn IdentityKeyStore,
24 now: SystemTime,
25) -> Result<CiphertextMessage> {
26 let mut session_record = session_store
27 .load_session(remote_address)
28 .await?
29 .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
30 let session_state = session_record
31 .session_state_mut()
32 .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
33
34 let chain_key = session_state.get_sender_chain_key()?;
35
36 let message_keys = chain_key.message_keys();
37
38 let sender_ephemeral = session_state.sender_ratchet_key()?;
39 let previous_counter = session_state.previous_counter();
40 let session_version = session_state.session_version()? as u8;
41
42 let local_identity_key = session_state.local_identity_key()?;
43 let their_identity_key = session_state.remote_identity_key()?.ok_or_else(|| {
44 SignalProtocolError::InvalidState(
45 "message_encrypt",
46 format!("no remote identity key for {}", remote_address),
47 )
48 })?;
49
50 let ctext =
51 signal_crypto::aes_256_cbc_encrypt(ptext, message_keys.cipher_key(), message_keys.iv())
52 .map_err(|_| {
53 log::error!("session state corrupt for {}", remote_address);
54 SignalProtocolError::InvalidSessionStructure("invalid sender chain message keys")
55 })?;
56
57 let message = if let Some(items) = session_state.unacknowledged_pre_key_message_items()? {
58 let timestamp_as_unix_time = items
59 .timestamp()
60 .duration_since(SystemTime::UNIX_EPOCH)
61 .unwrap_or_default()
62 .as_secs();
63 if items.timestamp() + MAX_UNACKNOWLEDGED_SESSION_AGE < now {
64 log::warn!(
65 "stale unacknowledged session for {} (created at {})",
66 remote_address,
67 timestamp_as_unix_time
68 );
69 return Err(SignalProtocolError::SessionNotFound(remote_address.clone()));
70 }
71
72 let local_registration_id = session_state.local_registration_id();
73
74 log::info!(
75 "Building PreKeyWhisperMessage for: {} with preKeyId: {} (session created at {})",
76 remote_address,
77 items
78 .pre_key_id()
79 .map_or_else(|| "<none>".to_string(), |id| id.to_string()),
80 timestamp_as_unix_time,
81 );
82
83 let message = SignalMessage::new(
84 session_version,
85 message_keys.mac_key(),
86 sender_ephemeral,
87 chain_key.index(),
88 previous_counter,
89 &ctext,
90 &local_identity_key,
91 &their_identity_key,
92 )?;
93
94 let kyber_payload = items
95 .kyber_pre_key_id()
96 .zip(items.kyber_ciphertext())
97 .map(|(id, ciphertext)| KyberPayload::new(id, ciphertext.into()));
98
99 CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new(
100 session_version,
101 local_registration_id,
102 items.pre_key_id(),
103 items.signed_pre_key_id(),
104 kyber_payload,
105 *items.base_key(),
106 local_identity_key,
107 message,
108 )?)
109 } else {
110 CiphertextMessage::SignalMessage(SignalMessage::new(
111 session_version,
112 message_keys.mac_key(),
113 sender_ephemeral,
114 chain_key.index(),
115 previous_counter,
116 &ctext,
117 &local_identity_key,
118 &their_identity_key,
119 )?)
120 };
121
122 session_state.set_sender_chain_key(&chain_key.next_chain_key());
123
124 if !identity_store
126 .is_trusted_identity(remote_address, &their_identity_key, Direction::Sending)
127 .await?
128 {
129 log::warn!(
130 "Identity key {} is not trusted for remote address {}",
131 their_identity_key
132 .public_key()
133 .public_key_bytes()
134 .map_or_else(|e| format!("<error: {}>", e), hex::encode),
135 remote_address,
136 );
137 return Err(SignalProtocolError::UntrustedIdentity(
138 remote_address.clone(),
139 ));
140 }
141
142 identity_store
144 .save_identity(remote_address, &their_identity_key)
145 .await?;
146
147 session_store
148 .store_session(remote_address, &session_record)
149 .await?;
150 Ok(message)
151}
152
153#[allow(clippy::too_many_arguments)]
154pub async fn message_decrypt<R: Rng + CryptoRng>(
155 ciphertext: &CiphertextMessage,
156 remote_address: &ProtocolAddress,
157 session_store: &mut dyn SessionStore,
158 identity_store: &mut dyn IdentityKeyStore,
159 pre_key_store: &mut dyn PreKeyStore,
160 signed_pre_key_store: &dyn SignedPreKeyStore,
161 kyber_pre_key_store: &mut dyn KyberPreKeyStore,
162 csprng: &mut R,
163) -> Result<Vec<u8>> {
164 match ciphertext {
165 CiphertextMessage::SignalMessage(m) => {
166 message_decrypt_signal(m, remote_address, session_store, identity_store, csprng).await
167 }
168 CiphertextMessage::PreKeySignalMessage(m) => {
169 message_decrypt_prekey(
170 m,
171 remote_address,
172 session_store,
173 identity_store,
174 pre_key_store,
175 signed_pre_key_store,
176 kyber_pre_key_store,
177 csprng,
178 )
179 .await
180 }
181 _ => Err(SignalProtocolError::InvalidArgument(format!(
182 "message_decrypt cannot be used to decrypt {:?} messages",
183 ciphertext.message_type()
184 ))),
185 }
186}
187
188#[allow(clippy::too_many_arguments)]
189pub async fn message_decrypt_prekey<R: Rng + CryptoRng>(
190 ciphertext: &PreKeySignalMessage,
191 remote_address: &ProtocolAddress,
192 session_store: &mut dyn SessionStore,
193 identity_store: &mut dyn IdentityKeyStore,
194 pre_key_store: &mut dyn PreKeyStore,
195 signed_pre_key_store: &dyn SignedPreKeyStore,
196 kyber_pre_key_store: &mut dyn KyberPreKeyStore,
197 csprng: &mut R,
198) -> Result<Vec<u8>> {
199 let mut session_record = session_store
200 .load_session(remote_address)
201 .await?
202 .unwrap_or_else(SessionRecord::new_fresh);
203
204 let pre_key_used_or_err = session::process_prekey(
206 ciphertext,
207 remote_address,
208 &mut session_record,
209 identity_store,
210 pre_key_store,
211 signed_pre_key_store,
212 kyber_pre_key_store,
213 )
214 .await;
215
216 let pre_key_used = match pre_key_used_or_err {
217 Ok(result) => result,
218 Err(e) => {
219 let errs = [e];
220 log::error!(
221 "{}",
222 create_decryption_failure_log(
223 remote_address,
224 &errs,
225 &session_record,
226 ciphertext.message()
227 )?
228 );
229 let [e] = errs;
230 return Err(e);
231 }
232 };
233
234 let ptext = decrypt_message_with_record(
235 remote_address,
236 &mut session_record,
237 ciphertext.message(),
238 CiphertextMessageType::PreKey,
239 csprng,
240 )?;
241
242 session_store
243 .store_session(remote_address, &session_record)
244 .await?;
245
246 if let Some(pre_key_id) = pre_key_used.pre_key_id {
247 pre_key_store.remove_pre_key(pre_key_id).await?;
248 }
249
250 if let Some(kyber_pre_key_id) = pre_key_used.kyber_pre_key_id {
251 kyber_pre_key_store
252 .mark_kyber_pre_key_used(kyber_pre_key_id)
253 .await?;
254 }
255
256 Ok(ptext)
257}
258
259pub async fn message_decrypt_signal<R: Rng + CryptoRng>(
260 ciphertext: &SignalMessage,
261 remote_address: &ProtocolAddress,
262 session_store: &mut dyn SessionStore,
263 identity_store: &mut dyn IdentityKeyStore,
264 csprng: &mut R,
265) -> Result<Vec<u8>> {
266 let mut session_record = session_store
267 .load_session(remote_address)
268 .await?
269 .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
270
271 let ptext = decrypt_message_with_record(
272 remote_address,
273 &mut session_record,
274 ciphertext,
275 CiphertextMessageType::Whisper,
276 csprng,
277 )?;
278
279 let their_identity_key = session_record
281 .session_state()
282 .expect("successfully decrypted; must have a current state")
283 .remote_identity_key()
284 .expect("successfully decrypted; must have a remote identity key")
285 .expect("successfully decrypted; must have a remote identity key");
286
287 if !identity_store
288 .is_trusted_identity(remote_address, &their_identity_key, Direction::Receiving)
289 .await?
290 {
291 log::warn!(
292 "Identity key {} is not trusted for remote address {}",
293 their_identity_key
294 .public_key()
295 .public_key_bytes()
296 .map_or_else(|e| format!("<error: {}>", e), hex::encode),
297 remote_address,
298 );
299 return Err(SignalProtocolError::UntrustedIdentity(
300 remote_address.clone(),
301 ));
302 }
303
304 identity_store
305 .save_identity(remote_address, &their_identity_key)
306 .await?;
307
308 session_store
309 .store_session(remote_address, &session_record)
310 .await?;
311
312 Ok(ptext)
313}
314
315fn create_decryption_failure_log(
316 remote_address: &ProtocolAddress,
317 mut errs: &[SignalProtocolError],
318 record: &SessionRecord,
319 ciphertext: &SignalMessage,
320) -> Result<String> {
321 fn append_session_summary(
322 lines: &mut Vec<String>,
323 idx: usize,
324 state: std::result::Result<&SessionState, InvalidSessionError>,
325 err: Option<&SignalProtocolError>,
326 ) {
327 let chains = state.map(|state| state.all_receiver_chain_logging_info());
328 match (err, &chains) {
329 (Some(err), Ok(chains)) => {
330 lines.push(format!(
331 "Candidate session {} failed with '{}', had {} receiver chains",
332 idx,
333 err,
334 chains.len()
335 ));
336 }
337 (Some(err), Err(state_err)) => {
338 lines.push(format!(
339 "Candidate session {} failed with '{}'; cannot get receiver chain info ({})",
340 idx, err, state_err,
341 ));
342 }
343 (None, Ok(chains)) => {
344 lines.push(format!(
345 "Candidate session {} had {} receiver chains",
346 idx,
347 chains.len()
348 ));
349 }
350 (None, Err(state_err)) => {
351 lines.push(format!(
352 "Candidate session {}: cannot get receiver chain info ({})",
353 idx, state_err,
354 ));
355 }
356 }
357
358 if let Ok(chains) = chains {
359 for chain in chains {
360 let chain_idx = match chain.1 {
361 Some(i) => i.to_string(),
362 None => "missing in protobuf".to_string(),
363 };
364
365 lines.push(format!(
366 "Receiver chain with sender ratchet public key {} chain key index {}",
367 hex::encode(chain.0),
368 chain_idx
369 ));
370 }
371 }
372 }
373
374 let mut lines = vec![];
375
376 lines.push(format!(
377 "Message from {} failed to decrypt; sender ratchet public key {} message counter {}",
378 remote_address,
379 hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()?),
380 ciphertext.counter()
381 ));
382
383 if let Some(current_session) = record.session_state() {
384 let err = errs.first();
385 if err.is_some() {
386 errs = &errs[1..];
387 }
388 append_session_summary(&mut lines, 0, Ok(current_session), err);
389 } else {
390 lines.push("No current session".to_string());
391 }
392
393 for (idx, (state, err)) in record
394 .previous_session_states()
395 .zip(errs.iter().map(Some).chain(std::iter::repeat(None)))
396 .enumerate()
397 {
398 let state = match state {
399 Ok(ref state) => Ok(state),
400 Err(err) => Err(err),
401 };
402 append_session_summary(&mut lines, idx + 1, state, err);
403 }
404
405 Ok(lines.join("\n"))
406}
407
408fn decrypt_message_with_record<R: Rng + CryptoRng>(
409 remote_address: &ProtocolAddress,
410 record: &mut SessionRecord,
411 ciphertext: &SignalMessage,
412 original_message_type: CiphertextMessageType,
413 csprng: &mut R,
414) -> Result<Vec<u8>> {
415 debug_assert!(matches!(
416 original_message_type,
417 CiphertextMessageType::Whisper | CiphertextMessageType::PreKey
418 ));
419 let log_decryption_failure = |state: &SessionState, error: &SignalProtocolError| {
420 log::warn!(
422 "Failed to decrypt {:?} message with ratchet key: {} and counter: {}. \
423 Session loaded for {}. Local session has base key: {} and counter: {}. {}",
424 original_message_type,
425 ciphertext
426 .sender_ratchet_key()
427 .public_key_bytes()
428 .map_or_else(|e| format!("<error: {}>", e), hex::encode),
429 ciphertext.counter(),
430 remote_address,
431 state
432 .sender_ratchet_key_for_logging()
433 .unwrap_or_else(|e| format!("<error: {}>", e)),
434 state.previous_counter(),
435 error
436 );
437 };
438
439 let mut errs = vec![];
440
441 if let Some(current_state) = record.session_state() {
442 let mut current_state = current_state.clone();
443 let result = decrypt_message_with_state(
444 CurrentOrPrevious::Current,
445 &mut current_state,
446 ciphertext,
447 original_message_type,
448 remote_address,
449 csprng,
450 );
451
452 match result {
453 Ok(ptext) => {
454 log::info!(
455 "decrypted {:?} message from {} with current session state (base key {})",
456 original_message_type,
457 remote_address,
458 current_state
459 .sender_ratchet_key_for_logging()
460 .expect("successful decrypt always has a valid base key"),
461 );
462 record.set_session_state(current_state); return Ok(ptext);
464 }
465 Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
466 return result;
467 }
468 Err(e) => {
469 log_decryption_failure(¤t_state, &e);
470 errs.push(e);
471 }
472 }
473 }
474
475 let mut updated_session = None;
477
478 for (idx, previous) in record.previous_session_states().enumerate() {
479 let mut previous = previous?;
480
481 let result = decrypt_message_with_state(
482 CurrentOrPrevious::Previous,
483 &mut previous,
484 ciphertext,
485 original_message_type,
486 remote_address,
487 csprng,
488 );
489
490 match result {
491 Ok(ptext) => {
492 log::info!(
493 "decrypted {:?} message from {} with PREVIOUS session state (base key {})",
494 original_message_type,
495 remote_address,
496 previous
497 .sender_ratchet_key_for_logging()
498 .expect("successful decrypt always has a valid base key"),
499 );
500 updated_session = Some((ptext, idx, previous));
501 break;
502 }
503 Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
504 return result;
505 }
506 Err(e) => {
507 log_decryption_failure(&previous, &e);
508 errs.push(e);
509 }
510 }
511 }
512
513 if let Some((ptext, idx, updated_session)) = updated_session {
514 record.promote_old_session(idx, updated_session);
515 Ok(ptext)
516 } else {
517 let previous_state_count = || record.previous_session_states().len();
518
519 if let Some(current_state) = record.session_state() {
520 log::error!(
521 "No valid session for recipient: {}, current session base key {}, number of previous states: {}",
522 remote_address,
523 current_state.sender_ratchet_key_for_logging()
524 .unwrap_or_else(|e| format!("<error: {}>", e)),
525 previous_state_count(),
526 );
527 } else {
528 log::error!(
529 "No valid session for recipient: {}, (no current session state), number of previous states: {}",
530 remote_address,
531 previous_state_count(),
532 );
533 }
534 log::error!(
535 "{}",
536 create_decryption_failure_log(remote_address, &errs, record, ciphertext)?
537 );
538 Err(SignalProtocolError::InvalidMessage(
539 original_message_type,
540 "decryption failed",
541 ))
542 }
543}
544
545#[derive(Clone, Copy)]
546enum CurrentOrPrevious {
547 Current,
548 Previous,
549}
550
551impl std::fmt::Display for CurrentOrPrevious {
552 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553 match self {
554 Self::Current => write!(f, "current"),
555 Self::Previous => write!(f, "previous"),
556 }
557 }
558}
559
560fn decrypt_message_with_state<R: Rng + CryptoRng>(
561 current_or_previous: CurrentOrPrevious,
562 state: &mut SessionState,
563 ciphertext: &SignalMessage,
564 original_message_type: CiphertextMessageType,
565 remote_address: &ProtocolAddress,
566 csprng: &mut R,
567) -> Result<Vec<u8>> {
568 let _ = state.root_key().map_err(|_| {
570 SignalProtocolError::InvalidMessage(
571 original_message_type,
572 "No session available to decrypt",
573 )
574 })?;
575
576 let ciphertext_version = ciphertext.message_version() as u32;
577 if ciphertext_version != state.session_version()? {
578 return Err(SignalProtocolError::UnrecognizedMessageVersion(
579 ciphertext_version,
580 ));
581 }
582
583 let their_ephemeral = ciphertext.sender_ratchet_key();
584 let counter = ciphertext.counter();
585 let chain_key = get_or_create_chain_key(state, their_ephemeral, remote_address, csprng)?;
586 let message_keys = get_or_create_message_key(
587 state,
588 their_ephemeral,
589 remote_address,
590 original_message_type,
591 &chain_key,
592 counter,
593 )?;
594
595 let their_identity_key =
596 state
597 .remote_identity_key()?
598 .ok_or(SignalProtocolError::InvalidSessionStructure(
599 "cannot decrypt without remote identity key",
600 ))?;
601
602 let mac_valid = ciphertext.verify_mac(
603 &their_identity_key,
604 &state.local_identity_key()?,
605 message_keys.mac_key(),
606 )?;
607
608 if !mac_valid {
609 return Err(SignalProtocolError::InvalidMessage(
610 original_message_type,
611 "MAC verification failed",
612 ));
613 }
614
615 let ptext = match signal_crypto::aes_256_cbc_decrypt(
616 ciphertext.body(),
617 message_keys.cipher_key(),
618 message_keys.iv(),
619 ) {
620 Ok(ptext) => ptext,
621 Err(signal_crypto::DecryptionError::BadKeyOrIv) => {
622 log::warn!(
623 "{} session state corrupt for {}",
624 current_or_previous,
625 remote_address,
626 );
627 return Err(SignalProtocolError::InvalidSessionStructure(
628 "invalid receiver chain message keys",
629 ));
630 }
631 Err(signal_crypto::DecryptionError::BadCiphertext(msg)) => {
632 log::warn!("failed to decrypt 1:1 message: {}", msg);
633 return Err(SignalProtocolError::InvalidMessage(
634 original_message_type,
635 "failed to decrypt",
636 ));
637 }
638 };
639
640 state.clear_unacknowledged_pre_key_message();
641
642 Ok(ptext)
643}
644
645fn get_or_create_chain_key<R: Rng + CryptoRng>(
646 state: &mut SessionState,
647 their_ephemeral: &PublicKey,
648 remote_address: &ProtocolAddress,
649 csprng: &mut R,
650) -> Result<ChainKey> {
651 if let Some(chain) = state.get_receiver_chain_key(their_ephemeral)? {
652 log::debug!("{} has existing receiver chain.", remote_address);
653 return Ok(chain);
654 }
655
656 log::info!("{} creating new chains.", remote_address);
657
658 let root_key = state.root_key()?;
659 let our_ephemeral = state.sender_ratchet_private_key()?;
660 let receiver_chain = root_key.create_chain(their_ephemeral, &our_ephemeral)?;
661 let our_new_ephemeral = KeyPair::generate(csprng);
662 let sender_chain = receiver_chain
663 .0
664 .create_chain(their_ephemeral, &our_new_ephemeral.private_key)?;
665
666 state.set_root_key(&sender_chain.0);
667 state.add_receiver_chain(their_ephemeral, &receiver_chain.1);
668
669 let current_index = state.get_sender_chain_key()?.index();
670 let previous_index = if current_index > 0 {
671 current_index - 1
672 } else {
673 0
674 };
675 state.set_previous_counter(previous_index);
676 state.set_sender_chain(&our_new_ephemeral, &sender_chain.1);
677
678 Ok(receiver_chain.1)
679}
680
681fn get_or_create_message_key(
682 state: &mut SessionState,
683 their_ephemeral: &PublicKey,
684 remote_address: &ProtocolAddress,
685 original_message_type: CiphertextMessageType,
686 chain_key: &ChainKey,
687 counter: u32,
688) -> Result<MessageKeys> {
689 let chain_index = chain_key.index();
690
691 if chain_index > counter {
692 return match state.get_message_keys(their_ephemeral, counter)? {
693 Some(keys) => Ok(keys),
694 None => {
695 log::info!(
696 "{} Duplicate message for counter: {}",
697 remote_address,
698 counter
699 );
700 Err(SignalProtocolError::DuplicatedMessage(chain_index, counter))
701 }
702 };
703 }
704
705 assert!(chain_index <= counter);
706
707 let jump = (counter - chain_index) as usize;
708
709 if jump > MAX_FORWARD_JUMPS {
710 if state.session_with_self()? {
711 log::info!(
712 "{} Jumping ahead {} messages (index: {}, counter: {})",
713 remote_address,
714 jump,
715 chain_index,
716 counter
717 );
718 } else {
719 log::error!(
720 "{} Exceeded future message limit: {}, index: {}, counter: {})",
721 remote_address,
722 MAX_FORWARD_JUMPS,
723 chain_index,
724 counter
725 );
726 return Err(SignalProtocolError::InvalidMessage(
727 original_message_type,
728 "message from too far into the future",
729 ));
730 }
731 }
732
733 let mut chain_key = chain_key.clone();
734
735 while chain_key.index() < counter {
736 let message_keys = chain_key.message_keys();
737 state.set_message_keys(their_ephemeral, &message_keys)?;
738 chain_key = chain_key.next_chain_key();
739 }
740
741 state.set_receiver_chain_key(their_ephemeral, &chain_key.next_chain_key())?;
742 Ok(chain_key.message_keys())
743}