libsignal_protocol/
protocol.rs

1//
2// Copyright 2020-2021 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use hmac::{Hmac, Mac};
7use prost::Message;
8use rand::{CryptoRng, Rng};
9use sha2::Sha256;
10use subtle::ConstantTimeEq;
11use uuid::Uuid;
12
13use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
14use crate::{
15    kem, proto, IdentityKey, PrivateKey, PublicKey, Result, SignalProtocolError, Timestamp,
16};
17
18pub(crate) const CIPHERTEXT_MESSAGE_CURRENT_VERSION: u8 = 4;
19// Backward compatible, lacking Kyber keys, version
20pub(crate) const CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION: u8 = 3;
21pub(crate) const SENDERKEY_MESSAGE_CURRENT_VERSION: u8 = 3;
22
23#[derive(Debug)]
24pub enum CiphertextMessage {
25    SignalMessage(SignalMessage),
26    PreKeySignalMessage(PreKeySignalMessage),
27    SenderKeyMessage(SenderKeyMessage),
28    PlaintextContent(PlaintextContent),
29}
30
31#[derive(Copy, Clone, Eq, PartialEq, Debug, num_enum::TryFromPrimitive)]
32#[repr(u8)]
33pub enum CiphertextMessageType {
34    Whisper = 2,
35    PreKey = 3,
36    SenderKey = 7,
37    Plaintext = 8,
38}
39
40impl CiphertextMessage {
41    pub fn message_type(&self) -> CiphertextMessageType {
42        match self {
43            CiphertextMessage::SignalMessage(_) => CiphertextMessageType::Whisper,
44            CiphertextMessage::PreKeySignalMessage(_) => CiphertextMessageType::PreKey,
45            CiphertextMessage::SenderKeyMessage(_) => CiphertextMessageType::SenderKey,
46            CiphertextMessage::PlaintextContent(_) => CiphertextMessageType::Plaintext,
47        }
48    }
49
50    pub fn serialize(&self) -> &[u8] {
51        match self {
52            CiphertextMessage::SignalMessage(x) => x.serialized(),
53            CiphertextMessage::PreKeySignalMessage(x) => x.serialized(),
54            CiphertextMessage::SenderKeyMessage(x) => x.serialized(),
55            CiphertextMessage::PlaintextContent(x) => x.serialized(),
56        }
57    }
58}
59
60#[derive(Debug, Clone)]
61pub struct SignalMessage {
62    message_version: u8,
63    sender_ratchet_key: PublicKey,
64    counter: u32,
65    #[allow(dead_code)]
66    previous_counter: u32,
67    ciphertext: Box<[u8]>,
68    serialized: Box<[u8]>,
69}
70
71impl SignalMessage {
72    const MAC_LENGTH: usize = 8;
73
74    pub fn new(
75        message_version: u8,
76        mac_key: &[u8],
77        sender_ratchet_key: PublicKey,
78        counter: u32,
79        previous_counter: u32,
80        ciphertext: &[u8],
81        sender_identity_key: &IdentityKey,
82        receiver_identity_key: &IdentityKey,
83    ) -> Result<Self> {
84        let message = proto::wire::SignalMessage {
85            ratchet_key: Some(sender_ratchet_key.serialize().into_vec()),
86            counter: Some(counter),
87            previous_counter: Some(previous_counter),
88            ciphertext: Some(Vec::<u8>::from(ciphertext)),
89        };
90        let mut serialized = Vec::with_capacity(1 + message.encoded_len() + Self::MAC_LENGTH);
91        serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
92        message
93            .encode(&mut serialized)
94            .expect("can always append to a buffer");
95        let mac = Self::compute_mac(
96            sender_identity_key,
97            receiver_identity_key,
98            mac_key,
99            &serialized,
100        )?;
101        serialized.extend_from_slice(&mac);
102        let serialized = serialized.into_boxed_slice();
103        Ok(Self {
104            message_version,
105            sender_ratchet_key,
106            counter,
107            previous_counter,
108            ciphertext: ciphertext.into(),
109            serialized,
110        })
111    }
112
113    #[inline]
114    pub fn message_version(&self) -> u8 {
115        self.message_version
116    }
117
118    #[inline]
119    pub fn sender_ratchet_key(&self) -> &PublicKey {
120        &self.sender_ratchet_key
121    }
122
123    #[inline]
124    pub fn counter(&self) -> u32 {
125        self.counter
126    }
127
128    #[inline]
129    pub fn serialized(&self) -> &[u8] {
130        &self.serialized
131    }
132
133    #[inline]
134    pub fn body(&self) -> &[u8] {
135        &self.ciphertext
136    }
137
138    pub fn verify_mac(
139        &self,
140        sender_identity_key: &IdentityKey,
141        receiver_identity_key: &IdentityKey,
142        mac_key: &[u8],
143    ) -> Result<bool> {
144        let our_mac = &Self::compute_mac(
145            sender_identity_key,
146            receiver_identity_key,
147            mac_key,
148            &self.serialized[..self.serialized.len() - Self::MAC_LENGTH],
149        )?;
150        let their_mac = &self.serialized[self.serialized.len() - Self::MAC_LENGTH..];
151        let result: bool = our_mac.ct_eq(their_mac).into();
152        if !result {
153            // A warning instead of an error because we try multiple sessions.
154            log::warn!(
155                "Bad Mac! Their Mac: {} Our Mac: {}",
156                hex::encode(their_mac),
157                hex::encode(our_mac)
158            );
159        }
160        Ok(result)
161    }
162
163    fn compute_mac(
164        sender_identity_key: &IdentityKey,
165        receiver_identity_key: &IdentityKey,
166        mac_key: &[u8],
167        message: &[u8],
168    ) -> Result<[u8; Self::MAC_LENGTH]> {
169        if mac_key.len() != 32 {
170            return Err(SignalProtocolError::InvalidMacKeyLength(mac_key.len()));
171        }
172        let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
173            .expect("HMAC-SHA256 should accept any size key");
174
175        mac.update(sender_identity_key.public_key().serialize().as_ref());
176        mac.update(receiver_identity_key.public_key().serialize().as_ref());
177        mac.update(message);
178        let mut result = [0u8; Self::MAC_LENGTH];
179        result.copy_from_slice(&mac.finalize().into_bytes()[..Self::MAC_LENGTH]);
180        Ok(result)
181    }
182}
183
184impl AsRef<[u8]> for SignalMessage {
185    fn as_ref(&self) -> &[u8] {
186        &self.serialized
187    }
188}
189
190impl TryFrom<&[u8]> for SignalMessage {
191    type Error = SignalProtocolError;
192
193    fn try_from(value: &[u8]) -> Result<Self> {
194        if value.len() < SignalMessage::MAC_LENGTH + 1 {
195            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
196        }
197        let message_version = value[0] >> 4;
198        if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
199            return Err(SignalProtocolError::LegacyCiphertextVersion(
200                message_version,
201            ));
202        }
203        if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
204            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
205                message_version,
206            ));
207        }
208
209        let proto_structure =
210            proto::wire::SignalMessage::decode(&value[1..value.len() - SignalMessage::MAC_LENGTH])
211                .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
212
213        let sender_ratchet_key = proto_structure
214            .ratchet_key
215            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
216        let sender_ratchet_key = PublicKey::deserialize(&sender_ratchet_key)?;
217        let counter = proto_structure
218            .counter
219            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
220        let previous_counter = proto_structure.previous_counter.unwrap_or(0);
221        let ciphertext = proto_structure
222            .ciphertext
223            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
224            .into_boxed_slice();
225
226        Ok(SignalMessage {
227            message_version,
228            sender_ratchet_key,
229            counter,
230            previous_counter,
231            ciphertext,
232            serialized: Box::from(value),
233        })
234    }
235}
236
237#[derive(Debug, Clone)]
238pub struct KyberPayload {
239    pre_key_id: KyberPreKeyId,
240    ciphertext: kem::SerializedCiphertext,
241}
242
243impl KyberPayload {
244    pub fn new(id: KyberPreKeyId, ciphertext: kem::SerializedCiphertext) -> Self {
245        Self {
246            pre_key_id: id,
247            ciphertext,
248        }
249    }
250}
251
252#[derive(Debug, Clone)]
253pub struct PreKeySignalMessage {
254    message_version: u8,
255    registration_id: u32,
256    pre_key_id: Option<PreKeyId>,
257    signed_pre_key_id: SignedPreKeyId,
258    kyber_payload: Option<KyberPayload>,
259    base_key: PublicKey,
260    identity_key: IdentityKey,
261    message: SignalMessage,
262    serialized: Box<[u8]>,
263}
264
265impl PreKeySignalMessage {
266    #[allow(clippy::too_many_arguments)]
267    pub fn new(
268        message_version: u8,
269        registration_id: u32,
270        pre_key_id: Option<PreKeyId>,
271        signed_pre_key_id: SignedPreKeyId,
272        kyber_payload: Option<KyberPayload>,
273        base_key: PublicKey,
274        identity_key: IdentityKey,
275        message: SignalMessage,
276    ) -> Result<Self> {
277        let proto_message = proto::wire::PreKeySignalMessage {
278            registration_id: Some(registration_id),
279            pre_key_id: pre_key_id.map(|id| id.into()),
280            signed_pre_key_id: Some(signed_pre_key_id.into()),
281            kyber_pre_key_id: kyber_payload.as_ref().map(|kyber| kyber.pre_key_id.into()),
282            kyber_ciphertext: kyber_payload
283                .as_ref()
284                .map(|kyber| kyber.ciphertext.to_vec()),
285            base_key: Some(base_key.serialize().into_vec()),
286            identity_key: Some(identity_key.serialize().into_vec()),
287            message: Some(Vec::from(message.as_ref())),
288        };
289        let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
290        serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
291        proto_message
292            .encode(&mut serialized)
293            .expect("can always append to a Vec");
294        Ok(Self {
295            message_version,
296            registration_id,
297            pre_key_id,
298            signed_pre_key_id,
299            kyber_payload,
300            base_key,
301            identity_key,
302            message,
303            serialized: serialized.into_boxed_slice(),
304        })
305    }
306
307    #[inline]
308    pub fn message_version(&self) -> u8 {
309        self.message_version
310    }
311
312    #[inline]
313    pub fn registration_id(&self) -> u32 {
314        self.registration_id
315    }
316
317    #[inline]
318    pub fn pre_key_id(&self) -> Option<PreKeyId> {
319        self.pre_key_id
320    }
321
322    #[inline]
323    pub fn signed_pre_key_id(&self) -> SignedPreKeyId {
324        self.signed_pre_key_id
325    }
326
327    #[inline]
328    pub fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
329        self.kyber_payload.as_ref().map(|kyber| kyber.pre_key_id)
330    }
331
332    #[inline]
333    pub fn kyber_ciphertext(&self) -> Option<&kem::SerializedCiphertext> {
334        self.kyber_payload.as_ref().map(|kyber| &kyber.ciphertext)
335    }
336
337    #[inline]
338    pub fn base_key(&self) -> &PublicKey {
339        &self.base_key
340    }
341
342    #[inline]
343    pub fn identity_key(&self) -> &IdentityKey {
344        &self.identity_key
345    }
346
347    #[inline]
348    pub fn message(&self) -> &SignalMessage {
349        &self.message
350    }
351
352    #[inline]
353    pub fn serialized(&self) -> &[u8] {
354        &self.serialized
355    }
356}
357
358impl AsRef<[u8]> for PreKeySignalMessage {
359    fn as_ref(&self) -> &[u8] {
360        &self.serialized
361    }
362}
363
364impl TryFrom<&[u8]> for PreKeySignalMessage {
365    type Error = SignalProtocolError;
366
367    fn try_from(value: &[u8]) -> Result<Self> {
368        if value.is_empty() {
369            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
370        }
371
372        let message_version = value[0] >> 4;
373        if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
374            return Err(SignalProtocolError::LegacyCiphertextVersion(
375                message_version,
376            ));
377        }
378        if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
379            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
380                message_version,
381            ));
382        }
383
384        let proto_structure = proto::wire::PreKeySignalMessage::decode(&value[1..])
385            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
386
387        let base_key = proto_structure
388            .base_key
389            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
390        let identity_key = proto_structure
391            .identity_key
392            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
393        let message = proto_structure
394            .message
395            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
396        let signed_pre_key_id = proto_structure
397            .signed_pre_key_id
398            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
399
400        let base_key = PublicKey::deserialize(base_key.as_ref())?;
401
402        let kyber_payload = match (
403            proto_structure.kyber_pre_key_id,
404            proto_structure.kyber_ciphertext,
405        ) {
406            (Some(id), Some(ct)) => Some(KyberPayload::new(id.into(), ct.into_boxed_slice())),
407            (None, None) if message_version <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION => None,
408            (None, None) => {
409                return Err(SignalProtocolError::InvalidMessage(
410                    CiphertextMessageType::PreKey,
411                    "Kyber pre key must be present for this session version",
412                ));
413            }
414            _ => {
415                return Err(SignalProtocolError::InvalidMessage(
416                    CiphertextMessageType::PreKey,
417                    "Both or neither kyber pre_key_id and kyber_ciphertext can be present",
418                ));
419            }
420        };
421
422        Ok(PreKeySignalMessage {
423            message_version,
424            registration_id: proto_structure.registration_id.unwrap_or(0),
425            pre_key_id: proto_structure.pre_key_id.map(|id| id.into()),
426            signed_pre_key_id: signed_pre_key_id.into(),
427            kyber_payload,
428            base_key,
429            identity_key: IdentityKey::try_from(identity_key.as_ref())?,
430            message: SignalMessage::try_from(message.as_ref())?,
431            serialized: Box::from(value),
432        })
433    }
434}
435
436#[derive(Debug, Clone)]
437pub struct SenderKeyMessage {
438    message_version: u8,
439    distribution_id: Uuid,
440    chain_id: u32,
441    iteration: u32,
442    ciphertext: Box<[u8]>,
443    serialized: Box<[u8]>,
444}
445
446impl SenderKeyMessage {
447    const SIGNATURE_LEN: usize = 64;
448
449    pub fn new<R: CryptoRng + Rng>(
450        message_version: u8,
451        distribution_id: Uuid,
452        chain_id: u32,
453        iteration: u32,
454        ciphertext: Box<[u8]>,
455        csprng: &mut R,
456        signature_key: &PrivateKey,
457    ) -> Result<Self> {
458        let proto_message = proto::wire::SenderKeyMessage {
459            distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
460            chain_id: Some(chain_id),
461            iteration: Some(iteration),
462            ciphertext: Some(ciphertext.to_vec()),
463        };
464        let proto_message_len = proto_message.encoded_len();
465        let mut serialized = Vec::with_capacity(1 + proto_message_len + Self::SIGNATURE_LEN);
466        serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
467        proto_message
468            .encode(&mut serialized)
469            .expect("can always append to a buffer");
470        let signature = signature_key.calculate_signature(&serialized, csprng)?;
471        serialized.extend_from_slice(&signature[..]);
472        Ok(Self {
473            message_version: SENDERKEY_MESSAGE_CURRENT_VERSION,
474            distribution_id,
475            chain_id,
476            iteration,
477            ciphertext,
478            serialized: serialized.into_boxed_slice(),
479        })
480    }
481
482    pub fn verify_signature(&self, signature_key: &PublicKey) -> Result<bool> {
483        let valid = signature_key.verify_signature(
484            &self.serialized[..self.serialized.len() - Self::SIGNATURE_LEN],
485            &self.serialized[self.serialized.len() - Self::SIGNATURE_LEN..],
486        )?;
487
488        Ok(valid)
489    }
490
491    #[inline]
492    pub fn message_version(&self) -> u8 {
493        self.message_version
494    }
495
496    #[inline]
497    pub fn distribution_id(&self) -> Uuid {
498        self.distribution_id
499    }
500
501    #[inline]
502    pub fn chain_id(&self) -> u32 {
503        self.chain_id
504    }
505
506    #[inline]
507    pub fn iteration(&self) -> u32 {
508        self.iteration
509    }
510
511    #[inline]
512    pub fn ciphertext(&self) -> &[u8] {
513        &self.ciphertext
514    }
515
516    #[inline]
517    pub fn serialized(&self) -> &[u8] {
518        &self.serialized
519    }
520}
521
522impl AsRef<[u8]> for SenderKeyMessage {
523    fn as_ref(&self) -> &[u8] {
524        &self.serialized
525    }
526}
527
528impl TryFrom<&[u8]> for SenderKeyMessage {
529    type Error = SignalProtocolError;
530
531    fn try_from(value: &[u8]) -> Result<Self> {
532        if value.len() < 1 + Self::SIGNATURE_LEN {
533            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
534        }
535        let message_version = value[0] >> 4;
536        if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
537            return Err(SignalProtocolError::LegacyCiphertextVersion(
538                message_version,
539            ));
540        }
541        if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
542            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
543                message_version,
544            ));
545        }
546        let proto_structure =
547            proto::wire::SenderKeyMessage::decode(&value[1..value.len() - Self::SIGNATURE_LEN])
548                .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
549
550        let distribution_id = proto_structure
551            .distribution_uuid
552            .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
553            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
554        let chain_id = proto_structure
555            .chain_id
556            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
557        let iteration = proto_structure
558            .iteration
559            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
560        let ciphertext = proto_structure
561            .ciphertext
562            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
563            .into_boxed_slice();
564
565        Ok(SenderKeyMessage {
566            message_version,
567            distribution_id,
568            chain_id,
569            iteration,
570            ciphertext,
571            serialized: Box::from(value),
572        })
573    }
574}
575
576#[derive(Debug, Clone)]
577pub struct SenderKeyDistributionMessage {
578    message_version: u8,
579    distribution_id: Uuid,
580    chain_id: u32,
581    iteration: u32,
582    chain_key: Vec<u8>,
583    signing_key: PublicKey,
584    serialized: Box<[u8]>,
585}
586
587impl SenderKeyDistributionMessage {
588    pub fn new(
589        message_version: u8,
590        distribution_id: Uuid,
591        chain_id: u32,
592        iteration: u32,
593        chain_key: Vec<u8>,
594        signing_key: PublicKey,
595    ) -> Result<Self> {
596        let proto_message = proto::wire::SenderKeyDistributionMessage {
597            distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
598            chain_id: Some(chain_id),
599            iteration: Some(iteration),
600            chain_key: Some(chain_key.clone()),
601            signing_key: Some(signing_key.serialize().to_vec()),
602        };
603        let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
604        serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
605        proto_message
606            .encode(&mut serialized)
607            .expect("can always append to a buffer");
608
609        Ok(Self {
610            message_version,
611            distribution_id,
612            chain_id,
613            iteration,
614            chain_key,
615            signing_key,
616            serialized: serialized.into_boxed_slice(),
617        })
618    }
619
620    #[inline]
621    pub fn message_version(&self) -> u8 {
622        self.message_version
623    }
624
625    #[inline]
626    pub fn distribution_id(&self) -> Result<Uuid> {
627        Ok(self.distribution_id)
628    }
629
630    #[inline]
631    pub fn chain_id(&self) -> Result<u32> {
632        Ok(self.chain_id)
633    }
634
635    #[inline]
636    pub fn iteration(&self) -> Result<u32> {
637        Ok(self.iteration)
638    }
639
640    #[inline]
641    pub fn chain_key(&self) -> Result<&[u8]> {
642        Ok(&self.chain_key)
643    }
644
645    #[inline]
646    pub fn signing_key(&self) -> Result<&PublicKey> {
647        Ok(&self.signing_key)
648    }
649
650    #[inline]
651    pub fn serialized(&self) -> &[u8] {
652        &self.serialized
653    }
654}
655
656impl AsRef<[u8]> for SenderKeyDistributionMessage {
657    fn as_ref(&self) -> &[u8] {
658        &self.serialized
659    }
660}
661
662impl TryFrom<&[u8]> for SenderKeyDistributionMessage {
663    type Error = SignalProtocolError;
664
665    fn try_from(value: &[u8]) -> Result<Self> {
666        // The message contains at least a X25519 key and a chain key
667        if value.len() < 1 + 32 + 32 {
668            return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
669        }
670
671        let message_version = value[0] >> 4;
672
673        if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
674            return Err(SignalProtocolError::LegacyCiphertextVersion(
675                message_version,
676            ));
677        }
678        if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
679            return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
680                message_version,
681            ));
682        }
683
684        let proto_structure = proto::wire::SenderKeyDistributionMessage::decode(&value[1..])
685            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
686
687        let distribution_id = proto_structure
688            .distribution_uuid
689            .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
690            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
691        let chain_id = proto_structure
692            .chain_id
693            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
694        let iteration = proto_structure
695            .iteration
696            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
697        let chain_key = proto_structure
698            .chain_key
699            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
700        let signing_key = proto_structure
701            .signing_key
702            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
703
704        if chain_key.len() != 32 || signing_key.len() != 33 {
705            return Err(SignalProtocolError::InvalidProtobufEncoding);
706        }
707
708        let signing_key = PublicKey::deserialize(&signing_key)?;
709
710        Ok(SenderKeyDistributionMessage {
711            message_version,
712            distribution_id,
713            chain_id,
714            iteration,
715            chain_key,
716            signing_key,
717            serialized: Box::from(value),
718        })
719    }
720}
721
722#[derive(Debug, Clone)]
723pub struct PlaintextContent {
724    serialized: Box<[u8]>,
725}
726
727impl PlaintextContent {
728    /// Identifies a serialized PlaintextContent.
729    ///
730    /// This ensures someone doesn't try to serialize an arbitrary Content message as
731    /// PlaintextContent; only messages that are okay to send as plaintext should be allowed.
732    const PLAINTEXT_CONTEXT_IDENTIFIER_BYTE: u8 = 0xC0;
733
734    /// Marks the end of a message and the start of any padding.
735    ///
736    /// Usually messages are padded to avoid exposing patterns,
737    /// but PlaintextContent messages are all fixed-length anyway, so there won't be any padding.
738    const PADDING_BOUNDARY_BYTE: u8 = 0x80;
739
740    #[inline]
741    pub fn body(&self) -> &[u8] {
742        &self.serialized[1..]
743    }
744
745    #[inline]
746    pub fn serialized(&self) -> &[u8] {
747        &self.serialized
748    }
749}
750
751impl From<DecryptionErrorMessage> for PlaintextContent {
752    fn from(message: DecryptionErrorMessage) -> Self {
753        let proto_structure = proto::service::Content {
754            decryption_error_message: Some(message.serialized().to_vec()),
755            ..Default::default()
756        };
757        let mut serialized = vec![Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE];
758        proto_structure
759            .encode(&mut serialized)
760            .expect("can always encode to a Vec");
761        serialized.push(Self::PADDING_BOUNDARY_BYTE);
762        Self {
763            serialized: Box::from(serialized),
764        }
765    }
766}
767
768impl TryFrom<&[u8]> for PlaintextContent {
769    type Error = SignalProtocolError;
770
771    fn try_from(value: &[u8]) -> Result<Self> {
772        if value.is_empty() {
773            return Err(SignalProtocolError::CiphertextMessageTooShort(0));
774        }
775        if value[0] != Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE {
776            return Err(SignalProtocolError::UnrecognizedMessageVersion(
777                value[0] as u32,
778            ));
779        }
780        Ok(Self {
781            serialized: Box::from(value),
782        })
783    }
784}
785
786#[derive(Debug, Clone)]
787pub struct DecryptionErrorMessage {
788    ratchet_key: Option<PublicKey>,
789    timestamp: Timestamp,
790    device_id: u32,
791    serialized: Box<[u8]>,
792}
793
794impl DecryptionErrorMessage {
795    pub fn for_original(
796        original_bytes: &[u8],
797        original_type: CiphertextMessageType,
798        original_timestamp: Timestamp,
799        original_sender_device_id: u32,
800    ) -> Result<Self> {
801        let ratchet_key = match original_type {
802            CiphertextMessageType::Whisper => {
803                Some(*SignalMessage::try_from(original_bytes)?.sender_ratchet_key())
804            }
805            CiphertextMessageType::PreKey => Some(
806                *PreKeySignalMessage::try_from(original_bytes)?
807                    .message()
808                    .sender_ratchet_key(),
809            ),
810            CiphertextMessageType::SenderKey => None,
811            CiphertextMessageType::Plaintext => {
812                return Err(SignalProtocolError::InvalidArgument(
813                    "cannot create a DecryptionErrorMessage for plaintext content; it is not encrypted".to_string()
814                ));
815            }
816        };
817
818        let proto_message = proto::service::DecryptionErrorMessage {
819            timestamp: Some(original_timestamp.epoch_millis()),
820            ratchet_key: ratchet_key.map(|k| k.serialize().into()),
821            device_id: Some(original_sender_device_id),
822        };
823        let serialized = proto_message.encode_to_vec();
824
825        Ok(Self {
826            ratchet_key,
827            timestamp: original_timestamp,
828            device_id: original_sender_device_id,
829            serialized: serialized.into_boxed_slice(),
830        })
831    }
832
833    #[inline]
834    pub fn timestamp(&self) -> Timestamp {
835        self.timestamp
836    }
837
838    #[inline]
839    pub fn ratchet_key(&self) -> Option<&PublicKey> {
840        self.ratchet_key.as_ref()
841    }
842
843    #[inline]
844    pub fn device_id(&self) -> u32 {
845        self.device_id
846    }
847
848    #[inline]
849    pub fn serialized(&self) -> &[u8] {
850        &self.serialized
851    }
852}
853
854impl TryFrom<&[u8]> for DecryptionErrorMessage {
855    type Error = SignalProtocolError;
856
857    fn try_from(value: &[u8]) -> Result<Self> {
858        let proto_structure = proto::service::DecryptionErrorMessage::decode(value)
859            .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
860        let timestamp = proto_structure
861            .timestamp
862            .map(Timestamp::from_epoch_millis)
863            .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
864        let ratchet_key = proto_structure
865            .ratchet_key
866            .map(|k| PublicKey::deserialize(&k))
867            .transpose()?;
868        let device_id = proto_structure.device_id.unwrap_or_default();
869        Ok(Self {
870            timestamp,
871            ratchet_key,
872            device_id,
873            serialized: Box::from(value),
874        })
875    }
876}
877
878/// For testing
879pub fn extract_decryption_error_message_from_serialized_content(
880    bytes: &[u8],
881) -> Result<DecryptionErrorMessage> {
882    if bytes.last() != Some(&PlaintextContent::PADDING_BOUNDARY_BYTE) {
883        return Err(SignalProtocolError::InvalidProtobufEncoding);
884    }
885    let content = proto::service::Content::decode(bytes.split_last().expect("checked above").1)
886        .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
887    content
888        .decryption_error_message
889        .as_deref()
890        .ok_or_else(|| {
891            SignalProtocolError::InvalidArgument(
892                "Content does not contain DecryptionErrorMessage".to_owned(),
893            )
894        })
895        .and_then(DecryptionErrorMessage::try_from)
896}
897
898#[cfg(test)]
899mod tests {
900    use rand::rngs::OsRng;
901    use rand::{CryptoRng, Rng};
902
903    use super::*;
904    use crate::KeyPair;
905
906    fn create_signal_message<T>(csprng: &mut T) -> Result<SignalMessage>
907    where
908        T: Rng + CryptoRng,
909    {
910        let mut mac_key = [0u8; 32];
911        csprng.fill_bytes(&mut mac_key);
912        let mac_key = mac_key;
913
914        let mut ciphertext = [0u8; 20];
915        csprng.fill_bytes(&mut ciphertext);
916        let ciphertext = ciphertext;
917
918        let sender_ratchet_key_pair = KeyPair::generate(csprng);
919        let sender_identity_key_pair = KeyPair::generate(csprng);
920        let receiver_identity_key_pair = KeyPair::generate(csprng);
921
922        SignalMessage::new(
923            4,
924            &mac_key,
925            sender_ratchet_key_pair.public_key,
926            42,
927            41,
928            &ciphertext,
929            &sender_identity_key_pair.public_key.into(),
930            &receiver_identity_key_pair.public_key.into(),
931        )
932    }
933
934    fn assert_signal_message_equals(m1: &SignalMessage, m2: &SignalMessage) {
935        assert_eq!(m1.message_version, m2.message_version);
936        assert_eq!(m1.sender_ratchet_key, m2.sender_ratchet_key);
937        assert_eq!(m1.counter, m2.counter);
938        assert_eq!(m1.previous_counter, m2.previous_counter);
939        assert_eq!(m1.ciphertext, m2.ciphertext);
940        assert_eq!(m1.serialized, m2.serialized);
941    }
942
943    #[test]
944    fn test_signal_message_serialize_deserialize() -> Result<()> {
945        let mut csprng = OsRng;
946        let message = create_signal_message(&mut csprng)?;
947        let deser_message =
948            SignalMessage::try_from(message.as_ref()).expect("should deserialize without error");
949        assert_signal_message_equals(&message, &deser_message);
950        Ok(())
951    }
952
953    #[test]
954    fn test_pre_key_signal_message_serialize_deserialize() -> Result<()> {
955        let mut csprng = OsRng;
956        let identity_key_pair = KeyPair::generate(&mut csprng);
957        let base_key_pair = KeyPair::generate(&mut csprng);
958        let message = create_signal_message(&mut csprng)?;
959        let pre_key_signal_message = PreKeySignalMessage::new(
960            3,
961            365,
962            None,
963            97.into(),
964            None, // TODO: add kyber prekeys
965            base_key_pair.public_key,
966            identity_key_pair.public_key.into(),
967            message,
968        )?;
969        let deser_pre_key_signal_message =
970            PreKeySignalMessage::try_from(pre_key_signal_message.as_ref())
971                .expect("should deserialize without error");
972        assert_eq!(
973            pre_key_signal_message.message_version,
974            deser_pre_key_signal_message.message_version
975        );
976        assert_eq!(
977            pre_key_signal_message.registration_id,
978            deser_pre_key_signal_message.registration_id
979        );
980        assert_eq!(
981            pre_key_signal_message.pre_key_id,
982            deser_pre_key_signal_message.pre_key_id
983        );
984        assert_eq!(
985            pre_key_signal_message.signed_pre_key_id,
986            deser_pre_key_signal_message.signed_pre_key_id
987        );
988        assert_eq!(
989            pre_key_signal_message.base_key,
990            deser_pre_key_signal_message.base_key
991        );
992        assert_eq!(
993            pre_key_signal_message.identity_key.public_key(),
994            deser_pre_key_signal_message.identity_key.public_key()
995        );
996        assert_signal_message_equals(
997            &pre_key_signal_message.message,
998            &deser_pre_key_signal_message.message,
999        );
1000        assert_eq!(
1001            pre_key_signal_message.serialized,
1002            deser_pre_key_signal_message.serialized
1003        );
1004        Ok(())
1005    }
1006
1007    #[test]
1008    fn test_sender_key_message_serialize_deserialize() -> Result<()> {
1009        let mut csprng = OsRng;
1010        let signature_key_pair = KeyPair::generate(&mut csprng);
1011        let sender_key_message = SenderKeyMessage::new(
1012            SENDERKEY_MESSAGE_CURRENT_VERSION,
1013            Uuid::from_u128(0xd1d1d1d1_7000_11eb_b32a_33b8a8a487a6),
1014            42,
1015            7,
1016            [1u8, 2, 3].into(),
1017            &mut csprng,
1018            &signature_key_pair.private_key,
1019        )?;
1020        let deser_sender_key_message = SenderKeyMessage::try_from(sender_key_message.as_ref())
1021            .expect("should deserialize without error");
1022        assert_eq!(
1023            sender_key_message.message_version,
1024            deser_sender_key_message.message_version
1025        );
1026        assert_eq!(
1027            sender_key_message.chain_id,
1028            deser_sender_key_message.chain_id
1029        );
1030        assert_eq!(
1031            sender_key_message.iteration,
1032            deser_sender_key_message.iteration
1033        );
1034        assert_eq!(
1035            sender_key_message.ciphertext,
1036            deser_sender_key_message.ciphertext
1037        );
1038        assert_eq!(
1039            sender_key_message.serialized,
1040            deser_sender_key_message.serialized
1041        );
1042        Ok(())
1043    }
1044
1045    #[test]
1046    fn test_decryption_error_message() -> Result<()> {
1047        let mut csprng = OsRng;
1048        let identity_key_pair = KeyPair::generate(&mut csprng);
1049        let base_key_pair = KeyPair::generate(&mut csprng);
1050        let message = create_signal_message(&mut csprng)?;
1051        let timestamp: Timestamp = Timestamp::from_epoch_millis(0x2_0000_0001);
1052        let device_id = 0x8086_2021;
1053
1054        {
1055            let error_message = DecryptionErrorMessage::for_original(
1056                message.serialized(),
1057                CiphertextMessageType::Whisper,
1058                timestamp,
1059                device_id,
1060            )?;
1061            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1062            assert_eq!(
1063                error_message.ratchet_key(),
1064                Some(message.sender_ratchet_key())
1065            );
1066            assert_eq!(error_message.timestamp(), timestamp);
1067            assert_eq!(error_message.device_id(), device_id);
1068        }
1069
1070        let pre_key_signal_message = PreKeySignalMessage::new(
1071            3,
1072            365,
1073            None,
1074            97.into(),
1075            None, // TODO: add kyber prekeys
1076            base_key_pair.public_key,
1077            identity_key_pair.public_key.into(),
1078            message,
1079        )?;
1080
1081        {
1082            let error_message = DecryptionErrorMessage::for_original(
1083                pre_key_signal_message.serialized(),
1084                CiphertextMessageType::PreKey,
1085                timestamp,
1086                device_id,
1087            )?;
1088            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1089            assert_eq!(
1090                error_message.ratchet_key(),
1091                Some(pre_key_signal_message.message().sender_ratchet_key())
1092            );
1093            assert_eq!(error_message.timestamp(), timestamp);
1094            assert_eq!(error_message.device_id(), device_id);
1095        }
1096
1097        let sender_key_message = SenderKeyMessage::new(
1098            3,
1099            Uuid::nil(),
1100            1,
1101            2,
1102            Box::from(b"test".to_owned()),
1103            &mut csprng,
1104            &base_key_pair.private_key,
1105        )?;
1106
1107        {
1108            let error_message = DecryptionErrorMessage::for_original(
1109                sender_key_message.serialized(),
1110                CiphertextMessageType::SenderKey,
1111                timestamp,
1112                device_id,
1113            )?;
1114            let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1115            assert_eq!(error_message.ratchet_key(), None);
1116            assert_eq!(error_message.timestamp(), timestamp);
1117            assert_eq!(error_message.device_id(), device_id);
1118        }
1119
1120        Ok(())
1121    }
1122
1123    #[test]
1124    fn test_decryption_error_message_for_plaintext() {
1125        assert!(matches!(
1126            DecryptionErrorMessage::for_original(
1127                &[],
1128                CiphertextMessageType::Plaintext,
1129                Timestamp::from_epoch_millis(5),
1130                7
1131            ),
1132            Err(SignalProtocolError::InvalidArgument(_))
1133        ));
1134    }
1135}