1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use std::future::Future;
6
7use bytes::Bytes;
8use futures::channel::oneshot::Canceled;
9use futures::channel::{mpsc, oneshot};
10use futures::future::BoxFuture;
11use futures::prelude::*;
12use futures::stream::FuturesUnordered;
13use reqwest::Method;
14use reqwest_websocket::WebSocket;
15use tokio::time::Instant;
16use tracing::debug;
17
18use crate::configuration::SignalServers;
19use crate::prelude::PushService;
20use crate::proto::{
21 web_socket_message, WebSocketMessage, WebSocketRequestMessage,
22 WebSocketResponseMessage,
23};
24use crate::push_service::{self, ServiceError, SignalServiceResponse};
25
26pub mod account;
27#[cfg(feature = "cdsi")]
28pub mod directory;
29pub mod keys;
30pub mod linking;
31pub mod profile;
32pub mod registration;
33mod request;
34mod sender;
35pub mod stickers;
36mod usernames;
37
38pub use request::WebSocketRequestMessageBuilder;
39
40type RequestStreamItem = (
41 WebSocketRequestMessage,
42 oneshot::Sender<WebSocketResponseMessage>,
43);
44
45pub struct SignalRequestStream {
46 inner: mpsc::UnboundedReceiver<RequestStreamItem>,
47}
48
49impl Stream for SignalRequestStream {
50 type Item = RequestStreamItem;
51
52 fn poll_next(
53 mut self: std::pin::Pin<&mut Self>,
54 cx: &mut std::task::Context<'_>,
55 ) -> std::task::Poll<Option<Self::Item>> {
56 let inner = &mut self.inner;
57 futures::pin_mut!(inner);
58 Stream::poll_next(inner, cx)
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct Identified;
64
65#[derive(Debug, Clone)]
66pub struct Unidentified;
67
68pub trait WebSocketType: 'static {}
69
70impl WebSocketType for Identified {}
71
72impl WebSocketType for Unidentified {}
73
74#[derive(Clone)]
79pub struct SignalWebSocket<C: WebSocketType> {
80 _type: PhantomData<C>,
81 pub(crate) unidentified_push_service: PushService,
83 inner: Arc<Mutex<SignalWebSocketInner>>,
84 request_sink: mpsc::Sender<(
85 WebSocketRequestMessage,
86 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
87 )>,
88}
89
90struct SignalWebSocketInner {
91 stream: Option<SignalRequestStream>,
92}
93
94struct SignalWebSocketProcess {
95 keep_alive_path: String,
97
98 requests: mpsc::Receiver<(
100 WebSocketRequestMessage,
101 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
102 )>,
103 request_sink: mpsc::UnboundedSender<RequestStreamItem>,
105
106 outgoing_requests: HashMap<
107 u64,
108 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
109 >,
110
111 outgoing_keep_alive_set: HashSet<u64>,
112
113 outgoing_responses: FuturesUnordered<
114 BoxFuture<'static, Result<WebSocketResponseMessage, Canceled>>,
115 >,
116
117 ws: WebSocket,
118}
119
120impl SignalWebSocketProcess {
121 async fn process_frame(
122 &mut self,
123 frame: Vec<u8>,
124 ) -> Result<(), ServiceError> {
125 use prost::Message;
126 let msg = WebSocketMessage::decode(Bytes::from(frame))?;
127 if let Some(request) = &msg.request {
128 tracing::trace!(
129 msg_type =? msg.r#type(),
130 request.id,
131 request.verb,
132 request.path,
133 request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0),
134 ?request.headers,
135 "decoded WebSocketMessage request"
136 );
137 } else if let Some(response) = &msg.response {
138 tracing::trace!(
139 msg_type =? msg.r#type(),
140 response.status,
141 response.message,
142 response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0),
143 ?response.headers,
144 response.id,
145 "decoded WebSocketMessage response"
146 );
147 } else {
148 tracing::debug!("decoded {msg:?}");
149 }
150
151 use web_socket_message::Type;
152 match (msg.r#type(), msg.request, msg.response) {
153 (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame {
154 reason: "unknown frame type",
155 }),
156 (Type::Request, Some(request), _) => {
157 let (sink, recv) = oneshot::channel();
158 tracing::trace!("sending request with body");
159 self.request_sink.send((request, sink)).await.map_err(
160 |_| ServiceError::WsClosing {
161 reason: "request handler failed",
162 },
163 )?;
164 self.outgoing_responses.push(Box::pin(recv));
165
166 Ok(())
167 },
168 (Type::Request, None, _) => Err(ServiceError::InvalidFrame {
169 reason: "type was request, but does not contain request",
170 }),
171 (Type::Response, _, Some(response)) => {
172 if let Some(id) = response.id {
173 if let Some(responder) = self.outgoing_requests.remove(&id)
174 {
175 if let Err(e) = responder.send(Ok(response)) {
176 tracing::warn!(
177 "Could not deliver response for id {}: {:?}",
178 id,
179 e
180 );
181 }
182 } else if let Some(_x) =
183 self.outgoing_keep_alive_set.take(&id)
184 {
185 let status_code = response.status();
186 if status_code != 200 {
187 tracing::warn!(
188 status_code,
189 "response code for keep-alive != 200"
190 );
191 return Err(ServiceError::UnhandledResponseCode {
192 http_code: response.status() as u16,
193 });
194 }
195 } else {
196 tracing::warn!(
197 ?response,
198 "response for non existing request"
199 );
200 }
201 }
202
203 Ok(())
204 },
205 (Type::Response, _, None) => Err(ServiceError::InvalidFrame {
206 reason: "type was response, but does not contain response",
207 }),
208 }
209 }
210
211 fn next_request_id(&self) -> u64 {
212 use rand::Rng;
213 let mut rng = rand::rng();
214 loop {
215 let id = rng.random();
216 if !self.outgoing_requests.contains_key(&id) {
217 return id;
218 }
219 }
220 }
221
222 async fn run(mut self) -> Result<(), ServiceError> {
223 let mut ka_interval = tokio::time::interval_at(
224 Instant::now(),
225 push_service::KEEPALIVE_TIMEOUT_SECONDS,
226 );
227
228 loop {
229 futures::select! {
230 _ = ka_interval.tick().fuse() => {
231 use prost::Message;
232 if !self.outgoing_keep_alive_set.is_empty() {
233 tracing::warn!("Websocket will be closed due to failed keepalives.");
234 if let Err(e) = self.ws.close(reqwest_websocket::CloseCode::Away, None).await {
235 tracing::debug!("Could not close WebSocket: {:?}", e);
236 }
237 self.outgoing_keep_alive_set.clear();
238 break;
239 }
240 tracing::debug!("sending keep-alive");
241 let request = WebSocketRequestMessage::new(Method::GET)
242 .id(self.next_request_id())
243 .path(&self.keep_alive_path)
244 .build();
245 self.outgoing_keep_alive_set.insert(request.id.unwrap());
246 let msg = WebSocketMessage {
247 r#type: Some(web_socket_message::Type::Request.into()),
248 request: Some(request),
249 ..Default::default()
250 };
251 let buffer = msg.encode_to_vec();
252 if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await {
253 tracing::info!("Websocket sink has closed: {:?}.", e);
254 break;
255 };
256 },
257 x = self.requests.next() => {
259 match x {
260 Some((mut request, responder)) => {
261 use prost::Message;
262
263 request.id = Some(
265 request
266 .id
267 .filter(|x| !self.outgoing_requests.contains_key(x))
268 .unwrap_or_else(|| self.next_request_id()),
269 );
270 tracing::trace!(
271 request.id,
272 request.verb,
273 request.path,
274 request_body_size_bytes = request.body.as_ref().map(|x| x.len()),
275 ?request.headers,
276 "sending WebSocketRequestMessage",
277 );
278
279 self.outgoing_requests.insert(request.id.unwrap(), responder);
280 let msg = WebSocketMessage {
281 r#type: Some(web_socket_message::Type::Request.into()),
282 request: Some(request),
283 ..Default::default()
284 };
285 let buffer = msg.encode_to_vec();
286 self.ws.send(reqwest_websocket::Message::Binary(buffer)).await?
287 }
288 None => {
289 debug!("end of application request stream; websocket closing");
290 return Ok(());
291 }
292 }
293 }
294 web_socket_item = self.ws.next().fuse() => {
296 use reqwest_websocket::Message;
297 match web_socket_item {
298 Some(Ok(Message::Close { code, reason })) => {
299 tracing::warn!(%code, reason, "websocket closed");
300 break;
301 },
302 Some(Ok(Message::Binary(frame))) => {
303 self.process_frame(frame).await?;
304 }
305 Some(Ok(Message::Ping(_))) => {
306 tracing::trace!("received ping");
307 }
308 Some(Ok(Message::Pong(_))) => {
309 tracing::trace!("received pong");
310 }
311 Some(Ok(Message::Text(_))) => {
312 tracing::trace!("received text (unsupported, skipping)");
313 }
314 Some(Err(e)) => return Err(e.into()),
315 None => {
316 return Err(ServiceError::WsClosing {
317 reason: "end of web request stream; socket closing"
318 });
319 }
320 }
321 }
322 response = self.outgoing_responses.next() => {
323 use prost::Message;
324 match response {
325 Some(Ok(response)) => {
326 tracing::trace!("sending response {:?}", response);
327
328 let msg = WebSocketMessage {
329 r#type: Some(web_socket_message::Type::Response.into()),
330 response: Some(response),
331 ..Default::default()
332 };
333 let buffer = msg.encode_to_vec();
334 self.ws.send(buffer.into()).await?;
335 }
336 Some(Err(error)) => {
337 tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing.");
338 }
339 None => {
340 unreachable!("outgoing responses should never fuse")
341 }
342 }
343 }
344 }
345 }
346 Ok(())
347 }
348}
349
350impl<C: WebSocketType> SignalWebSocket<C> {
351 fn inner_locked(&self) -> MutexGuard<'_, SignalWebSocketInner> {
352 self.inner.lock().unwrap()
353 }
354
355 pub fn new(
356 ws: WebSocket,
357 keep_alive_path: String,
358 unidentified_push_service: PushService,
359 ) -> (Self, impl Future<Output = ()>) {
360 let (incoming_request_sink, incoming_request_stream) =
362 mpsc::unbounded();
363 let (outgoing_request_sink, outgoing_requests) = mpsc::channel(1);
364
365 let process = SignalWebSocketProcess {
366 keep_alive_path,
367 requests: outgoing_requests,
368 request_sink: incoming_request_sink,
369 outgoing_requests: HashMap::default(),
370 outgoing_keep_alive_set: HashSet::new(),
371 outgoing_responses: vec![
374 Box::pin(futures::future::pending()) as BoxFuture<_>
375 ]
376 .into_iter()
377 .collect(),
378 ws,
379 };
380 let process = process.run().map(|x| match x {
381 Ok(()) => (),
382 Err(e) => {
383 tracing::error!("SignalWebSocket: {}", e);
384 },
385 });
386
387 (
388 Self {
389 _type: PhantomData,
390 request_sink: outgoing_request_sink,
391 unidentified_push_service,
392 inner: Arc::new(Mutex::new(SignalWebSocketInner {
393 stream: Some(SignalRequestStream {
394 inner: incoming_request_stream,
395 }),
396 })),
397 },
398 process,
399 )
400 }
401
402 pub fn servers(&self) -> SignalServers {
403 self.unidentified_push_service.servers
404 }
405
406 pub fn is_closed(&self) -> bool {
407 self.request_sink.is_closed()
408 }
409
410 pub fn is_used(&self) -> bool {
411 self.inner_locked().stream.is_none()
412 }
413
414 pub(crate) fn take_request_stream(
415 &mut self,
416 ) -> Option<SignalRequestStream> {
417 self.inner_locked().stream.take()
418 }
419
420 pub(crate) fn return_request_stream(&mut self, r: SignalRequestStream) {
421 self.inner_locked().stream.replace(r);
422 }
423
424 pub async fn with_request_stream<
427 R: 'static,
428 F: FnOnce(&mut SignalRequestStream) -> R,
429 >(
430 &mut self,
431 f: F,
432 ) -> R {
433 let mut s = self
434 .inner_locked()
435 .stream
436 .take()
437 .expect("request stream invariant");
438 let r = f(&mut s);
439 self.inner_locked().stream.replace(s);
440 r
441 }
442
443 pub fn request(
444 &mut self,
445 r: WebSocketRequestMessage,
446 ) -> impl Future<Output = Result<WebSocketResponseMessage, ServiceError>>
447 {
448 let (sink, recv): (
449 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
450 _,
451 ) = oneshot::channel();
452
453 let mut request_sink = self.request_sink.clone();
454 async move {
455 if let Err(_e) = request_sink.send((r, sink)).await {
456 return Err(ServiceError::WsClosing {
457 reason: "WebSocket closing while sending request",
458 });
459 }
460 match recv.await {
462 Ok(x) => x,
463 Err(_) => Err(ServiceError::WsClosing {
464 reason: "WebSocket closing while waiting for a response",
465 }),
466 }
467 }
468 }
469
470 pub(crate) async fn request_json<T>(
471 &mut self,
472 r: WebSocketRequestMessage,
473 ) -> Result<T, ServiceError>
474 where
475 for<'de> T: serde::Deserialize<'de>,
476 {
477 self.request(r)
478 .await?
479 .service_error_for_status()
480 .await?
481 .json()
482 .await
483 }
484}
485
486impl WebSocketResponseMessage {
487 pub async fn service_error_for_status(self) -> Result<Self, ServiceError> {
488 super::push_service::response::service_error_for_status(self).await
489 }
490
491 pub async fn json<T: for<'a> serde::Deserialize<'a>>(
492 &self,
493 ) -> Result<T, ServiceError> {
494 self.body
495 .as_ref()
496 .ok_or(ServiceError::UnsupportedContent)
497 .and_then(|b| serde_json::from_slice(b).map_err(Into::into))
498 }
499}