1use std::result::Result;
7use std::time::{Duration, SystemTime};
8
9use bitflags::bitflags;
10use prost::Message;
11#[cfg(test)]
12use rand::{CryptoRng, Rng};
13use subtle::ConstantTimeEq;
14
15use crate::proto::storage::{RecordStructure, SessionStructure, session_structure};
16use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION;
17#[cfg(test)]
18use crate::ratchet::MessageKeyGenerator;
19use crate::ratchet::{ChainKey, RootKey};
20use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
21use crate::{IdentityKey, KeyPair, PrivateKey, PublicKey, SignalProtocolError, consts, kem};
22
23#[derive(Debug)]
25pub(crate) struct InvalidSessionError(pub(crate) &'static str);
26
27impl std::fmt::Display for InvalidSessionError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 self.0.fmt(f)
30 }
31}
32
33impl From<InvalidSessionError> for SignalProtocolError {
34 fn from(e: InvalidSessionError) -> Self {
35 Self::InvalidSessionStructure(e.0)
36 }
37}
38
39#[derive(Debug, Clone)]
40pub(crate) struct UnacknowledgedPreKeyMessageItems<'a> {
41 pre_key_id: Option<PreKeyId>,
42 signed_pre_key_id: SignedPreKeyId,
43 base_key: PublicKey,
44 kyber_pre_key_id: Option<KyberPreKeyId>,
48 kyber_ciphertext: Option<&'a [u8]>,
49 timestamp: SystemTime,
50}
51
52impl<'a> UnacknowledgedPreKeyMessageItems<'a> {
53 fn new(
54 pre_key_id: Option<PreKeyId>,
55 signed_pre_key_id: SignedPreKeyId,
56 base_key: PublicKey,
57 pending_kyber_pre_key: Option<&'a session_structure::PendingKyberPreKey>,
58 timestamp: SystemTime,
59 ) -> Self {
60 let (kyber_pre_key_id, kyber_ciphertext) = pending_kyber_pre_key
61 .map(|pending| (pending.pre_key_id.into(), pending.ciphertext.as_slice()))
62 .unzip();
63 Self {
64 pre_key_id,
65 signed_pre_key_id,
66 base_key,
67 kyber_pre_key_id,
68 kyber_ciphertext,
69 timestamp,
70 }
71 }
72
73 pub(crate) fn pre_key_id(&self) -> Option<PreKeyId> {
74 self.pre_key_id
75 }
76
77 pub(crate) fn signed_pre_key_id(&self) -> SignedPreKeyId {
78 self.signed_pre_key_id
79 }
80
81 pub(crate) fn base_key(&self) -> &PublicKey {
82 &self.base_key
83 }
84
85 pub(crate) fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
86 self.kyber_pre_key_id
87 }
88
89 pub(crate) fn kyber_ciphertext(&self) -> Option<&'a [u8]> {
90 self.kyber_ciphertext
91 }
92
93 pub(crate) fn timestamp(&self) -> SystemTime {
94 self.timestamp
95 }
96}
97
98bitflags! {
99 #[repr(transparent)]
109 #[derive(Clone, Copy, PartialEq, Eq)]
110 pub struct SessionUsabilityRequirements : u32 {
111 const NotStale = 1 << 0;
118 const EstablishedWithPqxdh = 1 << 1;
123 const Spqr = 1 << 2;
131 }
132}
133
134#[derive(Clone, Debug)]
135pub(crate) struct SessionState {
136 session: SessionStructure,
137}
138
139impl SessionState {
140 pub(crate) fn from_session_structure(session: SessionStructure) -> Self {
141 Self { session }
142 }
143
144 pub(crate) fn new(
145 version: u8,
146 our_identity: &IdentityKey,
147 their_identity: &IdentityKey,
148 root_key: &RootKey,
149 alice_base_key: &PublicKey,
150 pq_ratchet_state: spqr::SerializedState,
151 ) -> Self {
152 Self {
153 session: SessionStructure {
154 session_version: version as u32,
155 local_identity_public: our_identity.public_key().serialize().into_vec(),
156 remote_identity_public: their_identity.serialize().into_vec(),
157 root_key: root_key.key().to_vec(),
158 previous_counter: 0,
159 sender_chain: None,
160 receiver_chains: vec![],
161 pending_pre_key: None,
162 pending_kyber_pre_key: None,
163 remote_registration_id: 0,
164 local_registration_id: 0,
165 alice_base_key: alice_base_key.serialize().into_vec(),
166 pq_ratchet_state,
167 },
168 }
169 }
170
171 pub(crate) fn alice_base_key(&self) -> &[u8] {
172 &self.session.alice_base_key
174 }
175
176 pub(crate) fn session_version(&self) -> Result<u32, InvalidSessionError> {
177 match self.session.session_version {
178 0 => Ok(2),
179 v => Ok(v),
180 }
181 }
182
183 pub(crate) fn remote_identity_key(&self) -> Result<Option<IdentityKey>, InvalidSessionError> {
184 match self.session.remote_identity_public.len() {
185 0 => Ok(None),
186 _ => Ok(Some(
187 IdentityKey::decode(&self.session.remote_identity_public)
188 .map_err(|_| InvalidSessionError("invalid remote identity key"))?,
189 )),
190 }
191 }
192
193 pub(crate) fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, InvalidSessionError> {
194 Ok(self.remote_identity_key()?.map(|k| k.serialize().to_vec()))
195 }
196
197 pub(crate) fn local_identity_key(&self) -> Result<IdentityKey, InvalidSessionError> {
198 IdentityKey::decode(&self.session.local_identity_public)
199 .map_err(|_| InvalidSessionError("invalid local identity key"))
200 }
201
202 pub(crate) fn local_identity_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
203 Ok(self.local_identity_key()?.serialize().to_vec())
204 }
205
206 #[deprecated]
207 #[allow(unused)]
208 pub(crate) fn session_with_self(&self) -> Result<bool, InvalidSessionError> {
209 if let Some(remote_id) = self.remote_identity_key_bytes()? {
210 let local_id = self.local_identity_key_bytes()?;
211 return Ok(remote_id == local_id);
212 }
213
214 Ok(false)
216 }
217
218 pub(crate) fn previous_counter(&self) -> u32 {
219 self.session.previous_counter
220 }
221
222 #[cfg(test)]
223 pub(crate) fn set_previous_counter(&mut self, ctr: u32) {
224 self.session.previous_counter = ctr;
225 }
226
227 #[cfg(test)]
228 pub(crate) fn root_key(&self) -> Result<RootKey, InvalidSessionError> {
229 let root_key_bytes = self.session.root_key[..]
230 .try_into()
231 .map_err(|_| InvalidSessionError("invalid root key"))?;
232 Ok(RootKey::new(root_key_bytes))
233 }
234
235 #[cfg(test)]
236 pub(crate) fn set_root_key(&mut self, root_key: &RootKey) {
237 self.session.root_key = root_key.key().to_vec();
238 }
239
240 pub(crate) fn sender_ratchet_key(&self) -> Result<PublicKey, InvalidSessionError> {
241 match self.session.sender_chain {
242 None => Err(InvalidSessionError("missing sender chain")),
243 Some(ref c) => PublicKey::deserialize(&c.sender_ratchet_key)
244 .map_err(|_| InvalidSessionError("invalid sender chain ratchet key")),
245 }
246 }
247
248 pub(crate) fn sender_ratchet_key_for_logging(&self) -> Result<String, InvalidSessionError> {
249 Ok(hex::encode(self.sender_ratchet_key()?.public_key_bytes()))
250 }
251
252 pub(crate) fn sender_ratchet_private_key(&self) -> Result<PrivateKey, InvalidSessionError> {
253 match self.session.sender_chain {
254 None => Err(InvalidSessionError("missing sender chain")),
255 Some(ref c) => PrivateKey::deserialize(&c.sender_ratchet_key_private)
256 .map_err(|_| InvalidSessionError("invalid sender chain private ratchet key")),
257 }
258 }
259
260 pub fn has_usable_sender_chain(
261 &self,
262 now: SystemTime,
263 requirements: SessionUsabilityRequirements,
264 ) -> Result<bool, InvalidSessionError> {
265 if self.session.sender_chain.is_none() {
266 return Ok(false);
267 }
268 if requirements.contains(SessionUsabilityRequirements::NotStale) {
269 if let Some(pending_pre_key) = &self.session.pending_pre_key {
270 let creation_timestamp =
271 SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp);
272 if creation_timestamp + consts::MAX_UNACKNOWLEDGED_SESSION_AGE < now {
273 return Ok(false);
274 }
275 }
276 }
277 #[allow(clippy::collapsible_if)]
278 if requirements.contains(SessionUsabilityRequirements::EstablishedWithPqxdh) {
279 if self.session_version()? <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION.into() {
280 return Ok(false);
281 }
282 }
283 #[allow(clippy::collapsible_if)]
284 if requirements.contains(SessionUsabilityRequirements::Spqr) {
285 if self.pq_ratchet_state().is_empty() {
286 return Ok(false);
287 }
288 }
289 Ok(true)
290 }
291
292 pub(crate) fn all_receiver_chain_logging_info(&self) -> Vec<(Vec<u8>, Option<u32>)> {
293 let mut results = vec![];
294 for chain in self.session.receiver_chains.iter() {
295 let sender_ratchet_public = chain.sender_ratchet_key.clone();
296
297 let chain_key_idx = chain.chain_key.as_ref().map(|chain_key| chain_key.index);
298
299 results.push((sender_ratchet_public, chain_key_idx))
300 }
301 results
302 }
303
304 pub(crate) fn get_receiver_chain(
305 &self,
306 sender: &PublicKey,
307 ) -> Result<Option<(session_structure::Chain, usize)>, InvalidSessionError> {
308 for (idx, chain) in self.session.receiver_chains.iter().enumerate() {
309 let chain_ratchet_key = PublicKey::deserialize(&chain.sender_ratchet_key)
312 .map_err(|_| InvalidSessionError("invalid receiver chain ratchet key"))?;
313
314 if &chain_ratchet_key == sender {
315 return Ok(Some((chain.clone(), idx)));
316 }
317 }
318
319 Ok(None)
320 }
321
322 pub(crate) fn get_receiver_chain_key(
323 &self,
324 sender: &PublicKey,
325 ) -> Result<Option<ChainKey>, InvalidSessionError> {
326 match self.get_receiver_chain(sender)? {
327 None => Ok(None),
328 Some((chain, _)) => match chain.chain_key {
329 None => Err(InvalidSessionError("missing receiver chain key")),
330 Some(c) => {
331 let chain_key_bytes = c.key[..]
332 .try_into()
333 .map_err(|_| InvalidSessionError("invalid receiver chain key"))?;
334 Ok(Some(ChainKey::new(chain_key_bytes, c.index)))
335 }
336 },
337 }
338 }
339
340 pub(crate) fn add_receiver_chain(&mut self, sender: &PublicKey, chain_key: &ChainKey) {
341 let chain_key = session_structure::chain::ChainKey {
342 index: chain_key.index(),
343 key: chain_key.key().to_vec(),
344 };
345
346 let chain = session_structure::Chain {
347 sender_ratchet_key: sender.serialize().to_vec(),
348 sender_ratchet_key_private: vec![],
349 chain_key: Some(chain_key),
350 message_keys: vec![],
351 };
352
353 self.session.receiver_chains.push(chain);
354
355 if self.session.receiver_chains.len() > consts::MAX_RECEIVER_CHAINS {
356 log::info!(
357 "Trimming excessive receiver_chain for session with base key {}, chain count: {}",
358 self.sender_ratchet_key_for_logging()
359 .unwrap_or_else(|e| format!("<error: {}>", e.0)),
360 self.session.receiver_chains.len()
361 );
362 self.session.receiver_chains.remove(0);
363 }
364 }
365
366 pub(crate) fn with_receiver_chain(mut self, sender: &PublicKey, chain_key: &ChainKey) -> Self {
367 self.add_receiver_chain(sender, chain_key);
368 self
369 }
370
371 pub(crate) fn set_sender_chain(&mut self, sender: &KeyPair, next_chain_key: &ChainKey) {
372 let chain_key = session_structure::chain::ChainKey {
373 index: next_chain_key.index(),
374 key: next_chain_key.key().to_vec(),
375 };
376
377 let new_chain = session_structure::Chain {
378 sender_ratchet_key: sender.public_key.serialize().to_vec(),
379 sender_ratchet_key_private: sender.private_key.serialize().to_vec(),
380 chain_key: Some(chain_key),
381 message_keys: vec![],
382 };
383
384 self.session.sender_chain = Some(new_chain);
385 }
386
387 pub(crate) fn with_sender_chain(mut self, sender: &KeyPair, next_chain_key: &ChainKey) -> Self {
388 self.set_sender_chain(sender, next_chain_key);
389 self
390 }
391
392 pub(crate) fn get_sender_chain_key(&self) -> Result<ChainKey, InvalidSessionError> {
393 let sender_chain = self
394 .session
395 .sender_chain
396 .as_ref()
397 .ok_or(InvalidSessionError("missing sender chain"))?;
398
399 let chain_key = sender_chain
400 .chain_key
401 .as_ref()
402 .ok_or(InvalidSessionError("missing sender chain key"))?;
403
404 let chain_key_bytes = chain_key.key[..]
405 .try_into()
406 .map_err(|_| InvalidSessionError("invalid sender chain key"))?;
407
408 Ok(ChainKey::new(chain_key_bytes, chain_key.index))
409 }
410
411 pub(crate) fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, InvalidSessionError> {
412 Ok(self.get_sender_chain_key()?.key().to_vec())
413 }
414
415 pub(crate) fn set_sender_chain_key(&mut self, next_chain_key: &ChainKey) {
416 let chain_key = session_structure::chain::ChainKey {
417 index: next_chain_key.index(),
418 key: next_chain_key.key().to_vec(),
419 };
420
421 let new_chain = match self.session.sender_chain.take() {
424 None => session_structure::Chain {
425 sender_ratchet_key: vec![],
426 sender_ratchet_key_private: vec![],
427 chain_key: Some(chain_key),
428 message_keys: vec![],
429 },
430 Some(mut c) => {
431 c.chain_key = Some(chain_key);
432 c
433 }
434 };
435
436 self.session.sender_chain = Some(new_chain);
437 }
438
439 #[cfg(test)]
440 pub(crate) fn get_message_keys(
441 &mut self,
442 sender: &PublicKey,
443 counter: u32,
444 ) -> Result<Option<MessageKeyGenerator>, InvalidSessionError> {
445 if let Some(mut chain_and_index) = self.get_receiver_chain(sender)? {
446 let message_key_idx = chain_and_index
447 .0
448 .message_keys
449 .iter()
450 .position(|m| m.index == counter);
451
452 if let Some(position) = message_key_idx {
453 let message_key = chain_and_index.0.message_keys.remove(position);
454 let keys =
455 MessageKeyGenerator::from_pb(message_key).map_err(InvalidSessionError)?;
456
457 self.session.receiver_chains[chain_and_index.1] = chain_and_index.0;
459 return Ok(Some(keys));
460 }
461 }
462
463 Ok(None)
464 }
465
466 #[cfg(test)]
467 pub(crate) fn set_message_keys(
468 &mut self,
469 sender: &PublicKey,
470 message_keys: MessageKeyGenerator,
471 ) -> Result<(), InvalidSessionError> {
472 let chain_and_index = self
473 .get_receiver_chain(sender)?
474 .expect("called set_message_keys for a non-existent chain");
475 let mut updated_chain = chain_and_index.0;
476 updated_chain.message_keys.insert(0, message_keys.into_pb());
477
478 if updated_chain.message_keys.len() > consts::MAX_MESSAGE_KEYS {
479 updated_chain.message_keys.pop();
480 }
481
482 self.session.receiver_chains[chain_and_index.1] = updated_chain;
483
484 Ok(())
485 }
486
487 #[cfg(test)]
488 pub(crate) fn set_receiver_chain_key(
489 &mut self,
490 sender: &PublicKey,
491 chain_key: &ChainKey,
492 ) -> Result<(), InvalidSessionError> {
493 let chain_and_index = self
494 .get_receiver_chain(sender)?
495 .expect("called set_receiver_chain_key for a non-existent chain");
496 let mut updated_chain = chain_and_index.0;
497 updated_chain.chain_key = Some(session_structure::chain::ChainKey {
498 index: chain_key.index(),
499 key: chain_key.key().to_vec(),
500 });
501
502 self.session.receiver_chains[chain_and_index.1] = updated_chain;
503
504 Ok(())
505 }
506
507 pub(crate) fn set_unacknowledged_pre_key_message(
508 &mut self,
509 pre_key_id: Option<PreKeyId>,
510 signed_ec_pre_key_id: SignedPreKeyId,
511 base_key: &PublicKey,
512 now: SystemTime,
513 ) {
514 let signed_ec_pre_key_id: u32 = signed_ec_pre_key_id.into();
515 let pending = session_structure::PendingPreKey {
516 pre_key_id: pre_key_id.map(PreKeyId::into),
517 signed_pre_key_id: signed_ec_pre_key_id as i32,
518 base_key: base_key.serialize().to_vec(),
519 timestamp: now
520 .duration_since(SystemTime::UNIX_EPOCH)
521 .unwrap_or_default()
522 .as_secs(),
523 };
524 self.session.pending_pre_key = Some(pending);
525 }
526
527 pub(crate) fn set_kyber_ciphertext(&mut self, ciphertext: kem::SerializedCiphertext) {
528 let pending = session_structure::PendingKyberPreKey {
529 pre_key_id: u32::MAX, ciphertext: ciphertext.into_vec(),
531 };
532 self.session.pending_kyber_pre_key = Some(pending);
533 }
534
535 pub(crate) fn set_unacknowledged_kyber_pre_key_id(
536 &mut self,
537 signed_kyber_pre_key_id: KyberPreKeyId,
538 ) {
539 let pending = self
540 .session
541 .pending_kyber_pre_key
542 .as_mut()
543 .expect("must have been set if kyber pre key is present");
544 pending.pre_key_id = signed_kyber_pre_key_id.into();
545 }
546
547 pub(crate) fn unacknowledged_pre_key_message_items(
548 &self,
549 ) -> Result<Option<UnacknowledgedPreKeyMessageItems<'_>>, InvalidSessionError> {
550 if let Some(ref pending_pre_key) = self.session.pending_pre_key {
551 Ok(Some(UnacknowledgedPreKeyMessageItems::new(
552 pending_pre_key.pre_key_id.map(Into::into),
553 (pending_pre_key.signed_pre_key_id as u32).into(),
554 PublicKey::deserialize(&pending_pre_key.base_key)
555 .map_err(|_| InvalidSessionError("invalid pending PreKey message base key"))?,
556 self.session.pending_kyber_pre_key.as_ref(),
557 SystemTime::UNIX_EPOCH + Duration::from_secs(pending_pre_key.timestamp),
558 )))
559 } else {
560 Ok(None)
561 }
562 }
563
564 pub(crate) fn clear_unacknowledged_pre_key_message(&mut self) {
565 let SessionStructure {
568 session_version: _session_version,
569 local_identity_public: _local_identity_public,
570 remote_identity_public: _remote_identity_public,
571 root_key: _root_key,
572 previous_counter: _previous_counter,
573 sender_chain: _sender_chain,
574 receiver_chains: _receiver_chains,
575 pending_pre_key: _pending_pre_key,
576 pending_kyber_pre_key: _pending_kyber_pre_key,
577 remote_registration_id: _remote_registration_id,
578 local_registration_id: _local_registration_id,
579 alice_base_key: _alice_base_key,
580 pq_ratchet_state: _pq_ratchet_state,
581 } = &self.session;
582 self.session.pending_pre_key = None;
586 self.session.pending_kyber_pre_key = None;
587 }
588
589 pub(crate) fn set_remote_registration_id(&mut self, registration_id: u32) {
590 self.session.remote_registration_id = registration_id;
591 }
592
593 pub(crate) fn remote_registration_id(&self) -> u32 {
594 self.session.remote_registration_id
595 }
596
597 pub(crate) fn set_local_registration_id(&mut self, registration_id: u32) {
598 self.session.local_registration_id = registration_id;
599 }
600
601 pub(crate) fn local_registration_id(&self) -> u32 {
602 self.session.local_registration_id
603 }
604
605 pub(crate) fn get_kyber_ciphertext(&self) -> Option<&Vec<u8>> {
606 self.session
607 .pending_kyber_pre_key
608 .as_ref()
609 .map(|pending| &pending.ciphertext)
610 }
611
612 pub(crate) fn take_ratchet_state(
624 &mut self,
625 self_session: bool,
626 ) -> crate::error::Result<crate::double_ratchet::RatchetState> {
627 let receiver_chains = std::mem::take(&mut self.session.receiver_chains);
628 Ok(crate::double_ratchet::RatchetState::from_pb(
629 &self.session,
630 self_session,
631 receiver_chains,
632 )?)
633 }
634
635 pub(crate) fn apply_ratchet_state(&mut self, ratchet: crate::double_ratchet::RatchetState) {
641 ratchet.apply_to_pb(&mut self.session);
642 }
643
644 #[cfg(test)]
645 pub(crate) fn pq_ratchet_recv(
646 &mut self,
647 msg: &spqr::SerializedMessage,
648 ) -> Result<spqr::MessageKey, spqr::Error> {
649 let _trace = libsignal_debug::trace_block!("SessionState::pq_ratchet_recv");
650 let spqr::Recv { state, key } = spqr::recv(&self.session.pq_ratchet_state, msg)?;
651 self.session.pq_ratchet_state = state;
652 Ok(key)
653 }
654
655 #[cfg(test)]
656 pub(crate) fn pq_ratchet_send<R: Rng + CryptoRng>(
657 &mut self,
658 csprng: &mut R,
659 ) -> Result<(spqr::SerializedMessage, spqr::MessageKey), spqr::Error> {
660 let _trace = libsignal_debug::trace_block!("SessionState::pq_ratchet_send");
661 let spqr::Send { state, key, msg } = spqr::send(&self.session.pq_ratchet_state, csprng)?;
662 self.session.pq_ratchet_state = state;
663 Ok((msg, key))
664 }
665
666 pub(crate) fn pq_ratchet_state(&self) -> &spqr::SerializedState {
667 &self.session.pq_ratchet_state
668 }
669
670 pub(crate) fn take_pq_ratchet_state(&mut self) -> spqr::SerializedState {
676 std::mem::take(&mut self.session.pq_ratchet_state)
677 }
678
679 pub(crate) fn set_pq_ratchet_state(&mut self, state: spqr::SerializedState) {
680 self.session.pq_ratchet_state = state;
681 }
682}
683
684impl From<SessionStructure> for SessionState {
685 fn from(value: SessionStructure) -> SessionState {
686 SessionState::from_session_structure(value)
687 }
688}
689
690impl From<SessionState> for SessionStructure {
691 fn from(value: SessionState) -> SessionStructure {
692 value.session
693 }
694}
695
696impl From<&SessionState> for SessionStructure {
697 fn from(value: &SessionState) -> SessionStructure {
698 value.session.clone()
699 }
700}
701
702#[derive(Clone)]
703pub struct SessionRecord {
704 current_session: Option<SessionState>,
705 previous_sessions: Vec<Vec<u8>>,
706}
707
708impl SessionRecord {
709 pub fn new_fresh() -> Self {
710 Self {
711 current_session: None,
712 previous_sessions: Vec::new(),
713 }
714 }
715
716 pub(crate) fn new(state: SessionState) -> Self {
717 Self {
718 current_session: Some(state),
719 previous_sessions: Vec::new(),
720 }
721 }
722
723 pub fn deserialize(bytes: &[u8]) -> Result<Self, SignalProtocolError> {
724 let record = RecordStructure::decode(bytes)
725 .map_err(|_| InvalidSessionError("failed to decode session record protobuf"))?;
726
727 Ok(Self {
728 current_session: record.current_session.map(|s| s.into()),
729 previous_sessions: record.previous_sessions,
730 })
731 }
732
733 pub(crate) fn promote_matching_session(
740 &mut self,
741 version: u32,
742 alice_base_key: &[u8],
743 ) -> Result<bool, InvalidSessionError> {
744 if let Some(current_session) = &self.current_session {
745 if current_session.session_version()? == version
746 && alice_base_key
747 .ct_eq(current_session.alice_base_key())
748 .into()
749 {
750 return Ok(true);
751 }
752 }
753
754 let mut session_to_promote = None;
755 for (i, previous) in self.previous_session_states().enumerate() {
756 let previous = previous?;
757 if previous.session_version()? == version
758 && alice_base_key.ct_eq(previous.alice_base_key()).into()
759 {
760 session_to_promote = Some((i, previous));
761 break;
762 }
763 }
764
765 if let Some((i, state)) = session_to_promote {
766 self.promote_old_session(i, state);
767 return Ok(true);
768 }
769
770 Ok(false)
771 }
772
773 pub(crate) fn session_state(&self) -> Option<&SessionState> {
774 self.current_session.as_ref()
775 }
776
777 pub(crate) fn session_state_mut(&mut self) -> Option<&mut SessionState> {
778 self.current_session.as_mut()
779 }
780
781 pub(crate) fn set_session_state(&mut self, session: SessionState) {
782 self.current_session = Some(session);
783 }
784
785 pub(crate) fn previous_session_states(
786 &self,
787 ) -> impl ExactSizeIterator<Item = Result<SessionState, InvalidSessionError>> + '_ {
788 self.previous_sessions.iter().map(|bytes| {
789 Ok(SessionStructure::decode(&bytes[..])
790 .map_err(|_| InvalidSessionError("failed to decode previous session protobuf"))?
791 .into())
792 })
793 }
794
795 pub(crate) fn promote_old_session(
796 &mut self,
797 old_session: usize,
798 updated_session: SessionState,
799 ) {
800 self.previous_sessions.remove(old_session);
801 self.promote_state(updated_session)
802 }
803
804 pub(crate) fn promote_state(&mut self, new_state: SessionState) {
805 self.archive_current_state_inner();
806 self.current_session = Some(new_state);
807 }
808
809 fn archive_current_state_inner(&mut self) -> bool {
813 if let Some(mut current_session) = self.current_session.take() {
814 if self.previous_sessions.len() >= consts::ARCHIVED_STATES_MAX_LENGTH {
815 self.previous_sessions.pop();
816 }
817 current_session.clear_unacknowledged_pre_key_message();
818 self.previous_sessions
819 .insert(0, current_session.session.encode_to_vec());
820 true
821 } else {
822 false
823 }
824 }
825
826 pub fn archive_current_state(&mut self) -> Result<(), SignalProtocolError> {
827 if !self.archive_current_state_inner() {
828 log::info!("Skipping archive, current session state is fresh");
829 }
830 Ok(())
831 }
832
833 pub fn serialize(&self) -> Result<Vec<u8>, SignalProtocolError> {
834 let record = RecordStructure {
835 current_session: self.current_session.as_ref().map(|s| s.into()),
836 previous_sessions: self.previous_sessions.clone(),
837 };
838 Ok(record.encode_to_vec())
839 }
840
841 pub fn current_pq_state(&self) -> Option<&spqr::SerializedState> {
842 self.current_session.as_ref().map(|s| s.pq_ratchet_state())
843 }
844
845 pub fn remote_registration_id(&self) -> Result<u32, SignalProtocolError> {
846 Ok(self
847 .session_state()
848 .ok_or_else(|| {
849 SignalProtocolError::InvalidState(
850 "remote_registration_id",
851 "No current session".into(),
852 )
853 })?
854 .remote_registration_id())
855 }
856
857 pub fn local_registration_id(&self) -> Result<u32, SignalProtocolError> {
858 Ok(self
859 .session_state()
860 .ok_or_else(|| {
861 SignalProtocolError::InvalidState(
862 "local_registration_id",
863 "No current session".into(),
864 )
865 })?
866 .local_registration_id())
867 }
868
869 pub fn session_version(&self) -> Result<u32, SignalProtocolError> {
870 Ok(self
871 .session_state()
872 .ok_or_else(|| {
873 SignalProtocolError::InvalidState("session_version", "No current session".into())
874 })?
875 .session_version()?)
876 }
877
878 pub fn local_identity_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
879 Ok(self
880 .session_state()
881 .ok_or_else(|| {
882 SignalProtocolError::InvalidState(
883 "local_identity_key_bytes",
884 "No current session".into(),
885 )
886 })?
887 .local_identity_key_bytes()?)
888 }
889
890 pub fn remote_identity_key_bytes(&self) -> Result<Option<Vec<u8>>, SignalProtocolError> {
891 Ok(self
892 .session_state()
893 .ok_or_else(|| {
894 SignalProtocolError::InvalidState(
895 "remote_identity_key_bytes",
896 "No current session".into(),
897 )
898 })?
899 .remote_identity_key_bytes()?)
900 }
901
902 pub fn has_usable_sender_chain(
903 &self,
904 now: SystemTime,
905 requirements: SessionUsabilityRequirements,
906 ) -> Result<bool, SignalProtocolError> {
907 match &self.current_session {
908 Some(session) => Ok(session.has_usable_sender_chain(now, requirements)?),
909 None => Ok(false),
910 }
911 }
912
913 pub fn alice_base_key(&self) -> Result<&[u8], SignalProtocolError> {
914 Ok(self
915 .session_state()
916 .ok_or_else(|| {
917 SignalProtocolError::InvalidState("alice_base_key", "No current session".into())
918 })?
919 .alice_base_key())
920 }
921
922 pub fn get_receiver_chain_key_bytes(
923 &self,
924 sender: &PublicKey,
925 ) -> Result<Option<Box<[u8]>>, SignalProtocolError> {
926 Ok(self
927 .session_state()
928 .ok_or_else(|| {
929 SignalProtocolError::InvalidState(
930 "get_receiver_chain_key",
931 "No current session".into(),
932 )
933 })?
934 .get_receiver_chain_key(sender)?
935 .map(|chain| chain.key()[..].into()))
936 }
937
938 pub fn get_sender_chain_key_bytes(&self) -> Result<Vec<u8>, SignalProtocolError> {
939 Ok(self
940 .session_state()
941 .ok_or_else(|| {
942 SignalProtocolError::InvalidState(
943 "get_sender_chain_key_bytes",
944 "No current session".into(),
945 )
946 })?
947 .get_sender_chain_key_bytes()?)
948 }
949
950 pub fn current_ratchet_key_matches(
951 &self,
952 key: &PublicKey,
953 ) -> Result<bool, SignalProtocolError> {
954 match &self.current_session {
955 Some(session) => Ok(&session.sender_ratchet_key()? == key),
956 None => Ok(false),
957 }
958 }
959
960 pub fn get_kyber_ciphertext(&self) -> Result<Option<&Vec<u8>>, SignalProtocolError> {
961 Ok(self
962 .session_state()
963 .ok_or_else(|| {
964 SignalProtocolError::InvalidState(
965 "get_kyber_ciphertext",
966 "No current session".into(),
967 )
968 })?
969 .get_kyber_ciphertext())
970 }
971}