1use std::result::Result;
7use std::time::{Duration, SystemTime};
8
9use bitflags::bitflags;
10use prost::Message;
11use rand::{CryptoRng, Rng};
12use subtle::ConstantTimeEq;
13
14use crate::proto::storage::{RecordStructure, SessionStructure, session_structure};
15use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION;
16use crate::ratchet::{ChainKey, MessageKeyGenerator, RootKey};
17use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
18use crate::{IdentityKey, KeyPair, PrivateKey, PublicKey, SignalProtocolError, consts, kem};
19
20#[derive(Debug)]
22pub(crate) struct InvalidSessionError(&'static str);
23
24impl std::fmt::Display for InvalidSessionError {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 self.0.fmt(f)
27 }
28}
29
30impl From<InvalidSessionError> for SignalProtocolError {
31 fn from(e: InvalidSessionError) -> Self {
32 Self::InvalidSessionStructure(e.0)
33 }
34}
35
36#[derive(Debug, Clone)]
37pub(crate) struct UnacknowledgedPreKeyMessageItems<'a> {
38 pre_key_id: Option<PreKeyId>,
39 signed_pre_key_id: SignedPreKeyId,
40 base_key: PublicKey,
41 kyber_pre_key_id: Option<KyberPreKeyId>,
45 kyber_ciphertext: Option<&'a [u8]>,
46 timestamp: SystemTime,
47}
48
49impl<'a> UnacknowledgedPreKeyMessageItems<'a> {
50 fn new(
51 pre_key_id: Option<PreKeyId>,
52 signed_pre_key_id: SignedPreKeyId,
53 base_key: PublicKey,
54 pending_kyber_pre_key: Option<&'a session_structure::PendingKyberPreKey>,
55 timestamp: SystemTime,
56 ) -> Self {
57 let (kyber_pre_key_id, kyber_ciphertext) = pending_kyber_pre_key
58 .map(|pending| (pending.pre_key_id.into(), pending.ciphertext.as_slice()))
59 .unzip();
60 Self {
61 pre_key_id,
62 signed_pre_key_id,
63 base_key,
64 kyber_pre_key_id,
65 kyber_ciphertext,
66 timestamp,
67 }
68 }
69
70 pub(crate) fn pre_key_id(&self) -> Option<PreKeyId> {
71 self.pre_key_id
72 }
73
74 pub(crate) fn signed_pre_key_id(&self) -> SignedPreKeyId {
75 self.signed_pre_key_id
76 }
77
78 pub(crate) fn base_key(&self) -> &PublicKey {
79 &self.base_key
80 }
81
82 pub(crate) fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
83 self.kyber_pre_key_id
84 }
85
86 pub(crate) fn kyber_ciphertext(&self) -> Option<&'a [u8]> {
87 self.kyber_ciphertext
88 }
89
90 pub(crate) fn timestamp(&self) -> SystemTime {
91 self.timestamp
92 }
93}
94
95bitflags! {
96 #[repr(transparent)]
106 #[derive(Clone, Copy, PartialEq, Eq)]
107 pub struct SessionUsabilityRequirements : u32 {
108 const NotStale = 1 << 0;
115 const EstablishedWithPqxdh = 1 << 1;
120 const Spqr = 1 << 2;
128 }
129}
130
131#[derive(Clone, Debug)]
132pub(crate) struct SessionState {
133 session: SessionStructure,
134}
135
136impl SessionState {
137 pub(crate) fn from_session_structure(session: SessionStructure) -> Self {
138 Self { session }
139 }
140
141 pub(crate) fn new(
142 version: u8,
143 our_identity: &IdentityKey,
144 their_identity: &IdentityKey,
145 root_key: &RootKey,
146 alice_base_key: &PublicKey,
147 pq_ratchet_state: spqr::SerializedState,
148 ) -> Self {
149 Self {
150 session: SessionStructure {
151 session_version: version as u32,
152 local_identity_public: our_identity.public_key().serialize().into_vec(),
153 remote_identity_public: their_identity.serialize().into_vec(),
154 root_key: root_key.key().to_vec(),
155 previous_counter: 0,
156 sender_chain: None,
157 receiver_chains: vec![],
158 pending_pre_key: None,
159 pending_kyber_pre_key: None,
160 remote_registration_id: 0,
161 local_registration_id: 0,
162 alice_base_key: alice_base_key.serialize().into_vec(),
163 pq_ratchet_state,
164 },
165 }
166 }
167
168 pub(crate) fn alice_base_key(&self) -> &[u8] {
169 &self.session.alice_base_key
171 }
172
173 pub(crate) fn session_version(&self) -> Result<u32, InvalidSessionError> {
174 match self.session.session_version {
175 0 => Ok(2),
176 v => Ok(v),
177 }
178 }
179
180 pub(crate) fn remote_identity_key(&self) -> Result<Option<IdentityKey>, InvalidSessionError> {
181 match self.session.remote_identity_public.len() {
182 0 => Ok(None),
183 _ => Ok(Some(
184 IdentityKey::decode(&self.session.remote_identity_public)
185 .map_err(|_| InvalidSessionError("invalid remote identity key"))?,
186 )),
187 }
188 }
189
190 pub(crate) fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, InvalidSessionError> {
191 Ok(self.remote_identity_key()?.map(|k| k.serialize().to_vec()))
192 }
193
194 pub(crate) fn local_identity_key(&self) -> Result<IdentityKey, InvalidSessionError> {
195 IdentityKey::decode(&self.session.local_identity_public)
196 .map_err(|_| InvalidSessionError("invalid local identity key"))
197 }
198
199 pub(crate) fn local_identity_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
200 Ok(self.local_identity_key()?.serialize().to_vec())
201 }
202
203 pub(crate) fn session_with_self(&self) -> Result<bool, InvalidSessionError> {
204 if let Some(remote_id) = self.remote_identity_key_bytes()? {
205 let local_id = self.local_identity_key_bytes()?;
206 return Ok(remote_id == local_id);
207 }
208
209 Ok(false)
211 }
212
213 pub(crate) fn previous_counter(&self) -> u32 {
214 self.session.previous_counter
215 }
216
217 pub(crate) fn set_previous_counter(&mut self, ctr: u32) {
218 self.session.previous_counter = ctr;
219 }
220
221 pub(crate) fn root_key(&self) -> Result<RootKey, InvalidSessionError> {
222 let root_key_bytes = self.session.root_key[..]
223 .try_into()
224 .map_err(|_| InvalidSessionError("invalid root key"))?;
225 Ok(RootKey::new(root_key_bytes))
226 }
227
228 pub(crate) fn set_root_key(&mut self, root_key: &RootKey) {
229 self.session.root_key = root_key.key().to_vec();
230 }
231
232 pub(crate) fn sender_ratchet_key(&self) -> Result<PublicKey, InvalidSessionError> {
233 match self.session.sender_chain {
234 None => Err(InvalidSessionError("missing sender chain")),
235 Some(ref c) => PublicKey::deserialize(&c.sender_ratchet_key)
236 .map_err(|_| InvalidSessionError("invalid sender chain ratchet key")),
237 }
238 }
239
240 pub(crate) fn sender_ratchet_key_for_logging(&self) -> Result<String, InvalidSessionError> {
241 Ok(hex::encode(self.sender_ratchet_key()?.public_key_bytes()))
242 }
243
244 pub(crate) fn sender_ratchet_private_key(&self) -> Result<PrivateKey, InvalidSessionError> {
245 match self.session.sender_chain {
246 None => Err(InvalidSessionError("missing sender chain")),
247 Some(ref c) => PrivateKey::deserialize(&c.sender_ratchet_key_private)
248 .map_err(|_| InvalidSessionError("invalid sender chain private ratchet key")),
249 }
250 }
251
252 pub fn has_usable_sender_chain(
253 &self,
254 now: SystemTime,
255 requirements: SessionUsabilityRequirements,
256 ) -> Result<bool, InvalidSessionError> {
257 if self.session.sender_chain.is_none() {
258 return Ok(false);
259 }
260 if requirements.contains(SessionUsabilityRequirements::NotStale) {
261 if let Some(pending_pre_key) = &self.session.pending_pre_key {
262 let creation_timestamp =
263 SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp);
264 if creation_timestamp + consts::MAX_UNACKNOWLEDGED_SESSION_AGE < now {
265 return Ok(false);
266 }
267 }
268 }
269 #[allow(clippy::collapsible_if)]
270 if requirements.contains(SessionUsabilityRequirements::EstablishedWithPqxdh) {
271 if self.session_version()? <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION.into() {
272 return Ok(false);
273 }
274 }
275 #[allow(clippy::collapsible_if)]
276 if requirements.contains(SessionUsabilityRequirements::Spqr) {
277 if self.pq_ratchet_state().is_empty() {
278 return Ok(false);
279 }
280 }
281 Ok(true)
282 }
283
284 pub(crate) fn all_receiver_chain_logging_info(&self) -> Vec<(Vec<u8>, Option<u32>)> {
285 let mut results = vec![];
286 for chain in self.session.receiver_chains.iter() {
287 let sender_ratchet_public = chain.sender_ratchet_key.clone();
288
289 let chain_key_idx = chain.chain_key.as_ref().map(|chain_key| chain_key.index);
290
291 results.push((sender_ratchet_public, chain_key_idx))
292 }
293 results
294 }
295
296 pub(crate) fn get_receiver_chain(
297 &self,
298 sender: &PublicKey,
299 ) -> Result<Option<(session_structure::Chain, usize)>, InvalidSessionError> {
300 for (idx, chain) in self.session.receiver_chains.iter().enumerate() {
301 let chain_ratchet_key = PublicKey::deserialize(&chain.sender_ratchet_key)
304 .map_err(|_| InvalidSessionError("invalid receiver chain ratchet key"))?;
305
306 if &chain_ratchet_key == sender {
307 return Ok(Some((chain.clone(), idx)));
308 }
309 }
310
311 Ok(None)
312 }
313
314 pub(crate) fn get_receiver_chain_key(
315 &self,
316 sender: &PublicKey,
317 ) -> Result<Option<ChainKey>, InvalidSessionError> {
318 match self.get_receiver_chain(sender)? {
319 None => Ok(None),
320 Some((chain, _)) => match chain.chain_key {
321 None => Err(InvalidSessionError("missing receiver chain key")),
322 Some(c) => {
323 let chain_key_bytes = c.key[..]
324 .try_into()
325 .map_err(|_| InvalidSessionError("invalid receiver chain key"))?;
326 Ok(Some(ChainKey::new(chain_key_bytes, c.index)))
327 }
328 },
329 }
330 }
331
332 pub(crate) fn add_receiver_chain(&mut self, sender: &PublicKey, chain_key: &ChainKey) {
333 let chain_key = session_structure::chain::ChainKey {
334 index: chain_key.index(),
335 key: chain_key.key().to_vec(),
336 };
337
338 let chain = session_structure::Chain {
339 sender_ratchet_key: sender.serialize().to_vec(),
340 sender_ratchet_key_private: vec![],
341 chain_key: Some(chain_key),
342 message_keys: vec![],
343 };
344
345 self.session.receiver_chains.push(chain);
346
347 if self.session.receiver_chains.len() > consts::MAX_RECEIVER_CHAINS {
348 log::info!(
349 "Trimming excessive receiver_chain for session with base key {}, chain count: {}",
350 self.sender_ratchet_key_for_logging()
351 .unwrap_or_else(|e| format!("<error: {}>", e.0)),
352 self.session.receiver_chains.len()
353 );
354 self.session.receiver_chains.remove(0);
355 }
356 }
357
358 pub(crate) fn with_receiver_chain(mut self, sender: &PublicKey, chain_key: &ChainKey) -> Self {
359 self.add_receiver_chain(sender, chain_key);
360 self
361 }
362
363 pub(crate) fn set_sender_chain(&mut self, sender: &KeyPair, next_chain_key: &ChainKey) {
364 let chain_key = session_structure::chain::ChainKey {
365 index: next_chain_key.index(),
366 key: next_chain_key.key().to_vec(),
367 };
368
369 let new_chain = session_structure::Chain {
370 sender_ratchet_key: sender.public_key.serialize().to_vec(),
371 sender_ratchet_key_private: sender.private_key.serialize().to_vec(),
372 chain_key: Some(chain_key),
373 message_keys: vec![],
374 };
375
376 self.session.sender_chain = Some(new_chain);
377 }
378
379 pub(crate) fn with_sender_chain(mut self, sender: &KeyPair, next_chain_key: &ChainKey) -> Self {
380 self.set_sender_chain(sender, next_chain_key);
381 self
382 }
383
384 pub(crate) fn get_sender_chain_key(&self) -> Result<ChainKey, InvalidSessionError> {
385 let sender_chain = self
386 .session
387 .sender_chain
388 .as_ref()
389 .ok_or(InvalidSessionError("missing sender chain"))?;
390
391 let chain_key = sender_chain
392 .chain_key
393 .as_ref()
394 .ok_or(InvalidSessionError("missing sender chain key"))?;
395
396 let chain_key_bytes = chain_key.key[..]
397 .try_into()
398 .map_err(|_| InvalidSessionError("invalid sender chain key"))?;
399
400 Ok(ChainKey::new(chain_key_bytes, chain_key.index))
401 }
402
403 pub(crate) fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
404 Ok(self.get_sender_chain_key()?.key().to_vec())
405 }
406
407 pub(crate) fn set_sender_chain_key(&mut self, next_chain_key: &ChainKey) {
408 let chain_key = session_structure::chain::ChainKey {
409 index: next_chain_key.index(),
410 key: next_chain_key.key().to_vec(),
411 };
412
413 let new_chain = match self.session.sender_chain.take() {
416 None => session_structure::Chain {
417 sender_ratchet_key: vec![],
418 sender_ratchet_key_private: vec![],
419 chain_key: Some(chain_key),
420 message_keys: vec![],
421 },
422 Some(mut c) => {
423 c.chain_key = Some(chain_key);
424 c
425 }
426 };
427
428 self.session.sender_chain = Some(new_chain);
429 }
430
431 pub(crate) fn get_message_keys(
432 &mut self,
433 sender: &PublicKey,
434 counter: u32,
435 ) -> Result<Option<MessageKeyGenerator>, InvalidSessionError> {
436 if let Some(mut chain_and_index) = self.get_receiver_chain(sender)? {
437 let message_key_idx = chain_and_index
438 .0
439 .message_keys
440 .iter()
441 .position(|m| m.index == counter);
442
443 if let Some(position) = message_key_idx {
444 let message_key = chain_and_index.0.message_keys.remove(position);
445 let keys =
446 MessageKeyGenerator::from_pb(message_key).map_err(InvalidSessionError)?;
447
448 self.session.receiver_chains[chain_and_index.1] = chain_and_index.0;
450 return Ok(Some(keys));
451 }
452 }
453
454 Ok(None)
455 }
456
457 pub(crate) fn set_message_keys(
458 &mut self,
459 sender: &PublicKey,
460 message_keys: MessageKeyGenerator,
461 ) -> Result<(), InvalidSessionError> {
462 let chain_and_index = self
463 .get_receiver_chain(sender)?
464 .expect("called set_message_keys for a non-existent chain");
465 let mut updated_chain = chain_and_index.0;
466 updated_chain.message_keys.insert(0, message_keys.into_pb());
467
468 if updated_chain.message_keys.len() > consts::MAX_MESSAGE_KEYS {
469 updated_chain.message_keys.pop();
470 }
471
472 self.session.receiver_chains[chain_and_index.1] = updated_chain;
473
474 Ok(())
475 }
476
477 pub(crate) fn set_receiver_chain_key(
478 &mut self,
479 sender: &PublicKey,
480 chain_key: &ChainKey,
481 ) -> Result<(), InvalidSessionError> {
482 let chain_and_index = self
483 .get_receiver_chain(sender)?
484 .expect("called set_receiver_chain_key for a non-existent chain");
485 let mut updated_chain = chain_and_index.0;
486 updated_chain.chain_key = Some(session_structure::chain::ChainKey {
487 index: chain_key.index(),
488 key: chain_key.key().to_vec(),
489 });
490
491 self.session.receiver_chains[chain_and_index.1] = updated_chain;
492
493 Ok(())
494 }
495
496 pub(crate) fn set_unacknowledged_pre_key_message(
497 &mut self,
498 pre_key_id: Option<PreKeyId>,
499 signed_ec_pre_key_id: SignedPreKeyId,
500 base_key: &PublicKey,
501 now: SystemTime,
502 ) {
503 let signed_ec_pre_key_id: u32 = signed_ec_pre_key_id.into();
504 let pending = session_structure::PendingPreKey {
505 pre_key_id: pre_key_id.map(PreKeyId::into),
506 signed_pre_key_id: signed_ec_pre_key_id as i32,
507 base_key: base_key.serialize().to_vec(),
508 timestamp: now
509 .duration_since(SystemTime::UNIX_EPOCH)
510 .unwrap_or_default()
511 .as_secs(),
512 };
513 self.session.pending_pre_key = Some(pending);
514 }
515
516 pub(crate) fn set_kyber_ciphertext(&mut self, ciphertext: kem::SerializedCiphertext) {
517 let pending = session_structure::PendingKyberPreKey {
518 pre_key_id: u32::MAX, ciphertext: ciphertext.into_vec(),
520 };
521 self.session.pending_kyber_pre_key = Some(pending);
522 }
523
524 pub(crate) fn set_unacknowledged_kyber_pre_key_id(
525 &mut self,
526 signed_kyber_pre_key_id: KyberPreKeyId,
527 ) {
528 let pending = self
529 .session
530 .pending_kyber_pre_key
531 .as_mut()
532 .expect("must have been set if kyber pre key is present");
533 pending.pre_key_id = signed_kyber_pre_key_id.into();
534 }
535
536 pub(crate) fn unacknowledged_pre_key_message_items(
537 &self,
538 ) -> Result<Option<UnacknowledgedPreKeyMessageItems<'_>>, InvalidSessionError> {
539 if let Some(ref pending_pre_key) = self.session.pending_pre_key {
540 Ok(Some(UnacknowledgedPreKeyMessageItems::new(
541 pending_pre_key.pre_key_id.map(Into::into),
542 (pending_pre_key.signed_pre_key_id as u32).into(),
543 PublicKey::deserialize(&pending_pre_key.base_key)
544 .map_err(|_| InvalidSessionError("invalid pending PreKey message base key"))?,
545 self.session.pending_kyber_pre_key.as_ref(),
546 SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp),
547 )))
548 } else {
549 Ok(None)
550 }
551 }
552
553 pub(crate) fn clear_unacknowledged_pre_key_message(&mut self) {
554 let SessionStructure {
557 session_version: _session_version,
558 local_identity_public: _local_identity_public,
559 remote_identity_public: _remote_identity_public,
560 root_key: _root_key,
561 previous_counter: _previous_counter,
562 sender_chain: _sender_chain,
563 receiver_chains: _receiver_chains,
564 pending_pre_key: _pending_pre_key,
565 pending_kyber_pre_key: _pending_kyber_pre_key,
566 remote_registration_id: _remote_registration_id,
567 local_registration_id: _local_registration_id,
568 alice_base_key: _alice_base_key,
569 pq_ratchet_state: _pq_ratchet_state,
570 } = &self.session;
571 self.session.pending_pre_key = None;
575 self.session.pending_kyber_pre_key = None;
576 }
577
578 pub(crate) fn set_remote_registration_id(&mut self, registration_id: u32) {
579 self.session.remote_registration_id = registration_id;
580 }
581
582 pub(crate) fn remote_registration_id(&self) -> u32 {
583 self.session.remote_registration_id
584 }
585
586 pub(crate) fn set_local_registration_id(&mut self, registration_id: u32) {
587 self.session.local_registration_id = registration_id;
588 }
589
590 pub(crate) fn local_registration_id(&self) -> u32 {
591 self.session.local_registration_id
592 }
593
594 pub(crate) fn get_kyber_ciphertext(&self) -> Option<&Vec<u8>> {
595 self.session
596 .pending_kyber_pre_key
597 .as_ref()
598 .map(|pending| &pending.ciphertext)
599 }
600
601 pub(crate) fn pq_ratchet_recv(
602 &mut self,
603 msg: &spqr::SerializedMessage,
604 ) -> Result<spqr::MessageKey, spqr::Error> {
605 let spqr::Recv { state, key } = spqr::recv(&self.session.pq_ratchet_state, msg)?;
606 self.session.pq_ratchet_state = state;
607 Ok(key)
608 }
609
610 pub(crate) fn pq_ratchet_send<R: Rng + CryptoRng>(
611 &mut self,
612 csprng: &mut R,
613 ) -> Result<(spqr::SerializedMessage, spqr::MessageKey), spqr::Error> {
614 let spqr::Send { state, key, msg } = spqr::send(&self.session.pq_ratchet_state, csprng)?;
615 self.session.pq_ratchet_state = state;
616 Ok((msg, key))
617 }
618
619 pub(crate) fn pq_ratchet_state(&self) -> &spqr::SerializedState {
620 &self.session.pq_ratchet_state
621 }
622}
623
624impl From<SessionStructure> for SessionState {
625 fn from(value: SessionStructure) -> SessionState {
626 SessionState::from_session_structure(value)
627 }
628}
629
630impl From<SessionState> for SessionStructure {
631 fn from(value: SessionState) -> SessionStructure {
632 value.session
633 }
634}
635
636impl From<&SessionState> for SessionStructure {
637 fn from(value: &SessionState) -> SessionStructure {
638 value.session.clone()
639 }
640}
641
642#[derive(Clone)]
643pub struct SessionRecord {
644 current_session: Option<SessionState>,
645 previous_sessions: Vec<Vec<u8>>,
646}
647
648impl SessionRecord {
649 pub fn new_fresh() -> Self {
650 Self {
651 current_session: None,
652 previous_sessions: Vec::new(),
653 }
654 }
655
656 pub(crate) fn new(state: SessionState) -> Self {
657 Self {
658 current_session: Some(state),
659 previous_sessions: Vec::new(),
660 }
661 }
662
663 pub fn deserialize(bytes: &[u8]) -> Result<Self, SignalProtocolError> {
664 let record = RecordStructure::decode(bytes)
665 .map_err(|_| InvalidSessionError("failed to decode session record protobuf"))?;
666
667 Ok(Self {
668 current_session: record.current_session.map(|s| s.into()),
669 previous_sessions: record.previous_sessions,
670 })
671 }
672
673 pub(crate) fn promote_matching_session(
680 &mut self,
681 version: u32,
682 alice_base_key: &[u8],
683 ) -> Result<bool, InvalidSessionError> {
684 if let Some(current_session) = &self.current_session {
685 if current_session.session_version()? == version
686 && alice_base_key
687 .ct_eq(current_session.alice_base_key())
688 .into()
689 {
690 return Ok(true);
691 }
692 }
693
694 let mut session_to_promote = None;
695 for (i, previous) in self.previous_session_states().enumerate() {
696 let previous = previous?;
697 if previous.session_version()? == version
698 && alice_base_key.ct_eq(previous.alice_base_key()).into()
699 {
700 session_to_promote = Some((i, previous));
701 break;
702 }
703 }
704
705 if let Some((i, state)) = session_to_promote {
706 self.promote_old_session(i, state);
707 return Ok(true);
708 }
709
710 Ok(false)
711 }
712
713 pub(crate) fn session_state(&self) -> Option<&SessionState> {
714 self.current_session.as_ref()
715 }
716
717 pub(crate) fn session_state_mut(&mut self) -> Option<&mut SessionState> {
718 self.current_session.as_mut()
719 }
720
721 pub(crate) fn set_session_state(&mut self, session: SessionState) {
722 self.current_session = Some(session);
723 }
724
725 pub(crate) fn previous_session_states(
726 &self,
727 ) -> impl ExactSizeIterator<Item = Result<SessionState, InvalidSessionError>> + '_ {
728 self.previous_sessions.iter().map(|bytes| {
729 Ok(SessionStructure::decode(&bytes[..])
730 .map_err(|_| InvalidSessionError("failed to decode previous session protobuf"))?
731 .into())
732 })
733 }
734
735 pub(crate) fn promote_old_session(
736 &mut self,
737 old_session: usize,
738 updated_session: SessionState,
739 ) {
740 self.previous_sessions.remove(old_session);
741 self.promote_state(updated_session)
742 }
743
744 pub(crate) fn promote_state(&mut self, new_state: SessionState) {
745 self.archive_current_state_inner();
746 self.current_session = Some(new_state);
747 }
748
749 fn archive_current_state_inner(&mut self) -> bool {
753 if let Some(mut current_session) = self.current_session.take() {
754 if self.previous_sessions.len() >= consts::ARCHIVED_STATES_MAX_LENGTH {
755 self.previous_sessions.pop();
756 }
757 current_session.clear_unacknowledged_pre_key_message();
758 self.previous_sessions
759 .insert(0, current_session.session.encode_to_vec());
760 true
761 } else {
762 false
763 }
764 }
765
766 pub fn archive_current_state(&mut self) -> Result<(), SignalProtocolError> {
767 if !self.archive_current_state_inner() {
768 log::info!("Skipping archive, current session state is fresh");
769 }
770 Ok(())
771 }
772
773 pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
774 let record = RecordStructure {
775 current_session: self.current_session.as_ref().map(|s| s.into()),
776 previous_sessions: self.previous_sessions.clone(),
777 };
778 Ok(record.encode_to_vec())
779 }
780
781 pub fn current_pq_state(&self) -> Option<&spqr::SerializedState> {
782 self.current_session.as_ref().map(|s| s.pq_ratchet_state())
783 }
784
785 pub fn remote_registration_id(&self) -> Result<u32, SignalProtocolError> {
786 Ok(self
787 .session_state()
788 .ok_or_else(|| {
789 SignalProtocolError::InvalidState(
790 "remote_registration_id",
791 "No current session".into(),
792 )
793 })?
794 .remote_registration_id())
795 }
796
797 pub fn local_registration_id(&self) -> Result<u32, SignalProtocolError> {
798 Ok(self
799 .session_state()
800 .ok_or_else(|| {
801 SignalProtocolError::InvalidState(
802 "local_registration_id",
803 "No current session".into(),
804 )
805 })?
806 .local_registration_id())
807 }
808
809 pub fn session_version(&self) -> Result<u32, SignalProtocolError> {
810 Ok(self
811 .session_state()
812 .ok_or_else(|| {
813 SignalProtocolError::InvalidState("session_version", "No current session".into())
814 })?
815 .session_version()?)
816 }
817
818 pub fn local_identity_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
819 Ok(self
820 .session_state()
821 .ok_or_else(|| {
822 SignalProtocolError::InvalidState(
823 "local_identity_key_bytes",
824 "No current session".into(),
825 )
826 })?
827 .local_identity_key_bytes()?)
828 }
829
830 pub fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, SignalProtocolError> {
831 Ok(self
832 .session_state()
833 .ok_or_else(|| {
834 SignalProtocolError::InvalidState(
835 "remote_identity_key_bytes",
836 "No current session".into(),
837 )
838 })?
839 .remote_identity_key_bytes()?)
840 }
841
842 pub fn has_usable_sender_chain(
843 &self,
844 now: SystemTime,
845 requirements: SessionUsabilityRequirements,
846 ) -> Result<bool, SignalProtocolError> {
847 match &self.current_session {
848 Some(session) => Ok(session.has_usable_sender_chain(now, requirements)?),
849 None => Ok(false),
850 }
851 }
852
853 pub fn alice_base_key(&self) -> Result<&[u8], SignalProtocolError> {
854 Ok(self
855 .session_state()
856 .ok_or_else(|| {
857 SignalProtocolError::InvalidState("alice_base_key", "No current session".into())
858 })?
859 .alice_base_key())
860 }
861
862 pub fn get_receiver_chain_key_bytes(
863 &self,
864 sender: &PublicKey,
865 ) -> Result<Option<Box<[u8]>>, SignalProtocolError> {
866 Ok(self
867 .session_state()
868 .ok_or_else(|| {
869 SignalProtocolError::InvalidState(
870 "get_receiver_chain_key",
871 "No current session".into(),
872 )
873 })?
874 .get_receiver_chain_key(sender)?
875 .map(|chain| chain.key()[..].into()))
876 }
877
878 pub fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
879 Ok(self
880 .session_state()
881 .ok_or_else(|| {
882 SignalProtocolError::InvalidState(
883 "get_sender_chain_key_bytes",
884 "No current session".into(),
885 )
886 })?
887 .get_sender_chain_key_bytes()?)
888 }
889
890 pub fn current_ratchet_key_matches(
891 &self,
892 key: &PublicKey,
893 ) -> Result<bool, SignalProtocolError> {
894 match &self.current_session {
895 Some(session) => Ok(&session.sender_ratchet_key()? == key),
896 None => Ok(false),
897 }
898 }
899
900 pub fn get_kyber_ciphertext(&self) -> Result<Option<&Vec<u8>>, SignalProtocolError> {
901 Ok(self
902 .session_state()
903 .ok_or_else(|| {
904 SignalProtocolError::InvalidState(
905 "get_kyber_ciphertext",
906 "No current session".into(),
907 )
908 })?
909 .get_kyber_ciphertext())
910 }
911}