libsignal_service/push_service/
mod.rs1use std::{sync::LazyLock, time::Duration};
2
3use crate::{
4 configuration::{Endpoint, ServiceCredentials},
5 prelude::ServiceConfiguration,
6 websocket::{SignalWebSocket, WebSocketType},
7};
8
9use protobuf::ProtobufResponseExt;
10use reqwest::{Method, RequestBuilder};
11use reqwest_websocket::RequestBuilderExt;
12use serde::{Deserialize, Serialize};
13use tracing::{debug_span, Instrument};
14
15pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
16pub static DEFAULT_DEVICE_ID: LazyLock<libsignal_core::DeviceId> =
17 LazyLock::new(|| libsignal_core::DeviceId::try_from(1).unwrap());
18
19mod account;
20mod cdn;
21mod error;
22pub mod linking;
23pub(crate) mod response;
24
25pub use account::*;
26pub use cdn::*;
27pub use error::*;
28pub(crate) use response::{ReqwestExt, SignalServiceResponse};
29
30#[derive(Debug, Serialize, Deserialize)]
31pub struct ProofRequired {
32 pub token: String,
33 pub options: Vec<String>,
34}
35
36#[derive(derive_more::Debug, Clone, Serialize, Deserialize)]
37pub struct HttpAuth {
38 pub username: String,
39 #[debug(ignore)]
40 pub password: String,
41}
42
43#[derive(Debug, Clone)]
44pub enum HttpAuthOverride {
45 NoOverride,
46 Unidentified,
47 Identified(HttpAuth),
48}
49
50#[derive(Debug, Clone, Eq, PartialEq)]
51pub enum AvatarWrite<C> {
52 NewAvatar(C),
53 RetainAvatar,
54 NoAvatar,
55}
56
57#[derive(Debug, Deserialize)]
58#[serde(rename_all = "camelCase")]
59pub struct MismatchedDevices {
60 pub missing_devices: Vec<u32>,
61 pub extra_devices: Vec<u32>,
62}
63
64#[derive(Debug, Deserialize)]
65#[serde(rename_all = "camelCase")]
66pub struct StaleDevices {
67 pub stale_devices: Vec<u32>,
68}
69
70#[derive(Clone)]
71pub struct PushService {
72 cfg: ServiceConfiguration,
73 credentials: Option<HttpAuth>,
74 client: reqwest::Client,
75}
76
77impl PushService {
78 pub fn new(
79 cfg: impl Into<ServiceConfiguration>,
80 credentials: Option<ServiceCredentials>,
81 user_agent: impl AsRef<str>,
82 ) -> Self {
83 let cfg = cfg.into();
84 let client = reqwest::ClientBuilder::new()
85 .tls_built_in_root_certs(false)
86 .add_root_certificate(
87 reqwest::Certificate::from_pem(
88 cfg.certificate_authority.as_bytes(),
89 )
90 .unwrap(),
91 )
92 .connect_timeout(Duration::from_secs(10))
93 .timeout(Duration::from_secs(65))
94 .user_agent(user_agent.as_ref())
95 .build()
96 .unwrap();
97
98 Self {
99 cfg,
100 credentials: credentials.and_then(|c| c.authorization()),
101 client,
102 }
103 }
104
105 #[tracing::instrument(skip(self), fields(endpoint = %endpoint))]
106 pub fn request(
107 &self,
108 method: Method,
109 endpoint: Endpoint,
110 auth_override: HttpAuthOverride,
111 ) -> Result<RequestBuilder, ServiceError> {
112 let url = endpoint.into_url(&self.cfg)?;
113 let mut builder = self.client.request(method, url);
114
115 builder = match auth_override {
116 HttpAuthOverride::NoOverride => {
117 if let Some(HttpAuth { username, password }) =
118 self.credentials.as_ref()
119 {
120 builder.basic_auth(username, Some(password))
121 } else {
122 builder
123 }
124 },
125 HttpAuthOverride::Identified(HttpAuth { username, password }) => {
126 builder.basic_auth(username, Some(password))
127 },
128 HttpAuthOverride::Unidentified => builder,
129 };
130
131 Ok(builder)
132 }
133
134 pub async fn ws<C: WebSocketType>(
135 &mut self,
136 path: &str,
137 keepalive_path: &str,
138 additional_headers: &[(&'static str, &str)],
139 credentials: Option<ServiceCredentials>,
140 ) -> Result<SignalWebSocket<C>, ServiceError> {
141 let span = debug_span!("websocket");
142
143 let mut url = Endpoint::service(path).into_url(&self.cfg)?;
144 url.set_scheme("wss").expect("valid https base url");
145
146 let mut builder = self.client.get(url);
147 for (key, value) in additional_headers {
148 builder = builder.header(*key, *value);
149 }
150
151 if let Some(credentials) = credentials {
152 builder =
153 builder.basic_auth(credentials.login(), credentials.password);
154 }
155
156 let ws = builder
157 .upgrade()
158 .send()
159 .await?
160 .into_websocket()
161 .instrument(span.clone())
162 .await?;
163
164 let unidentified_push_service = PushService {
165 cfg: self.cfg.clone(),
166 credentials: None,
167 client: self.client.clone(),
168 };
169 let (ws, task) = SignalWebSocket::new(
170 ws,
171 keepalive_path.to_owned(),
172 unidentified_push_service,
173 );
174 let task = task.instrument(span);
175 tokio::task::spawn(task);
176 Ok(ws)
177 }
178
179 pub(crate) async fn get_group(
180 &mut self,
181 credentials: HttpAuth,
182 ) -> Result<crate::proto::Group, ServiceError> {
183 self.request(
184 Method::GET,
185 Endpoint::storage("/v1/groups/"),
186 HttpAuthOverride::Identified(credentials),
187 )?
188 .send()
189 .await?
190 .service_error_for_status()
191 .await?
192 .protobuf()
193 .await
194 }
195}
196
197pub(crate) mod protobuf {
198 use async_trait::async_trait;
199 use prost::{EncodeError, Message};
200 use reqwest::{header, RequestBuilder, Response};
201
202 use super::ServiceError;
203
204 pub(crate) trait ProtobufRequestBuilderExt
205 where
206 Self: Sized,
207 {
208 #[allow(dead_code)]
211 fn protobuf<T: Message + Default>(
212 self,
213 value: T,
214 ) -> Result<Self, EncodeError>;
215 }
216
217 #[async_trait::async_trait]
218 pub(crate) trait ProtobufResponseExt {
219 async fn protobuf<T>(self) -> Result<T, ServiceError>
221 where
222 T: prost::Message + Default;
223 }
224
225 impl ProtobufRequestBuilderExt for RequestBuilder {
226 fn protobuf<T: Message + Default>(
227 self,
228 value: T,
229 ) -> Result<Self, EncodeError> {
230 let mut buf = Vec::new();
231 value.encode(&mut buf)?;
232 let this =
233 self.header(header::CONTENT_TYPE, "application/protobuf");
234 Ok(this.body(buf))
235 }
236 }
237
238 #[async_trait]
239 impl ProtobufResponseExt for Response {
240 async fn protobuf<T>(self) -> Result<T, ServiceError>
241 where
242 T: Message + Default,
243 {
244 let body = self.bytes().await?;
245 let decoded = T::decode(body)?;
246 Ok(decoded)
247 }
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use crate::configuration::SignalServers;
254 use bytes::{Buf, Bytes};
255
256 #[test]
257 fn create_clients() {
258 let configs = &[SignalServers::Staging, SignalServers::Production];
259
260 for cfg in configs {
261 let _ =
262 super::PushService::new(cfg, None, "libsignal-service test");
263 }
264 }
265
266 #[test]
267 fn serde_json_from_empty_reader() {
268 let bytes: Bytes = "".into();
270 assert!(
271 serde_json::from_reader::<bytes::buf::Reader<Bytes>, String>(
272 bytes.reader()
273 )
274 .is_err()
275 );
276 }
277
278 #[test]
279 fn serde_json_form_empty_vec() {
280 assert!(serde_json::to_vec(b"").is_ok());
282 }
283}