Skip to main content

libsignal_protocol/
protocol.rs

1//
2// Copyright 2020-2021 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use hmac::{Hmac, Mac};
7use prost::Message;
8use rand::{CryptoRng, Rng};
9use sha2::Sha256;
10use subtle::ConstantTimeEq;
11use uuid::Uuid;
12
13use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
14use crate::{
15    IdentityKey, PrivateKey, ProtocolAddress, PublicKey, Result, ServiceId, SignalProtocolError,
16    Timestamp, kem, proto,
17};
18
19pub(crate) const CIPHERTEXT_MESSAGE_CURRENT_VERSION: u8 = 4;
20// Backward compatible, lacking Kyber keys, version
21pub(crate) const CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION: u8 = 3;
22pub(crate) const SENDERKEY_MESSAGE_CURRENT_VERSION: u8 = 3;
23
24#[derive(Debug, Clone)]
25pub enum CiphertextMessage {
26    SignalMessage(SignalMessage),
27    PreKeySignalMessage(PreKeySignalMessage),
28    SenderKeyMessage(SenderKeyMessage),
29    PlaintextContent(PlaintextContent),
30}
31
32#[derive(Copy, Clone, Eq, PartialEq, Debug, derive_more::TryFrom)]
33#[repr(u8)]
34#[try_from(repr)]
35pub enum CiphertextMessageType {
36    Whisper = 2,
37    PreKey = 3,
38    SenderKey = 7,
39    Plaintext = 8,
40}
41
42impl CiphertextMessage {
43    pub fn message_type(&self) -> CiphertextMessageType {
44        match self {
45            CiphertextMessage::SignalMessage(_) => CiphertextMessageType::Whisper,
46            CiphertextMessage::PreKeySignalMessage(_) => CiphertextMessageType::PreKey,
47            CiphertextMessage::SenderKeyMessage(_) => CiphertextMessageType::SenderKey,
48            CiphertextMessage::PlaintextContent(_) => CiphertextMessageType::Plaintext,
49        }
50    }
51
52    pub fn serialize(&self) -> &[u8] {
53        match self {
54            CiphertextMessage::SignalMessage(x) => x.serialized(),
55            CiphertextMessage::PreKeySignalMessage(x) => x.serialized(),
56            CiphertextMessage::SenderKeyMessage(x) => x.serialized(),
57            CiphertextMessage::PlaintextContent(x) => x.serialized(),
58        }
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct SignalMessage {
64    message_version: u8,
65    sender_ratchet_key: PublicKey,
66    counter: u32,
67    #[cfg_attr(not(test), expect(dead_code))]
68    previous_counter: u32,
69    ciphertext: Box<[u8]>,
70    pq_ratchet: spqr::SerializedState,
71    addresses: Option<Box<[u8]>>,
72    serialized: Box<[u8]>,
73}
74
75impl SignalMessage {
76    const MAC_LENGTH: usize = 8;
77
78    #[allow(clippy::too_many_arguments)]
79    pub fn new(
80        message_version: u8,
81        mac_key: &[u8],
82        addresses: Option<(&ProtocolAddress, &ProtocolAddress)>,
83        sender_ratchet_key: PublicKey,
84        counter: u32,
85        previous_counter: u32,
86        ciphertext: &[u8],
87        sender_identity_key: &IdentityKey,
88        receiver_identity_key: &IdentityKey,
89        pq_ratchet: &[u8],
90    ) -> Result<Self> {
91        let addresses =
92            addresses.and_then(|(sender, recipient)| Self::serialize_addresses(sender, recipient));
93        let message = proto::wire::SignalMessage {
94            ratchet_key: Some(sender_ratchet_key.serialize().into_vec()),
95            counter: Some(counter),
96            previous_counter: Some(previous_counter),
97            ciphertext: Some(Vec::<u8>::from(ciphertext)),
98            pq_ratchet: if pq_ratchet.is_empty() {
99                None
100            } else {
101                Some(pq_ratchet.to_vec())
102            },
103            addresses,
104        };
105        let mut serialized = Vec::with_capacity(1 + message.encoded_len() + Self::MAC_LENGTH);
106        serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
107        message
108            .encode(&mut serialized)
109            .expect("can always append to a buffer");
110        let mac = Self::compute_mac(
111            sender_identity_key,
112            receiver_identity_key,
113            mac_key,
114            &serialized,
115        )?;
116        serialized.extend_from_slice(&mac);
117        let serialized = serialized.into_boxed_slice();
118        Ok(Self {
119            message_version,
120            sender_ratchet_key,
121            counter,
122            previous_counter,
123            ciphertext: ciphertext.into(),
124            pq_ratchet: pq_ratchet.to_vec(),
125            addresses: message.addresses.map(Into::into),
126            serialized,
127        })
128    }
129
130    #[inline]
131    pub fn message_version(&self) -> u8 {
132        self.message_version
133    }
134
135    #[inline]
136    pub fn sender_ratchet_key(&self) -> &PublicKey {
137        &self.sender_ratchet_key
138    }
139
140    #[inline]
141    pub fn counter(&self) -> u32 {
142        self.counter
143    }
144
145    #[inline]
146    pub fn pq_ratchet(&self) -> &spqr::SerializedMessage {
147        &self.pq_ratchet
148    }
149
150    #[inline]
151    pub fn serialized(&self) -> &[u8] {
152        &self.serialized
153    }
154
155    #[inline]
156    pub fn body(&self) -> &[u8] {
157        &self.ciphertext
158    }
159
160    pub(crate) fn verify_mac(
161        &self,
162        sender_identity_key: &IdentityKey,
163        receiver_identity_key: &IdentityKey,
164        mac_key: &[u8],
165    ) -> Result<bool> {
166        let (content, their_mac) = self
167            .serialized
168            .split_last_chunk::<{ Self::MAC_LENGTH }>()
169            .expect("length checked at construction");
170        let our_mac =
171            Self::compute_mac(sender_identity_key, receiver_identity_key, mac_key, content)?;
172        let result: bool = our_mac.ct_eq(their_mac).into();
173        if !result {
174            // A warning instead of an error because we try multiple sessions.
175            log::warn!(
176                "Bad Mac! Their Mac: {} Our Mac: {}",
177                hex::encode(their_mac),
178                hex::encode(our_mac)
179            );
180            return Ok(false);
181        }
182
183        Ok(true)
184    }
185
186    pub fn verify_mac_with_addresses(
187        &self,
188        sender_address: &ProtocolAddress,
189        recipient_address: &ProtocolAddress,
190        sender_identity_key: &IdentityKey,
191        receiver_identity_key: &IdentityKey,
192        mac_key: &[u8],
193    ) -> Result<bool> {
194        if !self.verify_mac(sender_identity_key, receiver_identity_key, mac_key)? {
195            return Ok(false);
196        }
197
198        // If the sender didn't include addresses, accept the message for
199        // backward compatibility with older clients.
200        let Some(encoded_addresses) = &self.addresses else {
201            return Ok(true);
202        };
203
204        let Some(expected) = Self::serialize_addresses(sender_address, recipient_address) else {
205            log::warn!(
206                "Locally supplied addresses not valid Service IDs: sender={}, recipient={}",
207                sender_address,
208                recipient_address,
209            );
210            return Ok(false);
211        };
212
213        if bool::from(expected.ct_eq(encoded_addresses.as_ref())) {
214            Ok(true)
215        } else {
216            log::warn!(
217                "Address mismatch: sender={}, recipient={}",
218                sender_address,
219                recipient_address,
220            );
221            Ok(false)
222        }
223    }
224
225    fn compute_mac(
226        sender_identity_key: &IdentityKey,
227        receiver_identity_key: &IdentityKey,
228        mac_key: &[u8],
229        message: &[u8],
230    ) -> Result<[u8; Self::MAC_LENGTH]> {
231        if mac_key.len() != 32 {
232            return Err(SignalProtocolError::InvalidMacKeyLength(mac_key.len()));
233        }
234        let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
235            .expect("HMAC-SHA256 should accept any size key");
236
237        mac.update(sender_identity_key.public_key().serialize().as_ref());
238        mac.update(receiver_identity_key.public_key().serialize().as_ref());
239        mac.update(message);
240        let result = *mac
241            .finalize()
242            .into_bytes()
243            .first_chunk()
244            .expect("enough bytes");
245        Ok(result)
246    }
247
248    /// Serializes sender and recipient addresses into a single byte vector.
249    /// Returns `None` if either address name is not a valid ServiceId.
250    fn serialize_addresses(
251        sender: &ProtocolAddress,
252        recipient: &ProtocolAddress,
253    ) -> Option<Vec<u8>> {
254        let sender_service_id = ServiceId::parse_from_service_id_string(sender.name())?;
255        let recipient_service_id = ServiceId::parse_from_service_id_string(recipient.name())?;
256
257        let mut bytes = Vec::with_capacity(36);
258        bytes.extend_from_slice(&sender_service_id.service_id_fixed_width_binary());
259        bytes.push(sender.device_id().into());
260        bytes.extend_from_slice(&recipient_service_id.service_id_fixed_width_binary());
261        bytes.push(recipient.device_id().into());
262        Some(bytes)
263    }
264}
265
266impl AsRef<[u8]> for SignalMessage {
267    fn as_ref(&self) -> &[u8] {
268        &self.serialized
269    }
270}
271
272impl TryFrom<&[u8]> for SignalMessage {
273    type Error = SignalProtocolError;
274
275    fn try_from(value: &[u8]) -> Result<Self> {
276        if value.len() < SignalMessage::MAC_LENGTH + 1 {
277            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
278        }
279        let message_version = value[0] >> 4;
280        if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
281            return Err(SignalProtocolError::LegacyCiphertextVersion(
282                message_version,
283            ));
284        }
285        if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
286            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
287                message_version,
288            ));
289        }
290
291        let proto_structure =
292            proto::wire::SignalMessage::decode(&value[1..value.len() - SignalMessage::MAC_LENGTH])
293                .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
294
295        let sender_ratchet_key = proto_structure
296            .ratchet_key
297            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
298        let sender_ratchet_key = PublicKey::deserialize(&sender_ratchet_key)?;
299        let counter = proto_structure
300            .counter
301            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
302        let previous_counter = proto_structure.previous_counter.unwrap_or(0);
303        let ciphertext = proto_structure
304            .ciphertext
305            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
306            .into_boxed_slice();
307
308        Ok(SignalMessage {
309            message_version,
310            sender_ratchet_key,
311            counter,
312            previous_counter,
313            ciphertext,
314            pq_ratchet: proto_structure.pq_ratchet.unwrap_or(vec![]),
315            addresses: proto_structure.addresses.map(Into::into),
316            serialized: Box::from(value),
317        })
318    }
319}
320
321#[derive(Debug, Clone)]
322pub struct KyberPayload {
323    pre_key_id: KyberPreKeyId,
324    ciphertext: kem::SerializedCiphertext,
325}
326
327impl KyberPayload {
328    pub fn new(id: KyberPreKeyId, ciphertext: kem::SerializedCiphertext) -> Self {
329        Self {
330            pre_key_id: id,
331            ciphertext,
332        }
333    }
334}
335
336#[derive(Debug, Clone)]
337pub struct PreKeySignalMessage {
338    message_version: u8,
339    registration_id: u32,
340    pre_key_id: Option<PreKeyId>,
341    signed_pre_key_id: SignedPreKeyId,
342    // While we reject messages without Kyber payloads, we still for now allow constructing the
343    // struct without one so that we can provide a better error message when we try to process it.
344    kyber_payload: Option<KyberPayload>,
345    base_key: PublicKey,
346    identity_key: IdentityKey,
347    message: SignalMessage,
348    serialized: Box<[u8]>,
349}
350
351impl PreKeySignalMessage {
352    pub fn new(
353        message_version: u8,
354        registration_id: u32,
355        pre_key_id: Option<PreKeyId>,
356        signed_pre_key_id: SignedPreKeyId,
357        kyber_payload: Option<KyberPayload>,
358        base_key: PublicKey,
359        identity_key: IdentityKey,
360        message: SignalMessage,
361    ) -> Result<Self> {
362        let proto_message = proto::wire::PreKeySignalMessage {
363            registration_id: Some(registration_id),
364            pre_key_id: pre_key_id.map(|id| id.into()),
365            signed_pre_key_id: Some(signed_pre_key_id.into()),
366            kyber_pre_key_id: kyber_payload.as_ref().map(|kyber| kyber.pre_key_id.into()),
367            kyber_ciphertext: kyber_payload
368                .as_ref()
369                .map(|kyber| kyber.ciphertext.to_vec()),
370            base_key: Some(base_key.serialize().into_vec()),
371            identity_key: Some(identity_key.serialize().into_vec()),
372            message: Some(Vec::from(message.as_ref())),
373        };
374        let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
375        serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
376        proto_message
377            .encode(&mut serialized)
378            .expect("can always append to a Vec");
379        Ok(Self {
380            message_version,
381            registration_id,
382            pre_key_id,
383            signed_pre_key_id,
384            kyber_payload,
385            base_key,
386            identity_key,
387            message,
388            serialized: serialized.into_boxed_slice(),
389        })
390    }
391
392    #[inline]
393    pub fn message_version(&self) -> u8 {
394        self.message_version
395    }
396
397    #[inline]
398    pub fn registration_id(&self) -> u32 {
399        self.registration_id
400    }
401
402    #[inline]
403    pub fn pre_key_id(&self) -> Option<PreKeyId> {
404        self.pre_key_id
405    }
406
407    #[inline]
408    pub fn signed_pre_key_id(&self) -> SignedPreKeyId {
409        self.signed_pre_key_id
410    }
411
412    #[inline]
413    pub fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
414        self.kyber_payload.as_ref().map(|kyber| kyber.pre_key_id)
415    }
416
417    #[inline]
418    pub fn kyber_ciphertext(&self) -> Option<&kem::SerializedCiphertext> {
419        self.kyber_payload.as_ref().map(|kyber| &kyber.ciphertext)
420    }
421
422    #[inline]
423    pub fn base_key(&self) -> &PublicKey {
424        &self.base_key
425    }
426
427    #[inline]
428    pub fn identity_key(&self) -> &IdentityKey {
429        &self.identity_key
430    }
431
432    #[inline]
433    pub fn message(&self) -> &SignalMessage {
434        &self.message
435    }
436
437    #[inline]
438    pub fn serialized(&self) -> &[u8] {
439        &self.serialized
440    }
441}
442
443impl AsRef<[u8]> for PreKeySignalMessage {
444    fn as_ref(&self) -> &[u8] {
445        &self.serialized
446    }
447}
448
449impl TryFrom<&[u8]> for PreKeySignalMessage {
450    type Error = SignalProtocolError;
451
452    fn try_from(value: &[u8]) -> Result<Self> {
453        if value.is_empty() {
454            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
455        }
456
457        let message_version = value[0] >> 4;
458        if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
459            return Err(SignalProtocolError::LegacyCiphertextVersion(
460                message_version,
461            ));
462        }
463        if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
464            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
465                message_version,
466            ));
467        }
468
469        let proto_structure = proto::wire::PreKeySignalMessage::decode(&value[1..])
470            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
471
472        let base_key = proto_structure
473            .base_key
474            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
475        let identity_key = proto_structure
476            .identity_key
477            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
478        let message = proto_structure
479            .message
480            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
481        let signed_pre_key_id = proto_structure
482            .signed_pre_key_id
483            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
484
485        let base_key = PublicKey::deserialize(base_key.as_ref())?;
486
487        let kyber_payload = match (
488            proto_structure.kyber_pre_key_id,
489            proto_structure.kyber_ciphertext,
490        ) {
491            (Some(id), Some(ct)) => Some(KyberPayload::new(id.into(), ct.into_boxed_slice())),
492            (None, None) if message_version <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION => None,
493            (None, None) => {
494                return Err(SignalProtocolError::InvalidMessage(
495                    CiphertextMessageType::PreKey,
496                    "Kyber pre key must be present for this session version",
497                ));
498            }
499            _ => {
500                return Err(SignalProtocolError::InvalidMessage(
501                    CiphertextMessageType::PreKey,
502                    "Both or neither kyber pre_key_id and kyber_ciphertext can be present",
503                ));
504            }
505        };
506
507        Ok(PreKeySignalMessage {
508            message_version,
509            registration_id: proto_structure.registration_id.unwrap_or(0),
510            pre_key_id: proto_structure.pre_key_id.map(|id| id.into()),
511            signed_pre_key_id: signed_pre_key_id.into(),
512            kyber_payload,
513            base_key,
514            identity_key: IdentityKey::try_from(identity_key.as_ref())?,
515            message: SignalMessage::try_from(message.as_ref())?,
516            serialized: Box::from(value),
517        })
518    }
519}
520
521#[derive(Debug, Clone)]
522pub struct SenderKeyMessage {
523    message_version: u8,
524    distribution_id: Uuid,
525    chain_id: u32,
526    iteration: u32,
527    ciphertext: Box<[u8]>,
528    serialized: Box<[u8]>,
529}
530
531impl SenderKeyMessage {
532    const SIGNATURE_LEN: usize = 64;
533
534    pub fn new<R: CryptoRng + Rng>(
535        message_version: u8,
536        distribution_id: Uuid,
537        chain_id: u32,
538        iteration: u32,
539        ciphertext: Box<[u8]>,
540        csprng: &mut R,
541        signature_key: &PrivateKey,
542    ) -> Result<Self> {
543        let proto_message = proto::wire::SenderKeyMessage {
544            distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
545            chain_id: Some(chain_id),
546            iteration: Some(iteration),
547            ciphertext: Some(ciphertext.to_vec()),
548        };
549        let proto_message_len = proto_message.encoded_len();
550        let mut serialized = Vec::with_capacity(1 + proto_message_len + Self::SIGNATURE_LEN);
551        serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
552        proto_message
553            .encode(&mut serialized)
554            .expect("can always append to a buffer");
555        let signature = signature_key.calculate_signature(&serialized, csprng)?;
556        serialized.extend_from_slice(&signature[..]);
557        Ok(Self {
558            message_version: SENDERKEY_MESSAGE_CURRENT_VERSION,
559            distribution_id,
560            chain_id,
561            iteration,
562            ciphertext,
563            serialized: serialized.into_boxed_slice(),
564        })
565    }
566
567    pub fn verify_signature(&self, signature_key: &PublicKey) -> Result<bool> {
568        let (content, signature) = self
569            .serialized
570            .split_last_chunk::<{ Self::SIGNATURE_LEN }>()
571            .expect("length checked on initialization");
572        let valid = signature_key.verify_signature(content, signature);
573
574        Ok(valid)
575    }
576
577    #[inline]
578    pub fn message_version(&self) -> u8 {
579        self.message_version
580    }
581
582    #[inline]
583    pub fn distribution_id(&self) -> Uuid {
584        self.distribution_id
585    }
586
587    #[inline]
588    pub fn chain_id(&self) -> u32 {
589        self.chain_id
590    }
591
592    #[inline]
593    pub fn iteration(&self) -> u32 {
594        self.iteration
595    }
596
597    #[inline]
598    pub fn ciphertext(&self) -> &[u8] {
599        &self.ciphertext
600    }
601
602    #[inline]
603    pub fn serialized(&self) -> &[u8] {
604        &self.serialized
605    }
606}
607
608impl AsRef<[u8]> for SenderKeyMessage {
609    fn as_ref(&self) -> &[u8] {
610        &self.serialized
611    }
612}
613
614impl TryFrom<&[u8]> for SenderKeyMessage {
615    type Error = SignalProtocolError;
616
617    fn try_from(value: &[u8]) -> Result<Self> {
618        if value.len() < 1 + Self::SIGNATURE_LEN {
619            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
620        }
621        let message_version = value[0] >> 4;
622        if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
623            return Err(SignalProtocolError::LegacyCiphertextVersion(
624                message_version,
625            ));
626        }
627        if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
628            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
629                message_version,
630            ));
631        }
632        let proto_structure =
633            proto::wire::SenderKeyMessage::decode(&value[1..value.len() - Self::SIGNATURE_LEN])
634                .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
635
636        let distribution_id = proto_structure
637            .distribution_uuid
638            .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
639            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
640        let chain_id = proto_structure
641            .chain_id
642            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
643        let iteration = proto_structure
644            .iteration
645            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
646        let ciphertext = proto_structure
647            .ciphertext
648            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
649            .into_boxed_slice();
650
651        Ok(SenderKeyMessage {
652            message_version,
653            distribution_id,
654            chain_id,
655            iteration,
656            ciphertext,
657            serialized: Box::from(value),
658        })
659    }
660}
661
662#[derive(Debug, Clone)]
663pub struct SenderKeyDistributionMessage {
664    message_version: u8,
665    distribution_id: Uuid,
666    chain_id: u32,
667    iteration: u32,
668    chain_key: Vec<u8>,
669    signing_key: PublicKey,
670    serialized: Box<[u8]>,
671}
672
673impl SenderKeyDistributionMessage {
674    pub fn new(
675        message_version: u8,
676        distribution_id: Uuid,
677        chain_id: u32,
678        iteration: u32,
679        chain_key: Vec<u8>,
680        signing_key: PublicKey,
681    ) -> Result<Self> {
682        let proto_message = proto::wire::SenderKeyDistributionMessage {
683            distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
684            chain_id: Some(chain_id),
685            iteration: Some(iteration),
686            chain_key: Some(chain_key.clone()),
687            signing_key: Some(signing_key.serialize().to_vec()),
688        };
689        let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
690        serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
691        proto_message
692            .encode(&mut serialized)
693            .expect("can always append to a buffer");
694
695        Ok(Self {
696            message_version,
697            distribution_id,
698            chain_id,
699            iteration,
700            chain_key,
701            signing_key,
702            serialized: serialized.into_boxed_slice(),
703        })
704    }
705
706    #[inline]
707    pub fn message_version(&self) -> u8 {
708        self.message_version
709    }
710
711    #[inline]
712    pub fn distribution_id(&self) -> Result<Uuid> {
713        Ok(self.distribution_id)
714    }
715
716    #[inline]
717    pub fn chain_id(&self) -> Result<u32> {
718        Ok(self.chain_id)
719    }
720
721    #[inline]
722    pub fn iteration(&self) -> Result<u32> {
723        Ok(self.iteration)
724    }
725
726    #[inline]
727    pub fn chain_key(&self) -> Result<&[u8]> {
728        Ok(&self.chain_key)
729    }
730
731    #[inline]
732    pub fn signing_key(&self) -> Result<&PublicKey> {
733        Ok(&self.signing_key)
734    }
735
736    #[inline]
737    pub fn serialized(&self) -> &[u8] {
738        &self.serialized
739    }
740}
741
742impl AsRef<[u8]> for SenderKeyDistributionMessage {
743    fn as_ref(&self) -> &[u8] {
744        &self.serialized
745    }
746}
747
748impl TryFrom<&[u8]> for SenderKeyDistributionMessage {
749    type Error = SignalProtocolError;
750
751    fn try_from(value: &[u8]) -> Result<Self> {
752        // The message contains at least a X25519 key and a chain key
753        if value.len() < 1 + 32 + 32 {
754            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
755        }
756
757        let message_version = value[0] >> 4;
758
759        if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
760            return Err(SignalProtocolError::LegacyCiphertextVersion(
761                message_version,
762            ));
763        }
764        if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
765            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
766                message_version,
767            ));
768        }
769
770        let proto_structure = proto::wire::SenderKeyDistributionMessage::decode(&value[1..])
771            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
772
773        let distribution_id = proto_structure
774            .distribution_uuid
775            .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
776            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
777        let chain_id = proto_structure
778            .chain_id
779            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
780        let iteration = proto_structure
781            .iteration
782            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
783        let chain_key = proto_structure
784            .chain_key
785            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
786        let signing_key = proto_structure
787            .signing_key
788            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
789
790        if chain_key.len() != 32 || signing_key.len() != 33 {
791            return Err(SignalProtocolError::InvalidProtobufEncoding);
792        }
793
794        let signing_key = PublicKey::deserialize(&signing_key)?;
795
796        Ok(SenderKeyDistributionMessage {
797            message_version,
798            distribution_id,
799            chain_id,
800            iteration,
801            chain_key,
802            signing_key,
803            serialized: Box::from(value),
804        })
805    }
806}
807
808#[derive(Debug, Clone)]
809pub struct PlaintextContent {
810    serialized: Box<[u8]>,
811}
812
813impl PlaintextContent {
814    /// Identifies a serialized PlaintextContent.
815    ///
816    /// This ensures someone doesn't try to serialize an arbitrary Content message as
817    /// PlaintextContent; only messages that are okay to send as plaintext should be allowed.
818    const PLAINTEXT_CONTEXT_IDENTIFIER_BYTE: u8 = 0xC0;
819
820    /// Marks the end of a message and the start of any padding.
821    ///
822    /// Usually messages are padded to avoid exposing patterns,
823    /// but PlaintextContent messages are all fixed-length anyway, so there won't be any padding.
824    const PADDING_BOUNDARY_BYTE: u8 = 0x80;
825
826    #[inline]
827    pub fn body(&self) -> &[u8] {
828        &self.serialized[1..]
829    }
830
831    #[inline]
832    pub fn serialized(&self) -> &[u8] {
833        &self.serialized
834    }
835}
836
837impl From<DecryptionErrorMessage> for PlaintextContent {
838    fn from(message: DecryptionErrorMessage) -> Self {
839        let proto_structure = proto::service::Content {
840            decryption_error_message: Some(message.serialized().to_vec()),
841            ..Default::default()
842        };
843        let mut serialized = vec![Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE];
844        proto_structure
845            .encode(&mut serialized)
846            .expect("can always encode to a Vec");
847        serialized.push(Self::PADDING_BOUNDARY_BYTE);
848        Self {
849            serialized: Box::from(serialized),
850        }
851    }
852}
853
854impl TryFrom<&[u8]> for PlaintextContent {
855    type Error = SignalProtocolError;
856
857    fn try_from(value: &[u8]) -> Result<Self> {
858        if value.is_empty() {
859            return Err(SignalProtocolError::CiphertextMessageTooShort(0));
860        }
861        if value[0] != Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE {
862            return Err(SignalProtocolError::UnrecognizedMessageVersion(
863                value[0] as u32,
864            ));
865        }
866        Ok(Self {
867            serialized: Box::from(value),
868        })
869    }
870}
871
872#[derive(Debug, Clone)]
873pub struct DecryptionErrorMessage {
874    ratchet_key: Option<PublicKey>,
875    timestamp: Timestamp,
876    device_id: u32,
877    serialized: Box<[u8]>,
878}
879
880impl DecryptionErrorMessage {
881    pub fn for_original(
882        original_bytes: &[u8],
883        original_type: CiphertextMessageType,
884        original_timestamp: Timestamp,
885        original_sender_device_id: u32,
886    ) -> Result<Self> {
887        let ratchet_key = match original_type {
888            CiphertextMessageType::Whisper => {
889                Some(*SignalMessage::try_from(original_bytes)?.sender_ratchet_key())
890            }
891            CiphertextMessageType::PreKey => Some(
892                *PreKeySignalMessage::try_from(original_bytes)?
893                    .message()
894                    .sender_ratchet_key(),
895            ),
896            CiphertextMessageType::SenderKey => None,
897            CiphertextMessageType::Plaintext => {
898                return Err(SignalProtocolError::InvalidArgument(
899                    "cannot create a DecryptionErrorMessage for plaintext content; it is not encrypted".to_string()
900                ));
901            }
902        };
903
904        let proto_message = proto::service::DecryptionErrorMessage {
905            timestamp: Some(original_timestamp.epoch_millis()),
906            ratchet_key: ratchet_key.map(|k| k.serialize().into()),
907            device_id: Some(original_sender_device_id),
908        };
909        let serialized = proto_message.encode_to_vec();
910
911        Ok(Self {
912            ratchet_key,
913            timestamp: original_timestamp,
914            device_id: original_sender_device_id,
915            serialized: serialized.into_boxed_slice(),
916        })
917    }
918
919    #[inline]
920    pub fn timestamp(&self) -> Timestamp {
921        self.timestamp
922    }
923
924    #[inline]
925    pub fn ratchet_key(&self) -> Option<&PublicKey> {
926        self.ratchet_key.as_ref()
927    }
928
929    #[inline]
930    pub fn device_id(&self) -> u32 {
931        self.device_id
932    }
933
934    #[inline]
935    pub fn serialized(&self) -> &[u8] {
936        &self.serialized
937    }
938}
939
940impl TryFrom<&[u8]> for DecryptionErrorMessage {
941    type Error = SignalProtocolError;
942
943    fn try_from(value: &[u8]) -> Result<Self> {
944        let proto_structure = proto::service::DecryptionErrorMessage::decode(value)
945            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
946        let timestamp = proto_structure
947            .timestamp
948            .map(Timestamp::from_epoch_millis)
949            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
950        let ratchet_key = proto_structure
951            .ratchet_key
952            .map(|k| PublicKey::deserialize(&k))
953            .transpose()?;
954        let device_id = proto_structure.device_id.unwrap_or_default();
955        Ok(Self {
956            timestamp,
957            ratchet_key,
958            device_id,
959            serialized: Box::from(value),
960        })
961    }
962}
963
964/// For testing
965pub fn extract_decryption_error_message_from_serialized_content(
966    bytes: &[u8],
967) -> Result<DecryptionErrorMessage> {
968    if bytes.last() != Some(&PlaintextContent::PADDING_BOUNDARY_BYTE) {
969        return Err(SignalProtocolError::InvalidProtobufEncoding);
970    }
971    let content = proto::service::Content::decode(bytes.split_last().expect("checked above").1)
972        .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
973    content
974        .decryption_error_message
975        .as_deref()
976        .ok_or_else(|| {
977            SignalProtocolError::InvalidArgument(
978                "Content does not contain DecryptionErrorMessage".to_owned(),
979            )
980        })
981        .and_then(DecryptionErrorMessage::try_from)
982}
983
984/// A consistent way to determine, given a session that is not PQ
985/// and a ratio of sessions which if not PQ should be archived,
986/// which sessions to use (returning true) and which to archive
987/// (returning false).  The session key's first 4 bytes are used as
988/// a uniformly random big-endian integer as part of this calculation,
989/// which works well for a session's `alice_base_key()`.
990pub fn should_use_nonpq_session(require_pq_ratio: f64, session_key: &[u8]) -> bool {
991    assert!(session_key.len() >= 4);
992    if require_pq_ratio >= 1.0 {
993        return false;
994    } else if require_pq_ratio <= 0.0 {
995        return true;
996    }
997    // We have a chain, but it's not a PQ chain.
998    // We want to deterministically decide whether a session should be used
999    // based on a ratio between 0 and 1.  We also want the decision as to
1000    // whether to use the session to be the same for Alice and Bob.
1001    // The session key is a x25519 key, from which we pull out 4 bytes
1002    // we expect to be relatively uniform.
1003    let sess_u32 = u32::from_be_bytes(
1004        (&session_key[..4])
1005            .try_into()
1006            .expect("should have 32 bytes"),
1007    );
1008    // We then convert the require_pq_ratio to a u32 that is 0xFF... for 1,
1009    // 0x00... for 0, and uniform in between for other values.
1010    #[allow(clippy::cast_possible_truncation)]
1011    let ratio_u32 = ((u32::MAX as f64) * require_pq_ratio) as u32;
1012    // Finally, we compare the two, and we only expire the existing session if
1013    // its key is smaller than the ratio key.
1014    ratio_u32 <= sess_u32
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    use rand::rngs::OsRng;
1020    use rand::{CryptoRng, Rng, RngCore, TryRngCore as _};
1021
1022    use super::*;
1023    use crate::{DeviceId, KeyPair};
1024
1025    fn create_signal_message<T>(csprng: &mut T) -> Result<SignalMessage>
1026    where
1027        T: Rng + CryptoRng,
1028    {
1029        let mut mac_key = [0u8; 32];
1030        csprng.fill_bytes(&mut mac_key);
1031        let mac_key = mac_key;
1032
1033        let mut ciphertext = [0u8; 20];
1034        csprng.fill_bytes(&mut ciphertext);
1035        let ciphertext = ciphertext;
1036
1037        let sender_ratchet_key_pair = KeyPair::generate(csprng);
1038        let sender_identity_key_pair = KeyPair::generate(csprng);
1039        let receiver_identity_key_pair = KeyPair::generate(csprng);
1040        let sender_address = ProtocolAddress::new(
1041            "31415926-5358-9793-2384-626433827950".to_owned(),
1042            DeviceId::new(1).unwrap(),
1043        );
1044        let recipient_address = ProtocolAddress::new(
1045            "27182818-2845-9045-2353-602874713526".to_owned(),
1046            DeviceId::new(1).unwrap(),
1047        );
1048
1049        SignalMessage::new(
1050            4,
1051            &mac_key,
1052            Some((&sender_address, &recipient_address)),
1053            sender_ratchet_key_pair.public_key,
1054            42,
1055            41,
1056            &ciphertext,
1057            &sender_identity_key_pair.public_key.into(),
1058            &receiver_identity_key_pair.public_key.into(),
1059            b"", // pq_ratchet
1060        )
1061    }
1062
1063    fn assert_signal_message_equals(m1: &SignalMessage, m2: &SignalMessage) {
1064        assert_eq!(m1.message_version, m2.message_version);
1065        assert_eq!(m1.sender_ratchet_key, m2.sender_ratchet_key);
1066        assert_eq!(m1.counter, m2.counter);
1067        assert_eq!(m1.previous_counter, m2.previous_counter);
1068        assert_eq!(m1.ciphertext, m2.ciphertext);
1069        assert_eq!(m1.addresses, m2.addresses);
1070        assert_eq!(m1.serialized, m2.serialized);
1071    }
1072
1073    #[test]
1074    fn test_signal_message_serialize_deserialize() -> Result<()> {
1075        let mut csprng = OsRng.unwrap_err();
1076        let message = create_signal_message(&mut csprng)?;
1077        let deser_message =
1078            SignalMessage::try_from(message.as_ref()).expect("should deserialize without error");
1079        assert_signal_message_equals(&message, &deser_message);
1080        Ok(())
1081    }
1082
1083    #[test]
1084    fn test_pre_key_signal_message_serialize_deserialize() -> Result<()> {
1085        let mut csprng = OsRng.unwrap_err();
1086        let identity_key_pair = KeyPair::generate(&mut csprng);
1087        let base_key_pair = KeyPair::generate(&mut csprng);
1088        let message = create_signal_message(&mut csprng)?;
1089        let pre_key_signal_message = PreKeySignalMessage::new(
1090            3,
1091            365,
1092            None,
1093            97.into(),
1094            None, // TODO: add kyber prekeys
1095            base_key_pair.public_key,
1096            identity_key_pair.public_key.into(),
1097            message,
1098        )?;
1099        let deser_pre_key_signal_message =
1100            PreKeySignalMessage::try_from(pre_key_signal_message.as_ref())
1101                .expect("should deserialize without error");
1102        assert_eq!(
1103            pre_key_signal_message.message_version,
1104            deser_pre_key_signal_message.message_version
1105        );
1106        assert_eq!(
1107            pre_key_signal_message.registration_id,
1108            deser_pre_key_signal_message.registration_id
1109        );
1110        assert_eq!(
1111            pre_key_signal_message.pre_key_id,
1112            deser_pre_key_signal_message.pre_key_id
1113        );
1114        assert_eq!(
1115            pre_key_signal_message.signed_pre_key_id,
1116            deser_pre_key_signal_message.signed_pre_key_id
1117        );
1118        assert_eq!(
1119            pre_key_signal_message.base_key,
1120            deser_pre_key_signal_message.base_key
1121        );
1122        assert_eq!(
1123            pre_key_signal_message.identity_key.public_key(),
1124            deser_pre_key_signal_message.identity_key.public_key()
1125        );
1126        assert_signal_message_equals(
1127            &pre_key_signal_message.message,
1128            &deser_pre_key_signal_message.message,
1129        );
1130        assert_eq!(
1131            pre_key_signal_message.serialized,
1132            deser_pre_key_signal_message.serialized
1133        );
1134        Ok(())
1135    }
1136
1137    #[test]
1138    fn test_signal_message_verify_mac_accepts_legacy_message_without_addresses() -> Result<()> {
1139        let mut csprng = OsRng.unwrap_err();
1140        let mut mac_key = [0u8; 32];
1141        csprng.fill_bytes(&mut mac_key);
1142
1143        let mut ciphertext = [0u8; 20];
1144        csprng.fill_bytes(&mut ciphertext);
1145
1146        let sender_ratchet_key_pair = KeyPair::generate(&mut csprng);
1147        let sender_identity_key_pair = KeyPair::generate(&mut csprng);
1148        let receiver_identity_key_pair = KeyPair::generate(&mut csprng);
1149        let sender_address = ProtocolAddress::new(
1150            "16180339-8874-9894-8482-045868343656".to_owned(),
1151            DeviceId::new(1).unwrap(),
1152        );
1153        let recipient_address = ProtocolAddress::new(
1154            "14142135-6237-3095-0488-016887242096".to_owned(),
1155            DeviceId::new(1).unwrap(),
1156        );
1157
1158        let message = SignalMessage::new(
1159            4,
1160            &mac_key,
1161            Some((&sender_address, &recipient_address)),
1162            sender_ratchet_key_pair.public_key,
1163            42,
1164            41,
1165            &ciphertext,
1166            &sender_identity_key_pair.public_key.into(),
1167            &receiver_identity_key_pair.public_key.into(),
1168            b"",
1169        )?;
1170
1171        let mut proto_structure = proto::wire::SignalMessage::decode(
1172            &message.serialized()[1..message.serialized().len() - SignalMessage::MAC_LENGTH],
1173        )
1174        .expect("valid protobuf");
1175        proto_structure.addresses = None;
1176
1177        let mut serialized =
1178            vec![((message.message_version() & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION];
1179        proto_structure.encode(&mut serialized).expect("encodes");
1180        let mac = SignalMessage::compute_mac(
1181            &sender_identity_key_pair.public_key.into(),
1182            &receiver_identity_key_pair.public_key.into(),
1183            &mac_key,
1184            &serialized,
1185        )?;
1186        serialized.extend_from_slice(&mac);
1187
1188        let legacy_message = SignalMessage::try_from(serialized.as_slice())?;
1189        assert!(legacy_message.verify_mac_with_addresses(
1190            &sender_address,
1191            &recipient_address,
1192            &sender_identity_key_pair.public_key.into(),
1193            &receiver_identity_key_pair.public_key.into(),
1194            &mac_key,
1195        )?);
1196
1197        Ok(())
1198    }
1199
1200    #[test]
1201    fn test_signal_message_verify_mac_rejects_wrong_address() -> Result<()> {
1202        let mut csprng = OsRng.unwrap_err();
1203        let mut mac_key = [0u8; 32];
1204        csprng.fill_bytes(&mut mac_key);
1205
1206        let mut ciphertext = [0u8; 20];
1207        csprng.fill_bytes(&mut ciphertext);
1208
1209        let sender_ratchet_key_pair = KeyPair::generate(&mut csprng);
1210        let sender_identity_key_pair = KeyPair::generate(&mut csprng);
1211        let receiver_identity_key_pair = KeyPair::generate(&mut csprng);
1212        let sender_address = ProtocolAddress::new(
1213            "deadbeef-cafe-babe-feed-faceb00c0ffe".to_owned(),
1214            DeviceId::new(1).unwrap(),
1215        );
1216        let recipient_address = ProtocolAddress::new(
1217            "01120358-1321-3455-0891-44233377610a".to_owned(),
1218            DeviceId::new(1).unwrap(),
1219        );
1220        let wrong_address = ProtocolAddress::new(
1221            "02030507-1113-1719-2329-313741434753".to_owned(),
1222            DeviceId::new(1).unwrap(),
1223        );
1224
1225        let message = SignalMessage::new(
1226            4,
1227            &mac_key,
1228            Some((&sender_address, &recipient_address)),
1229            sender_ratchet_key_pair.public_key,
1230            42,
1231            41,
1232            &ciphertext,
1233            &sender_identity_key_pair.public_key.into(),
1234            &receiver_identity_key_pair.public_key.into(),
1235            b"",
1236        )?;
1237
1238        // Wrong sender address should be rejected.
1239        assert!(!message.verify_mac_with_addresses(
1240            &wrong_address,
1241            &recipient_address,
1242            &sender_identity_key_pair.public_key.into(),
1243            &receiver_identity_key_pair.public_key.into(),
1244            &mac_key,
1245        )?);
1246
1247        // Wrong recipient address should be rejected.
1248        assert!(!message.verify_mac_with_addresses(
1249            &sender_address,
1250            &wrong_address,
1251            &sender_identity_key_pair.public_key.into(),
1252            &receiver_identity_key_pair.public_key.into(),
1253            &mac_key,
1254        )?);
1255
1256        // Correct addresses should be accepted.
1257        assert!(message.verify_mac_with_addresses(
1258            &sender_address,
1259            &recipient_address,
1260            &sender_identity_key_pair.public_key.into(),
1261            &receiver_identity_key_pair.public_key.into(),
1262            &mac_key,
1263        )?);
1264
1265        Ok(())
1266    }
1267
1268    #[test]
1269    fn test_sender_key_message_serialize_deserialize() -> Result<()> {
1270        let mut csprng = OsRng.unwrap_err();
1271        let signature_key_pair = KeyPair::generate(&mut csprng);
1272        let sender_key_message = SenderKeyMessage::new(
1273            SENDERKEY_MESSAGE_CURRENT_VERSION,
1274            Uuid::from_u128(0xd1d1d1d1_7000_11eb_b32a_33b8a8a487a6),
1275            42,
1276            7,
1277            [1u8, 2, 3].into(),
1278            &mut csprng,
1279            &signature_key_pair.private_key,
1280        )?;
1281        let deser_sender_key_message = SenderKeyMessage::try_from(sender_key_message.as_ref())
1282            .expect("should deserialize without error");
1283        assert_eq!(
1284            sender_key_message.message_version,
1285            deser_sender_key_message.message_version
1286        );
1287        assert_eq!(
1288            sender_key_message.chain_id,
1289            deser_sender_key_message.chain_id
1290        );
1291        assert_eq!(
1292            sender_key_message.iteration,
1293            deser_sender_key_message.iteration
1294        );
1295        assert_eq!(
1296            sender_key_message.ciphertext,
1297            deser_sender_key_message.ciphertext
1298        );
1299        assert_eq!(
1300            sender_key_message.serialized,
1301            deser_sender_key_message.serialized
1302        );
1303        Ok(())
1304    }
1305
1306    #[test]
1307    fn test_decryption_error_message() -> Result<()> {
1308        let mut csprng = OsRng.unwrap_err();
1309        let identity_key_pair = KeyPair::generate(&mut csprng);
1310        let base_key_pair = KeyPair::generate(&mut csprng);
1311        let message = create_signal_message(&mut csprng)?;
1312        let timestamp: Timestamp = Timestamp::from_epoch_millis(0x2_0000_0001);
1313        let device_id = 0x8086_2021;
1314
1315        {
1316            let error_message = DecryptionErrorMessage::for_original(
1317                message.serialized(),
1318                CiphertextMessageType::Whisper,
1319                timestamp,
1320                device_id,
1321            )?;
1322            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1323            assert_eq!(
1324                error_message.ratchet_key(),
1325                Some(message.sender_ratchet_key())
1326            );
1327            assert_eq!(error_message.timestamp(), timestamp);
1328            assert_eq!(error_message.device_id(), device_id);
1329        }
1330
1331        let pre_key_signal_message = PreKeySignalMessage::new(
1332            3,
1333            365,
1334            None,
1335            97.into(),
1336            None, // TODO: add kyber prekeys
1337            base_key_pair.public_key,
1338            identity_key_pair.public_key.into(),
1339            message,
1340        )?;
1341
1342        {
1343            let error_message = DecryptionErrorMessage::for_original(
1344                pre_key_signal_message.serialized(),
1345                CiphertextMessageType::PreKey,
1346                timestamp,
1347                device_id,
1348            )?;
1349            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1350            assert_eq!(
1351                error_message.ratchet_key(),
1352                Some(pre_key_signal_message.message().sender_ratchet_key())
1353            );
1354            assert_eq!(error_message.timestamp(), timestamp);
1355            assert_eq!(error_message.device_id(), device_id);
1356        }
1357
1358        let sender_key_message = SenderKeyMessage::new(
1359            3,
1360            Uuid::nil(),
1361            1,
1362            2,
1363            Box::from(b"test".to_owned()),
1364            &mut csprng,
1365            &base_key_pair.private_key,
1366        )?;
1367
1368        {
1369            let error_message = DecryptionErrorMessage::for_original(
1370                sender_key_message.serialized(),
1371                CiphertextMessageType::SenderKey,
1372                timestamp,
1373                device_id,
1374            )?;
1375            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1376            assert_eq!(error_message.ratchet_key(), None);
1377            assert_eq!(error_message.timestamp(), timestamp);
1378            assert_eq!(error_message.device_id(), device_id);
1379        }
1380
1381        Ok(())
1382    }
1383
1384    #[test]
1385    fn test_decryption_error_message_for_plaintext() {
1386        assert!(matches!(
1387            DecryptionErrorMessage::for_original(
1388                &[],
1389                CiphertextMessageType::Plaintext,
1390                Timestamp::from_epoch_millis(5),
1391                7
1392            ),
1393            Err(SignalProtocolError::InvalidArgument(_))
1394        ));
1395    }
1396
1397    #[test]
1398    fn test_should_use_nonpq_session() {
1399        let max = b"\xff\xff\xff\xff";
1400        let min = b"\x00\x00\x00\x00";
1401        let mid = b"\x7f\xff\xff\xff";
1402        assert!(!should_use_nonpq_session(1.0, max));
1403        assert!(!should_use_nonpq_session(1.0, min));
1404        assert!(should_use_nonpq_session(0.0, max));
1405        assert!(should_use_nonpq_session(0.0, min));
1406
1407        assert!(!should_use_nonpq_session(0.75, min));
1408        assert!(!should_use_nonpq_session(0.75, mid));
1409        assert!(should_use_nonpq_session(0.75, max));
1410
1411        assert!(!should_use_nonpq_session(0.25, min));
1412        assert!(should_use_nonpq_session(0.25, mid));
1413        assert!(should_use_nonpq_session(0.25, max));
1414    }
1415}