1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use std::future::Future;
6
7use bytes::Bytes;
8use futures::channel::oneshot::Canceled;
9use futures::channel::{mpsc, oneshot};
10use futures::future::BoxFuture;
11use futures::prelude::*;
12use futures::stream::FuturesUnordered;
13use reqwest::Method;
14use reqwest_websocket::WebSocket;
15use tokio::time::Instant;
16use tracing::debug;
17
18use crate::prelude::PushService;
19use crate::proto::{
20 web_socket_message, WebSocketMessage, WebSocketRequestMessage,
21 WebSocketResponseMessage,
22};
23use crate::push_service::{self, ServiceError, SignalServiceResponse};
24
25pub mod account;
26pub mod keys;
27pub mod linking;
28pub mod profile;
29pub mod registration;
30mod request;
31mod sender;
32pub mod stickers;
33
34pub use request::WebSocketRequestMessageBuilder;
35
36type RequestStreamItem = (
37 WebSocketRequestMessage,
38 oneshot::Sender<WebSocketResponseMessage>,
39);
40
41pub struct SignalRequestStream {
42 inner: mpsc::Receiver<RequestStreamItem>,
43}
44
45impl Stream for SignalRequestStream {
46 type Item = RequestStreamItem;
47
48 fn poll_next(
49 mut self: std::pin::Pin<&mut Self>,
50 cx: &mut std::task::Context<'_>,
51 ) -> std::task::Poll<Option<Self::Item>> {
52 let inner = &mut self.inner;
53 futures::pin_mut!(inner);
54 Stream::poll_next(inner, cx)
55 }
56}
57
58#[derive(Debug, Clone)]
59pub struct Identified;
60
61#[derive(Debug, Clone)]
62pub struct Unidentified;
63
64pub trait WebSocketType: 'static {}
65
66impl WebSocketType for Identified {}
67
68impl WebSocketType for Unidentified {}
69
70#[derive(Clone)]
75pub struct SignalWebSocket<C: WebSocketType> {
76 _type: PhantomData<C>,
77 pub(crate) unidentified_push_service: PushService,
79 inner: Arc<Mutex<SignalWebSocketInner>>,
80 request_sink: mpsc::Sender<(
81 WebSocketRequestMessage,
82 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
83 )>,
84}
85
86struct SignalWebSocketInner {
87 stream: Option<SignalRequestStream>,
88}
89
90struct SignalWebSocketProcess {
91 keep_alive_path: String,
93
94 requests: mpsc::Receiver<(
96 WebSocketRequestMessage,
97 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
98 )>,
99 request_sink: mpsc::Sender<RequestStreamItem>,
101
102 outgoing_requests: HashMap<
103 u64,
104 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
105 >,
106
107 outgoing_keep_alive_set: HashSet<u64>,
108
109 outgoing_responses: FuturesUnordered<
110 BoxFuture<'static, Result<WebSocketResponseMessage, Canceled>>,
111 >,
112
113 ws: WebSocket,
114}
115
116impl SignalWebSocketProcess {
117 async fn process_frame(
118 &mut self,
119 frame: Vec<u8>,
120 ) -> Result<(), ServiceError> {
121 use prost::Message;
122 let msg = WebSocketMessage::decode(Bytes::from(frame))?;
123 if let Some(request) = &msg.request {
124 tracing::trace!(
125 msg_type =? msg.r#type(),
126 request.id,
127 request.verb,
128 request.path,
129 request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0),
130 ?request.headers,
131 "decoded WebSocketMessage request"
132 );
133 } else if let Some(response) = &msg.response {
134 tracing::trace!(
135 msg_type =? msg.r#type(),
136 response.status,
137 response.message,
138 response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0),
139 ?response.headers,
140 response.id,
141 "decoded WebSocketMessage response"
142 );
143 } else {
144 tracing::debug!("decoded {msg:?}");
145 }
146
147 use web_socket_message::Type;
148 match (msg.r#type(), msg.request, msg.response) {
149 (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame {
150 reason: "unknown frame type",
151 }),
152 (Type::Request, Some(request), _) => {
153 let (sink, recv) = oneshot::channel();
154 tracing::trace!("sending request with body");
155 self.request_sink.send((request, sink)).await.map_err(
156 |_| ServiceError::WsClosing {
157 reason: "request handler failed",
158 },
159 )?;
160 self.outgoing_responses.push(Box::pin(recv));
161
162 Ok(())
163 },
164 (Type::Request, None, _) => Err(ServiceError::InvalidFrame {
165 reason: "type was request, but does not contain request",
166 }),
167 (Type::Response, _, Some(response)) => {
168 if let Some(id) = response.id {
169 if let Some(responder) = self.outgoing_requests.remove(&id)
170 {
171 if let Err(e) = responder.send(Ok(response)) {
172 tracing::warn!(
173 "Could not deliver response for id {}: {:?}",
174 id,
175 e
176 );
177 }
178 } else if let Some(_x) =
179 self.outgoing_keep_alive_set.take(&id)
180 {
181 let status_code = response.status();
182 if status_code != 200 {
183 tracing::warn!(
184 status_code,
185 "response code for keep-alive != 200"
186 );
187 return Err(ServiceError::UnhandledResponseCode {
188 http_code: response.status() as u16,
189 });
190 }
191 } else {
192 tracing::warn!(
193 ?response,
194 "response for non existing request"
195 );
196 }
197 }
198
199 Ok(())
200 },
201 (Type::Response, _, None) => Err(ServiceError::InvalidFrame {
202 reason: "type was response, but does not contain response",
203 }),
204 }
205 }
206
207 fn next_request_id(&self) -> u64 {
208 use rand::Rng;
209 let mut rng = rand::rng();
210 loop {
211 let id = rng.random();
212 if !self.outgoing_requests.contains_key(&id) {
213 return id;
214 }
215 }
216 }
217
218 async fn run(mut self) -> Result<(), ServiceError> {
219 let mut ka_interval = tokio::time::interval_at(
220 Instant::now(),
221 push_service::KEEPALIVE_TIMEOUT_SECONDS,
222 );
223
224 loop {
225 futures::select! {
226 _ = ka_interval.tick().fuse() => {
227 use prost::Message;
228 if !self.outgoing_keep_alive_set.is_empty() {
229 tracing::warn!("Websocket will be closed due to failed keepalives.");
230 if let Err(e) = self.ws.close(reqwest_websocket::CloseCode::Away, None).await {
231 tracing::debug!("Could not close WebSocket: {:?}", e);
232 }
233 self.outgoing_keep_alive_set.clear();
234 break;
235 }
236 tracing::debug!("sending keep-alive");
237 let request = WebSocketRequestMessage::new(Method::GET)
238 .id(self.next_request_id())
239 .path(&self.keep_alive_path)
240 .build();
241 self.outgoing_keep_alive_set.insert(request.id.unwrap());
242 let msg = WebSocketMessage {
243 r#type: Some(web_socket_message::Type::Request.into()),
244 request: Some(request),
245 ..Default::default()
246 };
247 let buffer = msg.encode_to_vec();
248 if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await {
249 tracing::info!("Websocket sink has closed: {:?}.", e);
250 break;
251 };
252 },
253 x = self.requests.next() => {
255 match x {
256 Some((mut request, responder)) => {
257 use prost::Message;
258
259 request.id = Some(
261 request
262 .id
263 .filter(|x| !self.outgoing_requests.contains_key(x))
264 .unwrap_or_else(|| self.next_request_id()),
265 );
266 tracing::trace!(
267 request.id,
268 request.verb,
269 request.path,
270 request_body_size_bytes = request.body.as_ref().map(|x| x.len()),
271 ?request.headers,
272 "sending WebSocketRequestMessage",
273 );
274
275 self.outgoing_requests.insert(request.id.unwrap(), responder);
276 let msg = WebSocketMessage {
277 r#type: Some(web_socket_message::Type::Request.into()),
278 request: Some(request),
279 ..Default::default()
280 };
281 let buffer = msg.encode_to_vec();
282 self.ws.send(reqwest_websocket::Message::Binary(buffer)).await?
283 }
284 None => {
285 debug!("end of application request stream; websocket closing");
286 return Ok(());
287 }
288 }
289 }
290 web_socket_item = self.ws.next().fuse() => {
292 use reqwest_websocket::Message;
293 match web_socket_item {
294 Some(Ok(Message::Close { code, reason })) => {
295 tracing::warn!(%code, reason, "websocket closed");
296 break;
297 },
298 Some(Ok(Message::Binary(frame))) => {
299 self.process_frame(frame).await?;
300 }
301 Some(Ok(Message::Ping(_))) => {
302 tracing::trace!("received ping");
303 }
304 Some(Ok(Message::Pong(_))) => {
305 tracing::trace!("received pong");
306 }
307 Some(Ok(Message::Text(_))) => {
308 tracing::trace!("received text (unsupported, skipping)");
309 }
310 Some(Err(e)) => return Err(e.into()),
311 None => {
312 return Err(ServiceError::WsClosing {
313 reason: "end of web request stream; socket closing"
314 });
315 }
316 }
317 }
318 response = self.outgoing_responses.next() => {
319 use prost::Message;
320 match response {
321 Some(Ok(response)) => {
322 tracing::trace!("sending response {:?}", response);
323
324 let msg = WebSocketMessage {
325 r#type: Some(web_socket_message::Type::Response.into()),
326 response: Some(response),
327 ..Default::default()
328 };
329 let buffer = msg.encode_to_vec();
330 self.ws.send(buffer.into()).await?;
331 }
332 Some(Err(error)) => {
333 tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing.");
334 }
335 None => {
336 unreachable!("outgoing responses should never fuse")
337 }
338 }
339 }
340 }
341 }
342 Ok(())
343 }
344}
345
346impl<C: WebSocketType> SignalWebSocket<C> {
347 fn inner_locked(&self) -> MutexGuard<'_, SignalWebSocketInner> {
348 self.inner.lock().unwrap()
349 }
350
351 pub fn new(
352 ws: WebSocket,
353 keep_alive_path: String,
354 unidentified_push_service: PushService,
355 ) -> (Self, impl Future<Output = ()>) {
356 let (incoming_request_sink, incoming_request_stream) = mpsc::channel(1);
358 let (outgoing_request_sink, outgoing_requests) = mpsc::channel(1);
359
360 let process = SignalWebSocketProcess {
361 keep_alive_path,
362 requests: outgoing_requests,
363 request_sink: incoming_request_sink,
364 outgoing_requests: HashMap::default(),
365 outgoing_keep_alive_set: HashSet::new(),
366 outgoing_responses: vec![
369 Box::pin(futures::future::pending()) as BoxFuture<_>
370 ]
371 .into_iter()
372 .collect(),
373 ws,
374 };
375 let process = process.run().map(|x| match x {
376 Ok(()) => (),
377 Err(e) => {
378 tracing::error!("SignalWebSocket: {}", e);
379 },
380 });
381
382 (
383 Self {
384 _type: PhantomData,
385 request_sink: outgoing_request_sink,
386 unidentified_push_service,
387 inner: Arc::new(Mutex::new(SignalWebSocketInner {
388 stream: Some(SignalRequestStream {
389 inner: incoming_request_stream,
390 }),
391 })),
392 },
393 process,
394 )
395 }
396
397 pub fn is_closed(&self) -> bool {
398 self.request_sink.is_closed()
399 }
400
401 pub fn is_used(&self) -> bool {
402 self.inner_locked().stream.is_none()
403 }
404
405 pub(crate) fn take_request_stream(
406 &mut self,
407 ) -> Option<SignalRequestStream> {
408 self.inner_locked().stream.take()
409 }
410
411 pub(crate) fn return_request_stream(&mut self, r: SignalRequestStream) {
412 self.inner_locked().stream.replace(r);
413 }
414
415 pub async fn with_request_stream<
418 R: 'static,
419 F: FnOnce(&mut SignalRequestStream) -> R,
420 >(
421 &mut self,
422 f: F,
423 ) -> R {
424 let mut s = self
425 .inner_locked()
426 .stream
427 .take()
428 .expect("request stream invariant");
429 let r = f(&mut s);
430 self.inner_locked().stream.replace(s);
431 r
432 }
433
434 pub fn request(
435 &mut self,
436 r: WebSocketRequestMessage,
437 ) -> impl Future<Output = Result<WebSocketResponseMessage, ServiceError>>
438 {
439 let (sink, recv): (
440 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
441 _,
442 ) = oneshot::channel();
443
444 let mut request_sink = self.request_sink.clone();
445 async move {
446 if let Err(_e) = request_sink.send((r, sink)).await {
447 return Err(ServiceError::WsClosing {
448 reason: "WebSocket closing while sending request",
449 });
450 }
451 match recv.await {
453 Ok(x) => x,
454 Err(_) => Err(ServiceError::WsClosing {
455 reason: "WebSocket closing while waiting for a response",
456 }),
457 }
458 }
459 }
460
461 pub(crate) async fn request_json<T>(
462 &mut self,
463 r: WebSocketRequestMessage,
464 ) -> Result<T, ServiceError>
465 where
466 for<'de> T: serde::Deserialize<'de>,
467 {
468 self.request(r)
469 .await?
470 .service_error_for_status()
471 .await?
472 .json()
473 .await
474 }
475}
476
477impl WebSocketResponseMessage {
478 pub async fn service_error_for_status(self) -> Result<Self, ServiceError> {
479 super::push_service::response::service_error_for_status(self).await
480 }
481
482 pub async fn json<T: for<'a> serde::Deserialize<'a>>(
483 &self,
484 ) -> Result<T, ServiceError> {
485 self.body
486 .as_ref()
487 .ok_or(ServiceError::UnsupportedContent)
488 .and_then(|b| serde_json::from_slice(b).map_err(Into::into))
489 }
490}