libsignal_service/websocket/
keys.rs

1use std::collections::HashMap;
2
3use libsignal_core::DeviceId;
4use libsignal_protocol::{
5    kem::{Key, Public},
6    IdentityKey, PreKeyBundle, PublicKey, SenderCertificate, ServiceId,
7    ServiceIdKind, SignalProtocolError,
8};
9use reqwest::Method;
10use serde::Deserialize;
11
12use crate::{
13    pre_keys::{
14        KyberPreKeyEntity, PreKeyEntity, PreKeyState, SignedPreKeyEntity,
15    },
16    push_service::DEFAULT_DEVICE_ID,
17    sender::OutgoingPushMessage,
18    utils::{serde_base64, serde_device_id},
19    websocket::{self, registration::VerifyAccountResponse, SignalWebSocket},
20};
21
22use super::ServiceError;
23
24#[derive(Debug, Deserialize, Default)]
25#[serde(rename_all = "camelCase")]
26pub struct PreKeyStatus {
27    pub count: u32,
28    pub pq_count: u32,
29}
30
31#[derive(Debug, Deserialize)]
32#[serde(rename_all = "camelCase")]
33pub struct PreKeyResponse {
34    #[serde(with = "serde_base64")]
35    pub identity_key: Vec<u8>,
36    pub devices: Vec<PreKeyResponseItem>,
37}
38
39#[derive(Debug, Deserialize)]
40#[serde(rename_all = "camelCase")]
41pub struct PreKeyResponseItem {
42    #[serde(with = "serde_device_id")]
43    pub device_id: DeviceId,
44    pub registration_id: u32,
45    pub signed_pre_key: SignedPreKeyEntity,
46    pub pre_key: Option<PreKeyEntity>,
47    pub pq_pre_key: KyberPreKeyEntity,
48}
49
50impl PreKeyResponseItem {
51    pub(crate) fn into_bundle(
52        self,
53        identity: IdentityKey,
54    ) -> Result<PreKeyBundle, ServiceError> {
55        let pre_key_bundle = PreKeyBundle::new(
56            self.registration_id,
57            self.device_id,
58            self.pre_key
59                .map(|pk| -> Result<_, SignalProtocolError> {
60                    Ok((
61                        pk.key_id.into(),
62                        PublicKey::deserialize(&pk.public_key)?,
63                    ))
64                })
65                .transpose()?,
66            // pre_key: Option<(u32, PublicKey)>,
67            self.signed_pre_key.key_id.into(),
68            PublicKey::deserialize(&self.signed_pre_key.public_key)?,
69            self.signed_pre_key.signature,
70            self.pq_pre_key.key_id.into(),
71            Key::<Public>::deserialize(&self.pq_pre_key.public_key)?,
72            self.pq_pre_key.signature,
73            identity,
74        )?;
75
76        Ok(pre_key_bundle)
77    }
78}
79
80#[derive(Debug, Deserialize)]
81#[serde(rename_all = "camelCase")]
82struct SenderCertificateJson {
83    #[serde(with = "serde_base64")]
84    certificate: Vec<u8>,
85}
86
87impl SignalWebSocket<websocket::Identified> {
88    pub async fn get_pre_key_status(
89        &mut self,
90        service_id_kind: ServiceIdKind,
91    ) -> Result<PreKeyStatus, ServiceError> {
92        self.http_request(
93            Method::GET,
94            format!("/v2/keys?identity={}", service_id_kind),
95        )?
96        .send()
97        .await?
98        .service_error_for_status()
99        .await?
100        .json()
101        .await
102    }
103
104    /// Checks for consistency of the repeated-use keys
105    ///
106    /// Supply the digest as follows:
107    /// `SHA256(identityKeyBytes || signedEcPreKeyId || signedEcPreKeyIdBytes || lastResortKeyId ||
108    /// lastResortKeyBytes)`
109    ///
110    /// The IDs are represented as 8-byte big endian ints.
111    ///
112    /// Retuns `Ok(true)` if the view is consistent, `Ok(false)` if the view is inconsistent.
113    pub async fn check_pre_keys(
114        &mut self,
115        service_id_kind: ServiceIdKind,
116        digest: &[u8; 32],
117    ) -> Result<bool, ServiceError> {
118        #[derive(serde::Serialize)]
119        #[serde(rename_all = "camelCase")]
120        struct CheckPreKeysRequest<'a> {
121            identity_type: String,
122            #[serde(with = "serde_base64")]
123            digest: &'a [u8; 32],
124        }
125
126        let req = CheckPreKeysRequest {
127            identity_type: service_id_kind.to_string(),
128            digest,
129        };
130
131        let res = self
132            .http_request(Method::POST, "/v2/keys/check")?
133            .send_json(&req)
134            .await?;
135
136        if res.status_code() == Some(reqwest::StatusCode::CONFLICT) {
137            return Ok(false);
138        }
139
140        res.service_error_for_status().await?;
141
142        Ok(true)
143    }
144
145    pub async fn register_pre_keys(
146        &mut self,
147        service_id_kind: ServiceIdKind,
148        pre_key_state: PreKeyState,
149    ) -> Result<(), ServiceError> {
150        self.http_request(
151            Method::PUT,
152            format!("/v2/keys?identity={}", service_id_kind),
153        )?
154        .send_json(&pre_key_state)
155        .await?
156        .service_error_for_status()
157        .await?;
158
159        Ok(())
160    }
161
162    pub async fn get_pre_key(
163        &mut self,
164        destination: &ServiceId,
165        device_id: DeviceId,
166    ) -> Result<PreKeyBundle, ServiceError> {
167        let path = format!(
168            "/v2/keys/{}/{}",
169            destination.service_id_string(),
170            device_id
171        );
172
173        let mut pre_key_response: PreKeyResponse = self
174            .http_request(Method::GET, path)?
175            .send()
176            .await?
177            .service_error_for_status()
178            .await?
179            .json()
180            .await?;
181
182        assert!(!pre_key_response.devices.is_empty());
183
184        let identity = IdentityKey::decode(&pre_key_response.identity_key)?;
185        let device = pre_key_response.devices.remove(0);
186        device.into_bundle(identity)
187    }
188
189    pub(crate) async fn get_pre_keys(
190        &mut self,
191        destination: &ServiceId,
192        device_id: DeviceId,
193    ) -> Result<Vec<PreKeyBundle>, ServiceError> {
194        let path = if device_id == *DEFAULT_DEVICE_ID {
195            format!("/v2/keys/{}/*", destination.service_id_string())
196        } else {
197            format!(
198                "/v2/keys/{}/{}",
199                destination.service_id_string(),
200                device_id
201            )
202        };
203        let pre_key_response: PreKeyResponse = self
204            .http_request(Method::GET, path)?
205            .send()
206            .await?
207            .service_error_for_status()
208            .await?
209            .json()
210            .await?;
211        let mut pre_keys = vec![];
212        let identity = IdentityKey::decode(&pre_key_response.identity_key)?;
213        for device in pre_key_response.devices {
214            pre_keys.push(device.into_bundle(identity)?);
215        }
216        Ok(pre_keys)
217    }
218
219    pub async fn get_sender_certificate(
220        &mut self,
221    ) -> Result<SenderCertificate, ServiceError> {
222        let cert: SenderCertificateJson = self
223            .http_request(Method::GET, "/v1/certificate/delivery")?
224            .send()
225            .await?
226            .service_error_for_status()
227            .await?
228            .json()
229            .await?;
230        Ok(SenderCertificate::deserialize(&cert.certificate)?)
231    }
232
233    pub async fn get_uuid_only_sender_certificate(
234        &mut self,
235    ) -> Result<SenderCertificate, ServiceError> {
236        let cert: SenderCertificateJson = self
237            .http_request(
238                Method::GET,
239                "/v1/certificate/delivery?includeE164=false",
240            )?
241            .send()
242            .await?
243            .service_error_for_status()
244            .await?
245            .json()
246            .await?;
247        Ok(SenderCertificate::deserialize(&cert.certificate)?)
248    }
249
250    pub async fn distribute_pni_keys(
251        &mut self,
252        pni_identity_key: &IdentityKey,
253        device_messages: Vec<OutgoingPushMessage>,
254        device_pni_signed_prekeys: HashMap<String, SignedPreKeyEntity>,
255        device_pni_last_resort_kyber_prekeys: HashMap<
256            String,
257            KyberPreKeyEntity,
258        >,
259        pni_registration_ids: HashMap<String, u32>,
260        signature_valid_on_each_signed_pre_key: bool,
261    ) -> Result<VerifyAccountResponse, ServiceError> {
262        #[derive(serde::Serialize, Debug)]
263        #[serde(rename_all = "camelCase")]
264        struct PniKeyDistributionRequest {
265            #[serde(with = "serde_base64")]
266            pni_identity_key: Vec<u8>,
267            device_messages: Vec<OutgoingPushMessage>,
268            device_pni_signed_prekeys: HashMap<String, SignedPreKeyEntity>,
269            #[serde(rename = "devicePniPqLastResortPrekeys")]
270            device_pni_last_resort_kyber_prekeys:
271                HashMap<String, KyberPreKeyEntity>,
272            pni_registration_ids: HashMap<String, u32>,
273            signature_valid_on_each_signed_pre_key: bool,
274        }
275        self.http_request(
276            Method::PUT,
277            "/v2/accounts/phone_number_identity_key_distribution",
278        )?
279        .send_json(&PniKeyDistributionRequest {
280            pni_identity_key: pni_identity_key.serialize().into(),
281            device_messages,
282            device_pni_signed_prekeys,
283            device_pni_last_resort_kyber_prekeys,
284            pni_registration_ids,
285            signature_valid_on_each_signed_pre_key,
286        })
287        .await?
288        .service_error_for_status()
289        .await?
290        .json()
291        .await
292    }
293}