libsignal_service/push_service/
mod.rs

1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4    configuration::{Endpoint, ServiceCredentials},
5    prelude::ServiceConfiguration,
6    websocket::{SignalWebSocket, WebSocketType},
7};
8
9use protobuf::ProtobufResponseExt;
10use reqwest::{Method, RequestBuilder};
11use reqwest_websocket::RequestBuilderExt;
12use serde::{Deserialize, Serialize};
13use tracing::{debug_span, Instrument};
14
15pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
16pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
17    LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
18
19mod account;
20mod cdn;
21mod error;
22pub mod linking;
23pub(crate) mod response;
24
25pub use account::*;
26pub use cdn::*;
27pub use error::*;
28pub(crate) use response::{ReqwestExt, SignalServiceResponse};
29
30#[derive(Debug, Serialize, Deserialize)]
31pub struct ProofRequired {
32    pub token: String,
33    pub options: Vec<String>,
34}
35
36#[derive(derive_more::Debug, Clone, Serialize, Deserialize)]
37pub struct HttpAuth {
38    pub username: String,
39    #[debug(ignore)]
40    pub password: String,
41}
42
43#[derive(Debug, Clone)]
44pub enum HttpAuthOverride {
45    NoOverride,
46    Unidentified,
47    Identified(HttpAuth),
48}
49
50#[derive(Debug, Clone, Eq, PartialEq)]
51pub enum AvatarWrite<C> {
52    NewAvatar(C),
53    RetainAvatar,
54    NoAvatar,
55}
56
57#[derive(Debug, Deserialize)]
58#[serde(rename_all = "camelCase")]
59pub struct MismatchedDevices {
60    pub missing_devices: Vec<u32>,
61    pub extra_devices: Vec<u32>,
62}
63
64#[derive(Debug, Deserialize)]
65#[serde(rename_all = "camelCase")]
66pub struct StaleDevices {
67    pub stale_devices: Vec<u32>,
68}
69
70#[derive(Clone)]
71pub struct PushService {
72    cfg: ServiceConfiguration,
73    credentials: Option<HttpAuth>,
74    client: reqwest::Client,
75}
76
77impl PushService {
78    pub fn new(
79        cfg: impl Into<ServiceConfiguration>,
80        credentials: Option<ServiceCredentials>,
81        user_agent: impl AsRef<str>,
82    ) -> Self {
83        let cfg = cfg.into();
84        let client = reqwest::ClientBuilder::new()
85            .tls_built_in_root_certs(false)
86            .add_root_certificate(
87                reqwest::Certificate::from_pem(
88                    cfg.certificate_authority.as_bytes(),
89                )
90                .unwrap(),
91            )
92            .connect_timeout(Duration::from_secs(10))
93            .timeout(Duration::from_secs(65))
94            .user_agent(user_agent.as_ref())
95            .build()
96            .unwrap();
97
98        Self {
99            cfg,
100            credentials: credentials.and_then(|c| c.authorization()),
101            client,
102        }
103    }
104
105    #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
106    pub fn request(
107        &self,
108        method: Method,
109        endpoint: Endpoint,
110        auth_override: HttpAuthOverride,
111    ) -> Result<RequestBuilder, ServiceError> {
112        let url = endpoint.into_url(&self.cfg)?;
113        let mut builder = self.client.request(method, url);
114
115        builder = match auth_override {
116            HttpAuthOverride::NoOverride => {
117                if let Some(HttpAuth { username, password }) =
118                    self.credentials.as_ref()
119                {
120                    builder.basic_auth(username, Some(password))
121                } else {
122                    builder
123                }
124            },
125            HttpAuthOverride::Identified(HttpAuth { username, password }) => {
126                builder.basic_auth(username, Some(password))
127            },
128            HttpAuthOverride::Unidentified => builder,
129        };
130
131        Ok(builder)
132    }
133
134    pub async fn ws<C: WebSocketType>(
135        &mut self,
136        path: &str,
137        keepalive_path: &str,
138        additional_headers: &[(&'static str, &str)],
139        credentials: Option<ServiceCredentials>,
140    ) -> Result<SignalWebSocket<C>, ServiceError> {
141        let span = debug_span!("websocket");
142
143        let mut url = Endpoint::service(path).into_url(&self.cfg)?;
144        url.set_scheme("wss").expect("valid https base url");
145
146        let mut builder = self.client.get(url);
147        for (key, value) in additional_headers {
148            builder = builder.header(*key, *value);
149        }
150
151        if let Some(credentials) = credentials {
152            builder =
153                builder.basic_auth(credentials.login(), credentials.password);
154        }
155
156        let ws = builder
157            .upgrade()
158            .send()
159            .await?
160            .into_websocket()
161            .instrument(span.clone())
162            .await?;
163
164        let unidentified_push_service = PushService {
165            cfg: self.cfg.clone(),
166            credentials: None,
167            client: self.client.clone(),
168        };
169        let (ws, task) = SignalWebSocket::new(
170            ws,
171            keepalive_path.to_owned(),
172            unidentified_push_service,
173        );
174        let task = task.instrument(span);
175        tokio::task::spawn(task);
176        Ok(ws)
177    }
178
179    pub(crate) async fn get_group(
180        &mut self,
181        credentials: HttpAuth,
182    ) -> Result<crate::proto::Group, ServiceError> {
183        self.request(
184            Method::GET,
185            Endpoint::storage("/v1/groups/"),
186            HttpAuthOverride::Identified(credentials),
187        )?
188        .send()
189        .await?
190        .service_error_for_status()
191        .await?
192        .protobuf()
193        .await
194    }
195}
196
197pub(crate) mod protobuf {
198    use async_trait::async_trait;
199    use prost::{EncodeError, Message};
200    use reqwest::{header, RequestBuilder, Response};
201
202    use super::ServiceError;
203
204    pub(crate) trait ProtobufRequestBuilderExt
205    where
206        Self: Sized,
207    {
208        /// Set the request payload encoded as protobuf.
209        /// Sets the `Content-Type` header to `application/protobuf`
210        #[allow(dead_code)]
211        fn protobuf<T: Message + Default>(
212            self,
213            value: T,
214        ) -> Result<Self, EncodeError>;
215    }
216
217    #[async_trait::async_trait]
218    pub(crate) trait ProtobufResponseExt {
219        /// Get the response body decoded from Protobuf
220        async fn protobuf<T>(self) -> Result<T, ServiceError>
221        where
222            T: prost::Message + Default;
223    }
224
225    impl ProtobufRequestBuilderExt for RequestBuilder {
226        fn protobuf<T: Message + Default>(
227            self,
228            value: T,
229        ) -> Result<Self, EncodeError> {
230            let mut buf = Vec::new();
231            value.encode(&mut buf)?;
232            let this =
233                self.header(header::CONTENT_TYPE, "application/protobuf");
234            Ok(this.body(buf))
235        }
236    }
237
238    #[async_trait]
239    impl ProtobufResponseExt for Response {
240        async fn protobuf<T>(self) -> Result<T, ServiceError>
241        where
242            T: Message + Default,
243        {
244            let body = self.bytes().await?;
245            let decoded = T::decode(body)?;
246            Ok(decoded)
247        }
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use crate::configuration::SignalServers;
254    use bytes::{Buf, Bytes};
255
256    #[test]
257    fn create_clients() {
258        let configs = &[SignalServers::Staging, SignalServers::Production];
259
260        for cfg in configs {
261            let _ =
262                super::PushService::new(cfg, None, "libsignal-service test");
263        }
264    }
265
266    #[test]
267    fn serde_json_from_empty_reader() {
268        // This fails, so we have handle empty response body separately in HyperPushService::json()
269        let bytes: Bytes = "".into();
270        assert!(
271            serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
272                bytes.reader()
273            )
274            .is_err()
275        );
276    }
277
278    #[test]
279    fn serde_json_form_empty_vec() {
280        // If we're trying to send and empty payload, serde_json must be able to make a Vec out of it
281        assert!(serde_json::to_vec(b"").is_ok());
282    }
283}