libsignal_service/push_service/
mod.rs1use std::time::Duration;
2
3use crate::{
4 configuration::{Endpoint, ServiceCredentials},
5 pre_keys::{KyberPreKeyEntity, PreKeyEntity, SignedPreKeyEntity},
6 prelude::ServiceConfiguration,
7 utils::serde_base64,
8 websocket::SignalWebSocket,
9};
10
11use derivative::Derivative;
12use libsignal_protocol::{
13 error::SignalProtocolError,
14 kem::{Key, Public},
15 IdentityKey, PreKeyBundle, PublicKey,
16};
17use protobuf::ProtobufResponseExt;
18use reqwest::{Method, RequestBuilder};
19use reqwest_websocket::RequestBuilderExt;
20use serde::{Deserialize, Serialize};
21use tracing::{debug_span, Instrument};
22
23pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
24pub const DEFAULT_DEVICE_ID: u32 = 1;
25
26mod account;
27mod cdn;
28mod error;
29mod keys;
30mod linking;
31mod profile;
32mod registration;
33mod response;
34mod stickers;
35
36pub use account::*;
37pub use cdn::*;
38pub use error::*;
39pub use keys::*;
40pub use linking::*;
41pub use profile::*;
42pub use registration::*;
43pub(crate) use response::{ReqwestExt, SignalServiceResponse};
44
45#[derive(Debug, Serialize, Deserialize)]
46pub struct ProofRequired {
47 pub token: String,
48 pub options: Vec<String>,
49}
50
51#[derive(Derivative, Clone, Serialize, Deserialize)]
52#[derivative(Debug)]
53pub struct HttpAuth {
54 pub username: String,
55 #[derivative(Debug = "ignore")]
56 pub password: String,
57}
58
59#[derive(Debug, Clone)]
60pub enum HttpAuthOverride {
61 NoOverride,
62 Unidentified,
63 Identified(HttpAuth),
64}
65
66#[derive(Debug, Clone, Eq, PartialEq)]
67pub enum AvatarWrite<C> {
68 NewAvatar(C),
69 RetainAvatar,
70 NoAvatar,
71}
72
73#[derive(Debug, Deserialize)]
74#[serde(rename_all = "camelCase")]
75struct SenderCertificateJson {
76 #[serde(with = "serde_base64")]
77 certificate: Vec<u8>,
78}
79
80#[derive(Debug, Deserialize)]
81#[serde(rename_all = "camelCase")]
82pub struct PreKeyResponse {
83 #[serde(with = "serde_base64")]
84 pub identity_key: Vec<u8>,
85 pub devices: Vec<PreKeyResponseItem>,
86}
87
88#[derive(Debug, Deserialize)]
89#[serde(rename_all = "camelCase")]
90pub struct PreKeyResponseItem {
91 pub device_id: u32,
92 pub registration_id: u32,
93 pub signed_pre_key: SignedPreKeyEntity,
94 pub pre_key: Option<PreKeyEntity>,
95 pub pq_pre_key: Option<KyberPreKeyEntity>,
96}
97
98impl PreKeyResponseItem {
99 pub(crate) fn into_bundle(
100 self,
101 identity: IdentityKey,
102 ) -> Result<PreKeyBundle, SignalProtocolError> {
103 let b = PreKeyBundle::new(
104 self.registration_id,
105 self.device_id.into(),
106 self.pre_key
107 .map(|pk| -> Result<_, SignalProtocolError> {
108 Ok((
109 pk.key_id.into(),
110 PublicKey::deserialize(&pk.public_key)?,
111 ))
112 })
113 .transpose()?,
114 self.signed_pre_key.key_id.into(),
116 PublicKey::deserialize(&self.signed_pre_key.public_key)?,
117 self.signed_pre_key.signature,
118 identity,
119 )?;
120
121 if let Some(pq_pk) = self.pq_pre_key {
122 Ok(b.with_kyber_pre_key(
123 pq_pk.key_id.into(),
124 Key::<Public>::deserialize(&pq_pk.public_key)?,
125 pq_pk.signature,
126 ))
127 } else {
128 Ok(b)
129 }
130 }
131}
132
133#[derive(Debug, Deserialize)]
134#[serde(rename_all = "camelCase")]
135pub struct MismatchedDevices {
136 pub missing_devices: Vec<u32>,
137 pub extra_devices: Vec<u32>,
138}
139
140#[derive(Debug, Deserialize)]
141#[serde(rename_all = "camelCase")]
142pub struct StaleDevices {
143 pub stale_devices: Vec<u32>,
144}
145
146#[derive(Clone)]
147pub struct PushService {
148 cfg: ServiceConfiguration,
149 credentials: Option<HttpAuth>,
150 client: reqwest::Client,
151}
152
153impl PushService {
154 pub fn new(
155 cfg: impl Into<ServiceConfiguration>,
156 credentials: Option<ServiceCredentials>,
157 user_agent: impl AsRef<str>,
158 ) -> Self {
159 let cfg = cfg.into();
160 let client = reqwest::ClientBuilder::new()
161 .tls_built_in_root_certs(false)
162 .add_root_certificate(
163 reqwest::Certificate::from_pem(
164 cfg.certificate_authority.as_bytes(),
165 )
166 .unwrap(),
167 )
168 .connect_timeout(Duration::from_secs(10))
169 .timeout(Duration::from_secs(65))
170 .user_agent(user_agent.as_ref())
171 .build()
172 .unwrap();
173
174 Self {
175 cfg,
176 credentials: credentials.and_then(|c| c.authorization()),
177 client,
178 }
179 }
180
181 #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
182 pub fn request(
183 &self,
184 method: Method,
185 endpoint: Endpoint,
186 auth_override: HttpAuthOverride,
187 ) -> Result<RequestBuilder, ServiceError> {
188 let url = endpoint.into_url(&self.cfg)?;
189 let mut builder = self.client.request(method, url);
190
191 builder = match auth_override {
192 HttpAuthOverride::NoOverride => {
193 if let Some(HttpAuth { username, password }) =
194 self.credentials.as_ref()
195 {
196 builder.basic_auth(username, Some(password))
197 } else {
198 builder
199 }
200 },
201 HttpAuthOverride::Identified(HttpAuth { username, password }) => {
202 builder.basic_auth(username, Some(password))
203 },
204 HttpAuthOverride::Unidentified => builder,
205 };
206
207 Ok(builder)
208 }
209
210 pub async fn ws(
211 &mut self,
212 path: &str,
213 keepalive_path: &str,
214 additional_headers: &[(&'static str, &str)],
215 credentials: Option<ServiceCredentials>,
216 ) -> Result<SignalWebSocket, ServiceError> {
217 let span = debug_span!("websocket");
218
219 let mut url = Endpoint::service(path).into_url(&self.cfg)?;
220 url.set_scheme("wss").expect("valid https base url");
221
222 let mut builder = self.client.get(url);
223 for (key, value) in additional_headers {
224 builder = builder.header(*key, *value);
225 }
226
227 if let Some(credentials) = credentials {
228 builder =
229 builder.basic_auth(credentials.login(), credentials.password);
230 }
231
232 let ws = builder
233 .upgrade()
234 .send()
235 .await?
236 .into_websocket()
237 .instrument(span.clone())
238 .await?;
239
240 let (ws, task) =
241 SignalWebSocket::from_socket(ws, keepalive_path.to_owned());
242 let task = task.instrument(span);
243 tokio::task::spawn(task);
244 Ok(ws)
245 }
246
247 pub(crate) async fn get_group(
248 &mut self,
249 credentials: HttpAuth,
250 ) -> Result<crate::proto::Group, ServiceError> {
251 self.request(
252 Method::GET,
253 Endpoint::storage("/v1/groups/"),
254 HttpAuthOverride::Identified(credentials),
255 )?
256 .send()
257 .await?
258 .service_error_for_status()
259 .await?
260 .protobuf()
261 .await
262 }
263}
264
265pub(crate) mod protobuf {
266 use async_trait::async_trait;
267 use prost::{EncodeError, Message};
268 use reqwest::{header, RequestBuilder, Response};
269
270 use super::ServiceError;
271
272 pub(crate) trait ProtobufRequestBuilderExt
273 where
274 Self: Sized,
275 {
276 #[allow(dead_code)]
279 fn protobuf<T: Message + Default>(
280 self,
281 value: T,
282 ) -> Result<Self, EncodeError>;
283 }
284
285 #[async_trait::async_trait]
286 pub(crate) trait ProtobufResponseExt {
287 async fn protobuf<T>(self) -> Result<T, ServiceError>
289 where
290 T: prost::Message + Default;
291 }
292
293 impl ProtobufRequestBuilderExt for RequestBuilder {
294 fn protobuf<T: Message + Default>(
295 self,
296 value: T,
297 ) -> Result<Self, EncodeError> {
298 let mut buf = Vec::new();
299 value.encode(&mut buf)?;
300 let this =
301 self.header(header::CONTENT_TYPE, "application/protobuf");
302 Ok(this.body(buf))
303 }
304 }
305
306 #[async_trait]
307 impl ProtobufResponseExt for Response {
308 async fn protobuf<T>(self) -> Result<T, ServiceError>
309 where
310 T: Message + Default,
311 {
312 let body = self.bytes().await?;
313 let decoded = T::decode(body)?;
314 Ok(decoded)
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use crate::configuration::SignalServers;
322 use bytes::{Buf, Bytes};
323
324 #[test]
325 fn create_clients() {
326 let configs = &[SignalServers::Staging, SignalServers::Production];
327
328 for cfg in configs {
329 let _ =
330 super::PushService::new(cfg, None, "libsignal-service test");
331 }
332 }
333
334 #[test]
335 fn serde_json_from_empty_reader() {
336 let bytes: Bytes = "".into();
338 assert!(
339 serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
340 bytes.reader()
341 )
342 .is_err()
343 );
344 }
345
346 #[test]
347 fn serde_json_form_empty_vec() {
348 assert!(serde_json::to_vec(b"").is_ok());
350 }
351}