libsignal_service/
account_manager.rs

1use base64::prelude::*;
2use phonenumber::PhoneNumber;
3use rand::{CryptoRng, Rng};
4use reqwest::Method;
5use std::collections::HashMap;
6use std::convert::{TryFrom, TryInto};
7
8use aes::cipher::{KeyIvInit, StreamCipher as _};
9use hmac::digest::Output;
10use hmac::{Hmac, Mac};
11use libsignal_protocol::{
12    kem, Aci, GenericSignedPreKey, IdentityKey, IdentityKeyPair,
13    IdentityKeyStore, KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore,
14    PublicKey, SenderKeyStore, ServiceIdKind, SignedPreKeyRecord, Timestamp,
15};
16use prost::Message;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19use tracing_futures::Instrument;
20use zkgroup::profiles::ProfileKey;
21
22use crate::content::ContentBody;
23use crate::master_key::MasterKey;
24use crate::pre_keys::{
25    KyberPreKeyEntity, PreKeyEntity, PreKeysStore, SignedPreKeyEntity,
26    PRE_KEY_BATCH_SIZE, PRE_KEY_MINIMUM,
27};
28use crate::prelude::{MessageSender, MessageSenderError};
29use crate::proto::sync_message::PniChangeNumber;
30use crate::proto::{DeviceName, SyncMessage};
31use crate::provisioning::generate_registration_id;
32use crate::push_service::{
33    AvatarWrite, CaptchaAttributes, DeviceActivationRequest, DeviceInfo,
34    HttpAuthOverride, RegistrationMethod, ReqwestExt, VerifyAccountResponse,
35    DEFAULT_DEVICE_ID,
36};
37use crate::sender::OutgoingPushMessage;
38use crate::service_address::ServiceIdExt;
39use crate::session_store::SessionStoreExt;
40use crate::timestamp::TimestampExt as _;
41use crate::utils::{random_length_padding, BASE64_RELAXED};
42use crate::{
43    configuration::{Endpoint, ServiceCredentials},
44    pre_keys::PreKeyState,
45    profile_cipher::{ProfileCipher, ProfileCipherError},
46    profile_name::ProfileName,
47    proto::{ProvisionEnvelope, ProvisionMessage, ProvisioningVersion},
48    provisioning::{ProvisioningCipher, ProvisioningError},
49    push_service::{AccountAttributes, PushService, ServiceError},
50    utils::serde_base64,
51};
52
53type Aes256Ctr128BE = ctr::Ctr128BE<aes::Aes256>;
54
55pub struct AccountManager {
56    service: PushService,
57    profile_key: Option<ProfileKey>,
58}
59
60#[derive(thiserror::Error, Debug)]
61pub enum ProfileManagerError {
62    #[error(transparent)]
63    ServiceError(#[from] ServiceError),
64    #[error(transparent)]
65    ProfileCipherError(#[from] ProfileCipherError),
66}
67
68#[derive(Debug, Default, Serialize, Deserialize, Clone)]
69pub struct Profile {
70    pub name: Option<ProfileName<String>>,
71    pub about: Option<String>,
72    pub about_emoji: Option<String>,
73    pub avatar: Option<String>,
74    pub unrestricted_unidentified_access: bool,
75}
76
77impl AccountManager {
78    pub fn new(service: PushService, profile_key: Option<ProfileKey>) -> Self {
79        Self {
80            service,
81            profile_key,
82        }
83    }
84
85    #[allow(clippy::too_many_arguments)]
86    #[tracing::instrument(skip(self, protocol_store))]
87    pub async fn check_pre_keys<P: PreKeysStore>(
88        &mut self,
89        protocol_store: &mut P,
90        service_id_kind: ServiceIdKind,
91    ) -> Result<bool, ServiceError> {
92        let Some(signed_prekey_id) = protocol_store.signed_prekey_id().await?
93        else {
94            tracing::warn!("No signed prekey found");
95            return Ok(false);
96        };
97        // XXX: should we instead use the `load_last_resort_kyber_pre_keys` method? Or refactor
98        //      those whole traits?
99        let Some(kyber_prekey_id) =
100            protocol_store.last_resort_kyber_prekey_id().await?
101        else {
102            tracing::warn!("No last resort kyber prekey found");
103            return Ok(false);
104        };
105
106        let signed_prekey =
107            protocol_store.get_signed_pre_key(signed_prekey_id).await?;
108        let kyber_prekey =
109            protocol_store.get_kyber_pre_key(kyber_prekey_id).await?;
110
111        // `SHA256(identityKeyBytes || signedEcPreKeyId || signedEcPreKeyIdBytes || lastResortKeyId || lastResortKeyBytes)`
112        let mut hash = Sha256::default();
113        hash.update(
114            protocol_store
115                .get_identity_key_pair()
116                .await?
117                .public_key()
118                .serialize(),
119        );
120        hash.update((u32::from(signed_prekey_id) as u64).to_be_bytes());
121        hash.update(signed_prekey.public_key()?.serialize());
122        hash.update((u32::from(kyber_prekey_id) as u64).to_be_bytes());
123        hash.update(kyber_prekey.public_key()?.serialize());
124
125        self.service
126            .check_pre_keys(service_id_kind, hash.finalize().as_ref())
127            .await
128    }
129
130    /// Checks the availability of pre-keys, and updates them as necessary.
131    ///
132    /// Parameters are the protocol's `StoreContext`, and the offsets for the next pre-key and
133    /// signed pre-keys.
134    ///
135    /// Equivalent to Java's RefreshPreKeysJob
136    #[allow(clippy::too_many_arguments)]
137    #[tracing::instrument(skip(self, csprng, protocol_store))]
138    pub async fn update_pre_key_bundle<R: Rng + CryptoRng, P: PreKeysStore>(
139        &mut self,
140        protocol_store: &mut P,
141        service_id_kind: ServiceIdKind,
142        use_last_resort_key: bool,
143        csprng: &mut R,
144    ) -> Result<(), ServiceError> {
145        let prekey_status = match self
146            .service
147            .get_pre_key_status(service_id_kind)
148            .instrument(tracing::span!(
149                tracing::Level::DEBUG,
150                "Fetching pre key status"
151            ))
152            .await
153        {
154            Ok(status) => status,
155            Err(ServiceError::Unauthorized) => {
156                tracing::info!("Got Unauthorized when fetching pre-key status. Assuming first installment.");
157                // Additionally, the second PUT request will fail if this really comes down to an
158                // authorization failure.
159                crate::push_service::PreKeyStatus {
160                    count: 0,
161                    pq_count: 0,
162                }
163            },
164            Err(e) => return Err(e),
165        };
166        tracing::trace!("Remaining pre-keys on server: {:?}", prekey_status);
167
168        let check_pre_keys = self
169            .check_pre_keys(protocol_store, service_id_kind)
170            .instrument(tracing::span!(
171                tracing::Level::DEBUG,
172                "Checking pre keys"
173            ))
174            .await?;
175        if !check_pre_keys {
176            tracing::info!(
177                "Last resort pre-keys are not up to date; refreshing."
178            );
179        } else {
180            tracing::debug!("Last resort pre-keys are up to date.");
181        }
182
183        // XXX We should honestly compare the pre-key count with the number of pre-keys we have
184        // locally. If we have more than the server, we should upload them.
185        // Currently the trait doesn't allow us to do that, so we just upload the batch size and
186        // pray.
187        if check_pre_keys
188            && (prekey_status.count >= PRE_KEY_MINIMUM
189                && prekey_status.pq_count >= PRE_KEY_MINIMUM)
190        {
191            if protocol_store.signed_pre_keys_count().await? > 0
192                && protocol_store.kyber_pre_keys_count(true).await? > 0
193                && protocol_store.signed_prekey_id().await?.is_some()
194                && protocol_store
195                    .last_resort_kyber_prekey_id()
196                    .await?
197                    .is_some()
198            {
199                tracing::debug!("Available keys sufficient");
200                return Ok(());
201            }
202            tracing::info!("Available keys sufficient; forcing refresh.");
203        }
204
205        let identity_key_pair = protocol_store
206            .get_identity_key_pair()
207            .instrument(tracing::trace_span!("get identity key pair"))
208            .await?;
209
210        let last_resort_keys = protocol_store
211            .load_last_resort_kyber_pre_keys()
212            .instrument(tracing::trace_span!("fetch last resort key"))
213            .await?;
214
215        // XXX: Maybe this check should be done in the generate_pre_keys function?
216        let has_last_resort_key = !last_resort_keys.is_empty();
217
218        let (pre_keys, signed_pre_key, pq_pre_keys, pq_last_resort_key) =
219            crate::pre_keys::replenish_pre_keys(
220                protocol_store,
221                csprng,
222                &identity_key_pair,
223                use_last_resort_key && !has_last_resort_key,
224                PRE_KEY_BATCH_SIZE,
225                PRE_KEY_BATCH_SIZE,
226            )
227            .await?;
228
229        let pq_last_resort_key = if has_last_resort_key {
230            if last_resort_keys.len() > 1 {
231                tracing::warn!(
232                    "More than one last resort key found; only uploading first"
233                );
234            }
235            Some(KyberPreKeyEntity::try_from(last_resort_keys[0].clone())?)
236        } else {
237            pq_last_resort_key
238                .map(KyberPreKeyEntity::try_from)
239                .transpose()?
240        };
241
242        let identity_key = *identity_key_pair.identity_key();
243
244        let pre_keys: Vec<_> = pre_keys
245            .into_iter()
246            .map(PreKeyEntity::try_from)
247            .collect::<Result<_, _>>()?;
248        let signed_pre_key = signed_pre_key.try_into()?;
249        let pq_pre_keys: Vec<_> = pq_pre_keys
250            .into_iter()
251            .map(KyberPreKeyEntity::try_from)
252            .collect::<Result<_, _>>()?;
253
254        tracing::info!(
255            "Uploading pre-keys: {} one-time, {} PQ, {} PQ last resort",
256            pre_keys.len(),
257            pq_pre_keys.len(),
258            if pq_last_resort_key.is_some() { 1 } else { 0 }
259        );
260
261        let pre_key_state = PreKeyState {
262            pre_keys,
263            signed_pre_key,
264            identity_key,
265            pq_pre_keys,
266            pq_last_resort_key,
267        };
268
269        self.service
270            .register_pre_keys(service_id_kind, pre_key_state)
271            .instrument(tracing::span!(
272                tracing::Level::DEBUG,
273                "Uploading pre keys"
274            ))
275            .await?;
276
277        Ok(())
278    }
279
280    async fn new_device_provisioning_code(
281        &mut self,
282    ) -> Result<String, ServiceError> {
283        #[derive(serde::Deserialize)]
284        #[serde(rename_all = "camelCase")]
285        struct DeviceCode {
286            verification_code: String,
287        }
288
289        let dc: DeviceCode = self
290            .service
291            .request(
292                Method::GET,
293                Endpoint::service("/v1/devices/provisioning/code"),
294                HttpAuthOverride::NoOverride,
295            )?
296            .send()
297            .await?
298            .service_error_for_status()
299            .await?
300            .json()
301            .await?;
302
303        Ok(dc.verification_code)
304    }
305
306    async fn send_provisioning_message(
307        &mut self,
308        destination: &str,
309        env: ProvisionEnvelope,
310    ) -> Result<(), ServiceError> {
311        #[derive(serde::Serialize)]
312        struct ProvisioningMessage {
313            body: String,
314        }
315
316        let body = env.encode_to_vec();
317
318        self.service
319            .request(
320                Method::PUT,
321                Endpoint::service(format!("/v1/provisioning/{destination}")),
322                HttpAuthOverride::NoOverride,
323            )?
324            .json(&ProvisioningMessage {
325                body: BASE64_RELAXED.encode(body),
326            })
327            .send()
328            .await?
329            .service_error_for_status()
330            .await?;
331
332        Ok(())
333    }
334
335    /// Link a new device, given a tsurl.
336    ///
337    /// Equivalent of Java's `AccountManager::addDevice()`
338    ///
339    /// When calling this, make sure that UnidentifiedDelivery is disabled, ie., that your
340    /// application does not send any unidentified messages before linking is complete.
341    /// Cfr.:
342    /// - `app/src/main/java/org/thoughtcrime/securesms/migrations/LegacyMigrationJob.java`:250 and;
343    /// - `app/src/main/java/org/thoughtcrime/securesms/DeviceActivity.java`:195
344    ///
345    /// ```java
346    /// TextSecurePreferences.setIsUnidentifiedDeliveryEnabled(context, false);
347    /// ```
348    pub async fn link_device<R: Rng + CryptoRng>(
349        &mut self,
350        csprng: &mut R,
351        url: url::Url,
352        aci_identity_store: &dyn IdentityKeyStore,
353        pni_identity_store: &dyn IdentityKeyStore,
354        credentials: ServiceCredentials,
355        master_key: Option<MasterKey>,
356    ) -> Result<(), ProvisioningError> {
357        let query: HashMap<_, _> = url.query_pairs().collect();
358        let ephemeral_id =
359            query.get("uuid").ok_or(ProvisioningError::MissingUuid)?;
360        let pub_key = query
361            .get("pub_key")
362            .ok_or(ProvisioningError::MissingPublicKey)?;
363        let pub_key = BASE64_RELAXED
364            .decode(&**pub_key)
365            .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
366        let pub_key = PublicKey::deserialize(&pub_key)
367            .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
368
369        let aci_identity_key_pair =
370            aci_identity_store.get_identity_key_pair().await?;
371        let pni_identity_key_pair =
372            pni_identity_store.get_identity_key_pair().await?;
373
374        if credentials.aci.is_none() {
375            tracing::warn!("No local ACI set");
376        }
377        if credentials.pni.is_none() {
378            tracing::warn!("No local PNI set");
379        }
380
381        let provisioning_code = self.new_device_provisioning_code().await?;
382
383        let msg = ProvisionMessage {
384            aci: credentials.aci.as_ref().map(|u| u.to_string()),
385            aci_identity_key_public: Some(
386                aci_identity_key_pair.public_key().serialize().into_vec(),
387            ),
388            aci_identity_key_private: Some(
389                aci_identity_key_pair.private_key().serialize(),
390            ),
391            number: Some(credentials.e164()),
392            pni_identity_key_public: Some(
393                pni_identity_key_pair.public_key().serialize().into_vec(),
394            ),
395            pni_identity_key_private: Some(
396                pni_identity_key_pair.private_key().serialize(),
397            ),
398            pni: credentials.pni.as_ref().map(uuid::Uuid::to_string),
399            profile_key: self.profile_key.as_ref().map(|x| x.bytes.to_vec()),
400            // CURRENT is not exposed by prost :(
401            provisioning_version: Some(i32::from(
402                ProvisioningVersion::TabletSupport,
403            ) as _),
404            provisioning_code: Some(provisioning_code),
405            read_receipts: None,
406            user_agent: None,
407            master_key: master_key.map(|x| x.into()),
408        };
409
410        let cipher = ProvisioningCipher::from_public(pub_key);
411
412        let encrypted = cipher.encrypt(csprng, msg)?;
413        self.send_provisioning_message(ephemeral_id, encrypted)
414            .await?;
415        Ok(())
416    }
417
418    pub async fn linked_devices(
419        &mut self,
420        aci_identity_store: &dyn IdentityKeyStore,
421    ) -> Result<Vec<DeviceInfo>, ServiceError> {
422        let device_infos = self.service.devices().await?;
423        let aci_identity_keypair =
424            aci_identity_store.get_identity_key_pair().await?;
425
426        device_infos
427            .into_iter()
428            .map(|i| {
429                Ok(DeviceInfo {
430                    id: i.id,
431                    name: i
432                        .name
433                        .map(|s| {
434                            decrypt_device_name_from_device_info(
435                                &s,
436                                &aci_identity_keypair,
437                            )
438                        })
439                        .transpose()?,
440                    created: i.created,
441                    last_seen: i.last_seen,
442                })
443            })
444            .collect()
445    }
446
447    pub async fn register_account<
448        R: Rng + CryptoRng,
449        Aci: PreKeysStore + IdentityKeyStore,
450        Pni: PreKeysStore + IdentityKeyStore,
451    >(
452        &mut self,
453        csprng: &mut R,
454        registration_method: RegistrationMethod<'_>,
455        account_attributes: AccountAttributes,
456        aci_protocol_store: &mut Aci,
457        pni_protocol_store: &mut Pni,
458        skip_device_transfer: bool,
459    ) -> Result<VerifyAccountResponse, ProvisioningError> {
460        let aci_identity_key_pair = aci_protocol_store
461            .get_identity_key_pair()
462            .instrument(tracing::trace_span!("get ACI identity key pair"))
463            .await?;
464        let pni_identity_key_pair = pni_protocol_store
465            .get_identity_key_pair()
466            .instrument(tracing::trace_span!("get PNI identity key pair"))
467            .await?;
468
469        let (
470            _aci_pre_keys,
471            aci_signed_pre_key,
472            _aci_kyber_pre_keys,
473            aci_last_resort_kyber_prekey,
474        ) = crate::pre_keys::replenish_pre_keys(
475            aci_protocol_store,
476            csprng,
477            &aci_identity_key_pair,
478            true,
479            0,
480            0,
481        )
482        .await?;
483
484        let (
485            _pni_pre_keys,
486            pni_signed_pre_key,
487            _pni_kyber_pre_keys,
488            pni_last_resort_kyber_prekey,
489        ) = crate::pre_keys::replenish_pre_keys(
490            pni_protocol_store,
491            csprng,
492            &pni_identity_key_pair,
493            true,
494            0,
495            0,
496        )
497        .await?;
498
499        let aci_identity_key = aci_identity_key_pair.identity_key();
500        let pni_identity_key = pni_identity_key_pair.identity_key();
501
502        let dar = DeviceActivationRequest {
503            aci_signed_pre_key: aci_signed_pre_key.try_into()?,
504            pni_signed_pre_key: pni_signed_pre_key.try_into()?,
505            aci_pq_last_resort_pre_key: aci_last_resort_kyber_prekey
506                .expect("requested last resort prekey")
507                .try_into()?,
508            pni_pq_last_resort_pre_key: pni_last_resort_kyber_prekey
509                .expect("requested last resort prekey")
510                .try_into()?,
511        };
512
513        let result = self
514            .service
515            .submit_registration_request(
516                registration_method,
517                account_attributes,
518                skip_device_transfer,
519                aci_identity_key,
520                pni_identity_key,
521                dar,
522            )
523            .await?;
524
525        Ok(result)
526    }
527
528    /// Upload a profile
529    ///
530    /// Panics if no `profile_key` was set.
531    ///
532    /// Convenience method for
533    /// ```ignore
534    /// manager.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _>(uuid, name, about, about_emoji, _)
535    /// ```
536    /// in which the `retain_avatar` parameter sets whether to remove (`false`) or retain (`true`) the
537    /// currently set avatar.
538    pub async fn upload_versioned_profile_without_avatar<
539        R: Rng + CryptoRng,
540        S: AsRef<str>,
541    >(
542        &mut self,
543        aci: libsignal_protocol::Aci,
544        name: ProfileName<S>,
545        about: Option<String>,
546        about_emoji: Option<String>,
547        retain_avatar: bool,
548        csprng: &mut R,
549    ) -> Result<(), ProfileManagerError> {
550        self.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _, _>(
551            aci,
552            name,
553            about,
554            about_emoji,
555            if retain_avatar {
556                AvatarWrite::RetainAvatar
557            } else {
558                AvatarWrite::NoAvatar
559            },
560            csprng,
561        )
562        .await?;
563        Ok(())
564    }
565
566    pub async fn retrieve_profile(
567        &mut self,
568        address: Aci,
569    ) -> Result<Profile, ProfileManagerError> {
570        let profile_key =
571            self.profile_key.expect("set profile key in AccountManager");
572
573        let encrypted_profile = self
574            .service
575            .retrieve_profile_by_id(address, Some(profile_key))
576            .await?;
577
578        let profile_cipher = ProfileCipher::new(profile_key);
579        Ok(profile_cipher.decrypt(encrypted_profile)?)
580    }
581
582    /// Upload a profile
583    ///
584    /// Panics if no `profile_key` was set.
585    ///
586    /// Returns the avatar url path.
587    pub async fn upload_versioned_profile<
588        's,
589        C: std::io::Read + Send + 's,
590        R: Rng + CryptoRng,
591        S: AsRef<str>,
592    >(
593        &mut self,
594        aci: libsignal_protocol::Aci,
595        name: ProfileName<S>,
596        about: Option<String>,
597        about_emoji: Option<String>,
598        avatar: AvatarWrite<&'s mut C>,
599        csprng: &mut R,
600    ) -> Result<Option<String>, ProfileManagerError> {
601        let profile_key =
602            self.profile_key.expect("set profile key in AccountManager");
603        let profile_cipher = ProfileCipher::new(profile_key);
604
605        // Profile encryption
606        let name = profile_cipher.encrypt_name(name.as_ref(), csprng)?;
607        let about = about.unwrap_or_default();
608        let about = profile_cipher.encrypt_about(about, csprng)?;
609        let about_emoji = about_emoji.unwrap_or_default();
610        let about_emoji = profile_cipher.encrypt_emoji(about_emoji, csprng)?;
611
612        // If avatar -> upload
613        if matches!(avatar, AvatarWrite::NewAvatar(_)) {
614            // FIXME ProfileCipherOutputStream.java
615            // It's just AES GCM, but a bit of work to decently implement it with a stream.
616            unimplemented!("Setting avatar requires ProfileCipherStream")
617        }
618
619        let profile_key = profile_cipher.into_inner();
620        let commitment = profile_key.get_commitment(aci);
621        let profile_key_version = profile_key.get_profile_key_version(aci);
622
623        Ok(self
624            .service
625            .write_profile::<C, S>(
626                &profile_key_version,
627                &name,
628                &about,
629                &about_emoji,
630                &commitment,
631                avatar,
632            )
633            .await?)
634    }
635
636    /// Set profile attributes
637    ///
638    /// Signal Android does not allow unsetting voice/video.
639    pub async fn set_account_attributes(
640        &mut self,
641        attributes: AccountAttributes,
642    ) -> Result<(), ServiceError> {
643        self.service.set_account_attributes(attributes).await
644    }
645
646    /// Update (encrypted) device name
647    pub async fn update_device_name<R: Rng + CryptoRng>(
648        &mut self,
649        device_name: &str,
650        public_key: &IdentityKey,
651        csprng: &mut R,
652    ) -> Result<(), ServiceError> {
653        let encrypted_device_name =
654            encrypt_device_name(csprng, device_name, public_key)?;
655
656        #[derive(Serialize)]
657        #[serde(rename_all = "camelCase")]
658        struct Data {
659            #[serde(with = "serde_base64")]
660            device_name: Vec<u8>,
661        }
662
663        self.service
664            .request(
665                Method::PUT,
666                Endpoint::service("/v1/accounts/name"),
667                HttpAuthOverride::NoOverride,
668            )?
669            .json(&Data {
670                device_name: encrypted_device_name.encode_to_vec(),
671            })
672            .send()
673            .await?
674            .service_error_for_status()
675            .await?;
676
677        Ok(())
678    }
679
680    /// Upload a proof-required reCaptcha token and response.
681    ///
682    /// Token gotten originally with HTTP status 428 response to sending a message.
683    /// Captcha gotten from user completing the challenge captcha.
684    ///
685    /// It's either a silent OK, or throws a ServiceError.
686    pub async fn submit_recaptcha_challenge(
687        &mut self,
688        token: &str,
689        captcha: &str,
690    ) -> Result<(), ServiceError> {
691        self.service
692            .request(
693                Method::PUT,
694                Endpoint::service("/v1/challenge"),
695                HttpAuthOverride::NoOverride,
696            )?
697            .json(&CaptchaAttributes {
698                challenge_type: "captcha",
699                token,
700                captcha,
701            })
702            .send()
703            .await?
704            .service_error_for_status()
705            .await?;
706
707        Ok(())
708    }
709
710    /// Initialize PNI on linked devices.
711    ///
712    /// Should be called as the primary device to migrate from pre-PNI to PNI.
713    ///
714    /// This is the equivalent of Android's PnpInitializeDevicesJob or iOS' PniHelloWorldManager.
715    #[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci = local_aci.service_id_string()))]
716    pub async fn pnp_initialize_devices<
717        R: Rng + CryptoRng,
718        AciStore: PreKeysStore + SessionStoreExt,
719        PniStore: PreKeysStore,
720        AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
721    >(
722        &mut self,
723        aci_protocol_store: &mut AciStore,
724        pni_protocol_store: &mut PniStore,
725        mut sender: MessageSender<AciOrPni>,
726        local_aci: Aci,
727        e164: PhoneNumber,
728        csprng: &mut R,
729    ) -> Result<(), MessageSenderError> {
730        let pni_identity_key_pair =
731            pni_protocol_store.get_identity_key_pair().await?;
732
733        let pni_identity_key = pni_identity_key_pair.identity_key();
734
735        // For every linked device, we generate a new set of pre-keys, and send them to the device.
736        let local_device_ids = aci_protocol_store
737            .get_sub_device_sessions(&local_aci.into())
738            .await?;
739
740        let mut device_messages =
741            Vec::<OutgoingPushMessage>::with_capacity(local_device_ids.len());
742        let mut device_pni_signed_prekeys =
743            HashMap::<String, SignedPreKeyEntity>::with_capacity(
744                local_device_ids.len(),
745            );
746        let mut device_pni_last_resort_kyber_prekeys =
747            HashMap::<String, KyberPreKeyEntity>::with_capacity(
748                local_device_ids.len(),
749            );
750        let mut pni_registration_ids =
751            HashMap::<String, u32>::with_capacity(local_device_ids.len());
752
753        let signature_valid_on_each_signed_pre_key = true;
754        for local_device_id in
755            std::iter::once(DEFAULT_DEVICE_ID).chain(local_device_ids)
756        {
757            let local_protocol_address =
758                local_aci.to_protocol_address(local_device_id);
759            let span = tracing::trace_span!(
760                "filtering devices",
761                address = %local_protocol_address
762            );
763            // Skip if we don't have a session with the device
764            if (local_device_id != DEFAULT_DEVICE_ID)
765                && aci_protocol_store
766                    .load_session(&local_protocol_address)
767                    .instrument(span)
768                    .await?
769                    .is_none()
770            {
771                tracing::warn!(
772                    "No session with device {}, skipping PNI provisioning",
773                    local_device_id
774                );
775                continue;
776            }
777            let (
778                _pre_keys,
779                signed_pre_key,
780                _kyber_pre_keys,
781                last_resort_kyber_prekey,
782            ) = if local_device_id == DEFAULT_DEVICE_ID {
783                crate::pre_keys::replenish_pre_keys(
784                    pni_protocol_store,
785                    csprng,
786                    &pni_identity_key_pair,
787                    true,
788                    0,
789                    0,
790                )
791                .await?
792            } else {
793                // Generate a signed prekey
794                let signed_pre_key_pair = KeyPair::generate(csprng);
795                let signed_pre_key_public = signed_pre_key_pair.public_key;
796                let signed_pre_key_signature =
797                    pni_identity_key_pair.private_key().calculate_signature(
798                        &signed_pre_key_public.serialize(),
799                        csprng,
800                    )?;
801
802                let signed_prekey_record = SignedPreKeyRecord::new(
803                    csprng.gen_range::<u32, _>(0..0xFFFFFF).into(),
804                    Timestamp::now(),
805                    &signed_pre_key_pair,
806                    &signed_pre_key_signature,
807                );
808
809                // Generate a last-resort Kyber prekey
810                let kyber_pre_key_record = KyberPreKeyRecord::generate(
811                    kem::KeyType::Kyber1024,
812                    csprng.gen_range::<u32, _>(0..0xFFFFFF).into(),
813                    pni_identity_key_pair.private_key(),
814                )?;
815                (
816                    vec![],
817                    signed_prekey_record,
818                    vec![],
819                    Some(kyber_pre_key_record),
820                )
821            };
822
823            let registration_id = if local_device_id == DEFAULT_DEVICE_ID {
824                pni_protocol_store.get_local_registration_id().await?
825            } else {
826                loop {
827                    let regid = generate_registration_id(csprng);
828                    if !pni_registration_ids.iter().any(|(_k, v)| *v == regid) {
829                        break regid;
830                    }
831                }
832            };
833
834            let local_device_id_s = local_device_id.to_string();
835            device_pni_signed_prekeys.insert(
836                local_device_id_s.clone(),
837                SignedPreKeyEntity::try_from(&signed_pre_key)?,
838            );
839            device_pni_last_resort_kyber_prekeys.insert(
840                local_device_id_s.clone(),
841                KyberPreKeyEntity::try_from(
842                    last_resort_kyber_prekey
843                        .as_ref()
844                        .expect("requested last resort key"),
845                )?,
846            );
847            pni_registration_ids
848                .insert(local_device_id_s.clone(), registration_id);
849
850            assert!(_pre_keys.is_empty());
851            assert!(_kyber_pre_keys.is_empty());
852
853            if local_device_id == DEFAULT_DEVICE_ID {
854                // This is the primary device
855                // We don't need to send a message to the primary device
856                continue;
857            }
858            // cfr. SignalServiceMessageSender::getEncryptedSyncPniInitializeDeviceMessage
859            let msg = SyncMessage {
860                pni_change_number: Some(PniChangeNumber {
861                    identity_key_pair: Some(
862                        pni_identity_key_pair.serialize().to_vec(),
863                    ),
864                    signed_pre_key: Some(signed_pre_key.serialize()?),
865                    last_resort_kyber_pre_key: Some(
866                        last_resort_kyber_prekey
867                            .expect("requested last resort key")
868                            .serialize()?,
869                    ),
870                    registration_id: Some(registration_id),
871                    new_e164: Some(
872                        e164.format().mode(phonenumber::Mode::E164).to_string(),
873                    ),
874                }),
875                padding: Some(random_length_padding(csprng, 512)),
876                ..SyncMessage::default()
877            };
878            let content: ContentBody = msg.into();
879            let msg = sender
880                .create_encrypted_message(
881                    &local_aci.into(),
882                    None,
883                    local_device_id.into(),
884                    &content.into_proto().encode_to_vec(),
885                )
886                .await?;
887            device_messages.push(msg);
888        }
889
890        self.service
891            .distribute_pni_keys(
892                pni_identity_key,
893                device_messages,
894                device_pni_signed_prekeys,
895                device_pni_last_resort_kyber_prekeys,
896                pni_registration_ids,
897                signature_valid_on_each_signed_pre_key,
898            )
899            .await?;
900
901        Ok(())
902    }
903}
904
905#[expect(clippy::result_large_err)]
906fn calculate_hmac256(
907    mac_key: &[u8],
908    ciphertext: &[u8],
909) -> Result<Output<Hmac<Sha256>>, ServiceError> {
910    let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
911        .map_err(|_| ServiceError::MacError)?;
912    mac.update(ciphertext);
913    Ok(mac.finalize().into_bytes())
914}
915
916#[expect(clippy::result_large_err)]
917pub fn encrypt_device_name<R: rand::Rng + rand::CryptoRng>(
918    csprng: &mut R,
919    device_name: &str,
920    identity_public: &IdentityKey,
921) -> Result<DeviceName, ServiceError> {
922    let plaintext = device_name.as_bytes().to_vec();
923    let ephemeral_key_pair = KeyPair::generate(csprng);
924
925    let master_secret = ephemeral_key_pair
926        .private_key
927        .calculate_agreement(identity_public.public_key())?;
928
929    let key1 = calculate_hmac256(&master_secret, b"auth")?;
930    let synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
931    let synthetic_iv = &synthetic_iv[..16];
932
933    let key2 = calculate_hmac256(&master_secret, b"cipher")?;
934    let cipher_key = calculate_hmac256(&key2, synthetic_iv)?;
935
936    let mut ciphertext = plaintext;
937
938    const IV: [u8; 16] = [0; 16];
939    let mut cipher = Aes256Ctr128BE::new(&cipher_key, &IV.into());
940    cipher.apply_keystream(&mut ciphertext);
941
942    let device_name = DeviceName {
943        ephemeral_public: Some(
944            ephemeral_key_pair.public_key.serialize().to_vec(),
945        ),
946        synthetic_iv: Some(synthetic_iv.to_vec()),
947        ciphertext: Some(ciphertext),
948    };
949
950    Ok(device_name)
951}
952
953#[expect(clippy::result_large_err)]
954fn decrypt_device_name_from_device_info(
955    string: &str,
956    aci: &IdentityKeyPair,
957) -> Result<String, ServiceError> {
958    let data = BASE64_RELAXED.decode(string)?;
959    let name = DeviceName::decode(&*data)?;
960    crate::decrypt_device_name(aci.private_key(), &name)
961}
962
963#[expect(clippy::result_large_err)]
964pub fn decrypt_device_name(
965    private_key: &PrivateKey,
966    device_name: &DeviceName,
967) -> Result<String, ServiceError> {
968    let DeviceName {
969        ephemeral_public: Some(ephemeral_public),
970        synthetic_iv: Some(synthetic_iv),
971        ciphertext: Some(ciphertext),
972    } = device_name
973    else {
974        return Err(ServiceError::InvalidDeviceName);
975    };
976
977    let synthetic_iv: [u8; 16] = synthetic_iv[..synthetic_iv.len().min(16)]
978        .try_into()
979        .map_err(|_| ServiceError::MacError)?;
980
981    let ephemeral_public = PublicKey::deserialize(ephemeral_public)?;
982
983    let master_secret = private_key.calculate_agreement(&ephemeral_public)?;
984    let key2 = calculate_hmac256(&master_secret, b"cipher")?;
985    let cipher_key = calculate_hmac256(&key2, &synthetic_iv)?;
986
987    let mut plaintext = ciphertext.to_vec();
988    const IV: [u8; 16] = [0; 16];
989    let mut cipher =
990        Aes256Ctr128BE::new(cipher_key.as_slice().into(), &IV.into());
991    cipher.apply_keystream(&mut plaintext);
992
993    let key1 = calculate_hmac256(&master_secret, b"auth")?;
994    let our_synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
995    let our_synthetic_iv = &our_synthetic_iv[..16];
996
997    if synthetic_iv != our_synthetic_iv {
998        Err(ServiceError::MacError)
999    } else {
1000        Ok(String::from_utf8_lossy(&plaintext).to_string())
1001    }
1002}
1003
1004#[cfg(test)]
1005mod tests {
1006    use crate::utils::BASE64_RELAXED;
1007    use base64::Engine;
1008    use libsignal_protocol::{IdentityKeyPair, PrivateKey, PublicKey};
1009
1010    use super::DeviceName;
1011
1012    #[test]
1013    fn encrypt_device_name() -> anyhow::Result<()> {
1014        let input_device_name = "Nokia 3310 Millenial Edition";
1015        let mut csprng = rand::thread_rng();
1016        let identity = IdentityKeyPair::generate(&mut csprng);
1017
1018        let device_name = super::encrypt_device_name(
1019            &mut csprng,
1020            input_device_name,
1021            identity.identity_key(),
1022        )?;
1023
1024        let decrypted_device_name =
1025            super::decrypt_device_name(identity.private_key(), &device_name)?;
1026
1027        assert_eq!(input_device_name, decrypted_device_name);
1028
1029        Ok(())
1030    }
1031
1032    #[test]
1033    fn decrypt_device_name() -> anyhow::Result<()> {
1034        let ephemeral_private_key = PrivateKey::deserialize(
1035            &BASE64_RELAXED
1036                .decode("0CgxHjwwblXjvX8sD5wZDWdYToMRf+CZSlgaUrxCGVo=")?,
1037        )?;
1038        let ephemeral_public_key = PublicKey::deserialize(
1039            &BASE64_RELAXED
1040                .decode("BcZS+Lt6yAKbEpXnRX+I5wHqesuvu93Q2V+fjidwW8R6")?,
1041        )?;
1042
1043        let device_name = DeviceName {
1044            ephemeral_public: Some(ephemeral_public_key.serialize().to_vec()),
1045            synthetic_iv: Some(
1046                BASE64_RELAXED.decode("86gekHGmltnnZ9QARhiFcg==")?,
1047            ),
1048            ciphertext: Some(
1049                BASE64_RELAXED
1050                    .decode("MtJ9/9KBWLBVAxfZJD4pLKzP4q+iodRJeCc+/A==")?,
1051            ),
1052        };
1053
1054        let decrypted_device_name =
1055            super::decrypt_device_name(&ephemeral_private_key, &device_name)?;
1056
1057        assert_eq!(decrypted_device_name, "Nokia 3310 Millenial Edition");
1058
1059        Ok(())
1060    }
1061}