libsignal_service/push_service/
mod.rs1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4 configuration::{Endpoint, ServiceCredentials},
5 pre_keys::{KyberPreKeyEntity, PreKeyEntity, SignedPreKeyEntity},
6 prelude::ServiceConfiguration,
7 utils::serde_base64,
8 websocket::SignalWebSocket,
9};
10
11use derivative::Derivative;
12use libsignal_protocol::{
13 error::SignalProtocolError,
14 kem::{Key, Public},
15 IdentityKey, PreKeyBundle, PublicKey,
16};
17use protobuf::ProtobufResponseExt;
18use reqwest::{Method, RequestBuilder};
19use reqwest_websocket::RequestBuilderExt;
20use serde::{Deserialize, Serialize};
21use tracing::{debug_span, Instrument};
22
23pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
24pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
25 LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
26
27mod account;
28mod cdn;
29mod error;
30mod keys;
31mod linking;
32mod profile;
33mod registration;
34mod response;
35mod stickers;
36
37pub use account::*;
38pub use cdn::*;
39pub use error::*;
40pub use keys::*;
41pub use linking::*;
42pub use profile::*;
43pub use registration::*;
44pub(crate) use response::{ReqwestExt, SignalServiceResponse};
45
46#[derive(Debug, Serialize, Deserialize)]
47pub struct ProofRequired {
48 pub token: String,
49 pub options: Vec<String>,
50}
51
52#[derive(Derivative, Clone, Serialize, Deserialize)]
53#[derivative(Debug)]
54pub struct HttpAuth {
55 pub username: String,
56 #[derivative(Debug = "ignore")]
57 pub password: String,
58}
59
60#[derive(Debug, Clone)]
61pub enum HttpAuthOverride {
62 NoOverride,
63 Unidentified,
64 Identified(HttpAuth),
65}
66
67#[derive(Debug, Clone, Eq, PartialEq)]
68pub enum AvatarWrite<C> {
69 NewAvatar(C),
70 RetainAvatar,
71 NoAvatar,
72}
73
74#[derive(Debug, Deserialize)]
75#[serde(rename_all = "camelCase")]
76struct SenderCertificateJson {
77 #[serde(with = "serde_base64")]
78 certificate: Vec<u8>,
79}
80
81#[derive(Debug, Deserialize)]
82#[serde(rename_all = "camelCase")]
83pub struct PreKeyResponse {
84 #[serde(with = "serde_base64")]
85 pub identity_key: Vec<u8>,
86 pub devices: Vec<PreKeyResponseItem>,
87}
88
89#[derive(Debug, Deserialize)]
90#[serde(rename_all = "camelCase")]
91pub struct PreKeyResponseItem {
92 pub device_id: u32,
93 pub registration_id: u32,
94 pub signed_pre_key: SignedPreKeyEntity,
95 pub pre_key: Option<PreKeyEntity>,
96 pub pq_pre_key: KyberPreKeyEntity,
97}
98
99impl PreKeyResponseItem {
100 #[allow(clippy::result_large_err)]
101 pub(crate) fn into_bundle(
102 self,
103 identity: IdentityKey,
104 ) -> Result<PreKeyBundle, ServiceError> {
105 Ok(PreKeyBundle::new(
106 self.registration_id,
107 self.device_id.try_into()?,
108 self.pre_key
109 .map(|pk| -> Result<_, SignalProtocolError> {
110 Ok((
111 pk.key_id.into(),
112 PublicKey::deserialize(&pk.public_key)?,
113 ))
114 })
115 .transpose()?,
116 self.signed_pre_key.key_id.into(),
118 PublicKey::deserialize(&self.signed_pre_key.public_key)?,
119 self.signed_pre_key.signature,
120 self.pq_pre_key.key_id.into(),
121 Key::<Public>::deserialize(&self.pq_pre_key.public_key)?,
122 self.pq_pre_key.signature,
123 identity,
124 )?)
125 }
126}
127
128#[derive(Debug, Deserialize)]
129#[serde(rename_all = "camelCase")]
130pub struct MismatchedDevices {
131 pub missing_devices: Vec<u32>,
132 pub extra_devices: Vec<u32>,
133}
134
135#[derive(Debug, Deserialize)]
136#[serde(rename_all = "camelCase")]
137pub struct StaleDevices {
138 pub stale_devices: Vec<u32>,
139}
140
141#[derive(Clone)]
142pub struct PushService {
143 cfg: ServiceConfiguration,
144 credentials: Option<HttpAuth>,
145 client: reqwest::Client,
146}
147
148impl PushService {
149 pub fn new(
150 cfg: impl Into<ServiceConfiguration>,
151 credentials: Option<ServiceCredentials>,
152 user_agent: impl AsRef<str>,
153 ) -> Self {
154 let cfg = cfg.into();
155 let client = reqwest::ClientBuilder::new()
156 .tls_built_in_root_certs(false)
157 .add_root_certificate(
158 reqwest::Certificate::from_pem(
159 cfg.certificate_authority.as_bytes(),
160 )
161 .unwrap(),
162 )
163 .connect_timeout(Duration::from_secs(10))
164 .timeout(Duration::from_secs(65))
165 .user_agent(user_agent.as_ref())
166 .build()
167 .unwrap();
168
169 Self {
170 cfg,
171 credentials: credentials.and_then(|c| c.authorization()),
172 client,
173 }
174 }
175
176 #[expect(clippy::result_large_err)]
177 #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
178 pub fn request(
179 &self,
180 method: Method,
181 endpoint: Endpoint,
182 auth_override: HttpAuthOverride,
183 ) -> Result<RequestBuilder, ServiceError> {
184 let url = endpoint.into_url(&self.cfg)?;
185 let mut builder = self.client.request(method, url);
186
187 builder = match auth_override {
188 HttpAuthOverride::NoOverride => {
189 if let Some(HttpAuth { username, password }) =
190 self.credentials.as_ref()
191 {
192 builder.basic_auth(username, Some(password))
193 } else {
194 builder
195 }
196 },
197 HttpAuthOverride::Identified(HttpAuth { username, password }) => {
198 builder.basic_auth(username, Some(password))
199 },
200 HttpAuthOverride::Unidentified => builder,
201 };
202
203 Ok(builder)
204 }
205
206 pub async fn ws(
207 &mut self,
208 path: &str,
209 keepalive_path: &str,
210 additional_headers: &[(&'static str, &str)],
211 credentials: Option<ServiceCredentials>,
212 ) -> Result<SignalWebSocket, ServiceError> {
213 let span = debug_span!("websocket");
214
215 let mut url = Endpoint::service(path).into_url(&self.cfg)?;
216 url.set_scheme("wss").expect("valid https base url");
217
218 let mut builder = self.client.get(url);
219 for (key, value) in additional_headers {
220 builder = builder.header(*key, *value);
221 }
222
223 if let Some(credentials) = credentials {
224 builder =
225 builder.basic_auth(credentials.login(), credentials.password);
226 }
227
228 let ws = builder
229 .upgrade()
230 .send()
231 .await?
232 .into_websocket()
233 .instrument(span.clone())
234 .await?;
235
236 let (ws, task) =
237 SignalWebSocket::from_socket(ws, keepalive_path.to_owned());
238 let task = task.instrument(span);
239 tokio::task::spawn(task);
240 Ok(ws)
241 }
242
243 pub(crate) async fn get_group(
244 &mut self,
245 credentials: HttpAuth,
246 ) -> Result<crate::proto::Group, ServiceError> {
247 self.request(
248 Method::GET,
249 Endpoint::storage("/v1/groups/"),
250 HttpAuthOverride::Identified(credentials),
251 )?
252 .send()
253 .await?
254 .service_error_for_status()
255 .await?
256 .protobuf()
257 .await
258 }
259}
260
261pub(crate) mod protobuf {
262 use async_trait::async_trait;
263 use prost::{EncodeError, Message};
264 use reqwest::{header, RequestBuilder, Response};
265
266 use super::ServiceError;
267
268 pub(crate) trait ProtobufRequestBuilderExt
269 where
270 Self: Sized,
271 {
272 #[allow(dead_code)]
275 fn protobuf<T: Message + Default>(
276 self,
277 value: T,
278 ) -> Result<Self, EncodeError>;
279 }
280
281 #[async_trait::async_trait]
282 pub(crate) trait ProtobufResponseExt {
283 async fn protobuf<T>(self) -> Result<T, ServiceError>
285 where
286 T: prost::Message + Default;
287 }
288
289 impl ProtobufRequestBuilderExt for RequestBuilder {
290 fn protobuf<T: Message + Default>(
291 self,
292 value: T,
293 ) -> Result<Self, EncodeError> {
294 let mut buf = Vec::new();
295 value.encode(&mut buf)?;
296 let this =
297 self.header(header::CONTENT_TYPE, "application/protobuf");
298 Ok(this.body(buf))
299 }
300 }
301
302 #[async_trait]
303 impl ProtobufResponseExt for Response {
304 async fn protobuf<T>(self) -> Result<T, ServiceError>
305 where
306 T: Message + Default,
307 {
308 let body = self.bytes().await?;
309 let decoded = T::decode(body)?;
310 Ok(decoded)
311 }
312 }
313}
314
315#[cfg(test)]
316mod tests {
317 use crate::configuration::SignalServers;
318 use bytes::{Buf, Bytes};
319
320 #[test]
321 fn create_clients() {
322 let configs = &[SignalServers::Staging, SignalServers::Production];
323
324 for cfg in configs {
325 let _ =
326 super::PushService::new(cfg, None, "libsignal-service test");
327 }
328 }
329
330 #[test]
331 fn serde_json_from_empty_reader() {
332 let bytes: Bytes = "".into();
334 assert!(
335 serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
336 bytes.reader()
337 )
338 .is_err()
339 );
340 }
341
342 #[test]
343 fn serde_json_form_empty_vec() {
344 assert!(serde_json::to_vec(b"").is_ok());
346 }
347}