libsignal_protocol/state/
session.rs

1//
2// Copyright 2020-2022 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::result::Result;
7use std::time::{Duration, SystemTime};
8
9use prost::Message;
10use rand::{CryptoRng, Rng};
11use subtle::ConstantTimeEq;
12
13use crate::proto::storage::{session_structure, RecordStructure, SessionStructure};
14use crate::ratchet::{ChainKey, MessageKeyGenerator, RootKey};
15use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
16use crate::{consts, kem, IdentityKey, KeyPair, PrivateKey, PublicKey, SignalProtocolError};
17
18/// A distinct error type to keep from accidentally propagating deserialization errors.
19#[derive(Debug)]
20pub(crate) struct InvalidSessionError(&'static str);
21
22impl std::fmt::Display for InvalidSessionError {
23    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24        self.0.fmt(f)
25    }
26}
27
28impl From<InvalidSessionError> for SignalProtocolError {
29    fn from(e: InvalidSessionError) -> Self {
30        Self::InvalidSessionStructure(e.0)
31    }
32}
33
34#[derive(Debug, Clone)]
35pub(crate) struct UnacknowledgedPreKeyMessageItems<'a> {
36    pre_key_id: Option<PreKeyId>,
37    signed_pre_key_id: SignedPreKeyId,
38    base_key: PublicKey,
39    // Although we require PQXDH for all new sessions now,
40    // we may in theory have an existing X3DH unacknowledged session,
41    // so we leave these optional for now.
42    kyber_pre_key_id: Option<KyberPreKeyId>,
43    kyber_ciphertext: Option<&'a [u8]>,
44    timestamp: SystemTime,
45}
46
47impl<'a> UnacknowledgedPreKeyMessageItems<'a> {
48    fn new(
49        pre_key_id: Option<PreKeyId>,
50        signed_pre_key_id: SignedPreKeyId,
51        base_key: PublicKey,
52        pending_kyber_pre_key: Option<&'a session_structure::PendingKyberPreKey>,
53        timestamp: SystemTime,
54    ) -> Self {
55        let (kyber_pre_key_id, kyber_ciphertext) = pending_kyber_pre_key
56            .map(|pending| (pending.pre_key_id.into(), pending.ciphertext.as_slice()))
57            .unzip();
58        Self {
59            pre_key_id,
60            signed_pre_key_id,
61            base_key,
62            kyber_pre_key_id,
63            kyber_ciphertext,
64            timestamp,
65        }
66    }
67
68    pub(crate) fn pre_key_id(&self) -> Option<PreKeyId> {
69        self.pre_key_id
70    }
71
72    pub(crate) fn signed_pre_key_id(&self) -> SignedPreKeyId {
73        self.signed_pre_key_id
74    }
75
76    pub(crate) fn base_key(&self) -> &PublicKey {
77        &self.base_key
78    }
79
80    pub(crate) fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
81        self.kyber_pre_key_id
82    }
83
84    pub(crate) fn kyber_ciphertext(&self) -> Option<&'a [u8]> {
85        self.kyber_ciphertext
86    }
87
88    pub(crate) fn timestamp(&self) -> SystemTime {
89        self.timestamp
90    }
91}
92
93#[derive(Clone, Debug)]
94pub(crate) struct SessionState {
95    session: SessionStructure,
96}
97
98impl SessionState {
99    pub(crate) fn from_session_structure(session: SessionStructure) -> Self {
100        Self { session }
101    }
102
103    pub(crate) fn new(
104        version: u8,
105        our_identity: &IdentityKey,
106        their_identity: &IdentityKey,
107        root_key: &RootKey,
108        alice_base_key: &PublicKey,
109        pq_ratchet_state: spqr::SerializedState,
110    ) -> Self {
111        Self {
112            session: SessionStructure {
113                session_version: version as u32,
114                local_identity_public: our_identity.public_key().serialize().into_vec(),
115                remote_identity_public: their_identity.serialize().into_vec(),
116                root_key: root_key.key().to_vec(),
117                previous_counter: 0,
118                sender_chain: None,
119                receiver_chains: vec![],
120                pending_pre_key: None,
121                pending_kyber_pre_key: None,
122                remote_registration_id: 0,
123                local_registration_id: 0,
124                alice_base_key: alice_base_key.serialize().into_vec(),
125                pq_ratchet_state,
126            },
127        }
128    }
129
130    pub(crate) fn alice_base_key(&self) -> &[u8] {
131        // Check the length before returning?
132        &self.session.alice_base_key
133    }
134
135    pub(crate) fn session_version(&self) -> Result<u32, InvalidSessionError> {
136        match self.session.session_version {
137            0 => Ok(2),
138            v => Ok(v),
139        }
140    }
141
142    pub(crate) fn remote_identity_key(&self) -> Result<Option<IdentityKey>, InvalidSessionError> {
143        match self.session.remote_identity_public.len() {
144            0 => Ok(None),
145            _ => Ok(Some(
146                IdentityKey::decode(&self.session.remote_identity_public)
147                    .map_err(|_| InvalidSessionError("invalid remote identity key"))?,
148            )),
149        }
150    }
151
152    pub(crate) fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, InvalidSessionError> {
153        Ok(self.remote_identity_key()?.map(|k| k.serialize().to_vec()))
154    }
155
156    pub(crate) fn local_identity_key(&self) -> Result<IdentityKey, InvalidSessionError> {
157        IdentityKey::decode(&self.session.local_identity_public)
158            .map_err(|_| InvalidSessionError("invalid local identity key"))
159    }
160
161    pub(crate) fn local_identity_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
162        Ok(self.local_identity_key()?.serialize().to_vec())
163    }
164
165    pub(crate) fn session_with_self(&self) -> Result<bool, InvalidSessionError> {
166        if let Some(remote_id) = self.remote_identity_key_bytes()? {
167            let local_id = self.local_identity_key_bytes()?;
168            return Ok(remote_id == local_id);
169        }
170
171        // If remote ID is not set then we can't be sure but treat as non-self
172        Ok(false)
173    }
174
175    pub(crate) fn previous_counter(&self) -> u32 {
176        self.session.previous_counter
177    }
178
179    pub(crate) fn set_previous_counter(&mut self, ctr: u32) {
180        self.session.previous_counter = ctr;
181    }
182
183    pub(crate) fn root_key(&self) -> Result<RootKey, InvalidSessionError> {
184        let root_key_bytes = self.session.root_key[..]
185            .try_into()
186            .map_err(|_| InvalidSessionError("invalid root key"))?;
187        Ok(RootKey::new(root_key_bytes))
188    }
189
190    pub(crate) fn set_root_key(&mut self, root_key: &RootKey) {
191        self.session.root_key = root_key.key().to_vec();
192    }
193
194    pub(crate) fn sender_ratchet_key(&self) -> Result<PublicKey, InvalidSessionError> {
195        match self.session.sender_chain {
196            None => Err(InvalidSessionError("missing sender chain")),
197            Some(ref c) => PublicKey::deserialize(&c.sender_ratchet_key)
198                .map_err(|_| InvalidSessionError("invalid sender chain ratchet key")),
199        }
200    }
201
202    pub(crate) fn sender_ratchet_key_for_logging(&self) -> Result<String, InvalidSessionError> {
203        Ok(hex::encode(self.sender_ratchet_key()?.public_key_bytes()))
204    }
205
206    pub(crate) fn sender_ratchet_private_key(&self) -> Result<PrivateKey, InvalidSessionError> {
207        match self.session.sender_chain {
208            None => Err(InvalidSessionError("missing sender chain")),
209            Some(ref c) => PrivateKey::deserialize(&c.sender_ratchet_key_private)
210                .map_err(|_| InvalidSessionError("invalid sender chain private ratchet key")),
211        }
212    }
213
214    pub fn has_usable_sender_chain(&self, now: SystemTime) -> Result<bool, InvalidSessionError> {
215        if self.session.sender_chain.is_none() {
216            return Ok(false);
217        }
218        if let Some(pending_pre_key) = &self.session.pending_pre_key {
219            let creation_timestamp =
220                SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp);
221            if creation_timestamp + consts::MAX_UNACKNOWLEDGED_SESSION_AGE < now {
222                return Ok(false);
223            }
224        }
225        Ok(true)
226    }
227
228    pub(crate) fn all_receiver_chain_logging_info(&self) -> Vec<(Vec<u8>, Option<u32>)> {
229        let mut results = vec![];
230        for chain in self.session.receiver_chains.iter() {
231            let sender_ratchet_public = chain.sender_ratchet_key.clone();
232
233            let chain_key_idx = chain.chain_key.as_ref().map(|chain_key| chain_key.index);
234
235            results.push((sender_ratchet_public, chain_key_idx))
236        }
237        results
238    }
239
240    pub(crate) fn get_receiver_chain(
241        &self,
242        sender: &PublicKey,
243    ) -> Result<Option<(session_structure::Chain, usize)>, InvalidSessionError> {
244        for (idx, chain) in self.session.receiver_chains.iter().enumerate() {
245            // If we compared bytes directly it would be faster, but may miss non-canonical points.
246            // It's unclear if supporting such points is desirable.
247            let chain_ratchet_key = PublicKey::deserialize(&chain.sender_ratchet_key)
248                .map_err(|_| InvalidSessionError("invalid receiver chain ratchet key"))?;
249
250            if &chain_ratchet_key == sender {
251                return Ok(Some((chain.clone(), idx)));
252            }
253        }
254
255        Ok(None)
256    }
257
258    pub(crate) fn get_receiver_chain_key(
259        &self,
260        sender: &PublicKey,
261    ) -> Result<Option<ChainKey>, InvalidSessionError> {
262        match self.get_receiver_chain(sender)? {
263            None => Ok(None),
264            Some((chain, _)) => match chain.chain_key {
265                None => Err(InvalidSessionError("missing receiver chain key")),
266                Some(c) => {
267                    let chain_key_bytes = c.key[..]
268                        .try_into()
269                        .map_err(|_| InvalidSessionError("invalid receiver chain key"))?;
270                    Ok(Some(ChainKey::new(chain_key_bytes, c.index)))
271                }
272            },
273        }
274    }
275
276    pub(crate) fn add_receiver_chain(&mut self, sender: &PublicKey, chain_key: &ChainKey) {
277        let chain_key = session_structure::chain::ChainKey {
278            index: chain_key.index(),
279            key: chain_key.key().to_vec(),
280        };
281
282        let chain = session_structure::Chain {
283            sender_ratchet_key: sender.serialize().to_vec(),
284            sender_ratchet_key_private: vec![],
285            chain_key: Some(chain_key),
286            message_keys: vec![],
287        };
288
289        self.session.receiver_chains.push(chain);
290
291        if self.session.receiver_chains.len() > consts::MAX_RECEIVER_CHAINS {
292            log::info!(
293                "Trimming excessive receiver_chain for session with base key {}, chain count: {}",
294                self.sender_ratchet_key_for_logging()
295                    .unwrap_or_else(|e| format!("<error: {}>", e.0)),
296                self.session.receiver_chains.len()
297            );
298            self.session.receiver_chains.remove(0);
299        }
300    }
301
302    pub(crate) fn with_receiver_chain(mut self, sender: &PublicKey, chain_key: &ChainKey) -> Self {
303        self.add_receiver_chain(sender, chain_key);
304        self
305    }
306
307    pub(crate) fn set_sender_chain(&mut self, sender: &KeyPair, next_chain_key: &ChainKey) {
308        let chain_key = session_structure::chain::ChainKey {
309            index: next_chain_key.index(),
310            key: next_chain_key.key().to_vec(),
311        };
312
313        let new_chain = session_structure::Chain {
314            sender_ratchet_key: sender.public_key.serialize().to_vec(),
315            sender_ratchet_key_private: sender.private_key.serialize().to_vec(),
316            chain_key: Some(chain_key),
317            message_keys: vec![],
318        };
319
320        self.session.sender_chain = Some(new_chain);
321    }
322
323    pub(crate) fn with_sender_chain(mut self, sender: &KeyPair, next_chain_key: &ChainKey) -> Self {
324        self.set_sender_chain(sender, next_chain_key);
325        self
326    }
327
328    pub(crate) fn get_sender_chain_key(&self) -> Result<ChainKey, InvalidSessionError> {
329        let sender_chain = self
330            .session
331            .sender_chain
332            .as_ref()
333            .ok_or(InvalidSessionError("missing sender chain"))?;
334
335        let chain_key = sender_chain
336            .chain_key
337            .as_ref()
338            .ok_or(InvalidSessionError("missing sender chain key"))?;
339
340        let chain_key_bytes = chain_key.key[..]
341            .try_into()
342            .map_err(|_| InvalidSessionError("invalid sender chain key"))?;
343
344        Ok(ChainKey::new(chain_key_bytes, chain_key.index))
345    }
346
347    pub(crate) fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
348        Ok(self.get_sender_chain_key()?.key().to_vec())
349    }
350
351    pub(crate) fn set_sender_chain_key(&mut self, next_chain_key: &ChainKey) {
352        let chain_key = session_structure::chain::ChainKey {
353            index: next_chain_key.index(),
354            key: next_chain_key.key().to_vec(),
355        };
356
357        // Is it actually valid to call this function with sender_chain == None?
358
359        let new_chain = match self.session.sender_chain.take() {
360            None => session_structure::Chain {
361                sender_ratchet_key: vec![],
362                sender_ratchet_key_private: vec![],
363                chain_key: Some(chain_key),
364                message_keys: vec![],
365            },
366            Some(mut c) => {
367                c.chain_key = Some(chain_key);
368                c
369            }
370        };
371
372        self.session.sender_chain = Some(new_chain);
373    }
374
375    pub(crate) fn get_message_keys(
376        &mut self,
377        sender: &PublicKey,
378        counter: u32,
379    ) -> Result<Option<MessageKeyGenerator>, InvalidSessionError> {
380        if let Some(mut chain_and_index) = self.get_receiver_chain(sender)? {
381            let message_key_idx = chain_and_index
382                .0
383                .message_keys
384                .iter()
385                .position(|m| m.index == counter);
386
387            if let Some(position) = message_key_idx {
388                let message_key = chain_and_index.0.message_keys.remove(position);
389                let keys =
390                    MessageKeyGenerator::from_pb(message_key).map_err(InvalidSessionError)?;
391
392                // Update with message key removed
393                self.session.receiver_chains[chain_and_index.1] = chain_and_index.0;
394                return Ok(Some(keys));
395            }
396        }
397
398        Ok(None)
399    }
400
401    pub(crate) fn set_message_keys(
402        &mut self,
403        sender: &PublicKey,
404        message_keys: MessageKeyGenerator,
405    ) -> Result<(), InvalidSessionError> {
406        let chain_and_index = self
407            .get_receiver_chain(sender)?
408            .expect("called set_message_keys for a non-existent chain");
409        let mut updated_chain = chain_and_index.0;
410        updated_chain.message_keys.insert(0, message_keys.into_pb());
411
412        if updated_chain.message_keys.len() > consts::MAX_MESSAGE_KEYS {
413            updated_chain.message_keys.pop();
414        }
415
416        self.session.receiver_chains[chain_and_index.1] = updated_chain;
417
418        Ok(())
419    }
420
421    pub(crate) fn set_receiver_chain_key(
422        &mut self,
423        sender: &PublicKey,
424        chain_key: &ChainKey,
425    ) -> Result<(), InvalidSessionError> {
426        let chain_and_index = self
427            .get_receiver_chain(sender)?
428            .expect("called set_receiver_chain_key for a non-existent chain");
429        let mut updated_chain = chain_and_index.0;
430        updated_chain.chain_key = Some(session_structure::chain::ChainKey {
431            index: chain_key.index(),
432            key: chain_key.key().to_vec(),
433        });
434
435        self.session.receiver_chains[chain_and_index.1] = updated_chain;
436
437        Ok(())
438    }
439
440    pub(crate) fn set_unacknowledged_pre_key_message(
441        &mut self,
442        pre_key_id: Option<PreKeyId>,
443        signed_ec_pre_key_id: SignedPreKeyId,
444        base_key: &PublicKey,
445        now: SystemTime,
446    ) {
447        let signed_ec_pre_key_id: u32 = signed_ec_pre_key_id.into();
448        let pending = session_structure::PendingPreKey {
449            pre_key_id: pre_key_id.map(PreKeyId::into),
450            signed_pre_key_id: signed_ec_pre_key_id as i32,
451            base_key: base_key.serialize().to_vec(),
452            timestamp: now
453                .duration_since(SystemTime::UNIX_EPOCH)
454                .unwrap_or_default()
455                .as_secs(),
456        };
457        self.session.pending_pre_key = Some(pending);
458    }
459
460    pub(crate) fn set_kyber_ciphertext(&mut self, ciphertext: kem::SerializedCiphertext) {
461        let pending = session_structure::PendingKyberPreKey {
462            pre_key_id: u32::MAX, // has to be set to the actual value separately
463            ciphertext: ciphertext.into_vec(),
464        };
465        self.session.pending_kyber_pre_key = Some(pending);
466    }
467
468    pub(crate) fn set_unacknowledged_kyber_pre_key_id(
469        &mut self,
470        signed_kyber_pre_key_id: KyberPreKeyId,
471    ) {
472        let pending = self
473            .session
474            .pending_kyber_pre_key
475            .as_mut()
476            .expect("must have been set if kyber pre key is present");
477        pending.pre_key_id = signed_kyber_pre_key_id.into();
478    }
479
480    pub(crate) fn unacknowledged_pre_key_message_items(
481        &self,
482    ) -> Result<Option<UnacknowledgedPreKeyMessageItems>, InvalidSessionError> {
483        if let Some(ref pending_pre_key) = self.session.pending_pre_key {
484            Ok(Some(UnacknowledgedPreKeyMessageItems::new(
485                pending_pre_key.pre_key_id.map(Into::into),
486                (pending_pre_key.signed_pre_key_id as u32).into(),
487                PublicKey::deserialize(&pending_pre_key.base_key)
488                    .map_err(|_| InvalidSessionError("invalid pending PreKey message base key"))?,
489                self.session.pending_kyber_pre_key.as_ref(),
490                SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp),
491            )))
492        } else {
493            Ok(None)
494        }
495    }
496
497    pub(crate) fn clear_unacknowledged_pre_key_message(&mut self) {
498        // Explicitly destructuring the SessionStructure in case there are new
499        // pending fields that need to be cleared.
500        let SessionStructure {
501            session_version: _session_version,
502            local_identity_public: _local_identity_public,
503            remote_identity_public: _remote_identity_public,
504            root_key: _root_key,
505            previous_counter: _previous_counter,
506            sender_chain: _sender_chain,
507            receiver_chains: _receiver_chains,
508            pending_pre_key: _pending_pre_key,
509            pending_kyber_pre_key: _pending_kyber_pre_key,
510            remote_registration_id: _remote_registration_id,
511            local_registration_id: _local_registration_id,
512            alice_base_key: _alice_base_key,
513            pq_ratchet_state: _pq_ratchet_state,
514        } = &self.session;
515        // ####### IMPORTANT #######
516        // Don't forget to clean up new pending fields.
517        // ####### IMPORTANT #######
518        self.session.pending_pre_key = None;
519        self.session.pending_kyber_pre_key = None;
520    }
521
522    pub(crate) fn set_remote_registration_id(&mut self, registration_id: u32) {
523        self.session.remote_registration_id = registration_id;
524    }
525
526    pub(crate) fn remote_registration_id(&self) -> u32 {
527        self.session.remote_registration_id
528    }
529
530    pub(crate) fn set_local_registration_id(&mut self, registration_id: u32) {
531        self.session.local_registration_id = registration_id;
532    }
533
534    pub(crate) fn local_registration_id(&self) -> u32 {
535        self.session.local_registration_id
536    }
537
538    pub(crate) fn get_kyber_ciphertext(&self) -> Option<&Vec<u8>> {
539        self.session
540            .pending_kyber_pre_key
541            .as_ref()
542            .map(|pending| &pending.ciphertext)
543    }
544
545    pub(crate) fn pq_ratchet_recv(
546        &mut self,
547        msg: &spqr::SerializedMessage,
548    ) -> Result<spqr::MessageKey, spqr::Error> {
549        let spqr::Recv { state, key } = spqr::recv(&self.session.pq_ratchet_state, msg)?;
550        self.session.pq_ratchet_state = state;
551        Ok(key)
552    }
553
554    pub(crate) fn pq_ratchet_send<R: Rng + CryptoRng>(
555        &mut self,
556        csprng: &mut R,
557    ) -> Result<(spqr::SerializedMessage, spqr::MessageKey), spqr::Error> {
558        let spqr::Send { state, key, msg } = spqr::send(&self.session.pq_ratchet_state, csprng)?;
559        self.session.pq_ratchet_state = state;
560        Ok((msg, key))
561    }
562
563    pub(crate) fn pq_ratchet_state(&self) -> &spqr::SerializedState {
564        &self.session.pq_ratchet_state
565    }
566}
567
568impl From<SessionStructure> for SessionState {
569    fn from(value: SessionStructure) -> SessionState {
570        SessionState::from_session_structure(value)
571    }
572}
573
574impl From<SessionState> for SessionStructure {
575    fn from(value: SessionState) -> SessionStructure {
576        value.session
577    }
578}
579
580impl From<&SessionState> for SessionStructure {
581    fn from(value: &SessionState) -> SessionStructure {
582        value.session.clone()
583    }
584}
585
586#[derive(Clone)]
587pub struct SessionRecord {
588    current_session: Option<SessionState>,
589    previous_sessions: Vec<Vec<u8>>,
590}
591
592impl SessionRecord {
593    pub fn new_fresh() -> Self {
594        Self {
595            current_session: None,
596            previous_sessions: Vec::new(),
597        }
598    }
599
600    pub(crate) fn new(state: SessionState) -> Self {
601        Self {
602            current_session: Some(state),
603            previous_sessions: Vec::new(),
604        }
605    }
606
607    pub fn deserialize(bytes: &[u8]) -> Result<Self, SignalProtocolError> {
608        let record = RecordStructure::decode(bytes)
609            .map_err(|_| InvalidSessionError("failed to decode session record protobuf"))?;
610
611        Ok(Self {
612            current_session: record.current_session.map(|s| s.into()),
613            previous_sessions: record.previous_sessions,
614        })
615    }
616
617    /// If there's a session with a matching version and `alice_base_key`, ensures that it is the
618    /// current session, promoting if necessary.
619    ///
620    /// Returns `Ok(true)` if such a session was found, `Ok(false)` if not, and
621    /// `Err(InvalidSessionError)` if an invalid session was found during the search (whether
622    /// current or not).
623    pub(crate) fn promote_matching_session(
624        &mut self,
625        version: u32,
626        alice_base_key: &[u8],
627    ) -> Result<bool, InvalidSessionError> {
628        if let Some(current_session) = &self.current_session {
629            if current_session.session_version()? == version
630                && alice_base_key
631                    .ct_eq(current_session.alice_base_key())
632                    .into()
633            {
634                return Ok(true);
635            }
636        }
637
638        let mut session_to_promote = None;
639        for (i, previous) in self.previous_session_states().enumerate() {
640            let previous = previous?;
641            if previous.session_version()? == version
642                && alice_base_key.ct_eq(previous.alice_base_key()).into()
643            {
644                session_to_promote = Some((i, previous));
645                break;
646            }
647        }
648
649        if let Some((i, state)) = session_to_promote {
650            self.promote_old_session(i, state);
651            return Ok(true);
652        }
653
654        Ok(false)
655    }
656
657    pub(crate) fn session_state(&self) -> Option<&SessionState> {
658        self.current_session.as_ref()
659    }
660
661    pub(crate) fn session_state_mut(&mut self) -> Option<&mut SessionState> {
662        self.current_session.as_mut()
663    }
664
665    pub(crate) fn set_session_state(&mut self, session: SessionState) {
666        self.current_session = Some(session);
667    }
668
669    pub(crate) fn previous_session_states(
670        &self,
671    ) -> impl ExactSizeIterator<Item = Result<SessionState, InvalidSessionError>> + '_ {
672        self.previous_sessions.iter().map(|bytes| {
673            Ok(SessionStructure::decode(&bytes[..])
674                .map_err(|_| InvalidSessionError("failed to decode previous session protobuf"))?
675                .into())
676        })
677    }
678
679    pub(crate) fn promote_old_session(
680        &mut self,
681        old_session: usize,
682        updated_session: SessionState,
683    ) {
684        self.previous_sessions.remove(old_session);
685        self.promote_state(updated_session)
686    }
687
688    pub(crate) fn promote_state(&mut self, new_state: SessionState) {
689        self.archive_current_state_inner();
690        self.current_session = Some(new_state);
691    }
692
693    // A non-fallible version of archive_current_state.
694    //
695    // Returns `true` if there was a session to archive, `false` if not.
696    fn archive_current_state_inner(&mut self) -> bool {
697        if let Some(mut current_session) = self.current_session.take() {
698            if self.previous_sessions.len() >= consts::ARCHIVED_STATES_MAX_LENGTH {
699                self.previous_sessions.pop();
700            }
701            current_session.clear_unacknowledged_pre_key_message();
702            self.previous_sessions
703                .insert(0, current_session.session.encode_to_vec());
704            true
705        } else {
706            false
707        }
708    }
709
710    pub fn archive_current_state(&mut self) -> Result<(), SignalProtocolError> {
711        if !self.archive_current_state_inner() {
712            log::info!("Skipping archive, current session state is fresh");
713        }
714        Ok(())
715    }
716
717    pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
718        let record = RecordStructure {
719            current_session: self.current_session.as_ref().map(|s| s.into()),
720            previous_sessions: self.previous_sessions.clone(),
721        };
722        Ok(record.encode_to_vec())
723    }
724
725    pub fn current_pq_state(&self) -> Option<&spqr::SerializedState> {
726        self.current_session.as_ref().map(|s| s.pq_ratchet_state())
727    }
728
729    pub fn remote_registration_id(&self) -> Result<u32, SignalProtocolError> {
730        Ok(self
731            .session_state()
732            .ok_or_else(|| {
733                SignalProtocolError::InvalidState(
734                    "remote_registration_id",
735                    "No current session".into(),
736                )
737            })?
738            .remote_registration_id())
739    }
740
741    pub fn local_registration_id(&self) -> Result<u32, SignalProtocolError> {
742        Ok(self
743            .session_state()
744            .ok_or_else(|| {
745                SignalProtocolError::InvalidState(
746                    "local_registration_id",
747                    "No current session".into(),
748                )
749            })?
750            .local_registration_id())
751    }
752
753    pub fn session_version(&self) -> Result<u32, SignalProtocolError> {
754        Ok(self
755            .session_state()
756            .ok_or_else(|| {
757                SignalProtocolError::InvalidState("session_version", "No current session".into())
758            })?
759            .session_version()?)
760    }
761
762    pub fn local_identity_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
763        Ok(self
764            .session_state()
765            .ok_or_else(|| {
766                SignalProtocolError::InvalidState(
767                    "local_identity_key_bytes",
768                    "No current session".into(),
769                )
770            })?
771            .local_identity_key_bytes()?)
772    }
773
774    pub fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, SignalProtocolError> {
775        Ok(self
776            .session_state()
777            .ok_or_else(|| {
778                SignalProtocolError::InvalidState(
779                    "remote_identity_key_bytes",
780                    "No current session".into(),
781                )
782            })?
783            .remote_identity_key_bytes()?)
784    }
785
786    pub fn has_usable_sender_chain(&self, now: SystemTime) -> Result<bool, SignalProtocolError> {
787        match &self.current_session {
788            Some(session) => Ok(session.has_usable_sender_chain(now)?),
789            None => Ok(false),
790        }
791    }
792
793    pub fn alice_base_key(&self) -> Result<&[u8], SignalProtocolError> {
794        Ok(self
795            .session_state()
796            .ok_or_else(|| {
797                SignalProtocolError::InvalidState("alice_base_key", "No current session".into())
798            })?
799            .alice_base_key())
800    }
801
802    pub fn get_receiver_chain_key_bytes(
803        &self,
804        sender: &PublicKey,
805    ) -> Result<Option<Box<[u8]>>, SignalProtocolError> {
806        Ok(self
807            .session_state()
808            .ok_or_else(|| {
809                SignalProtocolError::InvalidState(
810                    "get_receiver_chain_key",
811                    "No current session".into(),
812                )
813            })?
814            .get_receiver_chain_key(sender)?
815            .map(|chain| chain.key()[..].into()))
816    }
817
818    pub fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
819        Ok(self
820            .session_state()
821            .ok_or_else(|| {
822                SignalProtocolError::InvalidState(
823                    "get_sender_chain_key_bytes",
824                    "No current session".into(),
825                )
826            })?
827            .get_sender_chain_key_bytes()?)
828    }
829
830    pub fn current_ratchet_key_matches(
831        &self,
832        key: &PublicKey,
833    ) -> Result<bool, SignalProtocolError> {
834        match &self.current_session {
835            Some(session) => Ok(&session.sender_ratchet_key()? == key),
836            None => Ok(false),
837        }
838    }
839
840    pub fn get_kyber_ciphertext(&self) -> Result<Option<&Vec<u8>>, SignalProtocolError> {
841        Ok(self
842            .session_state()
843            .ok_or_else(|| {
844                SignalProtocolError::InvalidState(
845                    "get_kyber_ciphertext",
846                    "No current session".into(),
847                )
848            })?
849            .get_kyber_ciphertext())
850    }
851}