libsignal_service/
account_manager.rs

1use base64::prelude::*;
2use phonenumber::PhoneNumber;
3use rand::{CryptoRng, Rng};
4use reqwest::Method;
5use std::collections::HashMap;
6use std::convert::{TryFrom, TryInto};
7
8use aes::cipher::{KeyIvInit, StreamCipher as _};
9use hmac::digest::Output;
10use hmac::{Hmac, Mac};
11use libsignal_protocol::{
12    kem, Aci, GenericSignedPreKey, IdentityKey, IdentityKeyPair,
13    IdentityKeyStore, KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore,
14    PublicKey, SenderKeyStore, ServiceIdKind, SignedPreKeyRecord, Timestamp,
15};
16use prost::Message;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19use tracing_futures::Instrument;
20use zkgroup::profiles::ProfileKey;
21
22use crate::content::ContentBody;
23use crate::master_key::MasterKey;
24use crate::pre_keys::{
25    KyberPreKeyEntity, PreKeyEntity, PreKeysStore, SignedPreKeyEntity,
26    PRE_KEY_BATCH_SIZE, PRE_KEY_MINIMUM,
27};
28use crate::prelude::{MessageSender, MessageSenderError};
29use crate::proto::sync_message::PniChangeNumber;
30use crate::proto::{DeviceName, SyncMessage};
31use crate::provisioning::generate_registration_id;
32use crate::push_service::{
33    AvatarWrite, CaptchaAttributes, DeviceActivationRequest, DeviceInfo,
34    HttpAuthOverride, RegistrationMethod, ReqwestExt, VerifyAccountResponse,
35    DEFAULT_DEVICE_ID,
36};
37use crate::sender::OutgoingPushMessage;
38use crate::service_address::ServiceIdExt;
39use crate::session_store::SessionStoreExt;
40use crate::timestamp::TimestampExt as _;
41use crate::utils::{random_length_padding, BASE64_RELAXED};
42use crate::{
43    configuration::{Endpoint, ServiceCredentials},
44    pre_keys::PreKeyState,
45    profile_cipher::{ProfileCipher, ProfileCipherError},
46    profile_name::ProfileName,
47    proto::{ProvisionEnvelope, ProvisionMessage, ProvisioningVersion},
48    provisioning::{ProvisioningCipher, ProvisioningError},
49    push_service::{AccountAttributes, PushService, ServiceError},
50    utils::serde_base64,
51};
52
53type Aes256Ctr128BE = ctr::Ctr128BE<aes::Aes256>;
54
55pub struct AccountManager {
56    service: PushService,
57    profile_key: Option<ProfileKey>,
58}
59
60#[derive(thiserror::Error, Debug)]
61pub enum ProfileManagerError {
62    #[error(transparent)]
63    ServiceError(#[from] ServiceError),
64    #[error(transparent)]
65    ProfileCipherError(#[from] ProfileCipherError),
66}
67
68#[derive(Debug, Default, Serialize, Deserialize, Clone)]
69pub struct Profile {
70    pub name: Option<ProfileName<String>>,
71    pub about: Option<String>,
72    pub about_emoji: Option<String>,
73    pub avatar: Option<String>,
74    pub unrestricted_unidentified_access: bool,
75}
76
77impl AccountManager {
78    pub fn new(service: PushService, profile_key: Option<ProfileKey>) -> Self {
79        Self {
80            service,
81            profile_key,
82        }
83    }
84
85    #[allow(clippy::too_many_arguments)]
86    #[tracing::instrument(skip(self, protocol_store))]
87    pub async fn check_pre_keys<P: PreKeysStore>(
88        &mut self,
89        protocol_store: &mut P,
90        service_id_kind: ServiceIdKind,
91    ) -> Result<bool, ServiceError> {
92        let Some(signed_prekey_id) = protocol_store.signed_prekey_id().await?
93        else {
94            tracing::warn!("No signed prekey found");
95            return Ok(false);
96        };
97        // XXX: should we instead use the `load_last_resort_kyber_pre_keys` method? Or refactor
98        //      those whole traits?
99        let Some(kyber_prekey_id) =
100            protocol_store.last_resort_kyber_prekey_id().await?
101        else {
102            tracing::warn!("No last resort kyber prekey found");
103            return Ok(false);
104        };
105
106        let signed_prekey =
107            protocol_store.get_signed_pre_key(signed_prekey_id).await?;
108        let kyber_prekey =
109            protocol_store.get_kyber_pre_key(kyber_prekey_id).await?;
110
111        // `SHA256(identityKeyBytes || signedEcPreKeyId || signedEcPreKeyIdBytes || lastResortKeyId || lastResortKeyBytes)`
112        let mut hash = Sha256::default();
113        hash.update(
114            protocol_store
115                .get_identity_key_pair()
116                .await?
117                .public_key()
118                .serialize(),
119        );
120        hash.update((u32::from(signed_prekey_id) as u64).to_be_bytes());
121        hash.update(signed_prekey.public_key()?.serialize());
122        hash.update((u32::from(kyber_prekey_id) as u64).to_be_bytes());
123        hash.update(kyber_prekey.public_key()?.serialize());
124
125        self.service
126            .check_pre_keys(service_id_kind, hash.finalize().as_ref())
127            .await
128    }
129
130    /// Checks the availability of pre-keys, and updates them as necessary.
131    ///
132    /// Parameters are the protocol's `StoreContext`, and the offsets for the next pre-key and
133    /// signed pre-keys.
134    ///
135    /// Equivalent to Java's RefreshPreKeysJob
136    #[allow(clippy::too_many_arguments)]
137    #[tracing::instrument(skip(self, csprng, protocol_store))]
138    pub async fn update_pre_key_bundle<R: Rng + CryptoRng, P: PreKeysStore>(
139        &mut self,
140        protocol_store: &mut P,
141        service_id_kind: ServiceIdKind,
142        use_last_resort_key: bool,
143        csprng: &mut R,
144    ) -> Result<(), ServiceError> {
145        let prekey_status = match self
146            .service
147            .get_pre_key_status(service_id_kind)
148            .instrument(tracing::span!(
149                tracing::Level::DEBUG,
150                "Fetching pre key status"
151            ))
152            .await
153        {
154            Ok(status) => status,
155            Err(ServiceError::Unauthorized) => {
156                tracing::info!("Got Unauthorized when fetching pre-key status. Assuming first installment.");
157                // Additionally, the second PUT request will fail if this really comes down to an
158                // authorization failure.
159                crate::push_service::PreKeyStatus {
160                    count: 0,
161                    pq_count: 0,
162                }
163            },
164            Err(e) => return Err(e),
165        };
166        tracing::trace!("Remaining pre-keys on server: {:?}", prekey_status);
167
168        let check_pre_keys = self
169            .check_pre_keys(protocol_store, service_id_kind)
170            .instrument(tracing::span!(
171                tracing::Level::DEBUG,
172                "Checking pre keys"
173            ))
174            .await?;
175        if !check_pre_keys {
176            tracing::info!(
177                "Last resort pre-keys are not up to date; refreshing."
178            );
179        } else {
180            tracing::debug!("Last resort pre-keys are up to date.");
181        }
182
183        // XXX We should honestly compare the pre-key count with the number of pre-keys we have
184        // locally. If we have more than the server, we should upload them.
185        // Currently the trait doesn't allow us to do that, so we just upload the batch size and
186        // pray.
187        if check_pre_keys
188            && (prekey_status.count >= PRE_KEY_MINIMUM
189                && prekey_status.pq_count >= PRE_KEY_MINIMUM)
190        {
191            if protocol_store.signed_pre_keys_count().await? > 0
192                && protocol_store.kyber_pre_keys_count(true).await? > 0
193                && protocol_store.signed_prekey_id().await?.is_some()
194                && protocol_store
195                    .last_resort_kyber_prekey_id()
196                    .await?
197                    .is_some()
198            {
199                tracing::debug!("Available keys sufficient");
200                return Ok(());
201            }
202            tracing::info!("Available keys sufficient; forcing refresh.");
203        }
204
205        let identity_key_pair = protocol_store
206            .get_identity_key_pair()
207            .instrument(tracing::trace_span!("get identity key pair"))
208            .await?;
209
210        let last_resort_keys = protocol_store
211            .load_last_resort_kyber_pre_keys()
212            .instrument(tracing::trace_span!("fetch last resort key"))
213            .await?;
214
215        // XXX: Maybe this check should be done in the generate_pre_keys function?
216        let has_last_resort_key = !last_resort_keys.is_empty();
217
218        let (pre_keys, signed_pre_key, pq_pre_keys, pq_last_resort_key) =
219            crate::pre_keys::replenish_pre_keys(
220                protocol_store,
221                csprng,
222                &identity_key_pair,
223                use_last_resort_key && !has_last_resort_key,
224                PRE_KEY_BATCH_SIZE,
225                PRE_KEY_BATCH_SIZE,
226            )
227            .await?;
228
229        let pq_last_resort_key = if has_last_resort_key {
230            if last_resort_keys.len() > 1 {
231                tracing::warn!(
232                    "More than one last resort key found; only uploading first"
233                );
234            }
235            Some(KyberPreKeyEntity::try_from(last_resort_keys[0].clone())?)
236        } else {
237            pq_last_resort_key
238                .map(KyberPreKeyEntity::try_from)
239                .transpose()?
240        };
241
242        let identity_key = *identity_key_pair.identity_key();
243
244        let pre_keys: Vec<_> = pre_keys
245            .into_iter()
246            .map(PreKeyEntity::try_from)
247            .collect::<Result<_, _>>()?;
248        let signed_pre_key = signed_pre_key.try_into()?;
249        let pq_pre_keys: Vec<_> = pq_pre_keys
250            .into_iter()
251            .map(KyberPreKeyEntity::try_from)
252            .collect::<Result<_, _>>()?;
253
254        tracing::info!(
255            "Uploading pre-keys: {} one-time, {} PQ, {} PQ last resort",
256            pre_keys.len(),
257            pq_pre_keys.len(),
258            if pq_last_resort_key.is_some() { 1 } else { 0 }
259        );
260
261        let pre_key_state = PreKeyState {
262            pre_keys,
263            signed_pre_key,
264            identity_key,
265            pq_pre_keys,
266            pq_last_resort_key,
267        };
268
269        self.service
270            .register_pre_keys(service_id_kind, pre_key_state)
271            .instrument(tracing::span!(
272                tracing::Level::DEBUG,
273                "Uploading pre keys"
274            ))
275            .await?;
276
277        Ok(())
278    }
279
280    async fn new_device_provisioning_code(
281        &mut self,
282    ) -> Result<String, ServiceError> {
283        #[derive(serde::Deserialize)]
284        #[serde(rename_all = "camelCase")]
285        struct DeviceCode {
286            verification_code: String,
287        }
288
289        let dc: DeviceCode = self
290            .service
291            .request(
292                Method::GET,
293                Endpoint::service("/v1/devices/provisioning/code"),
294                HttpAuthOverride::NoOverride,
295            )?
296            .send()
297            .await?
298            .service_error_for_status()
299            .await?
300            .json()
301            .await?;
302
303        Ok(dc.verification_code)
304    }
305
306    async fn send_provisioning_message(
307        &mut self,
308        destination: &str,
309        env: ProvisionEnvelope,
310    ) -> Result<(), ServiceError> {
311        #[derive(serde::Serialize)]
312        struct ProvisioningMessage {
313            body: String,
314        }
315
316        let body = env.encode_to_vec();
317
318        self.service
319            .request(
320                Method::PUT,
321                Endpoint::service(format!("/v1/provisioning/{destination}")),
322                HttpAuthOverride::NoOverride,
323            )?
324            .json(&ProvisioningMessage {
325                body: BASE64_RELAXED.encode(body),
326            })
327            .send()
328            .await?
329            .service_error_for_status()
330            .await?;
331
332        Ok(())
333    }
334
335    /// Link a new device, given a tsurl.
336    ///
337    /// Equivalent of Java's `AccountManager::addDevice()`
338    ///
339    /// When calling this, make sure that UnidentifiedDelivery is disabled, ie., that your
340    /// application does not send any unidentified messages before linking is complete.
341    /// Cfr.:
342    /// - `app/src/main/java/org/thoughtcrime/securesms/migrations/LegacyMigrationJob.java`:250 and;
343    /// - `app/src/main/java/org/thoughtcrime/securesms/DeviceActivity.java`:195
344    ///
345    /// ```java
346    /// TextSecurePreferences.setIsUnidentifiedDeliveryEnabled(context, false);
347    /// ```
348    pub async fn link_device<R: Rng + CryptoRng>(
349        &mut self,
350        csprng: &mut R,
351        url: url::Url,
352        aci_identity_store: &dyn IdentityKeyStore,
353        pni_identity_store: &dyn IdentityKeyStore,
354        credentials: ServiceCredentials,
355        master_key: Option<MasterKey>,
356    ) -> Result<(), ProvisioningError> {
357        let query: HashMap<_, _> = url.query_pairs().collect();
358        let ephemeral_id =
359            query.get("uuid").ok_or(ProvisioningError::MissingUuid)?;
360        let pub_key = query
361            .get("pub_key")
362            .ok_or(ProvisioningError::MissingPublicKey)?;
363
364        let pub_key = BASE64_RELAXED
365            .decode(&**pub_key)
366            .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
367        let pub_key = PublicKey::deserialize(&pub_key)
368            .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
369
370        let aci_identity_key_pair =
371            aci_identity_store.get_identity_key_pair().await?;
372        let pni_identity_key_pair =
373            pni_identity_store.get_identity_key_pair().await?;
374
375        if credentials.aci.is_none() {
376            tracing::warn!("No local ACI set");
377        }
378        if credentials.pni.is_none() {
379            tracing::warn!("No local PNI set");
380        }
381
382        let provisioning_code = self.new_device_provisioning_code().await?;
383
384        let msg = ProvisionMessage {
385            aci: credentials.aci.as_ref().map(|u| u.to_string()),
386            aci_identity_key_public: Some(
387                aci_identity_key_pair.public_key().serialize().into_vec(),
388            ),
389            aci_identity_key_private: Some(
390                aci_identity_key_pair.private_key().serialize(),
391            ),
392            number: Some(credentials.e164()),
393            pni_identity_key_public: Some(
394                pni_identity_key_pair.public_key().serialize().into_vec(),
395            ),
396            pni_identity_key_private: Some(
397                pni_identity_key_pair.private_key().serialize(),
398            ),
399            pni: credentials.pni.as_ref().map(uuid::Uuid::to_string),
400            profile_key: self.profile_key.as_ref().map(|x| x.bytes.to_vec()),
401            // CURRENT is not exposed by prost :(
402            provisioning_version: Some(i32::from(
403                ProvisioningVersion::TabletSupport,
404            ) as _),
405            provisioning_code: Some(provisioning_code),
406            read_receipts: None,
407            user_agent: None,
408            master_key: master_key.map(|x| x.into()),
409            ephemeral_backup_key: None,
410            account_entropy_pool: None,
411            media_root_backup_key: None,
412        };
413
414        let cipher = ProvisioningCipher::from_public(pub_key);
415
416        let encrypted = cipher.encrypt(csprng, msg)?;
417        self.send_provisioning_message(ephemeral_id, encrypted)
418            .await?;
419        Ok(())
420    }
421
422    pub async fn linked_devices(
423        &mut self,
424        aci_identity_store: &dyn IdentityKeyStore,
425    ) -> Result<Vec<DeviceInfo>, ServiceError> {
426        let device_infos = self.service.devices().await?;
427        let aci_identity_keypair =
428            aci_identity_store.get_identity_key_pair().await?;
429
430        device_infos
431            .into_iter()
432            .map(|i| {
433                Ok(DeviceInfo {
434                    id: i.id,
435                    name: i
436                        .name
437                        .map(|s| {
438                            decrypt_device_name_from_device_info(
439                                &s,
440                                &aci_identity_keypair,
441                            )
442                        })
443                        .transpose()?,
444                    created: i.created,
445                    last_seen: i.last_seen,
446                })
447            })
448            .collect()
449    }
450
451    pub async fn register_account<
452        R: Rng + CryptoRng,
453        Aci: PreKeysStore + IdentityKeyStore,
454        Pni: PreKeysStore + IdentityKeyStore,
455    >(
456        &mut self,
457        csprng: &mut R,
458        registration_method: RegistrationMethod<'_>,
459        account_attributes: AccountAttributes,
460        aci_protocol_store: &mut Aci,
461        pni_protocol_store: &mut Pni,
462        skip_device_transfer: bool,
463    ) -> Result<VerifyAccountResponse, ProvisioningError> {
464        let aci_identity_key_pair = aci_protocol_store
465            .get_identity_key_pair()
466            .instrument(tracing::trace_span!("get ACI identity key pair"))
467            .await?;
468        let pni_identity_key_pair = pni_protocol_store
469            .get_identity_key_pair()
470            .instrument(tracing::trace_span!("get PNI identity key pair"))
471            .await?;
472
473        let (
474            _aci_pre_keys,
475            aci_signed_pre_key,
476            _aci_kyber_pre_keys,
477            aci_last_resort_kyber_prekey,
478        ) = crate::pre_keys::replenish_pre_keys(
479            aci_protocol_store,
480            csprng,
481            &aci_identity_key_pair,
482            true,
483            0,
484            0,
485        )
486        .await?;
487
488        let (
489            _pni_pre_keys,
490            pni_signed_pre_key,
491            _pni_kyber_pre_keys,
492            pni_last_resort_kyber_prekey,
493        ) = crate::pre_keys::replenish_pre_keys(
494            pni_protocol_store,
495            csprng,
496            &pni_identity_key_pair,
497            true,
498            0,
499            0,
500        )
501        .await?;
502
503        let aci_identity_key = aci_identity_key_pair.identity_key();
504        let pni_identity_key = pni_identity_key_pair.identity_key();
505
506        let dar = DeviceActivationRequest {
507            aci_signed_pre_key: aci_signed_pre_key.try_into()?,
508            pni_signed_pre_key: pni_signed_pre_key.try_into()?,
509            aci_pq_last_resort_pre_key: aci_last_resort_kyber_prekey
510                .expect("requested last resort prekey")
511                .try_into()?,
512            pni_pq_last_resort_pre_key: pni_last_resort_kyber_prekey
513                .expect("requested last resort prekey")
514                .try_into()?,
515        };
516
517        let result = self
518            .service
519            .submit_registration_request(
520                registration_method,
521                account_attributes,
522                skip_device_transfer,
523                aci_identity_key,
524                pni_identity_key,
525                dar,
526            )
527            .await?;
528
529        Ok(result)
530    }
531
532    /// Upload a profile
533    ///
534    /// Panics if no `profile_key` was set.
535    ///
536    /// Convenience method for
537    /// ```ignore
538    /// manager.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _>(uuid, name, about, about_emoji, _)
539    /// ```
540    /// in which the `retain_avatar` parameter sets whether to remove (`false`) or retain (`true`) the
541    /// currently set avatar.
542    pub async fn upload_versioned_profile_without_avatar<
543        R: Rng + CryptoRng,
544        S: AsRef<str>,
545    >(
546        &mut self,
547        aci: libsignal_protocol::Aci,
548        name: ProfileName<S>,
549        about: Option<String>,
550        about_emoji: Option<String>,
551        retain_avatar: bool,
552        csprng: &mut R,
553    ) -> Result<(), ProfileManagerError> {
554        self.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _, _>(
555            aci,
556            name,
557            about,
558            about_emoji,
559            if retain_avatar {
560                AvatarWrite::RetainAvatar
561            } else {
562                AvatarWrite::NoAvatar
563            },
564            csprng,
565        )
566        .await?;
567        Ok(())
568    }
569
570    pub async fn retrieve_profile(
571        &mut self,
572        address: Aci,
573    ) -> Result<Profile, ProfileManagerError> {
574        let profile_key =
575            self.profile_key.expect("set profile key in AccountManager");
576
577        let encrypted_profile = self
578            .service
579            .retrieve_profile_by_id(address, Some(profile_key))
580            .await?;
581
582        let profile_cipher = ProfileCipher::new(profile_key);
583        Ok(profile_cipher.decrypt(encrypted_profile)?)
584    }
585
586    /// Upload a profile
587    ///
588    /// Panics if no `profile_key` was set.
589    ///
590    /// Returns the avatar url path.
591    pub async fn upload_versioned_profile<
592        's,
593        C: std::io::Read + Send + 's,
594        R: Rng + CryptoRng,
595        S: AsRef<str>,
596    >(
597        &mut self,
598        aci: libsignal_protocol::Aci,
599        name: ProfileName<S>,
600        about: Option<String>,
601        about_emoji: Option<String>,
602        avatar: AvatarWrite<&'s mut C>,
603        csprng: &mut R,
604    ) -> Result<Option<String>, ProfileManagerError> {
605        let profile_key =
606            self.profile_key.expect("set profile key in AccountManager");
607        let profile_cipher = ProfileCipher::new(profile_key);
608
609        // Profile encryption
610        let name = profile_cipher.encrypt_name(name.as_ref(), csprng)?;
611        let about = about.unwrap_or_default();
612        let about = profile_cipher.encrypt_about(about, csprng)?;
613        let about_emoji = about_emoji.unwrap_or_default();
614        let about_emoji = profile_cipher.encrypt_emoji(about_emoji, csprng)?;
615
616        // If avatar -> upload
617        if matches!(avatar, AvatarWrite::NewAvatar(_)) {
618            // FIXME ProfileCipherOutputStream.java
619            // It's just AES GCM, but a bit of work to decently implement it with a stream.
620            unimplemented!("Setting avatar requires ProfileCipherStream")
621        }
622
623        let profile_key = profile_cipher.into_inner();
624        let commitment = profile_key.get_commitment(aci);
625        let profile_key_version = profile_key.get_profile_key_version(aci);
626
627        Ok(self
628            .service
629            .write_profile::<C, S>(
630                &profile_key_version,
631                &name,
632                &about,
633                &about_emoji,
634                &commitment,
635                avatar,
636            )
637            .await?)
638    }
639
640    /// Set profile attributes
641    ///
642    /// Signal Android does not allow unsetting voice/video.
643    pub async fn set_account_attributes(
644        &mut self,
645        attributes: AccountAttributes,
646    ) -> Result<(), ServiceError> {
647        self.service.set_account_attributes(attributes).await
648    }
649
650    /// Update (encrypted) device name
651    pub async fn update_device_name<R: Rng + CryptoRng>(
652        &mut self,
653        device_name: &str,
654        public_key: &IdentityKey,
655        csprng: &mut R,
656    ) -> Result<(), ServiceError> {
657        let encrypted_device_name =
658            encrypt_device_name(csprng, device_name, public_key)?;
659
660        #[derive(Serialize)]
661        #[serde(rename_all = "camelCase")]
662        struct Data {
663            #[serde(with = "serde_base64")]
664            device_name: Vec<u8>,
665        }
666
667        self.service
668            .request(
669                Method::PUT,
670                Endpoint::service("/v1/accounts/name"),
671                HttpAuthOverride::NoOverride,
672            )?
673            .json(&Data {
674                device_name: encrypted_device_name.encode_to_vec(),
675            })
676            .send()
677            .await?
678            .service_error_for_status()
679            .await?;
680
681        Ok(())
682    }
683
684    /// Upload a proof-required reCaptcha token and response.
685    ///
686    /// Token gotten originally with HTTP status 428 response to sending a message.
687    /// Captcha gotten from user completing the challenge captcha.
688    ///
689    /// It's either a silent OK, or throws a ServiceError.
690    pub async fn submit_recaptcha_challenge(
691        &mut self,
692        token: &str,
693        captcha: &str,
694    ) -> Result<(), ServiceError> {
695        self.service
696            .request(
697                Method::PUT,
698                Endpoint::service("/v1/challenge"),
699                HttpAuthOverride::NoOverride,
700            )?
701            .json(&CaptchaAttributes {
702                challenge_type: "captcha",
703                token,
704                captcha,
705            })
706            .send()
707            .await?
708            .service_error_for_status()
709            .await?;
710
711        Ok(())
712    }
713
714    /// Initialize PNI on linked devices.
715    ///
716    /// Should be called as the primary device to migrate from pre-PNI to PNI.
717    ///
718    /// This is the equivalent of Android's PnpInitializeDevicesJob or iOS' PniHelloWorldManager.
719    #[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci = local_aci.service_id_string()))]
720    pub async fn pnp_initialize_devices<
721        R: Rng + CryptoRng,
722        AciStore: PreKeysStore + SessionStoreExt,
723        PniStore: PreKeysStore,
724        AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
725    >(
726        &mut self,
727        aci_protocol_store: &mut AciStore,
728        pni_protocol_store: &mut PniStore,
729        mut sender: MessageSender<AciOrPni>,
730        local_aci: Aci,
731        e164: PhoneNumber,
732        csprng: &mut R,
733    ) -> Result<(), MessageSenderError> {
734        let pni_identity_key_pair =
735            pni_protocol_store.get_identity_key_pair().await?;
736
737        let pni_identity_key = pni_identity_key_pair.identity_key();
738
739        // For every linked device, we generate a new set of pre-keys, and send them to the device.
740        let local_device_ids = aci_protocol_store
741            .get_sub_device_sessions(&local_aci.into())
742            .await?;
743
744        let mut device_messages =
745            Vec::<OutgoingPushMessage>::with_capacity(local_device_ids.len());
746        let mut device_pni_signed_prekeys =
747            HashMap::<String, SignedPreKeyEntity>::with_capacity(
748                local_device_ids.len(),
749            );
750        let mut device_pni_last_resort_kyber_prekeys =
751            HashMap::<String, KyberPreKeyEntity>::with_capacity(
752                local_device_ids.len(),
753            );
754        let mut pni_registration_ids =
755            HashMap::<String, u32>::with_capacity(local_device_ids.len());
756
757        let signature_valid_on_each_signed_pre_key = true;
758        for local_device_id in
759            std::iter::once(*DEFAULT_DEVICE_ID).chain(local_device_ids)
760        {
761            let local_protocol_address =
762                local_aci.to_protocol_address(local_device_id)?;
763            let span = tracing::trace_span!(
764                "filtering devices",
765                address = %local_protocol_address
766            );
767            // Skip if we don't have a session with the device
768            if (local_device_id != *DEFAULT_DEVICE_ID)
769                && aci_protocol_store
770                    .load_session(&local_protocol_address)
771                    .instrument(span)
772                    .await?
773                    .is_none()
774            {
775                tracing::warn!(
776                    "No session with device {}, skipping PNI provisioning",
777                    local_device_id
778                );
779                continue;
780            }
781            let (
782                _pre_keys,
783                signed_pre_key,
784                _kyber_pre_keys,
785                last_resort_kyber_prekey,
786            ) = if local_device_id == *DEFAULT_DEVICE_ID {
787                crate::pre_keys::replenish_pre_keys(
788                    pni_protocol_store,
789                    csprng,
790                    &pni_identity_key_pair,
791                    true,
792                    0,
793                    0,
794                )
795                .await?
796            } else {
797                // Generate a signed prekey
798                let signed_pre_key_pair = KeyPair::generate(csprng);
799                let signed_pre_key_public = signed_pre_key_pair.public_key;
800                let signed_pre_key_signature = pni_identity_key_pair
801                    .private_key()
802                    .calculate_signature(
803                        &signed_pre_key_public.serialize(),
804                        csprng,
805                    )
806                    .map_err(MessageSenderError::InvalidPrivateKey)?;
807
808                let signed_prekey_record = SignedPreKeyRecord::new(
809                    csprng.random_range::<u32, _>(0..0xFFFFFF).into(),
810                    Timestamp::now(),
811                    &signed_pre_key_pair,
812                    &signed_pre_key_signature,
813                );
814
815                // Generate a last-resort Kyber prekey
816                let kyber_pre_key_record = KyberPreKeyRecord::generate(
817                    kem::KeyType::Kyber1024,
818                    csprng.random_range::<u32, _>(0..0xFFFFFF).into(),
819                    pni_identity_key_pair.private_key(),
820                )?;
821                (
822                    vec![],
823                    signed_prekey_record,
824                    vec![],
825                    Some(kyber_pre_key_record),
826                )
827            };
828
829            let registration_id = if local_device_id == *DEFAULT_DEVICE_ID {
830                pni_protocol_store.get_local_registration_id().await?
831            } else {
832                loop {
833                    let regid = generate_registration_id(csprng);
834                    if !pni_registration_ids.iter().any(|(_k, v)| *v == regid) {
835                        break regid;
836                    }
837                }
838            };
839
840            let local_device_id_s = local_device_id.to_string();
841            device_pni_signed_prekeys.insert(
842                local_device_id_s.clone(),
843                SignedPreKeyEntity::try_from(&signed_pre_key)?,
844            );
845            device_pni_last_resort_kyber_prekeys.insert(
846                local_device_id_s.clone(),
847                KyberPreKeyEntity::try_from(
848                    last_resort_kyber_prekey
849                        .as_ref()
850                        .expect("requested last resort key"),
851                )?,
852            );
853            pni_registration_ids
854                .insert(local_device_id_s.clone(), registration_id);
855
856            assert!(_pre_keys.is_empty());
857            assert!(_kyber_pre_keys.is_empty());
858
859            if local_device_id == *DEFAULT_DEVICE_ID {
860                // This is the primary device
861                // We don't need to send a message to the primary device
862                continue;
863            }
864            // cfr. SignalServiceMessageSender::getEncryptedSyncPniInitializeDeviceMessage
865            let msg = SyncMessage {
866                pni_change_number: Some(PniChangeNumber {
867                    identity_key_pair: Some(
868                        pni_identity_key_pair.serialize().to_vec(),
869                    ),
870                    signed_pre_key: Some(signed_pre_key.serialize()?),
871                    last_resort_kyber_pre_key: Some(
872                        last_resort_kyber_prekey
873                            .expect("requested last resort key")
874                            .serialize()?,
875                    ),
876                    registration_id: Some(registration_id),
877                    new_e164: Some(
878                        e164.format().mode(phonenumber::Mode::E164).to_string(),
879                    ),
880                }),
881                padding: Some(random_length_padding(csprng, 512)),
882                ..SyncMessage::default()
883            };
884            let content: ContentBody = msg.into();
885            let msg = sender
886                .create_encrypted_message(
887                    &local_aci.into(),
888                    None,
889                    local_device_id,
890                    &content.into_proto().encode_to_vec(),
891                )
892                .await?;
893            device_messages.push(msg);
894        }
895
896        self.service
897            .distribute_pni_keys(
898                pni_identity_key,
899                device_messages,
900                device_pni_signed_prekeys,
901                device_pni_last_resort_kyber_prekeys,
902                pni_registration_ids,
903                signature_valid_on_each_signed_pre_key,
904            )
905            .await?;
906
907        Ok(())
908    }
909}
910
911#[expect(clippy::result_large_err)]
912fn calculate_hmac256(
913    mac_key: &[u8],
914    ciphertext: &[u8],
915) -> Result<Output<Hmac<Sha256>>, ServiceError> {
916    let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
917        .map_err(|_| ServiceError::MacError)?;
918    mac.update(ciphertext);
919    Ok(mac.finalize().into_bytes())
920}
921
922#[expect(clippy::result_large_err)]
923pub fn encrypt_device_name<R: rand::Rng + rand::CryptoRng>(
924    csprng: &mut R,
925    device_name: &str,
926    identity_public: &IdentityKey,
927) -> Result<DeviceName, ServiceError> {
928    let plaintext = device_name.as_bytes().to_vec();
929    let ephemeral_key_pair = KeyPair::generate(csprng);
930
931    let master_secret = ephemeral_key_pair
932        .private_key
933        .calculate_agreement(identity_public.public_key())?;
934
935    let key1 = calculate_hmac256(&master_secret, b"auth")?;
936    let synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
937    let synthetic_iv = &synthetic_iv[..16];
938
939    let key2 = calculate_hmac256(&master_secret, b"cipher")?;
940    let cipher_key = calculate_hmac256(&key2, synthetic_iv)?;
941
942    let mut ciphertext = plaintext;
943
944    const IV: [u8; 16] = [0; 16];
945    let mut cipher = Aes256Ctr128BE::new(&cipher_key, &IV.into());
946    cipher.apply_keystream(&mut ciphertext);
947
948    let device_name = DeviceName {
949        ephemeral_public: Some(
950            ephemeral_key_pair.public_key.serialize().to_vec(),
951        ),
952        synthetic_iv: Some(synthetic_iv.to_vec()),
953        ciphertext: Some(ciphertext),
954    };
955
956    Ok(device_name)
957}
958
959#[expect(clippy::result_large_err)]
960fn decrypt_device_name_from_device_info(
961    string: &str,
962    aci: &IdentityKeyPair,
963) -> Result<String, ServiceError> {
964    let data = BASE64_RELAXED.decode(string)?;
965    let name = DeviceName::decode(&*data)?;
966    crate::decrypt_device_name(aci.private_key(), &name)
967}
968
969#[expect(clippy::result_large_err)]
970pub fn decrypt_device_name(
971    private_key: &PrivateKey,
972    device_name: &DeviceName,
973) -> Result<String, ServiceError> {
974    let DeviceName {
975        ephemeral_public: Some(ephemeral_public),
976        synthetic_iv: Some(synthetic_iv),
977        ciphertext: Some(ciphertext),
978    } = device_name
979    else {
980        return Err(ServiceError::InvalidDeviceName);
981    };
982
983    let synthetic_iv: [u8; 16] = synthetic_iv[..synthetic_iv.len().min(16)]
984        .try_into()
985        .map_err(|_| ServiceError::MacError)?;
986
987    let ephemeral_public = PublicKey::deserialize(ephemeral_public)?;
988
989    let master_secret = private_key.calculate_agreement(&ephemeral_public)?;
990    let key2 = calculate_hmac256(&master_secret, b"cipher")?;
991    let cipher_key = calculate_hmac256(&key2, &synthetic_iv)?;
992
993    let mut plaintext = ciphertext.to_vec();
994    const IV: [u8; 16] = [0; 16];
995    let mut cipher =
996        Aes256Ctr128BE::new(cipher_key.as_slice().into(), &IV.into());
997    cipher.apply_keystream(&mut plaintext);
998
999    let key1 = calculate_hmac256(&master_secret, b"auth")?;
1000    let our_synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
1001    let our_synthetic_iv = &our_synthetic_iv[..16];
1002
1003    if synthetic_iv != our_synthetic_iv {
1004        Err(ServiceError::MacError)
1005    } else {
1006        Ok(String::from_utf8_lossy(&plaintext).to_string())
1007    }
1008}
1009
1010#[cfg(test)]
1011mod tests {
1012    use crate::utils::BASE64_RELAXED;
1013    use base64::Engine;
1014    use libsignal_protocol::{IdentityKeyPair, PrivateKey, PublicKey};
1015
1016    use super::DeviceName;
1017
1018    #[test]
1019    fn encrypt_device_name() -> anyhow::Result<()> {
1020        let input_device_name = "Nokia 3310 Millenial Edition";
1021        let mut csprng = rand::rng();
1022        let identity = IdentityKeyPair::generate(&mut csprng);
1023
1024        let device_name = super::encrypt_device_name(
1025            &mut csprng,
1026            input_device_name,
1027            identity.identity_key(),
1028        )?;
1029
1030        let decrypted_device_name =
1031            super::decrypt_device_name(identity.private_key(), &device_name)?;
1032
1033        assert_eq!(input_device_name, decrypted_device_name);
1034
1035        Ok(())
1036    }
1037
1038    #[test]
1039    fn decrypt_device_name() -> anyhow::Result<()> {
1040        let ephemeral_private_key = PrivateKey::deserialize(
1041            &BASE64_RELAXED
1042                .decode("0CgxHjwwblXjvX8sD5wZDWdYToMRf+CZSlgaUrxCGVo=")?,
1043        )?;
1044        let ephemeral_public_key = PublicKey::deserialize(
1045            &BASE64_RELAXED
1046                .decode("BcZS+Lt6yAKbEpXnRX+I5wHqesuvu93Q2V+fjidwW8R6")?,
1047        )?;
1048
1049        let device_name = DeviceName {
1050            ephemeral_public: Some(ephemeral_public_key.serialize().to_vec()),
1051            synthetic_iv: Some(
1052                BASE64_RELAXED.decode("86gekHGmltnnZ9QARhiFcg==")?,
1053            ),
1054            ciphertext: Some(
1055                BASE64_RELAXED
1056                    .decode("MtJ9/9KBWLBVAxfZJD4pLKzP4q+iodRJeCc+/A==")?,
1057            ),
1058        };
1059
1060        let decrypted_device_name =
1061            super::decrypt_device_name(&ephemeral_private_key, &device_name)?;
1062
1063        assert_eq!(decrypted_device_name, "Nokia 3310 Millenial Edition");
1064
1065        Ok(())
1066    }
1067}