libsignal_protocol/state/
session.rs

1//
2// Copyright 2020-2022 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::result::Result;
7use std::time::{Duration, SystemTime};
8
9use bitflags::bitflags;
10use prost::Message;
11use rand::{CryptoRng, Rng};
12use subtle::ConstantTimeEq;
13
14use crate::proto::storage::{RecordStructure, SessionStructure, session_structure};
15use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION;
16use crate::ratchet::{ChainKey, MessageKeyGenerator, RootKey};
17use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
18use crate::{IdentityKey, KeyPair, PrivateKey, PublicKey, SignalProtocolError, consts, kem};
19
20/// A distinct error type to keep from accidentally propagating deserialization errors.
21#[derive(Debug)]
22pub(crate) struct InvalidSessionError(&'static str);
23
24impl std::fmt::Display for InvalidSessionError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        self.0.fmt(f)
27    }
28}
29
30impl From<InvalidSessionError> for SignalProtocolError {
31    fn from(e: InvalidSessionError) -> Self {
32        Self::InvalidSessionStructure(e.0)
33    }
34}
35
36#[derive(Debug, Clone)]
37pub(crate) struct UnacknowledgedPreKeyMessageItems<'a> {
38    pre_key_id: Option<PreKeyId>,
39    signed_pre_key_id: SignedPreKeyId,
40    base_key: PublicKey,
41    // Although we require PQXDH for all new sessions now,
42    // we may in theory have an existing X3DH unacknowledged session,
43    // so we leave these optional for now.
44    kyber_pre_key_id: Option<KyberPreKeyId>,
45    kyber_ciphertext: Option<&'a [u8]>,
46    timestamp: SystemTime,
47}
48
49impl<'a> UnacknowledgedPreKeyMessageItems<'a> {
50    fn new(
51        pre_key_id: Option<PreKeyId>,
52        signed_pre_key_id: SignedPreKeyId,
53        base_key: PublicKey,
54        pending_kyber_pre_key: Option<&'a session_structure::PendingKyberPreKey>,
55        timestamp: SystemTime,
56    ) -> Self {
57        let (kyber_pre_key_id, kyber_ciphertext) = pending_kyber_pre_key
58            .map(|pending| (pending.pre_key_id.into(), pending.ciphertext.as_slice()))
59            .unzip();
60        Self {
61            pre_key_id,
62            signed_pre_key_id,
63            base_key,
64            kyber_pre_key_id,
65            kyber_ciphertext,
66            timestamp,
67        }
68    }
69
70    pub(crate) fn pre_key_id(&self) -> Option<PreKeyId> {
71        self.pre_key_id
72    }
73
74    pub(crate) fn signed_pre_key_id(&self) -> SignedPreKeyId {
75        self.signed_pre_key_id
76    }
77
78    pub(crate) fn base_key(&self) -> &PublicKey {
79        &self.base_key
80    }
81
82    pub(crate) fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
83        self.kyber_pre_key_id
84    }
85
86    pub(crate) fn kyber_ciphertext(&self) -> Option<&'a [u8]> {
87        self.kyber_ciphertext
88    }
89
90    pub(crate) fn timestamp(&self) -> SystemTime {
91        self.timestamp
92    }
93}
94
95bitflags! {
96    /// Specifies which criteria make a session "usable" beyond simply having a present sender
97    /// chain.
98    ///
99    /// These requirements are conjunctive, i.e. specifying `NotStale | EstablishedWithPqxdh` means
100    /// the session must be neither stale nor established with X3DH.
101    ///
102    /// This struct is generated using the `bitflags` crate; the [`Flags`](::bitflags::Flags) trait
103    /// provides most of its API surface. It can also use "classic" C bitflag syntax, with `|` for
104    /// union and `&` for intersection (as shown above).
105    #[repr(transparent)]
106    #[derive(Clone, Copy, PartialEq, Eq)]
107    pub struct SessionUsabilityRequirements : u32 {
108        /// Requires that a session not be stale.
109        ///
110        /// A non-stale session is one of the following:
111        /// - "incoming", i.e. started by the peer
112        /// - "acknowledged", i.e. started locally but received a response
113        /// - no more than a few weeks old (the exact time is chosen by libsignal)
114        const NotStale = 1 << 0;
115        /// Requires that a session was established using PQXDH (or newer) rather than X3DH/X4DH.
116        ///
117        /// This includes unacknowledged sessions that are using PQXDH, since if they get a
118        /// response, the peer is confirmed to be using PQXDH as well.
119        const EstablishedWithPqxdh = 1 << 1;
120        /// Requires that a session is using SPQR.
121        ///
122        /// **Warning:** This allows unacknowledged sessions that include SPQR in their PreKey
123        /// messages. If the peer downgrades the session (by discarding the SPQR information) and
124        /// the local client allows it, a session that is previously considered "usable" can become
125        /// "not usable" upon receiving a response. Therefore, this should not be used to determine
126        /// whether a session is usable unless future downgrades will also be rejected.
127        const Spqr = 1 << 2;
128    }
129}
130
131#[derive(Clone, Debug)]
132pub(crate) struct SessionState {
133    session: SessionStructure,
134}
135
136impl SessionState {
137    pub(crate) fn from_session_structure(session: SessionStructure) -> Self {
138        Self { session }
139    }
140
141    pub(crate) fn new(
142        version: u8,
143        our_identity: &IdentityKey,
144        their_identity: &IdentityKey,
145        root_key: &RootKey,
146        alice_base_key: &PublicKey,
147        pq_ratchet_state: spqr::SerializedState,
148    ) -> Self {
149        Self {
150            session: SessionStructure {
151                session_version: version as u32,
152                local_identity_public: our_identity.public_key().serialize().into_vec(),
153                remote_identity_public: their_identity.serialize().into_vec(),
154                root_key: root_key.key().to_vec(),
155                previous_counter: 0,
156                sender_chain: None,
157                receiver_chains: vec![],
158                pending_pre_key: None,
159                pending_kyber_pre_key: None,
160                remote_registration_id: 0,
161                local_registration_id: 0,
162                alice_base_key: alice_base_key.serialize().into_vec(),
163                pq_ratchet_state,
164            },
165        }
166    }
167
168    pub(crate) fn alice_base_key(&self) -> &[u8] {
169        // Check the length before returning?
170        &self.session.alice_base_key
171    }
172
173    pub(crate) fn session_version(&self) -> Result<u32, InvalidSessionError> {
174        match self.session.session_version {
175            0 => Ok(2),
176            v => Ok(v),
177        }
178    }
179
180    pub(crate) fn remote_identity_key(&self) -> Result<Option<IdentityKey>, InvalidSessionError> {
181        match self.session.remote_identity_public.len() {
182            0 => Ok(None),
183            _ => Ok(Some(
184                IdentityKey::decode(&self.session.remote_identity_public)
185                    .map_err(|_| InvalidSessionError("invalid remote identity key"))?,
186            )),
187        }
188    }
189
190    pub(crate) fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, InvalidSessionError> {
191        Ok(self.remote_identity_key()?.map(|k| k.serialize().to_vec()))
192    }
193
194    pub(crate) fn local_identity_key(&self) -> Result<IdentityKey, InvalidSessionError> {
195        IdentityKey::decode(&self.session.local_identity_public)
196            .map_err(|_| InvalidSessionError("invalid local identity key"))
197    }
198
199    pub(crate) fn local_identity_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
200        Ok(self.local_identity_key()?.serialize().to_vec())
201    }
202
203    pub(crate) fn session_with_self(&self) -> Result<bool, InvalidSessionError> {
204        if let Some(remote_id) = self.remote_identity_key_bytes()? {
205            let local_id = self.local_identity_key_bytes()?;
206            return Ok(remote_id == local_id);
207        }
208
209        // If remote ID is not set then we can't be sure but treat as non-self
210        Ok(false)
211    }
212
213    pub(crate) fn previous_counter(&self) -> u32 {
214        self.session.previous_counter
215    }
216
217    pub(crate) fn set_previous_counter(&mut self, ctr: u32) {
218        self.session.previous_counter = ctr;
219    }
220
221    pub(crate) fn root_key(&self) -> Result<RootKey, InvalidSessionError> {
222        let root_key_bytes = self.session.root_key[..]
223            .try_into()
224            .map_err(|_| InvalidSessionError("invalid root key"))?;
225        Ok(RootKey::new(root_key_bytes))
226    }
227
228    pub(crate) fn set_root_key(&mut self, root_key: &RootKey) {
229        self.session.root_key = root_key.key().to_vec();
230    }
231
232    pub(crate) fn sender_ratchet_key(&self) -> Result<PublicKey, InvalidSessionError> {
233        match self.session.sender_chain {
234            None => Err(InvalidSessionError("missing sender chain")),
235            Some(ref c) => PublicKey::deserialize(&c.sender_ratchet_key)
236                .map_err(|_| InvalidSessionError("invalid sender chain ratchet key")),
237        }
238    }
239
240    pub(crate) fn sender_ratchet_key_for_logging(&self) -> Result<String, InvalidSessionError> {
241        Ok(hex::encode(self.sender_ratchet_key()?.public_key_bytes()))
242    }
243
244    pub(crate) fn sender_ratchet_private_key(&self) -> Result<PrivateKey, InvalidSessionError> {
245        match self.session.sender_chain {
246            None => Err(InvalidSessionError("missing sender chain")),
247            Some(ref c) => PrivateKey::deserialize(&c.sender_ratchet_key_private)
248                .map_err(|_| InvalidSessionError("invalid sender chain private ratchet key")),
249        }
250    }
251
252    pub fn has_usable_sender_chain(
253        &self,
254        now: SystemTime,
255        requirements: SessionUsabilityRequirements,
256    ) -> Result<bool, InvalidSessionError> {
257        if self.session.sender_chain.is_none() {
258            return Ok(false);
259        }
260        if requirements.contains(SessionUsabilityRequirements::NotStale) {
261            if let Some(pending_pre_key) = &self.session.pending_pre_key {
262                let creation_timestamp =
263                    SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp);
264                if creation_timestamp + consts::MAX_UNACKNOWLEDGED_SESSION_AGE < now {
265                    return Ok(false);
266                }
267            }
268        }
269        #[allow(clippy::collapsible_if)]
270        if requirements.contains(SessionUsabilityRequirements::EstablishedWithPqxdh) {
271            if self.session_version()? <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION.into() {
272                return Ok(false);
273            }
274        }
275        #[allow(clippy::collapsible_if)]
276        if requirements.contains(SessionUsabilityRequirements::Spqr) {
277            if self.pq_ratchet_state().is_empty() {
278                return Ok(false);
279            }
280        }
281        Ok(true)
282    }
283
284    pub(crate) fn all_receiver_chain_logging_info(&self) -> Vec<(Vec<u8>, Option<u32>)> {
285        let mut results = vec![];
286        for chain in self.session.receiver_chains.iter() {
287            let sender_ratchet_public = chain.sender_ratchet_key.clone();
288
289            let chain_key_idx = chain.chain_key.as_ref().map(|chain_key| chain_key.index);
290
291            results.push((sender_ratchet_public, chain_key_idx))
292        }
293        results
294    }
295
296    pub(crate) fn get_receiver_chain(
297        &self,
298        sender: &PublicKey,
299    ) -> Result<Option<(session_structure::Chain, usize)>, InvalidSessionError> {
300        for (idx, chain) in self.session.receiver_chains.iter().enumerate() {
301            // If we compared bytes directly it would be faster, but may miss non-canonical points.
302            // It's unclear if supporting such points is desirable.
303            let chain_ratchet_key = PublicKey::deserialize(&chain.sender_ratchet_key)
304                .map_err(|_| InvalidSessionError("invalid receiver chain ratchet key"))?;
305
306            if &chain_ratchet_key == sender {
307                return Ok(Some((chain.clone(), idx)));
308            }
309        }
310
311        Ok(None)
312    }
313
314    pub(crate) fn get_receiver_chain_key(
315        &self,
316        sender: &PublicKey,
317    ) -> Result<Option<ChainKey>, InvalidSessionError> {
318        match self.get_receiver_chain(sender)? {
319            None => Ok(None),
320            Some((chain, _)) => match chain.chain_key {
321                None => Err(InvalidSessionError("missing receiver chain key")),
322                Some(c) => {
323                    let chain_key_bytes = c.key[..]
324                        .try_into()
325                        .map_err(|_| InvalidSessionError("invalid receiver chain key"))?;
326                    Ok(Some(ChainKey::new(chain_key_bytes, c.index)))
327                }
328            },
329        }
330    }
331
332    pub(crate) fn add_receiver_chain(&mut self, sender: &PublicKey, chain_key: &ChainKey) {
333        let chain_key = session_structure::chain::ChainKey {
334            index: chain_key.index(),
335            key: chain_key.key().to_vec(),
336        };
337
338        let chain = session_structure::Chain {
339            sender_ratchet_key: sender.serialize().to_vec(),
340            sender_ratchet_key_private: vec![],
341            chain_key: Some(chain_key),
342            message_keys: vec![],
343        };
344
345        self.session.receiver_chains.push(chain);
346
347        if self.session.receiver_chains.len() > consts::MAX_RECEIVER_CHAINS {
348            log::info!(
349                "Trimming excessive receiver_chain for session with base key {}, chain count: {}",
350                self.sender_ratchet_key_for_logging()
351                    .unwrap_or_else(|e| format!("<error: {}>", e.0)),
352                self.session.receiver_chains.len()
353            );
354            self.session.receiver_chains.remove(0);
355        }
356    }
357
358    pub(crate) fn with_receiver_chain(mut self, sender: &PublicKey, chain_key: &ChainKey) -> Self {
359        self.add_receiver_chain(sender, chain_key);
360        self
361    }
362
363    pub(crate) fn set_sender_chain(&mut self, sender: &KeyPair, next_chain_key: &ChainKey) {
364        let chain_key = session_structure::chain::ChainKey {
365            index: next_chain_key.index(),
366            key: next_chain_key.key().to_vec(),
367        };
368
369        let new_chain = session_structure::Chain {
370            sender_ratchet_key: sender.public_key.serialize().to_vec(),
371            sender_ratchet_key_private: sender.private_key.serialize().to_vec(),
372            chain_key: Some(chain_key),
373            message_keys: vec![],
374        };
375
376        self.session.sender_chain = Some(new_chain);
377    }
378
379    pub(crate) fn with_sender_chain(mut self, sender: &KeyPair, next_chain_key: &ChainKey) -> Self {
380        self.set_sender_chain(sender, next_chain_key);
381        self
382    }
383
384    pub(crate) fn get_sender_chain_key(&self) -> Result<ChainKey, InvalidSessionError> {
385        let sender_chain = self
386            .session
387            .sender_chain
388            .as_ref()
389            .ok_or(InvalidSessionError("missing sender chain"))?;
390
391        let chain_key = sender_chain
392            .chain_key
393            .as_ref()
394            .ok_or(InvalidSessionError("missing sender chain key"))?;
395
396        let chain_key_bytes = chain_key.key[..]
397            .try_into()
398            .map_err(|_| InvalidSessionError("invalid sender chain key"))?;
399
400        Ok(ChainKey::new(chain_key_bytes, chain_key.index))
401    }
402
403    pub(crate) fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
404        Ok(self.get_sender_chain_key()?.key().to_vec())
405    }
406
407    pub(crate) fn set_sender_chain_key(&mut self, next_chain_key: &ChainKey) {
408        let chain_key = session_structure::chain::ChainKey {
409            index: next_chain_key.index(),
410            key: next_chain_key.key().to_vec(),
411        };
412
413        // Is it actually valid to call this function with sender_chain == None?
414
415        let new_chain = match self.session.sender_chain.take() {
416            None => session_structure::Chain {
417                sender_ratchet_key: vec![],
418                sender_ratchet_key_private: vec![],
419                chain_key: Some(chain_key),
420                message_keys: vec![],
421            },
422            Some(mut c) => {
423                c.chain_key = Some(chain_key);
424                c
425            }
426        };
427
428        self.session.sender_chain = Some(new_chain);
429    }
430
431    pub(crate) fn get_message_keys(
432        &mut self,
433        sender: &PublicKey,
434        counter: u32,
435    ) -> Result<Option<MessageKeyGenerator>, InvalidSessionError> {
436        if let Some(mut chain_and_index) = self.get_receiver_chain(sender)? {
437            let message_key_idx = chain_and_index
438                .0
439                .message_keys
440                .iter()
441                .position(|m| m.index == counter);
442
443            if let Some(position) = message_key_idx {
444                let message_key = chain_and_index.0.message_keys.remove(position);
445                let keys =
446                    MessageKeyGenerator::from_pb(message_key).map_err(InvalidSessionError)?;
447
448                // Update with message key removed
449                self.session.receiver_chains[chain_and_index.1] = chain_and_index.0;
450                return Ok(Some(keys));
451            }
452        }
453
454        Ok(None)
455    }
456
457    pub(crate) fn set_message_keys(
458        &mut self,
459        sender: &PublicKey,
460        message_keys: MessageKeyGenerator,
461    ) -> Result<(), InvalidSessionError> {
462        let chain_and_index = self
463            .get_receiver_chain(sender)?
464            .expect("called set_message_keys for a non-existent chain");
465        let mut updated_chain = chain_and_index.0;
466        updated_chain.message_keys.insert(0, message_keys.into_pb());
467
468        if updated_chain.message_keys.len() > consts::MAX_MESSAGE_KEYS {
469            updated_chain.message_keys.pop();
470        }
471
472        self.session.receiver_chains[chain_and_index.1] = updated_chain;
473
474        Ok(())
475    }
476
477    pub(crate) fn set_receiver_chain_key(
478        &mut self,
479        sender: &PublicKey,
480        chain_key: &ChainKey,
481    ) -> Result<(), InvalidSessionError> {
482        let chain_and_index = self
483            .get_receiver_chain(sender)?
484            .expect("called set_receiver_chain_key for a non-existent chain");
485        let mut updated_chain = chain_and_index.0;
486        updated_chain.chain_key = Some(session_structure::chain::ChainKey {
487            index: chain_key.index(),
488            key: chain_key.key().to_vec(),
489        });
490
491        self.session.receiver_chains[chain_and_index.1] = updated_chain;
492
493        Ok(())
494    }
495
496    pub(crate) fn set_unacknowledged_pre_key_message(
497        &mut self,
498        pre_key_id: Option<PreKeyId>,
499        signed_ec_pre_key_id: SignedPreKeyId,
500        base_key: &PublicKey,
501        now: SystemTime,
502    ) {
503        let signed_ec_pre_key_id: u32 = signed_ec_pre_key_id.into();
504        let pending = session_structure::PendingPreKey {
505            pre_key_id: pre_key_id.map(PreKeyId::into),
506            signed_pre_key_id: signed_ec_pre_key_id as i32,
507            base_key: base_key.serialize().to_vec(),
508            timestamp: now
509                .duration_since(SystemTime::UNIX_EPOCH)
510                .unwrap_or_default()
511                .as_secs(),
512        };
513        self.session.pending_pre_key = Some(pending);
514    }
515
516    pub(crate) fn set_kyber_ciphertext(&mut self, ciphertext: kem::SerializedCiphertext) {
517        let pending = session_structure::PendingKyberPreKey {
518            pre_key_id: u32::MAX, // has to be set to the actual value separately
519            ciphertext: ciphertext.into_vec(),
520        };
521        self.session.pending_kyber_pre_key = Some(pending);
522    }
523
524    pub(crate) fn set_unacknowledged_kyber_pre_key_id(
525        &mut self,
526        signed_kyber_pre_key_id: KyberPreKeyId,
527    ) {
528        let pending = self
529            .session
530            .pending_kyber_pre_key
531            .as_mut()
532            .expect("must have been set if kyber pre key is present");
533        pending.pre_key_id = signed_kyber_pre_key_id.into();
534    }
535
536    pub(crate) fn unacknowledged_pre_key_message_items(
537        &self,
538    ) -> Result<Option<UnacknowledgedPreKeyMessageItems<'_>>, InvalidSessionError> {
539        if let Some(ref pending_pre_key) = self.session.pending_pre_key {
540            Ok(Some(UnacknowledgedPreKeyMessageItems::new(
541                pending_pre_key.pre_key_id.map(Into::into),
542                (pending_pre_key.signed_pre_key_id as u32).into(),
543                PublicKey::deserialize(&pending_pre_key.base_key)
544                    .map_err(|_| InvalidSessionError("invalid pending PreKey message base key"))?,
545                self.session.pending_kyber_pre_key.as_ref(),
546                SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp),
547            )))
548        } else {
549            Ok(None)
550        }
551    }
552
553    pub(crate) fn clear_unacknowledged_pre_key_message(&mut self) {
554        // Explicitly destructuring the SessionStructure in case there are new
555        // pending fields that need to be cleared.
556        let SessionStructure {
557            session_version: _session_version,
558            local_identity_public: _local_identity_public,
559            remote_identity_public: _remote_identity_public,
560            root_key: _root_key,
561            previous_counter: _previous_counter,
562            sender_chain: _sender_chain,
563            receiver_chains: _receiver_chains,
564            pending_pre_key: _pending_pre_key,
565            pending_kyber_pre_key: _pending_kyber_pre_key,
566            remote_registration_id: _remote_registration_id,
567            local_registration_id: _local_registration_id,
568            alice_base_key: _alice_base_key,
569            pq_ratchet_state: _pq_ratchet_state,
570        } = &self.session;
571        // ####### IMPORTANT #######
572        // Don't forget to clean up new pending fields.
573        // ####### IMPORTANT #######
574        self.session.pending_pre_key = None;
575        self.session.pending_kyber_pre_key = None;
576    }
577
578    pub(crate) fn set_remote_registration_id(&mut self, registration_id: u32) {
579        self.session.remote_registration_id = registration_id;
580    }
581
582    pub(crate) fn remote_registration_id(&self) -> u32 {
583        self.session.remote_registration_id
584    }
585
586    pub(crate) fn set_local_registration_id(&mut self, registration_id: u32) {
587        self.session.local_registration_id = registration_id;
588    }
589
590    pub(crate) fn local_registration_id(&self) -> u32 {
591        self.session.local_registration_id
592    }
593
594    pub(crate) fn get_kyber_ciphertext(&self) -> Option<&Vec<u8>> {
595        self.session
596            .pending_kyber_pre_key
597            .as_ref()
598            .map(|pending| &pending.ciphertext)
599    }
600
601    pub(crate) fn pq_ratchet_recv(
602        &mut self,
603        msg: &spqr::SerializedMessage,
604    ) -> Result<spqr::MessageKey, spqr::Error> {
605        let _trace = libsignal_debug::trace_block!("SessionState::pq_ratchet_recv");
606        let spqr::Recv { state, key } = spqr::recv(&self.session.pq_ratchet_state, msg)?;
607        self.session.pq_ratchet_state = state;
608        Ok(key)
609    }
610
611    pub(crate) fn pq_ratchet_send<R: Rng + CryptoRng>(
612        &mut self,
613        csprng: &mut R,
614    ) -> Result<(spqr::SerializedMessage, spqr::MessageKey), spqr::Error> {
615        let _trace = libsignal_debug::trace_block!("SessionState::pq_ratchet_send");
616        let spqr::Send { state, key, msg } = spqr::send(&self.session.pq_ratchet_state, csprng)?;
617        self.session.pq_ratchet_state = state;
618        Ok((msg, key))
619    }
620
621    pub(crate) fn pq_ratchet_state(&self) -> &spqr::SerializedState {
622        &self.session.pq_ratchet_state
623    }
624}
625
626impl From<SessionStructure> for SessionState {
627    fn from(value: SessionStructure) -> SessionState {
628        SessionState::from_session_structure(value)
629    }
630}
631
632impl From<SessionState> for SessionStructure {
633    fn from(value: SessionState) -> SessionStructure {
634        value.session
635    }
636}
637
638impl From<&SessionState> for SessionStructure {
639    fn from(value: &SessionState) -> SessionStructure {
640        value.session.clone()
641    }
642}
643
644#[derive(Clone)]
645pub struct SessionRecord {
646    current_session: Option<SessionState>,
647    previous_sessions: Vec<Vec<u8>>,
648}
649
650impl SessionRecord {
651    pub fn new_fresh() -> Self {
652        Self {
653            current_session: None,
654            previous_sessions: Vec::new(),
655        }
656    }
657
658    pub(crate) fn new(state: SessionState) -> Self {
659        Self {
660            current_session: Some(state),
661            previous_sessions: Vec::new(),
662        }
663    }
664
665    pub fn deserialize(bytes: &[u8]) -> Result<Self, SignalProtocolError> {
666        let record = RecordStructure::decode(bytes)
667            .map_err(|_| InvalidSessionError("failed to decode session record protobuf"))?;
668
669        Ok(Self {
670            current_session: record.current_session.map(|s| s.into()),
671            previous_sessions: record.previous_sessions,
672        })
673    }
674
675    /// If there's a session with a matching version and `alice_base_key`, ensures that it is the
676    /// current session, promoting if necessary.
677    ///
678    /// Returns `Ok(true)` if such a session was found, `Ok(false)` if not, and
679    /// `Err(InvalidSessionError)` if an invalid session was found during the search (whether
680    /// current or not).
681    pub(crate) fn promote_matching_session(
682        &mut self,
683        version: u32,
684        alice_base_key: &[u8],
685    ) -> Result<bool, InvalidSessionError> {
686        if let Some(current_session) = &self.current_session {
687            if current_session.session_version()? == version
688                && alice_base_key
689                    .ct_eq(current_session.alice_base_key())
690                    .into()
691            {
692                return Ok(true);
693            }
694        }
695
696        let mut session_to_promote = None;
697        for (i, previous) in self.previous_session_states().enumerate() {
698            let previous = previous?;
699            if previous.session_version()? == version
700                && alice_base_key.ct_eq(previous.alice_base_key()).into()
701            {
702                session_to_promote = Some((i, previous));
703                break;
704            }
705        }
706
707        if let Some((i, state)) = session_to_promote {
708            self.promote_old_session(i, state);
709            return Ok(true);
710        }
711
712        Ok(false)
713    }
714
715    pub(crate) fn session_state(&self) -> Option<&SessionState> {
716        self.current_session.as_ref()
717    }
718
719    pub(crate) fn session_state_mut(&mut self) -> Option<&mut SessionState> {
720        self.current_session.as_mut()
721    }
722
723    pub(crate) fn set_session_state(&mut self, session: SessionState) {
724        self.current_session = Some(session);
725    }
726
727    pub(crate) fn previous_session_states(
728        &self,
729    ) -> impl ExactSizeIterator<Item = Result<SessionState, InvalidSessionError>> + '_ {
730        self.previous_sessions.iter().map(|bytes| {
731            Ok(SessionStructure::decode(&bytes[..])
732                .map_err(|_| InvalidSessionError("failed to decode previous session protobuf"))?
733                .into())
734        })
735    }
736
737    pub(crate) fn promote_old_session(
738        &mut self,
739        old_session: usize,
740        updated_session: SessionState,
741    ) {
742        self.previous_sessions.remove(old_session);
743        self.promote_state(updated_session)
744    }
745
746    pub(crate) fn promote_state(&mut self, new_state: SessionState) {
747        self.archive_current_state_inner();
748        self.current_session = Some(new_state);
749    }
750
751    // A non-fallible version of archive_current_state.
752    //
753    // Returns `true` if there was a session to archive, `false` if not.
754    fn archive_current_state_inner(&mut self) -> bool {
755        if let Some(mut current_session) = self.current_session.take() {
756            if self.previous_sessions.len() >= consts::ARCHIVED_STATES_MAX_LENGTH {
757                self.previous_sessions.pop();
758            }
759            current_session.clear_unacknowledged_pre_key_message();
760            self.previous_sessions
761                .insert(0, current_session.session.encode_to_vec());
762            true
763        } else {
764            false
765        }
766    }
767
768    pub fn archive_current_state(&mut self) -> Result<(), SignalProtocolError> {
769        if !self.archive_current_state_inner() {
770            log::info!("Skipping archive, current session state is fresh");
771        }
772        Ok(())
773    }
774
775    pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
776        let record = RecordStructure {
777            current_session: self.current_session.as_ref().map(|s| s.into()),
778            previous_sessions: self.previous_sessions.clone(),
779        };
780        Ok(record.encode_to_vec())
781    }
782
783    pub fn current_pq_state(&self) -> Option<&spqr::SerializedState> {
784        self.current_session.as_ref().map(|s| s.pq_ratchet_state())
785    }
786
787    pub fn remote_registration_id(&self) -> Result<u32, SignalProtocolError> {
788        Ok(self
789            .session_state()
790            .ok_or_else(|| {
791                SignalProtocolError::InvalidState(
792                    "remote_registration_id",
793                    "No current session".into(),
794                )
795            })?
796            .remote_registration_id())
797    }
798
799    pub fn local_registration_id(&self) -> Result<u32, SignalProtocolError> {
800        Ok(self
801            .session_state()
802            .ok_or_else(|| {
803                SignalProtocolError::InvalidState(
804                    "local_registration_id",
805                    "No current session".into(),
806                )
807            })?
808            .local_registration_id())
809    }
810
811    pub fn session_version(&self) -> Result<u32, SignalProtocolError> {
812        Ok(self
813            .session_state()
814            .ok_or_else(|| {
815                SignalProtocolError::InvalidState("session_version", "No current session".into())
816            })?
817            .session_version()?)
818    }
819
820    pub fn local_identity_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
821        Ok(self
822            .session_state()
823            .ok_or_else(|| {
824                SignalProtocolError::InvalidState(
825                    "local_identity_key_bytes",
826                    "No current session".into(),
827                )
828            })?
829            .local_identity_key_bytes()?)
830    }
831
832    pub fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, SignalProtocolError> {
833        Ok(self
834            .session_state()
835            .ok_or_else(|| {
836                SignalProtocolError::InvalidState(
837                    "remote_identity_key_bytes",
838                    "No current session".into(),
839                )
840            })?
841            .remote_identity_key_bytes()?)
842    }
843
844    pub fn has_usable_sender_chain(
845        &self,
846        now: SystemTime,
847        requirements: SessionUsabilityRequirements,
848    ) -> Result<bool, SignalProtocolError> {
849        match &self.current_session {
850            Some(session) => Ok(session.has_usable_sender_chain(now, requirements)?),
851            None => Ok(false),
852        }
853    }
854
855    pub fn alice_base_key(&self) -> Result<&[u8], SignalProtocolError> {
856        Ok(self
857            .session_state()
858            .ok_or_else(|| {
859                SignalProtocolError::InvalidState("alice_base_key", "No current session".into())
860            })?
861            .alice_base_key())
862    }
863
864    pub fn get_receiver_chain_key_bytes(
865        &self,
866        sender: &PublicKey,
867    ) -> Result<Option<Box<[u8]>>, SignalProtocolError> {
868        Ok(self
869            .session_state()
870            .ok_or_else(|| {
871                SignalProtocolError::InvalidState(
872                    "get_receiver_chain_key",
873                    "No current session".into(),
874                )
875            })?
876            .get_receiver_chain_key(sender)?
877            .map(|chain| chain.key()[..].into()))
878    }
879
880    pub fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
881        Ok(self
882            .session_state()
883            .ok_or_else(|| {
884                SignalProtocolError::InvalidState(
885                    "get_sender_chain_key_bytes",
886                    "No current session".into(),
887                )
888            })?
889            .get_sender_chain_key_bytes()?)
890    }
891
892    pub fn current_ratchet_key_matches(
893        &self,
894        key: &PublicKey,
895    ) -> Result<bool, SignalProtocolError> {
896        match &self.current_session {
897            Some(session) => Ok(&session.sender_ratchet_key()? == key),
898            None => Ok(false),
899        }
900    }
901
902    pub fn get_kyber_ciphertext(&self) -> Result<Option<&Vec<u8>>, SignalProtocolError> {
903        Ok(self
904            .session_state()
905            .ok_or_else(|| {
906                SignalProtocolError::InvalidState(
907                    "get_kyber_ciphertext",
908                    "No current session".into(),
909                )
910            })?
911            .get_kyber_ciphertext())
912    }
913}