libsignal_service/
sender.rs

1use std::{collections::HashSet, time::SystemTime};
2
3use chrono::prelude::*;
4use libsignal_core::{curve::CurveError, InvalidDeviceId};
5use libsignal_protocol::{
6    process_prekey_bundle, Aci, DeviceId, IdentityKey, IdentityKeyPair, Pni,
7    ProtocolStore, SenderCertificate, SenderKeyStore, ServiceId,
8    SignalProtocolError, UsePQRatchet,
9};
10use rand::{rng, CryptoRng, Rng};
11use tracing::{debug, error, info, trace, warn};
12use tracing_futures::Instrument;
13use uuid::Uuid;
14use zkgroup::GROUP_IDENTIFIER_LEN;
15
16use crate::{
17    cipher::{get_preferred_protocol_address, ServiceCipher},
18    content::ContentBody,
19    proto::{
20        attachment_pointer::{
21            AttachmentIdentifier, Flags as AttachmentPointerFlags,
22        },
23        sync_message::{
24            self, message_request_response, MessageRequestResponse,
25        },
26        AttachmentPointer, SyncMessage,
27    },
28    push_service::*,
29    service_address::ServiceIdExt,
30    session_store::SessionStoreExt,
31    unidentified_access::UnidentifiedAccess,
32    utils::serde_service_id,
33    websocket::SignalWebSocket,
34};
35
36pub use crate::proto::ContactDetails;
37
38#[derive(serde::Serialize, Debug)]
39#[serde(rename_all = "camelCase")]
40pub struct OutgoingPushMessage {
41    pub r#type: u32,
42    pub destination_device_id: u32,
43    pub destination_registration_id: u32,
44    pub content: String,
45}
46
47#[derive(serde::Serialize, Debug)]
48pub struct OutgoingPushMessages {
49    #[serde(with = "serde_service_id")]
50    pub destination: ServiceId,
51    pub timestamp: u64,
52    pub messages: Vec<OutgoingPushMessage>,
53    pub online: bool,
54}
55
56#[derive(serde::Deserialize, Debug)]
57#[serde(rename_all = "camelCase")]
58pub struct SendMessageResponse {
59    pub needs_sync: bool,
60}
61
62pub type SendMessageResult = Result<SentMessage, MessageSenderError>;
63
64#[derive(Debug, Clone)]
65pub struct SentMessage {
66    pub recipient: ServiceId,
67    pub used_identity_key: IdentityKey,
68    pub unidentified: bool,
69    pub needs_sync: bool,
70}
71
72/// Attachment specification to be used for uploading.
73///
74/// Loose equivalent of Java's `SignalServiceAttachmentStream`.
75#[derive(Debug, Default)]
76pub struct AttachmentSpec {
77    pub content_type: String,
78    pub length: usize,
79    pub file_name: Option<String>,
80    pub preview: Option<Vec<u8>>,
81    pub voice_note: Option<bool>,
82    pub borderless: Option<bool>,
83    pub width: Option<u32>,
84    pub height: Option<u32>,
85    pub caption: Option<String>,
86    pub blur_hash: Option<String>,
87}
88
89#[derive(Clone)]
90pub struct MessageSender<S> {
91    identified_ws: SignalWebSocket,
92    unidentified_ws: SignalWebSocket,
93    service: PushService,
94    cipher: ServiceCipher<S>,
95    protocol_store: S,
96    local_aci: Aci,
97    local_pni: Pni,
98    aci_identity: IdentityKeyPair,
99    pni_identity: Option<IdentityKeyPair>,
100    device_id: DeviceId,
101}
102
103#[derive(thiserror::Error, Debug)]
104pub enum AttachmentUploadError {
105    #[error("{0}")]
106    ServiceError(#[from] ServiceError),
107
108    #[error("Could not read attachment contents")]
109    IoError(#[from] std::io::Error),
110}
111
112#[derive(thiserror::Error, Debug)]
113pub enum MessageSenderError {
114    #[error("service error: {0}")]
115    ServiceError(#[from] ServiceError),
116
117    #[error("protocol error: {0}")]
118    ProtocolError(#[from] SignalProtocolError),
119
120    #[error("invalid private key: {0}")]
121    InvalidPrivateKey(#[from] CurveError),
122
123    #[error("invalid device ID: {0}")]
124    InvalidDeviceId(#[from] InvalidDeviceId),
125
126    #[error("Failed to upload attachment {0}")]
127    AttachmentUploadError(#[from] AttachmentUploadError),
128
129    #[error("primary device can't send sync message {0:?}")]
130    SendSyncMessageError(sync_message::request::Type),
131
132    #[error("Untrusted identity key with {address:?}")]
133    UntrustedIdentity { address: ServiceId },
134
135    #[error("Exceeded maximum number of retries")]
136    MaximumRetriesLimitExceeded,
137
138    #[error("Proof of type {options:?} required using token {token}")]
139    ProofRequired { token: String, options: Vec<String> },
140
141    #[error("Recipient not found: {service_id:?}")]
142    NotFound { service_id: ServiceId },
143
144    #[error("no messages were encrypted: this should not really happen and most likely implies a logic error")]
145    NoMessagesToSend,
146}
147
148pub type GroupV2Id = [u8; GROUP_IDENTIFIER_LEN];
149
150#[derive(Debug)]
151pub enum ThreadIdentifier {
152    Aci(Uuid),
153    Group(GroupV2Id),
154}
155
156#[derive(Debug)]
157pub struct EncryptedMessages {
158    messages: Vec<OutgoingPushMessage>,
159    used_identity_key: IdentityKey,
160}
161
162impl<S> MessageSender<S>
163where
164    S: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
165{
166    #[allow(clippy::too_many_arguments)]
167    pub fn new(
168        identified_ws: SignalWebSocket,
169        unidentified_ws: SignalWebSocket,
170        service: PushService,
171        cipher: ServiceCipher<S>,
172        protocol_store: S,
173        local_aci: impl Into<Aci>,
174        local_pni: impl Into<Pni>,
175        aci_identity: IdentityKeyPair,
176        pni_identity: Option<IdentityKeyPair>,
177        device_id: DeviceId,
178    ) -> Self {
179        MessageSender {
180            service,
181            identified_ws,
182            unidentified_ws,
183            cipher,
184            protocol_store,
185            local_aci: local_aci.into(),
186            local_pni: local_pni.into(),
187            aci_identity,
188            pni_identity,
189            device_id,
190        }
191    }
192
193    /// Encrypts and uploads an attachment
194    ///
195    /// Contents are accepted as an owned, plain text Vec, because encryption happens in-place.
196    #[tracing::instrument(skip(self, contents, csprng), fields(size = contents.len()))]
197    pub async fn upload_attachment<R: Rng + CryptoRng>(
198        &mut self,
199        spec: AttachmentSpec,
200        mut contents: Vec<u8>,
201        csprng: &mut R,
202    ) -> Result<AttachmentPointer, AttachmentUploadError> {
203        let len = contents.len();
204        // Encrypt
205        let (key, iv) = {
206            let mut key = [0u8; 64];
207            let mut iv = [0u8; 16];
208            csprng.fill_bytes(&mut key);
209            csprng.fill_bytes(&mut iv);
210            (key, iv)
211        };
212
213        // Padded length uses an exponential bracketting thingy.
214        // If you want to see how it looks:
215        // https://www.wolframalpha.com/input/?i=plot+floor%281.05%5Eceil%28log_1.05%28x%29%29%29+for+x+from+0+to+5000000
216        let padded_len: usize = {
217            // Java:
218            // return (int) Math.max(541, Math.floor(Math.pow(1.05, Math.ceil(Math.log(size) / Math.log(1.05)))))
219            std::cmp::max(
220                541,
221                1.05f64.powf((len as f64).log(1.05).ceil()).floor() as usize,
222            )
223        };
224        if padded_len < len {
225            error!(
226                "Padded len {} < len {}. Continuing with a privacy risk.",
227                padded_len, len
228            );
229        } else {
230            contents.resize(padded_len, 0);
231        }
232
233        tracing::trace_span!("encrypting attachment").in_scope(|| {
234            crate::attachment_cipher::encrypt_in_place(iv, key, &mut contents)
235        });
236
237        // Request upload attributes
238        // TODO: we can actually store the upload spec to be able to resume the upload later
239        // if it fails or stalls (= we should at least split the API calls so clients can decide what to do)
240        let attachment_upload_form = self
241            .service
242            .get_attachment_v4_upload_attributes()
243            .instrument(tracing::trace_span!("requesting upload attributes"))
244            .await?;
245
246        let resumable_upload_url = self
247            .service
248            .get_attachment_resumable_upload_url(&attachment_upload_form)
249            .await?;
250
251        let attachment_digest = self
252            .service
253            .upload_attachment_v4(
254                attachment_upload_form.cdn,
255                &resumable_upload_url,
256                contents.len() as u64,
257                attachment_upload_form.headers,
258                &mut std::io::Cursor::new(&contents),
259            )
260            .await?;
261
262        Ok(AttachmentPointer {
263            content_type: Some(spec.content_type),
264            key: Some(key.to_vec()),
265            size: Some(len as u32),
266            // thumbnail: Option<Vec<u8>>,
267            digest: Some(attachment_digest.digest),
268            file_name: spec.file_name,
269            flags: Some(
270                if spec.voice_note == Some(true) {
271                    AttachmentPointerFlags::VoiceMessage as u32
272                } else {
273                    0
274                } | if spec.borderless == Some(true) {
275                    AttachmentPointerFlags::Borderless as u32
276                } else {
277                    0
278                },
279            ),
280            width: spec.width,
281            height: spec.height,
282            caption: spec.caption,
283            blur_hash: spec.blur_hash,
284            upload_timestamp: Some(
285                SystemTime::now()
286                    .duration_since(SystemTime::UNIX_EPOCH)
287                    .expect("unix epoch in the past")
288                    .as_millis() as u64,
289            ),
290            cdn_number: Some(attachment_upload_form.cdn),
291            attachment_identifier: Some(AttachmentIdentifier::CdnKey(
292                attachment_upload_form.key,
293            )),
294            ..Default::default()
295        })
296    }
297
298    /// Upload contact details to the CDN
299    ///
300    /// Returns attachment ID and the attachment digest
301    #[tracing::instrument(skip(self, contacts))]
302    async fn upload_contact_details<Contacts>(
303        &mut self,
304        contacts: Contacts,
305    ) -> Result<AttachmentPointer, AttachmentUploadError>
306    where
307        Contacts: IntoIterator<Item = ContactDetails>,
308    {
309        use prost::Message;
310        let mut out = Vec::new();
311        for contact in contacts {
312            contact
313                .encode_length_delimited(&mut out)
314                .expect("infallible encoding");
315            // XXX add avatar here
316        }
317
318        let spec = AttachmentSpec {
319            content_type: "application/octet-stream".into(),
320            length: out.len(),
321            file_name: None,
322            preview: None,
323            voice_note: None,
324            borderless: None,
325            width: None,
326            height: None,
327            caption: None,
328            blur_hash: None,
329        };
330        self.upload_attachment(spec, out, &mut rng()).await
331    }
332
333    /// Return whether we have to prepare sync messages for other devices
334    ///
335    /// - If we are the main registered device, and there are established sub-device sessions (linked clients), return true
336    /// - If we are a secondary linked device, return true
337    async fn is_multi_device(&self) -> bool {
338        if self.device_id == *DEFAULT_DEVICE_ID {
339            self.protocol_store
340                .get_sub_device_sessions(&self.local_aci.into())
341                .await
342                .is_ok_and(|s| !s.is_empty())
343        } else {
344            true
345        }
346    }
347
348    /// Send a message `content` to a single `recipient`.
349    #[tracing::instrument(
350        skip(self, unidentified_access, message),
351        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
352    )]
353    pub async fn send_message(
354        &mut self,
355        recipient: &ServiceId,
356        mut unidentified_access: Option<UnidentifiedAccess>,
357        message: impl Into<ContentBody>,
358        timestamp: u64,
359        include_pni_signature: bool,
360        online: bool,
361    ) -> SendMessageResult {
362        let content_body = message.into();
363        let message_to_self = recipient == &self.local_aci;
364        let sync_message =
365            matches!(content_body, ContentBody::SynchronizeMessage(..));
366        let is_multi_device = self.is_multi_device().await;
367
368        use crate::proto::data_message::Flags;
369
370        let end_session = match &content_body {
371            ContentBody::DataMessage(message) => {
372                message.flags == Some(Flags::EndSession as u32)
373            },
374            _ => false,
375        };
376
377        // only send a sync message when sending to self and skip the rest of the process
378        if message_to_self && is_multi_device && !sync_message {
379            debug!("sending note to self");
380            let sync_message = self
381                .create_multi_device_sent_transcript_content(
382                    Some(recipient),
383                    content_body,
384                    timestamp,
385                    None,
386                );
387            return self
388                .try_send_message(
389                    *recipient,
390                    None,
391                    &sync_message,
392                    timestamp,
393                    include_pni_signature,
394                    online,
395                )
396                .await;
397        }
398
399        // don't send session enders as sealed sender
400        // sync messages are never sent as unidentified (reasons unclear), see: https://github.com/signalapp/Signal-Android/blob/main/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java#L779
401        if end_session || sync_message {
402            unidentified_access.take();
403        }
404
405        // try to send the original message to all the recipient's devices
406        let result = self
407            .try_send_message(
408                *recipient,
409                unidentified_access.as_ref(),
410                &content_body,
411                timestamp,
412                include_pni_signature,
413                online,
414            )
415            .await;
416
417        let needs_sync = match &result {
418            Ok(SentMessage { needs_sync, .. }) => *needs_sync,
419            _ => false,
420        };
421
422        if needs_sync || is_multi_device {
423            debug!("sending multi-device sync message");
424            let sync_message = self
425                .create_multi_device_sent_transcript_content(
426                    Some(recipient),
427                    content_body,
428                    timestamp,
429                    Some(&result),
430                );
431            self.try_send_message(
432                self.local_aci.into(),
433                None,
434                &sync_message,
435                timestamp,
436                false,
437                false,
438            )
439            .await?;
440        }
441
442        if end_session {
443            let n = self.protocol_store.delete_all_sessions(recipient).await?;
444            tracing::debug!(
445                "ended {} sessions with {}",
446                n,
447                recipient.raw_uuid()
448            );
449        }
450
451        result
452    }
453
454    /// Send a message to the recipients in a group.
455    ///
456    /// Recipients are a list of tuples, each containing:
457    /// - The recipient's address
458    /// - The recipient's unidentified access
459    /// - Whether the recipient requires a PNI signature
460    #[tracing::instrument(
461        skip(self, recipients, message),
462        fields(recipients = recipients.as_ref().len()),
463    )]
464    pub async fn send_message_to_group(
465        &mut self,
466        recipients: impl AsRef<[(ServiceId, Option<UnidentifiedAccess>, bool)]>,
467        message: impl Into<ContentBody>,
468        timestamp: u64,
469        online: bool,
470    ) -> Vec<SendMessageResult> {
471        let content_body: ContentBody = message.into();
472        let mut results = vec![];
473
474        let mut needs_sync_in_results = false;
475
476        for (recipient, unidentified_access, include_pni_signature) in
477            recipients.as_ref()
478        {
479            let result = self
480                .try_send_message(
481                    *recipient,
482                    unidentified_access.as_ref(),
483                    &content_body,
484                    timestamp,
485                    *include_pni_signature,
486                    online,
487                )
488                .await;
489
490            match result {
491                Ok(SentMessage { needs_sync, .. }) if needs_sync => {
492                    needs_sync_in_results = true;
493                },
494                _ => (),
495            };
496
497            results.push(result);
498        }
499
500        // we only need to send a synchronization message once
501        if needs_sync_in_results || self.is_multi_device().await {
502            let sync_message = self
503                .create_multi_device_sent_transcript_content(
504                    None,
505                    content_body,
506                    timestamp,
507                    &results,
508                );
509            // Note: the result of sending a sync message is not included in results
510            // See Signal Android `SignalServiceMessageSender.java:2817`
511            if let Err(error) = self
512                .try_send_message(
513                    self.local_aci.into(),
514                    None,
515                    &sync_message,
516                    timestamp,
517                    false, // XXX: maybe the sync device does want a PNI signature?
518                    false,
519                )
520                .await
521            {
522                error!(%error, "failed to send a synchronization message");
523            }
524        }
525
526        results
527    }
528
529    /// Send a message (`content`) to an address (`recipient`).
530    #[tracing::instrument(
531        level = "trace",
532        skip(self, unidentified_access, content_body, recipient),
533        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
534    )]
535    async fn try_send_message(
536        &mut self,
537        recipient: ServiceId,
538        mut unidentified_access: Option<&UnidentifiedAccess>,
539        content_body: &ContentBody,
540        timestamp: u64,
541        include_pni_signature: bool,
542        online: bool,
543    ) -> SendMessageResult {
544        trace!("trying to send a message");
545
546        use prost::Message;
547
548        let mut content = content_body.clone().into_proto();
549        if include_pni_signature {
550            content.pni_signature_message = Some(self.create_pni_signature()?);
551        }
552
553        let content_bytes = content.encode_to_vec();
554
555        let mut rng = rng();
556
557        for _ in 0..4u8 {
558            let Some(EncryptedMessages {
559                messages,
560                used_identity_key,
561            }) = self
562                .create_encrypted_messages(
563                    &recipient,
564                    unidentified_access.map(|x| &x.certificate),
565                    &content_bytes,
566                )
567                .await?
568            else {
569                // this can happen for example when a device is primary, without any secondaries
570                // and we send a message to ourselves (which is only a SyncMessage { sent: ... })
571                // addressed to self
572                return Err(MessageSenderError::NoMessagesToSend);
573            };
574
575            let messages = OutgoingPushMessages {
576                destination: recipient,
577                timestamp,
578                messages,
579                online,
580            };
581
582            let send = if let Some(unidentified) = &unidentified_access {
583                tracing::debug!("sending via unidentified");
584                self.unidentified_ws
585                    .send_messages_unidentified(messages, unidentified)
586                    .await
587            } else {
588                tracing::debug!("sending identified");
589                self.identified_ws.send_messages(messages).await
590            };
591
592            match send {
593                Ok(SendMessageResponse { needs_sync }) => {
594                    tracing::debug!("message sent!");
595                    return Ok(SentMessage {
596                        recipient,
597                        used_identity_key,
598                        unidentified: unidentified_access.is_some(),
599                        needs_sync,
600                    });
601                },
602                Err(ServiceError::Unauthorized)
603                    if unidentified_access.is_some() =>
604                {
605                    tracing::trace!("unauthorized error using unidentified; retry over identified");
606                    unidentified_access = None;
607                },
608                Err(ServiceError::MismatchedDevicesException(ref m)) => {
609                    tracing::debug!("{:?}", m);
610                    for extra_device_id in &m.extra_devices {
611                        tracing::debug!(
612                            "dropping session with device {}",
613                            extra_device_id
614                        );
615                        self.protocol_store
616                            .delete_service_addr_device_session(
617                                &recipient.to_protocol_address(
618                                    (*extra_device_id).try_into()?,
619                                )?,
620                            )
621                            .await?;
622                    }
623
624                    for missing_device_id in &m.missing_devices {
625                        tracing::debug!(
626                            "creating session with missing device {}",
627                            missing_device_id
628                        );
629                        let remote_address = recipient.to_protocol_address(
630                            (*missing_device_id).try_into()?,
631                        )?;
632                        let pre_key = self
633                            .service
634                            .get_pre_key(&recipient, *missing_device_id)
635                            .await?;
636
637                        process_prekey_bundle(
638                            &remote_address,
639                            &mut self.protocol_store.clone(),
640                            &mut self.protocol_store,
641                            &pre_key,
642                            SystemTime::now(),
643                            &mut rng,
644                            UsePQRatchet::No,
645                        )
646                        .await
647                        .map_err(|e| {
648                            error!("failed to create session: {}", e);
649                            MessageSenderError::UntrustedIdentity {
650                                address: recipient,
651                            }
652                        })?;
653                    }
654                },
655                Err(ServiceError::StaleDevices(ref m)) => {
656                    tracing::debug!("{:?}", m);
657                    for extra_device_id in &m.stale_devices {
658                        tracing::debug!(
659                            "dropping session with device {}",
660                            extra_device_id
661                        );
662                        self.protocol_store
663                            .delete_service_addr_device_session(
664                                &recipient.to_protocol_address(
665                                    (*extra_device_id).try_into()?,
666                                )?,
667                            )
668                            .await?;
669                    }
670                },
671                Err(ServiceError::ProofRequiredError(ref p)) => {
672                    tracing::debug!("{:?}", p);
673                    return Err(MessageSenderError::ProofRequired {
674                        token: p.token.clone(),
675                        options: p.options.clone(),
676                    });
677                },
678                Err(ServiceError::NotFoundError) => {
679                    tracing::debug!("Not found when sending a message");
680                    return Err(MessageSenderError::NotFound {
681                        service_id: recipient,
682                    });
683                },
684                Err(e) => {
685                    tracing::debug!(
686                        "Default error handler for ws.send_messages: {}",
687                        e
688                    );
689                    return Err(MessageSenderError::ServiceError(e));
690                },
691            }
692        }
693
694        Err(MessageSenderError::MaximumRetriesLimitExceeded)
695    }
696
697    /// Upload contact details to the CDN and send a sync message
698    #[tracing::instrument(
699        skip(self, unidentified_access, contacts, recipient),
700        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
701    )]
702    pub async fn send_contact_details<Contacts>(
703        &mut self,
704        recipient: &ServiceId,
705        unidentified_access: Option<UnidentifiedAccess>,
706        // XXX It may be interesting to use an intermediary type,
707        //     instead of ContactDetails directly,
708        //     because it allows us to add the avatar content.
709        contacts: Contacts,
710        online: bool,
711        complete: bool,
712    ) -> Result<(), MessageSenderError>
713    where
714        Contacts: IntoIterator<Item = ContactDetails>,
715    {
716        let ptr = self.upload_contact_details(contacts).await?;
717
718        let msg = SyncMessage {
719            contacts: Some(sync_message::Contacts {
720                blob: Some(ptr),
721                complete: Some(complete),
722            }),
723            ..SyncMessage::with_padding(&mut rng())
724        };
725
726        self.send_message(
727            recipient,
728            unidentified_access,
729            msg,
730            Utc::now().timestamp_millis() as u64,
731            false,
732            online,
733        )
734        .await?;
735
736        Ok(())
737    }
738
739    /// Send `Configuration` synchronization message
740    #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))]
741    pub async fn send_configuration(
742        &mut self,
743        recipient: &ServiceId,
744        configuration: sync_message::Configuration,
745    ) -> Result<(), MessageSenderError> {
746        let msg = SyncMessage {
747            configuration: Some(configuration),
748            ..SyncMessage::with_padding(&mut rng())
749        };
750
751        let ts = Utc::now().timestamp_millis() as u64;
752        self.send_message(recipient, None, msg, ts, false, false)
753            .await?;
754
755        Ok(())
756    }
757
758    /// Send `MessageRequestResponse` synchronization message with either a recipient ACI or a GroupV2 ID
759    #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))]
760    pub async fn send_message_request_response(
761        &mut self,
762        recipient: &ServiceId,
763        thread: &ThreadIdentifier,
764        action: message_request_response::Type,
765    ) -> Result<(), MessageSenderError> {
766        let message_request_response = Some(match thread {
767            ThreadIdentifier::Aci(aci) => {
768                tracing::debug!(
769                    "sending message request response {:?} for recipient {:?}",
770                    action,
771                    aci
772                );
773                MessageRequestResponse {
774                    thread_aci: Some(aci.to_string()),
775                    group_id: None,
776                    r#type: Some(action.into()),
777                }
778            },
779            ThreadIdentifier::Group(id) => {
780                tracing::debug!(
781                    "sending message request response {:?} for group {:?}",
782                    action,
783                    id
784                );
785                MessageRequestResponse {
786                    thread_aci: None,
787                    group_id: Some(id.to_vec()),
788                    r#type: Some(action.into()),
789                }
790            },
791        });
792
793        let msg = SyncMessage {
794            message_request_response,
795            ..SyncMessage::with_padding(&mut rng())
796        };
797
798        let ts = Utc::now().timestamp_millis() as u64;
799        self.send_message(recipient, None, msg, ts, false, false)
800            .await?;
801
802        Ok(())
803    }
804
805    /// Send `Keys` synchronization message
806    #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))]
807    pub async fn send_keys(
808        &mut self,
809        recipient: &ServiceId,
810        keys: sync_message::Keys,
811    ) -> Result<(), MessageSenderError> {
812        let msg = SyncMessage {
813            keys: Some(keys),
814            ..SyncMessage::with_padding(&mut rng())
815        };
816
817        let ts = Utc::now().timestamp_millis() as u64;
818        self.send_message(recipient, None, msg, ts, false, false)
819            .await?;
820
821        Ok(())
822    }
823
824    /// Send `Blocked` synchronization message
825    #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))]
826    pub async fn send_blocked(
827        &mut self,
828        recipient: &ServiceId,
829        blocked: sync_message::Blocked,
830    ) -> Result<(), MessageSenderError> {
831        let msg = SyncMessage {
832            blocked: Some(blocked),
833            ..SyncMessage::with_padding(&mut rng())
834        };
835
836        let ts = Utc::now().timestamp_millis() as u64;
837        self.send_message(recipient, None, msg, ts, false, false)
838            .await?;
839
840        Ok(())
841    }
842
843    /// Send a `Keys` request message
844    #[tracing::instrument(skip(self))]
845    pub async fn send_sync_message_request(
846        &mut self,
847        recipient: &ServiceId,
848        request_type: sync_message::request::Type,
849    ) -> Result<(), MessageSenderError> {
850        if self.device_id == *DEFAULT_DEVICE_ID {
851            return Err(MessageSenderError::SendSyncMessageError(request_type));
852        }
853
854        let msg = SyncMessage {
855            request: Some(sync_message::Request {
856                r#type: Some(request_type.into()),
857            }),
858            ..SyncMessage::with_padding(&mut rng())
859        };
860
861        let ts = Utc::now().timestamp_millis() as u64;
862        self.send_message(recipient, None, msg, ts, false, false)
863            .await?;
864
865        Ok(())
866    }
867
868    #[expect(clippy::result_large_err)]
869    #[tracing::instrument(level = "trace", skip(self))]
870    fn create_pni_signature(
871        &mut self,
872    ) -> Result<crate::proto::PniSignatureMessage, MessageSenderError> {
873        let mut rng = rng();
874        let signature = self
875            .pni_identity
876            .expect("PNI key set when PNI signature requested")
877            .sign_alternate_identity(
878                self.aci_identity.identity_key(),
879                &mut rng,
880            )?;
881        Ok(crate::proto::PniSignatureMessage {
882            pni: Some(self.local_pni.service_id_binary()),
883            signature: Some(signature.into()),
884        })
885    }
886
887    // Equivalent with `getEncryptedMessages`
888    #[tracing::instrument(
889        level = "trace",
890        skip(self, unidentified_access, content),
891        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
892    )]
893    async fn create_encrypted_messages(
894        &mut self,
895        recipient: &ServiceId,
896        unidentified_access: Option<&SenderCertificate>,
897        content: &[u8],
898    ) -> Result<Option<EncryptedMessages>, MessageSenderError> {
899        let mut messages = vec![];
900
901        let mut devices: HashSet<DeviceId> = self
902            .protocol_store
903            .get_sub_device_sessions(recipient)
904            .await?
905            .into_iter()
906            .collect();
907
908        // always send to the primary device no matter what
909        devices.insert(*DEFAULT_DEVICE_ID);
910
911        // never try to send messages to the sender device
912        match recipient {
913            ServiceId::Aci(aci) => {
914                if *aci == self.local_aci {
915                    devices.remove(&self.device_id);
916                }
917            },
918            ServiceId::Pni(pni) => {
919                if *pni == self.local_pni {
920                    devices.remove(&self.device_id);
921                }
922            },
923        };
924
925        for device_id in devices {
926            trace!("sending message to device {}", device_id);
927            // `create_encrypted_message` may fail with `SessionNotFound` if the session is corrupted;
928            // see https://github.com/whisperfish/libsignal-client/commit/601454d20.
929            // If this happens, delete the session and retry.
930            for _attempt in 0..2 {
931                match self
932                    .create_encrypted_message(
933                        recipient,
934                        unidentified_access,
935                        device_id,
936                        content,
937                    )
938                    .await
939                {
940                    Ok(message) => {
941                        messages.push(message);
942                        break;
943                    },
944                    Err(MessageSenderError::ServiceError(
945                        ServiceError::SignalProtocolError(
946                            SignalProtocolError::SessionNotFound(addr),
947                        ),
948                    )) => {
949                        // SessionNotFound is returned on certain session corruption.
950                        // Since delete_session *creates* a session if it doesn't exist,
951                        // the NotFound error is an indicator of session corruption.
952                        // Try to delete this session, if it gets succesfully deleted, retry.  Otherwise, fail.
953                        tracing::warn!("Potential session corruption for {}, deleting session", addr);
954                        match self.protocol_store.delete_session(&addr).await {
955                            Ok(()) => continue,
956                            Err(error) => {
957                                tracing::warn!(%error, %addr, "failed to delete session");
958                                return Err(
959                                    SignalProtocolError::SessionNotFound(addr)
960                                        .into(),
961                                );
962                            },
963                        }
964                    },
965                    Err(e) => return Err(e),
966                }
967            }
968        }
969
970        if messages.is_empty() {
971            Ok(None)
972        } else {
973            Ok(Some(EncryptedMessages {
974                messages,
975                used_identity_key: self
976                    .protocol_store
977                    .get_identity(
978                        &recipient.to_protocol_address(*DEFAULT_DEVICE_ID),
979                    )
980                    .await?
981                    .ok_or(MessageSenderError::UntrustedIdentity {
982                        address: *recipient,
983                    })?,
984            }))
985        }
986    }
987
988    /// Equivalent to `getEncryptedMessage`
989    ///
990    /// When no session with the recipient exists, we need to create one.
991    #[tracing::instrument(
992        level = "trace",
993        skip(self, unidentified_access, content),
994        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
995    )]
996    pub(crate) async fn create_encrypted_message(
997        &mut self,
998        recipient: &ServiceId,
999        unidentified_access: Option<&SenderCertificate>,
1000        device_id: DeviceId,
1001        content: &[u8],
1002    ) -> Result<OutgoingPushMessage, MessageSenderError> {
1003        let recipient_protocol_address =
1004            recipient.to_protocol_address(device_id);
1005
1006        tracing::trace!(
1007            "encrypting message for {}",
1008            recipient_protocol_address
1009        );
1010
1011        // establish a session with the recipient/device if necessary
1012        // no need to establish a session with ourselves (and our own current device)
1013        if self
1014            .protocol_store
1015            .load_session(&recipient_protocol_address)
1016            .await?
1017            .is_none()
1018        {
1019            info!(
1020                "establishing new session with {}",
1021                recipient_protocol_address
1022            );
1023            let pre_keys = match self
1024                .service
1025                .get_pre_keys(recipient, device_id.into())
1026                .await
1027            {
1028                Ok(ok) => {
1029                    tracing::trace!("Get prekeys OK");
1030                    ok
1031                },
1032                Err(ServiceError::NotFoundError) => {
1033                    return Err(MessageSenderError::NotFound {
1034                        service_id: *recipient,
1035                    });
1036                },
1037                Err(e) => Err(e)?,
1038            };
1039
1040            let mut rng = rng();
1041
1042            for pre_key_bundle in pre_keys {
1043                if recipient == &self.local_aci
1044                    && self.device_id == pre_key_bundle.device_id()?
1045                {
1046                    trace!("not establishing a session with myself!");
1047                    continue;
1048                }
1049
1050                let pre_key_address = get_preferred_protocol_address(
1051                    &self.protocol_store,
1052                    recipient,
1053                    pre_key_bundle.device_id()?,
1054                )
1055                .await?;
1056
1057                process_prekey_bundle(
1058                    &pre_key_address,
1059                    &mut self.protocol_store.clone(),
1060                    &mut self.protocol_store,
1061                    &pre_key_bundle,
1062                    SystemTime::now(),
1063                    &mut rng,
1064                    UsePQRatchet::No,
1065                )
1066                .await?;
1067            }
1068        }
1069
1070        let message = self
1071            .cipher
1072            .encrypt(
1073                &recipient_protocol_address,
1074                unidentified_access,
1075                content,
1076                &mut rng(),
1077            )
1078            .instrument(tracing::trace_span!("encrypting message"))
1079            .await?;
1080
1081        Ok(message)
1082    }
1083
1084    fn create_multi_device_sent_transcript_content<'a>(
1085        &mut self,
1086        recipient: Option<&ServiceId>,
1087        content_body: ContentBody,
1088        timestamp: u64,
1089        send_message_results: impl IntoIterator<Item = &'a SendMessageResult>,
1090    ) -> ContentBody {
1091        use sync_message::sent::UnidentifiedDeliveryStatus;
1092        let (data_message, edit_message) = match content_body {
1093            ContentBody::DataMessage(m) => (Some(m), None),
1094            ContentBody::EditMessage(m) => (None, Some(m)),
1095            _ => (None, None),
1096        };
1097        let unidentified_status: Vec<UnidentifiedDeliveryStatus> =
1098            send_message_results
1099                .into_iter()
1100                .filter_map(|result| result.as_ref().ok())
1101                .map(|sent| {
1102                    let SentMessage {
1103                        recipient,
1104                        unidentified,
1105                        used_identity_key,
1106                        ..
1107                    } = sent;
1108                    UnidentifiedDeliveryStatus {
1109                        destination_service_id: Some(
1110                            recipient.service_id_string(),
1111                        ),
1112                        unidentified: Some(*unidentified),
1113                        destination_pni_identity_key: Some(
1114                            used_identity_key.serialize().into(),
1115                        ),
1116                    }
1117                })
1118                .collect();
1119        ContentBody::SynchronizeMessage(SyncMessage {
1120            sent: Some(sync_message::Sent {
1121                destination_service_id: recipient
1122                    .map(ServiceId::service_id_string),
1123                destination_e164: None,
1124                expiration_start_timestamp: data_message
1125                    .as_ref()
1126                    .and_then(|m| m.expire_timer)
1127                    .map(|_| timestamp),
1128                message: data_message,
1129                edit_message,
1130                timestamp: Some(timestamp),
1131                unidentified_status,
1132                ..Default::default()
1133            }),
1134            ..SyncMessage::with_padding(&mut rng())
1135        })
1136    }
1137}