Skip to main content

libsignal_service/websocket/
mod.rs

1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use std::future::Future;
6
7use bytes::Bytes;
8use futures::channel::oneshot::Canceled;
9use futures::channel::{mpsc, oneshot};
10use futures::future::BoxFuture;
11use futures::prelude::*;
12use futures::stream::FuturesUnordered;
13use reqwest::Method;
14use reqwest_websocket::WebSocket;
15use tokio::time::Instant;
16use tracing::debug;
17
18use crate::configuration::SignalServers;
19use crate::prelude::PushService;
20use crate::proto::{
21    web_socket_message, WebSocketMessage, WebSocketRequestMessage,
22    WebSocketResponseMessage,
23};
24use crate::push_service::{self, ServiceError, SignalServiceResponse};
25
26pub mod account;
27#[cfg(feature = "cdsi")]
28pub mod directory;
29pub mod keys;
30pub mod linking;
31pub mod profile;
32pub mod registration;
33mod request;
34mod sender;
35pub mod stickers;
36mod usernames;
37
38pub use request::WebSocketRequestMessageBuilder;
39
40type RequestStreamItem = (
41    WebSocketRequestMessage,
42    oneshot::Sender<WebSocketResponseMessage>,
43);
44
45pub struct SignalRequestStream {
46    inner: mpsc::UnboundedReceiver<RequestStreamItem>,
47}
48
49impl Stream for SignalRequestStream {
50    type Item = RequestStreamItem;
51
52    fn poll_next(
53        mut self: std::pin::Pin<&mut Self>,
54        cx: &mut std::task::Context<'_>,
55    ) -> std::task::Poll<Option<Self::Item>> {
56        let inner = &mut self.inner;
57        futures::pin_mut!(inner);
58        Stream::poll_next(inner, cx)
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct Identified;
64
65#[derive(Debug, Clone)]
66pub struct Unidentified;
67
68pub trait WebSocketType: 'static {}
69
70impl WebSocketType for Identified {}
71
72impl WebSocketType for Unidentified {}
73
74/// A dispatching web socket client for the Signal web socket API.
75///
76/// This structure can be freely cloned, since this acts as a *facade* for multiple entry and exit
77/// points.
78#[derive(Clone)]
79pub struct SignalWebSocket<C: WebSocketType> {
80    _type: PhantomData<C>,
81    // XXX: at the end of the migration, this should be CDN operations only
82    pub(crate) unidentified_push_service: PushService,
83    inner: Arc<Mutex<SignalWebSocketInner>>,
84    request_sink: mpsc::Sender<(
85        WebSocketRequestMessage,
86        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
87    )>,
88}
89
90struct SignalWebSocketInner {
91    stream: Option<SignalRequestStream>,
92}
93
94struct SignalWebSocketProcess {
95    /// Whether to enable keep-alive or not (and send a request to this path)
96    keep_alive_path: String,
97
98    /// Receives requests from the application, which we forward to Signal.
99    requests: mpsc::Receiver<(
100        WebSocketRequestMessage,
101        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
102    )>,
103    /// Signal's requests should go in here, to be delivered to the application.
104    request_sink: mpsc::UnboundedSender<RequestStreamItem>,
105
106    outgoing_requests: HashMap<
107        u64,
108        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
109    >,
110
111    outgoing_keep_alive_set: HashSet<u64>,
112
113    outgoing_responses: FuturesUnordered<
114        BoxFuture<'static, Result<WebSocketResponseMessage, Canceled>>,
115    >,
116
117    ws: WebSocket,
118}
119
120impl SignalWebSocketProcess {
121    async fn process_frame(
122        &mut self,
123        frame: Vec<u8>,
124    ) -> Result<(), ServiceError> {
125        use prost::Message;
126        let msg = WebSocketMessage::decode(Bytes::from(frame))?;
127        if let Some(request) = &msg.request {
128            tracing::trace!(
129                msg_type =? msg.r#type(),
130                request.id,
131                request.verb,
132                request.path,
133                request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0),
134                ?request.headers,
135                "decoded WebSocketMessage request"
136            );
137        } else if let Some(response) = &msg.response {
138            tracing::trace!(
139                msg_type =? msg.r#type(),
140                response.status,
141                response.message,
142                response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0),
143                ?response.headers,
144                response.id,
145                "decoded WebSocketMessage response"
146            );
147        } else {
148            tracing::debug!("decoded {msg:?}");
149        }
150
151        use web_socket_message::Type;
152        match (msg.r#type(), msg.request, msg.response) {
153            (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame {
154                reason: "unknown frame type",
155            }),
156            (Type::Request, Some(request), _) => {
157                let (sink, recv) = oneshot::channel();
158                tracing::trace!("sending request with body");
159                self.request_sink.send((request, sink)).await.map_err(
160                    |_| ServiceError::WsClosing {
161                        reason: "request handler failed",
162                    },
163                )?;
164                self.outgoing_responses.push(Box::pin(recv));
165
166                Ok(())
167            },
168            (Type::Request, None, _) => Err(ServiceError::InvalidFrame {
169                reason: "type was request, but does not contain request",
170            }),
171            (Type::Response, _, Some(response)) => {
172                if let Some(id) = response.id {
173                    if let Some(responder) = self.outgoing_requests.remove(&id)
174                    {
175                        if let Err(e) = responder.send(Ok(response)) {
176                            tracing::warn!(
177                                "Could not deliver response for id {}: {:?}",
178                                id,
179                                e
180                            );
181                        }
182                    } else if let Some(_x) =
183                        self.outgoing_keep_alive_set.take(&id)
184                    {
185                        let status = reqwest::StatusCode::from_u16(
186                            response.status() as _,
187                        )
188                        .map_err(|e| {
189                            ServiceError::IO(std::io::Error::other(format!(
190                                "invalid http status code {} - {e}",
191                                response.status()
192                            )))
193                        })?;
194                        if !status.is_success() {
195                            tracing::warn!(
196                                %status,
197                                "response code for keep-alive not successful"
198                            );
199                            return Err(ServiceError::UnhandledResponseCode {
200                                status,
201                                body: String::from_utf8_lossy(response.body())
202                                    .into_owned(),
203                            });
204                        }
205                    } else {
206                        tracing::warn!(
207                            ?response,
208                            "response for non existing request"
209                        );
210                    }
211                }
212
213                Ok(())
214            },
215            (Type::Response, _, None) => Err(ServiceError::InvalidFrame {
216                reason: "type was response, but does not contain response",
217            }),
218        }
219    }
220
221    fn next_request_id(&self) -> u64 {
222        use rand::Rng;
223        let mut rng = rand::rng();
224        loop {
225            let id = rng.random();
226            if !self.outgoing_requests.contains_key(&id) {
227                return id;
228            }
229        }
230    }
231
232    async fn run(mut self) -> Result<(), ServiceError> {
233        let mut ka_interval = tokio::time::interval_at(
234            Instant::now(),
235            push_service::KEEPALIVE_TIMEOUT_SECONDS,
236        );
237
238        loop {
239            futures::select! {
240                _ = ka_interval.tick().fuse() => {
241                    use prost::Message;
242                    if !self.outgoing_keep_alive_set.is_empty() {
243                        tracing::warn!("Websocket will be closed due to failed keepalives.");
244                        if let Err(e) = self.ws.close(reqwest_websocket::CloseCode::Away, None).await {
245                            tracing::debug!("Could not close WebSocket: {:?}", e);
246                        }
247                        self.outgoing_keep_alive_set.clear();
248                        break;
249                    }
250                    tracing::debug!("sending keep-alive");
251                    let request = WebSocketRequestMessage::new(Method::GET)
252                        .id(self.next_request_id())
253                        .path(&self.keep_alive_path)
254                        .build();
255                    self.outgoing_keep_alive_set.insert(request.id.unwrap());
256                    let msg = WebSocketMessage {
257                        r#type: Some(web_socket_message::Type::Request.into()),
258                        request: Some(request),
259                        ..Default::default()
260                    };
261                    let buffer = msg.encode_to_vec();
262                    if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await {
263                        tracing::info!("Websocket sink has closed: {:?}.", e);
264                        break;
265                    };
266                },
267                // Process requests from the application, forward them to Signal
268                x = self.requests.next() => {
269                    match x {
270                        Some((mut request, responder)) => {
271                            use prost::Message;
272
273                            // Regenerate ID if already in the table
274                            request.id = Some(
275                                request
276                                    .id
277                                    .filter(|x| !self.outgoing_requests.contains_key(x))
278                                    .unwrap_or_else(|| self.next_request_id()),
279                            );
280                            tracing::trace!(
281                                request.id,
282                                request.verb,
283                                request.path,
284                                request_body_size_bytes = request.body.as_ref().map(|x| x.len()),
285                                ?request.headers,
286                                "sending WebSocketRequestMessage",
287                            );
288
289                            self.outgoing_requests.insert(request.id.unwrap(), responder);
290                            let msg = WebSocketMessage {
291                                r#type: Some(web_socket_message::Type::Request.into()),
292                                request: Some(request),
293                                ..Default::default()
294                            };
295                            let buffer = msg.encode_to_vec();
296                            self.ws.send(reqwest_websocket::Message::Binary(buffer)).await?
297                        }
298                        None => {
299                            debug!("end of application request stream; websocket closing");
300                            return Ok(());
301                        }
302                    }
303                }
304                // Incoming websocket message
305                web_socket_item = self.ws.next().fuse() => {
306                    use reqwest_websocket::Message;
307                    match web_socket_item {
308                        Some(Ok(Message::Close { code, reason })) => {
309                            tracing::warn!(%code, reason, "websocket closed");
310                            break;
311                        },
312                        Some(Ok(Message::Binary(frame))) => {
313                            self.process_frame(frame).await?;
314                        }
315                        Some(Ok(Message::Ping(_))) => {
316                            tracing::trace!("received ping");
317                        }
318                        Some(Ok(Message::Pong(_))) => {
319                            tracing::trace!("received pong");
320                        }
321                        Some(Ok(Message::Text(_))) => {
322                            tracing::trace!("received text (unsupported, skipping)");
323                        }
324                        Some(Err(e)) => return Err(e.into()),
325                        None => {
326                            return Err(ServiceError::WsClosing {
327                                reason: "end of web request stream; socket closing"
328                            });
329                        }
330                    }
331                }
332                response = self.outgoing_responses.next() => {
333                    use prost::Message;
334                    match response {
335                        Some(Ok(response)) => {
336                            tracing::trace!("sending response {:?}", response);
337
338                            let msg = WebSocketMessage {
339                                r#type: Some(web_socket_message::Type::Response.into()),
340                                response: Some(response),
341                                ..Default::default()
342                            };
343                            let buffer = msg.encode_to_vec();
344                            self.ws.send(buffer.into()).await?;
345                        }
346                        Some(Err(error)) => {
347                            tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing.");
348                        }
349                        None => {
350                            unreachable!("outgoing responses should never fuse")
351                        }
352                    }
353                }
354            }
355        }
356        Ok(())
357    }
358}
359
360impl<C: WebSocketType> SignalWebSocket<C> {
361    fn inner_locked(&self) -> MutexGuard<'_, SignalWebSocketInner> {
362        self.inner.lock().unwrap()
363    }
364
365    pub fn new(
366        ws: WebSocket,
367        keep_alive_path: String,
368        unidentified_push_service: PushService,
369    ) -> (Self, impl Future<Output = ()>) {
370        // Create process
371        let (incoming_request_sink, incoming_request_stream) =
372            mpsc::unbounded();
373        let (outgoing_request_sink, outgoing_requests) = mpsc::channel(1);
374
375        let process = SignalWebSocketProcess {
376            keep_alive_path,
377            requests: outgoing_requests,
378            request_sink: incoming_request_sink,
379            outgoing_requests: HashMap::default(),
380            outgoing_keep_alive_set: HashSet::new(),
381            // Initializing the FuturesUnordered with a `pending` future means it will never fuse
382            // itself, so an "empty" FuturesUnordered will still allow new futures to be added.
383            outgoing_responses: vec![
384                Box::pin(futures::future::pending()) as BoxFuture<_>
385            ]
386            .into_iter()
387            .collect(),
388            ws,
389        };
390        let process = process.run().map(|x| match x {
391            Ok(()) => (),
392            Err(e) => {
393                tracing::error!("SignalWebSocket: {}", e);
394            },
395        });
396
397        (
398            Self {
399                _type: PhantomData,
400                request_sink: outgoing_request_sink,
401                unidentified_push_service,
402                inner: Arc::new(Mutex::new(SignalWebSocketInner {
403                    stream: Some(SignalRequestStream {
404                        inner: incoming_request_stream,
405                    }),
406                })),
407            },
408            process,
409        )
410    }
411
412    pub fn servers(&self) -> SignalServers {
413        self.unidentified_push_service.servers
414    }
415
416    pub fn is_closed(&self) -> bool {
417        self.request_sink.is_closed()
418    }
419
420    pub fn is_used(&self) -> bool {
421        self.inner_locked().stream.is_none()
422    }
423
424    pub(crate) fn take_request_stream(
425        &mut self,
426    ) -> Option<SignalRequestStream> {
427        self.inner_locked().stream.take()
428    }
429
430    pub(crate) fn return_request_stream(&mut self, r: SignalRequestStream) {
431        self.inner_locked().stream.replace(r);
432    }
433
434    // XXX Ideally, this should take an *async* closure, then we could get rid of the
435    // `take_request_stream` and `return_request_stream`.
436    pub async fn with_request_stream<
437        R: 'static,
438        F: FnOnce(&mut SignalRequestStream) -> R,
439    >(
440        &mut self,
441        f: F,
442    ) -> R {
443        let mut s = self
444            .inner_locked()
445            .stream
446            .take()
447            .expect("request stream invariant");
448        let r = f(&mut s);
449        self.inner_locked().stream.replace(s);
450        r
451    }
452
453    pub fn request(
454        &mut self,
455        r: WebSocketRequestMessage,
456    ) -> impl Future<Output = Result<WebSocketResponseMessage, ServiceError>>
457    {
458        let (sink, recv): (
459            oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
460            _,
461        ) = oneshot::channel();
462
463        let mut request_sink = self.request_sink.clone();
464        async move {
465            if let Err(_e) = request_sink.send((r, sink)).await {
466                return Err(ServiceError::WsClosing {
467                    reason: "WebSocket closing while sending request",
468                });
469            }
470            // Handle the oneshot sender error for dropped senders.
471            match recv.await {
472                Ok(x) => x,
473                Err(_) => Err(ServiceError::WsClosing {
474                    reason: "WebSocket closing while waiting for a response",
475                }),
476            }
477        }
478    }
479
480    pub(crate) async fn request_json<T>(
481        &mut self,
482        r: WebSocketRequestMessage,
483    ) -> Result<T, ServiceError>
484    where
485        for<'de> T: serde::Deserialize<'de>,
486    {
487        self.request(r)
488            .await?
489            .service_error_for_status()
490            .await?
491            .json()
492            .await
493    }
494}
495
496impl WebSocketResponseMessage {
497    pub async fn service_error_for_status(self) -> Result<Self, ServiceError> {
498        super::push_service::response::service_error_for_status(self).await
499    }
500
501    pub async fn json<T: for<'a> serde::Deserialize<'a>>(
502        &self,
503    ) -> Result<T, ServiceError> {
504        self.body
505            .as_ref()
506            .ok_or(ServiceError::UnsupportedContent)
507            .and_then(|b| serde_json::from_slice(b).map_err(Into::into))
508    }
509}