libsignal_service/
sender.rs

1use std::{collections::HashSet, time::SystemTime};
2
3use chrono::prelude::*;
4use libsignal_protocol::{
5    process_prekey_bundle, Aci, DeviceId, IdentityKey, IdentityKeyPair, Pni,
6    ProtocolStore, SenderCertificate, SenderKeyStore, ServiceId,
7    SignalProtocolError,
8};
9use rand::{thread_rng, CryptoRng, Rng};
10use tracing::{debug, error, info, trace, warn};
11use tracing_futures::Instrument;
12use uuid::Uuid;
13use zkgroup::GROUP_IDENTIFIER_LEN;
14
15use crate::{
16    cipher::{get_preferred_protocol_address, ServiceCipher},
17    content::ContentBody,
18    proto::{
19        attachment_pointer::{
20            AttachmentIdentifier, Flags as AttachmentPointerFlags,
21        },
22        sync_message::{
23            self, message_request_response, MessageRequestResponse,
24        },
25        AttachmentPointer, SyncMessage,
26    },
27    push_service::*,
28    service_address::ServiceIdExt,
29    session_store::SessionStoreExt,
30    unidentified_access::UnidentifiedAccess,
31    utils::serde_service_id,
32    websocket::SignalWebSocket,
33};
34
35pub use crate::proto::{ContactDetails, GroupDetails};
36
37#[derive(serde::Serialize, Debug)]
38#[serde(rename_all = "camelCase")]
39pub struct OutgoingPushMessage {
40    pub r#type: u32,
41    pub destination_device_id: u32,
42    pub destination_registration_id: u32,
43    pub content: String,
44}
45
46#[derive(serde::Serialize, Debug)]
47pub struct OutgoingPushMessages {
48    #[serde(with = "serde_service_id")]
49    pub destination: ServiceId,
50    pub timestamp: u64,
51    pub messages: Vec<OutgoingPushMessage>,
52    pub online: bool,
53}
54
55#[derive(serde::Deserialize, Debug)]
56#[serde(rename_all = "camelCase")]
57pub struct SendMessageResponse {
58    pub needs_sync: bool,
59}
60
61pub type SendMessageResult = Result<SentMessage, MessageSenderError>;
62
63#[derive(Debug, Clone)]
64pub struct SentMessage {
65    pub recipient: ServiceId,
66    pub used_identity_key: IdentityKey,
67    pub unidentified: bool,
68    pub needs_sync: bool,
69}
70
71/// Attachment specification to be used for uploading.
72///
73/// Loose equivalent of Java's `SignalServiceAttachmentStream`.
74#[derive(Debug, Default)]
75pub struct AttachmentSpec {
76    pub content_type: String,
77    pub length: usize,
78    pub file_name: Option<String>,
79    pub preview: Option<Vec<u8>>,
80    pub voice_note: Option<bool>,
81    pub borderless: Option<bool>,
82    pub width: Option<u32>,
83    pub height: Option<u32>,
84    pub caption: Option<String>,
85    pub blur_hash: Option<String>,
86}
87
88#[derive(Clone)]
89pub struct MessageSender<S> {
90    identified_ws: SignalWebSocket,
91    unidentified_ws: SignalWebSocket,
92    service: PushService,
93    cipher: ServiceCipher<S>,
94    protocol_store: S,
95    local_aci: Aci,
96    local_pni: Pni,
97    aci_identity: IdentityKeyPair,
98    pni_identity: Option<IdentityKeyPair>,
99    device_id: DeviceId,
100}
101
102#[derive(thiserror::Error, Debug)]
103pub enum AttachmentUploadError {
104    #[error("{0}")]
105    ServiceError(#[from] ServiceError),
106
107    #[error("Could not read attachment contents")]
108    IoError(#[from] std::io::Error),
109}
110
111#[derive(thiserror::Error, Debug)]
112pub enum MessageSenderError {
113    #[error("service error: {0}")]
114    ServiceError(#[from] ServiceError),
115
116    #[error("protocol error: {0}")]
117    ProtocolError(#[from] SignalProtocolError),
118
119    #[error("Failed to upload attachment {0}")]
120    AttachmentUploadError(#[from] AttachmentUploadError),
121
122    #[error("primary device can't send sync message {0:?}")]
123    SendSyncMessageError(sync_message::request::Type),
124
125    #[error("Untrusted identity key with {address:?}")]
126    UntrustedIdentity { address: ServiceId },
127
128    #[error("Exceeded maximum number of retries")]
129    MaximumRetriesLimitExceeded,
130
131    #[error("Proof of type {options:?} required using token {token}")]
132    ProofRequired { token: String, options: Vec<String> },
133
134    #[error("Recipient not found: {service_id:?}")]
135    NotFound { service_id: ServiceId },
136
137    #[error("no messages were encrypted: this should not really happen and most likely implies a logic error")]
138    NoMessagesToSend,
139}
140
141pub type GroupV2Id = [u8; GROUP_IDENTIFIER_LEN];
142
143#[derive(Debug)]
144pub enum ThreadIdentifier {
145    Aci(Uuid),
146    Group(GroupV2Id),
147}
148
149#[derive(Debug)]
150pub struct EncryptedMessages {
151    messages: Vec<OutgoingPushMessage>,
152    used_identity_key: IdentityKey,
153}
154
155impl<S> MessageSender<S>
156where
157    S: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
158{
159    #[allow(clippy::too_many_arguments)]
160    pub fn new(
161        identified_ws: SignalWebSocket,
162        unidentified_ws: SignalWebSocket,
163        service: PushService,
164        cipher: ServiceCipher<S>,
165        protocol_store: S,
166        local_aci: impl Into<Aci>,
167        local_pni: impl Into<Pni>,
168        aci_identity: IdentityKeyPair,
169        pni_identity: Option<IdentityKeyPair>,
170        device_id: DeviceId,
171    ) -> Self {
172        MessageSender {
173            service,
174            identified_ws,
175            unidentified_ws,
176            cipher,
177            protocol_store,
178            local_aci: local_aci.into(),
179            local_pni: local_pni.into(),
180            aci_identity,
181            pni_identity,
182            device_id,
183        }
184    }
185
186    /// Encrypts and uploads an attachment
187    ///
188    /// Contents are accepted as an owned, plain text Vec, because encryption happens in-place.
189    #[tracing::instrument(skip(self, contents, csprng), fields(size = contents.len()))]
190    pub async fn upload_attachment<R: Rng + CryptoRng>(
191        &mut self,
192        spec: AttachmentSpec,
193        mut contents: Vec<u8>,
194        csprng: &mut R,
195    ) -> Result<AttachmentPointer, AttachmentUploadError> {
196        let len = contents.len();
197        // Encrypt
198        let (key, iv) = {
199            let mut key = [0u8; 64];
200            let mut iv = [0u8; 16];
201            csprng.fill_bytes(&mut key);
202            csprng.fill_bytes(&mut iv);
203            (key, iv)
204        };
205
206        // Padded length uses an exponential bracketting thingy.
207        // If you want to see how it looks:
208        // https://www.wolframalpha.com/input/?i=plot+floor%281.05%5Eceil%28log_1.05%28x%29%29%29+for+x+from+0+to+5000000
209        let padded_len: usize = {
210            // Java:
211            // return (int) Math.max(541, Math.floor(Math.pow(1.05, Math.ceil(Math.log(size) / Math.log(1.05)))))
212            std::cmp::max(
213                541,
214                1.05f64.powf((len as f64).log(1.05).ceil()).floor() as usize,
215            )
216        };
217        if padded_len < len {
218            error!(
219                "Padded len {} < len {}. Continuing with a privacy risk.",
220                padded_len, len
221            );
222        } else {
223            contents.resize(padded_len, 0);
224        }
225
226        tracing::trace_span!("encrypting attachment").in_scope(|| {
227            crate::attachment_cipher::encrypt_in_place(iv, key, &mut contents)
228        });
229
230        // Request upload attributes
231        // TODO: we can actually store the upload spec to be able to resume the upload later
232        // if it fails or stalls (= we should at least split the API calls so clients can decide what to do)
233        let attachment_upload_form = self
234            .service
235            .get_attachment_v4_upload_attributes()
236            .instrument(tracing::trace_span!("requesting upload attributes"))
237            .await?;
238
239        let resumable_upload_url = self
240            .service
241            .get_attachment_resumable_upload_url(&attachment_upload_form)
242            .await?;
243
244        let attachment_digest = self
245            .service
246            .upload_attachment_v4(
247                attachment_upload_form.cdn,
248                &resumable_upload_url,
249                contents.len() as u64,
250                attachment_upload_form.headers,
251                &mut std::io::Cursor::new(&contents),
252            )
253            .await?;
254
255        Ok(AttachmentPointer {
256            content_type: Some(spec.content_type),
257            key: Some(key.to_vec()),
258            size: Some(len as u32),
259            // thumbnail: Option<Vec<u8>>,
260            digest: Some(attachment_digest.digest),
261            file_name: spec.file_name,
262            flags: Some(
263                if spec.voice_note == Some(true) {
264                    AttachmentPointerFlags::VoiceMessage as u32
265                } else {
266                    0
267                } | if spec.borderless == Some(true) {
268                    AttachmentPointerFlags::Borderless as u32
269                } else {
270                    0
271                },
272            ),
273            width: spec.width,
274            height: spec.height,
275            caption: spec.caption,
276            blur_hash: spec.blur_hash,
277            upload_timestamp: Some(
278                SystemTime::now()
279                    .duration_since(SystemTime::UNIX_EPOCH)
280                    .expect("unix epoch in the past")
281                    .as_millis() as u64,
282            ),
283            cdn_number: Some(attachment_upload_form.cdn),
284            attachment_identifier: Some(AttachmentIdentifier::CdnKey(
285                attachment_upload_form.key,
286            )),
287            ..Default::default()
288        })
289    }
290
291    /// Upload contact details to the CDN
292    ///
293    /// Returns attachment ID and the attachment digest
294    #[tracing::instrument(skip(self, contacts))]
295    async fn upload_contact_details<Contacts>(
296        &mut self,
297        contacts: Contacts,
298    ) -> Result<AttachmentPointer, AttachmentUploadError>
299    where
300        Contacts: IntoIterator<Item = ContactDetails>,
301    {
302        use prost::Message;
303        let mut out = Vec::new();
304        for contact in contacts {
305            contact
306                .encode_length_delimited(&mut out)
307                .expect("infallible encoding");
308            // XXX add avatar here
309        }
310
311        let spec = AttachmentSpec {
312            content_type: "application/octet-stream".into(),
313            length: out.len(),
314            file_name: None,
315            preview: None,
316            voice_note: None,
317            borderless: None,
318            width: None,
319            height: None,
320            caption: None,
321            blur_hash: None,
322        };
323        self.upload_attachment(spec, out, &mut thread_rng()).await
324    }
325
326    /// Return whether we have to prepare sync messages for other devices
327    ///
328    /// - If we are the main registered device, and there are established sub-device sessions (linked clients), return true
329    /// - If we are a secondary linked device, return true
330    async fn is_multi_device(&self) -> bool {
331        if self.device_id == DEFAULT_DEVICE_ID.into() {
332            self.protocol_store
333                .get_sub_device_sessions(&self.local_aci.into())
334                .await
335                .is_ok_and(|s| !s.is_empty())
336        } else {
337            true
338        }
339    }
340
341    /// Send a message `content` to a single `recipient`.
342    #[tracing::instrument(
343        skip(self, unidentified_access, message),
344        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
345    )]
346    pub async fn send_message(
347        &mut self,
348        recipient: &ServiceId,
349        mut unidentified_access: Option<UnidentifiedAccess>,
350        message: impl Into<ContentBody>,
351        timestamp: u64,
352        include_pni_signature: bool,
353        online: bool,
354    ) -> SendMessageResult {
355        let content_body = message.into();
356        let message_to_self = recipient == &self.local_aci;
357        let sync_message =
358            matches!(content_body, ContentBody::SynchronizeMessage(..));
359        let is_multi_device = self.is_multi_device().await;
360
361        use crate::proto::data_message::Flags;
362
363        let end_session = match &content_body {
364            ContentBody::DataMessage(message) => {
365                message.flags == Some(Flags::EndSession as u32)
366            },
367            _ => false,
368        };
369
370        // only send a sync message when sending to self and skip the rest of the process
371        if message_to_self && is_multi_device && !sync_message {
372            debug!("sending note to self");
373            let sync_message = self
374                .create_multi_device_sent_transcript_content(
375                    Some(recipient),
376                    content_body,
377                    timestamp,
378                    None,
379                );
380            return self
381                .try_send_message(
382                    *recipient,
383                    None,
384                    &sync_message,
385                    timestamp,
386                    include_pni_signature,
387                    online,
388                )
389                .await;
390        }
391
392        // don't send session enders as sealed sender
393        // sync messages are never sent as unidentified (reasons unclear), see: https://github.com/signalapp/Signal-Android/blob/main/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java#L779
394        if end_session || sync_message {
395            unidentified_access.take();
396        }
397
398        // try to send the original message to all the recipient's devices
399        let result = self
400            .try_send_message(
401                *recipient,
402                unidentified_access.as_ref(),
403                &content_body,
404                timestamp,
405                include_pni_signature,
406                online,
407            )
408            .await;
409
410        let needs_sync = match &result {
411            Ok(SentMessage { needs_sync, .. }) => *needs_sync,
412            _ => false,
413        };
414
415        if needs_sync || is_multi_device {
416            debug!("sending multi-device sync message");
417            let sync_message = self
418                .create_multi_device_sent_transcript_content(
419                    Some(recipient),
420                    content_body,
421                    timestamp,
422                    Some(&result),
423                );
424            self.try_send_message(
425                self.local_aci.into(),
426                None,
427                &sync_message,
428                timestamp,
429                false,
430                false,
431            )
432            .await?;
433        }
434
435        if end_session {
436            let n = self.protocol_store.delete_all_sessions(recipient).await?;
437            tracing::debug!(
438                "ended {} sessions with {}",
439                n,
440                recipient.raw_uuid()
441            );
442        }
443
444        result
445    }
446
447    /// Send a message to the recipients in a group.
448    ///
449    /// Recipients are a list of tuples, each containing:
450    /// - The recipient's address
451    /// - The recipient's unidentified access
452    /// - Whether the recipient requires a PNI signature
453    #[tracing::instrument(
454        skip(self, recipients, message),
455        fields(recipients = recipients.as_ref().len()),
456    )]
457    pub async fn send_message_to_group(
458        &mut self,
459        recipients: impl AsRef<[(ServiceId, Option<UnidentifiedAccess>, bool)]>,
460        message: impl Into<ContentBody>,
461        timestamp: u64,
462        online: bool,
463    ) -> Vec<SendMessageResult> {
464        let content_body: ContentBody = message.into();
465        let mut results = vec![];
466
467        let mut needs_sync_in_results = false;
468
469        for (recipient, unidentified_access, include_pni_signature) in
470            recipients.as_ref()
471        {
472            let result = self
473                .try_send_message(
474                    *recipient,
475                    unidentified_access.as_ref(),
476                    &content_body,
477                    timestamp,
478                    *include_pni_signature,
479                    online,
480                )
481                .await;
482
483            match result {
484                Ok(SentMessage { needs_sync, .. }) if needs_sync => {
485                    needs_sync_in_results = true;
486                },
487                _ => (),
488            };
489
490            results.push(result);
491        }
492
493        // we only need to send a synchronization message once
494        if needs_sync_in_results || self.is_multi_device().await {
495            let sync_message = self
496                .create_multi_device_sent_transcript_content(
497                    None,
498                    content_body,
499                    timestamp,
500                    &results,
501                );
502            // Note: the result of sending a sync message is not included in results
503            // See Signal Android `SignalServiceMessageSender.java:2817`
504            if let Err(error) = self
505                .try_send_message(
506                    self.local_aci.into(),
507                    None,
508                    &sync_message,
509                    timestamp,
510                    false, // XXX: maybe the sync device does want a PNI signature?
511                    false,
512                )
513                .await
514            {
515                error!(%error, "failed to send a synchronization message");
516            }
517        }
518
519        results
520    }
521
522    /// Send a message (`content`) to an address (`recipient`).
523    #[tracing::instrument(
524        level = "trace",
525        skip(self, unidentified_access, content_body, recipient),
526        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
527    )]
528    async fn try_send_message(
529        &mut self,
530        recipient: ServiceId,
531        mut unidentified_access: Option<&UnidentifiedAccess>,
532        content_body: &ContentBody,
533        timestamp: u64,
534        include_pni_signature: bool,
535        online: bool,
536    ) -> SendMessageResult {
537        trace!("trying to send a message");
538
539        use prost::Message;
540
541        let mut content = content_body.clone().into_proto();
542        if include_pni_signature {
543            content.pni_signature_message = Some(self.create_pni_signature()?);
544        }
545
546        let content_bytes = content.encode_to_vec();
547
548        let mut rng = thread_rng();
549
550        for _ in 0..4u8 {
551            let Some(EncryptedMessages {
552                messages,
553                used_identity_key,
554            }) = self
555                .create_encrypted_messages(
556                    &recipient,
557                    unidentified_access.map(|x| &x.certificate),
558                    &content_bytes,
559                )
560                .await?
561            else {
562                // this can happen for example when a device is primary, without any secondaries
563                // and we send a message to ourselves (which is only a SyncMessage { sent: ... })
564                // addressed to self
565                return Err(MessageSenderError::NoMessagesToSend);
566            };
567
568            let messages = OutgoingPushMessages {
569                destination: recipient,
570                timestamp,
571                messages,
572                online,
573            };
574
575            let send = if let Some(unidentified) = &unidentified_access {
576                tracing::debug!("sending via unidentified");
577                self.unidentified_ws
578                    .send_messages_unidentified(messages, unidentified)
579                    .await
580            } else {
581                tracing::debug!("sending identified");
582                self.identified_ws.send_messages(messages).await
583            };
584
585            match send {
586                Ok(SendMessageResponse { needs_sync }) => {
587                    tracing::debug!("message sent!");
588                    return Ok(SentMessage {
589                        recipient,
590                        used_identity_key,
591                        unidentified: unidentified_access.is_some(),
592                        needs_sync,
593                    });
594                },
595                Err(ServiceError::Unauthorized)
596                    if unidentified_access.is_some() =>
597                {
598                    tracing::trace!("unauthorized error using unidentified; retry over identified");
599                    unidentified_access = None;
600                },
601                Err(ServiceError::MismatchedDevicesException(ref m)) => {
602                    tracing::debug!("{:?}", m);
603                    for extra_device_id in &m.extra_devices {
604                        tracing::debug!(
605                            "dropping session with device {}",
606                            extra_device_id
607                        );
608                        self.protocol_store
609                            .delete_service_addr_device_session(
610                                &recipient
611                                    .to_protocol_address(*extra_device_id),
612                            )
613                            .await?;
614                    }
615
616                    for missing_device_id in &m.missing_devices {
617                        tracing::debug!(
618                            "creating session with missing device {}",
619                            missing_device_id
620                        );
621                        let remote_address =
622                            recipient.to_protocol_address(*missing_device_id);
623                        let pre_key = self
624                            .service
625                            .get_pre_key(&recipient, *missing_device_id)
626                            .await?;
627
628                        process_prekey_bundle(
629                            &remote_address,
630                            &mut self.protocol_store.clone(),
631                            &mut self.protocol_store,
632                            &pre_key,
633                            SystemTime::now(),
634                            &mut rng,
635                        )
636                        .await
637                        .map_err(|e| {
638                            error!("failed to create session: {}", e);
639                            MessageSenderError::UntrustedIdentity {
640                                address: recipient,
641                            }
642                        })?;
643                    }
644                },
645                Err(ServiceError::StaleDevices(ref m)) => {
646                    tracing::debug!("{:?}", m);
647                    for extra_device_id in &m.stale_devices {
648                        tracing::debug!(
649                            "dropping session with device {}",
650                            extra_device_id
651                        );
652                        self.protocol_store
653                            .delete_service_addr_device_session(
654                                &recipient
655                                    .to_protocol_address(*extra_device_id),
656                            )
657                            .await?;
658                    }
659                },
660                Err(ServiceError::ProofRequiredError(ref p)) => {
661                    tracing::debug!("{:?}", p);
662                    return Err(MessageSenderError::ProofRequired {
663                        token: p.token.clone(),
664                        options: p.options.clone(),
665                    });
666                },
667                Err(ServiceError::NotFoundError) => {
668                    tracing::debug!("Not found when sending a message");
669                    return Err(MessageSenderError::NotFound {
670                        service_id: recipient,
671                    });
672                },
673                Err(e) => {
674                    tracing::debug!(
675                        "Default error handler for ws.send_messages: {}",
676                        e
677                    );
678                    return Err(MessageSenderError::ServiceError(e));
679                },
680            }
681        }
682
683        Err(MessageSenderError::MaximumRetriesLimitExceeded)
684    }
685
686    /// Upload contact details to the CDN and send a sync message
687    #[tracing::instrument(
688        skip(self, unidentified_access, contacts, recipient),
689        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
690    )]
691    pub async fn send_contact_details<Contacts>(
692        &mut self,
693        recipient: &ServiceId,
694        unidentified_access: Option<UnidentifiedAccess>,
695        // XXX It may be interesting to use an intermediary type,
696        //     instead of ContactDetails directly,
697        //     because it allows us to add the avatar content.
698        contacts: Contacts,
699        online: bool,
700        complete: bool,
701    ) -> Result<(), MessageSenderError>
702    where
703        Contacts: IntoIterator<Item = ContactDetails>,
704    {
705        let ptr = self.upload_contact_details(contacts).await?;
706
707        let msg = SyncMessage {
708            contacts: Some(sync_message::Contacts {
709                blob: Some(ptr),
710                complete: Some(complete),
711            }),
712            ..SyncMessage::with_padding(&mut thread_rng())
713        };
714
715        self.send_message(
716            recipient,
717            unidentified_access,
718            msg,
719            Utc::now().timestamp_millis() as u64,
720            false,
721            online,
722        )
723        .await?;
724
725        Ok(())
726    }
727
728    /// Send `Configuration` synchronization message
729    #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))]
730    pub async fn send_configuration(
731        &mut self,
732        recipient: &ServiceId,
733        configuration: sync_message::Configuration,
734    ) -> Result<(), MessageSenderError> {
735        let msg = SyncMessage {
736            configuration: Some(configuration),
737            ..SyncMessage::with_padding(&mut thread_rng())
738        };
739
740        let ts = Utc::now().timestamp_millis() as u64;
741        self.send_message(recipient, None, msg, ts, false, false)
742            .await?;
743
744        Ok(())
745    }
746
747    /// Send `MessageRequestResponse` synchronization message with either a recipient ACI or a GroupV2 ID
748    #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))]
749    pub async fn send_message_request_response(
750        &mut self,
751        recipient: &ServiceId,
752        thread: &ThreadIdentifier,
753        action: message_request_response::Type,
754    ) -> Result<(), MessageSenderError> {
755        let message_request_response = Some(match thread {
756            ThreadIdentifier::Aci(aci) => {
757                tracing::debug!(
758                    "sending message request response {:?} for recipient {:?}",
759                    action,
760                    aci
761                );
762                MessageRequestResponse {
763                    thread_aci: Some(aci.to_string()),
764                    group_id: None,
765                    r#type: Some(action.into()),
766                }
767            },
768            ThreadIdentifier::Group(id) => {
769                tracing::debug!(
770                    "sending message request response {:?} for group {:?}",
771                    action,
772                    id
773                );
774                MessageRequestResponse {
775                    thread_aci: None,
776                    group_id: Some(id.to_vec()),
777                    r#type: Some(action.into()),
778                }
779            },
780        });
781
782        let msg = SyncMessage {
783            message_request_response,
784            ..SyncMessage::with_padding(&mut thread_rng())
785        };
786
787        let ts = Utc::now().timestamp_millis() as u64;
788        self.send_message(recipient, None, msg, ts, false, false)
789            .await?;
790
791        Ok(())
792    }
793
794    /// Send `Keys` synchronization message
795    #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))]
796    pub async fn send_keys(
797        &mut self,
798        recipient: &ServiceId,
799        keys: sync_message::Keys,
800    ) -> Result<(), MessageSenderError> {
801        let msg = SyncMessage {
802            keys: Some(keys),
803            ..SyncMessage::with_padding(&mut thread_rng())
804        };
805
806        let ts = Utc::now().timestamp_millis() as u64;
807        self.send_message(recipient, None, msg, ts, false, false)
808            .await?;
809
810        Ok(())
811    }
812
813    /// Send a `Keys` request message
814    #[tracing::instrument(skip(self))]
815    pub async fn send_sync_message_request(
816        &mut self,
817        recipient: &ServiceId,
818        request_type: sync_message::request::Type,
819    ) -> Result<(), MessageSenderError> {
820        if self.device_id == DEFAULT_DEVICE_ID.into() {
821            return Err(MessageSenderError::SendSyncMessageError(request_type));
822        }
823
824        let msg = SyncMessage {
825            request: Some(sync_message::Request {
826                r#type: Some(request_type.into()),
827            }),
828            ..SyncMessage::with_padding(&mut thread_rng())
829        };
830
831        let ts = Utc::now().timestamp_millis() as u64;
832        self.send_message(recipient, None, msg, ts, false, false)
833            .await?;
834
835        Ok(())
836    }
837
838    #[tracing::instrument(level = "trace", skip(self))]
839    fn create_pni_signature(
840        &mut self,
841    ) -> Result<crate::proto::PniSignatureMessage, MessageSenderError> {
842        let mut rng = thread_rng();
843        let signature = self
844            .pni_identity
845            .expect("PNI key set when PNI signature requested")
846            .sign_alternate_identity(
847                self.aci_identity.identity_key(),
848                &mut rng,
849            )?;
850        Ok(crate::proto::PniSignatureMessage {
851            pni: Some(self.local_pni.service_id_binary()),
852            signature: Some(signature.into()),
853        })
854    }
855
856    // Equivalent with `getEncryptedMessages`
857    #[tracing::instrument(
858        level = "trace",
859        skip(self, unidentified_access, content),
860        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
861    )]
862    async fn create_encrypted_messages(
863        &mut self,
864        recipient: &ServiceId,
865        unidentified_access: Option<&SenderCertificate>,
866        content: &[u8],
867    ) -> Result<Option<EncryptedMessages>, MessageSenderError> {
868        let mut messages = vec![];
869
870        let mut devices: HashSet<DeviceId> = self
871            .protocol_store
872            .get_sub_device_sessions(recipient)
873            .await?
874            .into_iter()
875            .map(DeviceId::from)
876            .collect();
877
878        // always send to the primary device no matter what
879        devices.insert(DEFAULT_DEVICE_ID.into());
880
881        // never try to send messages to the sender device
882        match recipient {
883            ServiceId::Aci(aci) => {
884                if *aci == self.local_aci {
885                    devices.remove(&self.device_id);
886                }
887            },
888            ServiceId::Pni(pni) => {
889                if *pni == self.local_pni {
890                    devices.remove(&self.device_id);
891                }
892            },
893        };
894
895        for device_id in devices {
896            trace!("sending message to device {}", device_id);
897            // `create_encrypted_message` may fail with `SessionNotFound` if the session is corrupted;
898            // see https://github.com/whisperfish/libsignal-client/commit/601454d20.
899            // If this happens, delete the session and retry.
900            for _attempt in 0..2 {
901                match self
902                    .create_encrypted_message(
903                        recipient,
904                        unidentified_access,
905                        device_id,
906                        content,
907                    )
908                    .await
909                {
910                    Ok(message) => {
911                        messages.push(message);
912                        break;
913                    },
914                    Err(MessageSenderError::ServiceError(
915                        ServiceError::SignalProtocolError(
916                            SignalProtocolError::SessionNotFound(addr),
917                        ),
918                    )) => {
919                        // SessionNotFound is returned on certain session corruption.
920                        // Since delete_session *creates* a session if it doesn't exist,
921                        // the NotFound error is an indicator of session corruption.
922                        // Try to delete this session, if it gets succesfully deleted, retry.  Otherwise, fail.
923                        tracing::warn!("Potential session corruption for {}, deleting session", addr);
924                        match self.protocol_store.delete_session(&addr).await {
925                            Ok(()) => continue,
926                            Err(error) => {
927                                tracing::warn!(%error, %addr, "failed to delete session");
928                                return Err(
929                                    SignalProtocolError::SessionNotFound(addr)
930                                        .into(),
931                                );
932                            },
933                        }
934                    },
935                    Err(e) => return Err(e),
936                }
937            }
938        }
939
940        if messages.is_empty() {
941            Ok(None)
942        } else {
943            Ok(Some(EncryptedMessages {
944                messages,
945                used_identity_key: self
946                    .protocol_store
947                    .get_identity(
948                        &recipient.to_protocol_address(DEFAULT_DEVICE_ID),
949                    )
950                    .await?
951                    .ok_or(MessageSenderError::UntrustedIdentity {
952                        address: *recipient,
953                    })?,
954            }))
955        }
956    }
957
958    /// Equivalent to `getEncryptedMessage`
959    ///
960    /// When no session with the recipient exists, we need to create one.
961    #[tracing::instrument(
962        level = "trace",
963        skip(self, unidentified_access, content),
964        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
965    )]
966    pub(crate) async fn create_encrypted_message(
967        &mut self,
968        recipient: &ServiceId,
969        unidentified_access: Option<&SenderCertificate>,
970        device_id: DeviceId,
971        content: &[u8],
972    ) -> Result<OutgoingPushMessage, MessageSenderError> {
973        let recipient_protocol_address =
974            recipient.to_protocol_address(device_id);
975
976        tracing::trace!(
977            "encrypting message for {}",
978            recipient_protocol_address
979        );
980
981        // establish a session with the recipient/device if necessary
982        // no need to establish a session with ourselves (and our own current device)
983        if self
984            .protocol_store
985            .load_session(&recipient_protocol_address)
986            .await?
987            .is_none()
988        {
989            info!(
990                "establishing new session with {}",
991                recipient_protocol_address
992            );
993            let pre_keys = match self
994                .service
995                .get_pre_keys(recipient, device_id.into())
996                .await
997            {
998                Ok(ok) => {
999                    tracing::trace!("Get prekeys OK");
1000                    ok
1001                },
1002                Err(ServiceError::NotFoundError) => {
1003                    return Err(MessageSenderError::NotFound {
1004                        service_id: *recipient,
1005                    });
1006                },
1007                Err(e) => Err(e)?,
1008            };
1009
1010            let mut rng = thread_rng();
1011
1012            for pre_key_bundle in pre_keys {
1013                if recipient == &self.local_aci
1014                    && self.device_id == pre_key_bundle.device_id()?
1015                {
1016                    trace!("not establishing a session with myself!");
1017                    continue;
1018                }
1019
1020                let pre_key_address = get_preferred_protocol_address(
1021                    &self.protocol_store,
1022                    recipient,
1023                    pre_key_bundle.device_id()?,
1024                )
1025                .await?;
1026
1027                process_prekey_bundle(
1028                    &pre_key_address,
1029                    &mut self.protocol_store.clone(),
1030                    &mut self.protocol_store,
1031                    &pre_key_bundle,
1032                    SystemTime::now(),
1033                    &mut rng,
1034                )
1035                .await?;
1036            }
1037        }
1038
1039        let message = self
1040            .cipher
1041            .encrypt(
1042                &recipient_protocol_address,
1043                unidentified_access,
1044                content,
1045                &mut thread_rng(),
1046            )
1047            .instrument(tracing::trace_span!("encrypting message"))
1048            .await?;
1049
1050        Ok(message)
1051    }
1052
1053    fn create_multi_device_sent_transcript_content<'a>(
1054        &mut self,
1055        recipient: Option<&ServiceId>,
1056        content_body: ContentBody,
1057        timestamp: u64,
1058        send_message_results: impl IntoIterator<Item = &'a SendMessageResult>,
1059    ) -> ContentBody {
1060        use sync_message::sent::UnidentifiedDeliveryStatus;
1061        let (data_message, edit_message) = match content_body {
1062            ContentBody::DataMessage(m) => (Some(m), None),
1063            ContentBody::EditMessage(m) => (None, Some(m)),
1064            _ => (None, None),
1065        };
1066        let unidentified_status: Vec<UnidentifiedDeliveryStatus> =
1067            send_message_results
1068                .into_iter()
1069                .filter_map(|result| result.as_ref().ok())
1070                .map(|sent| {
1071                    let SentMessage {
1072                        recipient,
1073                        unidentified,
1074                        used_identity_key,
1075                        ..
1076                    } = sent;
1077                    UnidentifiedDeliveryStatus {
1078                        destination_service_id: Some(
1079                            recipient.service_id_string(),
1080                        ),
1081                        unidentified: Some(*unidentified),
1082                        destination_identity_key: Some(
1083                            used_identity_key.serialize().into(),
1084                        ),
1085                    }
1086                })
1087                .collect();
1088        ContentBody::SynchronizeMessage(SyncMessage {
1089            sent: Some(sync_message::Sent {
1090                destination_service_id: recipient
1091                    .map(ServiceId::service_id_string),
1092                destination_e164: None,
1093                expiration_start_timestamp: data_message
1094                    .as_ref()
1095                    .and_then(|m| m.expire_timer)
1096                    .map(|_| timestamp),
1097                message: data_message,
1098                edit_message,
1099                timestamp: Some(timestamp),
1100                unidentified_status,
1101                ..Default::default()
1102            }),
1103            ..SyncMessage::with_padding(&mut thread_rng())
1104        })
1105    }
1106}