1use base64::prelude::*;
2use libsignal_core::{DeviceId, E164};
3use rand::{CryptoRng, Rng};
4use reqwest::Method;
5use std::collections::HashMap;
6use std::convert::{TryFrom, TryInto};
7
8use aes::cipher::{KeyIvInit, StreamCipher as _};
9use hmac::digest::Output;
10use hmac::{Hmac, Mac};
11use libsignal_protocol::{
12 kem, Aci, GenericSignedPreKey, IdentityKey, IdentityKeyPair,
13 IdentityKeyStore, KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore,
14 PublicKey, SenderKeyStore, ServiceIdKind, SignedPreKeyRecord, Timestamp,
15};
16use prost::Message;
17use serde::{Deserialize, Serialize};
18use sha2::{Digest, Sha256};
19use tracing_futures::Instrument;
20use zkgroup::profiles::ProfileKey;
21
22use crate::content::ContentBody;
23use crate::master_key::MasterKey;
24use crate::pre_keys::{
25 KyberPreKeyEntity, PreKeyEntity, PreKeysStore, SignedPreKeyEntity,
26 PRE_KEY_BATCH_SIZE, PRE_KEY_MINIMUM,
27};
28use crate::prelude::{MessageSender, MessageSenderError};
29use crate::proto::sync_message::PniChangeNumber;
30use crate::proto::{DeviceName, SyncMessage};
31use crate::provisioning::generate_registration_id;
32use crate::push_service::{
33 AvatarWrite, HttpAuthOverride, ReqwestExt, DEFAULT_DEVICE_ID,
34};
35use crate::sender::OutgoingPushMessage;
36use crate::service_address::ServiceIdExt;
37use crate::session_store::SessionStoreExt;
38use crate::timestamp::TimestampExt as _;
39use crate::utils::{random_length_padding, BASE64_RELAXED};
40use crate::websocket::account::DeviceInfo;
41use crate::websocket::keys::PreKeyStatus;
42use crate::websocket::registration::{
43 CaptchaAttributes, DeviceActivationRequest, RegistrationMethod,
44 VerifyAccountResponse,
45};
46use crate::websocket::{self, SignalWebSocket};
47use crate::{
48 configuration::{Endpoint, ServiceCredentials},
49 pre_keys::PreKeyState,
50 profile_cipher::{ProfileCipher, ProfileCipherError},
51 profile_name::ProfileName,
52 proto::{ProvisionEnvelope, ProvisionMessage, ProvisioningVersion},
53 provisioning::{ProvisioningCipher, ProvisioningError},
54 push_service::{PushService, ServiceError},
55 utils::serde_base64,
56 websocket::account::AccountAttributes,
57};
58
59type Aes256Ctr128BE = ctr::Ctr128BE<aes::Aes256>;
60
61pub struct AccountManager {
62 service: PushService,
63 websocket: SignalWebSocket<websocket::Identified>,
64 profile_key: Option<ProfileKey>,
65}
66
67#[derive(thiserror::Error, Debug)]
68pub enum ProfileManagerError {
69 #[error(transparent)]
70 ServiceError(#[from] ServiceError),
71 #[error(transparent)]
72 ProfileCipherError(#[from] ProfileCipherError),
73}
74
75#[derive(Debug, Default, Serialize, Deserialize, Clone)]
76pub struct Profile {
77 pub name: Option<ProfileName<String>>,
78 pub about: Option<String>,
79 pub about_emoji: Option<String>,
80 pub avatar: Option<String>,
81 pub unrestricted_unidentified_access: bool,
82}
83
84impl AccountManager {
85 pub fn new(
86 service: PushService,
87 websocket: SignalWebSocket<websocket::Identified>,
88 profile_key: Option<ProfileKey>,
89 ) -> Self {
90 Self {
91 service,
92 websocket,
93 profile_key,
94 }
95 }
96
97 #[allow(clippy::too_many_arguments)]
98 #[tracing::instrument(skip(self, protocol_store))]
99 pub async fn check_pre_keys<P: PreKeysStore>(
100 &mut self,
101 protocol_store: &mut P,
102 service_id_kind: ServiceIdKind,
103 ) -> Result<bool, ServiceError> {
104 let Some(signed_prekey_id) = protocol_store.signed_prekey_id().await?
105 else {
106 tracing::warn!("No signed prekey found");
107 return Ok(false);
108 };
109 let Some(kyber_prekey_id) =
112 protocol_store.last_resort_kyber_prekey_id().await?
113 else {
114 tracing::warn!("No last resort kyber prekey found");
115 return Ok(false);
116 };
117
118 let signed_prekey =
119 protocol_store.get_signed_pre_key(signed_prekey_id).await?;
120 let kyber_prekey =
121 protocol_store.get_kyber_pre_key(kyber_prekey_id).await?;
122
123 let mut hash = Sha256::default();
125 hash.update(
126 protocol_store
127 .get_identity_key_pair()
128 .await?
129 .public_key()
130 .serialize(),
131 );
132 hash.update((u32::from(signed_prekey_id) as u64).to_be_bytes());
133 hash.update(signed_prekey.public_key()?.serialize());
134 hash.update((u32::from(kyber_prekey_id) as u64).to_be_bytes());
135 hash.update(kyber_prekey.public_key()?.serialize());
136
137 self.websocket
138 .check_pre_keys(service_id_kind, hash.finalize().as_ref())
139 .await
140 }
141
142 #[allow(clippy::too_many_arguments)]
149 #[tracing::instrument(skip(self, protocol_store))]
150 pub async fn update_pre_key_bundle<P: PreKeysStore>(
151 &mut self,
152 protocol_store: &mut P,
153 service_id_kind: ServiceIdKind,
154 use_last_resort_key: bool,
155 ) -> Result<(), ServiceError> {
156 let prekey_status = match self
157 .websocket
158 .get_pre_key_status(service_id_kind)
159 .instrument(tracing::span!(
160 tracing::Level::DEBUG,
161 "Fetching pre key status"
162 ))
163 .await
164 {
165 Ok(status) => status,
166 Err(ServiceError::Unauthorized) => {
167 tracing::info!("Got Unauthorized when fetching pre-key status. Assuming first installment.");
168 PreKeyStatus {
171 count: 0,
172 pq_count: 0,
173 }
174 },
175 Err(e) => return Err(e),
176 };
177 tracing::trace!("Remaining pre-keys on server: {:?}", prekey_status);
178
179 let check_pre_keys = self
180 .check_pre_keys(protocol_store, service_id_kind)
181 .instrument(tracing::span!(
182 tracing::Level::DEBUG,
183 "Checking pre keys"
184 ))
185 .await?;
186 if !check_pre_keys {
187 tracing::info!(
188 "Last resort pre-keys are not up to date; refreshing."
189 );
190 } else {
191 tracing::debug!("Last resort pre-keys are up to date.");
192 }
193
194 if check_pre_keys
199 && (prekey_status.count >= PRE_KEY_MINIMUM
200 && prekey_status.pq_count >= PRE_KEY_MINIMUM)
201 {
202 if protocol_store.signed_pre_keys_count().await? > 0
203 && protocol_store.kyber_pre_keys_count(true).await? > 0
204 && protocol_store.signed_prekey_id().await?.is_some()
205 && protocol_store
206 .last_resort_kyber_prekey_id()
207 .await?
208 .is_some()
209 {
210 tracing::debug!("Available keys sufficient");
211 return Ok(());
212 }
213 tracing::info!("Available keys sufficient; forcing refresh.");
214 }
215
216 let identity_key_pair = protocol_store
217 .get_identity_key_pair()
218 .instrument(tracing::trace_span!("get identity key pair"))
219 .await?;
220
221 let last_resort_keys = protocol_store
222 .load_last_resort_kyber_pre_keys()
223 .instrument(tracing::trace_span!("fetch last resort key"))
224 .await?;
225
226 let has_last_resort_key = !last_resort_keys.is_empty();
228
229 let (pre_keys, signed_pre_key, pq_pre_keys, pq_last_resort_key) =
230 crate::pre_keys::replenish_pre_keys(
231 protocol_store,
232 &mut rand::rng(),
233 &identity_key_pair,
234 use_last_resort_key && !has_last_resort_key,
235 PRE_KEY_BATCH_SIZE,
236 PRE_KEY_BATCH_SIZE,
237 )
238 .await?;
239
240 let pq_last_resort_key = if has_last_resort_key {
241 if last_resort_keys.len() > 1 {
242 tracing::warn!(
243 "More than one last resort key found; only uploading first"
244 );
245 }
246 Some(KyberPreKeyEntity::try_from(last_resort_keys[0].clone())?)
247 } else {
248 pq_last_resort_key
249 .map(KyberPreKeyEntity::try_from)
250 .transpose()?
251 };
252
253 let identity_key = *identity_key_pair.identity_key();
254
255 let pre_keys: Vec<_> = pre_keys
256 .into_iter()
257 .map(PreKeyEntity::try_from)
258 .collect::<Result<_, _>>()?;
259 let signed_pre_key = signed_pre_key.try_into()?;
260 let pq_pre_keys: Vec<_> = pq_pre_keys
261 .into_iter()
262 .map(KyberPreKeyEntity::try_from)
263 .collect::<Result<_, _>>()?;
264
265 tracing::info!(
266 "Uploading pre-keys: {} one-time, {} PQ, {} PQ last resort",
267 pre_keys.len(),
268 pq_pre_keys.len(),
269 if pq_last_resort_key.is_some() { 1 } else { 0 }
270 );
271
272 let pre_key_state = PreKeyState {
273 pre_keys,
274 signed_pre_key,
275 identity_key,
276 pq_pre_keys,
277 pq_last_resort_key,
278 };
279
280 self.websocket
281 .register_pre_keys(service_id_kind, pre_key_state)
282 .instrument(tracing::span!(
283 tracing::Level::DEBUG,
284 "Uploading pre keys"
285 ))
286 .await?;
287
288 Ok(())
289 }
290
291 async fn new_device_provisioning_code(
292 &mut self,
293 ) -> Result<String, ServiceError> {
294 #[derive(serde::Deserialize)]
295 #[serde(rename_all = "camelCase")]
296 struct DeviceCode {
297 verification_code: String,
298 }
299
300 let dc: DeviceCode = self
301 .service
302 .request(
303 Method::GET,
304 Endpoint::service("/v1/devices/provisioning/code"),
305 HttpAuthOverride::NoOverride,
306 )?
307 .send()
308 .await?
309 .service_error_for_status()
310 .await?
311 .json()
312 .await?;
313
314 Ok(dc.verification_code)
315 }
316
317 async fn send_provisioning_message(
318 &mut self,
319 destination: &str,
320 env: ProvisionEnvelope,
321 ) -> Result<(), ServiceError> {
322 #[derive(serde::Serialize)]
323 struct ProvisioningMessage {
324 body: String,
325 }
326
327 let body = env.encode_to_vec();
328
329 self.service
330 .request(
331 Method::PUT,
332 Endpoint::service(format!("/v1/provisioning/{destination}")),
333 HttpAuthOverride::NoOverride,
334 )?
335 .json(&ProvisioningMessage {
336 body: BASE64_RELAXED.encode(body),
337 })
338 .send()
339 .await?
340 .service_error_for_status()
341 .await?;
342
343 Ok(())
344 }
345
346 pub async fn link_device<R: Rng + CryptoRng>(
360 &mut self,
361 csprng: &mut R,
362 url: url::Url,
363 aci_identity_store: &dyn IdentityKeyStore,
364 pni_identity_store: &dyn IdentityKeyStore,
365 credentials: ServiceCredentials,
366 master_key: Option<MasterKey>,
367 ) -> Result<(), ProvisioningError> {
368 let query: HashMap<_, _> = url.query_pairs().collect();
369 let ephemeral_id =
370 query.get("uuid").ok_or(ProvisioningError::MissingUuid)?;
371 let pub_key = query
372 .get("pub_key")
373 .ok_or(ProvisioningError::MissingPublicKey)?;
374
375 let pub_key = BASE64_RELAXED
376 .decode(&**pub_key)
377 .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
378 let pub_key = PublicKey::deserialize(&pub_key)
379 .map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
380
381 let aci_identity_key_pair =
382 aci_identity_store.get_identity_key_pair().await?;
383 let pni_identity_key_pair =
384 pni_identity_store.get_identity_key_pair().await?;
385
386 if credentials.aci.is_none() {
387 tracing::warn!("No local ACI set");
388 }
389 if credentials.pni.is_none() {
390 tracing::warn!("No local PNI set");
391 }
392
393 let provisioning_code = self.new_device_provisioning_code().await?;
394
395 let msg = ProvisionMessage {
396 aci: credentials.aci.as_ref().map(|u| u.to_string()),
397 aci_binary: credentials.aci.map(|u| u.into_bytes().into()),
398 aci_identity_key_public: Some(
399 aci_identity_key_pair.public_key().serialize().into_vec(),
400 ),
401 aci_identity_key_private: Some(
402 aci_identity_key_pair.private_key().serialize(),
403 ),
404 number: Some(credentials.e164()),
405 pni_identity_key_public: Some(
406 pni_identity_key_pair.public_key().serialize().into_vec(),
407 ),
408 pni_identity_key_private: Some(
409 pni_identity_key_pair.private_key().serialize(),
410 ),
411 pni: credentials.pni.as_ref().map(uuid::Uuid::to_string),
412 pni_binary: credentials.pni.map(|u| u.into_bytes().into()),
413 profile_key: self.profile_key.as_ref().map(|x| x.bytes.to_vec()),
414 provisioning_version: Some(i32::from(
416 ProvisioningVersion::TabletSupport,
417 ) as _),
418 provisioning_code: Some(provisioning_code),
419 read_receipts: None,
420 user_agent: None,
421 master_key: master_key.map(|x| x.into()),
422 ephemeral_backup_key: None,
423 account_entropy_pool: None,
424 media_root_backup_key: None,
425 };
426
427 let cipher = ProvisioningCipher::from_public(pub_key);
428
429 let encrypted = cipher.encrypt(csprng, msg)?;
430 self.send_provisioning_message(ephemeral_id, encrypted)
431 .await?;
432 Ok(())
433 }
434
435 pub async fn linked_devices(
436 &mut self,
437 aci_identity_store: &dyn IdentityKeyStore,
438 ) -> Result<Vec<DeviceInfo>, ServiceError> {
439 let device_infos = self.websocket.devices().await?;
440 let aci_identity_keypair =
441 aci_identity_store.get_identity_key_pair().await?;
442
443 device_infos
444 .into_iter()
445 .map(|i| {
446 Ok(DeviceInfo {
447 id: i.id,
448 name: i.name.and_then(|s| {
449 match decrypt_device_name_from_device_info(
450 &s,
451 &aci_identity_keypair,
452 ) {
453 Ok(name) => Some(name),
454 Err(e) => {
455 tracing::error!("{e}");
456 None
457 },
458 }
459 }),
460 registration_id: i.registration_id,
461 last_seen: i.last_seen,
462 created_at: decrypt_device_created_at_from_device_info(
463 i.id,
464 i.registration_id,
465 &i.created_at_ciphertext,
466 &aci_identity_keypair,
467 )?,
468 })
469 })
470 .collect()
471 }
472
473 pub async fn register_account<
474 R: Rng + CryptoRng,
475 Aci: PreKeysStore + IdentityKeyStore,
476 Pni: PreKeysStore + IdentityKeyStore,
477 >(
478 &mut self,
479 csprng: &mut R,
480 registration_method: RegistrationMethod<'_>,
481 account_attributes: AccountAttributes,
482 aci_protocol_store: &mut Aci,
483 pni_protocol_store: &mut Pni,
484 skip_device_transfer: bool,
485 ) -> Result<VerifyAccountResponse, ProvisioningError> {
486 let aci_identity_key_pair = aci_protocol_store
487 .get_identity_key_pair()
488 .instrument(tracing::trace_span!("get ACI identity key pair"))
489 .await?;
490 let pni_identity_key_pair = pni_protocol_store
491 .get_identity_key_pair()
492 .instrument(tracing::trace_span!("get PNI identity key pair"))
493 .await?;
494
495 let (
496 _aci_pre_keys,
497 aci_signed_pre_key,
498 _aci_kyber_pre_keys,
499 aci_last_resort_kyber_prekey,
500 ) = crate::pre_keys::replenish_pre_keys(
501 aci_protocol_store,
502 csprng,
503 &aci_identity_key_pair,
504 true,
505 0,
506 0,
507 )
508 .await?;
509
510 let (
511 _pni_pre_keys,
512 pni_signed_pre_key,
513 _pni_kyber_pre_keys,
514 pni_last_resort_kyber_prekey,
515 ) = crate::pre_keys::replenish_pre_keys(
516 pni_protocol_store,
517 csprng,
518 &pni_identity_key_pair,
519 true,
520 0,
521 0,
522 )
523 .await?;
524
525 let aci_identity_key = aci_identity_key_pair.identity_key();
526 let pni_identity_key = pni_identity_key_pair.identity_key();
527
528 let dar = DeviceActivationRequest {
529 aci_signed_pre_key: aci_signed_pre_key.try_into()?,
530 pni_signed_pre_key: pni_signed_pre_key.try_into()?,
531 aci_pq_last_resort_pre_key: aci_last_resort_kyber_prekey
532 .expect("requested last resort prekey")
533 .try_into()?,
534 pni_pq_last_resort_pre_key: pni_last_resort_kyber_prekey
535 .expect("requested last resort prekey")
536 .try_into()?,
537 };
538
539 let result = self
540 .websocket
541 .submit_registration_request(
542 registration_method,
543 account_attributes,
544 skip_device_transfer,
545 aci_identity_key,
546 pni_identity_key,
547 dar,
548 )
549 .await?;
550
551 Ok(result)
552 }
553
554 pub async fn upload_versioned_profile_without_avatar<
565 R: Rng + CryptoRng,
566 S: AsRef<str>,
567 >(
568 &mut self,
569 aci: libsignal_protocol::Aci,
570 name: ProfileName<S>,
571 about: Option<String>,
572 about_emoji: Option<String>,
573 retain_avatar: bool,
574 csprng: &mut R,
575 ) -> Result<(), ProfileManagerError> {
576 self.upload_versioned_profile::<std::io::Cursor<Vec<u8>>, _, _>(
577 aci,
578 name,
579 about,
580 about_emoji,
581 if retain_avatar {
582 AvatarWrite::RetainAvatar
583 } else {
584 AvatarWrite::NoAvatar
585 },
586 csprng,
587 )
588 .await?;
589 Ok(())
590 }
591
592 pub async fn retrieve_profile(
593 &mut self,
594 address: Aci,
595 ) -> Result<Profile, ProfileManagerError> {
596 let profile_key =
597 self.profile_key.expect("set profile key in AccountManager");
598
599 let encrypted_profile = self
600 .websocket
601 .retrieve_profile_by_id(address, Some(profile_key))
602 .await?;
603
604 let profile_cipher = ProfileCipher::new(profile_key);
605 Ok(profile_cipher.decrypt(encrypted_profile)?)
606 }
607
608 pub async fn upload_versioned_profile<
614 's,
615 C: std::io::Read + Send + 's,
616 R: Rng + CryptoRng,
617 S: AsRef<str>,
618 >(
619 &mut self,
620 aci: libsignal_protocol::Aci,
621 name: ProfileName<S>,
622 about: Option<String>,
623 about_emoji: Option<String>,
624 avatar: AvatarWrite<&'s mut C>,
625 csprng: &mut R,
626 ) -> Result<Option<String>, ProfileManagerError> {
627 let profile_key =
628 self.profile_key.expect("set profile key in AccountManager");
629 let profile_cipher = ProfileCipher::new(profile_key);
630
631 let name = profile_cipher.encrypt_name(name.as_ref(), csprng)?;
633 let about = about.unwrap_or_default();
634 let about = profile_cipher.encrypt_about(about, csprng)?;
635 let about_emoji = about_emoji.unwrap_or_default();
636 let about_emoji = profile_cipher.encrypt_emoji(about_emoji, csprng)?;
637
638 if matches!(avatar, AvatarWrite::NewAvatar(_)) {
640 unimplemented!("Setting avatar requires ProfileCipherStream")
643 }
644
645 let profile_key = profile_cipher.into_inner();
646 let commitment = profile_key.get_commitment(aci);
647 let profile_key_version = profile_key.get_profile_key_version(aci);
648
649 Ok(self
650 .websocket
651 .write_profile::<C, S>(
652 &profile_key_version,
653 &name,
654 &about,
655 &about_emoji,
656 &commitment,
657 avatar,
658 )
659 .await?)
660 }
661
662 pub async fn set_account_attributes(
666 &mut self,
667 attributes: AccountAttributes,
668 ) -> Result<(), ServiceError> {
669 self.websocket.set_account_attributes(attributes).await
670 }
671
672 pub async fn update_device_name<R: Rng + CryptoRng>(
674 &mut self,
675 device_id: libsignal_core::DeviceId,
676 device_name: &str,
677 aci: Aci,
678 aci_identity_store: &dyn IdentityKeyStore,
679 csprng: &mut R,
680 ) -> Result<(), ServiceError> {
681 let addr = aci.to_protocol_address(device_id).unwrap();
682 let public_key = aci_identity_store.get_identity(&addr).await?;
683 let Some(public_key) = public_key else {
684 return Err(ServiceError::SendError {
685 reason: format!("public key for device {addr:?} not found"),
686 });
687 };
688 let encrypted_device_name =
689 encrypt_device_name(csprng, device_name, &public_key)?;
690
691 #[derive(Serialize)]
692 #[serde(rename_all = "camelCase")]
693 struct Data {
694 #[serde(with = "serde_base64")]
695 device_name: Vec<u8>,
696 }
697
698 self.service
699 .request(
700 Method::PUT,
701 Endpoint::service(format!(
702 "/v1/accounts/name?deviceId={}",
703 device_id
704 )),
705 HttpAuthOverride::NoOverride,
706 )?
707 .json(&Data {
708 device_name: encrypted_device_name.encode_to_vec(),
709 })
710 .send()
711 .await?
712 .service_error_for_status()
713 .await?;
714
715 Ok(())
716 }
717
718 pub async fn submit_recaptcha_challenge(
725 &mut self,
726 token: &str,
727 captcha: &str,
728 ) -> Result<(), ServiceError> {
729 self.service
730 .request(
731 Method::PUT,
732 Endpoint::service("/v1/challenge"),
733 HttpAuthOverride::NoOverride,
734 )?
735 .json(&CaptchaAttributes {
736 challenge_type: "captcha",
737 token,
738 captcha,
739 })
740 .send()
741 .await?
742 .service_error_for_status()
743 .await?;
744
745 Ok(())
746 }
747
748 #[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci = local_aci.service_id_string()))]
754 pub async fn pnp_initialize_devices<
755 R: Rng + CryptoRng,
756 AciStore: PreKeysStore + SessionStoreExt,
757 PniStore: PreKeysStore,
758 AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
759 >(
760 &mut self,
761 aci_protocol_store: &mut AciStore,
762 pni_protocol_store: &mut PniStore,
763 mut sender: MessageSender<AciOrPni>,
764 local_aci: Aci,
765 e164: E164,
766 csprng: &mut R,
767 ) -> Result<(), MessageSenderError> {
768 let pni_identity_key_pair =
769 pni_protocol_store.get_identity_key_pair().await?;
770
771 let pni_identity_key = pni_identity_key_pair.identity_key();
772
773 let local_device_ids = aci_protocol_store
775 .get_sub_device_sessions(&local_aci.into())
776 .await?;
777
778 let mut device_messages =
779 Vec::<OutgoingPushMessage>::with_capacity(local_device_ids.len());
780 let mut device_pni_signed_prekeys =
781 HashMap::<String, SignedPreKeyEntity>::with_capacity(
782 local_device_ids.len(),
783 );
784 let mut device_pni_last_resort_kyber_prekeys =
785 HashMap::<String, KyberPreKeyEntity>::with_capacity(
786 local_device_ids.len(),
787 );
788 let mut pni_registration_ids =
789 HashMap::<String, u32>::with_capacity(local_device_ids.len());
790
791 let signature_valid_on_each_signed_pre_key = true;
792 for local_device_id in
793 std::iter::once(*DEFAULT_DEVICE_ID).chain(local_device_ids)
794 {
795 let local_protocol_address =
796 local_aci.to_protocol_address(local_device_id)?;
797 let span = tracing::trace_span!(
798 "filtering devices",
799 address = %local_protocol_address
800 );
801 if (local_device_id != *DEFAULT_DEVICE_ID)
803 && aci_protocol_store
804 .load_session(&local_protocol_address)
805 .instrument(span)
806 .await?
807 .is_none()
808 {
809 tracing::warn!(
810 "No session with device {}, skipping PNI provisioning",
811 local_device_id
812 );
813 continue;
814 }
815 let (
816 _pre_keys,
817 signed_pre_key,
818 _kyber_pre_keys,
819 last_resort_kyber_prekey,
820 ) = if local_device_id == *DEFAULT_DEVICE_ID {
821 crate::pre_keys::replenish_pre_keys(
822 pni_protocol_store,
823 csprng,
824 &pni_identity_key_pair,
825 true,
826 0,
827 0,
828 )
829 .await?
830 } else {
831 let signed_pre_key_pair = KeyPair::generate(csprng);
833 let signed_pre_key_public = signed_pre_key_pair.public_key;
834 let signed_pre_key_signature = pni_identity_key_pair
835 .private_key()
836 .calculate_signature(
837 &signed_pre_key_public.serialize(),
838 csprng,
839 )
840 .map_err(MessageSenderError::InvalidPrivateKey)?;
841
842 let signed_prekey_record = SignedPreKeyRecord::new(
843 csprng.random_range::<u32, _>(0..0xFFFFFF).into(),
844 Timestamp::now(),
845 &signed_pre_key_pair,
846 &signed_pre_key_signature,
847 );
848
849 let kyber_pre_key_record = KyberPreKeyRecord::generate(
851 kem::KeyType::Kyber1024,
852 csprng.random_range::<u32, _>(0..0xFFFFFF).into(),
853 pni_identity_key_pair.private_key(),
854 )?;
855 (
856 vec![],
857 signed_prekey_record,
858 vec![],
859 Some(kyber_pre_key_record),
860 )
861 };
862
863 let registration_id = if local_device_id == *DEFAULT_DEVICE_ID {
864 pni_protocol_store.get_local_registration_id().await?
865 } else {
866 loop {
867 let regid = generate_registration_id(csprng);
868 if !pni_registration_ids.iter().any(|(_k, v)| *v == regid) {
869 break regid;
870 }
871 }
872 };
873
874 let local_device_id_s = local_device_id.to_string();
875 device_pni_signed_prekeys.insert(
876 local_device_id_s.clone(),
877 SignedPreKeyEntity::try_from(&signed_pre_key)?,
878 );
879 device_pni_last_resort_kyber_prekeys.insert(
880 local_device_id_s.clone(),
881 KyberPreKeyEntity::try_from(
882 last_resort_kyber_prekey
883 .as_ref()
884 .expect("requested last resort key"),
885 )?,
886 );
887 pni_registration_ids
888 .insert(local_device_id_s.clone(), registration_id);
889
890 assert!(_pre_keys.is_empty());
891 assert!(_kyber_pre_keys.is_empty());
892
893 if local_device_id == *DEFAULT_DEVICE_ID {
894 continue;
897 }
898 let msg = SyncMessage {
900 pni_change_number: Some(PniChangeNumber {
901 identity_key_pair: Some(
902 pni_identity_key_pair.serialize().to_vec(),
903 ),
904 signed_pre_key: Some(signed_pre_key.serialize()?),
905 last_resort_kyber_pre_key: Some(
906 last_resort_kyber_prekey
907 .expect("requested last resort key")
908 .serialize()?,
909 ),
910 registration_id: Some(registration_id),
911 new_e164: Some(e164.to_string()),
912 }),
913 padding: Some(random_length_padding(csprng, 512)),
914 ..SyncMessage::default()
915 };
916 let content: ContentBody = msg.into();
917 let msg = sender
918 .create_encrypted_message(
919 &local_aci.into(),
920 None,
921 local_device_id,
922 &content.into_proto().encode_to_vec(),
923 )
924 .await?;
925 device_messages.push(msg);
926 }
927
928 self.websocket
929 .distribute_pni_keys(
930 pni_identity_key,
931 device_messages,
932 device_pni_signed_prekeys,
933 device_pni_last_resort_kyber_prekeys,
934 pni_registration_ids,
935 signature_valid_on_each_signed_pre_key,
936 )
937 .await?;
938
939 Ok(())
940 }
941}
942
943fn calculate_hmac256(
944 mac_key: &[u8],
945 ciphertext: &[u8],
946) -> Result<Output<Hmac<Sha256>>, ServiceError> {
947 let mut mac = Hmac::<Sha256>::new_from_slice(mac_key)
948 .map_err(|_| ServiceError::MacError)?;
949 mac.update(ciphertext);
950 Ok(mac.finalize().into_bytes())
951}
952
953pub fn encrypt_device_name<R: rand::Rng + rand::CryptoRng>(
954 csprng: &mut R,
955 device_name: &str,
956 identity_public: &IdentityKey,
957) -> Result<DeviceName, ServiceError> {
958 let plaintext = device_name.as_bytes().to_vec();
959 let ephemeral_key_pair = KeyPair::generate(csprng);
960
961 let master_secret = ephemeral_key_pair
962 .private_key
963 .calculate_agreement(identity_public.public_key())?;
964
965 let key1 = calculate_hmac256(&master_secret, b"auth")?;
966 let synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
967 let synthetic_iv = &synthetic_iv[..16];
968
969 let key2 = calculate_hmac256(&master_secret, b"cipher")?;
970 let cipher_key = calculate_hmac256(&key2, synthetic_iv)?;
971
972 let mut ciphertext = plaintext;
973
974 const IV: [u8; 16] = [0; 16];
975 let mut cipher = Aes256Ctr128BE::new(&cipher_key, &IV.into());
976 cipher.apply_keystream(&mut ciphertext);
977
978 let device_name = DeviceName {
979 ephemeral_public: Some(
980 ephemeral_key_pair.public_key.serialize().to_vec(),
981 ),
982 synthetic_iv: Some(synthetic_iv.to_vec()),
983 ciphertext: Some(ciphertext),
984 };
985
986 Ok(device_name)
987}
988
989fn decrypt_device_name_from_device_info(
990 string: &str,
991 aci: &IdentityKeyPair,
992) -> Result<String, ServiceError> {
993 let data = BASE64_RELAXED.decode(string)?;
994 let name = DeviceName::decode(&*data)?;
995 crate::decrypt_device_name(aci.private_key(), &name)
996}
997
998fn decrypt_device_created_at_from_device_info(
1000 id: DeviceId,
1001 registration_id: i32,
1002 string: &str,
1003 aci: &IdentityKeyPair,
1004) -> Result<chrono::DateTime<chrono::Utc>, ServiceError> {
1005 use signal_crypto::SimpleHpkeReceiver;
1006
1007 let mut associated_data = [0; 5];
1008 associated_data[0] = id.into();
1009 associated_data[1..].copy_from_slice(®istration_id.to_be_bytes());
1010
1011 let data = BASE64_RELAXED.decode(string)?;
1012
1013 let result =
1014 aci.private_key()
1015 .open(b"deviceCreatedAt", &associated_data, &data)?;
1016
1017 let timestamp = i64::from_be_bytes(result.try_into().map_err(|_| {
1018 ServiceError::DecryptDeviceInfoFieldError("created-at")
1019 })?);
1020
1021 chrono::DateTime::<chrono::Utc>::from_timestamp_millis(timestamp)
1022 .ok_or(ServiceError::DecryptDeviceInfoFieldError("created-at"))
1023}
1024
1025pub fn decrypt_device_name(
1026 private_key: &PrivateKey,
1027 device_name: &DeviceName,
1028) -> Result<String, ServiceError> {
1029 let DeviceName {
1030 ephemeral_public: Some(ephemeral_public),
1031 synthetic_iv: Some(synthetic_iv),
1032 ciphertext: Some(ciphertext),
1033 } = device_name
1034 else {
1035 return Err(ServiceError::DecryptDeviceInfoFieldError("name"));
1036 };
1037
1038 let synthetic_iv: [u8; 16] = synthetic_iv[..synthetic_iv.len().min(16)]
1039 .try_into()
1040 .map_err(|_| ServiceError::MacError)?;
1041
1042 let ephemeral_public = PublicKey::deserialize(ephemeral_public)?;
1043
1044 let master_secret = private_key.calculate_agreement(&ephemeral_public)?;
1045 let key2 = calculate_hmac256(&master_secret, b"cipher")?;
1046 let cipher_key = calculate_hmac256(&key2, &synthetic_iv)?;
1047
1048 let mut plaintext = ciphertext.to_vec();
1049 const IV: [u8; 16] = [0; 16];
1050 let mut cipher =
1051 Aes256Ctr128BE::new(cipher_key.as_slice().into(), &IV.into());
1052 cipher.apply_keystream(&mut plaintext);
1053
1054 let key1 = calculate_hmac256(&master_secret, b"auth")?;
1055 let our_synthetic_iv = calculate_hmac256(&key1, &plaintext)?;
1056 let our_synthetic_iv = &our_synthetic_iv[..16];
1057
1058 if synthetic_iv != our_synthetic_iv {
1059 Err(ServiceError::MacError)
1060 } else {
1061 Ok(String::from_utf8_lossy(&plaintext).to_string())
1062 }
1063}
1064
1065#[cfg(test)]
1066mod tests {
1067 use crate::utils::BASE64_RELAXED;
1068 use base64::Engine;
1069 use libsignal_protocol::{IdentityKeyPair, PrivateKey, PublicKey};
1070
1071 use super::DeviceName;
1072
1073 #[test]
1074 fn encrypt_device_name() -> anyhow::Result<()> {
1075 let input_device_name = "Nokia 3310 Millenial Edition";
1076 let mut csprng = rand::rng();
1077 let identity = IdentityKeyPair::generate(&mut csprng);
1078
1079 let device_name = super::encrypt_device_name(
1080 &mut csprng,
1081 input_device_name,
1082 identity.identity_key(),
1083 )?;
1084
1085 let decrypted_device_name =
1086 super::decrypt_device_name(identity.private_key(), &device_name)?;
1087
1088 assert_eq!(input_device_name, decrypted_device_name);
1089
1090 Ok(())
1091 }
1092
1093 #[test]
1094 fn decrypt_device_name() -> anyhow::Result<()> {
1095 let ephemeral_private_key = PrivateKey::deserialize(
1096 &BASE64_RELAXED
1097 .decode("0CgxHjwwblXjvX8sD5wZDWdYToMRf+CZSlgaUrxCGVo=")?,
1098 )?;
1099 let ephemeral_public_key = PublicKey::deserialize(
1100 &BASE64_RELAXED
1101 .decode("BcZS+Lt6yAKbEpXnRX+I5wHqesuvu93Q2V+fjidwW8R6")?,
1102 )?;
1103
1104 let device_name = DeviceName {
1105 ephemeral_public: Some(ephemeral_public_key.serialize().to_vec()),
1106 synthetic_iv: Some(
1107 BASE64_RELAXED.decode("86gekHGmltnnZ9QARhiFcg==")?,
1108 ),
1109 ciphertext: Some(
1110 BASE64_RELAXED
1111 .decode("MtJ9/9KBWLBVAxfZJD4pLKzP4q+iodRJeCc+/A==")?,
1112 ),
1113 };
1114
1115 let decrypted_device_name =
1116 super::decrypt_device_name(&ephemeral_private_key, &device_name)?;
1117
1118 assert_eq!(decrypted_device_name, "Nokia 3310 Millenial Edition");
1119
1120 Ok(())
1121 }
1122}