libsignal_service/push_service/
mod.rs1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4 configuration::{Endpoint, ServiceCredentials},
5 pre_keys::{KyberPreKeyEntity, PreKeyEntity, SignedPreKeyEntity},
6 prelude::ServiceConfiguration,
7 utils::serde_base64,
8 websocket::SignalWebSocket,
9};
10
11use derivative::Derivative;
12use libsignal_protocol::{
13 error::SignalProtocolError,
14 kem::{Key, Public},
15 IdentityKey, PreKeyBundle, PublicKey,
16};
17use protobuf::ProtobufResponseExt;
18use reqwest::{Method, RequestBuilder};
19use reqwest_websocket::RequestBuilderExt;
20use serde::{Deserialize, Serialize};
21use serde_with::serde_as;
22use tracing::{debug_span, Instrument};
23
24pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
25pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
26 LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
27
28mod account;
29mod cdn;
30mod error;
31mod keys;
32mod linking;
33mod profile;
34mod registration;
35mod response;
36mod stickers;
37
38pub use account::*;
39pub use cdn::*;
40pub use error::*;
41pub use keys::*;
42pub use linking::*;
43pub use profile::*;
44pub use registration::*;
45pub(crate) use response::{ReqwestExt, SignalServiceResponse};
46
47#[derive(Debug, Serialize, Deserialize)]
48pub struct ProofRequired {
49 pub token: String,
50 pub options: Vec<String>,
51}
52
53#[derive(Derivative, Clone, Serialize, Deserialize)]
54#[derivative(Debug)]
55pub struct HttpAuth {
56 pub username: String,
57 #[derivative(Debug = "ignore")]
58 pub password: String,
59}
60
61#[derive(Debug, Clone)]
62pub enum HttpAuthOverride {
63 NoOverride,
64 Unidentified,
65 Identified(HttpAuth),
66}
67
68#[derive(Debug, Clone, Eq, PartialEq)]
69pub enum AvatarWrite<C> {
70 NewAvatar(C),
71 RetainAvatar,
72 NoAvatar,
73}
74
75#[derive(Debug, Deserialize)]
76#[serde(rename_all = "camelCase")]
77struct SenderCertificateJson {
78 #[serde(with = "serde_base64")]
79 certificate: Vec<u8>,
80}
81
82#[derive(Debug, Deserialize)]
83#[serde(rename_all = "camelCase")]
84pub struct PreKeyResponse {
85 #[serde(with = "serde_base64")]
86 pub identity_key: Vec<u8>,
87 pub devices: Vec<PreKeyResponseItem>,
88}
89
90#[derive(Debug, Deserialize)]
91#[serde(rename_all = "camelCase")]
92pub struct PreKeyResponseItem {
93 pub device_id: u32,
94 pub registration_id: u32,
95 pub signed_pre_key: SignedPreKeyEntity,
96 pub pre_key: Option<PreKeyEntity>,
97 pub pq_pre_key: KyberPreKeyEntity,
98}
99
100impl PreKeyResponseItem {
101 #[allow(clippy::result_large_err)]
102 pub(crate) fn into_bundle(
103 self,
104 identity: IdentityKey,
105 ) -> Result<PreKeyBundle, ServiceError> {
106 Ok(PreKeyBundle::new(
107 self.registration_id,
108 self.device_id.try_into()?,
109 self.pre_key
110 .map(|pk| -> Result<_, SignalProtocolError> {
111 Ok((
112 pk.key_id.into(),
113 PublicKey::deserialize(&pk.public_key)?,
114 ))
115 })
116 .transpose()?,
117 self.signed_pre_key.key_id.into(),
119 PublicKey::deserialize(&self.signed_pre_key.public_key)?,
120 self.signed_pre_key.signature,
121 self.pq_pre_key.key_id.into(),
122 Key::<Public>::deserialize(&self.pq_pre_key.public_key)?,
123 self.pq_pre_key.signature,
124 identity,
125 )?)
126 }
127}
128
129#[derive(Debug, Deserialize)]
130#[serde(rename_all = "camelCase")]
131pub struct MismatchedDevices {
132 pub missing_devices: Vec<u32>,
133 pub extra_devices: Vec<u32>,
134}
135
136#[derive(Debug, Deserialize)]
137#[serde_as]
138#[serde(rename_all = "camelCase")]
139pub struct StaleDevices {
140 pub stale_devices: Vec<u32>,
141}
142
143#[derive(Clone)]
144pub struct PushService {
145 cfg: ServiceConfiguration,
146 credentials: Option<HttpAuth>,
147 client: reqwest::Client,
148}
149
150impl PushService {
151 pub fn new(
152 cfg: impl Into<ServiceConfiguration>,
153 credentials: Option<ServiceCredentials>,
154 user_agent: impl AsRef<str>,
155 ) -> Self {
156 let cfg = cfg.into();
157 let client = reqwest::ClientBuilder::new()
158 .tls_built_in_root_certs(false)
159 .add_root_certificate(
160 reqwest::Certificate::from_pem(
161 cfg.certificate_authority.as_bytes(),
162 )
163 .unwrap(),
164 )
165 .connect_timeout(Duration::from_secs(10))
166 .timeout(Duration::from_secs(65))
167 .user_agent(user_agent.as_ref())
168 .build()
169 .unwrap();
170
171 Self {
172 cfg,
173 credentials: credentials.and_then(|c| c.authorization()),
174 client,
175 }
176 }
177
178 #[expect(clippy::result_large_err)]
179 #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
180 pub fn request(
181 &self,
182 method: Method,
183 endpoint: Endpoint,
184 auth_override: HttpAuthOverride,
185 ) -> Result<RequestBuilder, ServiceError> {
186 let url = endpoint.into_url(&self.cfg)?;
187 let mut builder = self.client.request(method, url);
188
189 builder = match auth_override {
190 HttpAuthOverride::NoOverride => {
191 if let Some(HttpAuth { username, password }) =
192 self.credentials.as_ref()
193 {
194 builder.basic_auth(username, Some(password))
195 } else {
196 builder
197 }
198 },
199 HttpAuthOverride::Identified(HttpAuth { username, password }) => {
200 builder.basic_auth(username, Some(password))
201 },
202 HttpAuthOverride::Unidentified => builder,
203 };
204
205 Ok(builder)
206 }
207
208 pub async fn ws(
209 &mut self,
210 path: &str,
211 keepalive_path: &str,
212 additional_headers: &[(&'static str, &str)],
213 credentials: Option<ServiceCredentials>,
214 ) -> Result<SignalWebSocket, ServiceError> {
215 let span = debug_span!("websocket");
216
217 let mut url = Endpoint::service(path).into_url(&self.cfg)?;
218 url.set_scheme("wss").expect("valid https base url");
219
220 let mut builder = self.client.get(url);
221 for (key, value) in additional_headers {
222 builder = builder.header(*key, *value);
223 }
224
225 if let Some(credentials) = credentials {
226 builder =
227 builder.basic_auth(credentials.login(), credentials.password);
228 }
229
230 let ws = builder
231 .upgrade()
232 .send()
233 .await?
234 .into_websocket()
235 .instrument(span.clone())
236 .await?;
237
238 let (ws, task) =
239 SignalWebSocket::from_socket(ws, keepalive_path.to_owned());
240 let task = task.instrument(span);
241 tokio::task::spawn(task);
242 Ok(ws)
243 }
244
245 pub(crate) async fn get_group(
246 &mut self,
247 credentials: HttpAuth,
248 ) -> Result<crate::proto::Group, ServiceError> {
249 self.request(
250 Method::GET,
251 Endpoint::storage("/v1/groups/"),
252 HttpAuthOverride::Identified(credentials),
253 )?
254 .send()
255 .await?
256 .service_error_for_status()
257 .await?
258 .protobuf()
259 .await
260 }
261}
262
263pub(crate) mod protobuf {
264 use async_trait::async_trait;
265 use prost::{EncodeError, Message};
266 use reqwest::{header, RequestBuilder, Response};
267
268 use super::ServiceError;
269
270 pub(crate) trait ProtobufRequestBuilderExt
271 where
272 Self: Sized,
273 {
274 #[allow(dead_code)]
277 fn protobuf<T: Message + Default>(
278 self,
279 value: T,
280 ) -> Result<Self, EncodeError>;
281 }
282
283 #[async_trait::async_trait]
284 pub(crate) trait ProtobufResponseExt {
285 async fn protobuf<T>(self) -> Result<T, ServiceError>
287 where
288 T: prost::Message + Default;
289 }
290
291 impl ProtobufRequestBuilderExt for RequestBuilder {
292 fn protobuf<T: Message + Default>(
293 self,
294 value: T,
295 ) -> Result<Self, EncodeError> {
296 let mut buf = Vec::new();
297 value.encode(&mut buf)?;
298 let this =
299 self.header(header::CONTENT_TYPE, "application/protobuf");
300 Ok(this.body(buf))
301 }
302 }
303
304 #[async_trait]
305 impl ProtobufResponseExt for Response {
306 async fn protobuf<T>(self) -> Result<T, ServiceError>
307 where
308 T: Message + Default,
309 {
310 let body = self.bytes().await?;
311 let decoded = T::decode(body)?;
312 Ok(decoded)
313 }
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use crate::configuration::SignalServers;
320 use bytes::{Buf, Bytes};
321
322 #[test]
323 fn create_clients() {
324 let configs = &[SignalServers::Staging, SignalServers::Production];
325
326 for cfg in configs {
327 let _ =
328 super::PushService::new(cfg, None, "libsignal-service test");
329 }
330 }
331
332 #[test]
333 fn serde_json_from_empty_reader() {
334 let bytes: Bytes = "".into();
336 assert!(
337 serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
338 bytes.reader()
339 )
340 .is_err()
341 );
342 }
343
344 #[test]
345 fn serde_json_form_empty_vec() {
346 assert!(serde_json::to_vec(b"").is_ok());
348 }
349}