libsignal_service/websocket/
mod.rs

1use std::collections::{HashMap, HashSet};
2use std::marker::PhantomData;
3use std::sync::{Arc, Mutex, MutexGuard};
4
5use std::future::Future;
6
7use bytes::Bytes;
8use futures::channel::oneshot::Canceled;
9use futures::channel::{mpsc, oneshot};
10use futures::future::BoxFuture;
11use futures::prelude::*;
12use futures::stream::FuturesUnordered;
13use reqwest::Method;
14use reqwest_websocket::WebSocket;
15use tokio::time::Instant;
16use tracing::debug;
17
18use crate::configuration::SignalServers;
19use crate::prelude::PushService;
20use crate::proto::{
21    web_socket_message, WebSocketMessage, WebSocketRequestMessage,
22    WebSocketResponseMessage,
23};
24use crate::push_service::{self, ServiceError, SignalServiceResponse};
25
26pub mod account;
27#[cfg(feature = "cdsi")]
28pub mod directory;
29pub mod keys;
30pub mod linking;
31pub mod profile;
32pub mod registration;
33mod request;
34mod sender;
35pub mod stickers;
36mod usernames;
37
38pub use request::WebSocketRequestMessageBuilder;
39
40type RequestStreamItem = (
41    WebSocketRequestMessage,
42    oneshot::Sender<WebSocketResponseMessage>,
43);
44
45pub struct SignalRequestStream {
46    inner: mpsc::UnboundedReceiver<RequestStreamItem>,
47}
48
49impl Stream for SignalRequestStream {
50    type Item = RequestStreamItem;
51
52    fn poll_next(
53        mut self: std::pin::Pin<&mut Self>,
54        cx: &mut std::task::Context<'_>,
55    ) -> std::task::Poll<Option<Self::Item>> {
56        let inner = &mut self.inner;
57        futures::pin_mut!(inner);
58        Stream::poll_next(inner, cx)
59    }
60}
61
62#[derive(Debug, Clone)]
63pub struct Identified;
64
65#[derive(Debug, Clone)]
66pub struct Unidentified;
67
68pub trait WebSocketType: 'static {}
69
70impl WebSocketType for Identified {}
71
72impl WebSocketType for Unidentified {}
73
74/// A dispatching web socket client for the Signal web socket API.
75///
76/// This structure can be freely cloned, since this acts as a *facade* for multiple entry and exit
77/// points.
78#[derive(Clone)]
79pub struct SignalWebSocket<C: WebSocketType> {
80    _type: PhantomData<C>,
81    // XXX: at the end of the migration, this should be CDN operations only
82    pub(crate) unidentified_push_service: PushService,
83    inner: Arc<Mutex<SignalWebSocketInner>>,
84    request_sink: mpsc::Sender<(
85        WebSocketRequestMessage,
86        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
87    )>,
88}
89
90struct SignalWebSocketInner {
91    stream: Option<SignalRequestStream>,
92}
93
94struct SignalWebSocketProcess {
95    /// Whether to enable keep-alive or not (and send a request to this path)
96    keep_alive_path: String,
97
98    /// Receives requests from the application, which we forward to Signal.
99    requests: mpsc::Receiver<(
100        WebSocketRequestMessage,
101        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
102    )>,
103    /// Signal's requests should go in here, to be delivered to the application.
104    request_sink: mpsc::UnboundedSender<RequestStreamItem>,
105
106    outgoing_requests: HashMap<
107        u64,
108        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
109    >,
110
111    outgoing_keep_alive_set: HashSet<u64>,
112
113    outgoing_responses: FuturesUnordered<
114        BoxFuture<'static, Result<WebSocketResponseMessage, Canceled>>,
115    >,
116
117    ws: WebSocket,
118}
119
120impl SignalWebSocketProcess {
121    async fn process_frame(
122        &mut self,
123        frame: Vec<u8>,
124    ) -> Result<(), ServiceError> {
125        use prost::Message;
126        let msg = WebSocketMessage::decode(Bytes::from(frame))?;
127        if let Some(request) = &msg.request {
128            tracing::trace!(
129                msg_type =? msg.r#type(),
130                request.id,
131                request.verb,
132                request.path,
133                request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0),
134                ?request.headers,
135                "decoded WebSocketMessage request"
136            );
137        } else if let Some(response) = &msg.response {
138            tracing::trace!(
139                msg_type =? msg.r#type(),
140                response.status,
141                response.message,
142                response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0),
143                ?response.headers,
144                response.id,
145                "decoded WebSocketMessage response"
146            );
147        } else {
148            tracing::debug!("decoded {msg:?}");
149        }
150
151        use web_socket_message::Type;
152        match (msg.r#type(), msg.request, msg.response) {
153            (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame {
154                reason: "unknown frame type",
155            }),
156            (Type::Request, Some(request), _) => {
157                let (sink, recv) = oneshot::channel();
158                tracing::trace!("sending request with body");
159                self.request_sink.send((request, sink)).await.map_err(
160                    |_| ServiceError::WsClosing {
161                        reason: "request handler failed",
162                    },
163                )?;
164                self.outgoing_responses.push(Box::pin(recv));
165
166                Ok(())
167            },
168            (Type::Request, None, _) => Err(ServiceError::InvalidFrame {
169                reason: "type was request, but does not contain request",
170            }),
171            (Type::Response, _, Some(response)) => {
172                if let Some(id) = response.id {
173                    if let Some(responder) = self.outgoing_requests.remove(&id)
174                    {
175                        if let Err(e) = responder.send(Ok(response)) {
176                            tracing::warn!(
177                                "Could not deliver response for id {}: {:?}",
178                                id,
179                                e
180                            );
181                        }
182                    } else if let Some(_x) =
183                        self.outgoing_keep_alive_set.take(&id)
184                    {
185                        let status_code = response.status();
186                        if status_code != 200 {
187                            tracing::warn!(
188                                status_code,
189                                "response code for keep-alive != 200"
190                            );
191                            return Err(ServiceError::UnhandledResponseCode {
192                                http_code: response.status() as u16,
193                            });
194                        }
195                    } else {
196                        tracing::warn!(
197                            ?response,
198                            "response for non existing request"
199                        );
200                    }
201                }
202
203                Ok(())
204            },
205            (Type::Response, _, None) => Err(ServiceError::InvalidFrame {
206                reason: "type was response, but does not contain response",
207            }),
208        }
209    }
210
211    fn next_request_id(&self) -> u64 {
212        use rand::Rng;
213        let mut rng = rand::rng();
214        loop {
215            let id = rng.random();
216            if !self.outgoing_requests.contains_key(&id) {
217                return id;
218            }
219        }
220    }
221
222    async fn run(mut self) -> Result<(), ServiceError> {
223        let mut ka_interval = tokio::time::interval_at(
224            Instant::now(),
225            push_service::KEEPALIVE_TIMEOUT_SECONDS,
226        );
227
228        loop {
229            futures::select! {
230                _ = ka_interval.tick().fuse() => {
231                    use prost::Message;
232                    if !self.outgoing_keep_alive_set.is_empty() {
233                        tracing::warn!("Websocket will be closed due to failed keepalives.");
234                        if let Err(e) = self.ws.close(reqwest_websocket::CloseCode::Away, None).await {
235                            tracing::debug!("Could not close WebSocket: {:?}", e);
236                        }
237                        self.outgoing_keep_alive_set.clear();
238                        break;
239                    }
240                    tracing::debug!("sending keep-alive");
241                    let request = WebSocketRequestMessage::new(Method::GET)
242                        .id(self.next_request_id())
243                        .path(&self.keep_alive_path)
244                        .build();
245                    self.outgoing_keep_alive_set.insert(request.id.unwrap());
246                    let msg = WebSocketMessage {
247                        r#type: Some(web_socket_message::Type::Request.into()),
248                        request: Some(request),
249                        ..Default::default()
250                    };
251                    let buffer = msg.encode_to_vec();
252                    if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await {
253                        tracing::info!("Websocket sink has closed: {:?}.", e);
254                        break;
255                    };
256                },
257                // Process requests from the application, forward them to Signal
258                x = self.requests.next() => {
259                    match x {
260                        Some((mut request, responder)) => {
261                            use prost::Message;
262
263                            // Regenerate ID if already in the table
264                            request.id = Some(
265                                request
266                                    .id
267                                    .filter(|x| !self.outgoing_requests.contains_key(x))
268                                    .unwrap_or_else(|| self.next_request_id()),
269                            );
270                            tracing::trace!(
271                                request.id,
272                                request.verb,
273                                request.path,
274                                request_body_size_bytes = request.body.as_ref().map(|x| x.len()),
275                                ?request.headers,
276                                "sending WebSocketRequestMessage",
277                            );
278
279                            self.outgoing_requests.insert(request.id.unwrap(), responder);
280                            let msg = WebSocketMessage {
281                                r#type: Some(web_socket_message::Type::Request.into()),
282                                request: Some(request),
283                                ..Default::default()
284                            };
285                            let buffer = msg.encode_to_vec();
286                            self.ws.send(reqwest_websocket::Message::Binary(buffer)).await?
287                        }
288                        None => {
289                            debug!("end of application request stream; websocket closing");
290                            return Ok(());
291                        }
292                    }
293                }
294                // Incoming websocket message
295                web_socket_item = self.ws.next().fuse() => {
296                    use reqwest_websocket::Message;
297                    match web_socket_item {
298                        Some(Ok(Message::Close { code, reason })) => {
299                            tracing::warn!(%code, reason, "websocket closed");
300                            break;
301                        },
302                        Some(Ok(Message::Binary(frame))) => {
303                            self.process_frame(frame).await?;
304                        }
305                        Some(Ok(Message::Ping(_))) => {
306                            tracing::trace!("received ping");
307                        }
308                        Some(Ok(Message::Pong(_))) => {
309                            tracing::trace!("received pong");
310                        }
311                        Some(Ok(Message::Text(_))) => {
312                            tracing::trace!("received text (unsupported, skipping)");
313                        }
314                        Some(Err(e)) => return Err(e.into()),
315                        None => {
316                            return Err(ServiceError::WsClosing {
317                                reason: "end of web request stream; socket closing"
318                            });
319                        }
320                    }
321                }
322                response = self.outgoing_responses.next() => {
323                    use prost::Message;
324                    match response {
325                        Some(Ok(response)) => {
326                            tracing::trace!("sending response {:?}", response);
327
328                            let msg = WebSocketMessage {
329                                r#type: Some(web_socket_message::Type::Response.into()),
330                                response: Some(response),
331                                ..Default::default()
332                            };
333                            let buffer = msg.encode_to_vec();
334                            self.ws.send(buffer.into()).await?;
335                        }
336                        Some(Err(error)) => {
337                            tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing.");
338                        }
339                        None => {
340                            unreachable!("outgoing responses should never fuse")
341                        }
342                    }
343                }
344            }
345        }
346        Ok(())
347    }
348}
349
350impl<C: WebSocketType> SignalWebSocket<C> {
351    fn inner_locked(&self) -> MutexGuard<'_, SignalWebSocketInner> {
352        self.inner.lock().unwrap()
353    }
354
355    pub fn new(
356        ws: WebSocket,
357        keep_alive_path: String,
358        unidentified_push_service: PushService,
359    ) -> (Self, impl Future<Output = ()>) {
360        // Create process
361        let (incoming_request_sink, incoming_request_stream) =
362            mpsc::unbounded();
363        let (outgoing_request_sink, outgoing_requests) = mpsc::channel(1);
364
365        let process = SignalWebSocketProcess {
366            keep_alive_path,
367            requests: outgoing_requests,
368            request_sink: incoming_request_sink,
369            outgoing_requests: HashMap::default(),
370            outgoing_keep_alive_set: HashSet::new(),
371            // Initializing the FuturesUnordered with a `pending` future means it will never fuse
372            // itself, so an "empty" FuturesUnordered will still allow new futures to be added.
373            outgoing_responses: vec![
374                Box::pin(futures::future::pending()) as BoxFuture<_>
375            ]
376            .into_iter()
377            .collect(),
378            ws,
379        };
380        let process = process.run().map(|x| match x {
381            Ok(()) => (),
382            Err(e) => {
383                tracing::error!("SignalWebSocket: {}", e);
384            },
385        });
386
387        (
388            Self {
389                _type: PhantomData,
390                request_sink: outgoing_request_sink,
391                unidentified_push_service,
392                inner: Arc::new(Mutex::new(SignalWebSocketInner {
393                    stream: Some(SignalRequestStream {
394                        inner: incoming_request_stream,
395                    }),
396                })),
397            },
398            process,
399        )
400    }
401
402    pub fn servers(&self) -> SignalServers {
403        self.unidentified_push_service.servers
404    }
405
406    pub fn is_closed(&self) -> bool {
407        self.request_sink.is_closed()
408    }
409
410    pub fn is_used(&self) -> bool {
411        self.inner_locked().stream.is_none()
412    }
413
414    pub(crate) fn take_request_stream(
415        &mut self,
416    ) -> Option<SignalRequestStream> {
417        self.inner_locked().stream.take()
418    }
419
420    pub(crate) fn return_request_stream(&mut self, r: SignalRequestStream) {
421        self.inner_locked().stream.replace(r);
422    }
423
424    // XXX Ideally, this should take an *async* closure, then we could get rid of the
425    // `take_request_stream` and `return_request_stream`.
426    pub async fn with_request_stream<
427        R: 'static,
428        F: FnOnce(&mut SignalRequestStream) -> R,
429    >(
430        &mut self,
431        f: F,
432    ) -> R {
433        let mut s = self
434            .inner_locked()
435            .stream
436            .take()
437            .expect("request stream invariant");
438        let r = f(&mut s);
439        self.inner_locked().stream.replace(s);
440        r
441    }
442
443    pub fn request(
444        &mut self,
445        r: WebSocketRequestMessage,
446    ) -> impl Future<Output = Result<WebSocketResponseMessage, ServiceError>>
447    {
448        let (sink, recv): (
449            oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
450            _,
451        ) = oneshot::channel();
452
453        let mut request_sink = self.request_sink.clone();
454        async move {
455            if let Err(_e) = request_sink.send((r, sink)).await {
456                return Err(ServiceError::WsClosing {
457                    reason: "WebSocket closing while sending request",
458                });
459            }
460            // Handle the oneshot sender error for dropped senders.
461            match recv.await {
462                Ok(x) => x,
463                Err(_) => Err(ServiceError::WsClosing {
464                    reason: "WebSocket closing while waiting for a response",
465                }),
466            }
467        }
468    }
469
470    pub(crate) async fn request_json<T>(
471        &mut self,
472        r: WebSocketRequestMessage,
473    ) -> Result<T, ServiceError>
474    where
475        for<'de> T: serde::Deserialize<'de>,
476    {
477        self.request(r)
478            .await?
479            .service_error_for_status()
480            .await?
481            .json()
482            .await
483    }
484}
485
486impl WebSocketResponseMessage {
487    pub async fn service_error_for_status(self) -> Result<Self, ServiceError> {
488        super::push_service::response::service_error_for_status(self).await
489    }
490
491    pub async fn json<T: for<'a> serde::Deserialize<'a>>(
492        &self,
493    ) -> Result<T, ServiceError> {
494        self.body
495            .as_ref()
496            .ok_or(ServiceError::UnsupportedContent)
497            .and_then(|b| serde_json::from_slice(b).map_err(Into::into))
498    }
499}