Skip to main content

libsignal_service/
sender.rs

1use std::{collections::HashSet, time::SystemTime};
2
3use chrono::prelude::*;
4use libsignal_core::{curve::CurveError, InvalidDeviceId};
5use libsignal_protocol::{
6    process_prekey_bundle, Aci, DeviceId, IdentityKey, IdentityKeyPair, Pni,
7    ProtocolStore, SenderCertificate, SenderKeyStore, ServiceId,
8    SignalProtocolError,
9};
10use rand::{rng, CryptoRng, Rng};
11use tracing::{debug, error, info, trace, warn};
12use tracing_futures::Instrument;
13use uuid::Uuid;
14use zkgroup::GROUP_IDENTIFIER_LEN;
15
16use crate::{
17    cipher::{get_preferred_protocol_address, ServiceCipher},
18    content::ContentBody,
19    proto::{
20        attachment_pointer::{
21            AttachmentIdentifier, Flags as AttachmentPointerFlags,
22        },
23        sync_message::{
24            self, message_request_response, MessageRequestResponse,
25        },
26        AttachmentPointer, SyncMessage,
27    },
28    push_service::*,
29    service_address::ServiceIdExt,
30    session_store::SessionStoreExt,
31    unidentified_access::UnidentifiedAccess,
32    utils::{serde_device_id, serde_service_id},
33    websocket::{self, SignalWebSocket},
34};
35
36pub use crate::proto::ContactDetails;
37
38#[derive(serde::Serialize, Debug)]
39#[serde(rename_all = "camelCase")]
40pub struct OutgoingPushMessage {
41    pub r#type: u32,
42    #[serde(with = "serde_device_id")]
43    pub destination_device_id: DeviceId,
44    pub destination_registration_id: u32,
45    pub content: String,
46}
47
48#[derive(serde::Serialize, Debug)]
49pub struct OutgoingPushMessages {
50    #[serde(with = "serde_service_id")]
51    pub destination: ServiceId,
52    pub timestamp: u64,
53    pub messages: Vec<OutgoingPushMessage>,
54    pub online: bool,
55}
56
57#[derive(serde::Deserialize, Debug)]
58#[serde(rename_all = "camelCase")]
59pub struct SendMessageResponse {
60    pub needs_sync: bool,
61}
62
63pub type SendMessageResult = Result<SentMessage, MessageSenderError>;
64
65#[derive(Debug, Clone)]
66pub struct SentMessage {
67    pub recipient: ServiceId,
68    pub used_identity_key: IdentityKey,
69    pub unidentified: bool,
70    pub needs_sync: bool,
71}
72
73/// Attachment specification to be used for uploading.
74///
75/// Loose equivalent of Java's `SignalServiceAttachmentStream`.
76#[derive(Debug, Default)]
77pub struct AttachmentSpec {
78    pub content_type: String,
79    pub length: usize,
80    pub file_name: Option<String>,
81    pub preview: Option<Vec<u8>>,
82    pub voice_note: Option<bool>,
83    pub borderless: Option<bool>,
84    pub width: Option<u32>,
85    pub height: Option<u32>,
86    pub caption: Option<String>,
87    pub blur_hash: Option<String>,
88}
89
90#[derive(Clone)]
91pub struct MessageSender<S> {
92    identified_ws: SignalWebSocket<websocket::Identified>,
93    unidentified_ws: SignalWebSocket<websocket::Unidentified>,
94    service: PushService,
95    cipher: ServiceCipher<S>,
96    protocol_store: S,
97    local_aci: Aci,
98    local_pni: Pni,
99    aci_identity: IdentityKeyPair,
100    pni_identity: Option<IdentityKeyPair>,
101    device_id: DeviceId,
102}
103
104#[derive(thiserror::Error, Debug)]
105pub enum AttachmentUploadError {
106    #[error("{0}")]
107    ServiceError(#[from] ServiceError),
108
109    #[error("Could not read attachment contents")]
110    IoError(#[from] std::io::Error),
111}
112
113#[derive(thiserror::Error, Debug)]
114pub enum MessageSenderError {
115    #[error("service error: {0}")]
116    ServiceError(#[from] ServiceError),
117
118    #[error("protocol error: {0}")]
119    ProtocolError(#[from] SignalProtocolError),
120
121    #[error("invalid private key: {0}")]
122    InvalidPrivateKey(#[from] CurveError),
123
124    #[error("invalid device ID: {0}")]
125    InvalidDeviceId(#[from] InvalidDeviceId),
126
127    #[error("Failed to upload attachment {0}")]
128    AttachmentUploadError(#[from] AttachmentUploadError),
129
130    #[error("primary device can't send sync message {0:?}")]
131    SendSyncMessageError(sync_message::request::Type),
132
133    #[error("Untrusted identity key with {address:?}")]
134    UntrustedIdentity { address: ServiceId },
135
136    #[error("Exceeded maximum number of retries")]
137    MaximumRetriesLimitExceeded,
138
139    #[error("Proof of type {options:?} required using token {token}")]
140    ProofRequired { token: String, options: Vec<String> },
141
142    #[error("Recipient not found: {service_id:?}")]
143    NotFound { service_id: ServiceId },
144
145    #[error("no messages were encrypted: this should not really happen and most likely implies a logic error")]
146    NoMessagesToSend,
147}
148
149pub type GroupV2Id = [u8; GROUP_IDENTIFIER_LEN];
150
151#[derive(Debug)]
152pub enum ThreadIdentifier {
153    Aci(Uuid),
154    Group(GroupV2Id),
155}
156
157#[derive(Debug)]
158pub struct EncryptedMessages {
159    messages: Vec<OutgoingPushMessage>,
160    used_identity_key: IdentityKey,
161}
162
163impl<S> MessageSender<S>
164where
165    S: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
166{
167    #[allow(clippy::too_many_arguments)]
168    pub fn new(
169        identified_ws: SignalWebSocket<websocket::Identified>,
170        unidentified_ws: SignalWebSocket<websocket::Unidentified>,
171        service: PushService,
172        cipher: ServiceCipher<S>,
173        protocol_store: S,
174        local_aci: impl Into<Aci>,
175        local_pni: impl Into<Pni>,
176        aci_identity: IdentityKeyPair,
177        pni_identity: Option<IdentityKeyPair>,
178        device_id: DeviceId,
179    ) -> Self {
180        MessageSender {
181            service,
182            identified_ws,
183            unidentified_ws,
184            cipher,
185            protocol_store,
186            local_aci: local_aci.into(),
187            local_pni: local_pni.into(),
188            aci_identity,
189            pni_identity,
190            device_id,
191        }
192    }
193
194    /// Encrypts and uploads an attachment
195    ///
196    /// Contents are accepted as an owned, plain text Vec, because encryption happens in-place.
197    #[tracing::instrument(skip(self, contents, csprng), fields(size = contents.len()))]
198    pub async fn upload_attachment<R: Rng + CryptoRng>(
199        &mut self,
200        spec: AttachmentSpec,
201        mut contents: Vec<u8>,
202        csprng: &mut R,
203    ) -> Result<AttachmentPointer, AttachmentUploadError> {
204        let len = contents.len();
205        // Encrypt
206        let (key, iv) = {
207            let mut key = [0u8; 64];
208            let mut iv = [0u8; 16];
209            csprng.fill_bytes(&mut key);
210            csprng.fill_bytes(&mut iv);
211            (key, iv)
212        };
213
214        // Padded length uses an exponential bracketting thingy.
215        // If you want to see how it looks:
216        // https://www.wolframalpha.com/input/?i=plot+floor%281.05%5Eceil%28log_1.05%28x%29%29%29+for+x+from+0+to+5000000
217        let padded_len: usize = {
218            // Java:
219            // return (int) Math.max(541, Math.floor(Math.pow(1.05, Math.ceil(Math.log(size) / Math.log(1.05)))))
220            std::cmp::max(
221                541,
222                1.05f64.powf((len as f64).log(1.05).ceil()).floor() as usize,
223            )
224        };
225        if padded_len < len {
226            error!(
227                "Padded len {} < len {}. Continuing with a privacy risk.",
228                padded_len, len
229            );
230        } else {
231            contents.resize(padded_len, 0);
232        }
233
234        tracing::trace_span!("encrypting attachment").in_scope(|| {
235            crate::attachment_cipher::encrypt_in_place(iv, key, &mut contents)
236        });
237
238        // Request upload attributes
239        // TODO: we can actually store the upload spec to be able to resume the upload later
240        // if it fails or stalls (= we should at least split the API calls so clients can decide what to do)
241        let attachment_upload_form = self
242            .service
243            .get_attachment_v4_upload_attributes()
244            .instrument(tracing::trace_span!("requesting upload attributes"))
245            .await?;
246
247        let resumable_upload_url = self
248            .service
249            .get_attachment_resumable_upload_url(&attachment_upload_form)
250            .await?;
251
252        let attachment_digest = self
253            .service
254            .upload_attachment_v4(
255                attachment_upload_form.cdn,
256                &resumable_upload_url,
257                contents.len() as u64,
258                attachment_upload_form.headers,
259                &mut std::io::Cursor::new(&contents),
260            )
261            .await?;
262
263        Ok(AttachmentPointer {
264            content_type: Some(spec.content_type),
265            key: Some(key.to_vec()),
266            size: Some(len as u32),
267            // thumbnail: Option<Vec<u8>>,
268            digest: Some(attachment_digest.digest),
269            file_name: spec.file_name,
270            flags: Some(
271                if spec.voice_note == Some(true) {
272                    AttachmentPointerFlags::VoiceMessage as u32
273                } else {
274                    0
275                } | if spec.borderless == Some(true) {
276                    AttachmentPointerFlags::Borderless as u32
277                } else {
278                    0
279                },
280            ),
281            width: spec.width,
282            height: spec.height,
283            caption: spec.caption,
284            blur_hash: spec.blur_hash,
285            upload_timestamp: Some(
286                SystemTime::now()
287                    .duration_since(SystemTime::UNIX_EPOCH)
288                    .expect("unix epoch in the past")
289                    .as_millis() as u64,
290            ),
291            cdn_number: Some(attachment_upload_form.cdn),
292            attachment_identifier: Some(AttachmentIdentifier::CdnKey(
293                attachment_upload_form.key,
294            )),
295            ..Default::default()
296        })
297    }
298
299    /// Upload contact details to the CDN
300    ///
301    /// Returns attachment ID and the attachment digest
302    #[tracing::instrument(skip(self, contacts))]
303    async fn upload_contact_details<Contacts>(
304        &mut self,
305        contacts: Contacts,
306    ) -> Result<AttachmentPointer, AttachmentUploadError>
307    where
308        Contacts: IntoIterator<Item = ContactDetails>,
309    {
310        use prost::Message;
311        let mut out = Vec::new();
312        for contact in contacts {
313            contact
314                .encode_length_delimited(&mut out)
315                .expect("infallible encoding");
316            // XXX add avatar here
317        }
318
319        let spec = AttachmentSpec {
320            content_type: "application/octet-stream".into(),
321            length: out.len(),
322            file_name: None,
323            preview: None,
324            voice_note: None,
325            borderless: None,
326            width: None,
327            height: None,
328            caption: None,
329            blur_hash: None,
330        };
331        self.upload_attachment(spec, out, &mut rng()).await
332    }
333
334    /// Return whether we have to prepare sync messages for other devices
335    ///
336    /// - If we are the main registered device, and there are established sub-device sessions (linked clients), return true
337    /// - If we are a secondary linked device, return true
338    async fn is_multi_device(&self) -> bool {
339        if self.device_id == *DEFAULT_DEVICE_ID {
340            self.protocol_store
341                .get_sub_device_sessions(&self.local_aci.into())
342                .await
343                .is_ok_and(|s| !s.is_empty())
344        } else {
345            true
346        }
347    }
348
349    /// Send a message `content` to a single `recipient`.
350    #[tracing::instrument(
351        skip(self, unidentified_access, message),
352        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
353    )]
354    pub async fn send_message(
355        &mut self,
356        recipient: &ServiceId,
357        mut unidentified_access: Option<UnidentifiedAccess>,
358        message: impl Into<ContentBody>,
359        timestamp: u64,
360        include_pni_signature: bool,
361        online: bool,
362    ) -> SendMessageResult {
363        let content_body = message.into();
364        let message_to_self = recipient == &self.local_aci;
365        let sync_message =
366            matches!(content_body, ContentBody::SynchronizeMessage(..));
367        let is_multi_device = self.is_multi_device().await;
368
369        use crate::proto::data_message::Flags;
370
371        let end_session = match &content_body {
372            ContentBody::DataMessage(message) => {
373                message.flags == Some(Flags::EndSession as u32)
374            },
375            _ => false,
376        };
377
378        // only send a sync message when sending to self and skip the rest of the process
379        if message_to_self && is_multi_device && !sync_message {
380            debug!("sending note to self");
381            if let Some(sync_message) = self
382                .create_multi_device_sent_transcript_content(
383                    Some(recipient),
384                    content_body,
385                    timestamp,
386                    None,
387                )
388            {
389                return self
390                    .try_send_message(
391                        *recipient,
392                        None,
393                        &sync_message,
394                        timestamp,
395                        include_pni_signature,
396                        online,
397                    )
398                    .await;
399            } else {
400                error!("could not create sync message from message to self");
401                return SendMessageResult::Err(
402                    MessageSenderError::NoMessagesToSend,
403                );
404            }
405        }
406
407        // don't send session enders as sealed sender
408        // sync messages are never sent as unidentified (reasons unclear), see: https://github.com/signalapp/Signal-Android/blob/main/libsignal-service/src/main/java/org/whispersystems/signalservice/api/SignalServiceMessageSender.java#L779
409        if end_session || sync_message {
410            unidentified_access.take();
411        }
412
413        // try to send the original message to all the recipient's devices
414        let result = self
415            .try_send_message(
416                *recipient,
417                unidentified_access.as_ref(),
418                &content_body,
419                timestamp,
420                include_pni_signature,
421                online,
422            )
423            .await;
424
425        let needs_sync = match &result {
426            Ok(SentMessage { needs_sync, .. }) => *needs_sync,
427            _ => false,
428        };
429
430        if needs_sync || is_multi_device {
431            let sync_body = if sync_message {
432                Some(content_body)
433            } else {
434                // Only some ContentBody types are syncable to self,
435                // not getting a content body to sync is not an error.
436                self.create_multi_device_sent_transcript_content(
437                    Some(recipient),
438                    content_body,
439                    timestamp,
440                    Some(&result),
441                )
442            };
443            if let Some(body) = sync_body {
444                debug!("sending multi-device sync message");
445                self.try_send_message(
446                    self.local_aci.into(),
447                    None,
448                    &body,
449                    timestamp,
450                    false,
451                    false,
452                )
453                .await?;
454            }
455        }
456
457        if end_session {
458            let n = self.protocol_store.delete_all_sessions(recipient).await?;
459            tracing::debug!(
460                "ended {} sessions with {}",
461                n,
462                recipient.raw_uuid()
463            );
464        }
465
466        result
467    }
468
469    /// Send a message to the recipients in a group.
470    ///
471    /// Recipients are a list of tuples, each containing:
472    /// - The recipient's address
473    /// - The recipient's unidentified access
474    /// - Whether the recipient requires a PNI signature
475    #[tracing::instrument(
476        skip(self, recipients, message),
477        fields(recipients = recipients.as_ref().len()),
478    )]
479    pub async fn send_message_to_group(
480        &mut self,
481        recipients: impl AsRef<[(ServiceId, Option<UnidentifiedAccess>, bool)]>,
482        message: impl Into<ContentBody>,
483        timestamp: u64,
484        online: bool,
485    ) -> Vec<SendMessageResult> {
486        let content_body: ContentBody = message.into();
487        let mut results = vec![];
488
489        let mut needs_sync_in_results = false;
490
491        for (recipient, unidentified_access, include_pni_signature) in
492            recipients.as_ref()
493        {
494            let result = self
495                .try_send_message(
496                    *recipient,
497                    unidentified_access.as_ref(),
498                    &content_body,
499                    timestamp,
500                    *include_pni_signature,
501                    online,
502                )
503                .await;
504
505            match result {
506                Ok(SentMessage { needs_sync, .. }) if needs_sync => {
507                    needs_sync_in_results = true;
508                },
509                _ => (),
510            };
511
512            results.push(result);
513        }
514
515        // we only need to send a synchronization message once
516        if needs_sync_in_results || self.is_multi_device().await {
517            if let Some(sync_message) = self
518                .create_multi_device_sent_transcript_content(
519                    None,
520                    content_body.clone(),
521                    timestamp,
522                    &results,
523                )
524            {
525                // Note: the result of sending a sync message is not included in results
526                // See Signal Android `SignalServiceMessageSender.java:2817`
527                if let Err(error) = self
528                    .try_send_message(
529                        self.local_aci.into(),
530                        None,
531                        &sync_message,
532                        timestamp,
533                        false, // XXX: maybe the sync device does want a PNI signature?
534                        false,
535                    )
536                    .await
537                {
538                    error!(%error, "failed to send a synchronization message");
539                }
540            } else {
541                error!("could not create sync message from a group message")
542            }
543        }
544
545        results
546    }
547
548    /// Send a message (`content`) to an address (`recipient`).
549    #[tracing::instrument(
550        level = "trace",
551        skip(self, unidentified_access, content_body, recipient),
552        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
553    )]
554    async fn try_send_message(
555        &mut self,
556        recipient: ServiceId,
557        mut unidentified_access: Option<&UnidentifiedAccess>,
558        content_body: &ContentBody,
559        timestamp: u64,
560        include_pni_signature: bool,
561        online: bool,
562    ) -> SendMessageResult {
563        trace!("trying to send a message");
564
565        use prost::Message;
566
567        let mut content = content_body.clone().into_proto();
568        if include_pni_signature {
569            content.pni_signature_message = Some(self.create_pni_signature()?);
570        }
571
572        let content_bytes = content.encode_to_vec();
573
574        let mut rng = rng();
575
576        for _ in 0..4u8 {
577            let Some(EncryptedMessages {
578                messages,
579                used_identity_key,
580            }) = self
581                .create_encrypted_messages(
582                    &recipient,
583                    unidentified_access.map(|x| &x.certificate),
584                    &content_bytes,
585                )
586                .await?
587            else {
588                // this can happen for example when a device is primary, without any secondaries
589                // and we send a message to ourselves (which is only a SyncMessage { sent: ... })
590                // addressed to self
591                return Err(MessageSenderError::NoMessagesToSend);
592            };
593
594            let messages = OutgoingPushMessages {
595                destination: recipient,
596                timestamp,
597                messages,
598                online,
599            };
600
601            let send = if let Some(unidentified) = &unidentified_access {
602                tracing::debug!("sending via unidentified");
603                self.unidentified_ws
604                    .send_messages_unidentified(messages, unidentified)
605                    .await
606            } else {
607                tracing::debug!("sending identified");
608                self.identified_ws.send_messages(messages).await
609            };
610
611            match send {
612                Ok(SendMessageResponse { needs_sync }) => {
613                    tracing::debug!("message sent!");
614                    return Ok(SentMessage {
615                        recipient,
616                        used_identity_key,
617                        unidentified: unidentified_access.is_some(),
618                        needs_sync,
619                    });
620                },
621                Err(ServiceError::Unauthorized)
622                    if unidentified_access.is_some() =>
623                {
624                    tracing::trace!("unauthorized error using unidentified; retry over identified");
625                    unidentified_access = None;
626                },
627                Err(ServiceError::MismatchedDevicesException(ref m)) => {
628                    tracing::debug!("{:?}", m);
629                    for extra_device_id in &m.extra_devices {
630                        tracing::debug!(
631                            "dropping session with device {}",
632                            extra_device_id
633                        );
634                        self.protocol_store
635                            .delete_service_addr_device_session(
636                                &recipient
637                                    .to_protocol_address(*extra_device_id)?,
638                            )
639                            .await?;
640                    }
641
642                    for missing_device_id in &m.missing_devices {
643                        tracing::debug!(
644                            "creating session with missing device {}",
645                            missing_device_id
646                        );
647                        let remote_address = recipient
648                            .to_protocol_address(*missing_device_id)?;
649                        let pre_key = self
650                            .identified_ws
651                            .get_pre_key(&recipient, *missing_device_id)
652                            .await?;
653
654                        process_prekey_bundle(
655                            &remote_address,
656                            &self
657                                .local_aci
658                                .to_protocol_address(self.device_id)
659                                .expect("valid device id"),
660                            &mut self.protocol_store.clone(),
661                            &mut self.protocol_store,
662                            &pre_key,
663                            SystemTime::now(),
664                            &mut rng,
665                        )
666                        .await
667                        .map_err(|e| {
668                            error!("failed to create session: {}", e);
669                            MessageSenderError::UntrustedIdentity {
670                                address: recipient,
671                            }
672                        })?;
673                    }
674                },
675                Err(ServiceError::StaleDevices(ref m)) => {
676                    tracing::debug!("{:?}", m);
677                    for extra_device_id in &m.stale_devices {
678                        tracing::debug!(
679                            "dropping session with device {}",
680                            extra_device_id
681                        );
682                        self.protocol_store
683                            .delete_service_addr_device_session(
684                                &recipient
685                                    .to_protocol_address(*extra_device_id)?,
686                            )
687                            .await?;
688                    }
689                },
690                Err(ServiceError::ProofRequiredError(ref p)) => {
691                    tracing::debug!("{:?}", p);
692                    return Err(MessageSenderError::ProofRequired {
693                        token: p.token.clone(),
694                        options: p.options.clone(),
695                    });
696                },
697                Err(ServiceError::NotFoundError) => {
698                    tracing::debug!("Not found when sending a message");
699                    return Err(MessageSenderError::NotFound {
700                        service_id: recipient,
701                    });
702                },
703                Err(e) => {
704                    tracing::debug!(
705                        "Default error handler for ws.send_messages: {}",
706                        e
707                    );
708                    return Err(MessageSenderError::ServiceError(e));
709                },
710            }
711        }
712
713        Err(MessageSenderError::MaximumRetriesLimitExceeded)
714    }
715
716    /// Upload contact details to the CDN and send a sync message
717    #[tracing::instrument(
718        skip(self, unidentified_access, contacts, recipient),
719        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
720    )]
721    pub async fn send_contact_details<Contacts>(
722        &mut self,
723        recipient: &ServiceId,
724        unidentified_access: Option<UnidentifiedAccess>,
725        // XXX It may be interesting to use an intermediary type,
726        //     instead of ContactDetails directly,
727        //     because it allows us to add the avatar content.
728        contacts: Contacts,
729        online: bool,
730        complete: bool,
731    ) -> Result<(), MessageSenderError>
732    where
733        Contacts: IntoIterator<Item = ContactDetails>,
734    {
735        let ptr = self.upload_contact_details(contacts).await?;
736
737        let msg = SyncMessage {
738            contacts: Some(sync_message::Contacts {
739                blob: Some(ptr),
740                complete: Some(complete),
741            }),
742            ..SyncMessage::with_padding(&mut rng())
743        };
744
745        self.send_sync_message(msg).await?;
746
747        Ok(())
748    }
749
750    /// Send `MessageRequestResponse` synchronization message with either a recipient ACI or a GroupV2 ID
751    #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))]
752    pub async fn send_message_request_response(
753        &mut self,
754        recipient: &ServiceId,
755        thread: &ThreadIdentifier,
756        action: message_request_response::Type,
757    ) -> Result<(), MessageSenderError> {
758        let message_request_response = Some(match thread {
759            ThreadIdentifier::Aci(aci) => {
760                tracing::debug!(
761                    "sending message request response {:?} for recipient {:?}",
762                    action,
763                    aci
764                );
765                MessageRequestResponse {
766                    thread_aci: Some(aci.to_string()),
767                    thread_aci_binary: Some(aci.into_bytes().to_vec()),
768                    group_id: None,
769                    r#type: Some(action.into()),
770                }
771            },
772            ThreadIdentifier::Group(id) => {
773                tracing::debug!(
774                    "sending message request response {:?} for group {:?}",
775                    action,
776                    id
777                );
778                MessageRequestResponse {
779                    thread_aci: None,
780                    thread_aci_binary: None,
781                    group_id: Some(id.to_vec()),
782                    r#type: Some(action.into()),
783                }
784            },
785        });
786
787        let msg = SyncMessage {
788            message_request_response,
789            ..SyncMessage::with_padding(&mut rng())
790        };
791
792        let ts = Utc::now().timestamp_millis() as u64;
793        self.send_message(recipient, None, msg, ts, false, false)
794            .await?;
795
796        Ok(())
797    }
798
799    /// Send a `SyncMessage` to own devices, if any.
800    pub async fn send_sync_message(
801        &mut self,
802        sync: SyncMessage,
803    ) -> Result<(), MessageSenderError> {
804        if self.is_multi_device().await {
805            let content = sync.into();
806            let timestamp = Utc::now().timestamp_millis() as u64;
807            debug!(
808                "sending multi-device sync message with content {content:?}"
809            );
810            self.try_send_message(
811                self.local_aci.into(),
812                None,
813                &content,
814                timestamp,
815                false,
816                false,
817            )
818            .await?;
819        }
820        Ok(())
821    }
822
823    /// Send a `SyncMessage` request message
824    #[tracing::instrument(skip(self))]
825    pub async fn send_sync_message_request(
826        &mut self,
827        recipient: &ServiceId,
828        request_type: sync_message::request::Type,
829    ) -> Result<(), MessageSenderError> {
830        if self.device_id == *DEFAULT_DEVICE_ID {
831            return Err(MessageSenderError::SendSyncMessageError(request_type));
832        }
833
834        let msg = SyncMessage {
835            request: Some(sync_message::Request {
836                r#type: Some(request_type.into()),
837            }),
838            ..SyncMessage::with_padding(&mut rng())
839        };
840        self.send_sync_message(msg).await?;
841
842        Ok(())
843    }
844
845    #[tracing::instrument(level = "trace", skip(self))]
846    fn create_pni_signature(
847        &mut self,
848    ) -> Result<crate::proto::PniSignatureMessage, MessageSenderError> {
849        let mut rng = rng();
850        let signature = self
851            .pni_identity
852            .expect("PNI key set when PNI signature requested")
853            .sign_alternate_identity(
854                self.aci_identity.identity_key(),
855                &mut rng,
856            )?;
857        Ok(crate::proto::PniSignatureMessage {
858            pni: Some(self.local_pni.service_id_binary()),
859            signature: Some(signature.into()),
860        })
861    }
862
863    // Equivalent with `getEncryptedMessages`
864    #[tracing::instrument(
865        level = "trace",
866        skip(self, unidentified_access, content),
867        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
868    )]
869    async fn create_encrypted_messages(
870        &mut self,
871        recipient: &ServiceId,
872        unidentified_access: Option<&SenderCertificate>,
873        content: &[u8],
874    ) -> Result<Option<EncryptedMessages>, MessageSenderError> {
875        let mut messages = vec![];
876
877        let mut devices: HashSet<DeviceId> = self
878            .protocol_store
879            .get_sub_device_sessions(recipient)
880            .await?
881            .into_iter()
882            .collect();
883
884        // always send to the primary device no matter what
885        devices.insert(*DEFAULT_DEVICE_ID);
886
887        // never try to send messages to the sender device
888        match recipient {
889            ServiceId::Aci(aci) => {
890                if *aci == self.local_aci {
891                    devices.remove(&self.device_id);
892                }
893            },
894            ServiceId::Pni(pni) => {
895                if *pni == self.local_pni {
896                    devices.remove(&self.device_id);
897                }
898            },
899        };
900
901        for device_id in devices {
902            trace!("sending message to device {}", device_id);
903            // `create_encrypted_message` may fail with `SessionNotFound` if the session is corrupted;
904            // see https://github.com/whisperfish/libsignal-client/commit/601454d20.
905            // If this happens, delete the session and retry.
906            for _attempt in 0..2 {
907                match self
908                    .create_encrypted_message(
909                        recipient,
910                        unidentified_access,
911                        device_id,
912                        content,
913                    )
914                    .await
915                {
916                    Ok(message) => {
917                        messages.push(message);
918                        break;
919                    },
920                    Err(MessageSenderError::ServiceError(
921                        ServiceError::SignalProtocolError(
922                            SignalProtocolError::SessionNotFound(addr),
923                        ),
924                    )) => {
925                        // SessionNotFound is returned on certain session corruption.
926                        // Since delete_session *creates* a session if it doesn't exist,
927                        // the NotFound error is an indicator of session corruption.
928                        // Try to delete this session, if it gets succesfully deleted, retry.  Otherwise, fail.
929                        tracing::warn!("Potential session corruption for {}, deleting session", addr);
930                        match self.protocol_store.delete_session(&addr).await {
931                            Ok(()) => continue,
932                            Err(error) => {
933                                tracing::warn!(%error, %addr, "failed to delete session");
934                                return Err(
935                                    SignalProtocolError::SessionNotFound(addr)
936                                        .into(),
937                                );
938                            },
939                        }
940                    },
941                    Err(e) => return Err(e),
942                }
943            }
944        }
945
946        if messages.is_empty() {
947            Ok(None)
948        } else {
949            Ok(Some(EncryptedMessages {
950                messages,
951                used_identity_key: self
952                    .protocol_store
953                    .get_identity(
954                        &recipient.to_protocol_address(*DEFAULT_DEVICE_ID),
955                    )
956                    .await?
957                    .ok_or(MessageSenderError::UntrustedIdentity {
958                        address: *recipient,
959                    })?,
960            }))
961        }
962    }
963
964    /// Equivalent to `getEncryptedMessage`
965    ///
966    /// When no session with the recipient exists, we need to create one.
967    #[tracing::instrument(
968        level = "trace",
969        skip(self, unidentified_access, content),
970        fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()),
971    )]
972    pub(crate) async fn create_encrypted_message(
973        &mut self,
974        recipient: &ServiceId,
975        unidentified_access: Option<&SenderCertificate>,
976        device_id: DeviceId,
977        content: &[u8],
978    ) -> Result<OutgoingPushMessage, MessageSenderError> {
979        let recipient_protocol_address =
980            recipient.to_protocol_address(device_id);
981
982        tracing::trace!(
983            "encrypting message for {}",
984            recipient_protocol_address
985        );
986
987        // establish a session with the recipient/device if necessary
988        // no need to establish a session with ourselves (and our own current device)
989        if self
990            .protocol_store
991            .load_session(&recipient_protocol_address)
992            .await?
993            .is_none()
994        {
995            info!(
996                "establishing new session with {}",
997                recipient_protocol_address
998            );
999            let pre_keys = match self
1000                .identified_ws
1001                .get_pre_keys(recipient, device_id)
1002                .await
1003            {
1004                Ok(ok) => {
1005                    tracing::trace!("Get prekeys OK");
1006                    ok
1007                },
1008                Err(ServiceError::NotFoundError) => {
1009                    return Err(MessageSenderError::NotFound {
1010                        service_id: *recipient,
1011                    });
1012                },
1013                Err(e) => Err(e)?,
1014            };
1015
1016            let mut rng = rng();
1017
1018            for pre_key_bundle in pre_keys {
1019                if recipient == &self.local_aci
1020                    && self.device_id == pre_key_bundle.device_id()?
1021                {
1022                    trace!("not establishing a session with myself!");
1023                    continue;
1024                }
1025
1026                let pre_key_address = get_preferred_protocol_address(
1027                    &self.protocol_store,
1028                    recipient,
1029                    pre_key_bundle.device_id()?,
1030                )
1031                .await?;
1032
1033                process_prekey_bundle(
1034                    &pre_key_address,
1035                    &self
1036                        .local_aci
1037                        .to_protocol_address(self.device_id)
1038                        .expect("valid device id"),
1039                    &mut self.protocol_store.clone(),
1040                    &mut self.protocol_store,
1041                    &pre_key_bundle,
1042                    SystemTime::now(),
1043                    &mut rng,
1044                )
1045                .await?;
1046            }
1047        }
1048
1049        let message = self
1050            .cipher
1051            .encrypt(
1052                &recipient_protocol_address,
1053                unidentified_access,
1054                content,
1055                &mut rng(),
1056            )
1057            .instrument(tracing::trace_span!("encrypting message"))
1058            .await?;
1059
1060        Ok(message)
1061    }
1062
1063    fn create_multi_device_sent_transcript_content<'a>(
1064        &mut self,
1065        recipient: Option<&ServiceId>,
1066        content_body: ContentBody,
1067        timestamp: u64,
1068        send_message_results: impl IntoIterator<Item = &'a SendMessageResult>,
1069    ) -> Option<ContentBody> {
1070        use sync_message::sent::UnidentifiedDeliveryStatus;
1071        let (message, edit_message) = match content_body {
1072            ContentBody::DataMessage(m) => (Some(m), None),
1073            ContentBody::EditMessage(m) => (None, Some(m)),
1074            content_body => {
1075                tracing::trace!(?content_body, "not syncing to self");
1076                return None;
1077            },
1078        };
1079        let unidentified_status: Vec<UnidentifiedDeliveryStatus> =
1080            send_message_results
1081                .into_iter()
1082                .filter_map(|result| result.as_ref().ok())
1083                .map(|sent| {
1084                    let SentMessage {
1085                        recipient,
1086                        unidentified,
1087                        used_identity_key,
1088                        ..
1089                    } = sent;
1090                    UnidentifiedDeliveryStatus {
1091                        destination_service_id: Some(
1092                            recipient.service_id_string(),
1093                        ),
1094                        destination_service_id_binary: Some(
1095                            recipient.service_id_binary(),
1096                        ),
1097                        unidentified: Some(*unidentified),
1098                        destination_pni_identity_key: Some(
1099                            used_identity_key.serialize().into(),
1100                        ),
1101                    }
1102                })
1103                .collect();
1104        Some(ContentBody::SynchronizeMessage(SyncMessage {
1105            sent: Some(sync_message::Sent {
1106                destination_service_id: recipient
1107                    .map(ServiceId::service_id_string),
1108                destination_service_id_binary: recipient
1109                    .map(ServiceId::service_id_binary),
1110                destination_e164: None,
1111                expiration_start_timestamp: message
1112                    .as_ref()
1113                    .and_then(|m| m.expire_timer)
1114                    .map(|_| timestamp),
1115                message,
1116                edit_message,
1117                timestamp: Some(timestamp),
1118                unidentified_status,
1119                ..Default::default()
1120            }),
1121            ..SyncMessage::with_padding(&mut rng())
1122        }))
1123    }
1124}