libsignal_service/websocket/
mod.rs

1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use std::future::Future;
6
7use bytes::Bytes;
8use futures::channel::oneshot::Canceled;
9use futures::channel::{mpsc, oneshot};
10use futures::future::BoxFuture;
11use futures::prelude::*;
12use futures::stream::FuturesUnordered;
13use reqwest::Method;
14use reqwest_websocket::WebSocket;
15use tokio::time::Instant;
16use tracing::debug;
17
18use crate::prelude::PushService;
19use crate::proto::{
20    web_socket_message, WebSocketMessage, WebSocketRequestMessage,
21    WebSocketResponseMessage,
22};
23use crate::push_service::{self, ServiceError, SignalServiceResponse};
24
25pub mod account;
26pub mod keys;
27pub mod linking;
28pub mod profile;
29pub mod registration;
30mod request;
31mod sender;
32pub mod stickers;
33
34pub use request::WebSocketRequestMessageBuilder;
35
36type RequestStreamItem = (
37    WebSocketRequestMessage,
38    oneshot::Sender<WebSocketResponseMessage>,
39);
40
41pub struct SignalRequestStream {
42    inner: mpsc::Receiver<RequestStreamItem>,
43}
44
45impl Stream for SignalRequestStream {
46    type Item = RequestStreamItem;
47
48    fn poll_next(
49        mut self: std::pin::Pin<&mut Self>,
50        cx: &mut std::task::Context<'_>,
51    ) -> std::task::Poll<Option<Self::Item>> {
52        let inner = &mut self.inner;
53        futures::pin_mut!(inner);
54        Stream::poll_next(inner, cx)
55    }
56}
57
58#[derive(Debug, Clone)]
59pub struct Identified;
60
61#[derive(Debug, Clone)]
62pub struct Unidentified;
63
64pub trait WebSocketType: 'static {}
65
66impl WebSocketType for Identified {}
67
68impl WebSocketType for Unidentified {}
69
70/// A dispatching web socket client for the Signal web socket API.
71///
72/// This structure can be freely cloned, since this acts as a *facade* for multiple entry and exit
73/// points.
74#[derive(Clone)]
75pub struct SignalWebSocket<C: WebSocketType> {
76    _type: PhantomData<C>,
77    // XXX: at the end of the migration, this should be CDN operations only
78    pub(crate) unidentified_push_service: PushService,
79    inner: Arc<Mutex<SignalWebSocketInner>>,
80    request_sink: mpsc::Sender<(
81        WebSocketRequestMessage,
82        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
83    )>,
84}
85
86struct SignalWebSocketInner {
87    stream: Option<SignalRequestStream>,
88}
89
90struct SignalWebSocketProcess {
91    /// Whether to enable keep-alive or not (and send a request to this path)
92    keep_alive_path: String,
93
94    /// Receives requests from the application, which we forward to Signal.
95    requests: mpsc::Receiver<(
96        WebSocketRequestMessage,
97        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
98    )>,
99    /// Signal's requests should go in here, to be delivered to the application.
100    request_sink: mpsc::Sender<RequestStreamItem>,
101
102    outgoing_requests: HashMap<
103        u64,
104        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
105    >,
106
107    outgoing_keep_alive_set: HashSet<u64>,
108
109    outgoing_responses: FuturesUnordered<
110        BoxFuture<'static, Result<WebSocketResponseMessage, Canceled>>,
111    >,
112
113    ws: WebSocket,
114}
115
116impl SignalWebSocketProcess {
117    async fn process_frame(
118        &mut self,
119        frame: Vec<u8>,
120    ) -> Result<(), ServiceError> {
121        use prost::Message;
122        let msg = WebSocketMessage::decode(Bytes::from(frame))?;
123        if let Some(request) = &msg.request {
124            tracing::trace!(
125                msg_type =? msg.r#type(),
126                request.id,
127                request.verb,
128                request.path,
129                request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0),
130                ?request.headers,
131                "decoded WebSocketMessage request"
132            );
133        } else if let Some(response) = &msg.response {
134            tracing::trace!(
135                msg_type =? msg.r#type(),
136                response.status,
137                response.message,
138                response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0),
139                ?response.headers,
140                response.id,
141                "decoded WebSocketMessage response"
142            );
143        } else {
144            tracing::debug!("decoded {msg:?}");
145        }
146
147        use web_socket_message::Type;
148        match (msg.r#type(), msg.request, msg.response) {
149            (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame {
150                reason: "unknown frame type",
151            }),
152            (Type::Request, Some(request), _) => {
153                let (sink, recv) = oneshot::channel();
154                tracing::trace!("sending request with body");
155                self.request_sink.send((request, sink)).await.map_err(
156                    |_| ServiceError::WsClosing {
157                        reason: "request handler failed",
158                    },
159                )?;
160                self.outgoing_responses.push(Box::pin(recv));
161
162                Ok(())
163            },
164            (Type::Request, None, _) => Err(ServiceError::InvalidFrame {
165                reason: "type was request, but does not contain request",
166            }),
167            (Type::Response, _, Some(response)) => {
168                if let Some(id) = response.id {
169                    if let Some(responder) = self.outgoing_requests.remove(&id)
170                    {
171                        if let Err(e) = responder.send(Ok(response)) {
172                            tracing::warn!(
173                                "Could not deliver response for id {}: {:?}",
174                                id,
175                                e
176                            );
177                        }
178                    } else if let Some(_x) =
179                        self.outgoing_keep_alive_set.take(&id)
180                    {
181                        let status_code = response.status();
182                        if status_code != 200 {
183                            tracing::warn!(
184                                status_code,
185                                "response code for keep-alive != 200"
186                            );
187                            return Err(ServiceError::UnhandledResponseCode {
188                                http_code: response.status() as u16,
189                            });
190                        }
191                    } else {
192                        tracing::warn!(
193                            ?response,
194                            "response for non existing request"
195                        );
196                    }
197                }
198
199                Ok(())
200            },
201            (Type::Response, _, None) => Err(ServiceError::InvalidFrame {
202                reason: "type was response, but does not contain response",
203            }),
204        }
205    }
206
207    fn next_request_id(&self) -> u64 {
208        use rand::Rng;
209        let mut rng = rand::rng();
210        loop {
211            let id = rng.random();
212            if !self.outgoing_requests.contains_key(&id) {
213                return id;
214            }
215        }
216    }
217
218    async fn run(mut self) -> Result<(), ServiceError> {
219        let mut ka_interval = tokio::time::interval_at(
220            Instant::now(),
221            push_service::KEEPALIVE_TIMEOUT_SECONDS,
222        );
223
224        loop {
225            futures::select! {
226                _ = ka_interval.tick().fuse() => {
227                    use prost::Message;
228                    if !self.outgoing_keep_alive_set.is_empty() {
229                        tracing::warn!("Websocket will be closed due to failed keepalives.");
230                        if let Err(e) = self.ws.close(reqwest_websocket::CloseCode::Away, None).await {
231                            tracing::debug!("Could not close WebSocket: {:?}", e);
232                        }
233                        self.outgoing_keep_alive_set.clear();
234                        break;
235                    }
236                    tracing::debug!("sending keep-alive");
237                    let request = WebSocketRequestMessage::new(Method::GET)
238                        .id(self.next_request_id())
239                        .path(&self.keep_alive_path)
240                        .build();
241                    self.outgoing_keep_alive_set.insert(request.id.unwrap());
242                    let msg = WebSocketMessage {
243                        r#type: Some(web_socket_message::Type::Request.into()),
244                        request: Some(request),
245                        ..Default::default()
246                    };
247                    let buffer = msg.encode_to_vec();
248                    if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await {
249                        tracing::info!("Websocket sink has closed: {:?}.", e);
250                        break;
251                    };
252                },
253                // Process requests from the application, forward them to Signal
254                x = self.requests.next() => {
255                    match x {
256                        Some((mut request, responder)) => {
257                            use prost::Message;
258
259                            // Regenerate ID if already in the table
260                            request.id = Some(
261                                request
262                                    .id
263                                    .filter(|x| !self.outgoing_requests.contains_key(x))
264                                    .unwrap_or_else(|| self.next_request_id()),
265                            );
266                            tracing::trace!(
267                                request.id,
268                                request.verb,
269                                request.path,
270                                request_body_size_bytes = request.body.as_ref().map(|x| x.len()),
271                                ?request.headers,
272                                "sending WebSocketRequestMessage",
273                            );
274
275                            self.outgoing_requests.insert(request.id.unwrap(), responder);
276                            let msg = WebSocketMessage {
277                                r#type: Some(web_socket_message::Type::Request.into()),
278                                request: Some(request),
279                                ..Default::default()
280                            };
281                            let buffer = msg.encode_to_vec();
282                            self.ws.send(reqwest_websocket::Message::Binary(buffer)).await?
283                        }
284                        None => {
285                            debug!("end of application request stream; websocket closing");
286                            return Ok(());
287                        }
288                    }
289                }
290                // Incoming websocket message
291                web_socket_item = self.ws.next().fuse() => {
292                    use reqwest_websocket::Message;
293                    match web_socket_item {
294                        Some(Ok(Message::Close { code, reason })) => {
295                            tracing::warn!(%code, reason, "websocket closed");
296                            break;
297                        },
298                        Some(Ok(Message::Binary(frame))) => {
299                            self.process_frame(frame).await?;
300                        }
301                        Some(Ok(Message::Ping(_))) => {
302                            tracing::trace!("received ping");
303                        }
304                        Some(Ok(Message::Pong(_))) => {
305                            tracing::trace!("received pong");
306                        }
307                        Some(Ok(Message::Text(_))) => {
308                            tracing::trace!("received text (unsupported, skipping)");
309                        }
310                        Some(Err(e)) => return Err(e.into()),
311                        None => {
312                            return Err(ServiceError::WsClosing {
313                                reason: "end of web request stream; socket closing"
314                            });
315                        }
316                    }
317                }
318                response = self.outgoing_responses.next() => {
319                    use prost::Message;
320                    match response {
321                        Some(Ok(response)) => {
322                            tracing::trace!("sending response {:?}", response);
323
324                            let msg = WebSocketMessage {
325                                r#type: Some(web_socket_message::Type::Response.into()),
326                                response: Some(response),
327                                ..Default::default()
328                            };
329                            let buffer = msg.encode_to_vec();
330                            self.ws.send(buffer.into()).await?;
331                        }
332                        Some(Err(error)) => {
333                            tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing.");
334                        }
335                        None => {
336                            unreachable!("outgoing responses should never fuse")
337                        }
338                    }
339                }
340            }
341        }
342        Ok(())
343    }
344}
345
346impl<C: WebSocketType> SignalWebSocket<C> {
347    fn inner_locked(&self) -> MutexGuard<'_, SignalWebSocketInner> {
348        self.inner.lock().unwrap()
349    }
350
351    pub fn new(
352        ws: WebSocket,
353        keep_alive_path: String,
354        unidentified_push_service: PushService,
355    ) -> (Self, impl Future<Output = ()>) {
356        // Create process
357        let (incoming_request_sink, incoming_request_stream) = mpsc::channel(1);
358        let (outgoing_request_sink, outgoing_requests) = mpsc::channel(1);
359
360        let process = SignalWebSocketProcess {
361            keep_alive_path,
362            requests: outgoing_requests,
363            request_sink: incoming_request_sink,
364            outgoing_requests: HashMap::default(),
365            outgoing_keep_alive_set: HashSet::new(),
366            // Initializing the FuturesUnordered with a `pending` future means it will never fuse
367            // itself, so an "empty" FuturesUnordered will still allow new futures to be added.
368            outgoing_responses: vec![
369                Box::pin(futures::future::pending()) as BoxFuture<_>
370            ]
371            .into_iter()
372            .collect(),
373            ws,
374        };
375        let process = process.run().map(|x| match x {
376            Ok(()) => (),
377            Err(e) => {
378                tracing::error!("SignalWebSocket: {}", e);
379            },
380        });
381
382        (
383            Self {
384                _type: PhantomData,
385                request_sink: outgoing_request_sink,
386                unidentified_push_service,
387                inner: Arc::new(Mutex::new(SignalWebSocketInner {
388                    stream: Some(SignalRequestStream {
389                        inner: incoming_request_stream,
390                    }),
391                })),
392            },
393            process,
394        )
395    }
396
397    pub fn is_closed(&self) -> bool {
398        self.request_sink.is_closed()
399    }
400
401    pub fn is_used(&self) -> bool {
402        self.inner_locked().stream.is_none()
403    }
404
405    pub(crate) fn take_request_stream(
406        &mut self,
407    ) -> Option<SignalRequestStream> {
408        self.inner_locked().stream.take()
409    }
410
411    pub(crate) fn return_request_stream(&mut self, r: SignalRequestStream) {
412        self.inner_locked().stream.replace(r);
413    }
414
415    // XXX Ideally, this should take an *async* closure, then we could get rid of the
416    // `take_request_stream` and `return_request_stream`.
417    pub async fn with_request_stream<
418        R: 'static,
419        F: FnOnce(&mut SignalRequestStream) -> R,
420    >(
421        &mut self,
422        f: F,
423    ) -> R {
424        let mut s = self
425            .inner_locked()
426            .stream
427            .take()
428            .expect("request stream invariant");
429        let r = f(&mut s);
430        self.inner_locked().stream.replace(s);
431        r
432    }
433
434    pub fn request(
435        &mut self,
436        r: WebSocketRequestMessage,
437    ) -> impl Future<Output = Result<WebSocketResponseMessage, ServiceError>>
438    {
439        let (sink, recv): (
440            oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
441            _,
442        ) = oneshot::channel();
443
444        let mut request_sink = self.request_sink.clone();
445        async move {
446            if let Err(_e) = request_sink.send((r, sink)).await {
447                return Err(ServiceError::WsClosing {
448                    reason: "WebSocket closing while sending request",
449                });
450            }
451            // Handle the oneshot sender error for dropped senders.
452            match recv.await {
453                Ok(x) => x,
454                Err(_) => Err(ServiceError::WsClosing {
455                    reason: "WebSocket closing while waiting for a response",
456                }),
457            }
458        }
459    }
460
461    pub(crate) async fn request_json<T>(
462        &mut self,
463        r: WebSocketRequestMessage,
464    ) -> Result<T, ServiceError>
465    where
466        for<'de> T: serde::Deserialize<'de>,
467    {
468        self.request(r)
469            .await?
470            .service_error_for_status()
471            .await?
472            .json()
473            .await
474    }
475}
476
477impl WebSocketResponseMessage {
478    pub async fn service_error_for_status(self) -> Result<Self, ServiceError> {
479        super::push_service::response::service_error_for_status(self).await
480    }
481
482    pub async fn json<T: for<'a> serde::Deserialize<'a>>(
483        &self,
484    ) -> Result<T, ServiceError> {
485        self.body
486            .as_ref()
487            .ok_or(ServiceError::UnsupportedContent)
488            .and_then(|b| serde_json::from_slice(b).map_err(Into::into))
489    }
490}