libsignal_service/push_service/
mod.rs1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4 configuration::{Endpoint, ServiceCredentials, SignalServers},
5 prelude::ServiceConfiguration,
6 utils::serde_device_id_vec,
7 websocket::{SignalWebSocket, WebSocketType},
8};
9
10use libsignal_core::DeviceId;
11use protobuf::ProtobufResponseExt;
12use reqwest::{Method, RequestBuilder};
13use reqwest_websocket::RequestBuilderExt;
14use serde::{Deserialize, Serialize};
15use tracing::{debug_span, Instrument};
16
17pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
18pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
19 LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
20
21mod account;
22mod cdn;
23mod error;
24pub mod linking;
25pub(crate) mod response;
26
27pub use account::*;
28pub use cdn::*;
29pub use error::*;
30pub(crate) use response::{ReqwestExt, SignalServiceResponse};
31
32#[derive(Debug, Serialize, Deserialize)]
33pub struct ProofRequired {
34 pub token: String,
35 pub options: Vec<String>,
36}
37
38#[derive(derive_more::Debug, Clone, Serialize, Deserialize)]
39pub struct HttpAuth {
40 pub username: String,
41 #[debug(ignore)]
42 pub password: String,
43}
44
45#[derive(Debug, Clone)]
46pub enum HttpAuthOverride {
47 NoOverride,
48 Unidentified,
49 Identified(HttpAuth),
50}
51
52#[derive(Debug, Clone, Eq, PartialEq)]
53pub enum AvatarWrite<C> {
54 NewAvatar(C),
55 RetainAvatar,
56 NoAvatar,
57}
58
59#[derive(Debug, Deserialize)]
60#[serde(rename_all = "camelCase")]
61pub struct MismatchedDevices {
62 #[serde(with = "serde_device_id_vec")]
63 pub missing_devices: Vec<DeviceId>,
64 #[serde(with = "serde_device_id_vec")]
65 pub extra_devices: Vec<DeviceId>,
66}
67
68#[derive(Debug, Deserialize)]
69#[serde(rename_all = "camelCase")]
70pub struct StaleDevices {
71 #[serde(with = "serde_device_id_vec")]
72 pub stale_devices: Vec<DeviceId>,
73}
74
75#[derive(Clone)]
76pub struct PushService {
77 pub(crate) servers: SignalServers,
78 cfg: ServiceConfiguration,
79 credentials: Option<HttpAuth>,
80 client: reqwest::Client,
81}
82
83impl PushService {
84 pub fn new(
85 env: SignalServers,
86 credentials: Option<ServiceCredentials>,
87 user_agent: impl AsRef<str>,
88 ) -> Self {
89 let cfg: ServiceConfiguration = env.into();
90 let client = reqwest::ClientBuilder::new()
91 .tls_built_in_root_certs(false)
92 .add_root_certificate(
93 reqwest::Certificate::from_pem(
94 cfg.certificate_authority.as_bytes(),
95 )
96 .unwrap(),
97 )
98 .connect_timeout(Duration::from_secs(10))
99 .timeout(Duration::from_secs(65))
100 .user_agent(user_agent.as_ref())
101 .http1_only()
102 .build()
103 .unwrap();
104
105 Self {
106 servers: env,
107 cfg,
108 credentials: credentials.and_then(|c| c.authorization()),
109 client,
110 }
111 }
112
113 #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
114 pub fn request(
115 &self,
116 method: Method,
117 endpoint: Endpoint,
118 auth_override: HttpAuthOverride,
119 ) -> Result<RequestBuilder, ServiceError> {
120 let url = endpoint.into_url(&self.cfg)?;
121 let mut builder = self.client.request(method, url);
122
123 builder = match auth_override {
124 HttpAuthOverride::NoOverride => {
125 if let Some(HttpAuth { username, password }) =
126 self.credentials.as_ref()
127 {
128 builder.basic_auth(username, Some(password))
129 } else {
130 builder
131 }
132 },
133 HttpAuthOverride::Identified(HttpAuth { username, password }) => {
134 builder.basic_auth(username, Some(password))
135 },
136 HttpAuthOverride::Unidentified => builder,
137 };
138
139 Ok(builder)
140 }
141
142 pub async fn ws<C: WebSocketType>(
143 &mut self,
144 path: &str,
145 keepalive_path: &str,
146 additional_headers: &[(&'static str, &str)],
147 credentials: Option<ServiceCredentials>,
148 ) -> Result<SignalWebSocket<C>, ServiceError> {
149 let span = debug_span!("websocket");
150
151 let mut url = Endpoint::service(path).into_url(&self.cfg)?;
152 url.set_scheme("wss").expect("valid https base url");
153
154 let mut builder = self.client.get(url);
155 for (key, value) in additional_headers {
156 builder = builder.header(*key, *value);
157 }
158
159 if let Some(credentials) = credentials {
160 builder =
161 builder.basic_auth(credentials.login(), credentials.password);
162 }
163
164 let ws = builder
165 .upgrade()
166 .send()
167 .await?
168 .into_websocket()
169 .instrument(span.clone())
170 .await?;
171
172 let unidentified_push_service = PushService {
173 servers: self.servers,
174 cfg: self.cfg.clone(),
175 credentials: None,
176 client: self.client.clone(),
177 };
178 let (ws, task) = SignalWebSocket::new(
179 ws,
180 keepalive_path.to_owned(),
181 unidentified_push_service,
182 );
183 let task = task.instrument(span);
184 tokio::task::spawn(task);
185 Ok(ws)
186 }
187
188 pub(crate) async fn get_group(
189 &mut self,
190 credentials: HttpAuth,
191 ) -> Result<crate::proto::Group, ServiceError> {
192 self.request(
193 Method::GET,
194 Endpoint::storage("/v1/groups/"),
195 HttpAuthOverride::Identified(credentials),
196 )?
197 .send()
198 .await?
199 .service_error_for_status()
200 .await?
201 .protobuf()
202 .await
203 }
204}
205
206pub(crate) mod protobuf {
207 use async_trait::async_trait;
208 use prost::{EncodeError, Message};
209 use reqwest::{header, RequestBuilder, Response};
210
211 use super::ServiceError;
212
213 pub(crate) trait ProtobufRequestBuilderExt
214 where
215 Self: Sized,
216 {
217 #[allow(dead_code)]
220 fn protobuf<T: Message + Default>(
221 self,
222 value: T,
223 ) -> Result<Self, EncodeError>;
224 }
225
226 #[async_trait::async_trait]
227 pub(crate) trait ProtobufResponseExt {
228 async fn protobuf<T>(self) -> Result<T, ServiceError>
230 where
231 T: prost::Message + Default;
232 }
233
234 impl ProtobufRequestBuilderExt for RequestBuilder {
235 fn protobuf<T: Message + Default>(
236 self,
237 value: T,
238 ) -> Result<Self, EncodeError> {
239 let mut buf = Vec::new();
240 value.encode(&mut buf)?;
241 let this =
242 self.header(header::CONTENT_TYPE, "application/x-protobuf");
243 Ok(this.body(buf))
244 }
245 }
246
247 #[async_trait]
248 impl ProtobufResponseExt for Response {
249 async fn protobuf<T>(self) -> Result<T, ServiceError>
250 where
251 T: Message + Default,
252 {
253 let body = self.bytes().await?;
254 let decoded = T::decode(body)?;
255 Ok(decoded)
256 }
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use crate::configuration::SignalServers;
263 use bytes::{Buf, Bytes};
264
265 #[test]
266 fn create_clients() {
267 let environments = &[SignalServers::Staging, SignalServers::Production];
268
269 for env in environments {
270 let _ =
271 super::PushService::new(*env, None, "libsignal-service test");
272 }
273 }
274
275 #[test]
276 fn serde_json_from_empty_reader() {
277 let bytes: Bytes = "".into();
279 assert!(
280 serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
281 bytes.reader()
282 )
283 .is_err()
284 );
285 }
286
287 #[test]
288 fn serde_json_form_empty_vec() {
289 assert!(serde_json::to_vec(b"").is_ok());
291 }
292}