libsignal_protocol/
protocol.rs

1//
2// Copyright 2020-2021 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use hmac::{Hmac, Mac};
7use prost::Message;
8use rand::{CryptoRng, Rng};
9use sha2::Sha256;
10use subtle::ConstantTimeEq;
11use uuid::Uuid;
12
13use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
14use crate::{
15    IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError, Timestamp, kem, proto,
16};
17
18pub(crate) const CIPHERTEXT_MESSAGE_CURRENT_VERSION: u8 = 4;
19// Backward compatible, lacking Kyber keys, version
20pub(crate) const CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION: u8 = 3;
21pub(crate) const SENDERKEY_MESSAGE_CURRENT_VERSION: u8 = 3;
22
23#[derive(Debug)]
24pub enum CiphertextMessage {
25    SignalMessage(SignalMessage),
26    PreKeySignalMessage(PreKeySignalMessage),
27    SenderKeyMessage(SenderKeyMessage),
28    PlaintextContent(PlaintextContent),
29}
30
31#[derive(Copy, Clone, Eq, PartialEq, Debug, derive_more::TryFrom)]
32#[repr(u8)]
33#[try_from(repr)]
34pub enum CiphertextMessageType {
35    Whisper = 2,
36    PreKey = 3,
37    SenderKey = 7,
38    Plaintext = 8,
39}
40
41impl CiphertextMessage {
42    pub fn message_type(&self) -> CiphertextMessageType {
43        match self {
44            CiphertextMessage::SignalMessage(_) => CiphertextMessageType::Whisper,
45            CiphertextMessage::PreKeySignalMessage(_) => CiphertextMessageType::PreKey,
46            CiphertextMessage::SenderKeyMessage(_) => CiphertextMessageType::SenderKey,
47            CiphertextMessage::PlaintextContent(_) => CiphertextMessageType::Plaintext,
48        }
49    }
50
51    pub fn serialize(&self) -> &[u8] {
52        match self {
53            CiphertextMessage::SignalMessage(x) => x.serialized(),
54            CiphertextMessage::PreKeySignalMessage(x) => x.serialized(),
55            CiphertextMessage::SenderKeyMessage(x) => x.serialized(),
56            CiphertextMessage::PlaintextContent(x) => x.serialized(),
57        }
58    }
59}
60
61#[derive(Debug, Clone)]
62pub struct SignalMessage {
63    message_version: u8,
64    sender_ratchet_key: PublicKey,
65    counter: u32,
66    #[cfg_attr(not(test), expect(dead_code))]
67    previous_counter: u32,
68    ciphertext: Box<[u8]>,
69    pq_ratchet: spqr::SerializedState,
70    serialized: Box<[u8]>,
71}
72
73impl SignalMessage {
74    const MAC_LENGTH: usize = 8;
75
76    #[allow(clippy::too_many_arguments)]
77    pub fn new(
78        message_version: u8,
79        mac_key: &[u8],
80        sender_ratchet_key: PublicKey,
81        counter: u32,
82        previous_counter: u32,
83        ciphertext: &[u8],
84        sender_identity_key: &IdentityKey,
85        receiver_identity_key: &IdentityKey,
86        pq_ratchet: &[u8],
87    ) -> Result<Self> {
88        let message = proto::wire::SignalMessage {
89            ratchet_key: Some(sender_ratchet_key.serialize().into_vec()),
90            counter: Some(counter),
91            previous_counter: Some(previous_counter),
92            ciphertext: Some(Vec::<u8>::from(ciphertext)),
93            pq_ratchet: if pq_ratchet.is_empty() {
94                None
95            } else {
96                Some(pq_ratchet.to_vec())
97            },
98        };
99        let mut serialized = Vec::with_capacity(1 + message.encoded_len() + Self::MAC_LENGTH);
100        serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
101        message
102            .encode(&mut serialized)
103            .expect("can always append to a buffer");
104        let mac = Self::compute_mac(
105            sender_identity_key,
106            receiver_identity_key,
107            mac_key,
108            &serialized,
109        )?;
110        serialized.extend_from_slice(&mac);
111        let serialized = serialized.into_boxed_slice();
112        Ok(Self {
113            message_version,
114            sender_ratchet_key,
115            counter,
116            previous_counter,
117            ciphertext: ciphertext.into(),
118            pq_ratchet: pq_ratchet.to_vec(),
119            serialized,
120        })
121    }
122
123    #[inline]
124    pub fn message_version(&self) -> u8 {
125        self.message_version
126    }
127
128    #[inline]
129    pub fn sender_ratchet_key(&self) -> &PublicKey {
130        &self.sender_ratchet_key
131    }
132
133    #[inline]
134    pub fn counter(&self) -> u32 {
135        self.counter
136    }
137
138    #[inline]
139    pub fn pq_ratchet(&self) -> &spqr::SerializedMessage {
140        &self.pq_ratchet
141    }
142
143    #[inline]
144    pub fn serialized(&self) -> &[u8] {
145        &self.serialized
146    }
147
148    #[inline]
149    pub fn body(&self) -> &[u8] {
150        &self.ciphertext
151    }
152
153    pub fn verify_mac(
154        &self,
155        sender_identity_key: &IdentityKey,
156        receiver_identity_key: &IdentityKey,
157        mac_key: &[u8],
158    ) -> Result<bool> {
159        let (content, their_mac) = self
160            .serialized
161            .split_last_chunk::<{ Self::MAC_LENGTH }>()
162            .expect("length checked at construction");
163        let our_mac =
164            Self::compute_mac(sender_identity_key, receiver_identity_key, mac_key, content)?;
165        let result: bool = our_mac.ct_eq(their_mac).into();
166        if !result {
167            // A warning instead of an error because we try multiple sessions.
168            log::warn!(
169                "Bad Mac! Their Mac: {} Our Mac: {}",
170                hex::encode(their_mac),
171                hex::encode(our_mac)
172            );
173        }
174        Ok(result)
175    }
176
177    fn compute_mac(
178        sender_identity_key: &IdentityKey,
179        receiver_identity_key: &IdentityKey,
180        mac_key: &[u8],
181        message: &[u8],
182    ) -> Result<[u8; Self::MAC_LENGTH]> {
183        if mac_key.len() != 32 {
184            return Err(SignalProtocolError::InvalidMacKeyLength(mac_key.len()));
185        }
186        let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
187            .expect("HMAC-SHA256 should accept any size key");
188
189        mac.update(sender_identity_key.public_key().serialize().as_ref());
190        mac.update(receiver_identity_key.public_key().serialize().as_ref());
191        mac.update(message);
192        let result = *mac
193            .finalize()
194            .into_bytes()
195            .first_chunk()
196            .expect("enough bytes");
197        Ok(result)
198    }
199}
200
201impl AsRef<[u8]> for SignalMessage {
202    fn as_ref(&self) -> &[u8] {
203        &self.serialized
204    }
205}
206
207impl TryFrom<&[u8]> for SignalMessage {
208    type Error = SignalProtocolError;
209
210    fn try_from(value: &[u8]) -> Result<Self> {
211        if value.len() < SignalMessage::MAC_LENGTH + 1 {
212            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
213        }
214        let message_version = value[0] >> 4;
215        if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
216            return Err(SignalProtocolError::LegacyCiphertextVersion(
217                message_version,
218            ));
219        }
220        if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
221            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
222                message_version,
223            ));
224        }
225
226        let proto_structure =
227            proto::wire::SignalMessage::decode(&value[1..value.len() - SignalMessage::MAC_LENGTH])
228                .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
229
230        let sender_ratchet_key = proto_structure
231            .ratchet_key
232            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
233        let sender_ratchet_key = PublicKey::deserialize(&sender_ratchet_key)?;
234        let counter = proto_structure
235            .counter
236            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
237        let previous_counter = proto_structure.previous_counter.unwrap_or(0);
238        let ciphertext = proto_structure
239            .ciphertext
240            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
241            .into_boxed_slice();
242
243        Ok(SignalMessage {
244            message_version,
245            sender_ratchet_key,
246            counter,
247            previous_counter,
248            ciphertext,
249            pq_ratchet: proto_structure.pq_ratchet.unwrap_or(vec![]),
250            serialized: Box::from(value),
251        })
252    }
253}
254
255#[derive(Debug, Clone)]
256pub struct KyberPayload {
257    pre_key_id: KyberPreKeyId,
258    ciphertext: kem::SerializedCiphertext,
259}
260
261impl KyberPayload {
262    pub fn new(id: KyberPreKeyId, ciphertext: kem::SerializedCiphertext) -> Self {
263        Self {
264            pre_key_id: id,
265            ciphertext,
266        }
267    }
268}
269
270#[derive(Debug, Clone)]
271pub struct PreKeySignalMessage {
272    message_version: u8,
273    registration_id: u32,
274    pre_key_id: Option<PreKeyId>,
275    signed_pre_key_id: SignedPreKeyId,
276    // While we reject messages without Kyber payloads, we still for now allow constructing the
277    // struct without one so that we can provide a better error message when we try to process it.
278    kyber_payload: Option<KyberPayload>,
279    base_key: PublicKey,
280    identity_key: IdentityKey,
281    message: SignalMessage,
282    serialized: Box<[u8]>,
283}
284
285impl PreKeySignalMessage {
286    pub fn new(
287        message_version: u8,
288        registration_id: u32,
289        pre_key_id: Option<PreKeyId>,
290        signed_pre_key_id: SignedPreKeyId,
291        kyber_payload: Option<KyberPayload>,
292        base_key: PublicKey,
293        identity_key: IdentityKey,
294        message: SignalMessage,
295    ) -> Result<Self> {
296        let proto_message = proto::wire::PreKeySignalMessage {
297            registration_id: Some(registration_id),
298            pre_key_id: pre_key_id.map(|id| id.into()),
299            signed_pre_key_id: Some(signed_pre_key_id.into()),
300            kyber_pre_key_id: kyber_payload.as_ref().map(|kyber| kyber.pre_key_id.into()),
301            kyber_ciphertext: kyber_payload
302                .as_ref()
303                .map(|kyber| kyber.ciphertext.to_vec()),
304            base_key: Some(base_key.serialize().into_vec()),
305            identity_key: Some(identity_key.serialize().into_vec()),
306            message: Some(Vec::from(message.as_ref())),
307        };
308        let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
309        serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
310        proto_message
311            .encode(&mut serialized)
312            .expect("can always append to a Vec");
313        Ok(Self {
314            message_version,
315            registration_id,
316            pre_key_id,
317            signed_pre_key_id,
318            kyber_payload,
319            base_key,
320            identity_key,
321            message,
322            serialized: serialized.into_boxed_slice(),
323        })
324    }
325
326    #[inline]
327    pub fn message_version(&self) -> u8 {
328        self.message_version
329    }
330
331    #[inline]
332    pub fn registration_id(&self) -> u32 {
333        self.registration_id
334    }
335
336    #[inline]
337    pub fn pre_key_id(&self) -> Option<PreKeyId> {
338        self.pre_key_id
339    }
340
341    #[inline]
342    pub fn signed_pre_key_id(&self) -> SignedPreKeyId {
343        self.signed_pre_key_id
344    }
345
346    #[inline]
347    pub fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
348        self.kyber_payload.as_ref().map(|kyber| kyber.pre_key_id)
349    }
350
351    #[inline]
352    pub fn kyber_ciphertext(&self) -> Option<&kem::SerializedCiphertext> {
353        self.kyber_payload.as_ref().map(|kyber| &kyber.ciphertext)
354    }
355
356    #[inline]
357    pub fn base_key(&self) -> &PublicKey {
358        &self.base_key
359    }
360
361    #[inline]
362    pub fn identity_key(&self) -> &IdentityKey {
363        &self.identity_key
364    }
365
366    #[inline]
367    pub fn message(&self) -> &SignalMessage {
368        &self.message
369    }
370
371    #[inline]
372    pub fn serialized(&self) -> &[u8] {
373        &self.serialized
374    }
375}
376
377impl AsRef<[u8]> for PreKeySignalMessage {
378    fn as_ref(&self) -> &[u8] {
379        &self.serialized
380    }
381}
382
383impl TryFrom<&[u8]> for PreKeySignalMessage {
384    type Error = SignalProtocolError;
385
386    fn try_from(value: &[u8]) -> Result<Self> {
387        if value.is_empty() {
388            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
389        }
390
391        let message_version = value[0] >> 4;
392        if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
393            return Err(SignalProtocolError::LegacyCiphertextVersion(
394                message_version,
395            ));
396        }
397        if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
398            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
399                message_version,
400            ));
401        }
402
403        let proto_structure = proto::wire::PreKeySignalMessage::decode(&value[1..])
404            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
405
406        let base_key = proto_structure
407            .base_key
408            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
409        let identity_key = proto_structure
410            .identity_key
411            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
412        let message = proto_structure
413            .message
414            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
415        let signed_pre_key_id = proto_structure
416            .signed_pre_key_id
417            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
418
419        let base_key = PublicKey::deserialize(base_key.as_ref())?;
420
421        let kyber_payload = match (
422            proto_structure.kyber_pre_key_id,
423            proto_structure.kyber_ciphertext,
424        ) {
425            (Some(id), Some(ct)) => Some(KyberPayload::new(id.into(), ct.into_boxed_slice())),
426            (None, None) if message_version <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION => None,
427            (None, None) => {
428                return Err(SignalProtocolError::InvalidMessage(
429                    CiphertextMessageType::PreKey,
430                    "Kyber pre key must be present for this session version",
431                ));
432            }
433            _ => {
434                return Err(SignalProtocolError::InvalidMessage(
435                    CiphertextMessageType::PreKey,
436                    "Both or neither kyber pre_key_id and kyber_ciphertext can be present",
437                ));
438            }
439        };
440
441        Ok(PreKeySignalMessage {
442            message_version,
443            registration_id: proto_structure.registration_id.unwrap_or(0),
444            pre_key_id: proto_structure.pre_key_id.map(|id| id.into()),
445            signed_pre_key_id: signed_pre_key_id.into(),
446            kyber_payload,
447            base_key,
448            identity_key: IdentityKey::try_from(identity_key.as_ref())?,
449            message: SignalMessage::try_from(message.as_ref())?,
450            serialized: Box::from(value),
451        })
452    }
453}
454
455#[derive(Debug, Clone)]
456pub struct SenderKeyMessage {
457    message_version: u8,
458    distribution_id: Uuid,
459    chain_id: u32,
460    iteration: u32,
461    ciphertext: Box<[u8]>,
462    serialized: Box<[u8]>,
463}
464
465impl SenderKeyMessage {
466    const SIGNATURE_LEN: usize = 64;
467
468    pub fn new<R: CryptoRng + Rng>(
469        message_version: u8,
470        distribution_id: Uuid,
471        chain_id: u32,
472        iteration: u32,
473        ciphertext: Box<[u8]>,
474        csprng: &mut R,
475        signature_key: &PrivateKey,
476    ) -> Result<Self> {
477        let proto_message = proto::wire::SenderKeyMessage {
478            distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
479            chain_id: Some(chain_id),
480            iteration: Some(iteration),
481            ciphertext: Some(ciphertext.to_vec()),
482        };
483        let proto_message_len = proto_message.encoded_len();
484        let mut serialized = Vec::with_capacity(1 + proto_message_len + Self::SIGNATURE_LEN);
485        serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
486        proto_message
487            .encode(&mut serialized)
488            .expect("can always append to a buffer");
489        let signature = signature_key.calculate_signature(&serialized, csprng)?;
490        serialized.extend_from_slice(&signature[..]);
491        Ok(Self {
492            message_version: SENDERKEY_MESSAGE_CURRENT_VERSION,
493            distribution_id,
494            chain_id,
495            iteration,
496            ciphertext,
497            serialized: serialized.into_boxed_slice(),
498        })
499    }
500
501    pub fn verify_signature(&self, signature_key: &PublicKey) -> Result<bool> {
502        let (content, signature) = self
503            .serialized
504            .split_last_chunk::<{ Self::SIGNATURE_LEN }>()
505            .expect("length checked on initialization");
506        let valid = signature_key.verify_signature(content, signature);
507
508        Ok(valid)
509    }
510
511    #[inline]
512    pub fn message_version(&self) -> u8 {
513        self.message_version
514    }
515
516    #[inline]
517    pub fn distribution_id(&self) -> Uuid {
518        self.distribution_id
519    }
520
521    #[inline]
522    pub fn chain_id(&self) -> u32 {
523        self.chain_id
524    }
525
526    #[inline]
527    pub fn iteration(&self) -> u32 {
528        self.iteration
529    }
530
531    #[inline]
532    pub fn ciphertext(&self) -> &[u8] {
533        &self.ciphertext
534    }
535
536    #[inline]
537    pub fn serialized(&self) -> &[u8] {
538        &self.serialized
539    }
540}
541
542impl AsRef<[u8]> for SenderKeyMessage {
543    fn as_ref(&self) -> &[u8] {
544        &self.serialized
545    }
546}
547
548impl TryFrom<&[u8]> for SenderKeyMessage {
549    type Error = SignalProtocolError;
550
551    fn try_from(value: &[u8]) -> Result<Self> {
552        if value.len() < 1 + Self::SIGNATURE_LEN {
553            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
554        }
555        let message_version = value[0] >> 4;
556        if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
557            return Err(SignalProtocolError::LegacyCiphertextVersion(
558                message_version,
559            ));
560        }
561        if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
562            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
563                message_version,
564            ));
565        }
566        let proto_structure =
567            proto::wire::SenderKeyMessage::decode(&value[1..value.len() - Self::SIGNATURE_LEN])
568                .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
569
570        let distribution_id = proto_structure
571            .distribution_uuid
572            .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
573            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
574        let chain_id = proto_structure
575            .chain_id
576            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
577        let iteration = proto_structure
578            .iteration
579            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
580        let ciphertext = proto_structure
581            .ciphertext
582            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
583            .into_boxed_slice();
584
585        Ok(SenderKeyMessage {
586            message_version,
587            distribution_id,
588            chain_id,
589            iteration,
590            ciphertext,
591            serialized: Box::from(value),
592        })
593    }
594}
595
596#[derive(Debug, Clone)]
597pub struct SenderKeyDistributionMessage {
598    message_version: u8,
599    distribution_id: Uuid,
600    chain_id: u32,
601    iteration: u32,
602    chain_key: Vec<u8>,
603    signing_key: PublicKey,
604    serialized: Box<[u8]>,
605}
606
607impl SenderKeyDistributionMessage {
608    pub fn new(
609        message_version: u8,
610        distribution_id: Uuid,
611        chain_id: u32,
612        iteration: u32,
613        chain_key: Vec<u8>,
614        signing_key: PublicKey,
615    ) -> Result<Self> {
616        let proto_message = proto::wire::SenderKeyDistributionMessage {
617            distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
618            chain_id: Some(chain_id),
619            iteration: Some(iteration),
620            chain_key: Some(chain_key.clone()),
621            signing_key: Some(signing_key.serialize().to_vec()),
622        };
623        let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
624        serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
625        proto_message
626            .encode(&mut serialized)
627            .expect("can always append to a buffer");
628
629        Ok(Self {
630            message_version,
631            distribution_id,
632            chain_id,
633            iteration,
634            chain_key,
635            signing_key,
636            serialized: serialized.into_boxed_slice(),
637        })
638    }
639
640    #[inline]
641    pub fn message_version(&self) -> u8 {
642        self.message_version
643    }
644
645    #[inline]
646    pub fn distribution_id(&self) -> Result<Uuid> {
647        Ok(self.distribution_id)
648    }
649
650    #[inline]
651    pub fn chain_id(&self) -> Result<u32> {
652        Ok(self.chain_id)
653    }
654
655    #[inline]
656    pub fn iteration(&self) -> Result<u32> {
657        Ok(self.iteration)
658    }
659
660    #[inline]
661    pub fn chain_key(&self) -> Result<&[u8]> {
662        Ok(&self.chain_key)
663    }
664
665    #[inline]
666    pub fn signing_key(&self) -> Result<&PublicKey> {
667        Ok(&self.signing_key)
668    }
669
670    #[inline]
671    pub fn serialized(&self) -> &[u8] {
672        &self.serialized
673    }
674}
675
676impl AsRef<[u8]> for SenderKeyDistributionMessage {
677    fn as_ref(&self) -> &[u8] {
678        &self.serialized
679    }
680}
681
682impl TryFrom<&[u8]> for SenderKeyDistributionMessage {
683    type Error = SignalProtocolError;
684
685    fn try_from(value: &[u8]) -> Result<Self> {
686        // The message contains at least a X25519 key and a chain key
687        if value.len() < 1 + 32 + 32 {
688            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
689        }
690
691        let message_version = value[0] >> 4;
692
693        if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
694            return Err(SignalProtocolError::LegacyCiphertextVersion(
695                message_version,
696            ));
697        }
698        if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
699            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
700                message_version,
701            ));
702        }
703
704        let proto_structure = proto::wire::SenderKeyDistributionMessage::decode(&value[1..])
705            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
706
707        let distribution_id = proto_structure
708            .distribution_uuid
709            .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
710            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
711        let chain_id = proto_structure
712            .chain_id
713            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
714        let iteration = proto_structure
715            .iteration
716            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
717        let chain_key = proto_structure
718            .chain_key
719            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
720        let signing_key = proto_structure
721            .signing_key
722            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
723
724        if chain_key.len() != 32 || signing_key.len() != 33 {
725            return Err(SignalProtocolError::InvalidProtobufEncoding);
726        }
727
728        let signing_key = PublicKey::deserialize(&signing_key)?;
729
730        Ok(SenderKeyDistributionMessage {
731            message_version,
732            distribution_id,
733            chain_id,
734            iteration,
735            chain_key,
736            signing_key,
737            serialized: Box::from(value),
738        })
739    }
740}
741
742#[derive(Debug, Clone)]
743pub struct PlaintextContent {
744    serialized: Box<[u8]>,
745}
746
747impl PlaintextContent {
748    /// Identifies a serialized PlaintextContent.
749    ///
750    /// This ensures someone doesn't try to serialize an arbitrary Content message as
751    /// PlaintextContent; only messages that are okay to send as plaintext should be allowed.
752    const PLAINTEXT_CONTEXT_IDENTIFIER_BYTE: u8 = 0xC0;
753
754    /// Marks the end of a message and the start of any padding.
755    ///
756    /// Usually messages are padded to avoid exposing patterns,
757    /// but PlaintextContent messages are all fixed-length anyway, so there won't be any padding.
758    const PADDING_BOUNDARY_BYTE: u8 = 0x80;
759
760    #[inline]
761    pub fn body(&self) -> &[u8] {
762        &self.serialized[1..]
763    }
764
765    #[inline]
766    pub fn serialized(&self) -> &[u8] {
767        &self.serialized
768    }
769}
770
771impl From<DecryptionErrorMessage> for PlaintextContent {
772    fn from(message: DecryptionErrorMessage) -> Self {
773        let proto_structure = proto::service::Content {
774            decryption_error_message: Some(message.serialized().to_vec()),
775            ..Default::default()
776        };
777        let mut serialized = vec![Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE];
778        proto_structure
779            .encode(&mut serialized)
780            .expect("can always encode to a Vec");
781        serialized.push(Self::PADDING_BOUNDARY_BYTE);
782        Self {
783            serialized: Box::from(serialized),
784        }
785    }
786}
787
788impl TryFrom<&[u8]> for PlaintextContent {
789    type Error = SignalProtocolError;
790
791    fn try_from(value: &[u8]) -> Result<Self> {
792        if value.is_empty() {
793            return Err(SignalProtocolError::CiphertextMessageTooShort(0));
794        }
795        if value[0] != Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE {
796            return Err(SignalProtocolError::UnrecognizedMessageVersion(
797                value[0] as u32,
798            ));
799        }
800        Ok(Self {
801            serialized: Box::from(value),
802        })
803    }
804}
805
806#[derive(Debug, Clone)]
807pub struct DecryptionErrorMessage {
808    ratchet_key: Option<PublicKey>,
809    timestamp: Timestamp,
810    device_id: u32,
811    serialized: Box<[u8]>,
812}
813
814impl DecryptionErrorMessage {
815    pub fn for_original(
816        original_bytes: &[u8],
817        original_type: CiphertextMessageType,
818        original_timestamp: Timestamp,
819        original_sender_device_id: u32,
820    ) -> Result<Self> {
821        let ratchet_key = match original_type {
822            CiphertextMessageType::Whisper => {
823                Some(*SignalMessage::try_from(original_bytes)?.sender_ratchet_key())
824            }
825            CiphertextMessageType::PreKey => Some(
826                *PreKeySignalMessage::try_from(original_bytes)?
827                    .message()
828                    .sender_ratchet_key(),
829            ),
830            CiphertextMessageType::SenderKey => None,
831            CiphertextMessageType::Plaintext => {
832                return Err(SignalProtocolError::InvalidArgument(
833                    "cannot create a DecryptionErrorMessage for plaintext content; it is not encrypted".to_string()
834                ));
835            }
836        };
837
838        let proto_message = proto::service::DecryptionErrorMessage {
839            timestamp: Some(original_timestamp.epoch_millis()),
840            ratchet_key: ratchet_key.map(|k| k.serialize().into()),
841            device_id: Some(original_sender_device_id),
842        };
843        let serialized = proto_message.encode_to_vec();
844
845        Ok(Self {
846            ratchet_key,
847            timestamp: original_timestamp,
848            device_id: original_sender_device_id,
849            serialized: serialized.into_boxed_slice(),
850        })
851    }
852
853    #[inline]
854    pub fn timestamp(&self) -> Timestamp {
855        self.timestamp
856    }
857
858    #[inline]
859    pub fn ratchet_key(&self) -> Option<&PublicKey> {
860        self.ratchet_key.as_ref()
861    }
862
863    #[inline]
864    pub fn device_id(&self) -> u32 {
865        self.device_id
866    }
867
868    #[inline]
869    pub fn serialized(&self) -> &[u8] {
870        &self.serialized
871    }
872}
873
874impl TryFrom<&[u8]> for DecryptionErrorMessage {
875    type Error = SignalProtocolError;
876
877    fn try_from(value: &[u8]) -> Result<Self> {
878        let proto_structure = proto::service::DecryptionErrorMessage::decode(value)
879            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
880        let timestamp = proto_structure
881            .timestamp
882            .map(Timestamp::from_epoch_millis)
883            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
884        let ratchet_key = proto_structure
885            .ratchet_key
886            .map(|k| PublicKey::deserialize(&k))
887            .transpose()?;
888        let device_id = proto_structure.device_id.unwrap_or_default();
889        Ok(Self {
890            timestamp,
891            ratchet_key,
892            device_id,
893            serialized: Box::from(value),
894        })
895    }
896}
897
898/// For testing
899pub fn extract_decryption_error_message_from_serialized_content(
900    bytes: &[u8],
901) -> Result<DecryptionErrorMessage> {
902    if bytes.last() != Some(&PlaintextContent::PADDING_BOUNDARY_BYTE) {
903        return Err(SignalProtocolError::InvalidProtobufEncoding);
904    }
905    let content = proto::service::Content::decode(bytes.split_last().expect("checked above").1)
906        .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
907    content
908        .decryption_error_message
909        .as_deref()
910        .ok_or_else(|| {
911            SignalProtocolError::InvalidArgument(
912                "Content does not contain DecryptionErrorMessage".to_owned(),
913            )
914        })
915        .and_then(DecryptionErrorMessage::try_from)
916}
917
918#[cfg(test)]
919mod tests {
920    use rand::rngs::OsRng;
921    use rand::{CryptoRng, Rng, TryRngCore as _};
922
923    use super::*;
924    use crate::KeyPair;
925
926    fn create_signal_message<T>(csprng: &mut T) -> Result<SignalMessage>
927    where
928        T: Rng + CryptoRng,
929    {
930        let mut mac_key = [0u8; 32];
931        csprng.fill_bytes(&mut mac_key);
932        let mac_key = mac_key;
933
934        let mut ciphertext = [0u8; 20];
935        csprng.fill_bytes(&mut ciphertext);
936        let ciphertext = ciphertext;
937
938        let sender_ratchet_key_pair = KeyPair::generate(csprng);
939        let sender_identity_key_pair = KeyPair::generate(csprng);
940        let receiver_identity_key_pair = KeyPair::generate(csprng);
941
942        SignalMessage::new(
943            4,
944            &mac_key,
945            sender_ratchet_key_pair.public_key,
946            42,
947            41,
948            &ciphertext,
949            &sender_identity_key_pair.public_key.into(),
950            &receiver_identity_key_pair.public_key.into(),
951            b"", // pq_ratchet
952        )
953    }
954
955    fn assert_signal_message_equals(m1: &SignalMessage, m2: &SignalMessage) {
956        assert_eq!(m1.message_version, m2.message_version);
957        assert_eq!(m1.sender_ratchet_key, m2.sender_ratchet_key);
958        assert_eq!(m1.counter, m2.counter);
959        assert_eq!(m1.previous_counter, m2.previous_counter);
960        assert_eq!(m1.ciphertext, m2.ciphertext);
961        assert_eq!(m1.serialized, m2.serialized);
962    }
963
964    #[test]
965    fn test_signal_message_serialize_deserialize() -> Result<()> {
966        let mut csprng = OsRng.unwrap_err();
967        let message = create_signal_message(&mut csprng)?;
968        let deser_message =
969            SignalMessage::try_from(message.as_ref()).expect("should deserialize without error");
970        assert_signal_message_equals(&message, &deser_message);
971        Ok(())
972    }
973
974    #[test]
975    fn test_pre_key_signal_message_serialize_deserialize() -> Result<()> {
976        let mut csprng = OsRng.unwrap_err();
977        let identity_key_pair = KeyPair::generate(&mut csprng);
978        let base_key_pair = KeyPair::generate(&mut csprng);
979        let message = create_signal_message(&mut csprng)?;
980        let pre_key_signal_message = PreKeySignalMessage::new(
981            3,
982            365,
983            None,
984            97.into(),
985            None, // TODO: add kyber prekeys
986            base_key_pair.public_key,
987            identity_key_pair.public_key.into(),
988            message,
989        )?;
990        let deser_pre_key_signal_message =
991            PreKeySignalMessage::try_from(pre_key_signal_message.as_ref())
992                .expect("should deserialize without error");
993        assert_eq!(
994            pre_key_signal_message.message_version,
995            deser_pre_key_signal_message.message_version
996        );
997        assert_eq!(
998            pre_key_signal_message.registration_id,
999            deser_pre_key_signal_message.registration_id
1000        );
1001        assert_eq!(
1002            pre_key_signal_message.pre_key_id,
1003            deser_pre_key_signal_message.pre_key_id
1004        );
1005        assert_eq!(
1006            pre_key_signal_message.signed_pre_key_id,
1007            deser_pre_key_signal_message.signed_pre_key_id
1008        );
1009        assert_eq!(
1010            pre_key_signal_message.base_key,
1011            deser_pre_key_signal_message.base_key
1012        );
1013        assert_eq!(
1014            pre_key_signal_message.identity_key.public_key(),
1015            deser_pre_key_signal_message.identity_key.public_key()
1016        );
1017        assert_signal_message_equals(
1018            &pre_key_signal_message.message,
1019            &deser_pre_key_signal_message.message,
1020        );
1021        assert_eq!(
1022            pre_key_signal_message.serialized,
1023            deser_pre_key_signal_message.serialized
1024        );
1025        Ok(())
1026    }
1027
1028    #[test]
1029    fn test_sender_key_message_serialize_deserialize() -> Result<()> {
1030        let mut csprng = OsRng.unwrap_err();
1031        let signature_key_pair = KeyPair::generate(&mut csprng);
1032        let sender_key_message = SenderKeyMessage::new(
1033            SENDERKEY_MESSAGE_CURRENT_VERSION,
1034            Uuid::from_u128(0xd1d1d1d1_7000_11eb_b32a_33b8a8a487a6),
1035            42,
1036            7,
1037            [1u8, 2, 3].into(),
1038            &mut csprng,
1039            &signature_key_pair.private_key,
1040        )?;
1041        let deser_sender_key_message = SenderKeyMessage::try_from(sender_key_message.as_ref())
1042            .expect("should deserialize without error");
1043        assert_eq!(
1044            sender_key_message.message_version,
1045            deser_sender_key_message.message_version
1046        );
1047        assert_eq!(
1048            sender_key_message.chain_id,
1049            deser_sender_key_message.chain_id
1050        );
1051        assert_eq!(
1052            sender_key_message.iteration,
1053            deser_sender_key_message.iteration
1054        );
1055        assert_eq!(
1056            sender_key_message.ciphertext,
1057            deser_sender_key_message.ciphertext
1058        );
1059        assert_eq!(
1060            sender_key_message.serialized,
1061            deser_sender_key_message.serialized
1062        );
1063        Ok(())
1064    }
1065
1066    #[test]
1067    fn test_decryption_error_message() -> Result<()> {
1068        let mut csprng = OsRng.unwrap_err();
1069        let identity_key_pair = KeyPair::generate(&mut csprng);
1070        let base_key_pair = KeyPair::generate(&mut csprng);
1071        let message = create_signal_message(&mut csprng)?;
1072        let timestamp: Timestamp = Timestamp::from_epoch_millis(0x2_0000_0001);
1073        let device_id = 0x8086_2021;
1074
1075        {
1076            let error_message = DecryptionErrorMessage::for_original(
1077                message.serialized(),
1078                CiphertextMessageType::Whisper,
1079                timestamp,
1080                device_id,
1081            )?;
1082            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1083            assert_eq!(
1084                error_message.ratchet_key(),
1085                Some(message.sender_ratchet_key())
1086            );
1087            assert_eq!(error_message.timestamp(), timestamp);
1088            assert_eq!(error_message.device_id(), device_id);
1089        }
1090
1091        let pre_key_signal_message = PreKeySignalMessage::new(
1092            3,
1093            365,
1094            None,
1095            97.into(),
1096            None, // TODO: add kyber prekeys
1097            base_key_pair.public_key,
1098            identity_key_pair.public_key.into(),
1099            message,
1100        )?;
1101
1102        {
1103            let error_message = DecryptionErrorMessage::for_original(
1104                pre_key_signal_message.serialized(),
1105                CiphertextMessageType::PreKey,
1106                timestamp,
1107                device_id,
1108            )?;
1109            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1110            assert_eq!(
1111                error_message.ratchet_key(),
1112                Some(pre_key_signal_message.message().sender_ratchet_key())
1113            );
1114            assert_eq!(error_message.timestamp(), timestamp);
1115            assert_eq!(error_message.device_id(), device_id);
1116        }
1117
1118        let sender_key_message = SenderKeyMessage::new(
1119            3,
1120            Uuid::nil(),
1121            1,
1122            2,
1123            Box::from(b"test".to_owned()),
1124            &mut csprng,
1125            &base_key_pair.private_key,
1126        )?;
1127
1128        {
1129            let error_message = DecryptionErrorMessage::for_original(
1130                sender_key_message.serialized(),
1131                CiphertextMessageType::SenderKey,
1132                timestamp,
1133                device_id,
1134            )?;
1135            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1136            assert_eq!(error_message.ratchet_key(), None);
1137            assert_eq!(error_message.timestamp(), timestamp);
1138            assert_eq!(error_message.device_id(), device_id);
1139        }
1140
1141        Ok(())
1142    }
1143
1144    #[test]
1145    fn test_decryption_error_message_for_plaintext() {
1146        assert!(matches!(
1147            DecryptionErrorMessage::for_original(
1148                &[],
1149                CiphertextMessageType::Plaintext,
1150                Timestamp::from_epoch_millis(5),
1151                7
1152            ),
1153            Err(SignalProtocolError::InvalidArgument(_))
1154        ));
1155    }
1156}