libsignal_protocol/
protocol.rs

1//
2// Copyright 2020-2021 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use hmac::{Hmac, Mac};
7use prost::Message;
8use rand::{CryptoRng, Rng};
9use sha2::Sha256;
10use subtle::ConstantTimeEq;
11use uuid::Uuid;
12
13use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
14use crate::{
15    kem, proto, IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError, Timestamp,
16};
17
18pub(crate) const CIPHERTEXT_MESSAGE_CURRENT_VERSION: u8 = 4;
19// Backward compatible, lacking Kyber keys, version
20pub(crate) const CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION: u8 = 3;
21pub(crate) const SENDERKEY_MESSAGE_CURRENT_VERSION: u8 = 3;
22
23#[derive(Debug)]
24pub enum CiphertextMessage {
25    SignalMessage(SignalMessage),
26    PreKeySignalMessage(PreKeySignalMessage),
27    SenderKeyMessage(SenderKeyMessage),
28    PlaintextContent(PlaintextContent),
29}
30
31#[derive(Copy, Clone, Eq, PartialEq, Debug, derive_more::TryFrom)]
32#[repr(u8)]
33#[try_from(repr)]
34pub enum CiphertextMessageType {
35    Whisper = 2,
36    PreKey = 3,
37    SenderKey = 7,
38    Plaintext = 8,
39}
40
41impl CiphertextMessage {
42    pub fn message_type(&self) -> CiphertextMessageType {
43        match self {
44            CiphertextMessage::SignalMessage(_) => CiphertextMessageType::Whisper,
45            CiphertextMessage::PreKeySignalMessage(_) => CiphertextMessageType::PreKey,
46            CiphertextMessage::SenderKeyMessage(_) => CiphertextMessageType::SenderKey,
47            CiphertextMessage::PlaintextContent(_) => CiphertextMessageType::Plaintext,
48        }
49    }
50
51    pub fn serialize(&self) -> &[u8] {
52        match self {
53            CiphertextMessage::SignalMessage(x) => x.serialized(),
54            CiphertextMessage::PreKeySignalMessage(x) => x.serialized(),
55            CiphertextMessage::SenderKeyMessage(x) => x.serialized(),
56            CiphertextMessage::PlaintextContent(x) => x.serialized(),
57        }
58    }
59}
60
61#[derive(Debug, Clone)]
62pub struct SignalMessage {
63    message_version: u8,
64    sender_ratchet_key: PublicKey,
65    counter: u32,
66    #[cfg_attr(not(test), expect(dead_code))]
67    previous_counter: u32,
68    ciphertext: Box<[u8]>,
69    pq_ratchet: spqr::SerializedState,
70    serialized: Box<[u8]>,
71}
72
73impl SignalMessage {
74    const MAC_LENGTH: usize = 8;
75
76    #[allow(clippy::too_many_arguments)]
77    pub fn new(
78        message_version: u8,
79        mac_key: &[u8],
80        sender_ratchet_key: PublicKey,
81        counter: u32,
82        previous_counter: u32,
83        ciphertext: &[u8],
84        sender_identity_key: &IdentityKey,
85        receiver_identity_key: &IdentityKey,
86        pq_ratchet: &[u8],
87    ) -> Result<Self> {
88        let message = proto::wire::SignalMessage {
89            ratchet_key: Some(sender_ratchet_key.serialize().into_vec()),
90            counter: Some(counter),
91            previous_counter: Some(previous_counter),
92            ciphertext: Some(Vec::<u8>::from(ciphertext)),
93            pq_ratchet: if pq_ratchet.is_empty() {
94                None
95            } else {
96                Some(pq_ratchet.to_vec())
97            },
98        };
99        let mut serialized = Vec::with_capacity(1 + message.encoded_len() + Self::MAC_LENGTH);
100        serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
101        message
102            .encode(&mut serialized)
103            .expect("can always append to a buffer");
104        let mac = Self::compute_mac(
105            sender_identity_key,
106            receiver_identity_key,
107            mac_key,
108            &serialized,
109        )?;
110        serialized.extend_from_slice(&mac);
111        let serialized = serialized.into_boxed_slice();
112        Ok(Self {
113            message_version,
114            sender_ratchet_key,
115            counter,
116            previous_counter,
117            ciphertext: ciphertext.into(),
118            pq_ratchet: pq_ratchet.to_vec(),
119            serialized,
120        })
121    }
122
123    #[inline]
124    pub fn message_version(&self) -> u8 {
125        self.message_version
126    }
127
128    #[inline]
129    pub fn sender_ratchet_key(&self) -> &PublicKey {
130        &self.sender_ratchet_key
131    }
132
133    #[inline]
134    pub fn counter(&self) -> u32 {
135        self.counter
136    }
137
138    #[inline]
139    pub fn pq_ratchet(&self) -> &spqr::SerializedMessage {
140        &self.pq_ratchet
141    }
142
143    #[inline]
144    pub fn serialized(&self) -> &[u8] {
145        &self.serialized
146    }
147
148    #[inline]
149    pub fn body(&self) -> &[u8] {
150        &self.ciphertext
151    }
152
153    pub fn verify_mac(
154        &self,
155        sender_identity_key: &IdentityKey,
156        receiver_identity_key: &IdentityKey,
157        mac_key: &[u8],
158    ) -> Result<bool> {
159        let our_mac = &Self::compute_mac(
160            sender_identity_key,
161            receiver_identity_key,
162            mac_key,
163            &self.serialized[..self.serialized.len() - Self::MAC_LENGTH],
164        )?;
165        let their_mac = &self.serialized[self.serialized.len() - Self::MAC_LENGTH..];
166        let result: bool = our_mac.ct_eq(their_mac).into();
167        if !result {
168            // A warning instead of an error because we try multiple sessions.
169            log::warn!(
170                "Bad Mac! Their Mac: {} Our Mac: {}",
171                hex::encode(their_mac),
172                hex::encode(our_mac)
173            );
174        }
175        Ok(result)
176    }
177
178    fn compute_mac(
179        sender_identity_key: &IdentityKey,
180        receiver_identity_key: &IdentityKey,
181        mac_key: &[u8],
182        message: &[u8],
183    ) -> Result<[u8; Self::MAC_LENGTH]> {
184        if mac_key.len() != 32 {
185            return Err(SignalProtocolError::InvalidMacKeyLength(mac_key.len()));
186        }
187        let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
188            .expect("HMAC-SHA256 should accept any size key");
189
190        mac.update(sender_identity_key.public_key().serialize().as_ref());
191        mac.update(receiver_identity_key.public_key().serialize().as_ref());
192        mac.update(message);
193        let mut result = [0u8; Self::MAC_LENGTH];
194        result.copy_from_slice(&mac.finalize().into_bytes()[..Self::MAC_LENGTH]);
195        Ok(result)
196    }
197}
198
199impl AsRef<[u8]> for SignalMessage {
200    fn as_ref(&self) -> &[u8] {
201        &self.serialized
202    }
203}
204
205impl TryFrom<&[u8]> for SignalMessage {
206    type Error = SignalProtocolError;
207
208    fn try_from(value: &[u8]) -> Result<Self> {
209        if value.len() < SignalMessage::MAC_LENGTH + 1 {
210            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
211        }
212        let message_version = value[0] >> 4;
213        if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
214            return Err(SignalProtocolError::LegacyCiphertextVersion(
215                message_version,
216            ));
217        }
218        if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
219            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
220                message_version,
221            ));
222        }
223
224        let proto_structure =
225            proto::wire::SignalMessage::decode(&value[1..value.len() - SignalMessage::MAC_LENGTH])
226                .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
227
228        let sender_ratchet_key = proto_structure
229            .ratchet_key
230            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
231        let sender_ratchet_key = PublicKey::deserialize(&sender_ratchet_key)?;
232        let counter = proto_structure
233            .counter
234            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
235        let previous_counter = proto_structure.previous_counter.unwrap_or(0);
236        let ciphertext = proto_structure
237            .ciphertext
238            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
239            .into_boxed_slice();
240
241        Ok(SignalMessage {
242            message_version,
243            sender_ratchet_key,
244            counter,
245            previous_counter,
246            ciphertext,
247            pq_ratchet: proto_structure.pq_ratchet.unwrap_or(vec![]),
248            serialized: Box::from(value),
249        })
250    }
251}
252
253#[derive(Debug, Clone)]
254pub struct KyberPayload {
255    pre_key_id: KyberPreKeyId,
256    ciphertext: kem::SerializedCiphertext,
257}
258
259impl KyberPayload {
260    pub fn new(id: KyberPreKeyId, ciphertext: kem::SerializedCiphertext) -> Self {
261        Self {
262            pre_key_id: id,
263            ciphertext,
264        }
265    }
266}
267
268#[derive(Debug, Clone)]
269pub struct PreKeySignalMessage {
270    message_version: u8,
271    registration_id: u32,
272    pre_key_id: Option<PreKeyId>,
273    signed_pre_key_id: SignedPreKeyId,
274    // While we reject messages without Kyber payloads, we still for now allow constructing the
275    // struct without one so that we can provide a better error message when we try to process it.
276    kyber_payload: Option<KyberPayload>,
277    base_key: PublicKey,
278    identity_key: IdentityKey,
279    message: SignalMessage,
280    serialized: Box<[u8]>,
281}
282
283impl PreKeySignalMessage {
284    pub fn new(
285        message_version: u8,
286        registration_id: u32,
287        pre_key_id: Option<PreKeyId>,
288        signed_pre_key_id: SignedPreKeyId,
289        kyber_payload: Option<KyberPayload>,
290        base_key: PublicKey,
291        identity_key: IdentityKey,
292        message: SignalMessage,
293    ) -> Result<Self> {
294        let proto_message = proto::wire::PreKeySignalMessage {
295            registration_id: Some(registration_id),
296            pre_key_id: pre_key_id.map(|id| id.into()),
297            signed_pre_key_id: Some(signed_pre_key_id.into()),
298            kyber_pre_key_id: kyber_payload.as_ref().map(|kyber| kyber.pre_key_id.into()),
299            kyber_ciphertext: kyber_payload
300                .as_ref()
301                .map(|kyber| kyber.ciphertext.to_vec()),
302            base_key: Some(base_key.serialize().into_vec()),
303            identity_key: Some(identity_key.serialize().into_vec()),
304            message: Some(Vec::from(message.as_ref())),
305        };
306        let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
307        serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
308        proto_message
309            .encode(&mut serialized)
310            .expect("can always append to a Vec");
311        Ok(Self {
312            message_version,
313            registration_id,
314            pre_key_id,
315            signed_pre_key_id,
316            kyber_payload,
317            base_key,
318            identity_key,
319            message,
320            serialized: serialized.into_boxed_slice(),
321        })
322    }
323
324    #[inline]
325    pub fn message_version(&self) -> u8 {
326        self.message_version
327    }
328
329    #[inline]
330    pub fn registration_id(&self) -> u32 {
331        self.registration_id
332    }
333
334    #[inline]
335    pub fn pre_key_id(&self) -> Option<PreKeyId> {
336        self.pre_key_id
337    }
338
339    #[inline]
340    pub fn signed_pre_key_id(&self) -> SignedPreKeyId {
341        self.signed_pre_key_id
342    }
343
344    #[inline]
345    pub fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
346        self.kyber_payload.as_ref().map(|kyber| kyber.pre_key_id)
347    }
348
349    #[inline]
350    pub fn kyber_ciphertext(&self) -> Option<&kem::SerializedCiphertext> {
351        self.kyber_payload.as_ref().map(|kyber| &kyber.ciphertext)
352    }
353
354    #[inline]
355    pub fn base_key(&self) -> &PublicKey {
356        &self.base_key
357    }
358
359    #[inline]
360    pub fn identity_key(&self) -> &IdentityKey {
361        &self.identity_key
362    }
363
364    #[inline]
365    pub fn message(&self) -> &SignalMessage {
366        &self.message
367    }
368
369    #[inline]
370    pub fn serialized(&self) -> &[u8] {
371        &self.serialized
372    }
373}
374
375impl AsRef<[u8]> for PreKeySignalMessage {
376    fn as_ref(&self) -> &[u8] {
377        &self.serialized
378    }
379}
380
381impl TryFrom<&[u8]> for PreKeySignalMessage {
382    type Error = SignalProtocolError;
383
384    fn try_from(value: &[u8]) -> Result<Self> {
385        if value.is_empty() {
386            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
387        }
388
389        let message_version = value[0] >> 4;
390        if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
391            return Err(SignalProtocolError::LegacyCiphertextVersion(
392                message_version,
393            ));
394        }
395        if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
396            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
397                message_version,
398            ));
399        }
400
401        let proto_structure = proto::wire::PreKeySignalMessage::decode(&value[1..])
402            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
403
404        let base_key = proto_structure
405            .base_key
406            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
407        let identity_key = proto_structure
408            .identity_key
409            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
410        let message = proto_structure
411            .message
412            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
413        let signed_pre_key_id = proto_structure
414            .signed_pre_key_id
415            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
416
417        let base_key = PublicKey::deserialize(base_key.as_ref())?;
418
419        let kyber_payload = match (
420            proto_structure.kyber_pre_key_id,
421            proto_structure.kyber_ciphertext,
422        ) {
423            (Some(id), Some(ct)) => Some(KyberPayload::new(id.into(), ct.into_boxed_slice())),
424            (None, None) if message_version <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION => None,
425            (None, None) => {
426                return Err(SignalProtocolError::InvalidMessage(
427                    CiphertextMessageType::PreKey,
428                    "Kyber pre key must be present for this session version",
429                ));
430            }
431            _ => {
432                return Err(SignalProtocolError::InvalidMessage(
433                    CiphertextMessageType::PreKey,
434                    "Both or neither kyber pre_key_id and kyber_ciphertext can be present",
435                ));
436            }
437        };
438
439        Ok(PreKeySignalMessage {
440            message_version,
441            registration_id: proto_structure.registration_id.unwrap_or(0),
442            pre_key_id: proto_structure.pre_key_id.map(|id| id.into()),
443            signed_pre_key_id: signed_pre_key_id.into(),
444            kyber_payload,
445            base_key,
446            identity_key: IdentityKey::try_from(identity_key.as_ref())?,
447            message: SignalMessage::try_from(message.as_ref())?,
448            serialized: Box::from(value),
449        })
450    }
451}
452
453#[derive(Debug, Clone)]
454pub struct SenderKeyMessage {
455    message_version: u8,
456    distribution_id: Uuid,
457    chain_id: u32,
458    iteration: u32,
459    ciphertext: Box<[u8]>,
460    serialized: Box<[u8]>,
461}
462
463impl SenderKeyMessage {
464    const SIGNATURE_LEN: usize = 64;
465
466    pub fn new<R: CryptoRng + Rng>(
467        message_version: u8,
468        distribution_id: Uuid,
469        chain_id: u32,
470        iteration: u32,
471        ciphertext: Box<[u8]>,
472        csprng: &mut R,
473        signature_key: &PrivateKey,
474    ) -> Result<Self> {
475        let proto_message = proto::wire::SenderKeyMessage {
476            distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
477            chain_id: Some(chain_id),
478            iteration: Some(iteration),
479            ciphertext: Some(ciphertext.to_vec()),
480        };
481        let proto_message_len = proto_message.encoded_len();
482        let mut serialized = Vec::with_capacity(1 + proto_message_len + Self::SIGNATURE_LEN);
483        serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
484        proto_message
485            .encode(&mut serialized)
486            .expect("can always append to a buffer");
487        let signature = signature_key.calculate_signature(&serialized, csprng)?;
488        serialized.extend_from_slice(&signature[..]);
489        Ok(Self {
490            message_version: SENDERKEY_MESSAGE_CURRENT_VERSION,
491            distribution_id,
492            chain_id,
493            iteration,
494            ciphertext,
495            serialized: serialized.into_boxed_slice(),
496        })
497    }
498
499    pub fn verify_signature(&self, signature_key: &PublicKey) -> Result<bool> {
500        let valid = signature_key.verify_signature(
501            &self.serialized[..self.serialized.len() - Self::SIGNATURE_LEN],
502            &self.serialized[self.serialized.len() - Self::SIGNATURE_LEN..],
503        );
504
505        Ok(valid)
506    }
507
508    #[inline]
509    pub fn message_version(&self) -> u8 {
510        self.message_version
511    }
512
513    #[inline]
514    pub fn distribution_id(&self) -> Uuid {
515        self.distribution_id
516    }
517
518    #[inline]
519    pub fn chain_id(&self) -> u32 {
520        self.chain_id
521    }
522
523    #[inline]
524    pub fn iteration(&self) -> u32 {
525        self.iteration
526    }
527
528    #[inline]
529    pub fn ciphertext(&self) -> &[u8] {
530        &self.ciphertext
531    }
532
533    #[inline]
534    pub fn serialized(&self) -> &[u8] {
535        &self.serialized
536    }
537}
538
539impl AsRef<[u8]> for SenderKeyMessage {
540    fn as_ref(&self) -> &[u8] {
541        &self.serialized
542    }
543}
544
545impl TryFrom<&[u8]> for SenderKeyMessage {
546    type Error = SignalProtocolError;
547
548    fn try_from(value: &[u8]) -> Result<Self> {
549        if value.len() < 1 + Self::SIGNATURE_LEN {
550            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
551        }
552        let message_version = value[0] >> 4;
553        if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
554            return Err(SignalProtocolError::LegacyCiphertextVersion(
555                message_version,
556            ));
557        }
558        if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
559            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
560                message_version,
561            ));
562        }
563        let proto_structure =
564            proto::wire::SenderKeyMessage::decode(&value[1..value.len() - Self::SIGNATURE_LEN])
565                .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
566
567        let distribution_id = proto_structure
568            .distribution_uuid
569            .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
570            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
571        let chain_id = proto_structure
572            .chain_id
573            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
574        let iteration = proto_structure
575            .iteration
576            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
577        let ciphertext = proto_structure
578            .ciphertext
579            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
580            .into_boxed_slice();
581
582        Ok(SenderKeyMessage {
583            message_version,
584            distribution_id,
585            chain_id,
586            iteration,
587            ciphertext,
588            serialized: Box::from(value),
589        })
590    }
591}
592
593#[derive(Debug, Clone)]
594pub struct SenderKeyDistributionMessage {
595    message_version: u8,
596    distribution_id: Uuid,
597    chain_id: u32,
598    iteration: u32,
599    chain_key: Vec<u8>,
600    signing_key: PublicKey,
601    serialized: Box<[u8]>,
602}
603
604impl SenderKeyDistributionMessage {
605    pub fn new(
606        message_version: u8,
607        distribution_id: Uuid,
608        chain_id: u32,
609        iteration: u32,
610        chain_key: Vec<u8>,
611        signing_key: PublicKey,
612    ) -> Result<Self> {
613        let proto_message = proto::wire::SenderKeyDistributionMessage {
614            distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
615            chain_id: Some(chain_id),
616            iteration: Some(iteration),
617            chain_key: Some(chain_key.clone()),
618            signing_key: Some(signing_key.serialize().to_vec()),
619        };
620        let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
621        serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
622        proto_message
623            .encode(&mut serialized)
624            .expect("can always append to a buffer");
625
626        Ok(Self {
627            message_version,
628            distribution_id,
629            chain_id,
630            iteration,
631            chain_key,
632            signing_key,
633            serialized: serialized.into_boxed_slice(),
634        })
635    }
636
637    #[inline]
638    pub fn message_version(&self) -> u8 {
639        self.message_version
640    }
641
642    #[inline]
643    pub fn distribution_id(&self) -> Result<Uuid> {
644        Ok(self.distribution_id)
645    }
646
647    #[inline]
648    pub fn chain_id(&self) -> Result<u32> {
649        Ok(self.chain_id)
650    }
651
652    #[inline]
653    pub fn iteration(&self) -> Result<u32> {
654        Ok(self.iteration)
655    }
656
657    #[inline]
658    pub fn chain_key(&self) -> Result<&[u8]> {
659        Ok(&self.chain_key)
660    }
661
662    #[inline]
663    pub fn signing_key(&self) -> Result<&PublicKey> {
664        Ok(&self.signing_key)
665    }
666
667    #[inline]
668    pub fn serialized(&self) -> &[u8] {
669        &self.serialized
670    }
671}
672
673impl AsRef<[u8]> for SenderKeyDistributionMessage {
674    fn as_ref(&self) -> &[u8] {
675        &self.serialized
676    }
677}
678
679impl TryFrom<&[u8]> for SenderKeyDistributionMessage {
680    type Error = SignalProtocolError;
681
682    fn try_from(value: &[u8]) -> Result<Self> {
683        // The message contains at least a X25519 key and a chain key
684        if value.len() < 1 + 32 + 32 {
685            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
686        }
687
688        let message_version = value[0] >> 4;
689
690        if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
691            return Err(SignalProtocolError::LegacyCiphertextVersion(
692                message_version,
693            ));
694        }
695        if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
696            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
697                message_version,
698            ));
699        }
700
701        let proto_structure = proto::wire::SenderKeyDistributionMessage::decode(&value[1..])
702            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
703
704        let distribution_id = proto_structure
705            .distribution_uuid
706            .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
707            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
708        let chain_id = proto_structure
709            .chain_id
710            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
711        let iteration = proto_structure
712            .iteration
713            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
714        let chain_key = proto_structure
715            .chain_key
716            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
717        let signing_key = proto_structure
718            .signing_key
719            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
720
721        if chain_key.len() != 32 || signing_key.len() != 33 {
722            return Err(SignalProtocolError::InvalidProtobufEncoding);
723        }
724
725        let signing_key = PublicKey::deserialize(&signing_key)?;
726
727        Ok(SenderKeyDistributionMessage {
728            message_version,
729            distribution_id,
730            chain_id,
731            iteration,
732            chain_key,
733            signing_key,
734            serialized: Box::from(value),
735        })
736    }
737}
738
739#[derive(Debug, Clone)]
740pub struct PlaintextContent {
741    serialized: Box<[u8]>,
742}
743
744impl PlaintextContent {
745    /// Identifies a serialized PlaintextContent.
746    ///
747    /// This ensures someone doesn't try to serialize an arbitrary Content message as
748    /// PlaintextContent; only messages that are okay to send as plaintext should be allowed.
749    const PLAINTEXT_CONTEXT_IDENTIFIER_BYTE: u8 = 0xC0;
750
751    /// Marks the end of a message and the start of any padding.
752    ///
753    /// Usually messages are padded to avoid exposing patterns,
754    /// but PlaintextContent messages are all fixed-length anyway, so there won't be any padding.
755    const PADDING_BOUNDARY_BYTE: u8 = 0x80;
756
757    #[inline]
758    pub fn body(&self) -> &[u8] {
759        &self.serialized[1..]
760    }
761
762    #[inline]
763    pub fn serialized(&self) -> &[u8] {
764        &self.serialized
765    }
766}
767
768impl From<DecryptionErrorMessage> for PlaintextContent {
769    fn from(message: DecryptionErrorMessage) -> Self {
770        let proto_structure = proto::service::Content {
771            decryption_error_message: Some(message.serialized().to_vec()),
772            ..Default::default()
773        };
774        let mut serialized = vec![Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE];
775        proto_structure
776            .encode(&mut serialized)
777            .expect("can always encode to a Vec");
778        serialized.push(Self::PADDING_BOUNDARY_BYTE);
779        Self {
780            serialized: Box::from(serialized),
781        }
782    }
783}
784
785impl TryFrom<&[u8]> for PlaintextContent {
786    type Error = SignalProtocolError;
787
788    fn try_from(value: &[u8]) -> Result<Self> {
789        if value.is_empty() {
790            return Err(SignalProtocolError::CiphertextMessageTooShort(0));
791        }
792        if value[0] != Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE {
793            return Err(SignalProtocolError::UnrecognizedMessageVersion(
794                value[0] as u32,
795            ));
796        }
797        Ok(Self {
798            serialized: Box::from(value),
799        })
800    }
801}
802
803#[derive(Debug, Clone)]
804pub struct DecryptionErrorMessage {
805    ratchet_key: Option<PublicKey>,
806    timestamp: Timestamp,
807    device_id: u32,
808    serialized: Box<[u8]>,
809}
810
811impl DecryptionErrorMessage {
812    pub fn for_original(
813        original_bytes: &[u8],
814        original_type: CiphertextMessageType,
815        original_timestamp: Timestamp,
816        original_sender_device_id: u32,
817    ) -> Result<Self> {
818        let ratchet_key = match original_type {
819            CiphertextMessageType::Whisper => {
820                Some(*SignalMessage::try_from(original_bytes)?.sender_ratchet_key())
821            }
822            CiphertextMessageType::PreKey => Some(
823                *PreKeySignalMessage::try_from(original_bytes)?
824                    .message()
825                    .sender_ratchet_key(),
826            ),
827            CiphertextMessageType::SenderKey => None,
828            CiphertextMessageType::Plaintext => {
829                return Err(SignalProtocolError::InvalidArgument(
830                    "cannot create a DecryptionErrorMessage for plaintext content; it is not encrypted".to_string()
831                ));
832            }
833        };
834
835        let proto_message = proto::service::DecryptionErrorMessage {
836            timestamp: Some(original_timestamp.epoch_millis()),
837            ratchet_key: ratchet_key.map(|k| k.serialize().into()),
838            device_id: Some(original_sender_device_id),
839        };
840        let serialized = proto_message.encode_to_vec();
841
842        Ok(Self {
843            ratchet_key,
844            timestamp: original_timestamp,
845            device_id: original_sender_device_id,
846            serialized: serialized.into_boxed_slice(),
847        })
848    }
849
850    #[inline]
851    pub fn timestamp(&self) -> Timestamp {
852        self.timestamp
853    }
854
855    #[inline]
856    pub fn ratchet_key(&self) -> Option<&PublicKey> {
857        self.ratchet_key.as_ref()
858    }
859
860    #[inline]
861    pub fn device_id(&self) -> u32 {
862        self.device_id
863    }
864
865    #[inline]
866    pub fn serialized(&self) -> &[u8] {
867        &self.serialized
868    }
869}
870
871impl TryFrom<&[u8]> for DecryptionErrorMessage {
872    type Error = SignalProtocolError;
873
874    fn try_from(value: &[u8]) -> Result<Self> {
875        let proto_structure = proto::service::DecryptionErrorMessage::decode(value)
876            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
877        let timestamp = proto_structure
878            .timestamp
879            .map(Timestamp::from_epoch_millis)
880            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
881        let ratchet_key = proto_structure
882            .ratchet_key
883            .map(|k| PublicKey::deserialize(&k))
884            .transpose()?;
885        let device_id = proto_structure.device_id.unwrap_or_default();
886        Ok(Self {
887            timestamp,
888            ratchet_key,
889            device_id,
890            serialized: Box::from(value),
891        })
892    }
893}
894
895/// For testing
896pub fn extract_decryption_error_message_from_serialized_content(
897    bytes: &[u8],
898) -> Result<DecryptionErrorMessage> {
899    if bytes.last() != Some(&PlaintextContent::PADDING_BOUNDARY_BYTE) {
900        return Err(SignalProtocolError::InvalidProtobufEncoding);
901    }
902    let content = proto::service::Content::decode(bytes.split_last().expect("checked above").1)
903        .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
904    content
905        .decryption_error_message
906        .as_deref()
907        .ok_or_else(|| {
908            SignalProtocolError::InvalidArgument(
909                "Content does not contain DecryptionErrorMessage".to_owned(),
910            )
911        })
912        .and_then(DecryptionErrorMessage::try_from)
913}
914
915#[cfg(test)]
916mod tests {
917    use rand::rngs::OsRng;
918    use rand::{CryptoRng, Rng, TryRngCore as _};
919
920    use super::*;
921    use crate::KeyPair;
922
923    fn create_signal_message<T>(csprng: &mut T) -> Result<SignalMessage>
924    where
925        T: Rng + CryptoRng,
926    {
927        let mut mac_key = [0u8; 32];
928        csprng.fill_bytes(&mut mac_key);
929        let mac_key = mac_key;
930
931        let mut ciphertext = [0u8; 20];
932        csprng.fill_bytes(&mut ciphertext);
933        let ciphertext = ciphertext;
934
935        let sender_ratchet_key_pair = KeyPair::generate(csprng);
936        let sender_identity_key_pair = KeyPair::generate(csprng);
937        let receiver_identity_key_pair = KeyPair::generate(csprng);
938
939        SignalMessage::new(
940            4,
941            &mac_key,
942            sender_ratchet_key_pair.public_key,
943            42,
944            41,
945            &ciphertext,
946            &sender_identity_key_pair.public_key.into(),
947            &receiver_identity_key_pair.public_key.into(),
948            b"", // pq_ratchet
949        )
950    }
951
952    fn assert_signal_message_equals(m1: &SignalMessage, m2: &SignalMessage) {
953        assert_eq!(m1.message_version, m2.message_version);
954        assert_eq!(m1.sender_ratchet_key, m2.sender_ratchet_key);
955        assert_eq!(m1.counter, m2.counter);
956        assert_eq!(m1.previous_counter, m2.previous_counter);
957        assert_eq!(m1.ciphertext, m2.ciphertext);
958        assert_eq!(m1.serialized, m2.serialized);
959    }
960
961    #[test]
962    fn test_signal_message_serialize_deserialize() -> Result<()> {
963        let mut csprng = OsRng.unwrap_err();
964        let message = create_signal_message(&mut csprng)?;
965        let deser_message =
966            SignalMessage::try_from(message.as_ref()).expect("should deserialize without error");
967        assert_signal_message_equals(&message, &deser_message);
968        Ok(())
969    }
970
971    #[test]
972    fn test_pre_key_signal_message_serialize_deserialize() -> Result<()> {
973        let mut csprng = OsRng.unwrap_err();
974        let identity_key_pair = KeyPair::generate(&mut csprng);
975        let base_key_pair = KeyPair::generate(&mut csprng);
976        let message = create_signal_message(&mut csprng)?;
977        let pre_key_signal_message = PreKeySignalMessage::new(
978            3,
979            365,
980            None,
981            97.into(),
982            None, // TODO: add kyber prekeys
983            base_key_pair.public_key,
984            identity_key_pair.public_key.into(),
985            message,
986        )?;
987        let deser_pre_key_signal_message =
988            PreKeySignalMessage::try_from(pre_key_signal_message.as_ref())
989                .expect("should deserialize without error");
990        assert_eq!(
991            pre_key_signal_message.message_version,
992            deser_pre_key_signal_message.message_version
993        );
994        assert_eq!(
995            pre_key_signal_message.registration_id,
996            deser_pre_key_signal_message.registration_id
997        );
998        assert_eq!(
999            pre_key_signal_message.pre_key_id,
1000            deser_pre_key_signal_message.pre_key_id
1001        );
1002        assert_eq!(
1003            pre_key_signal_message.signed_pre_key_id,
1004            deser_pre_key_signal_message.signed_pre_key_id
1005        );
1006        assert_eq!(
1007            pre_key_signal_message.base_key,
1008            deser_pre_key_signal_message.base_key
1009        );
1010        assert_eq!(
1011            pre_key_signal_message.identity_key.public_key(),
1012            deser_pre_key_signal_message.identity_key.public_key()
1013        );
1014        assert_signal_message_equals(
1015            &pre_key_signal_message.message,
1016            &deser_pre_key_signal_message.message,
1017        );
1018        assert_eq!(
1019            pre_key_signal_message.serialized,
1020            deser_pre_key_signal_message.serialized
1021        );
1022        Ok(())
1023    }
1024
1025    #[test]
1026    fn test_sender_key_message_serialize_deserialize() -> Result<()> {
1027        let mut csprng = OsRng.unwrap_err();
1028        let signature_key_pair = KeyPair::generate(&mut csprng);
1029        let sender_key_message = SenderKeyMessage::new(
1030            SENDERKEY_MESSAGE_CURRENT_VERSION,
1031            Uuid::from_u128(0xd1d1d1d1_7000_11eb_b32a_33b8a8a487a6),
1032            42,
1033            7,
1034            [1u8, 2, 3].into(),
1035            &mut csprng,
1036            &signature_key_pair.private_key,
1037        )?;
1038        let deser_sender_key_message = SenderKeyMessage::try_from(sender_key_message.as_ref())
1039            .expect("should deserialize without error");
1040        assert_eq!(
1041            sender_key_message.message_version,
1042            deser_sender_key_message.message_version
1043        );
1044        assert_eq!(
1045            sender_key_message.chain_id,
1046            deser_sender_key_message.chain_id
1047        );
1048        assert_eq!(
1049            sender_key_message.iteration,
1050            deser_sender_key_message.iteration
1051        );
1052        assert_eq!(
1053            sender_key_message.ciphertext,
1054            deser_sender_key_message.ciphertext
1055        );
1056        assert_eq!(
1057            sender_key_message.serialized,
1058            deser_sender_key_message.serialized
1059        );
1060        Ok(())
1061    }
1062
1063    #[test]
1064    fn test_decryption_error_message() -> Result<()> {
1065        let mut csprng = OsRng.unwrap_err();
1066        let identity_key_pair = KeyPair::generate(&mut csprng);
1067        let base_key_pair = KeyPair::generate(&mut csprng);
1068        let message = create_signal_message(&mut csprng)?;
1069        let timestamp: Timestamp = Timestamp::from_epoch_millis(0x2_0000_0001);
1070        let device_id = 0x8086_2021;
1071
1072        {
1073            let error_message = DecryptionErrorMessage::for_original(
1074                message.serialized(),
1075                CiphertextMessageType::Whisper,
1076                timestamp,
1077                device_id,
1078            )?;
1079            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1080            assert_eq!(
1081                error_message.ratchet_key(),
1082                Some(message.sender_ratchet_key())
1083            );
1084            assert_eq!(error_message.timestamp(), timestamp);
1085            assert_eq!(error_message.device_id(), device_id);
1086        }
1087
1088        let pre_key_signal_message = PreKeySignalMessage::new(
1089            3,
1090            365,
1091            None,
1092            97.into(),
1093            None, // TODO: add kyber prekeys
1094            base_key_pair.public_key,
1095            identity_key_pair.public_key.into(),
1096            message,
1097        )?;
1098
1099        {
1100            let error_message = DecryptionErrorMessage::for_original(
1101                pre_key_signal_message.serialized(),
1102                CiphertextMessageType::PreKey,
1103                timestamp,
1104                device_id,
1105            )?;
1106            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1107            assert_eq!(
1108                error_message.ratchet_key(),
1109                Some(pre_key_signal_message.message().sender_ratchet_key())
1110            );
1111            assert_eq!(error_message.timestamp(), timestamp);
1112            assert_eq!(error_message.device_id(), device_id);
1113        }
1114
1115        let sender_key_message = SenderKeyMessage::new(
1116            3,
1117            Uuid::nil(),
1118            1,
1119            2,
1120            Box::from(b"test".to_owned()),
1121            &mut csprng,
1122            &base_key_pair.private_key,
1123        )?;
1124
1125        {
1126            let error_message = DecryptionErrorMessage::for_original(
1127                sender_key_message.serialized(),
1128                CiphertextMessageType::SenderKey,
1129                timestamp,
1130                device_id,
1131            )?;
1132            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1133            assert_eq!(error_message.ratchet_key(), None);
1134            assert_eq!(error_message.timestamp(), timestamp);
1135            assert_eq!(error_message.device_id(), device_id);
1136        }
1137
1138        Ok(())
1139    }
1140
1141    #[test]
1142    fn test_decryption_error_message_for_plaintext() {
1143        assert!(matches!(
1144            DecryptionErrorMessage::for_original(
1145                &[],
1146                CiphertextMessageType::Plaintext,
1147                Timestamp::from_epoch_millis(5),
1148                7
1149            ),
1150            Err(SignalProtocolError::InvalidArgument(_))
1151        ));
1152    }
1153}