libsignal_service/push_service/
mod.rs

1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4    configuration::{Endpoint, ServiceCredentials},
5    prelude::ServiceConfiguration,
6    utils::serde_device_id_vec,
7    websocket::{SignalWebSocket, WebSocketType},
8};
9
10use libsignal_core::DeviceId;
11use protobuf::ProtobufResponseExt;
12use reqwest::{Method, RequestBuilder};
13use reqwest_websocket::RequestBuilderExt;
14use serde::{Deserialize, Serialize};
15use tracing::{debug_span, Instrument};
16
17pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
18pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
19    LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
20
21mod account;
22mod cdn;
23mod error;
24pub mod linking;
25pub(crate) mod response;
26
27pub use account::*;
28pub use cdn::*;
29pub use error::*;
30pub(crate) use response::{ReqwestExt, SignalServiceResponse};
31
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ProofRequired {
34    pub token: String,
35    pub options: Vec<String>,
36}
37
38#[derive(derive_more::Debug, Clone, Serialize, Deserialize)]
39pub struct HttpAuth {
40    pub username: String,
41    #[debug(ignore)]
42    pub password: String,
43}
44
45#[derive(Debug, Clone)]
46pub enum HttpAuthOverride {
47    NoOverride,
48    Unidentified,
49    Identified(HttpAuth),
50}
51
52#[derive(Debug, Clone, Eq, PartialEq)]
53pub enum AvatarWrite<C> {
54    NewAvatar(C),
55    RetainAvatar,
56    NoAvatar,
57}
58
59#[derive(Debug, Deserialize)]
60#[serde(rename_all = "camelCase")]
61pub struct MismatchedDevices {
62    #[serde(with = "serde_device_id_vec")]
63    pub missing_devices: Vec<DeviceId>,
64    #[serde(with = "serde_device_id_vec")]
65    pub extra_devices: Vec<DeviceId>,
66}
67
68#[derive(Debug, Deserialize)]
69#[serde(rename_all = "camelCase")]
70pub struct StaleDevices {
71    #[serde(with = "serde_device_id_vec")]
72    pub stale_devices: Vec<DeviceId>,
73}
74
75#[derive(Clone)]
76pub struct PushService {
77    cfg: ServiceConfiguration,
78    credentials: Option<HttpAuth>,
79    client: reqwest::Client,
80}
81
82impl PushService {
83    pub fn new(
84        cfg: impl Into<ServiceConfiguration>,
85        credentials: Option<ServiceCredentials>,
86        user_agent: impl AsRef<str>,
87    ) -> Self {
88        let cfg = cfg.into();
89        let client = reqwest::ClientBuilder::new()
90            .tls_built_in_root_certs(false)
91            .add_root_certificate(
92                reqwest::Certificate::from_pem(
93                    cfg.certificate_authority.as_bytes(),
94                )
95                .unwrap(),
96            )
97            .connect_timeout(Duration::from_secs(10))
98            .timeout(Duration::from_secs(65))
99            .user_agent(user_agent.as_ref())
100            .http1_only()
101            .build()
102            .unwrap();
103
104        Self {
105            cfg,
106            credentials: credentials.and_then(|c| c.authorization()),
107            client,
108        }
109    }
110
111    #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
112    pub fn request(
113        &self,
114        method: Method,
115        endpoint: Endpoint,
116        auth_override: HttpAuthOverride,
117    ) -> Result<RequestBuilder, ServiceError> {
118        let url = endpoint.into_url(&self.cfg)?;
119        let mut builder = self.client.request(method, url);
120
121        builder = match auth_override {
122            HttpAuthOverride::NoOverride => {
123                if let Some(HttpAuth { username, password }) =
124                    self.credentials.as_ref()
125                {
126                    builder.basic_auth(username, Some(password))
127                } else {
128                    builder
129                }
130            },
131            HttpAuthOverride::Identified(HttpAuth { username, password }) => {
132                builder.basic_auth(username, Some(password))
133            },
134            HttpAuthOverride::Unidentified => builder,
135        };
136
137        Ok(builder)
138    }
139
140    pub async fn ws<C: WebSocketType>(
141        &mut self,
142        path: &str,
143        keepalive_path: &str,
144        additional_headers: &[(&'static str, &str)],
145        credentials: Option<ServiceCredentials>,
146    ) -> Result<SignalWebSocket<C>, ServiceError> {
147        let span = debug_span!("websocket");
148
149        let mut url = Endpoint::service(path).into_url(&self.cfg)?;
150        url.set_scheme("wss").expect("valid https base url");
151
152        let mut builder = self.client.get(url);
153        for (key, value) in additional_headers {
154            builder = builder.header(*key, *value);
155        }
156
157        if let Some(credentials) = credentials {
158            builder =
159                builder.basic_auth(credentials.login(), credentials.password);
160        }
161
162        let ws = builder
163            .upgrade()
164            .send()
165            .await?
166            .into_websocket()
167            .instrument(span.clone())
168            .await?;
169
170        let unidentified_push_service = PushService {
171            cfg: self.cfg.clone(),
172            credentials: None,
173            client: self.client.clone(),
174        };
175        let (ws, task) = SignalWebSocket::new(
176            ws,
177            keepalive_path.to_owned(),
178            unidentified_push_service,
179        );
180        let task = task.instrument(span);
181        tokio::task::spawn(task);
182        Ok(ws)
183    }
184
185    pub(crate) async fn get_group(
186        &mut self,
187        credentials: HttpAuth,
188    ) -> Result<crate::proto::Group, ServiceError> {
189        self.request(
190            Method::GET,
191            Endpoint::storage("/v1/groups/"),
192            HttpAuthOverride::Identified(credentials),
193        )?
194        .send()
195        .await?
196        .service_error_for_status()
197        .await?
198        .protobuf()
199        .await
200    }
201}
202
203pub(crate) mod protobuf {
204    use async_trait::async_trait;
205    use prost::{EncodeError, Message};
206    use reqwest::{header, RequestBuilder, Response};
207
208    use super::ServiceError;
209
210    pub(crate) trait ProtobufRequestBuilderExt
211    where
212        Self: Sized,
213    {
214        /// Set the request payload encoded as protobuf.
215        /// Sets the `Content-Type` header to `application/protobuf`
216        #[allow(dead_code)]
217        fn protobuf<T: Message + Default>(
218            self,
219            value: T,
220        ) -> Result<Self, EncodeError>;
221    }
222
223    #[async_trait::async_trait]
224    pub(crate) trait ProtobufResponseExt {
225        /// Get the response body decoded from Protobuf
226        async fn protobuf<T>(self) -> Result<T, ServiceError>
227        where
228            T: prost::Message + Default;
229    }
230
231    impl ProtobufRequestBuilderExt for RequestBuilder {
232        fn protobuf<T: Message + Default>(
233            self,
234            value: T,
235        ) -> Result<Self, EncodeError> {
236            let mut buf = Vec::new();
237            value.encode(&mut buf)?;
238            let this =
239                self.header(header::CONTENT_TYPE, "application/protobuf");
240            Ok(this.body(buf))
241        }
242    }
243
244    #[async_trait]
245    impl ProtobufResponseExt for Response {
246        async fn protobuf<T>(self) -> Result<T, ServiceError>
247        where
248            T: Message + Default,
249        {
250            let body = self.bytes().await?;
251            let decoded = T::decode(body)?;
252            Ok(decoded)
253        }
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use crate::configuration::SignalServers;
260    use bytes::{Buf, Bytes};
261
262    #[test]
263    fn create_clients() {
264        let configs = &[SignalServers::Staging, SignalServers::Production];
265
266        for cfg in configs {
267            let _ =
268                super::PushService::new(cfg, None, "libsignal-service test");
269        }
270    }
271
272    #[test]
273    fn serde_json_from_empty_reader() {
274        // This fails, so we have handle empty response body separately in HyperPushService::json()
275        let bytes: Bytes = "".into();
276        assert!(
277            serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
278                bytes.reader()
279            )
280            .is_err()
281        );
282    }
283
284    #[test]
285    fn serde_json_form_empty_vec() {
286        // If we're trying to send and empty payload, serde_json must be able to make a Vec out of it
287        assert!(serde_json::to_vec(b"").is_ok());
288    }
289}