1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use std::future::Future;
6
7use bytes::Bytes;
8use futures::channel::oneshot::Canceled;
9use futures::channel::{mpsc, oneshot};
10use futures::future::BoxFuture;
11use futures::prelude::*;
12use futures::stream::FuturesUnordered;
13use reqwest::Method;
14use reqwest_websocket::WebSocket;
15use tokio::time::Instant;
16use tracing::debug;
17
18use crate::configuration::SignalServers;
19use crate::prelude::PushService;
20use crate::proto::{
21 web_socket_message, WebSocketMessage, WebSocketRequestMessage,
22 WebSocketResponseMessage,
23};
24use crate::push_service::{self, ServiceError, SignalServiceResponse};
25
26pub mod account;
27#[cfg(feature = "cdsi")]
28pub mod directory;
29pub mod keys;
30pub mod linking;
31pub mod profile;
32pub mod registration;
33mod request;
34mod sender;
35pub mod stickers;
36mod usernames;
37
38pub use request::WebSocketRequestMessageBuilder;
39
40type RequestStreamItem = (
41 WebSocketRequestMessage,
42 oneshot::Sender<WebSocketResponseMessage>,
43);
44
45pub struct SignalRequestStream {
46 inner: mpsc::UnboundedReceiver<RequestStreamItem>,
47}
48
49impl Stream for SignalRequestStream {
50 type Item = RequestStreamItem;
51
52 fn poll_next(
53 mut self: std::pin::Pin<&mut Self>,
54 cx: &mut std::task::Context<'_>,
55 ) -> std::task::Poll<Option<Self::Item>> {
56 let inner = &mut self.inner;
57 futures::pin_mut!(inner);
58 Stream::poll_next(inner, cx)
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct Identified;
64
65#[derive(Debug, Clone)]
66pub struct Unidentified;
67
68pub trait WebSocketType: 'static {}
69
70impl WebSocketType for Identified {}
71
72impl WebSocketType for Unidentified {}
73
74#[derive(Clone)]
79pub struct SignalWebSocket<C: WebSocketType> {
80 _type: PhantomData<C>,
81 pub(crate) unidentified_push_service: PushService,
83 inner: Arc<Mutex<SignalWebSocketInner>>,
84 request_sink: mpsc::Sender<(
85 WebSocketRequestMessage,
86 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
87 )>,
88}
89
90struct SignalWebSocketInner {
91 stream: Option<SignalRequestStream>,
92}
93
94struct SignalWebSocketProcess {
95 keep_alive_path: String,
97
98 requests: mpsc::Receiver<(
100 WebSocketRequestMessage,
101 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
102 )>,
103 request_sink: mpsc::UnboundedSender<RequestStreamItem>,
105
106 outgoing_requests: HashMap<
107 u64,
108 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
109 >,
110
111 outgoing_keep_alive_set: HashSet<u64>,
112
113 outgoing_responses: FuturesUnordered<
114 BoxFuture<'static, Result<WebSocketResponseMessage, Canceled>>,
115 >,
116
117 ws: WebSocket,
118}
119
120impl SignalWebSocketProcess {
121 async fn process_frame(
122 &mut self,
123 frame: Vec<u8>,
124 ) -> Result<(), ServiceError> {
125 use prost::Message;
126 let msg = WebSocketMessage::decode(Bytes::from(frame))?;
127 if let Some(request) = &msg.request {
128 tracing::trace!(
129 msg_type =? msg.r#type(),
130 request.id,
131 request.verb,
132 request.path,
133 request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0),
134 ?request.headers,
135 "decoded WebSocketMessage request"
136 );
137 } else if let Some(response) = &msg.response {
138 tracing::trace!(
139 msg_type =? msg.r#type(),
140 response.status,
141 response.message,
142 response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0),
143 ?response.headers,
144 response.id,
145 "decoded WebSocketMessage response"
146 );
147 } else {
148 tracing::debug!("decoded {msg:?}");
149 }
150
151 use web_socket_message::Type;
152 match (msg.r#type(), msg.request, msg.response) {
153 (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame {
154 reason: "unknown frame type",
155 }),
156 (Type::Request, Some(request), _) => {
157 let (sink, recv) = oneshot::channel();
158 tracing::trace!("sending request with body");
159 self.request_sink.send((request, sink)).await.map_err(
160 |_| ServiceError::WsClosing {
161 reason: "request handler failed",
162 },
163 )?;
164 self.outgoing_responses.push(Box::pin(recv));
165
166 Ok(())
167 },
168 (Type::Request, None, _) => Err(ServiceError::InvalidFrame {
169 reason: "type was request, but does not contain request",
170 }),
171 (Type::Response, _, Some(response)) => {
172 if let Some(id) = response.id {
173 if let Some(responder) = self.outgoing_requests.remove(&id)
174 {
175 if let Err(e) = responder.send(Ok(response)) {
176 tracing::warn!(
177 "Could not deliver response for id {}: {:?}",
178 id,
179 e
180 );
181 }
182 } else if let Some(_x) =
183 self.outgoing_keep_alive_set.take(&id)
184 {
185 let status = reqwest::StatusCode::from_u16(
186 response.status() as _,
187 )
188 .map_err(|e| {
189 ServiceError::IO(std::io::Error::other(format!(
190 "invalid http status code {} - {e}",
191 response.status()
192 )))
193 })?;
194 if !status.is_success() {
195 tracing::warn!(
196 %status,
197 "response code for keep-alive not successful"
198 );
199 return Err(ServiceError::UnhandledResponseCode {
200 status,
201 body: String::from_utf8_lossy(response.body())
202 .into_owned(),
203 });
204 }
205 } else {
206 tracing::warn!(
207 ?response,
208 "response for non existing request"
209 );
210 }
211 }
212
213 Ok(())
214 },
215 (Type::Response, _, None) => Err(ServiceError::InvalidFrame {
216 reason: "type was response, but does not contain response",
217 }),
218 }
219 }
220
221 fn next_request_id(&self) -> u64 {
222 use rand::Rng;
223 let mut rng = rand::rng();
224 loop {
225 let id = rng.random();
226 if !self.outgoing_requests.contains_key(&id) {
227 return id;
228 }
229 }
230 }
231
232 async fn run(mut self) -> Result<(), ServiceError> {
233 let mut ka_interval = tokio::time::interval_at(
234 Instant::now(),
235 push_service::KEEPALIVE_TIMEOUT_SECONDS,
236 );
237
238 loop {
239 futures::select! {
240 _ = ka_interval.tick().fuse() => {
241 use prost::Message;
242 if !self.outgoing_keep_alive_set.is_empty() {
243 tracing::warn!("Websocket will be closed due to failed keepalives.");
244 if let Err(e) = self.ws.close(reqwest_websocket::CloseCode::Away, None).await {
245 tracing::debug!("Could not close WebSocket: {:?}", e);
246 }
247 self.outgoing_keep_alive_set.clear();
248 break;
249 }
250 tracing::debug!("sending keep-alive");
251 let request = WebSocketRequestMessage::new(Method::GET)
252 .id(self.next_request_id())
253 .path(&self.keep_alive_path)
254 .build();
255 self.outgoing_keep_alive_set.insert(request.id.unwrap());
256 let msg = WebSocketMessage {
257 r#type: Some(web_socket_message::Type::Request.into()),
258 request: Some(request),
259 ..Default::default()
260 };
261 let buffer = msg.encode_to_vec();
262 if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await {
263 tracing::info!("Websocket sink has closed: {:?}.", e);
264 break;
265 };
266 },
267 x = self.requests.next() => {
269 match x {
270 Some((mut request, responder)) => {
271 use prost::Message;
272
273 request.id = Some(
275 request
276 .id
277 .filter(|x| !self.outgoing_requests.contains_key(x))
278 .unwrap_or_else(|| self.next_request_id()),
279 );
280 tracing::trace!(
281 request.id,
282 request.verb,
283 request.path,
284 request_body_size_bytes = request.body.as_ref().map(|x| x.len()),
285 ?request.headers,
286 "sending WebSocketRequestMessage",
287 );
288
289 self.outgoing_requests.insert(request.id.unwrap(), responder);
290 let msg = WebSocketMessage {
291 r#type: Some(web_socket_message::Type::Request.into()),
292 request: Some(request),
293 ..Default::default()
294 };
295 let buffer = msg.encode_to_vec();
296 self.ws.send(reqwest_websocket::Message::Binary(buffer)).await?
297 }
298 None => {
299 debug!("end of application request stream; websocket closing");
300 return Ok(());
301 }
302 }
303 }
304 web_socket_item = self.ws.next().fuse() => {
306 use reqwest_websocket::Message;
307 match web_socket_item {
308 Some(Ok(Message::Close { code, reason })) => {
309 tracing::warn!(%code, reason, "websocket closed");
310 break;
311 },
312 Some(Ok(Message::Binary(frame))) => {
313 self.process_frame(frame).await?;
314 }
315 Some(Ok(Message::Ping(_))) => {
316 tracing::trace!("received ping");
317 }
318 Some(Ok(Message::Pong(_))) => {
319 tracing::trace!("received pong");
320 }
321 Some(Ok(Message::Text(_))) => {
322 tracing::trace!("received text (unsupported, skipping)");
323 }
324 Some(Err(e)) => return Err(e.into()),
325 None => {
326 return Err(ServiceError::WsClosing {
327 reason: "end of web request stream; socket closing"
328 });
329 }
330 }
331 }
332 response = self.outgoing_responses.next() => {
333 use prost::Message;
334 match response {
335 Some(Ok(response)) => {
336 tracing::trace!("sending response {:?}", response);
337
338 let msg = WebSocketMessage {
339 r#type: Some(web_socket_message::Type::Response.into()),
340 response: Some(response),
341 ..Default::default()
342 };
343 let buffer = msg.encode_to_vec();
344 self.ws.send(buffer.into()).await?;
345 }
346 Some(Err(error)) => {
347 tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing.");
348 }
349 None => {
350 unreachable!("outgoing responses should never fuse")
351 }
352 }
353 }
354 }
355 }
356 Ok(())
357 }
358}
359
360impl<C: WebSocketType> SignalWebSocket<C> {
361 fn inner_locked(&self) -> MutexGuard<'_, SignalWebSocketInner> {
362 self.inner.lock().unwrap()
363 }
364
365 pub fn new(
366 ws: WebSocket,
367 keep_alive_path: String,
368 unidentified_push_service: PushService,
369 ) -> (Self, impl Future<Output = ()>) {
370 let (incoming_request_sink, incoming_request_stream) =
372 mpsc::unbounded();
373 let (outgoing_request_sink, outgoing_requests) = mpsc::channel(1);
374
375 let process = SignalWebSocketProcess {
376 keep_alive_path,
377 requests: outgoing_requests,
378 request_sink: incoming_request_sink,
379 outgoing_requests: HashMap::default(),
380 outgoing_keep_alive_set: HashSet::new(),
381 outgoing_responses: vec![
384 Box::pin(futures::future::pending()) as BoxFuture<_>
385 ]
386 .into_iter()
387 .collect(),
388 ws,
389 };
390 let process = process.run().map(|x| match x {
391 Ok(()) => (),
392 Err(e) => {
393 tracing::error!("SignalWebSocket: {}", e);
394 },
395 });
396
397 (
398 Self {
399 _type: PhantomData,
400 request_sink: outgoing_request_sink,
401 unidentified_push_service,
402 inner: Arc::new(Mutex::new(SignalWebSocketInner {
403 stream: Some(SignalRequestStream {
404 inner: incoming_request_stream,
405 }),
406 })),
407 },
408 process,
409 )
410 }
411
412 pub fn servers(&self) -> SignalServers {
413 self.unidentified_push_service.servers
414 }
415
416 pub fn is_closed(&self) -> bool {
417 self.request_sink.is_closed()
418 }
419
420 pub fn is_used(&self) -> bool {
421 self.inner_locked().stream.is_none()
422 }
423
424 pub(crate) fn take_request_stream(
425 &mut self,
426 ) -> Option<SignalRequestStream> {
427 self.inner_locked().stream.take()
428 }
429
430 pub(crate) fn return_request_stream(&mut self, r: SignalRequestStream) {
431 self.inner_locked().stream.replace(r);
432 }
433
434 pub async fn with_request_stream<
437 R: 'static,
438 F: FnOnce(&mut SignalRequestStream) -> R,
439 >(
440 &mut self,
441 f: F,
442 ) -> R {
443 let mut s = self
444 .inner_locked()
445 .stream
446 .take()
447 .expect("request stream invariant");
448 let r = f(&mut s);
449 self.inner_locked().stream.replace(s);
450 r
451 }
452
453 pub fn request(
454 &mut self,
455 r: WebSocketRequestMessage,
456 ) -> impl Future<Output = Result<WebSocketResponseMessage, ServiceError>>
457 {
458 let (sink, recv): (
459 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
460 _,
461 ) = oneshot::channel();
462
463 let mut request_sink = self.request_sink.clone();
464 async move {
465 if let Err(_e) = request_sink.send((r, sink)).await {
466 return Err(ServiceError::WsClosing {
467 reason: "WebSocket closing while sending request",
468 });
469 }
470 match recv.await {
472 Ok(x) => x,
473 Err(_) => Err(ServiceError::WsClosing {
474 reason: "WebSocket closing while waiting for a response",
475 }),
476 }
477 }
478 }
479
480 pub(crate) async fn request_json<T>(
481 &mut self,
482 r: WebSocketRequestMessage,
483 ) -> Result<T, ServiceError>
484 where
485 for<'de> T: serde::Deserialize<'de>,
486 {
487 self.request(r)
488 .await?
489 .service_error_for_status()
490 .await?
491 .json()
492 .await
493 }
494}
495
496impl WebSocketResponseMessage {
497 pub async fn service_error_for_status(self) -> Result<Self, ServiceError> {
498 super::push_service::response::service_error_for_status(self).await
499 }
500
501 pub async fn json<T: for<'a> serde::Deserialize<'a>>(
502 &self,
503 ) -> Result<T, ServiceError> {
504 self.body
505 .as_ref()
506 .ok_or(ServiceError::UnsupportedContent)
507 .and_then(|b| serde_json::from_slice(b).map_err(Into::into))
508 }
509}