libsignal_protocol/
session_cipher.rs

1//
2// Copyright 2020-2022 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::time::SystemTime;
7
8use rand::{CryptoRng, Rng};
9
10use crate::consts::{MAX_FORWARD_JUMPS, MAX_UNACKNOWLEDGED_SESSION_AGE};
11use crate::ratchet::{ChainKey, MessageKeys};
12use crate::state::{InvalidSessionError, SessionState};
13use crate::{
14    session, CiphertextMessage, CiphertextMessageType, Direction, IdentityKeyStore, KeyPair,
15    KyberPayload, KyberPreKeyStore, PreKeySignalMessage, PreKeyStore, ProtocolAddress, PublicKey,
16    Result, SessionRecord, SessionStore, SignalMessage, SignalProtocolError, SignedPreKeyStore,
17};
18
19pub async fn message_encrypt(
20    ptext: &[u8],
21    remote_address: &ProtocolAddress,
22    session_store: &mut dyn SessionStore,
23    identity_store: &mut dyn IdentityKeyStore,
24    now: SystemTime,
25) -> Result<CiphertextMessage> {
26    let mut session_record = session_store
27        .load_session(remote_address)
28        .await?
29        .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
30    let session_state = session_record
31        .session_state_mut()
32        .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
33
34    let chain_key = session_state.get_sender_chain_key()?;
35
36    let message_keys = chain_key.message_keys();
37
38    let sender_ephemeral = session_state.sender_ratchet_key()?;
39    let previous_counter = session_state.previous_counter();
40    let session_version = session_state.session_version()? as u8;
41
42    let local_identity_key = session_state.local_identity_key()?;
43    let their_identity_key = session_state.remote_identity_key()?.ok_or_else(|| {
44        SignalProtocolError::InvalidState(
45            "message_encrypt",
46            format!("no remote identity key for {}", remote_address),
47        )
48    })?;
49
50    let ctext =
51        signal_crypto::aes_256_cbc_encrypt(ptext, message_keys.cipher_key(), message_keys.iv())
52            .map_err(|_| {
53                log::error!("session state corrupt for {}", remote_address);
54                SignalProtocolError::InvalidSessionStructure("invalid sender chain message keys")
55            })?;
56
57    let message = if let Some(items) = session_state.unacknowledged_pre_key_message_items()? {
58        let timestamp_as_unix_time = items
59            .timestamp()
60            .duration_since(SystemTime::UNIX_EPOCH)
61            .unwrap_or_default()
62            .as_secs();
63        if items.timestamp() + MAX_UNACKNOWLEDGED_SESSION_AGE < now {
64            log::warn!(
65                "stale unacknowledged session for {} (created at {})",
66                remote_address,
67                timestamp_as_unix_time
68            );
69            return Err(SignalProtocolError::SessionNotFound(remote_address.clone()));
70        }
71
72        let local_registration_id = session_state.local_registration_id();
73
74        log::info!(
75            "Building PreKeyWhisperMessage for: {} with preKeyId: {} (session created at {})",
76            remote_address,
77            items
78                .pre_key_id()
79                .map_or_else(|| "<none>".to_string(), |id| id.to_string()),
80            timestamp_as_unix_time,
81        );
82
83        let message = SignalMessage::new(
84            session_version,
85            message_keys.mac_key(),
86            sender_ephemeral,
87            chain_key.index(),
88            previous_counter,
89            &ctext,
90            &local_identity_key,
91            &their_identity_key,
92        )?;
93
94        let kyber_payload = items
95            .kyber_pre_key_id()
96            .zip(items.kyber_ciphertext())
97            .map(|(id, ciphertext)| KyberPayload::new(id, ciphertext.into()));
98
99        CiphertextMessage::PreKeySignalMessage(PreKeySignalMessage::new(
100            session_version,
101            local_registration_id,
102            items.pre_key_id(),
103            items.signed_pre_key_id(),
104            kyber_payload,
105            *items.base_key(),
106            local_identity_key,
107            message,
108        )?)
109    } else {
110        CiphertextMessage::SignalMessage(SignalMessage::new(
111            session_version,
112            message_keys.mac_key(),
113            sender_ephemeral,
114            chain_key.index(),
115            previous_counter,
116            &ctext,
117            &local_identity_key,
118            &their_identity_key,
119        )?)
120    };
121
122    session_state.set_sender_chain_key(&chain_key.next_chain_key());
123
124    // XXX why is this check after everything else?!!
125    if !identity_store
126        .is_trusted_identity(remote_address, &their_identity_key, Direction::Sending)
127        .await?
128    {
129        log::warn!(
130            "Identity key {} is not trusted for remote address {}",
131            their_identity_key
132                .public_key()
133                .public_key_bytes()
134                .map_or_else(|e| format!("<error: {}>", e), hex::encode),
135            remote_address,
136        );
137        return Err(SignalProtocolError::UntrustedIdentity(
138            remote_address.clone(),
139        ));
140    }
141
142    // XXX this could be combined with the above call to the identity store (in a new API)
143    identity_store
144        .save_identity(remote_address, &their_identity_key)
145        .await?;
146
147    session_store
148        .store_session(remote_address, &session_record)
149        .await?;
150    Ok(message)
151}
152
153#[allow(clippy::too_many_arguments)]
154pub async fn message_decrypt<R: Rng + CryptoRng>(
155    ciphertext: &CiphertextMessage,
156    remote_address: &ProtocolAddress,
157    session_store: &mut dyn SessionStore,
158    identity_store: &mut dyn IdentityKeyStore,
159    pre_key_store: &mut dyn PreKeyStore,
160    signed_pre_key_store: &dyn SignedPreKeyStore,
161    kyber_pre_key_store: &mut dyn KyberPreKeyStore,
162    csprng: &mut R,
163) -> Result<Vec<u8>> {
164    match ciphertext {
165        CiphertextMessage::SignalMessage(m) => {
166            message_decrypt_signal(m, remote_address, session_store, identity_store, csprng).await
167        }
168        CiphertextMessage::PreKeySignalMessage(m) => {
169            message_decrypt_prekey(
170                m,
171                remote_address,
172                session_store,
173                identity_store,
174                pre_key_store,
175                signed_pre_key_store,
176                kyber_pre_key_store,
177                csprng,
178            )
179            .await
180        }
181        _ => Err(SignalProtocolError::InvalidArgument(format!(
182            "message_decrypt cannot be used to decrypt {:?} messages",
183            ciphertext.message_type()
184        ))),
185    }
186}
187
188#[allow(clippy::too_many_arguments)]
189pub async fn message_decrypt_prekey<R: Rng + CryptoRng>(
190    ciphertext: &PreKeySignalMessage,
191    remote_address: &ProtocolAddress,
192    session_store: &mut dyn SessionStore,
193    identity_store: &mut dyn IdentityKeyStore,
194    pre_key_store: &mut dyn PreKeyStore,
195    signed_pre_key_store: &dyn SignedPreKeyStore,
196    kyber_pre_key_store: &mut dyn KyberPreKeyStore,
197    csprng: &mut R,
198) -> Result<Vec<u8>> {
199    let mut session_record = session_store
200        .load_session(remote_address)
201        .await?
202        .unwrap_or_else(SessionRecord::new_fresh);
203
204    // Make sure we log the session state if we fail to process the pre-key.
205    let pre_key_used_or_err = session::process_prekey(
206        ciphertext,
207        remote_address,
208        &mut session_record,
209        identity_store,
210        pre_key_store,
211        signed_pre_key_store,
212        kyber_pre_key_store,
213    )
214    .await;
215
216    let pre_key_used = match pre_key_used_or_err {
217        Ok(result) => result,
218        Err(e) => {
219            let errs = [e];
220            log::error!(
221                "{}",
222                create_decryption_failure_log(
223                    remote_address,
224                    &errs,
225                    &session_record,
226                    ciphertext.message()
227                )?
228            );
229            let [e] = errs;
230            return Err(e);
231        }
232    };
233
234    let ptext = decrypt_message_with_record(
235        remote_address,
236        &mut session_record,
237        ciphertext.message(),
238        CiphertextMessageType::PreKey,
239        csprng,
240    )?;
241
242    session_store
243        .store_session(remote_address, &session_record)
244        .await?;
245
246    if let Some(pre_key_id) = pre_key_used.pre_key_id {
247        pre_key_store.remove_pre_key(pre_key_id).await?;
248    }
249
250    if let Some(kyber_pre_key_id) = pre_key_used.kyber_pre_key_id {
251        kyber_pre_key_store
252            .mark_kyber_pre_key_used(kyber_pre_key_id)
253            .await?;
254    }
255
256    Ok(ptext)
257}
258
259pub async fn message_decrypt_signal<R: Rng + CryptoRng>(
260    ciphertext: &SignalMessage,
261    remote_address: &ProtocolAddress,
262    session_store: &mut dyn SessionStore,
263    identity_store: &mut dyn IdentityKeyStore,
264    csprng: &mut R,
265) -> Result<Vec<u8>> {
266    let mut session_record = session_store
267        .load_session(remote_address)
268        .await?
269        .ok_or_else(|| SignalProtocolError::SessionNotFound(remote_address.clone()))?;
270
271    let ptext = decrypt_message_with_record(
272        remote_address,
273        &mut session_record,
274        ciphertext,
275        CiphertextMessageType::Whisper,
276        csprng,
277    )?;
278
279    // Why are we performing this check after decryption instead of before?
280    let their_identity_key = session_record
281        .session_state()
282        .expect("successfully decrypted; must have a current state")
283        .remote_identity_key()
284        .expect("successfully decrypted; must have a remote identity key")
285        .expect("successfully decrypted; must have a remote identity key");
286
287    if !identity_store
288        .is_trusted_identity(remote_address, &their_identity_key, Direction::Receiving)
289        .await?
290    {
291        log::warn!(
292            "Identity key {} is not trusted for remote address {}",
293            their_identity_key
294                .public_key()
295                .public_key_bytes()
296                .map_or_else(|e| format!("<error: {}>", e), hex::encode),
297            remote_address,
298        );
299        return Err(SignalProtocolError::UntrustedIdentity(
300            remote_address.clone(),
301        ));
302    }
303
304    identity_store
305        .save_identity(remote_address, &their_identity_key)
306        .await?;
307
308    session_store
309        .store_session(remote_address, &session_record)
310        .await?;
311
312    Ok(ptext)
313}
314
315fn create_decryption_failure_log(
316    remote_address: &ProtocolAddress,
317    mut errs: &[SignalProtocolError],
318    record: &SessionRecord,
319    ciphertext: &SignalMessage,
320) -> Result<String> {
321    fn append_session_summary(
322        lines: &mut Vec<String>,
323        idx: usize,
324        state: std::result::Result<&SessionState, InvalidSessionError>,
325        err: Option<&SignalProtocolError>,
326    ) {
327        let chains = state.map(|state| state.all_receiver_chain_logging_info());
328        match (err, &chains) {
329            (Some(err), Ok(chains)) => {
330                lines.push(format!(
331                    "Candidate session {} failed with '{}', had {} receiver chains",
332                    idx,
333                    err,
334                    chains.len()
335                ));
336            }
337            (Some(err), Err(state_err)) => {
338                lines.push(format!(
339                    "Candidate session {} failed with '{}'; cannot get receiver chain info ({})",
340                    idx, err, state_err,
341                ));
342            }
343            (None, Ok(chains)) => {
344                lines.push(format!(
345                    "Candidate session {} had {} receiver chains",
346                    idx,
347                    chains.len()
348                ));
349            }
350            (None, Err(state_err)) => {
351                lines.push(format!(
352                    "Candidate session {}: cannot get receiver chain info ({})",
353                    idx, state_err,
354                ));
355            }
356        }
357
358        if let Ok(chains) = chains {
359            for chain in chains {
360                let chain_idx = match chain.1 {
361                    Some(i) => i.to_string(),
362                    None => "missing in protobuf".to_string(),
363                };
364
365                lines.push(format!(
366                    "Receiver chain with sender ratchet public key {} chain key index {}",
367                    hex::encode(chain.0),
368                    chain_idx
369                ));
370            }
371        }
372    }
373
374    let mut lines = vec![];
375
376    lines.push(format!(
377        "Message from {} failed to decrypt; sender ratchet public key {} message counter {}",
378        remote_address,
379        hex::encode(ciphertext.sender_ratchet_key().public_key_bytes()?),
380        ciphertext.counter()
381    ));
382
383    if let Some(current_session) = record.session_state() {
384        let err = errs.first();
385        if err.is_some() {
386            errs = &errs[1..];
387        }
388        append_session_summary(&mut lines, 0, Ok(current_session), err);
389    } else {
390        lines.push("No current session".to_string());
391    }
392
393    for (idx, (state, err)) in record
394        .previous_session_states()
395        .zip(errs.iter().map(Some).chain(std::iter::repeat(None)))
396        .enumerate()
397    {
398        let state = match state {
399            Ok(ref state) => Ok(state),
400            Err(err) => Err(err),
401        };
402        append_session_summary(&mut lines, idx + 1, state, err);
403    }
404
405    Ok(lines.join("\n"))
406}
407
408fn decrypt_message_with_record<R: Rng + CryptoRng>(
409    remote_address: &ProtocolAddress,
410    record: &mut SessionRecord,
411    ciphertext: &SignalMessage,
412    original_message_type: CiphertextMessageType,
413    csprng: &mut R,
414) -> Result<Vec<u8>> {
415    debug_assert!(matches!(
416        original_message_type,
417        CiphertextMessageType::Whisper | CiphertextMessageType::PreKey
418    ));
419    let log_decryption_failure = |state: &SessionState, error: &SignalProtocolError| {
420        // A warning rather than an error because we try multiple sessions.
421        log::warn!(
422            "Failed to decrypt {:?} message with ratchet key: {} and counter: {}. \
423             Session loaded for {}. Local session has base key: {} and counter: {}. {}",
424            original_message_type,
425            ciphertext
426                .sender_ratchet_key()
427                .public_key_bytes()
428                .map_or_else(|e| format!("<error: {}>", e), hex::encode),
429            ciphertext.counter(),
430            remote_address,
431            state
432                .sender_ratchet_key_for_logging()
433                .unwrap_or_else(|e| format!("<error: {}>", e)),
434            state.previous_counter(),
435            error
436        );
437    };
438
439    let mut errs = vec![];
440
441    if let Some(current_state) = record.session_state() {
442        let mut current_state = current_state.clone();
443        let result = decrypt_message_with_state(
444            CurrentOrPrevious::Current,
445            &mut current_state,
446            ciphertext,
447            original_message_type,
448            remote_address,
449            csprng,
450        );
451
452        match result {
453            Ok(ptext) => {
454                log::info!(
455                    "decrypted {:?} message from {} with current session state (base key {})",
456                    original_message_type,
457                    remote_address,
458                    current_state
459                        .sender_ratchet_key_for_logging()
460                        .expect("successful decrypt always has a valid base key"),
461                );
462                record.set_session_state(current_state); // update the state
463                return Ok(ptext);
464            }
465            Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
466                return result;
467            }
468            Err(e) => {
469                log_decryption_failure(&current_state, &e);
470                errs.push(e);
471            }
472        }
473    }
474
475    // Try some old sessions:
476    let mut updated_session = None;
477
478    for (idx, previous) in record.previous_session_states().enumerate() {
479        let mut previous = previous?;
480
481        let result = decrypt_message_with_state(
482            CurrentOrPrevious::Previous,
483            &mut previous,
484            ciphertext,
485            original_message_type,
486            remote_address,
487            csprng,
488        );
489
490        match result {
491            Ok(ptext) => {
492                log::info!(
493                    "decrypted {:?} message from {} with PREVIOUS session state (base key {})",
494                    original_message_type,
495                    remote_address,
496                    previous
497                        .sender_ratchet_key_for_logging()
498                        .expect("successful decrypt always has a valid base key"),
499                );
500                updated_session = Some((ptext, idx, previous));
501                break;
502            }
503            Err(SignalProtocolError::DuplicatedMessage(_, _)) => {
504                return result;
505            }
506            Err(e) => {
507                log_decryption_failure(&previous, &e);
508                errs.push(e);
509            }
510        }
511    }
512
513    if let Some((ptext, idx, updated_session)) = updated_session {
514        record.promote_old_session(idx, updated_session);
515        Ok(ptext)
516    } else {
517        let previous_state_count = || record.previous_session_states().len();
518
519        if let Some(current_state) = record.session_state() {
520            log::error!(
521                "No valid session for recipient: {}, current session base key {}, number of previous states: {}",
522                remote_address,
523                current_state.sender_ratchet_key_for_logging()
524                .unwrap_or_else(|e| format!("<error: {}>", e)),
525                previous_state_count(),
526            );
527        } else {
528            log::error!(
529                "No valid session for recipient: {}, (no current session state), number of previous states: {}",
530                remote_address,
531                previous_state_count(),
532            );
533        }
534        log::error!(
535            "{}",
536            create_decryption_failure_log(remote_address, &errs, record, ciphertext)?
537        );
538        Err(SignalProtocolError::InvalidMessage(
539            original_message_type,
540            "decryption failed",
541        ))
542    }
543}
544
545#[derive(Clone, Copy)]
546enum CurrentOrPrevious {
547    Current,
548    Previous,
549}
550
551impl std::fmt::Display for CurrentOrPrevious {
552    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
553        match self {
554            Self::Current => write!(f, "current"),
555            Self::Previous => write!(f, "previous"),
556        }
557    }
558}
559
560fn decrypt_message_with_state<R: Rng + CryptoRng>(
561    current_or_previous: CurrentOrPrevious,
562    state: &mut SessionState,
563    ciphertext: &SignalMessage,
564    original_message_type: CiphertextMessageType,
565    remote_address: &ProtocolAddress,
566    csprng: &mut R,
567) -> Result<Vec<u8>> {
568    // Check for a completely empty or invalid session state before we do anything else.
569    let _ = state.root_key().map_err(|_| {
570        SignalProtocolError::InvalidMessage(
571            original_message_type,
572            "No session available to decrypt",
573        )
574    })?;
575
576    let ciphertext_version = ciphertext.message_version() as u32;
577    if ciphertext_version != state.session_version()? {
578        return Err(SignalProtocolError::UnrecognizedMessageVersion(
579            ciphertext_version,
580        ));
581    }
582
583    let their_ephemeral = ciphertext.sender_ratchet_key();
584    let counter = ciphertext.counter();
585    let chain_key = get_or_create_chain_key(state, their_ephemeral, remote_address, csprng)?;
586    let message_keys = get_or_create_message_key(
587        state,
588        their_ephemeral,
589        remote_address,
590        original_message_type,
591        &chain_key,
592        counter,
593    )?;
594
595    let their_identity_key =
596        state
597            .remote_identity_key()?
598            .ok_or(SignalProtocolError::InvalidSessionStructure(
599                "cannot decrypt without remote identity key",
600            ))?;
601
602    let mac_valid = ciphertext.verify_mac(
603        &their_identity_key,
604        &state.local_identity_key()?,
605        message_keys.mac_key(),
606    )?;
607
608    if !mac_valid {
609        return Err(SignalProtocolError::InvalidMessage(
610            original_message_type,
611            "MAC verification failed",
612        ));
613    }
614
615    let ptext = match signal_crypto::aes_256_cbc_decrypt(
616        ciphertext.body(),
617        message_keys.cipher_key(),
618        message_keys.iv(),
619    ) {
620        Ok(ptext) => ptext,
621        Err(signal_crypto::DecryptionError::BadKeyOrIv) => {
622            log::warn!(
623                "{} session state corrupt for {}",
624                current_or_previous,
625                remote_address,
626            );
627            return Err(SignalProtocolError::InvalidSessionStructure(
628                "invalid receiver chain message keys",
629            ));
630        }
631        Err(signal_crypto::DecryptionError::BadCiphertext(msg)) => {
632            log::warn!("failed to decrypt 1:1 message: {}", msg);
633            return Err(SignalProtocolError::InvalidMessage(
634                original_message_type,
635                "failed to decrypt",
636            ));
637        }
638    };
639
640    state.clear_unacknowledged_pre_key_message();
641
642    Ok(ptext)
643}
644
645fn get_or_create_chain_key<R: Rng + CryptoRng>(
646    state: &mut SessionState,
647    their_ephemeral: &PublicKey,
648    remote_address: &ProtocolAddress,
649    csprng: &mut R,
650) -> Result<ChainKey> {
651    if let Some(chain) = state.get_receiver_chain_key(their_ephemeral)? {
652        log::debug!("{} has existing receiver chain.", remote_address);
653        return Ok(chain);
654    }
655
656    log::info!("{} creating new chains.", remote_address);
657
658    let root_key = state.root_key()?;
659    let our_ephemeral = state.sender_ratchet_private_key()?;
660    let receiver_chain = root_key.create_chain(their_ephemeral, &our_ephemeral)?;
661    let our_new_ephemeral = KeyPair::generate(csprng);
662    let sender_chain = receiver_chain
663        .0
664        .create_chain(their_ephemeral, &our_new_ephemeral.private_key)?;
665
666    state.set_root_key(&sender_chain.0);
667    state.add_receiver_chain(their_ephemeral, &receiver_chain.1);
668
669    let current_index = state.get_sender_chain_key()?.index();
670    let previous_index = if current_index > 0 {
671        current_index - 1
672    } else {
673        0
674    };
675    state.set_previous_counter(previous_index);
676    state.set_sender_chain(&our_new_ephemeral, &sender_chain.1);
677
678    Ok(receiver_chain.1)
679}
680
681fn get_or_create_message_key(
682    state: &mut SessionState,
683    their_ephemeral: &PublicKey,
684    remote_address: &ProtocolAddress,
685    original_message_type: CiphertextMessageType,
686    chain_key: &ChainKey,
687    counter: u32,
688) -> Result<MessageKeys> {
689    let chain_index = chain_key.index();
690
691    if chain_index > counter {
692        return match state.get_message_keys(their_ephemeral, counter)? {
693            Some(keys) => Ok(keys),
694            None => {
695                log::info!(
696                    "{} Duplicate message for counter: {}",
697                    remote_address,
698                    counter
699                );
700                Err(SignalProtocolError::DuplicatedMessage(chain_index, counter))
701            }
702        };
703    }
704
705    assert!(chain_index <= counter);
706
707    let jump = (counter - chain_index) as usize;
708
709    if jump > MAX_FORWARD_JUMPS {
710        if state.session_with_self()? {
711            log::info!(
712                "{} Jumping ahead {} messages (index: {}, counter: {})",
713                remote_address,
714                jump,
715                chain_index,
716                counter
717            );
718        } else {
719            log::error!(
720                "{} Exceeded future message limit: {}, index: {}, counter: {})",
721                remote_address,
722                MAX_FORWARD_JUMPS,
723                chain_index,
724                counter
725            );
726            return Err(SignalProtocolError::InvalidMessage(
727                original_message_type,
728                "message from too far into the future",
729            ));
730        }
731    }
732
733    let mut chain_key = chain_key.clone();
734
735    while chain_key.index() < counter {
736        let message_keys = chain_key.message_keys();
737        state.set_message_keys(their_ephemeral, &message_keys)?;
738        chain_key = chain_key.next_chain_key();
739    }
740
741    state.set_receiver_chain_key(their_ephemeral, &chain_key.next_chain_key())?;
742    Ok(chain_key.message_keys())
743}