libsignal_service/websocket/
mod.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, Mutex, MutexGuard};
3
4use std::future::Future;
5
6use bytes::Bytes;
7use futures::channel::oneshot::Canceled;
8use futures::channel::{mpsc, oneshot};
9use futures::future::BoxFuture;
10use futures::prelude::*;
11use futures::stream::FuturesUnordered;
12use reqwest::Method;
13use reqwest_websocket::WebSocket;
14use tokio::time::Instant;
15
16use crate::proto::{
17    web_socket_message, WebSocketMessage, WebSocketRequestMessage,
18    WebSocketResponseMessage,
19};
20use crate::push_service::{self, ServiceError, SignalServiceResponse};
21
22mod request;
23mod sender;
24
25pub use request::WebSocketRequestMessageBuilder;
26
27type RequestStreamItem = (
28    WebSocketRequestMessage,
29    oneshot::Sender<WebSocketResponseMessage>,
30);
31
32pub struct SignalRequestStream {
33    inner: mpsc::Receiver<RequestStreamItem>,
34}
35
36impl Stream for SignalRequestStream {
37    type Item = RequestStreamItem;
38
39    fn poll_next(
40        mut self: std::pin::Pin<&mut Self>,
41        cx: &mut std::task::Context<'_>,
42    ) -> std::task::Poll<Option<Self::Item>> {
43        let inner = &mut self.inner;
44        futures::pin_mut!(inner);
45        Stream::poll_next(inner, cx)
46    }
47}
48
49/// A dispatching web socket client for the Signal web socket API.
50///
51/// This structure can be freely cloned, since this acts as a *facade* for multiple entry and exit
52/// points.
53#[derive(Clone)]
54pub struct SignalWebSocket {
55    inner: Arc<Mutex<SignalWebSocketInner>>,
56    request_sink: mpsc::Sender<(
57        WebSocketRequestMessage,
58        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
59    )>,
60}
61
62struct SignalWebSocketInner {
63    stream: Option<SignalRequestStream>,
64}
65
66struct SignalWebSocketProcess {
67    /// Whether to enable keep-alive or not (and send a request to this path)
68    keep_alive_path: String,
69
70    /// Receives requests from the application, which we forward to Signal.
71    requests: mpsc::Receiver<(
72        WebSocketRequestMessage,
73        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
74    )>,
75    /// Signal's requests should go in here, to be delivered to the application.
76    request_sink: mpsc::Sender<RequestStreamItem>,
77
78    outgoing_requests: HashMap<
79        u64,
80        oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
81    >,
82
83    outgoing_keep_alive_set: HashSet<u64>,
84
85    outgoing_responses: FuturesUnordered<
86        BoxFuture<'static, Result<WebSocketResponseMessage, Canceled>>,
87    >,
88
89    ws: WebSocket,
90}
91
92impl SignalWebSocketProcess {
93    async fn process_frame(
94        &mut self,
95        frame: Vec<u8>,
96    ) -> Result<(), ServiceError> {
97        use prost::Message;
98        let msg = WebSocketMessage::decode(Bytes::from(frame))?;
99        if let Some(request) = &msg.request {
100            tracing::trace!(
101                msg_type =? msg.r#type(),
102                request.id,
103                request.verb,
104                request.path,
105                request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0),
106                ?request.headers,
107                "decoded WebSocketMessage request"
108            );
109        } else if let Some(response) = &msg.response {
110            tracing::trace!(
111                msg_type =? msg.r#type(),
112                response.status,
113                response.message,
114                response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0),
115                ?response.headers,
116                response.id,
117                "decoded WebSocketMessage response"
118            );
119        } else {
120            tracing::debug!("decoded {msg:?}");
121        }
122
123        use web_socket_message::Type;
124        match (msg.r#type(), msg.request, msg.response) {
125            (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame {
126                reason: "unknown frame type",
127            }),
128            (Type::Request, Some(request), _) => {
129                let (sink, recv) = oneshot::channel();
130                tracing::trace!("sending request with body");
131                self.request_sink.send((request, sink)).await.map_err(
132                    |_| ServiceError::WsClosing {
133                        reason: "request handler failed",
134                    },
135                )?;
136                self.outgoing_responses.push(Box::pin(recv));
137
138                Ok(())
139            },
140            (Type::Request, None, _) => Err(ServiceError::InvalidFrame {
141                reason: "type was request, but does not contain request",
142            }),
143            (Type::Response, _, Some(response)) => {
144                if let Some(id) = response.id {
145                    if let Some(responder) = self.outgoing_requests.remove(&id)
146                    {
147                        if let Err(e) = responder.send(Ok(response)) {
148                            tracing::warn!(
149                                "Could not deliver response for id {}: {:?}",
150                                id,
151                                e
152                            );
153                        }
154                    } else if let Some(_x) =
155                        self.outgoing_keep_alive_set.take(&id)
156                    {
157                        let status_code = response.status();
158                        if status_code != 200 {
159                            tracing::warn!(
160                                status_code,
161                                "response code for keep-alive != 200"
162                            );
163                            return Err(ServiceError::UnhandledResponseCode {
164                                http_code: response.status() as u16,
165                            });
166                        }
167                    } else {
168                        tracing::warn!(
169                            ?response,
170                            "response for non existing request"
171                        );
172                    }
173                }
174
175                Ok(())
176            },
177            (Type::Response, _, None) => Err(ServiceError::InvalidFrame {
178                reason: "type was response, but does not contain response",
179            }),
180        }
181    }
182
183    fn next_request_id(&self) -> u64 {
184        use rand::Rng;
185        let mut rng = rand::thread_rng();
186        loop {
187            let id = rng.gen();
188            if !self.outgoing_requests.contains_key(&id) {
189                return id;
190            }
191        }
192    }
193
194    async fn run(mut self) -> Result<(), ServiceError> {
195        let mut ka_interval = tokio::time::interval_at(
196            Instant::now(),
197            push_service::KEEPALIVE_TIMEOUT_SECONDS,
198        );
199
200        loop {
201            futures::select! {
202                _ = ka_interval.tick().fuse() => {
203                    use prost::Message;
204                    if !self.outgoing_keep_alive_set.is_empty() {
205                        tracing::warn!("Websocket will be closed due to failed keepalives.");
206                        if let Err(e) = self.ws.close(reqwest_websocket::CloseCode::Away, None).await {
207                            tracing::debug!("Could not close WebSocket: {:?}", e);
208                        }
209                        self.outgoing_keep_alive_set.clear();
210                        break;
211                    }
212                    tracing::debug!("sending keep-alive");
213                    let request = WebSocketRequestMessage::new(Method::GET)
214                        .id(self.next_request_id())
215                        .path(&self.keep_alive_path)
216                        .build();
217                    self.outgoing_keep_alive_set.insert(request.id.unwrap());
218                    let msg = WebSocketMessage {
219                        r#type: Some(web_socket_message::Type::Request.into()),
220                        request: Some(request),
221                        ..Default::default()
222                    };
223                    let buffer = msg.encode_to_vec();
224                    if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await {
225                        tracing::info!("Websocket sink has closed: {:?}.", e);
226                        break;
227                    };
228                },
229                // Process requests from the application, forward them to Signal
230                x = self.requests.next() => {
231                    match x {
232                        Some((mut request, responder)) => {
233                            use prost::Message;
234
235                            // Regenerate ID if already in the table
236                            request.id = Some(
237                                request
238                                    .id
239                                    .filter(|x| !self.outgoing_requests.contains_key(x))
240                                    .unwrap_or_else(|| self.next_request_id()),
241                            );
242                            tracing::trace!(
243                                request.id,
244                                request.verb,
245                                request.path,
246                                request_body_size_bytes = request.body.as_ref().map(|x| x.len()),
247                                ?request.headers,
248                                "sending WebSocketRequestMessage",
249                            );
250
251                            self.outgoing_requests.insert(request.id.unwrap(), responder);
252                            let msg = WebSocketMessage {
253                                r#type: Some(web_socket_message::Type::Request.into()),
254                                request: Some(request),
255                                ..Default::default()
256                            };
257                            let buffer = msg.encode_to_vec();
258                            self.ws.send(reqwest_websocket::Message::Binary(buffer)).await?
259                        }
260                        None => {
261                            return Err(ServiceError::WsClosing {
262                                reason: "end of application request stream; socket closing"
263                            });
264                        }
265                    }
266                }
267                // Incoming websocket message
268                web_socket_item = self.ws.next().fuse() => {
269                    use reqwest_websocket::Message;
270                    match web_socket_item {
271                        Some(Ok(Message::Close { code, reason })) => {
272                            tracing::warn!(%code, reason, "websocket closed");
273                            break;
274                        },
275                        Some(Ok(Message::Binary(frame))) => {
276                            self.process_frame(frame).await?;
277                        }
278                        Some(Ok(Message::Ping(_))) => {
279                            tracing::trace!("received ping");
280                        }
281                        Some(Ok(Message::Pong(_))) => {
282                            tracing::trace!("received pong");
283                        }
284                        Some(Ok(Message::Text(_))) => {
285                            tracing::trace!("received text (unsupported, skipping)");
286                        }
287                        Some(Err(e)) => return Err(ServiceError::WsError(e)),
288                        None => {
289                            return Err(ServiceError::WsClosing {
290                                reason: "end of web request stream; socket closing"
291                            });
292                        }
293                    }
294                }
295                response = self.outgoing_responses.next() => {
296                    use prost::Message;
297                    match response {
298                        Some(Ok(response)) => {
299                            tracing::trace!("sending response {:?}", response);
300
301                            let msg = WebSocketMessage {
302                                r#type: Some(web_socket_message::Type::Response.into()),
303                                response: Some(response),
304                                ..Default::default()
305                            };
306                            let buffer = msg.encode_to_vec();
307                            self.ws.send(buffer.into()).await?;
308                        }
309                        Some(Err(error)) => {
310                            tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing.");
311                        }
312                        None => {
313                            unreachable!("outgoing responses should never fuse")
314                        }
315                    }
316                }
317            }
318        }
319        Ok(())
320    }
321}
322
323impl SignalWebSocket {
324    fn inner_locked(&self) -> MutexGuard<'_, SignalWebSocketInner> {
325        self.inner.lock().unwrap()
326    }
327
328    pub fn from_socket(
329        ws: WebSocket,
330        keep_alive_path: String,
331    ) -> (Self, impl Future<Output = ()>) {
332        // Create process
333        let (incoming_request_sink, incoming_request_stream) = mpsc::channel(1);
334        let (outgoing_request_sink, outgoing_requests) = mpsc::channel(1);
335
336        let process = SignalWebSocketProcess {
337            keep_alive_path,
338            requests: outgoing_requests,
339            request_sink: incoming_request_sink,
340            outgoing_requests: HashMap::default(),
341            outgoing_keep_alive_set: HashSet::new(),
342            // Initializing the FuturesUnordered with a `pending` future means it will never fuse
343            // itself, so an "empty" FuturesUnordered will still allow new futures to be added.
344            outgoing_responses: vec![
345                Box::pin(futures::future::pending()) as BoxFuture<_>
346            ]
347            .into_iter()
348            .collect(),
349            ws,
350        };
351        let process = process.run().map(|x| match x {
352            Ok(()) => (),
353            Err(e) => {
354                tracing::error!("SignalWebSocket: {}", e);
355            },
356        });
357
358        (
359            Self {
360                request_sink: outgoing_request_sink,
361                inner: Arc::new(Mutex::new(SignalWebSocketInner {
362                    stream: Some(SignalRequestStream {
363                        inner: incoming_request_stream,
364                    }),
365                })),
366            },
367            process,
368        )
369    }
370
371    pub fn is_closed(&self) -> bool {
372        self.request_sink.is_closed()
373    }
374
375    pub fn is_used(&self) -> bool {
376        self.inner_locked().stream.is_none()
377    }
378
379    pub(crate) fn take_request_stream(
380        &mut self,
381    ) -> Option<SignalRequestStream> {
382        self.inner_locked().stream.take()
383    }
384
385    pub(crate) fn return_request_stream(&mut self, r: SignalRequestStream) {
386        self.inner_locked().stream.replace(r);
387    }
388
389    // XXX Ideally, this should take an *async* closure, then we could get rid of the
390    // `take_request_stream` and `return_request_stream`.
391    pub async fn with_request_stream<
392        R: 'static,
393        F: FnOnce(&mut SignalRequestStream) -> R,
394    >(
395        &mut self,
396        f: F,
397    ) -> R {
398        let mut s = self
399            .inner_locked()
400            .stream
401            .take()
402            .expect("request stream invariant");
403        let r = f(&mut s);
404        self.inner_locked().stream.replace(s);
405        r
406    }
407
408    pub fn request(
409        &mut self,
410        r: WebSocketRequestMessage,
411    ) -> impl Future<Output = Result<WebSocketResponseMessage, ServiceError>>
412    {
413        let (sink, recv): (
414            oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
415            _,
416        ) = oneshot::channel();
417
418        let mut request_sink = self.request_sink.clone();
419        async move {
420            if let Err(_e) = request_sink.send((r, sink)).await {
421                return Err(ServiceError::WsClosing {
422                    reason: "WebSocket closing while sending request",
423                });
424            }
425            // Handle the oneshot sender error for dropped senders.
426            match recv.await {
427                Ok(x) => x,
428                Err(_) => Err(ServiceError::WsClosing {
429                    reason: "WebSocket closing while waiting for a response",
430                }),
431            }
432        }
433    }
434
435    pub(crate) async fn request_json<T>(
436        &mut self,
437        r: WebSocketRequestMessage,
438    ) -> Result<T, ServiceError>
439    where
440        for<'de> T: serde::Deserialize<'de>,
441    {
442        self.request(r)
443            .await?
444            .service_error_for_status()
445            .await?
446            .json()
447            .await
448    }
449}