libsignal_service/push_service/
mod.rs

1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4    configuration::{Endpoint, ServiceCredentials},
5    pre_keys::{KyberPreKeyEntity, PreKeyEntity, SignedPreKeyEntity},
6    prelude::ServiceConfiguration,
7    utils::serde_base64,
8    websocket::SignalWebSocket,
9};
10
11use derivative::Derivative;
12use libsignal_protocol::{
13    error::SignalProtocolError,
14    kem::{Key, Public},
15    IdentityKey, PreKeyBundle, PublicKey,
16};
17use protobuf::ProtobufResponseExt;
18use reqwest::{Method, RequestBuilder};
19use reqwest_websocket::RequestBuilderExt;
20use serde::{Deserialize, Serialize};
21use tracing::{debug_span, Instrument};
22
23pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
24pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
25    LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
26
27mod account;
28mod cdn;
29mod error;
30mod keys;
31mod linking;
32mod profile;
33mod registration;
34mod response;
35mod stickers;
36
37pub use account::*;
38pub use cdn::*;
39pub use error::*;
40pub use keys::*;
41pub use linking::*;
42pub use profile::*;
43pub use registration::*;
44pub(crate) use response::{ReqwestExt, SignalServiceResponse};
45
46#[derive(Debug, Serialize, Deserialize)]
47pub struct ProofRequired {
48    pub token: String,
49    pub options: Vec<String>,
50}
51
52#[derive(Derivative, Clone, Serialize, Deserialize)]
53#[derivative(Debug)]
54pub struct HttpAuth {
55    pub username: String,
56    #[derivative(Debug = "ignore")]
57    pub password: String,
58}
59
60#[derive(Debug, Clone)]
61pub enum HttpAuthOverride {
62    NoOverride,
63    Unidentified,
64    Identified(HttpAuth),
65}
66
67#[derive(Debug, Clone, Eq, PartialEq)]
68pub enum AvatarWrite<C> {
69    NewAvatar(C),
70    RetainAvatar,
71    NoAvatar,
72}
73
74#[derive(Debug, Deserialize)]
75#[serde(rename_all = "camelCase")]
76struct SenderCertificateJson {
77    #[serde(with = "serde_base64")]
78    certificate: Vec<u8>,
79}
80
81#[derive(Debug, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub struct PreKeyResponse {
84    #[serde(with = "serde_base64")]
85    pub identity_key: Vec<u8>,
86    pub devices: Vec<PreKeyResponseItem>,
87}
88
89#[derive(Debug, Deserialize)]
90#[serde(rename_all = "camelCase")]
91pub struct PreKeyResponseItem {
92    pub device_id: u32,
93    pub registration_id: u32,
94    pub signed_pre_key: SignedPreKeyEntity,
95    pub pre_key: Option<PreKeyEntity>,
96    pub pq_pre_key: KyberPreKeyEntity,
97}
98
99impl PreKeyResponseItem {
100    #[allow(clippy::result_large_err)]
101    pub(crate) fn into_bundle(
102        self,
103        identity: IdentityKey,
104    ) -> Result<PreKeyBundle, ServiceError> {
105        Ok(PreKeyBundle::new(
106            self.registration_id,
107            self.device_id.try_into()?,
108            self.pre_key
109                .map(|pk| -> Result<_, SignalProtocolError> {
110                    Ok((
111                        pk.key_id.into(),
112                        PublicKey::deserialize(&pk.public_key)?,
113                    ))
114                })
115                .transpose()?,
116            // pre_key: Option<(u32, PublicKey)>,
117            self.signed_pre_key.key_id.into(),
118            PublicKey::deserialize(&self.signed_pre_key.public_key)?,
119            self.signed_pre_key.signature,
120            self.pq_pre_key.key_id.into(),
121            Key::<Public>::deserialize(&self.pq_pre_key.public_key)?,
122            self.pq_pre_key.signature,
123            identity,
124        )?)
125    }
126}
127
128#[derive(Debug, Deserialize)]
129#[serde(rename_all = "camelCase")]
130pub struct MismatchedDevices {
131    pub missing_devices: Vec<u32>,
132    pub extra_devices: Vec<u32>,
133}
134
135#[derive(Debug, Deserialize)]
136#[serde(rename_all = "camelCase")]
137pub struct StaleDevices {
138    pub stale_devices: Vec<u32>,
139}
140
141#[derive(Clone)]
142pub struct PushService {
143    cfg: ServiceConfiguration,
144    credentials: Option<HttpAuth>,
145    client: reqwest::Client,
146}
147
148impl PushService {
149    pub fn new(
150        cfg: impl Into<ServiceConfiguration>,
151        credentials: Option<ServiceCredentials>,
152        user_agent: impl AsRef<str>,
153    ) -> Self {
154        let cfg = cfg.into();
155        let client = reqwest::ClientBuilder::new()
156            .tls_built_in_root_certs(false)
157            .add_root_certificate(
158                reqwest::Certificate::from_pem(
159                    cfg.certificate_authority.as_bytes(),
160                )
161                .unwrap(),
162            )
163            .connect_timeout(Duration::from_secs(10))
164            .timeout(Duration::from_secs(65))
165            .user_agent(user_agent.as_ref())
166            .build()
167            .unwrap();
168
169        Self {
170            cfg,
171            credentials: credentials.and_then(|c| c.authorization()),
172            client,
173        }
174    }
175
176    #[expect(clippy::result_large_err)]
177    #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
178    pub fn request(
179        &self,
180        method: Method,
181        endpoint: Endpoint,
182        auth_override: HttpAuthOverride,
183    ) -> Result<RequestBuilder, ServiceError> {
184        let url = endpoint.into_url(&self.cfg)?;
185        let mut builder = self.client.request(method, url);
186
187        builder = match auth_override {
188            HttpAuthOverride::NoOverride => {
189                if let Some(HttpAuth { username, password }) =
190                    self.credentials.as_ref()
191                {
192                    builder.basic_auth(username, Some(password))
193                } else {
194                    builder
195                }
196            },
197            HttpAuthOverride::Identified(HttpAuth { username, password }) => {
198                builder.basic_auth(username, Some(password))
199            },
200            HttpAuthOverride::Unidentified => builder,
201        };
202
203        Ok(builder)
204    }
205
206    pub async fn ws(
207        &mut self,
208        path: &str,
209        keepalive_path: &str,
210        additional_headers: &[(&'static str, &str)],
211        credentials: Option<ServiceCredentials>,
212    ) -> Result<SignalWebSocket, ServiceError> {
213        let span = debug_span!("websocket");
214
215        let mut url = Endpoint::service(path).into_url(&self.cfg)?;
216        url.set_scheme("wss").expect("valid https base url");
217
218        let mut builder = self.client.get(url);
219        for (key, value) in additional_headers {
220            builder = builder.header(*key, *value);
221        }
222
223        if let Some(credentials) = credentials {
224            builder =
225                builder.basic_auth(credentials.login(), credentials.password);
226        }
227
228        let ws = builder
229            .upgrade()
230            .send()
231            .await?
232            .into_websocket()
233            .instrument(span.clone())
234            .await?;
235
236        let (ws, task) =
237            SignalWebSocket::from_socket(ws, keepalive_path.to_owned());
238        let task = task.instrument(span);
239        tokio::task::spawn(task);
240        Ok(ws)
241    }
242
243    pub(crate) async fn get_group(
244        &mut self,
245        credentials: HttpAuth,
246    ) -> Result<crate::proto::Group, ServiceError> {
247        self.request(
248            Method::GET,
249            Endpoint::storage("/v1/groups/"),
250            HttpAuthOverride::Identified(credentials),
251        )?
252        .send()
253        .await?
254        .service_error_for_status()
255        .await?
256        .protobuf()
257        .await
258    }
259}
260
261pub(crate) mod protobuf {
262    use async_trait::async_trait;
263    use prost::{EncodeError, Message};
264    use reqwest::{header, RequestBuilder, Response};
265
266    use super::ServiceError;
267
268    pub(crate) trait ProtobufRequestBuilderExt
269    where
270        Self: Sized,
271    {
272        /// Set the request payload encoded as protobuf.
273        /// Sets the `Content-Type` header to `application/protobuf`
274        #[allow(dead_code)]
275        fn protobuf<T: Message + Default>(
276            self,
277            value: T,
278        ) -> Result<Self, EncodeError>;
279    }
280
281    #[async_trait::async_trait]
282    pub(crate) trait ProtobufResponseExt {
283        /// Get the response body decoded from Protobuf
284        async fn protobuf<T>(self) -> Result<T, ServiceError>
285        where
286            T: prost::Message + Default;
287    }
288
289    impl ProtobufRequestBuilderExt for RequestBuilder {
290        fn protobuf<T: Message + Default>(
291            self,
292            value: T,
293        ) -> Result<Self, EncodeError> {
294            let mut buf = Vec::new();
295            value.encode(&mut buf)?;
296            let this =
297                self.header(header::CONTENT_TYPE, "application/protobuf");
298            Ok(this.body(buf))
299        }
300    }
301
302    #[async_trait]
303    impl ProtobufResponseExt for Response {
304        async fn protobuf<T>(self) -> Result<T, ServiceError>
305        where
306            T: Message + Default,
307        {
308            let body = self.bytes().await?;
309            let decoded = T::decode(body)?;
310            Ok(decoded)
311        }
312    }
313}
314
315#[cfg(test)]
316mod tests {
317    use crate::configuration::SignalServers;
318    use bytes::{Buf, Bytes};
319
320    #[test]
321    fn create_clients() {
322        let configs = &[SignalServers::Staging, SignalServers::Production];
323
324        for cfg in configs {
325            let _ =
326                super::PushService::new(cfg, None, "libsignal-service test");
327        }
328    }
329
330    #[test]
331    fn serde_json_from_empty_reader() {
332        // This fails, so we have handle empty response body separately in HyperPushService::json()
333        let bytes: Bytes = "".into();
334        assert!(
335            serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
336                bytes.reader()
337            )
338            .is_err()
339        );
340    }
341
342    #[test]
343    fn serde_json_form_empty_vec() {
344        // If we're trying to send and empty payload, serde_json must be able to make a Vec out of it
345        assert!(serde_json::to_vec(b"").is_ok());
346    }
347}