libsignal_protocol/state/
session.rs

1//
2// Copyright 2020-2022 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::result::Result;
7use std::time::{Duration, SystemTime};
8
9use bitflags::bitflags;
10use prost::Message;
11use rand::{CryptoRng, Rng};
12use subtle::ConstantTimeEq;
13
14use crate::proto::storage::{RecordStructure, SessionStructure, session_structure};
15use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION;
16use crate::ratchet::{ChainKey, MessageKeyGenerator, RootKey};
17use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
18use crate::{IdentityKey, KeyPair, PrivateKey, PublicKey, SignalProtocolError, consts, kem};
19
20/// A distinct error type to keep from accidentally propagating deserialization errors.
21#[derive(Debug)]
22pub(crate) struct InvalidSessionError(&'static str);
23
24impl std::fmt::Display for InvalidSessionError {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        self.0.fmt(f)
27    }
28}
29
30impl From<InvalidSessionError> for SignalProtocolError {
31    fn from(e: InvalidSessionError) -> Self {
32        Self::InvalidSessionStructure(e.0)
33    }
34}
35
36#[derive(Debug, Clone)]
37pub(crate) struct UnacknowledgedPreKeyMessageItems<'a> {
38    pre_key_id: Option<PreKeyId>,
39    signed_pre_key_id: SignedPreKeyId,
40    base_key: PublicKey,
41    // Although we require PQXDH for all new sessions now,
42    // we may in theory have an existing X3DH unacknowledged session,
43    // so we leave these optional for now.
44    kyber_pre_key_id: Option<KyberPreKeyId>,
45    kyber_ciphertext: Option<&'a [u8]>,
46    timestamp: SystemTime,
47}
48
49impl<'a> UnacknowledgedPreKeyMessageItems<'a> {
50    fn new(
51        pre_key_id: Option<PreKeyId>,
52        signed_pre_key_id: SignedPreKeyId,
53        base_key: PublicKey,
54        pending_kyber_pre_key: Option<&'a session_structure::PendingKyberPreKey>,
55        timestamp: SystemTime,
56    ) -> Self {
57        let (kyber_pre_key_id, kyber_ciphertext) = pending_kyber_pre_key
58            .map(|pending| (pending.pre_key_id.into(), pending.ciphertext.as_slice()))
59            .unzip();
60        Self {
61            pre_key_id,
62            signed_pre_key_id,
63            base_key,
64            kyber_pre_key_id,
65            kyber_ciphertext,
66            timestamp,
67        }
68    }
69
70    pub(crate) fn pre_key_id(&self) -> Option<PreKeyId> {
71        self.pre_key_id
72    }
73
74    pub(crate) fn signed_pre_key_id(&self) -> SignedPreKeyId {
75        self.signed_pre_key_id
76    }
77
78    pub(crate) fn base_key(&self) -> &PublicKey {
79        &self.base_key
80    }
81
82    pub(crate) fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
83        self.kyber_pre_key_id
84    }
85
86    pub(crate) fn kyber_ciphertext(&self) -> Option<&'a [u8]> {
87        self.kyber_ciphertext
88    }
89
90    pub(crate) fn timestamp(&self) -> SystemTime {
91        self.timestamp
92    }
93}
94
95bitflags! {
96    /// Specifies which criteria make a session "usable" beyond simply having a present sender
97    /// chain.
98    ///
99    /// These requirements are conjunctive, i.e. specifying `NotStale | EstablishedWithPqxdh` means
100    /// the session must be neither stale nor established with X3DH.
101    ///
102    /// This struct is generated using the `bitflags` crate; the [`Flags`](::bitflags::Flags) trait
103    /// provides most of its API surface. It can also use "classic" C bitflag syntax, with `|` for
104    /// union and `&` for intersection (as shown above).
105    #[repr(transparent)]
106    #[derive(Clone, Copy, PartialEq, Eq)]
107    pub struct SessionUsabilityRequirements : u32 {
108        /// Requires that a session not be stale.
109        ///
110        /// A non-stale session is one of the following:
111        /// - "incoming", i.e. started by the peer
112        /// - "acknowledged", i.e. started locally but received a response
113        /// - no more than a few weeks old (the exact time is chosen by libsignal)
114        const NotStale = 1 << 0;
115        /// Requires that a session was established using PQXDH (or newer) rather than X3DH/X4DH.
116        ///
117        /// This includes unacknowledged sessions that are using PQXDH, since if they get a
118        /// response, the peer is confirmed to be using PQXDH as well.
119        const EstablishedWithPqxdh = 1 << 1;
120        /// Requires that a session is using SPQR.
121        ///
122        /// **Warning:** This allows unacknowledged sessions that include SPQR in their PreKey
123        /// messages. If the peer downgrades the session (by discarding the SPQR information) and
124        /// the local client allows it, a session that is previously considered "usable" can become
125        /// "not usable" upon receiving a response. Therefore, this should not be used to determine
126        /// whether a session is usable unless future downgrades will also be rejected.
127        const Spqr = 1 << 2;
128    }
129}
130
131#[derive(Clone, Debug)]
132pub(crate) struct SessionState {
133    session: SessionStructure,
134}
135
136impl SessionState {
137    pub(crate) fn from_session_structure(session: SessionStructure) -> Self {
138        Self { session }
139    }
140
141    pub(crate) fn new(
142        version: u8,
143        our_identity: &IdentityKey,
144        their_identity: &IdentityKey,
145        root_key: &RootKey,
146        alice_base_key: &PublicKey,
147        pq_ratchet_state: spqr::SerializedState,
148    ) -> Self {
149        Self {
150            session: SessionStructure {
151                session_version: version as u32,
152                local_identity_public: our_identity.public_key().serialize().into_vec(),
153                remote_identity_public: their_identity.serialize().into_vec(),
154                root_key: root_key.key().to_vec(),
155                previous_counter: 0,
156                sender_chain: None,
157                receiver_chains: vec![],
158                pending_pre_key: None,
159                pending_kyber_pre_key: None,
160                remote_registration_id: 0,
161                local_registration_id: 0,
162                alice_base_key: alice_base_key.serialize().into_vec(),
163                pq_ratchet_state,
164            },
165        }
166    }
167
168    pub(crate) fn alice_base_key(&self) -> &[u8] {
169        // Check the length before returning?
170        &self.session.alice_base_key
171    }
172
173    pub(crate) fn session_version(&self) -> Result<u32, InvalidSessionError> {
174        match self.session.session_version {
175            0 => Ok(2),
176            v => Ok(v),
177        }
178    }
179
180    pub(crate) fn remote_identity_key(&self) -> Result<Option<IdentityKey>, InvalidSessionError> {
181        match self.session.remote_identity_public.len() {
182            0 => Ok(None),
183            _ => Ok(Some(
184                IdentityKey::decode(&self.session.remote_identity_public)
185                    .map_err(|_| InvalidSessionError("invalid remote identity key"))?,
186            )),
187        }
188    }
189
190    pub(crate) fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, InvalidSessionError> {
191        Ok(self.remote_identity_key()?.map(|k| k.serialize().to_vec()))
192    }
193
194    pub(crate) fn local_identity_key(&self) -> Result<IdentityKey, InvalidSessionError> {
195        IdentityKey::decode(&self.session.local_identity_public)
196            .map_err(|_| InvalidSessionError("invalid local identity key"))
197    }
198
199    pub(crate) fn local_identity_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
200        Ok(self.local_identity_key()?.serialize().to_vec())
201    }
202
203    pub(crate) fn session_with_self(&self) -> Result<bool, InvalidSessionError> {
204        if let Some(remote_id) = self.remote_identity_key_bytes()? {
205            let local_id = self.local_identity_key_bytes()?;
206            return Ok(remote_id == local_id);
207        }
208
209        // If remote ID is not set then we can't be sure but treat as non-self
210        Ok(false)
211    }
212
213    pub(crate) fn previous_counter(&self) -> u32 {
214        self.session.previous_counter
215    }
216
217    pub(crate) fn set_previous_counter(&mut self, ctr: u32) {
218        self.session.previous_counter = ctr;
219    }
220
221    pub(crate) fn root_key(&self) -> Result<RootKey, InvalidSessionError> {
222        let root_key_bytes = self.session.root_key[..]
223            .try_into()
224            .map_err(|_| InvalidSessionError("invalid root key"))?;
225        Ok(RootKey::new(root_key_bytes))
226    }
227
228    pub(crate) fn set_root_key(&mut self, root_key: &RootKey) {
229        self.session.root_key = root_key.key().to_vec();
230    }
231
232    pub(crate) fn sender_ratchet_key(&self) -> Result<PublicKey, InvalidSessionError> {
233        match self.session.sender_chain {
234            None => Err(InvalidSessionError("missing sender chain")),
235            Some(ref c) => PublicKey::deserialize(&c.sender_ratchet_key)
236                .map_err(|_| InvalidSessionError("invalid sender chain ratchet key")),
237        }
238    }
239
240    pub(crate) fn sender_ratchet_key_for_logging(&self) -> Result<String, InvalidSessionError> {
241        Ok(hex::encode(self.sender_ratchet_key()?.public_key_bytes()))
242    }
243
244    pub(crate) fn sender_ratchet_private_key(&self) -> Result<PrivateKey, InvalidSessionError> {
245        match self.session.sender_chain {
246            None => Err(InvalidSessionError("missing sender chain")),
247            Some(ref c) => PrivateKey::deserialize(&c.sender_ratchet_key_private)
248                .map_err(|_| InvalidSessionError("invalid sender chain private ratchet key")),
249        }
250    }
251
252    pub fn has_usable_sender_chain(
253        &self,
254        now: SystemTime,
255        requirements: SessionUsabilityRequirements,
256    ) -> Result<bool, InvalidSessionError> {
257        if self.session.sender_chain.is_none() {
258            return Ok(false);
259        }
260        if requirements.contains(SessionUsabilityRequirements::NotStale) {
261            if let Some(pending_pre_key) = &self.session.pending_pre_key {
262                let creation_timestamp =
263                    SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp);
264                if creation_timestamp + consts::MAX_UNACKNOWLEDGED_SESSION_AGE < now {
265                    return Ok(false);
266                }
267            }
268        }
269        #[allow(clippy::collapsible_if)]
270        if requirements.contains(SessionUsabilityRequirements::EstablishedWithPqxdh) {
271            if self.session_version()? <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION.into() {
272                return Ok(false);
273            }
274        }
275        #[allow(clippy::collapsible_if)]
276        if requirements.contains(SessionUsabilityRequirements::Spqr) {
277            if self.pq_ratchet_state().is_empty() {
278                return Ok(false);
279            }
280        }
281        Ok(true)
282    }
283
284    pub(crate) fn all_receiver_chain_logging_info(&self) -> Vec<(Vec<u8>, Option<u32>)> {
285        let mut results = vec![];
286        for chain in self.session.receiver_chains.iter() {
287            let sender_ratchet_public = chain.sender_ratchet_key.clone();
288
289            let chain_key_idx = chain.chain_key.as_ref().map(|chain_key| chain_key.index);
290
291            results.push((sender_ratchet_public, chain_key_idx))
292        }
293        results
294    }
295
296    pub(crate) fn get_receiver_chain(
297        &self,
298        sender: &PublicKey,
299    ) -> Result<Option<(session_structure::Chain, usize)>, InvalidSessionError> {
300        for (idx, chain) in self.session.receiver_chains.iter().enumerate() {
301            // If we compared bytes directly it would be faster, but may miss non-canonical points.
302            // It's unclear if supporting such points is desirable.
303            let chain_ratchet_key = PublicKey::deserialize(&chain.sender_ratchet_key)
304                .map_err(|_| InvalidSessionError("invalid receiver chain ratchet key"))?;
305
306            if &chain_ratchet_key == sender {
307                return Ok(Some((chain.clone(), idx)));
308            }
309        }
310
311        Ok(None)
312    }
313
314    pub(crate) fn get_receiver_chain_key(
315        &self,
316        sender: &PublicKey,
317    ) -> Result<Option<ChainKey>, InvalidSessionError> {
318        match self.get_receiver_chain(sender)? {
319            None => Ok(None),
320            Some((chain, _)) => match chain.chain_key {
321                None => Err(InvalidSessionError("missing receiver chain key")),
322                Some(c) => {
323                    let chain_key_bytes = c.key[..]
324                        .try_into()
325                        .map_err(|_| InvalidSessionError("invalid receiver chain key"))?;
326                    Ok(Some(ChainKey::new(chain_key_bytes, c.index)))
327                }
328            },
329        }
330    }
331
332    pub(crate) fn add_receiver_chain(&mut self, sender: &PublicKey, chain_key: &ChainKey) {
333        let chain_key = session_structure::chain::ChainKey {
334            index: chain_key.index(),
335            key: chain_key.key().to_vec(),
336        };
337
338        let chain = session_structure::Chain {
339            sender_ratchet_key: sender.serialize().to_vec(),
340            sender_ratchet_key_private: vec![],
341            chain_key: Some(chain_key),
342            message_keys: vec![],
343        };
344
345        self.session.receiver_chains.push(chain);
346
347        if self.session.receiver_chains.len() > consts::MAX_RECEIVER_CHAINS {
348            log::info!(
349                "Trimming excessive receiver_chain for session with base key {}, chain count: {}",
350                self.sender_ratchet_key_for_logging()
351                    .unwrap_or_else(|e| format!("<error: {}>", e.0)),
352                self.session.receiver_chains.len()
353            );
354            self.session.receiver_chains.remove(0);
355        }
356    }
357
358    pub(crate) fn with_receiver_chain(mut self, sender: &PublicKey, chain_key: &ChainKey) -> Self {
359        self.add_receiver_chain(sender, chain_key);
360        self
361    }
362
363    pub(crate) fn set_sender_chain(&mut self, sender: &KeyPair, next_chain_key: &ChainKey) {
364        let chain_key = session_structure::chain::ChainKey {
365            index: next_chain_key.index(),
366            key: next_chain_key.key().to_vec(),
367        };
368
369        let new_chain = session_structure::Chain {
370            sender_ratchet_key: sender.public_key.serialize().to_vec(),
371            sender_ratchet_key_private: sender.private_key.serialize().to_vec(),
372            chain_key: Some(chain_key),
373            message_keys: vec![],
374        };
375
376        self.session.sender_chain = Some(new_chain);
377    }
378
379    pub(crate) fn with_sender_chain(mut self, sender: &KeyPair, next_chain_key: &ChainKey) -> Self {
380        self.set_sender_chain(sender, next_chain_key);
381        self
382    }
383
384    pub(crate) fn get_sender_chain_key(&self) -> Result<ChainKey, InvalidSessionError> {
385        let sender_chain = self
386            .session
387            .sender_chain
388            .as_ref()
389            .ok_or(InvalidSessionError("missing sender chain"))?;
390
391        let chain_key = sender_chain
392            .chain_key
393            .as_ref()
394            .ok_or(InvalidSessionError("missing sender chain key"))?;
395
396        let chain_key_bytes = chain_key.key[..]
397            .try_into()
398            .map_err(|_| InvalidSessionError("invalid sender chain key"))?;
399
400        Ok(ChainKey::new(chain_key_bytes, chain_key.index))
401    }
402
403    pub(crate) fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
404        Ok(self.get_sender_chain_key()?.key().to_vec())
405    }
406
407    pub(crate) fn set_sender_chain_key(&mut self, next_chain_key: &ChainKey) {
408        let chain_key = session_structure::chain::ChainKey {
409            index: next_chain_key.index(),
410            key: next_chain_key.key().to_vec(),
411        };
412
413        // Is it actually valid to call this function with sender_chain == None?
414
415        let new_chain = match self.session.sender_chain.take() {
416            None => session_structure::Chain {
417                sender_ratchet_key: vec![],
418                sender_ratchet_key_private: vec![],
419                chain_key: Some(chain_key),
420                message_keys: vec![],
421            },
422            Some(mut c) => {
423                c.chain_key = Some(chain_key);
424                c
425            }
426        };
427
428        self.session.sender_chain = Some(new_chain);
429    }
430
431    pub(crate) fn get_message_keys(
432        &mut self,
433        sender: &PublicKey,
434        counter: u32,
435    ) -> Result<Option<MessageKeyGenerator>, InvalidSessionError> {
436        if let Some(mut chain_and_index) = self.get_receiver_chain(sender)? {
437            let message_key_idx = chain_and_index
438                .0
439                .message_keys
440                .iter()
441                .position(|m| m.index == counter);
442
443            if let Some(position) = message_key_idx {
444                let message_key = chain_and_index.0.message_keys.remove(position);
445                let keys =
446                    MessageKeyGenerator::from_pb(message_key).map_err(InvalidSessionError)?;
447
448                // Update with message key removed
449                self.session.receiver_chains[chain_and_index.1] = chain_and_index.0;
450                return Ok(Some(keys));
451            }
452        }
453
454        Ok(None)
455    }
456
457    pub(crate) fn set_message_keys(
458        &mut self,
459        sender: &PublicKey,
460        message_keys: MessageKeyGenerator,
461    ) -> Result<(), InvalidSessionError> {
462        let chain_and_index = self
463            .get_receiver_chain(sender)?
464            .expect("called set_message_keys for a non-existent chain");
465        let mut updated_chain = chain_and_index.0;
466        updated_chain.message_keys.insert(0, message_keys.into_pb());
467
468        if updated_chain.message_keys.len() > consts::MAX_MESSAGE_KEYS {
469            updated_chain.message_keys.pop();
470        }
471
472        self.session.receiver_chains[chain_and_index.1] = updated_chain;
473
474        Ok(())
475    }
476
477    pub(crate) fn set_receiver_chain_key(
478        &mut self,
479        sender: &PublicKey,
480        chain_key: &ChainKey,
481    ) -> Result<(), InvalidSessionError> {
482        let chain_and_index = self
483            .get_receiver_chain(sender)?
484            .expect("called set_receiver_chain_key for a non-existent chain");
485        let mut updated_chain = chain_and_index.0;
486        updated_chain.chain_key = Some(session_structure::chain::ChainKey {
487            index: chain_key.index(),
488            key: chain_key.key().to_vec(),
489        });
490
491        self.session.receiver_chains[chain_and_index.1] = updated_chain;
492
493        Ok(())
494    }
495
496    pub(crate) fn set_unacknowledged_pre_key_message(
497        &mut self,
498        pre_key_id: Option<PreKeyId>,
499        signed_ec_pre_key_id: SignedPreKeyId,
500        base_key: &PublicKey,
501        now: SystemTime,
502    ) {
503        let signed_ec_pre_key_id: u32 = signed_ec_pre_key_id.into();
504        let pending = session_structure::PendingPreKey {
505            pre_key_id: pre_key_id.map(PreKeyId::into),
506            signed_pre_key_id: signed_ec_pre_key_id as i32,
507            base_key: base_key.serialize().to_vec(),
508            timestamp: now
509                .duration_since(SystemTime::UNIX_EPOCH)
510                .unwrap_or_default()
511                .as_secs(),
512        };
513        self.session.pending_pre_key = Some(pending);
514    }
515
516    pub(crate) fn set_kyber_ciphertext(&mut self, ciphertext: kem::SerializedCiphertext) {
517        let pending = session_structure::PendingKyberPreKey {
518            pre_key_id: u32::MAX, // has to be set to the actual value separately
519            ciphertext: ciphertext.into_vec(),
520        };
521        self.session.pending_kyber_pre_key = Some(pending);
522    }
523
524    pub(crate) fn set_unacknowledged_kyber_pre_key_id(
525        &mut self,
526        signed_kyber_pre_key_id: KyberPreKeyId,
527    ) {
528        let pending = self
529            .session
530            .pending_kyber_pre_key
531            .as_mut()
532            .expect("must have been set if kyber pre key is present");
533        pending.pre_key_id = signed_kyber_pre_key_id.into();
534    }
535
536    pub(crate) fn unacknowledged_pre_key_message_items(
537        &self,
538    ) -> Result<Option<UnacknowledgedPreKeyMessageItems<'_>>, InvalidSessionError> {
539        if let Some(ref pending_pre_key) = self.session.pending_pre_key {
540            Ok(Some(UnacknowledgedPreKeyMessageItems::new(
541                pending_pre_key.pre_key_id.map(Into::into),
542                (pending_pre_key.signed_pre_key_id as u32).into(),
543                PublicKey::deserialize(&pending_pre_key.base_key)
544                    .map_err(|_| InvalidSessionError("invalid pending PreKey message base key"))?,
545                self.session.pending_kyber_pre_key.as_ref(),
546                SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp),
547            )))
548        } else {
549            Ok(None)
550        }
551    }
552
553    pub(crate) fn clear_unacknowledged_pre_key_message(&mut self) {
554        // Explicitly destructuring the SessionStructure in case there are new
555        // pending fields that need to be cleared.
556        let SessionStructure {
557            session_version: _session_version,
558            local_identity_public: _local_identity_public,
559            remote_identity_public: _remote_identity_public,
560            root_key: _root_key,
561            previous_counter: _previous_counter,
562            sender_chain: _sender_chain,
563            receiver_chains: _receiver_chains,
564            pending_pre_key: _pending_pre_key,
565            pending_kyber_pre_key: _pending_kyber_pre_key,
566            remote_registration_id: _remote_registration_id,
567            local_registration_id: _local_registration_id,
568            alice_base_key: _alice_base_key,
569            pq_ratchet_state: _pq_ratchet_state,
570        } = &self.session;
571        // ####### IMPORTANT #######
572        // Don't forget to clean up new pending fields.
573        // ####### IMPORTANT #######
574        self.session.pending_pre_key = None;
575        self.session.pending_kyber_pre_key = None;
576    }
577
578    pub(crate) fn set_remote_registration_id(&mut self, registration_id: u32) {
579        self.session.remote_registration_id = registration_id;
580    }
581
582    pub(crate) fn remote_registration_id(&self) -> u32 {
583        self.session.remote_registration_id
584    }
585
586    pub(crate) fn set_local_registration_id(&mut self, registration_id: u32) {
587        self.session.local_registration_id = registration_id;
588    }
589
590    pub(crate) fn local_registration_id(&self) -> u32 {
591        self.session.local_registration_id
592    }
593
594    pub(crate) fn get_kyber_ciphertext(&self) -> Option<&Vec<u8>> {
595        self.session
596            .pending_kyber_pre_key
597            .as_ref()
598            .map(|pending| &pending.ciphertext)
599    }
600
601    pub(crate) fn pq_ratchet_recv(
602        &mut self,
603        msg: &spqr::SerializedMessage,
604    ) -> Result<spqr::MessageKey, spqr::Error> {
605        let spqr::Recv { state, key } = spqr::recv(&self.session.pq_ratchet_state, msg)?;
606        self.session.pq_ratchet_state = state;
607        Ok(key)
608    }
609
610    pub(crate) fn pq_ratchet_send<R: Rng + CryptoRng>(
611        &mut self,
612        csprng: &mut R,
613    ) -> Result<(spqr::SerializedMessage, spqr::MessageKey), spqr::Error> {
614        let spqr::Send { state, key, msg } = spqr::send(&self.session.pq_ratchet_state, csprng)?;
615        self.session.pq_ratchet_state = state;
616        Ok((msg, key))
617    }
618
619    pub(crate) fn pq_ratchet_state(&self) -> &spqr::SerializedState {
620        &self.session.pq_ratchet_state
621    }
622}
623
624impl From<SessionStructure> for SessionState {
625    fn from(value: SessionStructure) -> SessionState {
626        SessionState::from_session_structure(value)
627    }
628}
629
630impl From<SessionState> for SessionStructure {
631    fn from(value: SessionState) -> SessionStructure {
632        value.session
633    }
634}
635
636impl From<&SessionState> for SessionStructure {
637    fn from(value: &SessionState) -> SessionStructure {
638        value.session.clone()
639    }
640}
641
642#[derive(Clone)]
643pub struct SessionRecord {
644    current_session: Option<SessionState>,
645    previous_sessions: Vec<Vec<u8>>,
646}
647
648impl SessionRecord {
649    pub fn new_fresh() -> Self {
650        Self {
651            current_session: None,
652            previous_sessions: Vec::new(),
653        }
654    }
655
656    pub(crate) fn new(state: SessionState) -> Self {
657        Self {
658            current_session: Some(state),
659            previous_sessions: Vec::new(),
660        }
661    }
662
663    pub fn deserialize(bytes: &[u8]) -> Result<Self, SignalProtocolError> {
664        let record = RecordStructure::decode(bytes)
665            .map_err(|_| InvalidSessionError("failed to decode session record protobuf"))?;
666
667        Ok(Self {
668            current_session: record.current_session.map(|s| s.into()),
669            previous_sessions: record.previous_sessions,
670        })
671    }
672
673    /// If there's a session with a matching version and `alice_base_key`, ensures that it is the
674    /// current session, promoting if necessary.
675    ///
676    /// Returns `Ok(true)` if such a session was found, `Ok(false)` if not, and
677    /// `Err(InvalidSessionError)` if an invalid session was found during the search (whether
678    /// current or not).
679    pub(crate) fn promote_matching_session(
680        &mut self,
681        version: u32,
682        alice_base_key: &[u8],
683    ) -> Result<bool, InvalidSessionError> {
684        if let Some(current_session) = &self.current_session {
685            if current_session.session_version()? == version
686                && alice_base_key
687                    .ct_eq(current_session.alice_base_key())
688                    .into()
689            {
690                return Ok(true);
691            }
692        }
693
694        let mut session_to_promote = None;
695        for (i, previous) in self.previous_session_states().enumerate() {
696            let previous = previous?;
697            if previous.session_version()? == version
698                && alice_base_key.ct_eq(previous.alice_base_key()).into()
699            {
700                session_to_promote = Some((i, previous));
701                break;
702            }
703        }
704
705        if let Some((i, state)) = session_to_promote {
706            self.promote_old_session(i, state);
707            return Ok(true);
708        }
709
710        Ok(false)
711    }
712
713    pub(crate) fn session_state(&self) -> Option<&SessionState> {
714        self.current_session.as_ref()
715    }
716
717    pub(crate) fn session_state_mut(&mut self) -> Option<&mut SessionState> {
718        self.current_session.as_mut()
719    }
720
721    pub(crate) fn set_session_state(&mut self, session: SessionState) {
722        self.current_session = Some(session);
723    }
724
725    pub(crate) fn previous_session_states(
726        &self,
727    ) -> impl ExactSizeIterator<Item = Result<SessionState, InvalidSessionError>> + '_ {
728        self.previous_sessions.iter().map(|bytes| {
729            Ok(SessionStructure::decode(&bytes[..])
730                .map_err(|_| InvalidSessionError("failed to decode previous session protobuf"))?
731                .into())
732        })
733    }
734
735    pub(crate) fn promote_old_session(
736        &mut self,
737        old_session: usize,
738        updated_session: SessionState,
739    ) {
740        self.previous_sessions.remove(old_session);
741        self.promote_state(updated_session)
742    }
743
744    pub(crate) fn promote_state(&mut self, new_state: SessionState) {
745        self.archive_current_state_inner();
746        self.current_session = Some(new_state);
747    }
748
749    // A non-fallible version of archive_current_state.
750    //
751    // Returns `true` if there was a session to archive, `false` if not.
752    fn archive_current_state_inner(&mut self) -> bool {
753        if let Some(mut current_session) = self.current_session.take() {
754            if self.previous_sessions.len() >= consts::ARCHIVED_STATES_MAX_LENGTH {
755                self.previous_sessions.pop();
756            }
757            current_session.clear_unacknowledged_pre_key_message();
758            self.previous_sessions
759                .insert(0, current_session.session.encode_to_vec());
760            true
761        } else {
762            false
763        }
764    }
765
766    pub fn archive_current_state(&mut self) -> Result<(), SignalProtocolError> {
767        if !self.archive_current_state_inner() {
768            log::info!("Skipping archive, current session state is fresh");
769        }
770        Ok(())
771    }
772
773    pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
774        let record = RecordStructure {
775            current_session: self.current_session.as_ref().map(|s| s.into()),
776            previous_sessions: self.previous_sessions.clone(),
777        };
778        Ok(record.encode_to_vec())
779    }
780
781    pub fn current_pq_state(&self) -> Option<&spqr::SerializedState> {
782        self.current_session.as_ref().map(|s| s.pq_ratchet_state())
783    }
784
785    pub fn remote_registration_id(&self) -> Result<u32, SignalProtocolError> {
786        Ok(self
787            .session_state()
788            .ok_or_else(|| {
789                SignalProtocolError::InvalidState(
790                    "remote_registration_id",
791                    "No current session".into(),
792                )
793            })?
794            .remote_registration_id())
795    }
796
797    pub fn local_registration_id(&self) -> Result<u32, SignalProtocolError> {
798        Ok(self
799            .session_state()
800            .ok_or_else(|| {
801                SignalProtocolError::InvalidState(
802                    "local_registration_id",
803                    "No current session".into(),
804                )
805            })?
806            .local_registration_id())
807    }
808
809    pub fn session_version(&self) -> Result<u32, SignalProtocolError> {
810        Ok(self
811            .session_state()
812            .ok_or_else(|| {
813                SignalProtocolError::InvalidState("session_version", "No current session".into())
814            })?
815            .session_version()?)
816    }
817
818    pub fn local_identity_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
819        Ok(self
820            .session_state()
821            .ok_or_else(|| {
822                SignalProtocolError::InvalidState(
823                    "local_identity_key_bytes",
824                    "No current session".into(),
825                )
826            })?
827            .local_identity_key_bytes()?)
828    }
829
830    pub fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, SignalProtocolError> {
831        Ok(self
832            .session_state()
833            .ok_or_else(|| {
834                SignalProtocolError::InvalidState(
835                    "remote_identity_key_bytes",
836                    "No current session".into(),
837                )
838            })?
839            .remote_identity_key_bytes()?)
840    }
841
842    pub fn has_usable_sender_chain(
843        &self,
844        now: SystemTime,
845        requirements: SessionUsabilityRequirements,
846    ) -> Result<bool, SignalProtocolError> {
847        match &self.current_session {
848            Some(session) => Ok(session.has_usable_sender_chain(now, requirements)?),
849            None => Ok(false),
850        }
851    }
852
853    pub fn alice_base_key(&self) -> Result<&[u8], SignalProtocolError> {
854        Ok(self
855            .session_state()
856            .ok_or_else(|| {
857                SignalProtocolError::InvalidState("alice_base_key", "No current session".into())
858            })?
859            .alice_base_key())
860    }
861
862    pub fn get_receiver_chain_key_bytes(
863        &self,
864        sender: &PublicKey,
865    ) -> Result<Option<Box<[u8]>>, SignalProtocolError> {
866        Ok(self
867            .session_state()
868            .ok_or_else(|| {
869                SignalProtocolError::InvalidState(
870                    "get_receiver_chain_key",
871                    "No current session".into(),
872                )
873            })?
874            .get_receiver_chain_key(sender)?
875            .map(|chain| chain.key()[..].into()))
876    }
877
878    pub fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
879        Ok(self
880            .session_state()
881            .ok_or_else(|| {
882                SignalProtocolError::InvalidState(
883                    "get_sender_chain_key_bytes",
884                    "No current session".into(),
885                )
886            })?
887            .get_sender_chain_key_bytes()?)
888    }
889
890    pub fn current_ratchet_key_matches(
891        &self,
892        key: &PublicKey,
893    ) -> Result<bool, SignalProtocolError> {
894        match &self.current_session {
895            Some(session) => Ok(&session.sender_ratchet_key()? == key),
896            None => Ok(false),
897        }
898    }
899
900    pub fn get_kyber_ciphertext(&self) -> Result<Option<&Vec<u8>>, SignalProtocolError> {
901        Ok(self
902            .session_state()
903            .ok_or_else(|| {
904                SignalProtocolError::InvalidState(
905                    "get_kyber_ciphertext",
906                    "No current session".into(),
907                )
908            })?
909            .get_kyber_ciphertext())
910    }
911}