libsignal_service/push_service/
mod.rs

1use std::time::Duration;
2
3use crate::{
4    configuration::{Endpoint, ServiceCredentials},
5    pre_keys::{KyberPreKeyEntity, PreKeyEntity, SignedPreKeyEntity},
6    prelude::ServiceConfiguration,
7    utils::serde_base64,
8    websocket::SignalWebSocket,
9};
10
11use derivative::Derivative;
12use libsignal_protocol::{
13    error::SignalProtocolError,
14    kem::{Key, Public},
15    IdentityKey, PreKeyBundle, PublicKey,
16};
17use protobuf::ProtobufResponseExt;
18use reqwest::{Method, RequestBuilder};
19use reqwest_websocket::RequestBuilderExt;
20use serde::{Deserialize, Serialize};
21use tracing::{debug_span, Instrument};
22
23pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
24pub const DEFAULT_DEVICE_ID: u32 = 1;
25
26mod account;
27mod cdn;
28mod error;
29mod keys;
30mod linking;
31mod profile;
32mod registration;
33mod response;
34mod stickers;
35
36pub use account::*;
37pub use cdn::*;
38pub use error::*;
39pub use keys::*;
40pub use linking::*;
41pub use profile::*;
42pub use registration::*;
43pub(crate) use response::{ReqwestExt, SignalServiceResponse};
44
45#[derive(Debug, Serialize, Deserialize)]
46pub struct ProofRequired {
47    pub token: String,
48    pub options: Vec<String>,
49}
50
51#[derive(Derivative, Clone, Serialize, Deserialize)]
52#[derivative(Debug)]
53pub struct HttpAuth {
54    pub username: String,
55    #[derivative(Debug = "ignore")]
56    pub password: String,
57}
58
59#[derive(Debug, Clone)]
60pub enum HttpAuthOverride {
61    NoOverride,
62    Unidentified,
63    Identified(HttpAuth),
64}
65
66#[derive(Debug, Clone, Eq, PartialEq)]
67pub enum AvatarWrite<C> {
68    NewAvatar(C),
69    RetainAvatar,
70    NoAvatar,
71}
72
73#[derive(Debug, Deserialize)]
74#[serde(rename_all = "camelCase")]
75struct SenderCertificateJson {
76    #[serde(with = "serde_base64")]
77    certificate: Vec<u8>,
78}
79
80#[derive(Debug, Deserialize)]
81#[serde(rename_all = "camelCase")]
82pub struct PreKeyResponse {
83    #[serde(with = "serde_base64")]
84    pub identity_key: Vec<u8>,
85    pub devices: Vec<PreKeyResponseItem>,
86}
87
88#[derive(Debug, Deserialize)]
89#[serde(rename_all = "camelCase")]
90pub struct PreKeyResponseItem {
91    pub device_id: u32,
92    pub registration_id: u32,
93    pub signed_pre_key: SignedPreKeyEntity,
94    pub pre_key: Option<PreKeyEntity>,
95    pub pq_pre_key: Option<KyberPreKeyEntity>,
96}
97
98impl PreKeyResponseItem {
99    pub(crate) fn into_bundle(
100        self,
101        identity: IdentityKey,
102    ) -> Result<PreKeyBundle, SignalProtocolError> {
103        let b = PreKeyBundle::new(
104            self.registration_id,
105            self.device_id.into(),
106            self.pre_key
107                .map(|pk| -> Result<_, SignalProtocolError> {
108                    Ok((
109                        pk.key_id.into(),
110                        PublicKey::deserialize(&pk.public_key)?,
111                    ))
112                })
113                .transpose()?,
114            // pre_key: Option<(u32, PublicKey)>,
115            self.signed_pre_key.key_id.into(),
116            PublicKey::deserialize(&self.signed_pre_key.public_key)?,
117            self.signed_pre_key.signature,
118            identity,
119        )?;
120
121        if let Some(pq_pk) = self.pq_pre_key {
122            Ok(b.with_kyber_pre_key(
123                pq_pk.key_id.into(),
124                Key::<Public>::deserialize(&pq_pk.public_key)?,
125                pq_pk.signature,
126            ))
127        } else {
128            Ok(b)
129        }
130    }
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct MismatchedDevices {
136    pub missing_devices: Vec<u32>,
137    pub extra_devices: Vec<u32>,
138}
139
140#[derive(Debug, Deserialize)]
141#[serde(rename_all = "camelCase")]
142pub struct StaleDevices {
143    pub stale_devices: Vec<u32>,
144}
145
146#[derive(Clone)]
147pub struct PushService {
148    cfg: ServiceConfiguration,
149    credentials: Option<HttpAuth>,
150    client: reqwest::Client,
151}
152
153impl PushService {
154    pub fn new(
155        cfg: impl Into<ServiceConfiguration>,
156        credentials: Option<ServiceCredentials>,
157        user_agent: impl AsRef<str>,
158    ) -> Self {
159        let cfg = cfg.into();
160        let client = reqwest::ClientBuilder::new()
161            .tls_built_in_root_certs(false)
162            .add_root_certificate(
163                reqwest::Certificate::from_pem(
164                    cfg.certificate_authority.as_bytes(),
165                )
166                .unwrap(),
167            )
168            .connect_timeout(Duration::from_secs(10))
169            .timeout(Duration::from_secs(65))
170            .user_agent(user_agent.as_ref())
171            .build()
172            .unwrap();
173
174        Self {
175            cfg,
176            credentials: credentials.and_then(|c| c.authorization()),
177            client,
178        }
179    }
180
181    #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
182    pub fn request(
183        &self,
184        method: Method,
185        endpoint: Endpoint,
186        auth_override: HttpAuthOverride,
187    ) -> Result<RequestBuilder, ServiceError> {
188        let url = endpoint.into_url(&self.cfg)?;
189        let mut builder = self.client.request(method, url);
190
191        builder = match auth_override {
192            HttpAuthOverride::NoOverride => {
193                if let Some(HttpAuth { username, password }) =
194                    self.credentials.as_ref()
195                {
196                    builder.basic_auth(username, Some(password))
197                } else {
198                    builder
199                }
200            },
201            HttpAuthOverride::Identified(HttpAuth { username, password }) => {
202                builder.basic_auth(username, Some(password))
203            },
204            HttpAuthOverride::Unidentified => builder,
205        };
206
207        Ok(builder)
208    }
209
210    pub async fn ws(
211        &mut self,
212        path: &str,
213        keepalive_path: &str,
214        additional_headers: &[(&'static str, &str)],
215        credentials: Option<ServiceCredentials>,
216    ) -> Result<SignalWebSocket, ServiceError> {
217        let span = debug_span!("websocket");
218
219        let mut url = Endpoint::service(path).into_url(&self.cfg)?;
220        url.set_scheme("wss").expect("valid https base url");
221
222        let mut builder = self.client.get(url);
223        for (key, value) in additional_headers {
224            builder = builder.header(*key, *value);
225        }
226
227        if let Some(credentials) = credentials {
228            builder =
229                builder.basic_auth(credentials.login(), credentials.password);
230        }
231
232        let ws = builder
233            .upgrade()
234            .send()
235            .await?
236            .into_websocket()
237            .instrument(span.clone())
238            .await?;
239
240        let (ws, task) =
241            SignalWebSocket::from_socket(ws, keepalive_path.to_owned());
242        let task = task.instrument(span);
243        tokio::task::spawn(task);
244        Ok(ws)
245    }
246
247    pub(crate) async fn get_group(
248        &mut self,
249        credentials: HttpAuth,
250    ) -> Result<crate::proto::Group, ServiceError> {
251        self.request(
252            Method::GET,
253            Endpoint::storage("/v1/groups/"),
254            HttpAuthOverride::Identified(credentials),
255        )?
256        .send()
257        .await?
258        .service_error_for_status()
259        .await?
260        .protobuf()
261        .await
262    }
263}
264
265pub(crate) mod protobuf {
266    use async_trait::async_trait;
267    use prost::{EncodeError, Message};
268    use reqwest::{header, RequestBuilder, Response};
269
270    use super::ServiceError;
271
272    pub(crate) trait ProtobufRequestBuilderExt
273    where
274        Self: Sized,
275    {
276        /// Set the request payload encoded as protobuf.
277        /// Sets the `Content-Type` header to `application/protobuf`
278        #[allow(dead_code)]
279        fn protobuf<T: Message + Default>(
280            self,
281            value: T,
282        ) -> Result<Self, EncodeError>;
283    }
284
285    #[async_trait::async_trait]
286    pub(crate) trait ProtobufResponseExt {
287        /// Get the response body decoded from Protobuf
288        async fn protobuf<T>(self) -> Result<T, ServiceError>
289        where
290            T: prost::Message + Default;
291    }
292
293    impl ProtobufRequestBuilderExt for RequestBuilder {
294        fn protobuf<T: Message + Default>(
295            self,
296            value: T,
297        ) -> Result<Self, EncodeError> {
298            let mut buf = Vec::new();
299            value.encode(&mut buf)?;
300            let this =
301                self.header(header::CONTENT_TYPE, "application/protobuf");
302            Ok(this.body(buf))
303        }
304    }
305
306    #[async_trait]
307    impl ProtobufResponseExt for Response {
308        async fn protobuf<T>(self) -> Result<T, ServiceError>
309        where
310            T: Message + Default,
311        {
312            let body = self.bytes().await?;
313            let decoded = T::decode(body)?;
314            Ok(decoded)
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use crate::configuration::SignalServers;
322    use bytes::{Buf, Bytes};
323
324    #[test]
325    fn create_clients() {
326        let configs = &[SignalServers::Staging, SignalServers::Production];
327
328        for cfg in configs {
329            let _ =
330                super::PushService::new(cfg, None, "libsignal-service test");
331        }
332    }
333
334    #[test]
335    fn serde_json_from_empty_reader() {
336        // This fails, so we have handle empty response body separately in HyperPushService::json()
337        let bytes: Bytes = "".into();
338        assert!(
339            serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
340                bytes.reader()
341            )
342            .is_err()
343        );
344    }
345
346    #[test]
347    fn serde_json_form_empty_vec() {
348        // If we're trying to send and empty payload, serde_json must be able to make a Vec out of it
349        assert!(serde_json::to_vec(b"").is_ok());
350    }
351}