1use hmac::{Hmac, Mac};
7use prost::Message;
8use rand::{CryptoRng, Rng};
9use sha2::Sha256;
10use subtle::ConstantTimeEq;
11use uuid::Uuid;
12
13use crate::state::{KyberPreKeyId, PreKeyId, SignedPreKeyId};
14use crate::{
15 IdentityKey, PrivateKey, ProtocolAddress, PublicKey, Result, ServiceId, SignalProtocolError,
16 Timestamp, kem, proto,
17};
18
19pub(crate) const CIPHERTEXT_MESSAGE_CURRENT_VERSION: u8 = 4;
20pub(crate) const CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION: u8 = 3;
22pub(crate) const SENDERKEY_MESSAGE_CURRENT_VERSION: u8 = 3;
23
24#[derive(Debug, Clone)]
25pub enum CiphertextMessage {
26 SignalMessage(SignalMessage),
27 PreKeySignalMessage(PreKeySignalMessage),
28 SenderKeyMessage(SenderKeyMessage),
29 PlaintextContent(PlaintextContent),
30}
31
32#[derive(Copy, Clone, Eq, PartialEq, Debug, derive_more::TryFrom)]
33#[repr(u8)]
34#[try_from(repr)]
35pub enum CiphertextMessageType {
36 Whisper = 2,
37 PreKey = 3,
38 SenderKey = 7,
39 Plaintext = 8,
40}
41
42impl CiphertextMessage {
43 pub fn message_type(&self) -> CiphertextMessageType {
44 match self {
45 CiphertextMessage::SignalMessage(_) => CiphertextMessageType::Whisper,
46 CiphertextMessage::PreKeySignalMessage(_) => CiphertextMessageType::PreKey,
47 CiphertextMessage::SenderKeyMessage(_) => CiphertextMessageType::SenderKey,
48 CiphertextMessage::PlaintextContent(_) => CiphertextMessageType::Plaintext,
49 }
50 }
51
52 pub fn serialize(&self) -> &[u8] {
53 match self {
54 CiphertextMessage::SignalMessage(x) => x.serialized(),
55 CiphertextMessage::PreKeySignalMessage(x) => x.serialized(),
56 CiphertextMessage::SenderKeyMessage(x) => x.serialized(),
57 CiphertextMessage::PlaintextContent(x) => x.serialized(),
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct SignalMessage {
64 message_version: u8,
65 sender_ratchet_key: PublicKey,
66 counter: u32,
67 #[cfg_attr(not(test), expect(dead_code))]
68 previous_counter: u32,
69 ciphertext: Box<[u8]>,
70 pq_ratchet: spqr::SerializedState,
71 addresses: Option<Box<[u8]>>,
72 serialized: Box<[u8]>,
73}
74
75impl SignalMessage {
76 const MAC_LENGTH: usize = 8;
77
78 #[allow(clippy::too_many_arguments)]
79 pub fn new(
80 message_version: u8,
81 mac_key: &[u8],
82 addresses: Option<(&ProtocolAddress, &ProtocolAddress)>,
83 sender_ratchet_key: PublicKey,
84 counter: u32,
85 previous_counter: u32,
86 ciphertext: &[u8],
87 sender_identity_key: &IdentityKey,
88 receiver_identity_key: &IdentityKey,
89 pq_ratchet: &[u8],
90 ) -> Result<Self> {
91 let addresses =
92 addresses.and_then(|(sender, recipient)| Self::serialize_addresses(sender, recipient));
93 let message = proto::wire::SignalMessage {
94 ratchet_key: Some(sender_ratchet_key.serialize().into_vec()),
95 counter: Some(counter),
96 previous_counter: Some(previous_counter),
97 ciphertext: Some(Vec::<u8>::from(ciphertext)),
98 pq_ratchet: if pq_ratchet.is_empty() {
99 None
100 } else {
101 Some(pq_ratchet.to_vec())
102 },
103 addresses,
104 };
105 let mut serialized = Vec::with_capacity(1 + message.encoded_len() + Self::MAC_LENGTH);
106 serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
107 message
108 .encode(&mut serialized)
109 .expect("can always append to a buffer");
110 let mac = Self::compute_mac(
111 sender_identity_key,
112 receiver_identity_key,
113 mac_key,
114 &serialized,
115 )?;
116 serialized.extend_from_slice(&mac);
117 let serialized = serialized.into_boxed_slice();
118 Ok(Self {
119 message_version,
120 sender_ratchet_key,
121 counter,
122 previous_counter,
123 ciphertext: ciphertext.into(),
124 pq_ratchet: pq_ratchet.to_vec(),
125 addresses: message.addresses.map(Into::into),
126 serialized,
127 })
128 }
129
130 #[inline]
131 pub fn message_version(&self) -> u8 {
132 self.message_version
133 }
134
135 #[inline]
136 pub fn sender_ratchet_key(&self) -> &PublicKey {
137 &self.sender_ratchet_key
138 }
139
140 #[inline]
141 pub fn counter(&self) -> u32 {
142 self.counter
143 }
144
145 #[inline]
146 pub fn pq_ratchet(&self) -> &spqr::SerializedMessage {
147 &self.pq_ratchet
148 }
149
150 #[inline]
151 pub fn serialized(&self) -> &[u8] {
152 &self.serialized
153 }
154
155 #[inline]
156 pub fn body(&self) -> &[u8] {
157 &self.ciphertext
158 }
159
160 pub(crate) fn verify_mac(
161 &self,
162 sender_identity_key: &IdentityKey,
163 receiver_identity_key: &IdentityKey,
164 mac_key: &[u8],
165 ) -> Result<bool> {
166 let (content, their_mac) = self
167 .serialized
168 .split_last_chunk::<{ Self::MAC_LENGTH }>()
169 .expect("length checked at construction");
170 let our_mac =
171 Self::compute_mac(sender_identity_key, receiver_identity_key, mac_key, content)?;
172 let result: bool = our_mac.ct_eq(their_mac).into();
173 if !result {
174 log::warn!(
176 "Bad Mac! Their Mac: {} Our Mac: {}",
177 hex::encode(their_mac),
178 hex::encode(our_mac)
179 );
180 return Ok(false);
181 }
182
183 Ok(true)
184 }
185
186 pub fn verify_mac_with_addresses(
187 &self,
188 sender_address: &ProtocolAddress,
189 recipient_address: &ProtocolAddress,
190 sender_identity_key: &IdentityKey,
191 receiver_identity_key: &IdentityKey,
192 mac_key: &[u8],
193 ) -> Result<bool> {
194 if !self.verify_mac(sender_identity_key, receiver_identity_key, mac_key)? {
195 return Ok(false);
196 }
197
198 let Some(encoded_addresses) = &self.addresses else {
201 return Ok(true);
202 };
203
204 let Some(expected) = Self::serialize_addresses(sender_address, recipient_address) else {
205 log::warn!(
206 "Locally supplied addresses not valid Service IDs: sender={}, recipient={}",
207 sender_address,
208 recipient_address,
209 );
210 return Ok(false);
211 };
212
213 if bool::from(expected.ct_eq(encoded_addresses.as_ref())) {
214 Ok(true)
215 } else {
216 log::warn!(
217 "Address mismatch: sender={}, recipient={}",
218 sender_address,
219 recipient_address,
220 );
221 Ok(false)
222 }
223 }
224
225 fn compute_mac(
226 sender_identity_key: &IdentityKey,
227 receiver_identity_key: &IdentityKey,
228 mac_key: &[u8],
229 message: &[u8],
230 ) -> Result<[u8; Self::MAC_LENGTH]> {
231 if mac_key.len() != 32 {
232 return Err(SignalProtocolError::InvalidMacKeyLength(mac_key.len()));
233 }
234 let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
235 .expect("HMAC-SHA256 should accept any size key");
236
237 mac.update(sender_identity_key.public_key().serialize().as_ref());
238 mac.update(receiver_identity_key.public_key().serialize().as_ref());
239 mac.update(message);
240 let result = *mac
241 .finalize()
242 .into_bytes()
243 .first_chunk()
244 .expect("enough bytes");
245 Ok(result)
246 }
247
248 fn serialize_addresses(
251 sender: &ProtocolAddress,
252 recipient: &ProtocolAddress,
253 ) -> Option<Vec<u8>> {
254 let sender_service_id = ServiceId::parse_from_service_id_string(sender.name())?;
255 let recipient_service_id = ServiceId::parse_from_service_id_string(recipient.name())?;
256
257 let mut bytes = Vec::with_capacity(36);
258 bytes.extend_from_slice(&sender_service_id.service_id_fixed_width_binary());
259 bytes.push(sender.device_id().into());
260 bytes.extend_from_slice(&recipient_service_id.service_id_fixed_width_binary());
261 bytes.push(recipient.device_id().into());
262 Some(bytes)
263 }
264}
265
266impl AsRef<[u8]> for SignalMessage {
267 fn as_ref(&self) -> &[u8] {
268 &self.serialized
269 }
270}
271
272impl TryFrom<&[u8]> for SignalMessage {
273 type Error = SignalProtocolError;
274
275 fn try_from(value: &[u8]) -> Result<Self> {
276 if value.len() < SignalMessage::MAC_LENGTH + 1 {
277 return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
278 }
279 let message_version = value[0] >> 4;
280 if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
281 return Err(SignalProtocolError::LegacyCiphertextVersion(
282 message_version,
283 ));
284 }
285 if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
286 return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
287 message_version,
288 ));
289 }
290
291 let proto_structure =
292 proto::wire::SignalMessage::decode(&value[1..value.len() - SignalMessage::MAC_LENGTH])
293 .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
294
295 let sender_ratchet_key = proto_structure
296 .ratchet_key
297 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
298 let sender_ratchet_key = PublicKey::deserialize(&sender_ratchet_key)?;
299 let counter = proto_structure
300 .counter
301 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
302 let previous_counter = proto_structure.previous_counter.unwrap_or(0);
303 let ciphertext = proto_structure
304 .ciphertext
305 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
306 .into_boxed_slice();
307
308 Ok(SignalMessage {
309 message_version,
310 sender_ratchet_key,
311 counter,
312 previous_counter,
313 ciphertext,
314 pq_ratchet: proto_structure.pq_ratchet.unwrap_or(vec![]),
315 addresses: proto_structure.addresses.map(Into::into),
316 serialized: Box::from(value),
317 })
318 }
319}
320
321#[derive(Debug, Clone)]
322pub struct KyberPayload {
323 pre_key_id: KyberPreKeyId,
324 ciphertext: kem::SerializedCiphertext,
325}
326
327impl KyberPayload {
328 pub fn new(id: KyberPreKeyId, ciphertext: kem::SerializedCiphertext) -> Self {
329 Self {
330 pre_key_id: id,
331 ciphertext,
332 }
333 }
334}
335
336#[derive(Debug, Clone)]
337pub struct PreKeySignalMessage {
338 message_version: u8,
339 registration_id: u32,
340 pre_key_id: Option<PreKeyId>,
341 signed_pre_key_id: SignedPreKeyId,
342 kyber_payload: Option<KyberPayload>,
345 base_key: PublicKey,
346 identity_key: IdentityKey,
347 message: SignalMessage,
348 serialized: Box<[u8]>,
349}
350
351impl PreKeySignalMessage {
352 pub fn new(
353 message_version: u8,
354 registration_id: u32,
355 pre_key_id: Option<PreKeyId>,
356 signed_pre_key_id: SignedPreKeyId,
357 kyber_payload: Option<KyberPayload>,
358 base_key: PublicKey,
359 identity_key: IdentityKey,
360 message: SignalMessage,
361 ) -> Result<Self> {
362 let proto_message = proto::wire::PreKeySignalMessage {
363 registration_id: Some(registration_id),
364 pre_key_id: pre_key_id.map(|id| id.into()),
365 signed_pre_key_id: Some(signed_pre_key_id.into()),
366 kyber_pre_key_id: kyber_payload.as_ref().map(|kyber| kyber.pre_key_id.into()),
367 kyber_ciphertext: kyber_payload
368 .as_ref()
369 .map(|kyber| kyber.ciphertext.to_vec()),
370 base_key: Some(base_key.serialize().into_vec()),
371 identity_key: Some(identity_key.serialize().into_vec()),
372 message: Some(Vec::from(message.as_ref())),
373 };
374 let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
375 serialized.push(((message_version & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION);
376 proto_message
377 .encode(&mut serialized)
378 .expect("can always append to a Vec");
379 Ok(Self {
380 message_version,
381 registration_id,
382 pre_key_id,
383 signed_pre_key_id,
384 kyber_payload,
385 base_key,
386 identity_key,
387 message,
388 serialized: serialized.into_boxed_slice(),
389 })
390 }
391
392 #[inline]
393 pub fn message_version(&self) -> u8 {
394 self.message_version
395 }
396
397 #[inline]
398 pub fn registration_id(&self) -> u32 {
399 self.registration_id
400 }
401
402 #[inline]
403 pub fn pre_key_id(&self) -> Option<PreKeyId> {
404 self.pre_key_id
405 }
406
407 #[inline]
408 pub fn signed_pre_key_id(&self) -> SignedPreKeyId {
409 self.signed_pre_key_id
410 }
411
412 #[inline]
413 pub fn kyber_pre_key_id(&self) -> Option<KyberPreKeyId> {
414 self.kyber_payload.as_ref().map(|kyber| kyber.pre_key_id)
415 }
416
417 #[inline]
418 pub fn kyber_ciphertext(&self) -> Option<&kem::SerializedCiphertext> {
419 self.kyber_payload.as_ref().map(|kyber| &kyber.ciphertext)
420 }
421
422 #[inline]
423 pub fn base_key(&self) -> &PublicKey {
424 &self.base_key
425 }
426
427 #[inline]
428 pub fn identity_key(&self) -> &IdentityKey {
429 &self.identity_key
430 }
431
432 #[inline]
433 pub fn message(&self) -> &SignalMessage {
434 &self.message
435 }
436
437 #[inline]
438 pub fn serialized(&self) -> &[u8] {
439 &self.serialized
440 }
441}
442
443impl AsRef<[u8]> for PreKeySignalMessage {
444 fn as_ref(&self) -> &[u8] {
445 &self.serialized
446 }
447}
448
449impl TryFrom<&[u8]> for PreKeySignalMessage {
450 type Error = SignalProtocolError;
451
452 fn try_from(value: &[u8]) -> Result<Self> {
453 if value.is_empty() {
454 return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
455 }
456
457 let message_version = value[0] >> 4;
458 if message_version < CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
459 return Err(SignalProtocolError::LegacyCiphertextVersion(
460 message_version,
461 ));
462 }
463 if message_version > CIPHERTEXT_MESSAGE_CURRENT_VERSION {
464 return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
465 message_version,
466 ));
467 }
468
469 let proto_structure = proto::wire::PreKeySignalMessage::decode(&value[1..])
470 .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
471
472 let base_key = proto_structure
473 .base_key
474 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
475 let identity_key = proto_structure
476 .identity_key
477 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
478 let message = proto_structure
479 .message
480 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
481 let signed_pre_key_id = proto_structure
482 .signed_pre_key_id
483 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
484
485 let base_key = PublicKey::deserialize(base_key.as_ref())?;
486
487 let kyber_payload = match (
488 proto_structure.kyber_pre_key_id,
489 proto_structure.kyber_ciphertext,
490 ) {
491 (Some(id), Some(ct)) => Some(KyberPayload::new(id.into(), ct.into_boxed_slice())),
492 (None, None) if message_version <= CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION => None,
493 (None, None) => {
494 return Err(SignalProtocolError::InvalidMessage(
495 CiphertextMessageType::PreKey,
496 "Kyber pre key must be present for this session version",
497 ));
498 }
499 _ => {
500 return Err(SignalProtocolError::InvalidMessage(
501 CiphertextMessageType::PreKey,
502 "Both or neither kyber pre_key_id and kyber_ciphertext can be present",
503 ));
504 }
505 };
506
507 Ok(PreKeySignalMessage {
508 message_version,
509 registration_id: proto_structure.registration_id.unwrap_or(0),
510 pre_key_id: proto_structure.pre_key_id.map(|id| id.into()),
511 signed_pre_key_id: signed_pre_key_id.into(),
512 kyber_payload,
513 base_key,
514 identity_key: IdentityKey::try_from(identity_key.as_ref())?,
515 message: SignalMessage::try_from(message.as_ref())?,
516 serialized: Box::from(value),
517 })
518 }
519}
520
521#[derive(Debug, Clone)]
522pub struct SenderKeyMessage {
523 message_version: u8,
524 distribution_id: Uuid,
525 chain_id: u32,
526 iteration: u32,
527 ciphertext: Box<[u8]>,
528 serialized: Box<[u8]>,
529}
530
531impl SenderKeyMessage {
532 const SIGNATURE_LEN: usize = 64;
533
534 pub fn new<R: CryptoRng + Rng>(
535 message_version: u8,
536 distribution_id: Uuid,
537 chain_id: u32,
538 iteration: u32,
539 ciphertext: Box<[u8]>,
540 csprng: &mut R,
541 signature_key: &PrivateKey,
542 ) -> Result<Self> {
543 let proto_message = proto::wire::SenderKeyMessage {
544 distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
545 chain_id: Some(chain_id),
546 iteration: Some(iteration),
547 ciphertext: Some(ciphertext.to_vec()),
548 };
549 let proto_message_len = proto_message.encoded_len();
550 let mut serialized = Vec::with_capacity(1 + proto_message_len + Self::SIGNATURE_LEN);
551 serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
552 proto_message
553 .encode(&mut serialized)
554 .expect("can always append to a buffer");
555 let signature = signature_key.calculate_signature(&serialized, csprng)?;
556 serialized.extend_from_slice(&signature[..]);
557 Ok(Self {
558 message_version: SENDERKEY_MESSAGE_CURRENT_VERSION,
559 distribution_id,
560 chain_id,
561 iteration,
562 ciphertext,
563 serialized: serialized.into_boxed_slice(),
564 })
565 }
566
567 pub fn verify_signature(&self, signature_key: &PublicKey) -> Result<bool> {
568 let (content, signature) = self
569 .serialized
570 .split_last_chunk::<{ Self::SIGNATURE_LEN }>()
571 .expect("length checked on initialization");
572 let valid = signature_key.verify_signature(content, signature);
573
574 Ok(valid)
575 }
576
577 #[inline]
578 pub fn message_version(&self) -> u8 {
579 self.message_version
580 }
581
582 #[inline]
583 pub fn distribution_id(&self) -> Uuid {
584 self.distribution_id
585 }
586
587 #[inline]
588 pub fn chain_id(&self) -> u32 {
589 self.chain_id
590 }
591
592 #[inline]
593 pub fn iteration(&self) -> u32 {
594 self.iteration
595 }
596
597 #[inline]
598 pub fn ciphertext(&self) -> &[u8] {
599 &self.ciphertext
600 }
601
602 #[inline]
603 pub fn serialized(&self) -> &[u8] {
604 &self.serialized
605 }
606}
607
608impl AsRef<[u8]> for SenderKeyMessage {
609 fn as_ref(&self) -> &[u8] {
610 &self.serialized
611 }
612}
613
614impl TryFrom<&[u8]> for SenderKeyMessage {
615 type Error = SignalProtocolError;
616
617 fn try_from(value: &[u8]) -> Result<Self> {
618 if value.len() < 1 + Self::SIGNATURE_LEN {
619 return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
620 }
621 let message_version = value[0] >> 4;
622 if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
623 return Err(SignalProtocolError::LegacyCiphertextVersion(
624 message_version,
625 ));
626 }
627 if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
628 return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
629 message_version,
630 ));
631 }
632 let proto_structure =
633 proto::wire::SenderKeyMessage::decode(&value[1..value.len() - Self::SIGNATURE_LEN])
634 .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
635
636 let distribution_id = proto_structure
637 .distribution_uuid
638 .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
639 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
640 let chain_id = proto_structure
641 .chain_id
642 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
643 let iteration = proto_structure
644 .iteration
645 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
646 let ciphertext = proto_structure
647 .ciphertext
648 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?
649 .into_boxed_slice();
650
651 Ok(SenderKeyMessage {
652 message_version,
653 distribution_id,
654 chain_id,
655 iteration,
656 ciphertext,
657 serialized: Box::from(value),
658 })
659 }
660}
661
662#[derive(Debug, Clone)]
663pub struct SenderKeyDistributionMessage {
664 message_version: u8,
665 distribution_id: Uuid,
666 chain_id: u32,
667 iteration: u32,
668 chain_key: Vec<u8>,
669 signing_key: PublicKey,
670 serialized: Box<[u8]>,
671}
672
673impl SenderKeyDistributionMessage {
674 pub fn new(
675 message_version: u8,
676 distribution_id: Uuid,
677 chain_id: u32,
678 iteration: u32,
679 chain_key: Vec<u8>,
680 signing_key: PublicKey,
681 ) -> Result<Self> {
682 let proto_message = proto::wire::SenderKeyDistributionMessage {
683 distribution_uuid: Some(distribution_id.as_bytes().to_vec()),
684 chain_id: Some(chain_id),
685 iteration: Some(iteration),
686 chain_key: Some(chain_key.clone()),
687 signing_key: Some(signing_key.serialize().to_vec()),
688 };
689 let mut serialized = Vec::with_capacity(1 + proto_message.encoded_len());
690 serialized.push(((message_version & 0xF) << 4) | SENDERKEY_MESSAGE_CURRENT_VERSION);
691 proto_message
692 .encode(&mut serialized)
693 .expect("can always append to a buffer");
694
695 Ok(Self {
696 message_version,
697 distribution_id,
698 chain_id,
699 iteration,
700 chain_key,
701 signing_key,
702 serialized: serialized.into_boxed_slice(),
703 })
704 }
705
706 #[inline]
707 pub fn message_version(&self) -> u8 {
708 self.message_version
709 }
710
711 #[inline]
712 pub fn distribution_id(&self) -> Result<Uuid> {
713 Ok(self.distribution_id)
714 }
715
716 #[inline]
717 pub fn chain_id(&self) -> Result<u32> {
718 Ok(self.chain_id)
719 }
720
721 #[inline]
722 pub fn iteration(&self) -> Result<u32> {
723 Ok(self.iteration)
724 }
725
726 #[inline]
727 pub fn chain_key(&self) -> Result<&[u8]> {
728 Ok(&self.chain_key)
729 }
730
731 #[inline]
732 pub fn signing_key(&self) -> Result<&PublicKey> {
733 Ok(&self.signing_key)
734 }
735
736 #[inline]
737 pub fn serialized(&self) -> &[u8] {
738 &self.serialized
739 }
740}
741
742impl AsRef<[u8]> for SenderKeyDistributionMessage {
743 fn as_ref(&self) -> &[u8] {
744 &self.serialized
745 }
746}
747
748impl TryFrom<&[u8]> for SenderKeyDistributionMessage {
749 type Error = SignalProtocolError;
750
751 fn try_from(value: &[u8]) -> Result<Self> {
752 if value.len() < 1 + 32 + 32 {
754 return Err(SignalProtocolError::CiphertextMessageTooShort(value.len()));
755 }
756
757 let message_version = value[0] >> 4;
758
759 if message_version < SENDERKEY_MESSAGE_CURRENT_VERSION {
760 return Err(SignalProtocolError::LegacyCiphertextVersion(
761 message_version,
762 ));
763 }
764 if message_version > SENDERKEY_MESSAGE_CURRENT_VERSION {
765 return Err(SignalProtocolError::UnrecognizedCiphertextVersion(
766 message_version,
767 ));
768 }
769
770 let proto_structure = proto::wire::SenderKeyDistributionMessage::decode(&value[1..])
771 .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
772
773 let distribution_id = proto_structure
774 .distribution_uuid
775 .and_then(|bytes| Uuid::from_slice(bytes.as_slice()).ok())
776 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
777 let chain_id = proto_structure
778 .chain_id
779 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
780 let iteration = proto_structure
781 .iteration
782 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
783 let chain_key = proto_structure
784 .chain_key
785 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
786 let signing_key = proto_structure
787 .signing_key
788 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
789
790 if chain_key.len() != 32 || signing_key.len() != 33 {
791 return Err(SignalProtocolError::InvalidProtobufEncoding);
792 }
793
794 let signing_key = PublicKey::deserialize(&signing_key)?;
795
796 Ok(SenderKeyDistributionMessage {
797 message_version,
798 distribution_id,
799 chain_id,
800 iteration,
801 chain_key,
802 signing_key,
803 serialized: Box::from(value),
804 })
805 }
806}
807
808#[derive(Debug, Clone)]
809pub struct PlaintextContent {
810 serialized: Box<[u8]>,
811}
812
813impl PlaintextContent {
814 const PLAINTEXT_CONTEXT_IDENTIFIER_BYTE: u8 = 0xC0;
819
820 const PADDING_BOUNDARY_BYTE: u8 = 0x80;
825
826 #[inline]
827 pub fn body(&self) -> &[u8] {
828 &self.serialized[1..]
829 }
830
831 #[inline]
832 pub fn serialized(&self) -> &[u8] {
833 &self.serialized
834 }
835}
836
837impl From<DecryptionErrorMessage> for PlaintextContent {
838 fn from(message: DecryptionErrorMessage) -> Self {
839 let proto_structure = proto::service::Content {
840 decryption_error_message: Some(message.serialized().to_vec()),
841 ..Default::default()
842 };
843 let mut serialized = vec![Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE];
844 proto_structure
845 .encode(&mut serialized)
846 .expect("can always encode to a Vec");
847 serialized.push(Self::PADDING_BOUNDARY_BYTE);
848 Self {
849 serialized: Box::from(serialized),
850 }
851 }
852}
853
854impl TryFrom<&[u8]> for PlaintextContent {
855 type Error = SignalProtocolError;
856
857 fn try_from(value: &[u8]) -> Result<Self> {
858 if value.is_empty() {
859 return Err(SignalProtocolError::CiphertextMessageTooShort(0));
860 }
861 if value[0] != Self::PLAINTEXT_CONTEXT_IDENTIFIER_BYTE {
862 return Err(SignalProtocolError::UnrecognizedMessageVersion(
863 value[0] as u32,
864 ));
865 }
866 Ok(Self {
867 serialized: Box::from(value),
868 })
869 }
870}
871
872#[derive(Debug, Clone)]
873pub struct DecryptionErrorMessage {
874 ratchet_key: Option<PublicKey>,
875 timestamp: Timestamp,
876 device_id: u32,
877 serialized: Box<[u8]>,
878}
879
880impl DecryptionErrorMessage {
881 pub fn for_original(
882 original_bytes: &[u8],
883 original_type: CiphertextMessageType,
884 original_timestamp: Timestamp,
885 original_sender_device_id: u32,
886 ) -> Result<Self> {
887 let ratchet_key = match original_type {
888 CiphertextMessageType::Whisper => {
889 Some(*SignalMessage::try_from(original_bytes)?.sender_ratchet_key())
890 }
891 CiphertextMessageType::PreKey => Some(
892 *PreKeySignalMessage::try_from(original_bytes)?
893 .message()
894 .sender_ratchet_key(),
895 ),
896 CiphertextMessageType::SenderKey => None,
897 CiphertextMessageType::Plaintext => {
898 return Err(SignalProtocolError::InvalidArgument(
899 "cannot create a DecryptionErrorMessage for plaintext content; it is not encrypted".to_string()
900 ));
901 }
902 };
903
904 let proto_message = proto::service::DecryptionErrorMessage {
905 timestamp: Some(original_timestamp.epoch_millis()),
906 ratchet_key: ratchet_key.map(|k| k.serialize().into()),
907 device_id: Some(original_sender_device_id),
908 };
909 let serialized = proto_message.encode_to_vec();
910
911 Ok(Self {
912 ratchet_key,
913 timestamp: original_timestamp,
914 device_id: original_sender_device_id,
915 serialized: serialized.into_boxed_slice(),
916 })
917 }
918
919 #[inline]
920 pub fn timestamp(&self) -> Timestamp {
921 self.timestamp
922 }
923
924 #[inline]
925 pub fn ratchet_key(&self) -> Option<&PublicKey> {
926 self.ratchet_key.as_ref()
927 }
928
929 #[inline]
930 pub fn device_id(&self) -> u32 {
931 self.device_id
932 }
933
934 #[inline]
935 pub fn serialized(&self) -> &[u8] {
936 &self.serialized
937 }
938}
939
940impl TryFrom<&[u8]> for DecryptionErrorMessage {
941 type Error = SignalProtocolError;
942
943 fn try_from(value: &[u8]) -> Result<Self> {
944 let proto_structure = proto::service::DecryptionErrorMessage::decode(value)
945 .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
946 let timestamp = proto_structure
947 .timestamp
948 .map(Timestamp::from_epoch_millis)
949 .ok_or(SignalProtocolError::InvalidProtobufEncoding)?;
950 let ratchet_key = proto_structure
951 .ratchet_key
952 .map(|k| PublicKey::deserialize(&k))
953 .transpose()?;
954 let device_id = proto_structure.device_id.unwrap_or_default();
955 Ok(Self {
956 timestamp,
957 ratchet_key,
958 device_id,
959 serialized: Box::from(value),
960 })
961 }
962}
963
964pub fn extract_decryption_error_message_from_serialized_content(
966 bytes: &[u8],
967) -> Result<DecryptionErrorMessage> {
968 if bytes.last() != Some(&PlaintextContent::PADDING_BOUNDARY_BYTE) {
969 return Err(SignalProtocolError::InvalidProtobufEncoding);
970 }
971 let content = proto::service::Content::decode(bytes.split_last().expect("checked above").1)
972 .map_err(|_| SignalProtocolError::InvalidProtobufEncoding)?;
973 content
974 .decryption_error_message
975 .as_deref()
976 .ok_or_else(|| {
977 SignalProtocolError::InvalidArgument(
978 "Content does not contain DecryptionErrorMessage".to_owned(),
979 )
980 })
981 .and_then(DecryptionErrorMessage::try_from)
982}
983
984pub fn should_use_nonpq_session(require_pq_ratio: f64, session_key: &[u8]) -> bool {
991 assert!(session_key.len() >= 4);
992 if require_pq_ratio >= 1.0 {
993 return false;
994 } else if require_pq_ratio <= 0.0 {
995 return true;
996 }
997 let sess_u32 = u32::from_be_bytes(
1004 (&session_key[..4])
1005 .try_into()
1006 .expect("should have 32 bytes"),
1007 );
1008 #[allow(clippy::cast_possible_truncation)]
1011 let ratio_u32 = ((u32::MAX as f64) * require_pq_ratio) as u32;
1012 ratio_u32 <= sess_u32
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019 use rand::rngs::OsRng;
1020 use rand::{CryptoRng, Rng, RngCore, TryRngCore as _};
1021
1022 use super::*;
1023 use crate::{DeviceId, KeyPair};
1024
1025 fn create_signal_message<T>(csprng: &mut T) -> Result<SignalMessage>
1026 where
1027 T: Rng + CryptoRng,
1028 {
1029 let mut mac_key = [0u8; 32];
1030 csprng.fill_bytes(&mut mac_key);
1031 let mac_key = mac_key;
1032
1033 let mut ciphertext = [0u8; 20];
1034 csprng.fill_bytes(&mut ciphertext);
1035 let ciphertext = ciphertext;
1036
1037 let sender_ratchet_key_pair = KeyPair::generate(csprng);
1038 let sender_identity_key_pair = KeyPair::generate(csprng);
1039 let receiver_identity_key_pair = KeyPair::generate(csprng);
1040 let sender_address = ProtocolAddress::new(
1041 "31415926-5358-9793-2384-626433827950".to_owned(),
1042 DeviceId::new(1).unwrap(),
1043 );
1044 let recipient_address = ProtocolAddress::new(
1045 "27182818-2845-9045-2353-602874713526".to_owned(),
1046 DeviceId::new(1).unwrap(),
1047 );
1048
1049 SignalMessage::new(
1050 4,
1051 &mac_key,
1052 Some((&sender_address, &recipient_address)),
1053 sender_ratchet_key_pair.public_key,
1054 42,
1055 41,
1056 &ciphertext,
1057 &sender_identity_key_pair.public_key.into(),
1058 &receiver_identity_key_pair.public_key.into(),
1059 b"", )
1061 }
1062
1063 fn assert_signal_message_equals(m1: &SignalMessage, m2: &SignalMessage) {
1064 assert_eq!(m1.message_version, m2.message_version);
1065 assert_eq!(m1.sender_ratchet_key, m2.sender_ratchet_key);
1066 assert_eq!(m1.counter, m2.counter);
1067 assert_eq!(m1.previous_counter, m2.previous_counter);
1068 assert_eq!(m1.ciphertext, m2.ciphertext);
1069 assert_eq!(m1.addresses, m2.addresses);
1070 assert_eq!(m1.serialized, m2.serialized);
1071 }
1072
1073 #[test]
1074 fn test_signal_message_serialize_deserialize() -> Result<()> {
1075 let mut csprng = OsRng.unwrap_err();
1076 let message = create_signal_message(&mut csprng)?;
1077 let deser_message =
1078 SignalMessage::try_from(message.as_ref()).expect("should deserialize without error");
1079 assert_signal_message_equals(&message, &deser_message);
1080 Ok(())
1081 }
1082
1083 #[test]
1084 fn test_pre_key_signal_message_serialize_deserialize() -> Result<()> {
1085 let mut csprng = OsRng.unwrap_err();
1086 let identity_key_pair = KeyPair::generate(&mut csprng);
1087 let base_key_pair = KeyPair::generate(&mut csprng);
1088 let message = create_signal_message(&mut csprng)?;
1089 let pre_key_signal_message = PreKeySignalMessage::new(
1090 3,
1091 365,
1092 None,
1093 97.into(),
1094 None, base_key_pair.public_key,
1096 identity_key_pair.public_key.into(),
1097 message,
1098 )?;
1099 let deser_pre_key_signal_message =
1100 PreKeySignalMessage::try_from(pre_key_signal_message.as_ref())
1101 .expect("should deserialize without error");
1102 assert_eq!(
1103 pre_key_signal_message.message_version,
1104 deser_pre_key_signal_message.message_version
1105 );
1106 assert_eq!(
1107 pre_key_signal_message.registration_id,
1108 deser_pre_key_signal_message.registration_id
1109 );
1110 assert_eq!(
1111 pre_key_signal_message.pre_key_id,
1112 deser_pre_key_signal_message.pre_key_id
1113 );
1114 assert_eq!(
1115 pre_key_signal_message.signed_pre_key_id,
1116 deser_pre_key_signal_message.signed_pre_key_id
1117 );
1118 assert_eq!(
1119 pre_key_signal_message.base_key,
1120 deser_pre_key_signal_message.base_key
1121 );
1122 assert_eq!(
1123 pre_key_signal_message.identity_key.public_key(),
1124 deser_pre_key_signal_message.identity_key.public_key()
1125 );
1126 assert_signal_message_equals(
1127 &pre_key_signal_message.message,
1128 &deser_pre_key_signal_message.message,
1129 );
1130 assert_eq!(
1131 pre_key_signal_message.serialized,
1132 deser_pre_key_signal_message.serialized
1133 );
1134 Ok(())
1135 }
1136
1137 #[test]
1138 fn test_signal_message_verify_mac_accepts_legacy_message_without_addresses() -> Result<()> {
1139 let mut csprng = OsRng.unwrap_err();
1140 let mut mac_key = [0u8; 32];
1141 csprng.fill_bytes(&mut mac_key);
1142
1143 let mut ciphertext = [0u8; 20];
1144 csprng.fill_bytes(&mut ciphertext);
1145
1146 let sender_ratchet_key_pair = KeyPair::generate(&mut csprng);
1147 let sender_identity_key_pair = KeyPair::generate(&mut csprng);
1148 let receiver_identity_key_pair = KeyPair::generate(&mut csprng);
1149 let sender_address = ProtocolAddress::new(
1150 "16180339-8874-9894-8482-045868343656".to_owned(),
1151 DeviceId::new(1).unwrap(),
1152 );
1153 let recipient_address = ProtocolAddress::new(
1154 "14142135-6237-3095-0488-016887242096".to_owned(),
1155 DeviceId::new(1).unwrap(),
1156 );
1157
1158 let message = SignalMessage::new(
1159 4,
1160 &mac_key,
1161 Some((&sender_address, &recipient_address)),
1162 sender_ratchet_key_pair.public_key,
1163 42,
1164 41,
1165 &ciphertext,
1166 &sender_identity_key_pair.public_key.into(),
1167 &receiver_identity_key_pair.public_key.into(),
1168 b"",
1169 )?;
1170
1171 let mut proto_structure = proto::wire::SignalMessage::decode(
1172 &message.serialized()[1..message.serialized().len() - SignalMessage::MAC_LENGTH],
1173 )
1174 .expect("valid protobuf");
1175 proto_structure.addresses = None;
1176
1177 let mut serialized =
1178 vec![((message.message_version() & 0xF) << 4) | CIPHERTEXT_MESSAGE_CURRENT_VERSION];
1179 proto_structure.encode(&mut serialized).expect("encodes");
1180 let mac = SignalMessage::compute_mac(
1181 &sender_identity_key_pair.public_key.into(),
1182 &receiver_identity_key_pair.public_key.into(),
1183 &mac_key,
1184 &serialized,
1185 )?;
1186 serialized.extend_from_slice(&mac);
1187
1188 let legacy_message = SignalMessage::try_from(serialized.as_slice())?;
1189 assert!(legacy_message.verify_mac_with_addresses(
1190 &sender_address,
1191 &recipient_address,
1192 &sender_identity_key_pair.public_key.into(),
1193 &receiver_identity_key_pair.public_key.into(),
1194 &mac_key,
1195 )?);
1196
1197 Ok(())
1198 }
1199
1200 #[test]
1201 fn test_signal_message_verify_mac_rejects_wrong_address() -> Result<()> {
1202 let mut csprng = OsRng.unwrap_err();
1203 let mut mac_key = [0u8; 32];
1204 csprng.fill_bytes(&mut mac_key);
1205
1206 let mut ciphertext = [0u8; 20];
1207 csprng.fill_bytes(&mut ciphertext);
1208
1209 let sender_ratchet_key_pair = KeyPair::generate(&mut csprng);
1210 let sender_identity_key_pair = KeyPair::generate(&mut csprng);
1211 let receiver_identity_key_pair = KeyPair::generate(&mut csprng);
1212 let sender_address = ProtocolAddress::new(
1213 "deadbeef-cafe-babe-feed-faceb00c0ffe".to_owned(),
1214 DeviceId::new(1).unwrap(),
1215 );
1216 let recipient_address = ProtocolAddress::new(
1217 "01120358-1321-3455-0891-44233377610a".to_owned(),
1218 DeviceId::new(1).unwrap(),
1219 );
1220 let wrong_address = ProtocolAddress::new(
1221 "02030507-1113-1719-2329-313741434753".to_owned(),
1222 DeviceId::new(1).unwrap(),
1223 );
1224
1225 let message = SignalMessage::new(
1226 4,
1227 &mac_key,
1228 Some((&sender_address, &recipient_address)),
1229 sender_ratchet_key_pair.public_key,
1230 42,
1231 41,
1232 &ciphertext,
1233 &sender_identity_key_pair.public_key.into(),
1234 &receiver_identity_key_pair.public_key.into(),
1235 b"",
1236 )?;
1237
1238 assert!(!message.verify_mac_with_addresses(
1240 &wrong_address,
1241 &recipient_address,
1242 &sender_identity_key_pair.public_key.into(),
1243 &receiver_identity_key_pair.public_key.into(),
1244 &mac_key,
1245 )?);
1246
1247 assert!(!message.verify_mac_with_addresses(
1249 &sender_address,
1250 &wrong_address,
1251 &sender_identity_key_pair.public_key.into(),
1252 &receiver_identity_key_pair.public_key.into(),
1253 &mac_key,
1254 )?);
1255
1256 assert!(message.verify_mac_with_addresses(
1258 &sender_address,
1259 &recipient_address,
1260 &sender_identity_key_pair.public_key.into(),
1261 &receiver_identity_key_pair.public_key.into(),
1262 &mac_key,
1263 )?);
1264
1265 Ok(())
1266 }
1267
1268 #[test]
1269 fn test_sender_key_message_serialize_deserialize() -> Result<()> {
1270 let mut csprng = OsRng.unwrap_err();
1271 let signature_key_pair = KeyPair::generate(&mut csprng);
1272 let sender_key_message = SenderKeyMessage::new(
1273 SENDERKEY_MESSAGE_CURRENT_VERSION,
1274 Uuid::from_u128(0xd1d1d1d1_7000_11eb_b32a_33b8a8a487a6),
1275 42,
1276 7,
1277 [1u8, 2, 3].into(),
1278 &mut csprng,
1279 &signature_key_pair.private_key,
1280 )?;
1281 let deser_sender_key_message = SenderKeyMessage::try_from(sender_key_message.as_ref())
1282 .expect("should deserialize without error");
1283 assert_eq!(
1284 sender_key_message.message_version,
1285 deser_sender_key_message.message_version
1286 );
1287 assert_eq!(
1288 sender_key_message.chain_id,
1289 deser_sender_key_message.chain_id
1290 );
1291 assert_eq!(
1292 sender_key_message.iteration,
1293 deser_sender_key_message.iteration
1294 );
1295 assert_eq!(
1296 sender_key_message.ciphertext,
1297 deser_sender_key_message.ciphertext
1298 );
1299 assert_eq!(
1300 sender_key_message.serialized,
1301 deser_sender_key_message.serialized
1302 );
1303 Ok(())
1304 }
1305
1306 #[test]
1307 fn test_decryption_error_message() -> Result<()> {
1308 let mut csprng = OsRng.unwrap_err();
1309 let identity_key_pair = KeyPair::generate(&mut csprng);
1310 let base_key_pair = KeyPair::generate(&mut csprng);
1311 let message = create_signal_message(&mut csprng)?;
1312 let timestamp: Timestamp = Timestamp::from_epoch_millis(0x2_0000_0001);
1313 let device_id = 0x8086_2021;
1314
1315 {
1316 let error_message = DecryptionErrorMessage::for_original(
1317 message.serialized(),
1318 CiphertextMessageType::Whisper,
1319 timestamp,
1320 device_id,
1321 )?;
1322 let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1323 assert_eq!(
1324 error_message.ratchet_key(),
1325 Some(message.sender_ratchet_key())
1326 );
1327 assert_eq!(error_message.timestamp(), timestamp);
1328 assert_eq!(error_message.device_id(), device_id);
1329 }
1330
1331 let pre_key_signal_message = PreKeySignalMessage::new(
1332 3,
1333 365,
1334 None,
1335 97.into(),
1336 None, base_key_pair.public_key,
1338 identity_key_pair.public_key.into(),
1339 message,
1340 )?;
1341
1342 {
1343 let error_message = DecryptionErrorMessage::for_original(
1344 pre_key_signal_message.serialized(),
1345 CiphertextMessageType::PreKey,
1346 timestamp,
1347 device_id,
1348 )?;
1349 let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1350 assert_eq!(
1351 error_message.ratchet_key(),
1352 Some(pre_key_signal_message.message().sender_ratchet_key())
1353 );
1354 assert_eq!(error_message.timestamp(), timestamp);
1355 assert_eq!(error_message.device_id(), device_id);
1356 }
1357
1358 let sender_key_message = SenderKeyMessage::new(
1359 3,
1360 Uuid::nil(),
1361 1,
1362 2,
1363 Box::from(b"test".to_owned()),
1364 &mut csprng,
1365 &base_key_pair.private_key,
1366 )?;
1367
1368 {
1369 let error_message = DecryptionErrorMessage::for_original(
1370 sender_key_message.serialized(),
1371 CiphertextMessageType::SenderKey,
1372 timestamp,
1373 device_id,
1374 )?;
1375 let error_message = DecryptionErrorMessage::try_from(error_message.serialized())?;
1376 assert_eq!(error_message.ratchet_key(), None);
1377 assert_eq!(error_message.timestamp(), timestamp);
1378 assert_eq!(error_message.device_id(), device_id);
1379 }
1380
1381 Ok(())
1382 }
1383
1384 #[test]
1385 fn test_decryption_error_message_for_plaintext() {
1386 assert!(matches!(
1387 DecryptionErrorMessage::for_original(
1388 &[],
1389 CiphertextMessageType::Plaintext,
1390 Timestamp::from_epoch_millis(5),
1391 7
1392 ),
1393 Err(SignalProtocolError::InvalidArgument(_))
1394 ));
1395 }
1396
1397 #[test]
1398 fn test_should_use_nonpq_session() {
1399 let max = b"\xff\xff\xff\xff";
1400 let min = b"\x00\x00\x00\x00";
1401 let mid = b"\x7f\xff\xff\xff";
1402 assert!(!should_use_nonpq_session(1.0, max));
1403 assert!(!should_use_nonpq_session(1.0, min));
1404 assert!(should_use_nonpq_session(0.0, max));
1405 assert!(should_use_nonpq_session(0.0, min));
1406
1407 assert!(!should_use_nonpq_session(0.75, min));
1408 assert!(!should_use_nonpq_session(0.75, mid));
1409 assert!(should_use_nonpq_session(0.75, max));
1410
1411 assert!(!should_use_nonpq_session(0.25, min));
1412 assert!(should_use_nonpq_session(0.25, mid));
1413 assert!(should_use_nonpq_session(0.25, max));
1414 }
1415}