Skip to main content

libsignal_service/
account_manager.rs

1use base64::prelude::*;
2use libsignal_core::{DeviceId, E164};
3use rand::{CryptoRng, Rng};
4use reqwest::Method;
5use std::collections::HashMap;
6use std::convert::{TryFrom, TryInto};
7
8use aes::cipher::{KeyIvInit, StreamCipher as _};
9use hmac::digest::Output;
10use hmac::{Hmac, Mac};
11use libsignal_protocol::{
12    kem, Aci, GenericSignedPreKey, IdentityKey, IdentityKeyPair,
13    IdentityKeyStore, KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore,
14    PublicKey, SenderKeyStore, ServiceIdKind, SignedPreKeyRecord, Timestamp,
15};
16use prost::Message;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19use tracing_futures::Instrument;
20use zkgroup::profiles::ProfileKey;
21
22use crate::content::ContentBody;
23use crate::pre_keys::{
24    KyberPreKeyEntity, PreKeyEntity, PreKeysStore, SignedPreKeyEntity,
25    PRE_KEY_BATCH_SIZE, PRE_KEY_MINIMUM,
26};
27use crate::prelude::{MessageSender, MessageSenderError};
28use crate::proto::sync_message::PniChangeNumber;
29use crate::proto::{DeviceName, SyncMessage};
30use crate::provisioning::{generate_registration_id, ProvisioningSecrets};
31use crate::push_service::{
32    AvatarWrite, HttpAuthOverride, ReqwestExt, DEFAULT_DEVICE_ID,
33};
34use crate::sender::OutgoingPushMessage;
35use crate::service_address::ServiceIdExt;
36use crate::session_store::SessionStoreExt;
37use crate::timestamp::TimestampExt as _;
38use crate::utils::{random_length_padding, BASE64_RELAXED};
39use crate::websocket::account::DeviceInfo;
40use crate::websocket::keys::PreKeyStatus;
41use crate::websocket::registration::{
42    CaptchaAttributes, DeviceActivationRequest, RegistrationMethod,
43    VerifyAccountResponse,
44};
45use crate::websocket::{self, SignalWebSocket};
46use crate::{
47    configuration::Endpoint,
48    pre_keys::PreKeyState,
49    profile_cipher::{ProfileCipher, ProfileCipherError},
50    profile_name::ProfileName,
51    proto::{ProvisionEnvelope, ProvisionMessage, ProvisioningVersion},
52    provisioning::{ProvisioningCipher, ProvisioningError},
53    push_service::{PushService, ServiceError},
54    utils::serde_base64,
55    websocket::account::AccountAttributes,
56};
57
58type Aes256Ctr128BE = ctr::Ctr128BE<aes::Aes256>;
59
60pub struct AccountManager {
61    service: PushService,
62    websocket: SignalWebSocket<websocket::Identified>,
63    profile_key: Option<ProfileKey>,
64}
65
66#[derive(thiserror::Error, Debug)]
67pub enum ProfileManagerError {
68    #[error(transparent)]
69    ServiceError(#[from] ServiceError),
70    #[error(transparent)]
71    ProfileCipherError(#[from] ProfileCipherError),
72}
73
74#[derive(Debug, Default, Serialize, Deserialize, Clone)]
75pub struct Profile {
76    pub name: Option<ProfileName<String>>,
77    pub about: Option<String>,
78    pub about_emoji: Option<String>,
79    pub avatar: Option<String>,
80    pub unrestricted_unidentified_access: bool,
81}
82
83impl AccountManager {
84    pub fn new(
85        service: PushService,
86        websocket: SignalWebSocket<websocket::Identified>,
87        profile_key: Option<ProfileKey>,
88    ) -> Self {
89        Self {
90            service,
91            websocket,
92            profile_key,
93        }
94    }
95
96    #[allow(clippy::too_many_arguments)]
97    #[tracing::instrument(skip(self, protocol_store))]
98    pub async fn check_pre_keys<P: PreKeysStore>(
99        &mut self,
100        protocol_store: &mut P,
101        service_id_kind: ServiceIdKind,
102    ) -> Result<bool, ServiceError> {
103        let Some(signed_prekey_id) = protocol_store.signed_prekey_id().await?
104        else {
105            tracing::warn!("No signed prekey found");
106            return Ok(false);
107        };
108        // XXX: should we instead use the `load_last_resort_kyber_pre_keys` method? Or refactor
109        //      those whole traits?
110        let Some(kyber_prekey_id) =
111            protocol_store.last_resort_kyber_prekey_id().await?
112        else {
113            tracing::warn!("No last resort kyber prekey found");
114            return Ok(false);
115        };
116
117        let signed_prekey =
118            protocol_store.get_signed_pre_key(signed_prekey_id).await?;
119        let kyber_prekey =
120            protocol_store.get_kyber_pre_key(kyber_prekey_id).await?;
121
122        // `SHA256(identityKeyBytes || signedEcPreKeyId || signedEcPreKeyIdBytes || lastResortKeyId || lastResortKeyBytes)`
123        let mut hash = Sha256::default();
124        hash.update(
125            protocol_store
126                .get_identity_key_pair()
127                .await?
128                .public_key()
129                .serialize(),
130        );
131        hash.update((u32::from(signed_prekey_id) as u64).to_be_bytes());
132        hash.update(signed_prekey.public_key()?.serialize());
133        hash.update((u32::from(kyber_prekey_id) as u64).to_be_bytes());
134        hash.update(kyber_prekey.public_key()?.serialize());
135
136        self.websocket
137            .check_pre_keys(service_id_kind, hash.finalize().as_ref())
138            .await
139    }
140
141    /// Checks the availability of pre-keys, and updates them as necessary.
142    ///
143    /// Parameters are the protocol's `StoreContext`, and the offsets for the next pre-key and
144    /// signed pre-keys.
145    ///
146    /// Equivalent to Java's RefreshPreKeysJob
147    #[allow(clippy::too_many_arguments)]
148    #[tracing::instrument(skip(self, protocol_store))]
149    pub async fn update_pre_key_bundle<P: PreKeysStore>(
150        &mut self,
151        protocol_store: &mut P,
152        service_id_kind: ServiceIdKind,
153        use_last_resort_key: bool,
154    ) -> Result<(), ServiceError> {
155        let prekey_status = match self
156            .websocket
157            .get_pre_key_status(service_id_kind)
158            .instrument(tracing::span!(
159                tracing::Level::DEBUG,
160                "Fetching pre key status"
161            ))
162            .await
163        {
164            Ok(status) => status,
165            Err(ServiceError::Unauthorized) => {
166                tracing::info!("Got Unauthorized when fetching pre-key status. Assuming first installment.");
167                // Additionally, the second PUT request will fail if this really comes down to an
168                // authorization failure.
169                PreKeyStatus {
170                    count: 0,
171                    pq_count: 0,
172                }
173            },
174            Err(e) => return Err(e),
175        };
176        tracing::trace!("Remaining pre-keys on server: {:?}", prekey_status);
177
178        let check_pre_keys = self
179            .check_pre_keys(protocol_store, service_id_kind)
180            .instrument(tracing::span!(
181                tracing::Level::DEBUG,
182                "Checking pre keys"
183            ))
184            .await?;
185        if !check_pre_keys {
186            tracing::info!(
187                "Last resort pre-keys are not up to date; refreshing."
188            );
189        } else {
190            tracing::debug!("Last resort pre-keys are up to date.");
191        }
192
193        // XXX We should honestly compare the pre-key count with the number of pre-keys we have
194        // locally. If we have more than the server, we should upload them.
195        // Currently the trait doesn't allow us to do that, so we just upload the batch size and
196        // pray.
197        if check_pre_keys
198            && (prekey_status.count >= PRE_KEY_MINIMUM
199                && prekey_status.pq_count >= PRE_KEY_MINIMUM)
200        {
201            if protocol_store.signed_pre_keys_count().await? > 0
202                && protocol_store.kyber_pre_keys_count(true).await? > 0
203                && protocol_store.signed_prekey_id().await?.is_some()
204                && protocol_store
205                    .last_resort_kyber_prekey_id()
206                    .await?
207                    .is_some()
208            {
209                tracing::debug!("Available keys sufficient");
210                return Ok(());
211            }
212            tracing::info!("Available keys sufficient; forcing refresh.");
213        }
214
215        let identity_key_pair = protocol_store
216            .get_identity_key_pair()
217            .instrument(tracing::trace_span!("get identity key pair"))
218            .await?;
219
220        let last_resort_keys = protocol_store
221            .load_last_resort_kyber_pre_keys()
222            .instrument(tracing::trace_span!("fetch last resort key"))
223            .await?;
224
225        // XXX: Maybe this check should be done in the generate_pre_keys function?
226        let has_last_resort_key = !last_resort_keys.is_empty();
227
228        let (pre_keys, signed_pre_key, pq_pre_keys, pq_last_resort_key) =
229            crate::pre_keys::replenish_pre_keys(
230                protocol_store,
231                &mut rand::rng(),
232                &identity_key_pair,
233                use_last_resort_key && !has_last_resort_key,
234                PRE_KEY_BATCH_SIZE,
235                PRE_KEY_BATCH_SIZE,
236            )
237            .await?;
238
239        let pq_last_resort_key = if has_last_resort_key {
240            if last_resort_keys.len() > 1 {
241                tracing::warn!(
242                    "More than one last resort key found; only uploading first"
243                );
244            }
245            Some(KyberPreKeyEntity::try_from(last_resort_keys[0].clone())?)
246        } else {
247            pq_last_resort_key
248                .map(KyberPreKeyEntity::try_from)
249                .transpose()?
250        };
251
252        let identity_key = *identity_key_pair.identity_key();
253
254        let pre_keys: Vec<_> = pre_keys
255            .into_iter()
256            .map(PreKeyEntity::try_from)
257            .collect::<Result<_, _>>()?;
258        let signed_pre_key = signed_pre_key.try_into()?;
259        let pq_pre_keys: Vec<_> = pq_pre_keys
260            .into_iter()
261            .map(KyberPreKeyEntity::try_from)
262            .collect::<Result<_, _>>()?;
263
264        tracing::info!(
265            "Uploading pre-keys: {} one-time, {} PQ, {} PQ last resort",
266            pre_keys.len(),
267            pq_pre_keys.len(),
268            if pq_last_resort_key.is_some() { 1 } else { 0 }
269        );
270
271        let pre_key_state = PreKeyState {
272            pre_keys,
273            signed_pre_key,
274            identity_key,
275            pq_pre_keys,
276            pq_last_resort_key,
277        };
278
279        self.websocket
280            .register_pre_keys(service_id_kind, pre_key_state)
281            .instrument(tracing::span!(
282                tracing::Level::DEBUG,
283                "Uploading pre keys"
284            ))
285            .await?;
286
287        Ok(())
288    }
289
290    async fn new_device_provisioning_code(
291        &mut self,
292    ) -> Result<String, ServiceError> {
293        #[derive(serde::Deserialize)]
294        #[serde(rename_all = "camelCase")]
295        struct DeviceCode {
296            verification_code: String,
297        }
298
299        let dc: DeviceCode = self
300            .service
301            .request(
302                Method::GET,
303                Endpoint::service("/v1/devices/provisioning/code"),
304                HttpAuthOverride::NoOverride,
305            )?
306            .send()
307            .await?
308            .service_error_for_status()
309            .await?
310            .json()
311            .await?;
312
313        Ok(dc.verification_code)
314    }
315
316    async fn send_provisioning_message(
317        &mut self,
318        destination: &str,
319        env: ProvisionEnvelope,
320    ) -> Result<(), ServiceError> {
321        #[derive(serde::Serialize)]
322        struct ProvisioningMessage {
323            body: String,
324        }
325
326        let body = env.encode_to_vec();
327
328        self.service
329            .request(
330                Method::PUT,
331                Endpoint::service(format!("/v1/provisioning/{destination}")),
332                HttpAuthOverride::NoOverride,
333            )?
334            .json(&ProvisioningMessage {
335                body: BASE64_RELAXED.encode(body),
336            })
337            .send()
338            .await?
339            .service_error_for_status()
340            .await?;
341
342        Ok(())
343    }
344
345    /// Link a new device, given a tsurl.
346    ///
347    /// Equivalent of Java's `AccountManager::addDevice()`
348    ///
349    /// When calling this, make sure that UnidentifiedDelivery is disabled, ie., that your
350    /// application does not send any unidentified messages before linking is complete.
351    /// Cfr.:
352    /// - `app/src/main/java/org/thoughtcrime/securesms/migrations/LegacyMigrationJob.java`:250 and;
353    /// - `app/src/main/java/org/thoughtcrime/securesms/DeviceActivity.java`:195
354    ///
355    /// ```java
356    /// TextSecurePreferences.setIsUnidentifiedDeliveryEnabled(context, false);
357    /// ```
358    pub async fn link_device<R: Rng + CryptoRng>(
359        &mut self,
360        csprng: &mut R,
361        url: url::Url,
362        aci_identity_store: &dyn IdentityKeyStore,
363        pni_identity_store: &dyn IdentityKeyStore,
364        secrets: ProvisioningSecrets,
365    ) -> Result<(), ProvisioningError> {
366        let ProvisioningSecrets {
367            credentials,
368            master_key,
369            ephemeral_backup_key,
370            account_entropy_pool,
371            media_root_backup_key,
372        } = secrets;
373
374        let query: HashMap<_, _> = url.query_pairs().collect();
375        let ephemeral_id =
376            query.get("uuid").ok_or(ProvisioningError::MissingUuid)?;
377        let pub_key = query
378            .get("pub_key")
379            .ok_or(ProvisioningError::MissingPublicKey)?;
380
381        let pub_key = BASE64_RELAXED
382            .decode(&**pub_key)
383            .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
384        let pub_key = PublicKey::deserialize(&pub_key)
385            .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
386
387        let aci_identity_key_pair =
388            aci_identity_store.get_identity_key_pair().await?;
389        let pni_identity_key_pair =
390            pni_identity_store.get_identity_key_pair().await?;
391
392        if credentials.aci.is_none() {
393            tracing::warn!("No local ACI set");
394        }
395        if credentials.pni.is_none() {
396            tracing::warn!("No local PNI set");
397        }
398
399        let provisioning_code = self.new_device_provisioning_code().await?;
400
401        // Same order as in Provisioning.proto for helping comparison.
402        let msg = ProvisionMessage {
403            aci_identity_key_public: Some(
404                aci_identity_key_pair.public_key().serialize().into_vec(),
405            ),
406            aci_identity_key_private: Some(
407                aci_identity_key_pair.private_key().serialize(),
408            ),
409            pni_identity_key_public: Some(
410                pni_identity_key_pair.public_key().serialize().into_vec(),
411            ),
412            pni_identity_key_private: Some(
413                pni_identity_key_pair.private_key().serialize(),
414            ),
415            aci: credentials.aci.as_ref().map(|u| u.to_string()),
416            pni: credentials.pni.as_ref().map(uuid::Uuid::to_string),
417            number: Some(credentials.e164()),
418            provisioning_code: Some(provisioning_code),
419            user_agent: None,
420            profile_key: self.profile_key.as_ref().map(|x| x.bytes.to_vec()),
421            // CURRENT is not exposed by prost :(
422            read_receipts: None,
423            provisioning_version: Some(i32::from(
424                ProvisioningVersion::TabletSupport,
425            ) as _),
426            master_key: master_key.map(|x| x.into()),
427            ephemeral_backup_key: ephemeral_backup_key.map(|k| k.to_vec()), // 32-bytes
428            account_entropy_pool: Some(account_entropy_pool.to_string()),
429            media_root_backup_key: media_root_backup_key.map(|k| k.to_vec()), // 32-bytes
430            aci_binary: credentials.aci.map(|u| u.into_bytes().into()),
431            pni_binary: credentials.pni.map(|u| u.into_bytes().into()),
432        };
433
434        let cipher = ProvisioningCipher::from_public(pub_key);
435
436        let encrypted = cipher.encrypt(csprng, msg)?;
437        self.send_provisioning_message(ephemeral_id, encrypted)
438            .await?;
439        Ok(())
440    }
441
442    pub async fn linked_devices(
443        &mut self,
444        aci_identity_store: &dyn IdentityKeyStore,
445    ) -> Result<Vec<DeviceInfo>, ServiceError> {
446        let device_infos = self.websocket.devices().await?;
447        let aci_identity_keypair =
448            aci_identity_store.get_identity_key_pair().await?;
449
450        device_infos
451            .into_iter()
452            .map(|i| {
453                Ok(DeviceInfo {
454                    id: i.id,
455                    name: i.name.and_then(|s| {
456                        match decrypt_device_name_from_device_info(
457                            &s,
458                            &aci_identity_keypair,
459                        ) {
460                            Ok(name) => Some(name),
461                            Err(e) => {
462                                tracing::error!("{e}");
463                                None
464                            },
465                        }
466                    }),
467                    registration_id: i.registration_id,
468                    last_seen: i.last_seen,
469                    created_at: decrypt_device_created_at_from_device_info(
470                        i.id,
471                        i.registration_id,
472                        &i.created_at_ciphertext,
473                        &aci_identity_keypair,
474                    )?,
475                })
476            })
477            .collect()
478    }
479
480    pub async fn register_account<
481        R: Rng + CryptoRng,
482        Aci: PreKeysStore + IdentityKeyStore,
483        Pni: PreKeysStore + IdentityKeyStore,
484    >(
485        &mut self,
486        csprng: &mut R,
487        registration_method: RegistrationMethod<'_>,
488        account_attributes: AccountAttributes,
489        aci_protocol_store: &mut Aci,
490        pni_protocol_store: &mut Pni,
491        skip_device_transfer: bool,
492    ) -> Result<VerifyAccountResponse, ProvisioningError> {
493        let aci_identity_key_pair = aci_protocol_store
494            .get_identity_key_pair()
495            .instrument(tracing::trace_span!("get ACI identity key pair"))
496            .await?;
497        let pni_identity_key_pair = pni_protocol_store
498            .get_identity_key_pair()
499            .instrument(tracing::trace_span!("get PNI identity key pair"))
500            .await?;
501
502        let (
503            _aci_pre_keys,
504            aci_signed_pre_key,
505            _aci_kyber_pre_keys,
506            aci_last_resort_kyber_prekey,
507        ) = crate::pre_keys::replenish_pre_keys(
508            aci_protocol_store,
509            csprng,
510            &aci_identity_key_pair,
511            true,
512            0,
513            0,
514        )
515        .await?;
516
517        let (
518            _pni_pre_keys,
519            pni_signed_pre_key,
520            _pni_kyber_pre_keys,
521            pni_last_resort_kyber_prekey,
522        ) = crate::pre_keys::replenish_pre_keys(
523            pni_protocol_store,
524            csprng,
525            &pni_identity_key_pair,
526            true,
527            0,
528            0,
529        )
530        .await?;
531
532        let aci_identity_key = aci_identity_key_pair.identity_key();
533        let pni_identity_key = pni_identity_key_pair.identity_key();
534
535        let dar = DeviceActivationRequest {
536            aci_signed_pre_key: aci_signed_pre_key.try_into()?,
537            pni_signed_pre_key: pni_signed_pre_key.try_into()?,
538            aci_pq_last_resort_pre_key: aci_last_resort_kyber_prekey
539                .expect("requested last resort prekey")
540                .try_into()?,
541            pni_pq_last_resort_pre_key: pni_last_resort_kyber_prekey
542                .expect("requested last resort prekey")
543                .try_into()?,
544        };
545
546        let result = self
547            .websocket
548            .submit_registration_request(
549                registration_method,
550                account_attributes,
551                skip_device_transfer,
552                aci_identity_key,
553                pni_identity_key,
554                dar,
555            )
556            .await?;
557
558        Ok(result)
559    }
560
561    /// Upload a profile
562    ///
563    /// Panics if no `profile_key` was set.
564    ///
565    /// Convenience method for
566    /// ```ignore
567    /// manager.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _>(uuid, name, about, about_emoji, _)
568    /// ```
569    /// in which the `retain_avatar` parameter sets whether to remove (`false`) or retain (`true`) the
570    /// currently set avatar.
571    pub async fn upload_versioned_profile_without_avatar<
572        R: Rng + CryptoRng,
573        S: AsRef<str>,
574    >(
575        &mut self,
576        aci: libsignal_protocol::Aci,
577        name: ProfileName<S>,
578        about: Option<String>,
579        about_emoji: Option<String>,
580        retain_avatar: bool,
581        csprng: &mut R,
582    ) -> Result<(), ProfileManagerError> {
583        self.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _, _>(
584            aci,
585            name,
586            about,
587            about_emoji,
588            if retain_avatar {
589                AvatarWrite::RetainAvatar
590            } else {
591                AvatarWrite::NoAvatar
592            },
593            csprng,
594        )
595        .await?;
596        Ok(())
597    }
598
599    pub async fn retrieve_profile(
600        &mut self,
601        address: Aci,
602    ) -> Result<Profile, ProfileManagerError> {
603        let profile_key =
604            self.profile_key.expect("set profile key in AccountManager");
605
606        let encrypted_profile = self
607            .websocket
608            .retrieve_profile_by_id(address, Some(profile_key))
609            .await?;
610
611        let profile_cipher = ProfileCipher::new(profile_key);
612        Ok(profile_cipher.decrypt(encrypted_profile)?)
613    }
614
615    /// Upload a profile
616    ///
617    /// Panics if no `profile_key` was set.
618    ///
619    /// Returns the avatar url path.
620    pub async fn upload_versioned_profile<
621        's,
622        C: std::io::Read + Send + 's,
623        R: Rng + CryptoRng,
624        S: AsRef<str>,
625    >(
626        &mut self,
627        aci: libsignal_protocol::Aci,
628        name: ProfileName<S>,
629        about: Option<String>,
630        about_emoji: Option<String>,
631        avatar: AvatarWrite<&'s mut C>,
632        csprng: &mut R,
633    ) -> Result<Option<String>, ProfileManagerError> {
634        let profile_key =
635            self.profile_key.expect("set profile key in AccountManager");
636        let profile_cipher = ProfileCipher::new(profile_key);
637
638        // Profile encryption
639        let name = profile_cipher.encrypt_name(name.as_ref(), csprng)?;
640        let about = about.unwrap_or_default();
641        let about = profile_cipher.encrypt_about(about, csprng)?;
642        let about_emoji = about_emoji.unwrap_or_default();
643        let about_emoji = profile_cipher.encrypt_emoji(about_emoji, csprng)?;
644
645        // If avatar -> upload
646        if matches!(avatar, AvatarWrite::NewAvatar(_)) {
647            // FIXME ProfileCipherOutputStream.java
648            // It's just AES GCM, but a bit of work to decently implement it with a stream.
649            unimplemented!("Setting avatar requires ProfileCipherStream")
650        }
651
652        let profile_key = profile_cipher.into_inner();
653        let commitment = profile_key.get_commitment(aci);
654        let profile_key_version = profile_key.get_profile_key_version(aci);
655
656        Ok(self
657            .websocket
658            .write_profile::<C, S>(
659                &profile_key_version,
660                &name,
661                &about,
662                &about_emoji,
663                &commitment,
664                avatar,
665            )
666            .await?)
667    }
668
669    /// Set profile attributes
670    ///
671    /// Signal Android does not allow unsetting voice/video.
672    pub async fn set_account_attributes(
673        &mut self,
674        attributes: AccountAttributes,
675    ) -> Result<(), ServiceError> {
676        self.websocket.set_account_attributes(attributes).await
677    }
678
679    /// Update (encrypted) device name
680    pub async fn update_device_name<R: Rng + CryptoRng>(
681        &mut self,
682        device_id: libsignal_core::DeviceId,
683        device_name: &str,
684        aci: Aci,
685        aci_identity_store: &dyn IdentityKeyStore,
686        csprng: &mut R,
687    ) -> Result<(), ServiceError> {
688        let addr = aci.to_protocol_address(device_id).unwrap();
689        let public_key = aci_identity_store.get_identity(&addr).await?;
690        let Some(public_key) = public_key else {
691            return Err(ServiceError::SendError {
692                reason: format!("public key for device {addr:?} not found"),
693            });
694        };
695        let encrypted_device_name =
696            encrypt_device_name(csprng, device_name, &public_key)?;
697
698        #[derive(Serialize)]
699        #[serde(rename_all = "camelCase")]
700        struct Data {
701            #[serde(with = "serde_base64")]
702            device_name: Vec<u8>,
703        }
704
705        self.service
706            .request(
707                Method::PUT,
708                Endpoint::service(format!(
709                    "/v1/accounts/name?deviceId={}",
710                    device_id
711                )),
712                HttpAuthOverride::NoOverride,
713            )?
714            .json(&Data {
715                device_name: encrypted_device_name.encode_to_vec(),
716            })
717            .send()
718            .await?
719            .service_error_for_status()
720            .await?;
721
722        Ok(())
723    }
724
725    /// Upload a proof-required reCaptcha token and response.
726    ///
727    /// Token gotten originally with HTTP status 428 response to sending a message.
728    /// Captcha gotten from user completing the challenge captcha.
729    ///
730    /// It's either a silent OK, or throws a ServiceError.
731    pub async fn submit_recaptcha_challenge(
732        &mut self,
733        token: &str,
734        captcha: &str,
735    ) -> Result<(), ServiceError> {
736        self.service
737            .request(
738                Method::PUT,
739                Endpoint::service("/v1/challenge"),
740                HttpAuthOverride::NoOverride,
741            )?
742            .json(&CaptchaAttributes {
743                challenge_type: "captcha",
744                token,
745                captcha,
746            })
747            .send()
748            .await?
749            .service_error_for_status()
750            .await?;
751
752        Ok(())
753    }
754
755    /// Initialize PNI on linked devices.
756    ///
757    /// Should be called as the primary device to migrate from pre-PNI to PNI.
758    ///
759    /// This is the equivalent of Android's PnpInitializeDevicesJob or iOS' PniHelloWorldManager.
760    #[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci = local_aci.service_id_string()))]
761    pub async fn pnp_initialize_devices<
762        R: Rng + CryptoRng,
763        AciStore: PreKeysStore + SessionStoreExt,
764        PniStore: PreKeysStore,
765        AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
766    >(
767        &mut self,
768        aci_protocol_store: &mut AciStore,
769        pni_protocol_store: &mut PniStore,
770        mut sender: MessageSender<AciOrPni>,
771        local_aci: Aci,
772        e164: E164,
773        csprng: &mut R,
774    ) -> Result<(), MessageSenderError> {
775        let pni_identity_key_pair =
776            pni_protocol_store.get_identity_key_pair().await?;
777
778        let pni_identity_key = pni_identity_key_pair.identity_key();
779
780        // For every linked device, we generate a new set of pre-keys, and send them to the device.
781        let local_device_ids = aci_protocol_store
782            .get_sub_device_sessions(&local_aci.into())
783            .await?;
784
785        let mut device_messages =
786            Vec::<OutgoingPushMessage>::with_capacity(local_device_ids.len());
787        let mut device_pni_signed_prekeys =
788            HashMap::<String, SignedPreKeyEntity>::with_capacity(
789                local_device_ids.len(),
790            );
791        let mut device_pni_last_resort_kyber_prekeys =
792            HashMap::<String, KyberPreKeyEntity>::with_capacity(
793                local_device_ids.len(),
794            );
795        let mut pni_registration_ids =
796            HashMap::<String, u32>::with_capacity(local_device_ids.len());
797
798        let signature_valid_on_each_signed_pre_key = true;
799        for local_device_id in
800            std::iter::once(*DEFAULT_DEVICE_ID).chain(local_device_ids)
801        {
802            let local_protocol_address =
803                local_aci.to_protocol_address(local_device_id)?;
804            let span = tracing::trace_span!(
805                "filtering devices",
806                address = %local_protocol_address
807            );
808            // Skip if we don't have a session with the device
809            if (local_device_id != *DEFAULT_DEVICE_ID)
810                && aci_protocol_store
811                    .load_session(&local_protocol_address)
812                    .instrument(span)
813                    .await?
814                    .is_none()
815            {
816                tracing::warn!(
817                    "No session with device {}, skipping PNI provisioning",
818                    local_device_id
819                );
820                continue;
821            }
822            let (
823                _pre_keys,
824                signed_pre_key,
825                _kyber_pre_keys,
826                last_resort_kyber_prekey,
827            ) = if local_device_id == *DEFAULT_DEVICE_ID {
828                crate::pre_keys::replenish_pre_keys(
829                    pni_protocol_store,
830                    csprng,
831                    &pni_identity_key_pair,
832                    true,
833                    0,
834                    0,
835                )
836                .await?
837            } else {
838                // Generate a signed prekey
839                let signed_pre_key_pair = KeyPair::generate(csprng);
840                let signed_pre_key_public = signed_pre_key_pair.public_key;
841                let signed_pre_key_signature = pni_identity_key_pair
842                    .private_key()
843                    .calculate_signature(
844                        &signed_pre_key_public.serialize(),
845                        csprng,
846                    )
847                    .map_err(MessageSenderError::InvalidPrivateKey)?;
848
849                let signed_prekey_record = SignedPreKeyRecord::new(
850                    csprng.random_range::<u32, _>(0..0xFFFFFF).into(),
851                    Timestamp::now(),
852                    &signed_pre_key_pair,
853                    &signed_pre_key_signature,
854                );
855
856                // Generate a last-resort Kyber prekey
857                let kyber_pre_key_record = KyberPreKeyRecord::generate(
858                    kem::KeyType::Kyber1024,
859                    csprng.random_range::<u32, _>(0..0xFFFFFF).into(),
860                    pni_identity_key_pair.private_key(),
861                )?;
862                (
863                    vec![],
864                    signed_prekey_record,
865                    vec![],
866                    Some(kyber_pre_key_record),
867                )
868            };
869
870            let registration_id = if local_device_id == *DEFAULT_DEVICE_ID {
871                pni_protocol_store.get_local_registration_id().await?
872            } else {
873                loop {
874                    let regid = generate_registration_id(csprng);
875                    if !pni_registration_ids.iter().any(|(_k, v)| *v == regid) {
876                        break regid;
877                    }
878                }
879            };
880
881            let local_device_id_s = local_device_id.to_string();
882            device_pni_signed_prekeys.insert(
883                local_device_id_s.clone(),
884                SignedPreKeyEntity::try_from(&signed_pre_key)?,
885            );
886            device_pni_last_resort_kyber_prekeys.insert(
887                local_device_id_s.clone(),
888                KyberPreKeyEntity::try_from(
889                    last_resort_kyber_prekey
890                        .as_ref()
891                        .expect("requested last resort key"),
892                )?,
893            );
894            pni_registration_ids
895                .insert(local_device_id_s.clone(), registration_id);
896
897            assert!(_pre_keys.is_empty());
898            assert!(_kyber_pre_keys.is_empty());
899
900            if local_device_id == *DEFAULT_DEVICE_ID {
901                // This is the primary device
902                // We don't need to send a message to the primary device
903                continue;
904            }
905            // cfr. SignalServiceMessageSender::getEncryptedSyncPniInitializeDeviceMessage
906            let msg = SyncMessage {
907                pni_change_number: Some(PniChangeNumber {
908                    identity_key_pair: Some(
909                        pni_identity_key_pair.serialize().to_vec(),
910                    ),
911                    signed_pre_key: Some(signed_pre_key.serialize()?),
912                    last_resort_kyber_pre_key: Some(
913                        last_resort_kyber_prekey
914                            .expect("requested last resort key")
915                            .serialize()?,
916                    ),
917                    registration_id: Some(registration_id),
918                    new_e164: Some(e164.to_string()),
919                }),
920                padding: Some(random_length_padding(csprng, 512)),
921                ..SyncMessage::default()
922            };
923            let content: ContentBody = msg.into();
924            let msg = sender
925                .create_encrypted_message(
926                    &local_aci.into(),
927                    None,
928                    local_device_id,
929                    &content.into_proto().encode_to_vec(),
930                )
931                .await?;
932            device_messages.push(msg);
933        }
934
935        self.websocket
936            .distribute_pni_keys(
937                pni_identity_key,
938                device_messages,
939                device_pni_signed_prekeys,
940                device_pni_last_resort_kyber_prekeys,
941                pni_registration_ids,
942                signature_valid_on_each_signed_pre_key,
943            )
944            .await?;
945
946        Ok(())
947    }
948}
949
950fn calculate_hmac256(
951    mac_key: &[u8],
952    ciphertext: &[u8],
953) -> Result<Output<Hmac<Sha256>>, ServiceError> {
954    let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
955        .map_err(|_| ServiceError::MacError)?;
956    mac.update(ciphertext);
957    Ok(mac.finalize().into_bytes())
958}
959
960pub fn encrypt_device_name<R: rand::Rng + rand::CryptoRng>(
961    csprng: &mut R,
962    device_name: &str,
963    identity_public: &IdentityKey,
964) -> Result<DeviceName, ServiceError> {
965    let plaintext = device_name.as_bytes().to_vec();
966    let ephemeral_key_pair = KeyPair::generate(csprng);
967
968    let master_secret = ephemeral_key_pair
969        .private_key
970        .calculate_agreement(identity_public.public_key())?;
971
972    let key1 = calculate_hmac256(&master_secret, b"auth")?;
973    let synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
974    let synthetic_iv = &synthetic_iv[..16];
975
976    let key2 = calculate_hmac256(&master_secret, b"cipher")?;
977    let cipher_key = calculate_hmac256(&key2, synthetic_iv)?;
978
979    let mut ciphertext = plaintext;
980
981    const IV: [u8; 16] = [0; 16];
982    let mut cipher = Aes256Ctr128BE::new(&cipher_key, &IV.into());
983    cipher.apply_keystream(&mut ciphertext);
984
985    let device_name = DeviceName {
986        ephemeral_public: Some(
987            ephemeral_key_pair.public_key.serialize().to_vec(),
988        ),
989        synthetic_iv: Some(synthetic_iv.to_vec()),
990        ciphertext: Some(ciphertext),
991    };
992
993    Ok(device_name)
994}
995
996fn decrypt_device_name_from_device_info(
997    string: &str,
998    aci: &IdentityKeyPair,
999) -> Result<String, ServiceError> {
1000    let Ok(data) = BASE64_RELAXED.decode(string) else {
1001        tracing::trace!(
1002            "Device name not base64 encoded, assuming plaintext: {}",
1003            string.to_string()
1004        );
1005        return Ok(string.to_string());
1006    };
1007    let name = match DeviceName::decode(data.as_slice()) {
1008        Ok(decoded) => decoded,
1009        Err(_) => {
1010            tracing::trace!(
1011                "Decoding device name failed, assuming plaintext: {}",
1012                string.to_string()
1013            );
1014            return Ok(string.to_string());
1015        },
1016    };
1017    crate::decrypt_device_name(aci.private_key(), &name)
1018}
1019
1020// Analogous to https://github.com/signalapp/Signal-Android/blob/d88a862e0985cc2bbc463c5f504f5bb4e91ad4fc/app/src/main/java/org/thoughtcrime/securesms/linkdevice/LinkDeviceRepository.kt#L121.
1021fn decrypt_device_created_at_from_device_info(
1022    id: DeviceId,
1023    registration_id: i32,
1024    string: &str,
1025    aci: &IdentityKeyPair,
1026) -> Result<chrono::DateTime<chrono::Utc>, ServiceError> {
1027    use signal_crypto::SimpleHpkeReceiver;
1028
1029    let mut associated_data = [0; 5];
1030    associated_data[0] = id.into();
1031    associated_data[1..].copy_from_slice(&registration_id.to_be_bytes());
1032
1033    let data = BASE64_RELAXED.decode(string)?;
1034
1035    let result =
1036        aci.private_key()
1037            .open(b"deviceCreatedAt", &associated_data, &data)?;
1038
1039    let timestamp = i64::from_be_bytes(result.try_into().map_err(|_| {
1040        ServiceError::DecryptDeviceInfoFieldError("created-at")
1041    })?);
1042
1043    chrono::DateTime::<chrono::Utc>::from_timestamp_millis(timestamp)
1044        .ok_or(ServiceError::DecryptDeviceInfoFieldError("created-at"))
1045}
1046
1047pub fn decrypt_device_name(
1048    private_key: &PrivateKey,
1049    device_name: &DeviceName,
1050) -> Result<String, ServiceError> {
1051    let DeviceName {
1052        ephemeral_public: Some(ephemeral_public),
1053        synthetic_iv: Some(synthetic_iv),
1054        ciphertext: Some(ciphertext),
1055    } = device_name
1056    else {
1057        return Err(ServiceError::DecryptDeviceInfoFieldError("name"));
1058    };
1059
1060    let synthetic_iv: [u8; 16] = synthetic_iv[..synthetic_iv.len().min(16)]
1061        .try_into()
1062        .map_err(|_| ServiceError::MacError)?;
1063
1064    let ephemeral_public = PublicKey::deserialize(ephemeral_public)?;
1065
1066    let master_secret = private_key.calculate_agreement(&ephemeral_public)?;
1067    let key2 = calculate_hmac256(&master_secret, b"cipher")?;
1068    let cipher_key = calculate_hmac256(&key2, &synthetic_iv)?;
1069
1070    let mut plaintext = ciphertext.to_vec();
1071    const IV: [u8; 16] = [0; 16];
1072    let mut cipher =
1073        Aes256Ctr128BE::new(cipher_key.as_slice().into(), &IV.into());
1074    cipher.apply_keystream(&mut plaintext);
1075
1076    let key1 = calculate_hmac256(&master_secret, b"auth")?;
1077    let our_synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
1078    let our_synthetic_iv = &our_synthetic_iv[..16];
1079
1080    if synthetic_iv != our_synthetic_iv {
1081        Err(ServiceError::MacError)
1082    } else {
1083        Ok(String::from_utf8_lossy(&plaintext).to_string())
1084    }
1085}
1086
1087#[cfg(test)]
1088mod tests {
1089    use crate::utils::BASE64_RELAXED;
1090    use base64::Engine;
1091    use libsignal_protocol::{IdentityKeyPair, PrivateKey, PublicKey};
1092
1093    use super::DeviceName;
1094
1095    #[test]
1096    fn encrypt_device_name() -> anyhow::Result<()> {
1097        let input_device_name = "Nokia 3310 Millenial Edition";
1098        let mut csprng = rand::rng();
1099        let identity = IdentityKeyPair::generate(&mut csprng);
1100
1101        let device_name = super::encrypt_device_name(
1102            &mut csprng,
1103            input_device_name,
1104            identity.identity_key(),
1105        )?;
1106
1107        let decrypted_device_name =
1108            super::decrypt_device_name(identity.private_key(), &device_name)?;
1109
1110        assert_eq!(input_device_name, decrypted_device_name);
1111
1112        Ok(())
1113    }
1114
1115    #[test]
1116    fn decrypt_device_name() -> anyhow::Result<()> {
1117        let ephemeral_private_key = PrivateKey::deserialize(
1118            &BASE64_RELAXED
1119                .decode("0CgxHjwwblXjvX8sD5wZDWdYToMRf+CZSlgaUrxCGVo=")?,
1120        )?;
1121        let ephemeral_public_key = PublicKey::deserialize(
1122            &BASE64_RELAXED
1123                .decode("BcZS+Lt6yAKbEpXnRX+I5wHqesuvu93Q2V+fjidwW8R6")?,
1124        )?;
1125
1126        let device_name = DeviceName {
1127            ephemeral_public: Some(ephemeral_public_key.serialize().to_vec()),
1128            synthetic_iv: Some(
1129                BASE64_RELAXED.decode("86gekHGmltnnZ9QARhiFcg==")?,
1130            ),
1131            ciphertext: Some(
1132                BASE64_RELAXED
1133                    .decode("MtJ9/9KBWLBVAxfZJD4pLKzP4q+iodRJeCc+/A==")?,
1134            ),
1135        };
1136
1137        let decrypted_device_name =
1138            super::decrypt_device_name(&ephemeral_private_key, &device_name)?;
1139
1140        assert_eq!(decrypted_device_name, "Nokia 3310 Millenial Edition");
1141
1142        Ok(())
1143    }
1144}