Skip to main content

libsignal_protocol/state/
session.rs

1//
2// Copyright 2020-2022 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::result::Result;
7use std::time::{Duration, SystemTime};
8
9use bitflags::bitflags;
10use prost::Message;
11#[cfg(test)]
12use rand::{CryptoRng, Rng};
13use subtle::ConstantTimeEq;
14
15use crate::proto::storage::{RecordStructure, SessionStructure, session_structure};
16use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION;
17#[cfg(test)]
18use crate::ratchet::MessageKeyGenerator;
19use crate::ratchet::{ChainKey, RootKey};
20use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
21use crate::{IdentityKey, KeyPair, PrivateKey, PublicKey, SignalProtocolError, consts, kem};
22
23/// A distinct error type to keep from accidentally propagating deserialization errors.
24#[derive(Debug)]
25pub(crate) struct InvalidSessionError(pub(crate) &'static str);
26
27impl std::fmt::Display for InvalidSessionError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        self.0.fmt(f)
30    }
31}
32
33impl From<InvalidSessionError> for SignalProtocolError {
34    fn from(e: InvalidSessionError) -> Self {
35        Self::InvalidSessionStructure(e.0)
36    }
37}
38
39#[derive(Debug, Clone)]
40pub(crate) struct UnacknowledgedPreKeyMessageItems<'a> {
41    pre_key_id: Option<PreKeyId>,
42    signed_pre_key_id: SignedPreKeyId,
43    base_key: PublicKey,
44    // Although we require PQXDH for all new sessions now,
45    // we may in theory have an existing X3DH unacknowledged session,
46    // so we leave these optional for now.
47    kyber_pre_key_id: Option<KyberPreKeyId>,
48    kyber_ciphertext: Option<&'a [u8]>,
49    timestamp: SystemTime,
50}
51
52impl<'a> UnacknowledgedPreKeyMessageItems<'a> {
53    fn new(
54        pre_key_id: Option<PreKeyId>,
55        signed_pre_key_id: SignedPreKeyId,
56        base_key: PublicKey,
57        pending_kyber_pre_key: Option<&'a session_structure::PendingKyberPreKey>,
58        timestamp: SystemTime,
59    ) -> Self {
60        let (kyber_pre_key_id, kyber_ciphertext) = pending_kyber_pre_key
61            .map(|pending| (pending.pre_key_id.into(), pending.ciphertext.as_slice()))
62            .unzip();
63        Self {
64            pre_key_id,
65            signed_pre_key_id,
66            base_key,
67            kyber_pre_key_id,
68            kyber_ciphertext,
69            timestamp,
70        }
71    }
72
73    pub(crate) fn pre_key_id(&self) -> Option<PreKeyId> {
74        self.pre_key_id
75    }
76
77    pub(crate) fn signed_pre_key_id(&self) -> SignedPreKeyId {
78        self.signed_pre_key_id
79    }
80
81    pub(crate) fn base_key(&self) -> &PublicKey {
82        &self.base_key
83    }
84
85    pub(crate) fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
86        self.kyber_pre_key_id
87    }
88
89    pub(crate) fn kyber_ciphertext(&self) -> Option<&'a [u8]> {
90        self.kyber_ciphertext
91    }
92
93    pub(crate) fn timestamp(&self) -> SystemTime {
94        self.timestamp
95    }
96}
97
98bitflags! {
99    /// Specifies which criteria make a session "usable" beyond simply having a present sender
100    /// chain.
101    ///
102    /// These requirements are conjunctive, i.e. specifying `NotStale | EstablishedWithPqxdh` means
103    /// the session must be neither stale nor established with X3DH.
104    ///
105    /// This struct is generated using the `bitflags` crate; the [`Flags`](::bitflags::Flags) trait
106    /// provides most of its API surface. It can also use "classic" C bitflag syntax, with `|` for
107    /// union and `&` for intersection (as shown above).
108    #[repr(transparent)]
109    #[derive(Clone, Copy, PartialEq, Eq)]
110    pub struct SessionUsabilityRequirements : u32 {
111        /// Requires that a session not be stale.
112        ///
113        /// A non-stale session is one of the following:
114        /// - "incoming", i.e. started by the peer
115        /// - "acknowledged", i.e. started locally but received a response
116        /// - no more than a few weeks old (the exact time is chosen by libsignal)
117        const NotStale = 1 << 0;
118        /// Requires that a session was established using PQXDH (or newer) rather than X3DH/X4DH.
119        ///
120        /// This includes unacknowledged sessions that are using PQXDH, since if they get a
121        /// response, the peer is confirmed to be using PQXDH as well.
122        const EstablishedWithPqxdh = 1 << 1;
123        /// Requires that a session is using SPQR.
124        ///
125        /// **Warning:** This allows unacknowledged sessions that include SPQR in their PreKey
126        /// messages. If the peer downgrades the session (by discarding the SPQR information) and
127        /// the local client allows it, a session that is previously considered "usable" can become
128        /// "not usable" upon receiving a response. Therefore, this should not be used to determine
129        /// whether a session is usable unless future downgrades will also be rejected.
130        const Spqr = 1 << 2;
131    }
132}
133
134#[derive(Clone, Debug)]
135pub(crate) struct SessionState {
136    session: SessionStructure,
137}
138
139impl SessionState {
140    pub(crate) fn from_session_structure(session: SessionStructure) -> Self {
141        Self { session }
142    }
143
144    pub(crate) fn new(
145        version: u8,
146        our_identity: &IdentityKey,
147        their_identity: &IdentityKey,
148        root_key: &RootKey,
149        alice_base_key: &PublicKey,
150        pq_ratchet_state: spqr::SerializedState,
151    ) -> Self {
152        Self {
153            session: SessionStructure {
154                session_version: version as u32,
155                local_identity_public: our_identity.public_key().serialize().into_vec(),
156                remote_identity_public: their_identity.serialize().into_vec(),
157                root_key: root_key.key().to_vec(),
158                previous_counter: 0,
159                sender_chain: None,
160                receiver_chains: vec![],
161                pending_pre_key: None,
162                pending_kyber_pre_key: None,
163                remote_registration_id: 0,
164                local_registration_id: 0,
165                alice_base_key: alice_base_key.serialize().into_vec(),
166                pq_ratchet_state,
167            },
168        }
169    }
170
171    pub(crate) fn alice_base_key(&self) -> &[u8] {
172        // Check the length before returning?
173        &self.session.alice_base_key
174    }
175
176    pub(crate) fn session_version(&self) -> Result<u32, InvalidSessionError> {
177        match self.session.session_version {
178            0 => Ok(2),
179            v => Ok(v),
180        }
181    }
182
183    pub(crate) fn remote_identity_key(&self) -> Result<Option<IdentityKey>, InvalidSessionError> {
184        match self.session.remote_identity_public.len() {
185            0 => Ok(None),
186            _ => Ok(Some(
187                IdentityKey::decode(&self.session.remote_identity_public)
188                    .map_err(|_| InvalidSessionError("invalid remote identity key"))?,
189            )),
190        }
191    }
192
193    pub(crate) fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, InvalidSessionError> {
194        Ok(self.remote_identity_key()?.map(|k| k.serialize().to_vec()))
195    }
196
197    pub(crate) fn local_identity_key(&self) -> Result<IdentityKey, InvalidSessionError> {
198        IdentityKey::decode(&self.session.local_identity_public)
199            .map_err(|_| InvalidSessionError("invalid local identity key"))
200    }
201
202    pub(crate) fn local_identity_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
203        Ok(self.local_identity_key()?.serialize().to_vec())
204    }
205
206    #[deprecated]
207    #[allow(unused)]
208    pub(crate) fn session_with_self(&self) -> Result<bool, InvalidSessionError> {
209        if let Some(remote_id) = self.remote_identity_key_bytes()? {
210            let local_id = self.local_identity_key_bytes()?;
211            return Ok(remote_id == local_id);
212        }
213
214        // If remote ID is not set then we can't be sure but treat as non-self
215        Ok(false)
216    }
217
218    pub(crate) fn previous_counter(&self) -> u32 {
219        self.session.previous_counter
220    }
221
222    #[cfg(test)]
223    pub(crate) fn set_previous_counter(&mut self, ctr: u32) {
224        self.session.previous_counter = ctr;
225    }
226
227    #[cfg(test)]
228    pub(crate) fn root_key(&self) -> Result<RootKey, InvalidSessionError> {
229        let root_key_bytes = self.session.root_key[..]
230            .try_into()
231            .map_err(|_| InvalidSessionError("invalid root key"))?;
232        Ok(RootKey::new(root_key_bytes))
233    }
234
235    #[cfg(test)]
236    pub(crate) fn set_root_key(&mut self, root_key: &RootKey) {
237        self.session.root_key = root_key.key().to_vec();
238    }
239
240    pub(crate) fn sender_ratchet_key(&self) -> Result<PublicKey, InvalidSessionError> {
241        match self.session.sender_chain {
242            None => Err(InvalidSessionError("missing sender chain")),
243            Some(ref c) => PublicKey::deserialize(&c.sender_ratchet_key)
244                .map_err(|_| InvalidSessionError("invalid sender chain ratchet key")),
245        }
246    }
247
248    pub(crate) fn sender_ratchet_key_for_logging(&self) -> Result<String, InvalidSessionError> {
249        Ok(hex::encode(self.sender_ratchet_key()?.public_key_bytes()))
250    }
251
252    pub(crate) fn sender_ratchet_private_key(&self) -> Result<PrivateKey, InvalidSessionError> {
253        match self.session.sender_chain {
254            None => Err(InvalidSessionError("missing sender chain")),
255            Some(ref c) => PrivateKey::deserialize(&c.sender_ratchet_key_private)
256                .map_err(|_| InvalidSessionError("invalid sender chain private ratchet key")),
257        }
258    }
259
260    pub fn has_usable_sender_chain(
261        &self,
262        now: SystemTime,
263        requirements: SessionUsabilityRequirements,
264    ) -> Result<bool, InvalidSessionError> {
265        if self.session.sender_chain.is_none() {
266            return Ok(false);
267        }
268        if requirements.contains(SessionUsabilityRequirements::NotStale) {
269            if let Some(pending_pre_key) = &self.session.pending_pre_key {
270                let creation_timestamp =
271                    SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp);
272                if creation_timestamp + consts::MAX_UNACKNOWLEDGED_SESSION_AGE < now {
273                    return Ok(false);
274                }
275            }
276        }
277        #[allow(clippy::collapsible_if)]
278        if requirements.contains(SessionUsabilityRequirements::EstablishedWithPqxdh) {
279            if self.session_version()? <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION.into() {
280                return Ok(false);
281            }
282        }
283        #[allow(clippy::collapsible_if)]
284        if requirements.contains(SessionUsabilityRequirements::Spqr) {
285            if self.pq_ratchet_state().is_empty() {
286                return Ok(false);
287            }
288        }
289        Ok(true)
290    }
291
292    pub(crate) fn all_receiver_chain_logging_info(&self) -> Vec<(Vec<u8>, Option<u32>)> {
293        let mut results = vec![];
294        for chain in self.session.receiver_chains.iter() {
295            let sender_ratchet_public = chain.sender_ratchet_key.clone();
296
297            let chain_key_idx = chain.chain_key.as_ref().map(|chain_key| chain_key.index);
298
299            results.push((sender_ratchet_public, chain_key_idx))
300        }
301        results
302    }
303
304    pub(crate) fn get_receiver_chain(
305        &self,
306        sender: &PublicKey,
307    ) -> Result<Option<(session_structure::Chain, usize)>, InvalidSessionError> {
308        for (idx, chain) in self.session.receiver_chains.iter().enumerate() {
309            // If we compared bytes directly it would be faster, but may miss non-canonical points.
310            // It's unclear if supporting such points is desirable.
311            let chain_ratchet_key = PublicKey::deserialize(&chain.sender_ratchet_key)
312                .map_err(|_| InvalidSessionError("invalid receiver chain ratchet key"))?;
313
314            if &chain_ratchet_key == sender {
315                return Ok(Some((chain.clone(), idx)));
316            }
317        }
318
319        Ok(None)
320    }
321
322    pub(crate) fn get_receiver_chain_key(
323        &self,
324        sender: &PublicKey,
325    ) -> Result<Option<ChainKey>, InvalidSessionError> {
326        match self.get_receiver_chain(sender)? {
327            None => Ok(None),
328            Some((chain, _)) => match chain.chain_key {
329                None => Err(InvalidSessionError("missing receiver chain key")),
330                Some(c) => {
331                    let chain_key_bytes = c.key[..]
332                        .try_into()
333                        .map_err(|_| InvalidSessionError("invalid receiver chain key"))?;
334                    Ok(Some(ChainKey::new(chain_key_bytes, c.index)))
335                }
336            },
337        }
338    }
339
340    pub(crate) fn add_receiver_chain(&mut self, sender: &PublicKey, chain_key: &ChainKey) {
341        let chain_key = session_structure::chain::ChainKey {
342            index: chain_key.index(),
343            key: chain_key.key().to_vec(),
344        };
345
346        let chain = session_structure::Chain {
347            sender_ratchet_key: sender.serialize().to_vec(),
348            sender_ratchet_key_private: vec![],
349            chain_key: Some(chain_key),
350            message_keys: vec![],
351        };
352
353        self.session.receiver_chains.push(chain);
354
355        if self.session.receiver_chains.len() > consts::MAX_RECEIVER_CHAINS {
356            log::info!(
357                "Trimming excessive receiver_chain for session with base key {}, chain count: {}",
358                self.sender_ratchet_key_for_logging()
359                    .unwrap_or_else(|e| format!("<error: {}>", e.0)),
360                self.session.receiver_chains.len()
361            );
362            self.session.receiver_chains.remove(0);
363        }
364    }
365
366    pub(crate) fn with_receiver_chain(mut self, sender: &PublicKey, chain_key: &ChainKey) -> Self {
367        self.add_receiver_chain(sender, chain_key);
368        self
369    }
370
371    pub(crate) fn set_sender_chain(&mut self, sender: &KeyPair, next_chain_key: &ChainKey) {
372        let chain_key = session_structure::chain::ChainKey {
373            index: next_chain_key.index(),
374            key: next_chain_key.key().to_vec(),
375        };
376
377        let new_chain = session_structure::Chain {
378            sender_ratchet_key: sender.public_key.serialize().to_vec(),
379            sender_ratchet_key_private: sender.private_key.serialize().to_vec(),
380            chain_key: Some(chain_key),
381            message_keys: vec![],
382        };
383
384        self.session.sender_chain = Some(new_chain);
385    }
386
387    pub(crate) fn with_sender_chain(mut self, sender: &KeyPair, next_chain_key: &ChainKey) -> Self {
388        self.set_sender_chain(sender, next_chain_key);
389        self
390    }
391
392    pub(crate) fn get_sender_chain_key(&self) -> Result<ChainKey, InvalidSessionError> {
393        let sender_chain = self
394            .session
395            .sender_chain
396            .as_ref()
397            .ok_or(InvalidSessionError("missing sender chain"))?;
398
399        let chain_key = sender_chain
400            .chain_key
401            .as_ref()
402            .ok_or(InvalidSessionError("missing sender chain key"))?;
403
404        let chain_key_bytes = chain_key.key[..]
405            .try_into()
406            .map_err(|_| InvalidSessionError("invalid sender chain key"))?;
407
408        Ok(ChainKey::new(chain_key_bytes, chain_key.index))
409    }
410
411    pub(crate) fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
412        Ok(self.get_sender_chain_key()?.key().to_vec())
413    }
414
415    pub(crate) fn set_sender_chain_key(&mut self, next_chain_key: &ChainKey) {
416        let chain_key = session_structure::chain::ChainKey {
417            index: next_chain_key.index(),
418            key: next_chain_key.key().to_vec(),
419        };
420
421        // Is it actually valid to call this function with sender_chain == None?
422
423        let new_chain = match self.session.sender_chain.take() {
424            None => session_structure::Chain {
425                sender_ratchet_key: vec![],
426                sender_ratchet_key_private: vec![],
427                chain_key: Some(chain_key),
428                message_keys: vec![],
429            },
430            Some(mut c) => {
431                c.chain_key = Some(chain_key);
432                c
433            }
434        };
435
436        self.session.sender_chain = Some(new_chain);
437    }
438
439    #[cfg(test)]
440    pub(crate) fn get_message_keys(
441        &mut self,
442        sender: &PublicKey,
443        counter: u32,
444    ) -> Result<Option<MessageKeyGenerator>, InvalidSessionError> {
445        if let Some(mut chain_and_index) = self.get_receiver_chain(sender)? {
446            let message_key_idx = chain_and_index
447                .0
448                .message_keys
449                .iter()
450                .position(|m| m.index == counter);
451
452            if let Some(position) = message_key_idx {
453                let message_key = chain_and_index.0.message_keys.remove(position);
454                let keys =
455                    MessageKeyGenerator::from_pb(message_key).map_err(InvalidSessionError)?;
456
457                // Update with message key removed
458                self.session.receiver_chains[chain_and_index.1] = chain_and_index.0;
459                return Ok(Some(keys));
460            }
461        }
462
463        Ok(None)
464    }
465
466    #[cfg(test)]
467    pub(crate) fn set_message_keys(
468        &mut self,
469        sender: &PublicKey,
470        message_keys: MessageKeyGenerator,
471    ) -> Result<(), InvalidSessionError> {
472        let chain_and_index = self
473            .get_receiver_chain(sender)?
474            .expect("called set_message_keys for a non-existent chain");
475        let mut updated_chain = chain_and_index.0;
476        updated_chain.message_keys.insert(0, message_keys.into_pb());
477
478        if updated_chain.message_keys.len() > consts::MAX_MESSAGE_KEYS {
479            updated_chain.message_keys.pop();
480        }
481
482        self.session.receiver_chains[chain_and_index.1] = updated_chain;
483
484        Ok(())
485    }
486
487    #[cfg(test)]
488    pub(crate) fn set_receiver_chain_key(
489        &mut self,
490        sender: &PublicKey,
491        chain_key: &ChainKey,
492    ) -> Result<(), InvalidSessionError> {
493        let chain_and_index = self
494            .get_receiver_chain(sender)?
495            .expect("called set_receiver_chain_key for a non-existent chain");
496        let mut updated_chain = chain_and_index.0;
497        updated_chain.chain_key = Some(session_structure::chain::ChainKey {
498            index: chain_key.index(),
499            key: chain_key.key().to_vec(),
500        });
501
502        self.session.receiver_chains[chain_and_index.1] = updated_chain;
503
504        Ok(())
505    }
506
507    pub(crate) fn set_unacknowledged_pre_key_message(
508        &mut self,
509        pre_key_id: Option<PreKeyId>,
510        signed_ec_pre_key_id: SignedPreKeyId,
511        base_key: &PublicKey,
512        now: SystemTime,
513    ) {
514        let signed_ec_pre_key_id: u32 = signed_ec_pre_key_id.into();
515        let pending = session_structure::PendingPreKey {
516            pre_key_id: pre_key_id.map(PreKeyId::into),
517            signed_pre_key_id: signed_ec_pre_key_id as i32,
518            base_key: base_key.serialize().to_vec(),
519            timestamp: now
520                .duration_since(SystemTime::UNIX_EPOCH)
521                .unwrap_or_default()
522                .as_secs(),
523        };
524        self.session.pending_pre_key = Some(pending);
525    }
526
527    pub(crate) fn set_kyber_ciphertext(&mut self, ciphertext: kem::SerializedCiphertext) {
528        let pending = session_structure::PendingKyberPreKey {
529            pre_key_id: u32::MAX, // has to be set to the actual value separately
530            ciphertext: ciphertext.into_vec(),
531        };
532        self.session.pending_kyber_pre_key = Some(pending);
533    }
534
535    pub(crate) fn set_unacknowledged_kyber_pre_key_id(
536        &mut self,
537        signed_kyber_pre_key_id: KyberPreKeyId,
538    ) {
539        let pending = self
540            .session
541            .pending_kyber_pre_key
542            .as_mut()
543            .expect("must have been set if kyber pre key is present");
544        pending.pre_key_id = signed_kyber_pre_key_id.into();
545    }
546
547    pub(crate) fn unacknowledged_pre_key_message_items(
548        &self,
549    ) -> Result<Option<UnacknowledgedPreKeyMessageItems<'_>>, InvalidSessionError> {
550        if let Some(ref pending_pre_key) = self.session.pending_pre_key {
551            Ok(Some(UnacknowledgedPreKeyMessageItems::new(
552                pending_pre_key.pre_key_id.map(Into::into),
553                (pending_pre_key.signed_pre_key_id as u32).into(),
554                PublicKey::deserialize(&pending_pre_key.base_key)
555                    .map_err(|_| InvalidSessionError("invalid pending PreKey message base key"))?,
556                self.session.pending_kyber_pre_key.as_ref(),
557                SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp),
558            )))
559        } else {
560            Ok(None)
561        }
562    }
563
564    pub(crate) fn clear_unacknowledged_pre_key_message(&mut self) {
565        // Explicitly destructuring the SessionStructure in case there are new
566        // pending fields that need to be cleared.
567        let SessionStructure {
568            session_version: _session_version,
569            local_identity_public: _local_identity_public,
570            remote_identity_public: _remote_identity_public,
571            root_key: _root_key,
572            previous_counter: _previous_counter,
573            sender_chain: _sender_chain,
574            receiver_chains: _receiver_chains,
575            pending_pre_key: _pending_pre_key,
576            pending_kyber_pre_key: _pending_kyber_pre_key,
577            remote_registration_id: _remote_registration_id,
578            local_registration_id: _local_registration_id,
579            alice_base_key: _alice_base_key,
580            pq_ratchet_state: _pq_ratchet_state,
581        } = &self.session;
582        // ####### IMPORTANT #######
583        // Don't forget to clean up new pending fields.
584        // ####### IMPORTANT #######
585        self.session.pending_pre_key = None;
586        self.session.pending_kyber_pre_key = None;
587    }
588
589    pub(crate) fn set_remote_registration_id(&mut self, registration_id: u32) {
590        self.session.remote_registration_id = registration_id;
591    }
592
593    pub(crate) fn remote_registration_id(&self) -> u32 {
594        self.session.remote_registration_id
595    }
596
597    pub(crate) fn set_local_registration_id(&mut self, registration_id: u32) {
598        self.session.local_registration_id = registration_id;
599    }
600
601    pub(crate) fn local_registration_id(&self) -> u32 {
602        self.session.local_registration_id
603    }
604
605    pub(crate) fn get_kyber_ciphertext(&self) -> Option<&Vec<u8>> {
606        self.session
607            .pending_kyber_pre_key
608            .as_ref()
609            .map(|pending| &pending.ciphertext)
610    }
611
612    /// Extract the Double Ratchet state from this session.
613    ///
614    /// This **moves** the receiver chains out of the session protobuf,
615    /// leaving them empty. The caller must either apply the ratchet state
616    /// back via [`apply_ratchet_state`](Self::apply_ratchet_state) or
617    /// discard this `SessionState` entirely. Reading receiver chains from
618    /// a taken session is invalid and will produce incorrect results.
619    ///
620    /// The caller must supply `self_session` (whether this is a note-to-self
621    /// session) since it requires identity key comparison, which is above the
622    /// ratchet layer.
623    pub(crate) fn take_ratchet_state(
624        &mut self,
625        self_session: bool,
626    ) -> crate::error::Result<crate::double_ratchet::RatchetState> {
627        let receiver_chains = std::mem::take(&mut self.session.receiver_chains);
628        Ok(crate::double_ratchet::RatchetState::from_pb(
629            &self.session,
630            self_session,
631            receiver_chains,
632        )?)
633    }
634
635    /// Write modified Double Ratchet state back into this session.
636    ///
637    /// Only the ratchet-relevant fields are updated; all other session
638    /// fields (identity keys, pending pre-key, SPQR state, etc.) are
639    /// left unchanged.
640    pub(crate) fn apply_ratchet_state(&mut self, ratchet: crate::double_ratchet::RatchetState) {
641        ratchet.apply_to_pb(&mut self.session);
642    }
643
644    #[cfg(test)]
645    pub(crate) fn pq_ratchet_recv(
646        &mut self,
647        msg: &spqr::SerializedMessage,
648    ) -> Result<spqr::MessageKey, spqr::Error> {
649        let _trace = libsignal_debug::trace_block!("SessionState::pq_ratchet_recv");
650        let spqr::Recv { state, key } = spqr::recv(&self.session.pq_ratchet_state, msg)?;
651        self.session.pq_ratchet_state = state;
652        Ok(key)
653    }
654
655    #[cfg(test)]
656    pub(crate) fn pq_ratchet_send<R: Rng + CryptoRng>(
657        &mut self,
658        csprng: &mut R,
659    ) -> Result<(spqr::SerializedMessage, spqr::MessageKey), spqr::Error> {
660        let _trace = libsignal_debug::trace_block!("SessionState::pq_ratchet_send");
661        let spqr::Send { state, key, msg } = spqr::send(&self.session.pq_ratchet_state, csprng)?;
662        self.session.pq_ratchet_state = state;
663        Ok((msg, key))
664    }
665
666    pub(crate) fn pq_ratchet_state(&self) -> &spqr::SerializedState {
667        &self.session.pq_ratchet_state
668    }
669
670    /// Move the PQ ratchet state out, leaving an empty vec in its place.
671    ///
672    /// The caller must either write a new PQ ratchet state back via
673    /// [`set_pq_ratchet_state`](Self::set_pq_ratchet_state) or discard
674    /// this `SessionState` entirely.
675    pub(crate) fn take_pq_ratchet_state(&mut self) -> spqr::SerializedState {
676        std::mem::take(&mut self.session.pq_ratchet_state)
677    }
678
679    pub(crate) fn set_pq_ratchet_state(&mut self, state: spqr::SerializedState) {
680        self.session.pq_ratchet_state = state;
681    }
682}
683
684impl From<SessionStructure> for SessionState {
685    fn from(value: SessionStructure) -> SessionState {
686        SessionState::from_session_structure(value)
687    }
688}
689
690impl From<SessionState> for SessionStructure {
691    fn from(value: SessionState) -> SessionStructure {
692        value.session
693    }
694}
695
696impl From<&SessionState> for SessionStructure {
697    fn from(value: &SessionState) -> SessionStructure {
698        value.session.clone()
699    }
700}
701
702#[derive(Clone)]
703pub struct SessionRecord {
704    current_session: Option<SessionState>,
705    previous_sessions: Vec<Vec<u8>>,
706}
707
708impl SessionRecord {
709    pub fn new_fresh() -> Self {
710        Self {
711            current_session: None,
712            previous_sessions: Vec::new(),
713        }
714    }
715
716    pub(crate) fn new(state: SessionState) -> Self {
717        Self {
718            current_session: Some(state),
719            previous_sessions: Vec::new(),
720        }
721    }
722
723    pub fn deserialize(bytes: &[u8]) -> Result<Self, SignalProtocolError> {
724        let record = RecordStructure::decode(bytes)
725            .map_err(|_| InvalidSessionError("failed to decode session record protobuf"))?;
726
727        Ok(Self {
728            current_session: record.current_session.map(|s| s.into()),
729            previous_sessions: record.previous_sessions,
730        })
731    }
732
733    /// If there's a session with a matching version and `alice_base_key`, ensures that it is the
734    /// current session, promoting if necessary.
735    ///
736    /// Returns `Ok(true)` if such a session was found, `Ok(false)` if not, and
737    /// `Err(InvalidSessionError)` if an invalid session was found during the search (whether
738    /// current or not).
739    pub(crate) fn promote_matching_session(
740        &mut self,
741        version: u32,
742        alice_base_key: &[u8],
743    ) -> Result<bool, InvalidSessionError> {
744        if let Some(current_session) = &self.current_session {
745            if current_session.session_version()? == version
746                && alice_base_key
747                    .ct_eq(current_session.alice_base_key())
748                    .into()
749            {
750                return Ok(true);
751            }
752        }
753
754        let mut session_to_promote = None;
755        for (i, previous) in self.previous_session_states().enumerate() {
756            let previous = previous?;
757            if previous.session_version()? == version
758                && alice_base_key.ct_eq(previous.alice_base_key()).into()
759            {
760                session_to_promote = Some((i, previous));
761                break;
762            }
763        }
764
765        if let Some((i, state)) = session_to_promote {
766            self.promote_old_session(i, state);
767            return Ok(true);
768        }
769
770        Ok(false)
771    }
772
773    pub(crate) fn session_state(&self) -> Option<&SessionState> {
774        self.current_session.as_ref()
775    }
776
777    pub(crate) fn session_state_mut(&mut self) -> Option<&mut SessionState> {
778        self.current_session.as_mut()
779    }
780
781    pub(crate) fn set_session_state(&mut self, session: SessionState) {
782        self.current_session = Some(session);
783    }
784
785    pub(crate) fn previous_session_states(
786        &self,
787    ) -> impl ExactSizeIterator<Item = Result<SessionState, InvalidSessionError>> + '_ {
788        self.previous_sessions.iter().map(|bytes| {
789            Ok(SessionStructure::decode(&bytes[..])
790                .map_err(|_| InvalidSessionError("failed to decode previous session protobuf"))?
791                .into())
792        })
793    }
794
795    pub(crate) fn promote_old_session(
796        &mut self,
797        old_session: usize,
798        updated_session: SessionState,
799    ) {
800        self.previous_sessions.remove(old_session);
801        self.promote_state(updated_session)
802    }
803
804    pub(crate) fn promote_state(&mut self, new_state: SessionState) {
805        self.archive_current_state_inner();
806        self.current_session = Some(new_state);
807    }
808
809    // A non-fallible version of archive_current_state.
810    //
811    // Returns `true` if there was a session to archive, `false` if not.
812    fn archive_current_state_inner(&mut self) -> bool {
813        if let Some(mut current_session) = self.current_session.take() {
814            if self.previous_sessions.len() >= consts::ARCHIVED_STATES_MAX_LENGTH {
815                self.previous_sessions.pop();
816            }
817            current_session.clear_unacknowledged_pre_key_message();
818            self.previous_sessions
819                .insert(0, current_session.session.encode_to_vec());
820            true
821        } else {
822            false
823        }
824    }
825
826    pub fn archive_current_state(&mut self) -> Result<(), SignalProtocolError> {
827        if !self.archive_current_state_inner() {
828            log::info!("Skipping archive, current session state is fresh");
829        }
830        Ok(())
831    }
832
833    pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
834        let record = RecordStructure {
835            current_session: self.current_session.as_ref().map(|s| s.into()),
836            previous_sessions: self.previous_sessions.clone(),
837        };
838        Ok(record.encode_to_vec())
839    }
840
841    pub fn current_pq_state(&self) -> Option<&spqr::SerializedState> {
842        self.current_session.as_ref().map(|s| s.pq_ratchet_state())
843    }
844
845    pub fn remote_registration_id(&self) -> Result<u32, SignalProtocolError> {
846        Ok(self
847            .session_state()
848            .ok_or_else(|| {
849                SignalProtocolError::InvalidState(
850                    "remote_registration_id",
851                    "No current session".into(),
852                )
853            })?
854            .remote_registration_id())
855    }
856
857    pub fn local_registration_id(&self) -> Result<u32, SignalProtocolError> {
858        Ok(self
859            .session_state()
860            .ok_or_else(|| {
861                SignalProtocolError::InvalidState(
862                    "local_registration_id",
863                    "No current session".into(),
864                )
865            })?
866            .local_registration_id())
867    }
868
869    pub fn session_version(&self) -> Result<u32, SignalProtocolError> {
870        Ok(self
871            .session_state()
872            .ok_or_else(|| {
873                SignalProtocolError::InvalidState("session_version", "No current session".into())
874            })?
875            .session_version()?)
876    }
877
878    pub fn local_identity_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
879        Ok(self
880            .session_state()
881            .ok_or_else(|| {
882                SignalProtocolError::InvalidState(
883                    "local_identity_key_bytes",
884                    "No current session".into(),
885                )
886            })?
887            .local_identity_key_bytes()?)
888    }
889
890    pub fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, SignalProtocolError> {
891        Ok(self
892            .session_state()
893            .ok_or_else(|| {
894                SignalProtocolError::InvalidState(
895                    "remote_identity_key_bytes",
896                    "No current session".into(),
897                )
898            })?
899            .remote_identity_key_bytes()?)
900    }
901
902    pub fn has_usable_sender_chain(
903        &self,
904        now: SystemTime,
905        requirements: SessionUsabilityRequirements,
906    ) -> Result<bool, SignalProtocolError> {
907        match &self.current_session {
908            Some(session) => Ok(session.has_usable_sender_chain(now, requirements)?),
909            None => Ok(false),
910        }
911    }
912
913    pub fn alice_base_key(&self) -> Result<&[u8], SignalProtocolError> {
914        Ok(self
915            .session_state()
916            .ok_or_else(|| {
917                SignalProtocolError::InvalidState("alice_base_key", "No current session".into())
918            })?
919            .alice_base_key())
920    }
921
922    pub fn get_receiver_chain_key_bytes(
923        &self,
924        sender: &PublicKey,
925    ) -> Result<Option<Box<[u8]>>, SignalProtocolError> {
926        Ok(self
927            .session_state()
928            .ok_or_else(|| {
929                SignalProtocolError::InvalidState(
930                    "get_receiver_chain_key",
931                    "No current session".into(),
932                )
933            })?
934            .get_receiver_chain_key(sender)?
935            .map(|chain| chain.key()[..].into()))
936    }
937
938    pub fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
939        Ok(self
940            .session_state()
941            .ok_or_else(|| {
942                SignalProtocolError::InvalidState(
943                    "get_sender_chain_key_bytes",
944                    "No current session".into(),
945                )
946            })?
947            .get_sender_chain_key_bytes()?)
948    }
949
950    pub fn current_ratchet_key_matches(
951        &self,
952        key: &PublicKey,
953    ) -> Result<bool, SignalProtocolError> {
954        match &self.current_session {
955            Some(session) => Ok(&session.sender_ratchet_key()? == key),
956            None => Ok(false),
957        }
958    }
959
960    pub fn get_kyber_ciphertext(&self) -> Result<Option<&Vec<u8>>, SignalProtocolError> {
961        Ok(self
962            .session_state()
963            .ok_or_else(|| {
964                SignalProtocolError::InvalidState(
965                    "get_kyber_ciphertext",
966                    "No current session".into(),
967                )
968            })?
969            .get_kyber_ciphertext())
970    }
971}