libsignal_service/push_service/
mod.rs1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4 configuration::{Endpoint, ServiceCredentials},
5 prelude::ServiceConfiguration,
6 utils::serde_device_id_vec,
7 websocket::{SignalWebSocket, WebSocketType},
8};
9
10use libsignal_core::DeviceId;
11use protobuf::ProtobufResponseExt;
12use reqwest::{Method, RequestBuilder};
13use reqwest_websocket::RequestBuilderExt;
14use serde::{Deserialize, Serialize};
15use tracing::{debug_span, Instrument};
16
17pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
18pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
19 LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
20
21mod account;
22mod cdn;
23mod error;
24pub mod linking;
25pub(crate) mod response;
26
27pub use account::*;
28pub use cdn::*;
29pub use error::*;
30pub(crate) use response::{ReqwestExt, SignalServiceResponse};
31
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ProofRequired {
34 pub token: String,
35 pub options: Vec<String>,
36}
37
38#[derive(derive_more::Debug, Clone, Serialize, Deserialize)]
39pub struct HttpAuth {
40 pub username: String,
41 #[debug(ignore)]
42 pub password: String,
43}
44
45#[derive(Debug, Clone)]
46pub enum HttpAuthOverride {
47 NoOverride,
48 Unidentified,
49 Identified(HttpAuth),
50}
51
52#[derive(Debug, Clone, Eq, PartialEq)]
53pub enum AvatarWrite<C> {
54 NewAvatar(C),
55 RetainAvatar,
56 NoAvatar,
57}
58
59#[derive(Debug, Deserialize)]
60#[serde(rename_all = "camelCase")]
61pub struct MismatchedDevices {
62 #[serde(with = "serde_device_id_vec")]
63 pub missing_devices: Vec<DeviceId>,
64 #[serde(with = "serde_device_id_vec")]
65 pub extra_devices: Vec<DeviceId>,
66}
67
68#[derive(Debug, Deserialize)]
69#[serde(rename_all = "camelCase")]
70pub struct StaleDevices {
71 #[serde(with = "serde_device_id_vec")]
72 pub stale_devices: Vec<DeviceId>,
73}
74
75#[derive(Clone)]
76pub struct PushService {
77 cfg: ServiceConfiguration,
78 credentials: Option<HttpAuth>,
79 client: reqwest::Client,
80}
81
82impl PushService {
83 pub fn new(
84 cfg: impl Into<ServiceConfiguration>,
85 credentials: Option<ServiceCredentials>,
86 user_agent: impl AsRef<str>,
87 ) -> Self {
88 let cfg = cfg.into();
89 let client = reqwest::ClientBuilder::new()
90 .tls_built_in_root_certs(false)
91 .add_root_certificate(
92 reqwest::Certificate::from_pem(
93 cfg.certificate_authority.as_bytes(),
94 )
95 .unwrap(),
96 )
97 .connect_timeout(Duration::from_secs(10))
98 .timeout(Duration::from_secs(65))
99 .user_agent(user_agent.as_ref())
100 .http1_only()
101 .build()
102 .unwrap();
103
104 Self {
105 cfg,
106 credentials: credentials.and_then(|c| c.authorization()),
107 client,
108 }
109 }
110
111 #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
112 pub fn request(
113 &self,
114 method: Method,
115 endpoint: Endpoint,
116 auth_override: HttpAuthOverride,
117 ) -> Result<RequestBuilder, ServiceError> {
118 let url = endpoint.into_url(&self.cfg)?;
119 let mut builder = self.client.request(method, url);
120
121 builder = match auth_override {
122 HttpAuthOverride::NoOverride => {
123 if let Some(HttpAuth { username, password }) =
124 self.credentials.as_ref()
125 {
126 builder.basic_auth(username, Some(password))
127 } else {
128 builder
129 }
130 },
131 HttpAuthOverride::Identified(HttpAuth { username, password }) => {
132 builder.basic_auth(username, Some(password))
133 },
134 HttpAuthOverride::Unidentified => builder,
135 };
136
137 Ok(builder)
138 }
139
140 pub async fn ws<C: WebSocketType>(
141 &mut self,
142 path: &str,
143 keepalive_path: &str,
144 additional_headers: &[(&'static str, &str)],
145 credentials: Option<ServiceCredentials>,
146 ) -> Result<SignalWebSocket<C>, ServiceError> {
147 let span = debug_span!("websocket");
148
149 let mut url = Endpoint::service(path).into_url(&self.cfg)?;
150 url.set_scheme("wss").expect("valid https base url");
151
152 let mut builder = self.client.get(url);
153 for (key, value) in additional_headers {
154 builder = builder.header(*key, *value);
155 }
156
157 if let Some(credentials) = credentials {
158 builder =
159 builder.basic_auth(credentials.login(), credentials.password);
160 }
161
162 let ws = builder
163 .upgrade()
164 .send()
165 .await?
166 .into_websocket()
167 .instrument(span.clone())
168 .await?;
169
170 let unidentified_push_service = PushService {
171 cfg: self.cfg.clone(),
172 credentials: None,
173 client: self.client.clone(),
174 };
175 let (ws, task) = SignalWebSocket::new(
176 ws,
177 keepalive_path.to_owned(),
178 unidentified_push_service,
179 );
180 let task = task.instrument(span);
181 tokio::task::spawn(task);
182 Ok(ws)
183 }
184
185 pub(crate) async fn get_group(
186 &mut self,
187 credentials: HttpAuth,
188 ) -> Result<crate::proto::Group, ServiceError> {
189 self.request(
190 Method::GET,
191 Endpoint::storage("/v1/groups/"),
192 HttpAuthOverride::Identified(credentials),
193 )?
194 .send()
195 .await?
196 .service_error_for_status()
197 .await?
198 .protobuf()
199 .await
200 }
201}
202
203pub(crate) mod protobuf {
204 use async_trait::async_trait;
205 use prost::{EncodeError, Message};
206 use reqwest::{header, RequestBuilder, Response};
207
208 use super::ServiceError;
209
210 pub(crate) trait ProtobufRequestBuilderExt
211 where
212 Self: Sized,
213 {
214 #[allow(dead_code)]
217 fn protobuf<T: Message + Default>(
218 self,
219 value: T,
220 ) -> Result<Self, EncodeError>;
221 }
222
223 #[async_trait::async_trait]
224 pub(crate) trait ProtobufResponseExt {
225 async fn protobuf<T>(self) -> Result<T, ServiceError>
227 where
228 T: prost::Message + Default;
229 }
230
231 impl ProtobufRequestBuilderExt for RequestBuilder {
232 fn protobuf<T: Message + Default>(
233 self,
234 value: T,
235 ) -> Result<Self, EncodeError> {
236 let mut buf = Vec::new();
237 value.encode(&mut buf)?;
238 let this =
239 self.header(header::CONTENT_TYPE, "application/protobuf");
240 Ok(this.body(buf))
241 }
242 }
243
244 #[async_trait]
245 impl ProtobufResponseExt for Response {
246 async fn protobuf<T>(self) -> Result<T, ServiceError>
247 where
248 T: Message + Default,
249 {
250 let body = self.bytes().await?;
251 let decoded = T::decode(body)?;
252 Ok(decoded)
253 }
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use crate::configuration::SignalServers;
260 use bytes::{Buf, Bytes};
261
262 #[test]
263 fn create_clients() {
264 let configs = &[SignalServers::Staging, SignalServers::Production];
265
266 for cfg in configs {
267 let _ =
268 super::PushService::new(cfg, None, "libsignal-service test");
269 }
270 }
271
272 #[test]
273 fn serde_json_from_empty_reader() {
274 let bytes: Bytes = "".into();
276 assert!(
277 serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
278 bytes.reader()
279 )
280 .is_err()
281 );
282 }
283
284 #[test]
285 fn serde_json_form_empty_vec() {
286 assert!(serde_json::to_vec(b"").is_ok());
288 }
289}