libsignal_service/push_service/
mod.rs

1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4    configuration::{Endpoint, ServiceCredentials},
5    pre_keys::{KyberPreKeyEntity, PreKeyEntity, SignedPreKeyEntity},
6    prelude::ServiceConfiguration,
7    utils::serde_base64,
8    websocket::SignalWebSocket,
9};
10
11use derivative::Derivative;
12use libsignal_protocol::{
13    error::SignalProtocolError,
14    kem::{Key, Public},
15    IdentityKey, PreKeyBundle, PublicKey,
16};
17use protobuf::ProtobufResponseExt;
18use reqwest::{Method, RequestBuilder};
19use reqwest_websocket::RequestBuilderExt;
20use serde::{Deserialize, Serialize};
21use serde_with::serde_as;
22use tracing::{debug_span, Instrument};
23
24pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
25pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
26    LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
27
28mod account;
29mod cdn;
30mod error;
31mod keys;
32mod linking;
33mod profile;
34mod registration;
35mod response;
36mod stickers;
37
38pub use account::*;
39pub use cdn::*;
40pub use error::*;
41pub use keys::*;
42pub use linking::*;
43pub use profile::*;
44pub use registration::*;
45pub(crate) use response::{ReqwestExt, SignalServiceResponse};
46
47#[derive(Debug, Serialize, Deserialize)]
48pub struct ProofRequired {
49    pub token: String,
50    pub options: Vec<String>,
51}
52
53#[derive(Derivative, Clone, Serialize, Deserialize)]
54#[derivative(Debug)]
55pub struct HttpAuth {
56    pub username: String,
57    #[derivative(Debug = "ignore")]
58    pub password: String,
59}
60
61#[derive(Debug, Clone)]
62pub enum HttpAuthOverride {
63    NoOverride,
64    Unidentified,
65    Identified(HttpAuth),
66}
67
68#[derive(Debug, Clone, Eq, PartialEq)]
69pub enum AvatarWrite<C> {
70    NewAvatar(C),
71    RetainAvatar,
72    NoAvatar,
73}
74
75#[derive(Debug, Deserialize)]
76#[serde(rename_all = "camelCase")]
77struct SenderCertificateJson {
78    #[serde(with = "serde_base64")]
79    certificate: Vec<u8>,
80}
81
82#[derive(Debug, Deserialize)]
83#[serde(rename_all = "camelCase")]
84pub struct PreKeyResponse {
85    #[serde(with = "serde_base64")]
86    pub identity_key: Vec<u8>,
87    pub devices: Vec<PreKeyResponseItem>,
88}
89
90#[derive(Debug, Deserialize)]
91#[serde(rename_all = "camelCase")]
92pub struct PreKeyResponseItem {
93    pub device_id: u32,
94    pub registration_id: u32,
95    pub signed_pre_key: SignedPreKeyEntity,
96    pub pre_key: Option<PreKeyEntity>,
97    pub pq_pre_key: KyberPreKeyEntity,
98}
99
100impl PreKeyResponseItem {
101    #[allow(clippy::result_large_err)]
102    pub(crate) fn into_bundle(
103        self,
104        identity: IdentityKey,
105    ) -> Result<PreKeyBundle, ServiceError> {
106        Ok(PreKeyBundle::new(
107            self.registration_id,
108            self.device_id.try_into()?,
109            self.pre_key
110                .map(|pk| -> Result<_, SignalProtocolError> {
111                    Ok((
112                        pk.key_id.into(),
113                        PublicKey::deserialize(&pk.public_key)?,
114                    ))
115                })
116                .transpose()?,
117            // pre_key: Option<(u32, PublicKey)>,
118            self.signed_pre_key.key_id.into(),
119            PublicKey::deserialize(&self.signed_pre_key.public_key)?,
120            self.signed_pre_key.signature,
121            self.pq_pre_key.key_id.into(),
122            Key::<Public>::deserialize(&self.pq_pre_key.public_key)?,
123            self.pq_pre_key.signature,
124            identity,
125        )?)
126    }
127}
128
129#[derive(Debug, Deserialize)]
130#[serde(rename_all = "camelCase")]
131pub struct MismatchedDevices {
132    pub missing_devices: Vec<u32>,
133    pub extra_devices: Vec<u32>,
134}
135
136#[derive(Debug, Deserialize)]
137#[serde_as]
138#[serde(rename_all = "camelCase")]
139pub struct StaleDevices {
140    pub stale_devices: Vec<u32>,
141}
142
143#[derive(Clone)]
144pub struct PushService {
145    cfg: ServiceConfiguration,
146    credentials: Option<HttpAuth>,
147    client: reqwest::Client,
148}
149
150impl PushService {
151    pub fn new(
152        cfg: impl Into<ServiceConfiguration>,
153        credentials: Option<ServiceCredentials>,
154        user_agent: impl AsRef<str>,
155    ) -> Self {
156        let cfg = cfg.into();
157        let client = reqwest::ClientBuilder::new()
158            .tls_built_in_root_certs(false)
159            .add_root_certificate(
160                reqwest::Certificate::from_pem(
161                    cfg.certificate_authority.as_bytes(),
162                )
163                .unwrap(),
164            )
165            .connect_timeout(Duration::from_secs(10))
166            .timeout(Duration::from_secs(65))
167            .user_agent(user_agent.as_ref())
168            .build()
169            .unwrap();
170
171        Self {
172            cfg,
173            credentials: credentials.and_then(|c| c.authorization()),
174            client,
175        }
176    }
177
178    #[expect(clippy::result_large_err)]
179    #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
180    pub fn request(
181        &self,
182        method: Method,
183        endpoint: Endpoint,
184        auth_override: HttpAuthOverride,
185    ) -> Result<RequestBuilder, ServiceError> {
186        let url = endpoint.into_url(&self.cfg)?;
187        let mut builder = self.client.request(method, url);
188
189        builder = match auth_override {
190            HttpAuthOverride::NoOverride => {
191                if let Some(HttpAuth { username, password }) =
192                    self.credentials.as_ref()
193                {
194                    builder.basic_auth(username, Some(password))
195                } else {
196                    builder
197                }
198            },
199            HttpAuthOverride::Identified(HttpAuth { username, password }) => {
200                builder.basic_auth(username, Some(password))
201            },
202            HttpAuthOverride::Unidentified => builder,
203        };
204
205        Ok(builder)
206    }
207
208    pub async fn ws(
209        &mut self,
210        path: &str,
211        keepalive_path: &str,
212        additional_headers: &[(&'static str, &str)],
213        credentials: Option<ServiceCredentials>,
214    ) -> Result<SignalWebSocket, ServiceError> {
215        let span = debug_span!("websocket");
216
217        let mut url = Endpoint::service(path).into_url(&self.cfg)?;
218        url.set_scheme("wss").expect("valid https base url");
219
220        let mut builder = self.client.get(url);
221        for (key, value) in additional_headers {
222            builder = builder.header(*key, *value);
223        }
224
225        if let Some(credentials) = credentials {
226            builder =
227                builder.basic_auth(credentials.login(), credentials.password);
228        }
229
230        let ws = builder
231            .upgrade()
232            .send()
233            .await?
234            .into_websocket()
235            .instrument(span.clone())
236            .await?;
237
238        let (ws, task) =
239            SignalWebSocket::from_socket(ws, keepalive_path.to_owned());
240        let task = task.instrument(span);
241        tokio::task::spawn(task);
242        Ok(ws)
243    }
244
245    pub(crate) async fn get_group(
246        &mut self,
247        credentials: HttpAuth,
248    ) -> Result<crate::proto::Group, ServiceError> {
249        self.request(
250            Method::GET,
251            Endpoint::storage("/v1/groups/"),
252            HttpAuthOverride::Identified(credentials),
253        )?
254        .send()
255        .await?
256        .service_error_for_status()
257        .await?
258        .protobuf()
259        .await
260    }
261}
262
263pub(crate) mod protobuf {
264    use async_trait::async_trait;
265    use prost::{EncodeError, Message};
266    use reqwest::{header, RequestBuilder, Response};
267
268    use super::ServiceError;
269
270    pub(crate) trait ProtobufRequestBuilderExt
271    where
272        Self: Sized,
273    {
274        /// Set the request payload encoded as protobuf.
275        /// Sets the `Content-Type` header to `application/protobuf`
276        #[allow(dead_code)]
277        fn protobuf<T: Message + Default>(
278            self,
279            value: T,
280        ) -> Result<Self, EncodeError>;
281    }
282
283    #[async_trait::async_trait]
284    pub(crate) trait ProtobufResponseExt {
285        /// Get the response body decoded from Protobuf
286        async fn protobuf<T>(self) -> Result<T, ServiceError>
287        where
288            T: prost::Message + Default;
289    }
290
291    impl ProtobufRequestBuilderExt for RequestBuilder {
292        fn protobuf<T: Message + Default>(
293            self,
294            value: T,
295        ) -> Result<Self, EncodeError> {
296            let mut buf = Vec::new();
297            value.encode(&mut buf)?;
298            let this =
299                self.header(header::CONTENT_TYPE, "application/protobuf");
300            Ok(this.body(buf))
301        }
302    }
303
304    #[async_trait]
305    impl ProtobufResponseExt for Response {
306        async fn protobuf<T>(self) -> Result<T, ServiceError>
307        where
308            T: Message + Default,
309        {
310            let body = self.bytes().await?;
311            let decoded = T::decode(body)?;
312            Ok(decoded)
313        }
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use crate::configuration::SignalServers;
320    use bytes::{Buf, Bytes};
321
322    #[test]
323    fn create_clients() {
324        let configs = &[SignalServers::Staging, SignalServers::Production];
325
326        for cfg in configs {
327            let _ =
328                super::PushService::new(cfg, None, "libsignal-service test");
329        }
330    }
331
332    #[test]
333    fn serde_json_from_empty_reader() {
334        // This fails, so we have handle empty response body separately in HyperPushService::json()
335        let bytes: Bytes = "".into();
336        assert!(
337            serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
338                bytes.reader()
339            )
340            .is_err()
341        );
342    }
343
344    #[test]
345    fn serde_json_form_empty_vec() {
346        // If we're trying to send and empty payload, serde_json must be able to make a Vec out of it
347        assert!(serde_json::to_vec(b"").is_ok());
348    }
349}