libsignal_service/push_service/
mod.rs

1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4    configuration::{Endpoint, ServiceCredentials, SignalServers},
5    prelude::ServiceConfiguration,
6    utils::serde_device_id_vec,
7    websocket::{SignalWebSocket, WebSocketType},
8};
9
10use libsignal_core::DeviceId;
11use protobuf::ProtobufResponseExt;
12use reqwest::{Method, RequestBuilder};
13use reqwest_websocket::RequestBuilderExt;
14use serde::{Deserialize, Serialize};
15use tracing::{debug_span, Instrument};
16
17pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
18pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
19    LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
20
21mod account;
22mod cdn;
23mod error;
24pub mod linking;
25pub(crate) mod response;
26
27pub use account::*;
28pub use cdn::*;
29pub use error::*;
30pub(crate) use response::{ReqwestExt, SignalServiceResponse};
31
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ProofRequired {
34    pub token: String,
35    pub options: Vec<String>,
36}
37
38#[derive(derive_more::Debug, Clone, Serialize, Deserialize)]
39pub struct HttpAuth {
40    pub username: String,
41    #[debug(ignore)]
42    pub password: String,
43}
44
45#[derive(Debug, Clone)]
46pub enum HttpAuthOverride {
47    NoOverride,
48    Unidentified,
49    Identified(HttpAuth),
50}
51
52#[derive(Debug, Clone, Eq, PartialEq)]
53pub enum AvatarWrite<C> {
54    NewAvatar(C),
55    RetainAvatar,
56    NoAvatar,
57}
58
59#[derive(Debug, Deserialize)]
60#[serde(rename_all = "camelCase")]
61pub struct MismatchedDevices {
62    #[serde(with = "serde_device_id_vec")]
63    pub missing_devices: Vec<DeviceId>,
64    #[serde(with = "serde_device_id_vec")]
65    pub extra_devices: Vec<DeviceId>,
66}
67
68#[derive(Debug, Deserialize)]
69#[serde(rename_all = "camelCase")]
70pub struct StaleDevices {
71    #[serde(with = "serde_device_id_vec")]
72    pub stale_devices: Vec<DeviceId>,
73}
74
75#[derive(Clone)]
76pub struct PushService {
77    pub(crate) servers: SignalServers,
78    cfg: ServiceConfiguration,
79    credentials: Option<HttpAuth>,
80    client: reqwest::Client,
81}
82
83impl PushService {
84    pub fn new(
85        env: SignalServers,
86        credentials: Option<ServiceCredentials>,
87        user_agent: impl AsRef<str>,
88    ) -> Self {
89        let cfg: ServiceConfiguration = env.into();
90        let client = reqwest::ClientBuilder::new()
91            .tls_built_in_root_certs(false)
92            .add_root_certificate(
93                reqwest::Certificate::from_pem(
94                    cfg.certificate_authority.as_bytes(),
95                )
96                .unwrap(),
97            )
98            .connect_timeout(Duration::from_secs(10))
99            .timeout(Duration::from_secs(65))
100            .user_agent(user_agent.as_ref())
101            .http1_only()
102            .build()
103            .unwrap();
104
105        Self {
106            servers: env,
107            cfg,
108            credentials: credentials.and_then(|c| c.authorization()),
109            client,
110        }
111    }
112
113    #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
114    pub fn request(
115        &self,
116        method: Method,
117        endpoint: Endpoint,
118        auth_override: HttpAuthOverride,
119    ) -> Result<RequestBuilder, ServiceError> {
120        let url = endpoint.into_url(&self.cfg)?;
121        let mut builder = self.client.request(method, url);
122
123        builder = match auth_override {
124            HttpAuthOverride::NoOverride => {
125                if let Some(HttpAuth { username, password }) =
126                    self.credentials.as_ref()
127                {
128                    builder.basic_auth(username, Some(password))
129                } else {
130                    builder
131                }
132            },
133            HttpAuthOverride::Identified(HttpAuth { username, password }) => {
134                builder.basic_auth(username, Some(password))
135            },
136            HttpAuthOverride::Unidentified => builder,
137        };
138
139        Ok(builder)
140    }
141
142    pub async fn ws<C: WebSocketType>(
143        &mut self,
144        path: &str,
145        keepalive_path: &str,
146        additional_headers: &[(&'static str, &str)],
147        credentials: Option<ServiceCredentials>,
148    ) -> Result<SignalWebSocket<C>, ServiceError> {
149        let span = debug_span!("websocket");
150
151        let mut url = Endpoint::service(path).into_url(&self.cfg)?;
152        url.set_scheme("wss").expect("valid https base url");
153
154        let mut builder = self.client.get(url);
155        for (key, value) in additional_headers {
156            builder = builder.header(*key, *value);
157        }
158
159        if let Some(credentials) = credentials {
160            builder =
161                builder.basic_auth(credentials.login(), credentials.password);
162        }
163
164        let ws = builder
165            .upgrade()
166            .send()
167            .await?
168            .into_websocket()
169            .instrument(span.clone())
170            .await?;
171
172        let unidentified_push_service = PushService {
173            servers: self.servers,
174            cfg: self.cfg.clone(),
175            credentials: None,
176            client: self.client.clone(),
177        };
178        let (ws, task) = SignalWebSocket::new(
179            ws,
180            keepalive_path.to_owned(),
181            unidentified_push_service,
182        );
183        let task = task.instrument(span);
184        tokio::task::spawn(task);
185        Ok(ws)
186    }
187
188    pub(crate) async fn get_group(
189        &mut self,
190        credentials: HttpAuth,
191    ) -> Result<crate::proto::Group, ServiceError> {
192        self.request(
193            Method::GET,
194            Endpoint::storage("/v1/groups/"),
195            HttpAuthOverride::Identified(credentials),
196        )?
197        .send()
198        .await?
199        .service_error_for_status()
200        .await?
201        .protobuf()
202        .await
203    }
204}
205
206pub(crate) mod protobuf {
207    use async_trait::async_trait;
208    use prost::{EncodeError, Message};
209    use reqwest::{header, RequestBuilder, Response};
210
211    use super::ServiceError;
212
213    pub(crate) trait ProtobufRequestBuilderExt
214    where
215        Self: Sized,
216    {
217        /// Set the request payload encoded as protobuf.
218        /// Sets the `Content-Type` header to `application/x-protobuf`
219        #[allow(dead_code)]
220        fn protobuf<T: Message + Default>(
221            self,
222            value: T,
223        ) -> Result<Self, EncodeError>;
224    }
225
226    #[async_trait::async_trait]
227    pub(crate) trait ProtobufResponseExt {
228        /// Get the response body decoded from Protobuf
229        async fn protobuf<T>(self) -> Result<T, ServiceError>
230        where
231            T: prost::Message + Default;
232    }
233
234    impl ProtobufRequestBuilderExt for RequestBuilder {
235        fn protobuf<T: Message + Default>(
236            self,
237            value: T,
238        ) -> Result<Self, EncodeError> {
239            let mut buf = Vec::new();
240            value.encode(&mut buf)?;
241            let this =
242                self.header(header::CONTENT_TYPE, "application/x-protobuf");
243            Ok(this.body(buf))
244        }
245    }
246
247    #[async_trait]
248    impl ProtobufResponseExt for Response {
249        async fn protobuf<T>(self) -> Result<T, ServiceError>
250        where
251            T: Message + Default,
252        {
253            let body = self.bytes().await?;
254            let decoded = T::decode(body)?;
255            Ok(decoded)
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use crate::configuration::SignalServers;
263    use bytes::{Buf, Bytes};
264
265    #[test]
266    fn create_clients() {
267        let environments = &[SignalServers::Staging, SignalServers::Production];
268
269        for env in environments {
270            let _ =
271                super::PushService::new(*env, None, "libsignal-service test");
272        }
273    }
274
275    #[test]
276    fn serde_json_from_empty_reader() {
277        // This fails, so we have handle empty response body separately in HyperPushService::json()
278        let bytes: Bytes = "".into();
279        assert!(
280            serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
281                bytes.reader()
282            )
283            .is_err()
284        );
285    }
286
287    #[test]
288    fn serde_json_form_empty_vec() {
289        // If we're trying to send and empty payload, serde_json must be able to make a Vec out of it
290        assert!(serde_json::to_vec(b"").is_ok());
291    }
292}