1use std::result::Result;
7use std::time::{Duration, SystemTime};
8
9use bitflags::bitflags;
10use prost::Message;
11use rand::{CryptoRng, Rng};
12use subtle::ConstantTimeEq;
13
14use crate::proto::storage::{RecordStructure, SessionStructure, session_structure};
15use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION;
16use crate::ratchet::{ChainKey, MessageKeyGenerator, RootKey};
17use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
18use crate::{IdentityKey, KeyPair, PrivateKey, PublicKey, SignalProtocolError, consts, kem};
19
20#[derive(Debug)]
22pub(crate) struct InvalidSessionError(&'static str);
23
24impl std::fmt::Display for InvalidSessionError {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 self.0.fmt(f)
27 }
28}
29
30impl From<InvalidSessionError> for SignalProtocolError {
31 fn from(e: InvalidSessionError) -> Self {
32 Self::InvalidSessionStructure(e.0)
33 }
34}
35
36#[derive(Debug, Clone)]
37pub(crate) struct UnacknowledgedPreKeyMessageItems<'a> {
38 pre_key_id: Option<PreKeyId>,
39 signed_pre_key_id: SignedPreKeyId,
40 base_key: PublicKey,
41 kyber_pre_key_id: Option<KyberPreKeyId>,
45 kyber_ciphertext: Option<&'a [u8]>,
46 timestamp: SystemTime,
47}
48
49impl<'a> UnacknowledgedPreKeyMessageItems<'a> {
50 fn new(
51 pre_key_id: Option<PreKeyId>,
52 signed_pre_key_id: SignedPreKeyId,
53 base_key: PublicKey,
54 pending_kyber_pre_key: Option<&'a session_structure::PendingKyberPreKey>,
55 timestamp: SystemTime,
56 ) -> Self {
57 let (kyber_pre_key_id, kyber_ciphertext) = pending_kyber_pre_key
58 .map(|pending| (pending.pre_key_id.into(), pending.ciphertext.as_slice()))
59 .unzip();
60 Self {
61 pre_key_id,
62 signed_pre_key_id,
63 base_key,
64 kyber_pre_key_id,
65 kyber_ciphertext,
66 timestamp,
67 }
68 }
69
70 pub(crate) fn pre_key_id(&self) -> Option<PreKeyId> {
71 self.pre_key_id
72 }
73
74 pub(crate) fn signed_pre_key_id(&self) -> SignedPreKeyId {
75 self.signed_pre_key_id
76 }
77
78 pub(crate) fn base_key(&self) -> &PublicKey {
79 &self.base_key
80 }
81
82 pub(crate) fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
83 self.kyber_pre_key_id
84 }
85
86 pub(crate) fn kyber_ciphertext(&self) -> Option<&'a [u8]> {
87 self.kyber_ciphertext
88 }
89
90 pub(crate) fn timestamp(&self) -> SystemTime {
91 self.timestamp
92 }
93}
94
95bitflags! {
96 #[repr(transparent)]
106 #[derive(Clone, Copy, PartialEq, Eq)]
107 pub struct SessionUsabilityRequirements : u32 {
108 const NotStale = 1 << 0;
115 const EstablishedWithPqxdh = 1 << 1;
120 const Spqr = 1 << 2;
128 }
129}
130
131#[derive(Clone, Debug)]
132pub(crate) struct SessionState {
133 session: SessionStructure,
134}
135
136impl SessionState {
137 pub(crate) fn from_session_structure(session: SessionStructure) -> Self {
138 Self { session }
139 }
140
141 pub(crate) fn new(
142 version: u8,
143 our_identity: &IdentityKey,
144 their_identity: &IdentityKey,
145 root_key: &RootKey,
146 alice_base_key: &PublicKey,
147 pq_ratchet_state: spqr::SerializedState,
148 ) -> Self {
149 Self {
150 session: SessionStructure {
151 session_version: version as u32,
152 local_identity_public: our_identity.public_key().serialize().into_vec(),
153 remote_identity_public: their_identity.serialize().into_vec(),
154 root_key: root_key.key().to_vec(),
155 previous_counter: 0,
156 sender_chain: None,
157 receiver_chains: vec![],
158 pending_pre_key: None,
159 pending_kyber_pre_key: None,
160 remote_registration_id: 0,
161 local_registration_id: 0,
162 alice_base_key: alice_base_key.serialize().into_vec(),
163 pq_ratchet_state,
164 },
165 }
166 }
167
168 pub(crate) fn alice_base_key(&self) -> &[u8] {
169 &self.session.alice_base_key
171 }
172
173 pub(crate) fn session_version(&self) -> Result<u32, InvalidSessionError> {
174 match self.session.session_version {
175 0 => Ok(2),
176 v => Ok(v),
177 }
178 }
179
180 pub(crate) fn remote_identity_key(&self) -> Result<Option<IdentityKey>, InvalidSessionError> {
181 match self.session.remote_identity_public.len() {
182 0 => Ok(None),
183 _ => Ok(Some(
184 IdentityKey::decode(&self.session.remote_identity_public)
185 .map_err(|_| InvalidSessionError("invalid remote identity key"))?,
186 )),
187 }
188 }
189
190 pub(crate) fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, InvalidSessionError> {
191 Ok(self.remote_identity_key()?.map(|k| k.serialize().to_vec()))
192 }
193
194 pub(crate) fn local_identity_key(&self) -> Result<IdentityKey, InvalidSessionError> {
195 IdentityKey::decode(&self.session.local_identity_public)
196 .map_err(|_| InvalidSessionError("invalid local identity key"))
197 }
198
199 pub(crate) fn local_identity_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
200 Ok(self.local_identity_key()?.serialize().to_vec())
201 }
202
203 pub(crate) fn session_with_self(&self) -> Result<bool, InvalidSessionError> {
204 if let Some(remote_id) = self.remote_identity_key_bytes()? {
205 let local_id = self.local_identity_key_bytes()?;
206 return Ok(remote_id == local_id);
207 }
208
209 Ok(false)
211 }
212
213 pub(crate) fn previous_counter(&self) -> u32 {
214 self.session.previous_counter
215 }
216
217 pub(crate) fn set_previous_counter(&mut self, ctr: u32) {
218 self.session.previous_counter = ctr;
219 }
220
221 pub(crate) fn root_key(&self) -> Result<RootKey, InvalidSessionError> {
222 let root_key_bytes = self.session.root_key[..]
223 .try_into()
224 .map_err(|_| InvalidSessionError("invalid root key"))?;
225 Ok(RootKey::new(root_key_bytes))
226 }
227
228 pub(crate) fn set_root_key(&mut self, root_key: &RootKey) {
229 self.session.root_key = root_key.key().to_vec();
230 }
231
232 pub(crate) fn sender_ratchet_key(&self) -> Result<PublicKey, InvalidSessionError> {
233 match self.session.sender_chain {
234 None => Err(InvalidSessionError("missing sender chain")),
235 Some(ref c) => PublicKey::deserialize(&c.sender_ratchet_key)
236 .map_err(|_| InvalidSessionError("invalid sender chain ratchet key")),
237 }
238 }
239
240 pub(crate) fn sender_ratchet_key_for_logging(&self) -> Result<String, InvalidSessionError> {
241 Ok(hex::encode(self.sender_ratchet_key()?.public_key_bytes()))
242 }
243
244 pub(crate) fn sender_ratchet_private_key(&self) -> Result<PrivateKey, InvalidSessionError> {
245 match self.session.sender_chain {
246 None => Err(InvalidSessionError("missing sender chain")),
247 Some(ref c) => PrivateKey::deserialize(&c.sender_ratchet_key_private)
248 .map_err(|_| InvalidSessionError("invalid sender chain private ratchet key")),
249 }
250 }
251
252 pub fn has_usable_sender_chain(
253 &self,
254 now: SystemTime,
255 requirements: SessionUsabilityRequirements,
256 ) -> Result<bool, InvalidSessionError> {
257 if self.session.sender_chain.is_none() {
258 return Ok(false);
259 }
260 if requirements.contains(SessionUsabilityRequirements::NotStale) {
261 if let Some(pending_pre_key) = &self.session.pending_pre_key {
262 let creation_timestamp =
263 SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp);
264 if creation_timestamp + consts::MAX_UNACKNOWLEDGED_SESSION_AGE < now {
265 return Ok(false);
266 }
267 }
268 }
269 #[allow(clippy::collapsible_if)]
270 if requirements.contains(SessionUsabilityRequirements::EstablishedWithPqxdh) {
271 if self.session_version()? <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION.into() {
272 return Ok(false);
273 }
274 }
275 #[allow(clippy::collapsible_if)]
276 if requirements.contains(SessionUsabilityRequirements::Spqr) {
277 if self.pq_ratchet_state().is_empty() {
278 return Ok(false);
279 }
280 }
281 Ok(true)
282 }
283
284 pub(crate) fn all_receiver_chain_logging_info(&self) -> Vec<(Vec<u8>, Option<u32>)> {
285 let mut results = vec![];
286 for chain in self.session.receiver_chains.iter() {
287 let sender_ratchet_public = chain.sender_ratchet_key.clone();
288
289 let chain_key_idx = chain.chain_key.as_ref().map(|chain_key| chain_key.index);
290
291 results.push((sender_ratchet_public, chain_key_idx))
292 }
293 results
294 }
295
296 pub(crate) fn get_receiver_chain(
297 &self,
298 sender: &PublicKey,
299 ) -> Result<Option<(session_structure::Chain, usize)>, InvalidSessionError> {
300 for (idx, chain) in self.session.receiver_chains.iter().enumerate() {
301 let chain_ratchet_key = PublicKey::deserialize(&chain.sender_ratchet_key)
304 .map_err(|_| InvalidSessionError("invalid receiver chain ratchet key"))?;
305
306 if &chain_ratchet_key == sender {
307 return Ok(Some((chain.clone(), idx)));
308 }
309 }
310
311 Ok(None)
312 }
313
314 pub(crate) fn get_receiver_chain_key(
315 &self,
316 sender: &PublicKey,
317 ) -> Result<Option<ChainKey>, InvalidSessionError> {
318 match self.get_receiver_chain(sender)? {
319 None => Ok(None),
320 Some((chain, _)) => match chain.chain_key {
321 None => Err(InvalidSessionError("missing receiver chain key")),
322 Some(c) => {
323 let chain_key_bytes = c.key[..]
324 .try_into()
325 .map_err(|_| InvalidSessionError("invalid receiver chain key"))?;
326 Ok(Some(ChainKey::new(chain_key_bytes, c.index)))
327 }
328 },
329 }
330 }
331
332 pub(crate) fn add_receiver_chain(&mut self, sender: &PublicKey, chain_key: &ChainKey) {
333 let chain_key = session_structure::chain::ChainKey {
334 index: chain_key.index(),
335 key: chain_key.key().to_vec(),
336 };
337
338 let chain = session_structure::Chain {
339 sender_ratchet_key: sender.serialize().to_vec(),
340 sender_ratchet_key_private: vec![],
341 chain_key: Some(chain_key),
342 message_keys: vec![],
343 };
344
345 self.session.receiver_chains.push(chain);
346
347 if self.session.receiver_chains.len() > consts::MAX_RECEIVER_CHAINS {
348 log::info!(
349 "Trimming excessive receiver_chain for session with base key {}, chain count: {}",
350 self.sender_ratchet_key_for_logging()
351 .unwrap_or_else(|e| format!("<error: {}>", e.0)),
352 self.session.receiver_chains.len()
353 );
354 self.session.receiver_chains.remove(0);
355 }
356 }
357
358 pub(crate) fn with_receiver_chain(mut self, sender: &PublicKey, chain_key: &ChainKey) -> Self {
359 self.add_receiver_chain(sender, chain_key);
360 self
361 }
362
363 pub(crate) fn set_sender_chain(&mut self, sender: &KeyPair, next_chain_key: &ChainKey) {
364 let chain_key = session_structure::chain::ChainKey {
365 index: next_chain_key.index(),
366 key: next_chain_key.key().to_vec(),
367 };
368
369 let new_chain = session_structure::Chain {
370 sender_ratchet_key: sender.public_key.serialize().to_vec(),
371 sender_ratchet_key_private: sender.private_key.serialize().to_vec(),
372 chain_key: Some(chain_key),
373 message_keys: vec![],
374 };
375
376 self.session.sender_chain = Some(new_chain);
377 }
378
379 pub(crate) fn with_sender_chain(mut self, sender: &KeyPair, next_chain_key: &ChainKey) -> Self {
380 self.set_sender_chain(sender, next_chain_key);
381 self
382 }
383
384 pub(crate) fn get_sender_chain_key(&self) -> Result<ChainKey, InvalidSessionError> {
385 let sender_chain = self
386 .session
387 .sender_chain
388 .as_ref()
389 .ok_or(InvalidSessionError("missing sender chain"))?;
390
391 let chain_key = sender_chain
392 .chain_key
393 .as_ref()
394 .ok_or(InvalidSessionError("missing sender chain key"))?;
395
396 let chain_key_bytes = chain_key.key[..]
397 .try_into()
398 .map_err(|_| InvalidSessionError("invalid sender chain key"))?;
399
400 Ok(ChainKey::new(chain_key_bytes, chain_key.index))
401 }
402
403 pub(crate) fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
404 Ok(self.get_sender_chain_key()?.key().to_vec())
405 }
406
407 pub(crate) fn set_sender_chain_key(&mut self, next_chain_key: &ChainKey) {
408 let chain_key = session_structure::chain::ChainKey {
409 index: next_chain_key.index(),
410 key: next_chain_key.key().to_vec(),
411 };
412
413 let new_chain = match self.session.sender_chain.take() {
416 None => session_structure::Chain {
417 sender_ratchet_key: vec![],
418 sender_ratchet_key_private: vec![],
419 chain_key: Some(chain_key),
420 message_keys: vec![],
421 },
422 Some(mut c) => {
423 c.chain_key = Some(chain_key);
424 c
425 }
426 };
427
428 self.session.sender_chain = Some(new_chain);
429 }
430
431 pub(crate) fn get_message_keys(
432 &mut self,
433 sender: &PublicKey,
434 counter: u32,
435 ) -> Result<Option<MessageKeyGenerator>, InvalidSessionError> {
436 if let Some(mut chain_and_index) = self.get_receiver_chain(sender)? {
437 let message_key_idx = chain_and_index
438 .0
439 .message_keys
440 .iter()
441 .position(|m| m.index == counter);
442
443 if let Some(position) = message_key_idx {
444 let message_key = chain_and_index.0.message_keys.remove(position);
445 let keys =
446 MessageKeyGenerator::from_pb(message_key).map_err(InvalidSessionError)?;
447
448 self.session.receiver_chains[chain_and_index.1] = chain_and_index.0;
450 return Ok(Some(keys));
451 }
452 }
453
454 Ok(None)
455 }
456
457 pub(crate) fn set_message_keys(
458 &mut self,
459 sender: &PublicKey,
460 message_keys: MessageKeyGenerator,
461 ) -> Result<(), InvalidSessionError> {
462 let chain_and_index = self
463 .get_receiver_chain(sender)?
464 .expect("called set_message_keys for a non-existent chain");
465 let mut updated_chain = chain_and_index.0;
466 updated_chain.message_keys.insert(0, message_keys.into_pb());
467
468 if updated_chain.message_keys.len() > consts::MAX_MESSAGE_KEYS {
469 updated_chain.message_keys.pop();
470 }
471
472 self.session.receiver_chains[chain_and_index.1] = updated_chain;
473
474 Ok(())
475 }
476
477 pub(crate) fn set_receiver_chain_key(
478 &mut self,
479 sender: &PublicKey,
480 chain_key: &ChainKey,
481 ) -> Result<(), InvalidSessionError> {
482 let chain_and_index = self
483 .get_receiver_chain(sender)?
484 .expect("called set_receiver_chain_key for a non-existent chain");
485 let mut updated_chain = chain_and_index.0;
486 updated_chain.chain_key = Some(session_structure::chain::ChainKey {
487 index: chain_key.index(),
488 key: chain_key.key().to_vec(),
489 });
490
491 self.session.receiver_chains[chain_and_index.1] = updated_chain;
492
493 Ok(())
494 }
495
496 pub(crate) fn set_unacknowledged_pre_key_message(
497 &mut self,
498 pre_key_id: Option<PreKeyId>,
499 signed_ec_pre_key_id: SignedPreKeyId,
500 base_key: &PublicKey,
501 now: SystemTime,
502 ) {
503 let signed_ec_pre_key_id: u32 = signed_ec_pre_key_id.into();
504 let pending = session_structure::PendingPreKey {
505 pre_key_id: pre_key_id.map(PreKeyId::into),
506 signed_pre_key_id: signed_ec_pre_key_id as i32,
507 base_key: base_key.serialize().to_vec(),
508 timestamp: now
509 .duration_since(SystemTime::UNIX_EPOCH)
510 .unwrap_or_default()
511 .as_secs(),
512 };
513 self.session.pending_pre_key = Some(pending);
514 }
515
516 pub(crate) fn set_kyber_ciphertext(&mut self, ciphertext: kem::SerializedCiphertext) {
517 let pending = session_structure::PendingKyberPreKey {
518 pre_key_id: u32::MAX, ciphertext: ciphertext.into_vec(),
520 };
521 self.session.pending_kyber_pre_key = Some(pending);
522 }
523
524 pub(crate) fn set_unacknowledged_kyber_pre_key_id(
525 &mut self,
526 signed_kyber_pre_key_id: KyberPreKeyId,
527 ) {
528 let pending = self
529 .session
530 .pending_kyber_pre_key
531 .as_mut()
532 .expect("must have been set if kyber pre key is present");
533 pending.pre_key_id = signed_kyber_pre_key_id.into();
534 }
535
536 pub(crate) fn unacknowledged_pre_key_message_items(
537 &self,
538 ) -> Result<Option<UnacknowledgedPreKeyMessageItems<'_>>, InvalidSessionError> {
539 if let Some(ref pending_pre_key) = self.session.pending_pre_key {
540 Ok(Some(UnacknowledgedPreKeyMessageItems::new(
541 pending_pre_key.pre_key_id.map(Into::into),
542 (pending_pre_key.signed_pre_key_id as u32).into(),
543 PublicKey::deserialize(&pending_pre_key.base_key)
544 .map_err(|_| InvalidSessionError("invalid pending PreKey message base key"))?,
545 self.session.pending_kyber_pre_key.as_ref(),
546 SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp),
547 )))
548 } else {
549 Ok(None)
550 }
551 }
552
553 pub(crate) fn clear_unacknowledged_pre_key_message(&mut self) {
554 let SessionStructure {
557 session_version: _session_version,
558 local_identity_public: _local_identity_public,
559 remote_identity_public: _remote_identity_public,
560 root_key: _root_key,
561 previous_counter: _previous_counter,
562 sender_chain: _sender_chain,
563 receiver_chains: _receiver_chains,
564 pending_pre_key: _pending_pre_key,
565 pending_kyber_pre_key: _pending_kyber_pre_key,
566 remote_registration_id: _remote_registration_id,
567 local_registration_id: _local_registration_id,
568 alice_base_key: _alice_base_key,
569 pq_ratchet_state: _pq_ratchet_state,
570 } = &self.session;
571 self.session.pending_pre_key = None;
575 self.session.pending_kyber_pre_key = None;
576 }
577
578 pub(crate) fn set_remote_registration_id(&mut self, registration_id: u32) {
579 self.session.remote_registration_id = registration_id;
580 }
581
582 pub(crate) fn remote_registration_id(&self) -> u32 {
583 self.session.remote_registration_id
584 }
585
586 pub(crate) fn set_local_registration_id(&mut self, registration_id: u32) {
587 self.session.local_registration_id = registration_id;
588 }
589
590 pub(crate) fn local_registration_id(&self) -> u32 {
591 self.session.local_registration_id
592 }
593
594 pub(crate) fn get_kyber_ciphertext(&self) -> Option<&Vec<u8>> {
595 self.session
596 .pending_kyber_pre_key
597 .as_ref()
598 .map(|pending| &pending.ciphertext)
599 }
600
601 pub(crate) fn pq_ratchet_recv(
602 &mut self,
603 msg: &spqr::SerializedMessage,
604 ) -> Result<spqr::MessageKey, spqr::Error> {
605 let _trace = libsignal_debug::trace_block!("SessionState::pq_ratchet_recv");
606 let spqr::Recv { state, key } = spqr::recv(&self.session.pq_ratchet_state, msg)?;
607 self.session.pq_ratchet_state = state;
608 Ok(key)
609 }
610
611 pub(crate) fn pq_ratchet_send<R: Rng + CryptoRng>(
612 &mut self,
613 csprng: &mut R,
614 ) -> Result<(spqr::SerializedMessage, spqr::MessageKey), spqr::Error> {
615 let _trace = libsignal_debug::trace_block!("SessionState::pq_ratchet_send");
616 let spqr::Send { state, key, msg } = spqr::send(&self.session.pq_ratchet_state, csprng)?;
617 self.session.pq_ratchet_state = state;
618 Ok((msg, key))
619 }
620
621 pub(crate) fn pq_ratchet_state(&self) -> &spqr::SerializedState {
622 &self.session.pq_ratchet_state
623 }
624}
625
626impl From<SessionStructure> for SessionState {
627 fn from(value: SessionStructure) -> SessionState {
628 SessionState::from_session_structure(value)
629 }
630}
631
632impl From<SessionState> for SessionStructure {
633 fn from(value: SessionState) -> SessionStructure {
634 value.session
635 }
636}
637
638impl From<&SessionState> for SessionStructure {
639 fn from(value: &SessionState) -> SessionStructure {
640 value.session.clone()
641 }
642}
643
644#[derive(Clone)]
645pub struct SessionRecord {
646 current_session: Option<SessionState>,
647 previous_sessions: Vec<Vec<u8>>,
648}
649
650impl SessionRecord {
651 pub fn new_fresh() -> Self {
652 Self {
653 current_session: None,
654 previous_sessions: Vec::new(),
655 }
656 }
657
658 pub(crate) fn new(state: SessionState) -> Self {
659 Self {
660 current_session: Some(state),
661 previous_sessions: Vec::new(),
662 }
663 }
664
665 pub fn deserialize(bytes: &[u8]) -> Result<Self, SignalProtocolError> {
666 let record = RecordStructure::decode(bytes)
667 .map_err(|_| InvalidSessionError("failed to decode session record protobuf"))?;
668
669 Ok(Self {
670 current_session: record.current_session.map(|s| s.into()),
671 previous_sessions: record.previous_sessions,
672 })
673 }
674
675 pub(crate) fn promote_matching_session(
682 &mut self,
683 version: u32,
684 alice_base_key: &[u8],
685 ) -> Result<bool, InvalidSessionError> {
686 if let Some(current_session) = &self.current_session {
687 if current_session.session_version()? == version
688 && alice_base_key
689 .ct_eq(current_session.alice_base_key())
690 .into()
691 {
692 return Ok(true);
693 }
694 }
695
696 let mut session_to_promote = None;
697 for (i, previous) in self.previous_session_states().enumerate() {
698 let previous = previous?;
699 if previous.session_version()? == version
700 && alice_base_key.ct_eq(previous.alice_base_key()).into()
701 {
702 session_to_promote = Some((i, previous));
703 break;
704 }
705 }
706
707 if let Some((i, state)) = session_to_promote {
708 self.promote_old_session(i, state);
709 return Ok(true);
710 }
711
712 Ok(false)
713 }
714
715 pub(crate) fn session_state(&self) -> Option<&SessionState> {
716 self.current_session.as_ref()
717 }
718
719 pub(crate) fn session_state_mut(&mut self) -> Option<&mut SessionState> {
720 self.current_session.as_mut()
721 }
722
723 pub(crate) fn set_session_state(&mut self, session: SessionState) {
724 self.current_session = Some(session);
725 }
726
727 pub(crate) fn previous_session_states(
728 &self,
729 ) -> impl ExactSizeIterator<Item = Result<SessionState, InvalidSessionError>> + '_ {
730 self.previous_sessions.iter().map(|bytes| {
731 Ok(SessionStructure::decode(&bytes[..])
732 .map_err(|_| InvalidSessionError("failed to decode previous session protobuf"))?
733 .into())
734 })
735 }
736
737 pub(crate) fn promote_old_session(
738 &mut self,
739 old_session: usize,
740 updated_session: SessionState,
741 ) {
742 self.previous_sessions.remove(old_session);
743 self.promote_state(updated_session)
744 }
745
746 pub(crate) fn promote_state(&mut self, new_state: SessionState) {
747 self.archive_current_state_inner();
748 self.current_session = Some(new_state);
749 }
750
751 fn archive_current_state_inner(&mut self) -> bool {
755 if let Some(mut current_session) = self.current_session.take() {
756 if self.previous_sessions.len() >= consts::ARCHIVED_STATES_MAX_LENGTH {
757 self.previous_sessions.pop();
758 }
759 current_session.clear_unacknowledged_pre_key_message();
760 self.previous_sessions
761 .insert(0, current_session.session.encode_to_vec());
762 true
763 } else {
764 false
765 }
766 }
767
768 pub fn archive_current_state(&mut self) -> Result<(), SignalProtocolError> {
769 if !self.archive_current_state_inner() {
770 log::info!("Skipping archive, current session state is fresh");
771 }
772 Ok(())
773 }
774
775 pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
776 let record = RecordStructure {
777 current_session: self.current_session.as_ref().map(|s| s.into()),
778 previous_sessions: self.previous_sessions.clone(),
779 };
780 Ok(record.encode_to_vec())
781 }
782
783 pub fn current_pq_state(&self) -> Option<&spqr::SerializedState> {
784 self.current_session.as_ref().map(|s| s.pq_ratchet_state())
785 }
786
787 pub fn remote_registration_id(&self) -> Result<u32, SignalProtocolError> {
788 Ok(self
789 .session_state()
790 .ok_or_else(|| {
791 SignalProtocolError::InvalidState(
792 "remote_registration_id",
793 "No current session".into(),
794 )
795 })?
796 .remote_registration_id())
797 }
798
799 pub fn local_registration_id(&self) -> Result<u32, SignalProtocolError> {
800 Ok(self
801 .session_state()
802 .ok_or_else(|| {
803 SignalProtocolError::InvalidState(
804 "local_registration_id",
805 "No current session".into(),
806 )
807 })?
808 .local_registration_id())
809 }
810
811 pub fn session_version(&self) -> Result<u32, SignalProtocolError> {
812 Ok(self
813 .session_state()
814 .ok_or_else(|| {
815 SignalProtocolError::InvalidState("session_version", "No current session".into())
816 })?
817 .session_version()?)
818 }
819
820 pub fn local_identity_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
821 Ok(self
822 .session_state()
823 .ok_or_else(|| {
824 SignalProtocolError::InvalidState(
825 "local_identity_key_bytes",
826 "No current session".into(),
827 )
828 })?
829 .local_identity_key_bytes()?)
830 }
831
832 pub fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, SignalProtocolError> {
833 Ok(self
834 .session_state()
835 .ok_or_else(|| {
836 SignalProtocolError::InvalidState(
837 "remote_identity_key_bytes",
838 "No current session".into(),
839 )
840 })?
841 .remote_identity_key_bytes()?)
842 }
843
844 pub fn has_usable_sender_chain(
845 &self,
846 now: SystemTime,
847 requirements: SessionUsabilityRequirements,
848 ) -> Result<bool, SignalProtocolError> {
849 match &self.current_session {
850 Some(session) => Ok(session.has_usable_sender_chain(now, requirements)?),
851 None => Ok(false),
852 }
853 }
854
855 pub fn alice_base_key(&self) -> Result<&[u8], SignalProtocolError> {
856 Ok(self
857 .session_state()
858 .ok_or_else(|| {
859 SignalProtocolError::InvalidState("alice_base_key", "No current session".into())
860 })?
861 .alice_base_key())
862 }
863
864 pub fn get_receiver_chain_key_bytes(
865 &self,
866 sender: &PublicKey,
867 ) -> Result<Option<Box<[u8]>>, SignalProtocolError> {
868 Ok(self
869 .session_state()
870 .ok_or_else(|| {
871 SignalProtocolError::InvalidState(
872 "get_receiver_chain_key",
873 "No current session".into(),
874 )
875 })?
876 .get_receiver_chain_key(sender)?
877 .map(|chain| chain.key()[..].into()))
878 }
879
880 pub fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
881 Ok(self
882 .session_state()
883 .ok_or_else(|| {
884 SignalProtocolError::InvalidState(
885 "get_sender_chain_key_bytes",
886 "No current session".into(),
887 )
888 })?
889 .get_sender_chain_key_bytes()?)
890 }
891
892 pub fn current_ratchet_key_matches(
893 &self,
894 key: &PublicKey,
895 ) -> Result<bool, SignalProtocolError> {
896 match &self.current_session {
897 Some(session) => Ok(&session.sender_ratchet_key()? == key),
898 None => Ok(false),
899 }
900 }
901
902 pub fn get_kyber_ciphertext(&self) -> Result<Option<&Vec<u8>>, SignalProtocolError> {
903 Ok(self
904 .session_state()
905 .ok_or_else(|| {
906 SignalProtocolError::InvalidState(
907 "get_kyber_ciphertext",
908 "No current session".into(),
909 )
910 })?
911 .get_kyber_ciphertext())
912 }
913}