libsignal_service/
account_manager.rs

1use base64::prelude::*;
2use libsignal_core::DeviceId;
3use phonenumber::PhoneNumber;
4use rand::{CryptoRng, Rng};
5use reqwest::Method;
6use std::collections::HashMap;
7use std::convert::{TryFrom, TryInto};
8
9use aes::cipher::{KeyIvInit, StreamCipher as _};
10use hmac::digest::Output;
11use hmac::{Hmac, Mac};
12use libsignal_protocol::{
13    kem, Aci, GenericSignedPreKey, IdentityKey, IdentityKeyPair,
14    IdentityKeyStore, KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore,
15    PublicKey, SenderKeyStore, ServiceIdKind, SignedPreKeyRecord, Timestamp,
16};
17use prost::Message;
18use serde::{Deserialize, Serialize};
19use sha2::{Digest, Sha256};
20use tracing_futures::Instrument;
21use zkgroup::profiles::ProfileKey;
22
23use crate::content::ContentBody;
24use crate::master_key::MasterKey;
25use crate::pre_keys::{
26    KyberPreKeyEntity, PreKeyEntity, PreKeysStore, SignedPreKeyEntity,
27    PRE_KEY_BATCH_SIZE, PRE_KEY_MINIMUM,
28};
29use crate::prelude::{MessageSender, MessageSenderError};
30use crate::proto::sync_message::PniChangeNumber;
31use crate::proto::{DeviceName, SyncMessage};
32use crate::provisioning::generate_registration_id;
33use crate::push_service::{
34    AvatarWrite, HttpAuthOverride, ReqwestExt, DEFAULT_DEVICE_ID,
35};
36use crate::sender::OutgoingPushMessage;
37use crate::service_address::ServiceIdExt;
38use crate::session_store::SessionStoreExt;
39use crate::timestamp::TimestampExt as _;
40use crate::utils::{random_length_padding, BASE64_RELAXED};
41use crate::websocket::account::DeviceInfo;
42use crate::websocket::keys::PreKeyStatus;
43use crate::websocket::registration::{
44    CaptchaAttributes, DeviceActivationRequest, RegistrationMethod,
45    VerifyAccountResponse,
46};
47use crate::websocket::{self, SignalWebSocket};
48use crate::{
49    configuration::{Endpoint, ServiceCredentials},
50    pre_keys::PreKeyState,
51    profile_cipher::{ProfileCipher, ProfileCipherError},
52    profile_name::ProfileName,
53    proto::{ProvisionEnvelope, ProvisionMessage, ProvisioningVersion},
54    provisioning::{ProvisioningCipher, ProvisioningError},
55    push_service::{PushService, ServiceError},
56    utils::serde_base64,
57    websocket::account::AccountAttributes,
58};
59
60type Aes256Ctr128BE = ctr::Ctr128BE<aes::Aes256>;
61
62pub struct AccountManager {
63    service: PushService,
64    websocket: SignalWebSocket<websocket::Identified>,
65    profile_key: Option<ProfileKey>,
66}
67
68#[derive(thiserror::Error, Debug)]
69pub enum ProfileManagerError {
70    #[error(transparent)]
71    ServiceError(#[from] ServiceError),
72    #[error(transparent)]
73    ProfileCipherError(#[from] ProfileCipherError),
74}
75
76#[derive(Debug, Default, Serialize, Deserialize, Clone)]
77pub struct Profile {
78    pub name: Option<ProfileName<String>>,
79    pub about: Option<String>,
80    pub about_emoji: Option<String>,
81    pub avatar: Option<String>,
82    pub unrestricted_unidentified_access: bool,
83}
84
85impl AccountManager {
86    pub fn new(
87        service: PushService,
88        websocket: SignalWebSocket<websocket::Identified>,
89        profile_key: Option<ProfileKey>,
90    ) -> Self {
91        Self {
92            service,
93            websocket,
94            profile_key,
95        }
96    }
97
98    #[allow(clippy::too_many_arguments)]
99    #[tracing::instrument(skip(self, protocol_store))]
100    pub async fn check_pre_keys<P: PreKeysStore>(
101        &mut self,
102        protocol_store: &mut P,
103        service_id_kind: ServiceIdKind,
104    ) -> Result<bool, ServiceError> {
105        let Some(signed_prekey_id) = protocol_store.signed_prekey_id().await?
106        else {
107            tracing::warn!("No signed prekey found");
108            return Ok(false);
109        };
110        // XXX: should we instead use the `load_last_resort_kyber_pre_keys` method? Or refactor
111        //      those whole traits?
112        let Some(kyber_prekey_id) =
113            protocol_store.last_resort_kyber_prekey_id().await?
114        else {
115            tracing::warn!("No last resort kyber prekey found");
116            return Ok(false);
117        };
118
119        let signed_prekey =
120            protocol_store.get_signed_pre_key(signed_prekey_id).await?;
121        let kyber_prekey =
122            protocol_store.get_kyber_pre_key(kyber_prekey_id).await?;
123
124        // `SHA256(identityKeyBytes || signedEcPreKeyId || signedEcPreKeyIdBytes || lastResortKeyId || lastResortKeyBytes)`
125        let mut hash = Sha256::default();
126        hash.update(
127            protocol_store
128                .get_identity_key_pair()
129                .await?
130                .public_key()
131                .serialize(),
132        );
133        hash.update((u32::from(signed_prekey_id) as u64).to_be_bytes());
134        hash.update(signed_prekey.public_key()?.serialize());
135        hash.update((u32::from(kyber_prekey_id) as u64).to_be_bytes());
136        hash.update(kyber_prekey.public_key()?.serialize());
137
138        self.websocket
139            .check_pre_keys(service_id_kind, hash.finalize().as_ref())
140            .await
141    }
142
143    /// Checks the availability of pre-keys, and updates them as necessary.
144    ///
145    /// Parameters are the protocol's `StoreContext`, and the offsets for the next pre-key and
146    /// signed pre-keys.
147    ///
148    /// Equivalent to Java's RefreshPreKeysJob
149    #[allow(clippy::too_many_arguments)]
150    #[tracing::instrument(skip(self, protocol_store))]
151    pub async fn update_pre_key_bundle<P: PreKeysStore>(
152        &mut self,
153        protocol_store: &mut P,
154        service_id_kind: ServiceIdKind,
155        use_last_resort_key: bool,
156    ) -> Result<(), ServiceError> {
157        let prekey_status = match self
158            .websocket
159            .get_pre_key_status(service_id_kind)
160            .instrument(tracing::span!(
161                tracing::Level::DEBUG,
162                "Fetching pre key status"
163            ))
164            .await
165        {
166            Ok(status) => status,
167            Err(ServiceError::Unauthorized) => {
168                tracing::info!("Got Unauthorized when fetching pre-key status. Assuming first installment.");
169                // Additionally, the second PUT request will fail if this really comes down to an
170                // authorization failure.
171                PreKeyStatus {
172                    count: 0,
173                    pq_count: 0,
174                }
175            },
176            Err(e) => return Err(e),
177        };
178        tracing::trace!("Remaining pre-keys on server: {:?}", prekey_status);
179
180        let check_pre_keys = self
181            .check_pre_keys(protocol_store, service_id_kind)
182            .instrument(tracing::span!(
183                tracing::Level::DEBUG,
184                "Checking pre keys"
185            ))
186            .await?;
187        if !check_pre_keys {
188            tracing::info!(
189                "Last resort pre-keys are not up to date; refreshing."
190            );
191        } else {
192            tracing::debug!("Last resort pre-keys are up to date.");
193        }
194
195        // XXX We should honestly compare the pre-key count with the number of pre-keys we have
196        // locally. If we have more than the server, we should upload them.
197        // Currently the trait doesn't allow us to do that, so we just upload the batch size and
198        // pray.
199        if check_pre_keys
200            && (prekey_status.count >= PRE_KEY_MINIMUM
201                && prekey_status.pq_count >= PRE_KEY_MINIMUM)
202        {
203            if protocol_store.signed_pre_keys_count().await? > 0
204                && protocol_store.kyber_pre_keys_count(true).await? > 0
205                && protocol_store.signed_prekey_id().await?.is_some()
206                && protocol_store
207                    .last_resort_kyber_prekey_id()
208                    .await?
209                    .is_some()
210            {
211                tracing::debug!("Available keys sufficient");
212                return Ok(());
213            }
214            tracing::info!("Available keys sufficient; forcing refresh.");
215        }
216
217        let identity_key_pair = protocol_store
218            .get_identity_key_pair()
219            .instrument(tracing::trace_span!("get identity key pair"))
220            .await?;
221
222        let last_resort_keys = protocol_store
223            .load_last_resort_kyber_pre_keys()
224            .instrument(tracing::trace_span!("fetch last resort key"))
225            .await?;
226
227        // XXX: Maybe this check should be done in the generate_pre_keys function?
228        let has_last_resort_key = !last_resort_keys.is_empty();
229
230        let (pre_keys, signed_pre_key, pq_pre_keys, pq_last_resort_key) =
231            crate::pre_keys::replenish_pre_keys(
232                protocol_store,
233                &mut rand::rng(),
234                &identity_key_pair,
235                use_last_resort_key && !has_last_resort_key,
236                PRE_KEY_BATCH_SIZE,
237                PRE_KEY_BATCH_SIZE,
238            )
239            .await?;
240
241        let pq_last_resort_key = if has_last_resort_key {
242            if last_resort_keys.len() > 1 {
243                tracing::warn!(
244                    "More than one last resort key found; only uploading first"
245                );
246            }
247            Some(KyberPreKeyEntity::try_from(last_resort_keys[0].clone())?)
248        } else {
249            pq_last_resort_key
250                .map(KyberPreKeyEntity::try_from)
251                .transpose()?
252        };
253
254        let identity_key = *identity_key_pair.identity_key();
255
256        let pre_keys: Vec<_> = pre_keys
257            .into_iter()
258            .map(PreKeyEntity::try_from)
259            .collect::<Result<_, _>>()?;
260        let signed_pre_key = signed_pre_key.try_into()?;
261        let pq_pre_keys: Vec<_> = pq_pre_keys
262            .into_iter()
263            .map(KyberPreKeyEntity::try_from)
264            .collect::<Result<_, _>>()?;
265
266        tracing::info!(
267            "Uploading pre-keys: {} one-time, {} PQ, {} PQ last resort",
268            pre_keys.len(),
269            pq_pre_keys.len(),
270            if pq_last_resort_key.is_some() { 1 } else { 0 }
271        );
272
273        let pre_key_state = PreKeyState {
274            pre_keys,
275            signed_pre_key,
276            identity_key,
277            pq_pre_keys,
278            pq_last_resort_key,
279        };
280
281        self.websocket
282            .register_pre_keys(service_id_kind, pre_key_state)
283            .instrument(tracing::span!(
284                tracing::Level::DEBUG,
285                "Uploading pre keys"
286            ))
287            .await?;
288
289        Ok(())
290    }
291
292    async fn new_device_provisioning_code(
293        &mut self,
294    ) -> Result<String, ServiceError> {
295        #[derive(serde::Deserialize)]
296        #[serde(rename_all = "camelCase")]
297        struct DeviceCode {
298            verification_code: String,
299        }
300
301        let dc: DeviceCode = self
302            .service
303            .request(
304                Method::GET,
305                Endpoint::service("/v1/devices/provisioning/code"),
306                HttpAuthOverride::NoOverride,
307            )?
308            .send()
309            .await?
310            .service_error_for_status()
311            .await?
312            .json()
313            .await?;
314
315        Ok(dc.verification_code)
316    }
317
318    async fn send_provisioning_message(
319        &mut self,
320        destination: &str,
321        env: ProvisionEnvelope,
322    ) -> Result<(), ServiceError> {
323        #[derive(serde::Serialize)]
324        struct ProvisioningMessage {
325            body: String,
326        }
327
328        let body = env.encode_to_vec();
329
330        self.service
331            .request(
332                Method::PUT,
333                Endpoint::service(format!("/v1/provisioning/{destination}")),
334                HttpAuthOverride::NoOverride,
335            )?
336            .json(&ProvisioningMessage {
337                body: BASE64_RELAXED.encode(body),
338            })
339            .send()
340            .await?
341            .service_error_for_status()
342            .await?;
343
344        Ok(())
345    }
346
347    /// Link a new device, given a tsurl.
348    ///
349    /// Equivalent of Java's `AccountManager::addDevice()`
350    ///
351    /// When calling this, make sure that UnidentifiedDelivery is disabled, ie., that your
352    /// application does not send any unidentified messages before linking is complete.
353    /// Cfr.:
354    /// - `app/src/main/java/org/thoughtcrime/securesms/migrations/LegacyMigrationJob.java`:250 and;
355    /// - `app/src/main/java/org/thoughtcrime/securesms/DeviceActivity.java`:195
356    ///
357    /// ```java
358    /// TextSecurePreferences.setIsUnidentifiedDeliveryEnabled(context, false);
359    /// ```
360    pub async fn link_device<R: Rng + CryptoRng>(
361        &mut self,
362        csprng: &mut R,
363        url: url::Url,
364        aci_identity_store: &dyn IdentityKeyStore,
365        pni_identity_store: &dyn IdentityKeyStore,
366        credentials: ServiceCredentials,
367        master_key: Option<MasterKey>,
368    ) -> Result<(), ProvisioningError> {
369        let query: HashMap<_, _> = url.query_pairs().collect();
370        let ephemeral_id =
371            query.get("uuid").ok_or(ProvisioningError::MissingUuid)?;
372        let pub_key = query
373            .get("pub_key")
374            .ok_or(ProvisioningError::MissingPublicKey)?;
375
376        let pub_key = BASE64_RELAXED
377            .decode(&**pub_key)
378            .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
379        let pub_key = PublicKey::deserialize(&pub_key)
380            .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
381
382        let aci_identity_key_pair =
383            aci_identity_store.get_identity_key_pair().await?;
384        let pni_identity_key_pair =
385            pni_identity_store.get_identity_key_pair().await?;
386
387        if credentials.aci.is_none() {
388            tracing::warn!("No local ACI set");
389        }
390        if credentials.pni.is_none() {
391            tracing::warn!("No local PNI set");
392        }
393
394        let provisioning_code = self.new_device_provisioning_code().await?;
395
396        let msg = ProvisionMessage {
397            aci: credentials.aci.as_ref().map(|u| u.to_string()),
398            aci_binary: credentials.aci.map(|u| u.into_bytes().into()),
399            aci_identity_key_public: Some(
400                aci_identity_key_pair.public_key().serialize().into_vec(),
401            ),
402            aci_identity_key_private: Some(
403                aci_identity_key_pair.private_key().serialize(),
404            ),
405            number: Some(credentials.e164()),
406            pni_identity_key_public: Some(
407                pni_identity_key_pair.public_key().serialize().into_vec(),
408            ),
409            pni_identity_key_private: Some(
410                pni_identity_key_pair.private_key().serialize(),
411            ),
412            pni: credentials.pni.as_ref().map(uuid::Uuid::to_string),
413            pni_binary: credentials.pni.map(|u| u.into_bytes().into()),
414            profile_key: self.profile_key.as_ref().map(|x| x.bytes.to_vec()),
415            // CURRENT is not exposed by prost :(
416            provisioning_version: Some(i32::from(
417                ProvisioningVersion::TabletSupport,
418            ) as _),
419            provisioning_code: Some(provisioning_code),
420            read_receipts: None,
421            user_agent: None,
422            master_key: master_key.map(|x| x.into()),
423            ephemeral_backup_key: None,
424            account_entropy_pool: None,
425            media_root_backup_key: None,
426        };
427
428        let cipher = ProvisioningCipher::from_public(pub_key);
429
430        let encrypted = cipher.encrypt(csprng, msg)?;
431        self.send_provisioning_message(ephemeral_id, encrypted)
432            .await?;
433        Ok(())
434    }
435
436    pub async fn linked_devices(
437        &mut self,
438        aci_identity_store: &dyn IdentityKeyStore,
439    ) -> Result<Vec<DeviceInfo>, ServiceError> {
440        let device_infos = self.websocket.devices().await?;
441        let aci_identity_keypair =
442            aci_identity_store.get_identity_key_pair().await?;
443
444        device_infos
445            .into_iter()
446            .map(|i| {
447                Ok(DeviceInfo {
448                    id: i.id,
449                    name: i.name.and_then(|s| {
450                        match decrypt_device_name_from_device_info(
451                            &s,
452                            &aci_identity_keypair,
453                        ) {
454                            Ok(name) => Some(name),
455                            Err(e) => {
456                                tracing::error!("{e}");
457                                None
458                            },
459                        }
460                    }),
461                    registration_id: i.registration_id,
462                    last_seen: i.last_seen,
463                    created_at: decrypt_device_created_at_from_device_info(
464                        i.id,
465                        i.registration_id,
466                        &i.created_at_ciphertext,
467                        &aci_identity_keypair,
468                    )?,
469                })
470            })
471            .collect()
472    }
473
474    pub async fn register_account<
475        R: Rng + CryptoRng,
476        Aci: PreKeysStore + IdentityKeyStore,
477        Pni: PreKeysStore + IdentityKeyStore,
478    >(
479        &mut self,
480        csprng: &mut R,
481        registration_method: RegistrationMethod<'_>,
482        account_attributes: AccountAttributes,
483        aci_protocol_store: &mut Aci,
484        pni_protocol_store: &mut Pni,
485        skip_device_transfer: bool,
486    ) -> Result<VerifyAccountResponse, ProvisioningError> {
487        let aci_identity_key_pair = aci_protocol_store
488            .get_identity_key_pair()
489            .instrument(tracing::trace_span!("get ACI identity key pair"))
490            .await?;
491        let pni_identity_key_pair = pni_protocol_store
492            .get_identity_key_pair()
493            .instrument(tracing::trace_span!("get PNI identity key pair"))
494            .await?;
495
496        let (
497            _aci_pre_keys,
498            aci_signed_pre_key,
499            _aci_kyber_pre_keys,
500            aci_last_resort_kyber_prekey,
501        ) = crate::pre_keys::replenish_pre_keys(
502            aci_protocol_store,
503            csprng,
504            &aci_identity_key_pair,
505            true,
506            0,
507            0,
508        )
509        .await?;
510
511        let (
512            _pni_pre_keys,
513            pni_signed_pre_key,
514            _pni_kyber_pre_keys,
515            pni_last_resort_kyber_prekey,
516        ) = crate::pre_keys::replenish_pre_keys(
517            pni_protocol_store,
518            csprng,
519            &pni_identity_key_pair,
520            true,
521            0,
522            0,
523        )
524        .await?;
525
526        let aci_identity_key = aci_identity_key_pair.identity_key();
527        let pni_identity_key = pni_identity_key_pair.identity_key();
528
529        let dar = DeviceActivationRequest {
530            aci_signed_pre_key: aci_signed_pre_key.try_into()?,
531            pni_signed_pre_key: pni_signed_pre_key.try_into()?,
532            aci_pq_last_resort_pre_key: aci_last_resort_kyber_prekey
533                .expect("requested last resort prekey")
534                .try_into()?,
535            pni_pq_last_resort_pre_key: pni_last_resort_kyber_prekey
536                .expect("requested last resort prekey")
537                .try_into()?,
538        };
539
540        let result = self
541            .websocket
542            .submit_registration_request(
543                registration_method,
544                account_attributes,
545                skip_device_transfer,
546                aci_identity_key,
547                pni_identity_key,
548                dar,
549            )
550            .await?;
551
552        Ok(result)
553    }
554
555    /// Upload a profile
556    ///
557    /// Panics if no `profile_key` was set.
558    ///
559    /// Convenience method for
560    /// ```ignore
561    /// manager.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _>(uuid, name, about, about_emoji, _)
562    /// ```
563    /// in which the `retain_avatar` parameter sets whether to remove (`false`) or retain (`true`) the
564    /// currently set avatar.
565    pub async fn upload_versioned_profile_without_avatar<
566        R: Rng + CryptoRng,
567        S: AsRef<str>,
568    >(
569        &mut self,
570        aci: libsignal_protocol::Aci,
571        name: ProfileName<S>,
572        about: Option<String>,
573        about_emoji: Option<String>,
574        retain_avatar: bool,
575        csprng: &mut R,
576    ) -> Result<(), ProfileManagerError> {
577        self.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _, _>(
578            aci,
579            name,
580            about,
581            about_emoji,
582            if retain_avatar {
583                AvatarWrite::RetainAvatar
584            } else {
585                AvatarWrite::NoAvatar
586            },
587            csprng,
588        )
589        .await?;
590        Ok(())
591    }
592
593    pub async fn retrieve_profile(
594        &mut self,
595        address: Aci,
596    ) -> Result<Profile, ProfileManagerError> {
597        let profile_key =
598            self.profile_key.expect("set profile key in AccountManager");
599
600        let encrypted_profile = self
601            .websocket
602            .retrieve_profile_by_id(address, Some(profile_key))
603            .await?;
604
605        let profile_cipher = ProfileCipher::new(profile_key);
606        Ok(profile_cipher.decrypt(encrypted_profile)?)
607    }
608
609    /// Upload a profile
610    ///
611    /// Panics if no `profile_key` was set.
612    ///
613    /// Returns the avatar url path.
614    pub async fn upload_versioned_profile<
615        's,
616        C: std::io::Read + Send + 's,
617        R: Rng + CryptoRng,
618        S: AsRef<str>,
619    >(
620        &mut self,
621        aci: libsignal_protocol::Aci,
622        name: ProfileName<S>,
623        about: Option<String>,
624        about_emoji: Option<String>,
625        avatar: AvatarWrite<&'s mut C>,
626        csprng: &mut R,
627    ) -> Result<Option<String>, ProfileManagerError> {
628        let profile_key =
629            self.profile_key.expect("set profile key in AccountManager");
630        let profile_cipher = ProfileCipher::new(profile_key);
631
632        // Profile encryption
633        let name = profile_cipher.encrypt_name(name.as_ref(), csprng)?;
634        let about = about.unwrap_or_default();
635        let about = profile_cipher.encrypt_about(about, csprng)?;
636        let about_emoji = about_emoji.unwrap_or_default();
637        let about_emoji = profile_cipher.encrypt_emoji(about_emoji, csprng)?;
638
639        // If avatar -> upload
640        if matches!(avatar, AvatarWrite::NewAvatar(_)) {
641            // FIXME ProfileCipherOutputStream.java
642            // It's just AES GCM, but a bit of work to decently implement it with a stream.
643            unimplemented!("Setting avatar requires ProfileCipherStream")
644        }
645
646        let profile_key = profile_cipher.into_inner();
647        let commitment = profile_key.get_commitment(aci);
648        let profile_key_version = profile_key.get_profile_key_version(aci);
649
650        Ok(self
651            .websocket
652            .write_profile::<C, S>(
653                &profile_key_version,
654                &name,
655                &about,
656                &about_emoji,
657                &commitment,
658                avatar,
659            )
660            .await?)
661    }
662
663    /// Set profile attributes
664    ///
665    /// Signal Android does not allow unsetting voice/video.
666    pub async fn set_account_attributes(
667        &mut self,
668        attributes: AccountAttributes,
669    ) -> Result<(), ServiceError> {
670        self.websocket.set_account_attributes(attributes).await
671    }
672
673    /// Update (encrypted) device name
674    pub async fn update_device_name<R: Rng + CryptoRng>(
675        &mut self,
676        device_id: libsignal_core::DeviceId,
677        device_name: &str,
678        aci: Aci,
679        aci_identity_store: &dyn IdentityKeyStore,
680        csprng: &mut R,
681    ) -> Result<(), ServiceError> {
682        let addr = aci.to_protocol_address(device_id).unwrap();
683        let public_key = aci_identity_store.get_identity(&addr).await?;
684        let Some(public_key) = public_key else {
685            return Err(ServiceError::SendError {
686                reason: format!("public key for device {addr:?} not found"),
687            });
688        };
689        let encrypted_device_name =
690            encrypt_device_name(csprng, device_name, &public_key)?;
691
692        #[derive(Serialize)]
693        #[serde(rename_all = "camelCase")]
694        struct Data {
695            #[serde(with = "serde_base64")]
696            device_name: Vec<u8>,
697        }
698
699        self.service
700            .request(
701                Method::PUT,
702                Endpoint::service(format!(
703                    "/v1/accounts/name?deviceId={}",
704                    device_id
705                )),
706                HttpAuthOverride::NoOverride,
707            )?
708            .json(&Data {
709                device_name: encrypted_device_name.encode_to_vec(),
710            })
711            .send()
712            .await?
713            .service_error_for_status()
714            .await?;
715
716        Ok(())
717    }
718
719    /// Upload a proof-required reCaptcha token and response.
720    ///
721    /// Token gotten originally with HTTP status 428 response to sending a message.
722    /// Captcha gotten from user completing the challenge captcha.
723    ///
724    /// It's either a silent OK, or throws a ServiceError.
725    pub async fn submit_recaptcha_challenge(
726        &mut self,
727        token: &str,
728        captcha: &str,
729    ) -> Result<(), ServiceError> {
730        self.service
731            .request(
732                Method::PUT,
733                Endpoint::service("/v1/challenge"),
734                HttpAuthOverride::NoOverride,
735            )?
736            .json(&CaptchaAttributes {
737                challenge_type: "captcha",
738                token,
739                captcha,
740            })
741            .send()
742            .await?
743            .service_error_for_status()
744            .await?;
745
746        Ok(())
747    }
748
749    /// Initialize PNI on linked devices.
750    ///
751    /// Should be called as the primary device to migrate from pre-PNI to PNI.
752    ///
753    /// This is the equivalent of Android's PnpInitializeDevicesJob or iOS' PniHelloWorldManager.
754    #[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci = local_aci.service_id_string()))]
755    pub async fn pnp_initialize_devices<
756        R: Rng + CryptoRng,
757        AciStore: PreKeysStore + SessionStoreExt,
758        PniStore: PreKeysStore,
759        AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
760    >(
761        &mut self,
762        aci_protocol_store: &mut AciStore,
763        pni_protocol_store: &mut PniStore,
764        mut sender: MessageSender<AciOrPni>,
765        local_aci: Aci,
766        e164: PhoneNumber,
767        csprng: &mut R,
768    ) -> Result<(), MessageSenderError> {
769        let pni_identity_key_pair =
770            pni_protocol_store.get_identity_key_pair().await?;
771
772        let pni_identity_key = pni_identity_key_pair.identity_key();
773
774        // For every linked device, we generate a new set of pre-keys, and send them to the device.
775        let local_device_ids = aci_protocol_store
776            .get_sub_device_sessions(&local_aci.into())
777            .await?;
778
779        let mut device_messages =
780            Vec::<OutgoingPushMessage>::with_capacity(local_device_ids.len());
781        let mut device_pni_signed_prekeys =
782            HashMap::<String, SignedPreKeyEntity>::with_capacity(
783                local_device_ids.len(),
784            );
785        let mut device_pni_last_resort_kyber_prekeys =
786            HashMap::<String, KyberPreKeyEntity>::with_capacity(
787                local_device_ids.len(),
788            );
789        let mut pni_registration_ids =
790            HashMap::<String, u32>::with_capacity(local_device_ids.len());
791
792        let signature_valid_on_each_signed_pre_key = true;
793        for local_device_id in
794            std::iter::once(*DEFAULT_DEVICE_ID).chain(local_device_ids)
795        {
796            let local_protocol_address =
797                local_aci.to_protocol_address(local_device_id)?;
798            let span = tracing::trace_span!(
799                "filtering devices",
800                address = %local_protocol_address
801            );
802            // Skip if we don't have a session with the device
803            if (local_device_id != *DEFAULT_DEVICE_ID)
804                && aci_protocol_store
805                    .load_session(&local_protocol_address)
806                    .instrument(span)
807                    .await?
808                    .is_none()
809            {
810                tracing::warn!(
811                    "No session with device {}, skipping PNI provisioning",
812                    local_device_id
813                );
814                continue;
815            }
816            let (
817                _pre_keys,
818                signed_pre_key,
819                _kyber_pre_keys,
820                last_resort_kyber_prekey,
821            ) = if local_device_id == *DEFAULT_DEVICE_ID {
822                crate::pre_keys::replenish_pre_keys(
823                    pni_protocol_store,
824                    csprng,
825                    &pni_identity_key_pair,
826                    true,
827                    0,
828                    0,
829                )
830                .await?
831            } else {
832                // Generate a signed prekey
833                let signed_pre_key_pair = KeyPair::generate(csprng);
834                let signed_pre_key_public = signed_pre_key_pair.public_key;
835                let signed_pre_key_signature = pni_identity_key_pair
836                    .private_key()
837                    .calculate_signature(
838                        &signed_pre_key_public.serialize(),
839                        csprng,
840                    )
841                    .map_err(MessageSenderError::InvalidPrivateKey)?;
842
843                let signed_prekey_record = SignedPreKeyRecord::new(
844                    csprng.random_range::<u32, _>(0..0xFFFFFF).into(),
845                    Timestamp::now(),
846                    &signed_pre_key_pair,
847                    &signed_pre_key_signature,
848                );
849
850                // Generate a last-resort Kyber prekey
851                let kyber_pre_key_record = KyberPreKeyRecord::generate(
852                    kem::KeyType::Kyber1024,
853                    csprng.random_range::<u32, _>(0..0xFFFFFF).into(),
854                    pni_identity_key_pair.private_key(),
855                )?;
856                (
857                    vec![],
858                    signed_prekey_record,
859                    vec![],
860                    Some(kyber_pre_key_record),
861                )
862            };
863
864            let registration_id = if local_device_id == *DEFAULT_DEVICE_ID {
865                pni_protocol_store.get_local_registration_id().await?
866            } else {
867                loop {
868                    let regid = generate_registration_id(csprng);
869                    if !pni_registration_ids.iter().any(|(_k, v)| *v == regid) {
870                        break regid;
871                    }
872                }
873            };
874
875            let local_device_id_s = local_device_id.to_string();
876            device_pni_signed_prekeys.insert(
877                local_device_id_s.clone(),
878                SignedPreKeyEntity::try_from(&signed_pre_key)?,
879            );
880            device_pni_last_resort_kyber_prekeys.insert(
881                local_device_id_s.clone(),
882                KyberPreKeyEntity::try_from(
883                    last_resort_kyber_prekey
884                        .as_ref()
885                        .expect("requested last resort key"),
886                )?,
887            );
888            pni_registration_ids
889                .insert(local_device_id_s.clone(), registration_id);
890
891            assert!(_pre_keys.is_empty());
892            assert!(_kyber_pre_keys.is_empty());
893
894            if local_device_id == *DEFAULT_DEVICE_ID {
895                // This is the primary device
896                // We don't need to send a message to the primary device
897                continue;
898            }
899            // cfr. SignalServiceMessageSender::getEncryptedSyncPniInitializeDeviceMessage
900            let msg = SyncMessage {
901                pni_change_number: Some(PniChangeNumber {
902                    identity_key_pair: Some(
903                        pni_identity_key_pair.serialize().to_vec(),
904                    ),
905                    signed_pre_key: Some(signed_pre_key.serialize()?),
906                    last_resort_kyber_pre_key: Some(
907                        last_resort_kyber_prekey
908                            .expect("requested last resort key")
909                            .serialize()?,
910                    ),
911                    registration_id: Some(registration_id),
912                    new_e164: Some(
913                        e164.format().mode(phonenumber::Mode::E164).to_string(),
914                    ),
915                }),
916                padding: Some(random_length_padding(csprng, 512)),
917                ..SyncMessage::default()
918            };
919            let content: ContentBody = msg.into();
920            let msg = sender
921                .create_encrypted_message(
922                    &local_aci.into(),
923                    None,
924                    local_device_id,
925                    &content.into_proto().encode_to_vec(),
926                )
927                .await?;
928            device_messages.push(msg);
929        }
930
931        self.websocket
932            .distribute_pni_keys(
933                pni_identity_key,
934                device_messages,
935                device_pni_signed_prekeys,
936                device_pni_last_resort_kyber_prekeys,
937                pni_registration_ids,
938                signature_valid_on_each_signed_pre_key,
939            )
940            .await?;
941
942        Ok(())
943    }
944}
945
946fn calculate_hmac256(
947    mac_key: &[u8],
948    ciphertext: &[u8],
949) -> Result<Output<Hmac<Sha256>>, ServiceError> {
950    let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
951        .map_err(|_| ServiceError::MacError)?;
952    mac.update(ciphertext);
953    Ok(mac.finalize().into_bytes())
954}
955
956pub fn encrypt_device_name<R: rand::Rng + rand::CryptoRng>(
957    csprng: &mut R,
958    device_name: &str,
959    identity_public: &IdentityKey,
960) -> Result<DeviceName, ServiceError> {
961    let plaintext = device_name.as_bytes().to_vec();
962    let ephemeral_key_pair = KeyPair::generate(csprng);
963
964    let master_secret = ephemeral_key_pair
965        .private_key
966        .calculate_agreement(identity_public.public_key())?;
967
968    let key1 = calculate_hmac256(&master_secret, b"auth")?;
969    let synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
970    let synthetic_iv = &synthetic_iv[..16];
971
972    let key2 = calculate_hmac256(&master_secret, b"cipher")?;
973    let cipher_key = calculate_hmac256(&key2, synthetic_iv)?;
974
975    let mut ciphertext = plaintext;
976
977    const IV: [u8; 16] = [0; 16];
978    let mut cipher = Aes256Ctr128BE::new(&cipher_key, &IV.into());
979    cipher.apply_keystream(&mut ciphertext);
980
981    let device_name = DeviceName {
982        ephemeral_public: Some(
983            ephemeral_key_pair.public_key.serialize().to_vec(),
984        ),
985        synthetic_iv: Some(synthetic_iv.to_vec()),
986        ciphertext: Some(ciphertext),
987    };
988
989    Ok(device_name)
990}
991
992fn decrypt_device_name_from_device_info(
993    string: &str,
994    aci: &IdentityKeyPair,
995) -> Result<String, ServiceError> {
996    let data = BASE64_RELAXED.decode(string)?;
997    let name = DeviceName::decode(&*data)?;
998    crate::decrypt_device_name(aci.private_key(), &name)
999}
1000
1001// Analogous to https://github.com/signalapp/Signal-Android/blob/d88a862e0985cc2bbc463c5f504f5bb4e91ad4fc/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt#L121.
1002fn decrypt_device_created_at_from_device_info(
1003    id: DeviceId,
1004    registration_id: i32,
1005    string: &str,
1006    aci: &IdentityKeyPair,
1007) -> Result<chrono::DateTime<chrono::Utc>, ServiceError> {
1008    use signal_crypto::SimpleHpkeReceiver;
1009
1010    let mut associated_data = [0; 5];
1011    associated_data[0] = id.into();
1012    associated_data[1..].copy_from_slice(&registration_id.to_be_bytes());
1013
1014    let data = BASE64_RELAXED.decode(string)?;
1015
1016    let result =
1017        aci.private_key()
1018            .open(b"deviceCreatedAt", &associated_data, &data)?;
1019
1020    let timestamp = i64::from_be_bytes(result.try_into().map_err(|_| {
1021        ServiceError::DecryptDeviceInfoFieldError("created-at")
1022    })?);
1023
1024    chrono::DateTime::<chrono::Utc>::from_timestamp_millis(timestamp)
1025        .ok_or(ServiceError::DecryptDeviceInfoFieldError("created-at"))
1026}
1027
1028pub fn decrypt_device_name(
1029    private_key: &PrivateKey,
1030    device_name: &DeviceName,
1031) -> Result<String, ServiceError> {
1032    let DeviceName {
1033        ephemeral_public: Some(ephemeral_public),
1034        synthetic_iv: Some(synthetic_iv),
1035        ciphertext: Some(ciphertext),
1036    } = device_name
1037    else {
1038        return Err(ServiceError::DecryptDeviceInfoFieldError("name"));
1039    };
1040
1041    let synthetic_iv: [u8; 16] = synthetic_iv[..synthetic_iv.len().min(16)]
1042        .try_into()
1043        .map_err(|_| ServiceError::MacError)?;
1044
1045    let ephemeral_public = PublicKey::deserialize(ephemeral_public)?;
1046
1047    let master_secret = private_key.calculate_agreement(&ephemeral_public)?;
1048    let key2 = calculate_hmac256(&master_secret, b"cipher")?;
1049    let cipher_key = calculate_hmac256(&key2, &synthetic_iv)?;
1050
1051    let mut plaintext = ciphertext.to_vec();
1052    const IV: [u8; 16] = [0; 16];
1053    let mut cipher =
1054        Aes256Ctr128BE::new(cipher_key.as_slice().into(), &IV.into());
1055    cipher.apply_keystream(&mut plaintext);
1056
1057    let key1 = calculate_hmac256(&master_secret, b"auth")?;
1058    let our_synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
1059    let our_synthetic_iv = &our_synthetic_iv[..16];
1060
1061    if synthetic_iv != our_synthetic_iv {
1062        Err(ServiceError::MacError)
1063    } else {
1064        Ok(String::from_utf8_lossy(&plaintext).to_string())
1065    }
1066}
1067
1068#[cfg(test)]
1069mod tests {
1070    use crate::utils::BASE64_RELAXED;
1071    use base64::Engine;
1072    use libsignal_protocol::{IdentityKeyPair, PrivateKey, PublicKey};
1073
1074    use super::DeviceName;
1075
1076    #[test]
1077    fn encrypt_device_name() -> anyhow::Result<()> {
1078        let input_device_name = "Nokia 3310 Millenial Edition";
1079        let mut csprng = rand::rng();
1080        let identity = IdentityKeyPair::generate(&mut csprng);
1081
1082        let device_name = super::encrypt_device_name(
1083            &mut csprng,
1084            input_device_name,
1085            identity.identity_key(),
1086        )?;
1087
1088        let decrypted_device_name =
1089            super::decrypt_device_name(identity.private_key(), &device_name)?;
1090
1091        assert_eq!(input_device_name, decrypted_device_name);
1092
1093        Ok(())
1094    }
1095
1096    #[test]
1097    fn decrypt_device_name() -> anyhow::Result<()> {
1098        let ephemeral_private_key = PrivateKey::deserialize(
1099            &BASE64_RELAXED
1100                .decode("0CgxHjwwblXjvX8sD5wZDWdYToMRf+CZSlgaUrxCGVo=")?,
1101        )?;
1102        let ephemeral_public_key = PublicKey::deserialize(
1103            &BASE64_RELAXED
1104                .decode("BcZS+Lt6yAKbEpXnRX+I5wHqesuvu93Q2V+fjidwW8R6")?,
1105        )?;
1106
1107        let device_name = DeviceName {
1108            ephemeral_public: Some(ephemeral_public_key.serialize().to_vec()),
1109            synthetic_iv: Some(
1110                BASE64_RELAXED.decode("86gekHGmltnnZ9QARhiFcg==")?,
1111            ),
1112            ciphertext: Some(
1113                BASE64_RELAXED
1114                    .decode("MtJ9/9KBWLBVAxfZJD4pLKzP4q+iodRJeCc+/A==")?,
1115            ),
1116        };
1117
1118        let decrypted_device_name =
1119            super::decrypt_device_name(&ephemeral_private_key, &device_name)?;
1120
1121        assert_eq!(decrypted_device_name, "Nokia 3310 Millenial Edition");
1122
1123        Ok(())
1124    }
1125}