libsignal_service/websocket/
mod.rs1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, Mutex, MutexGuard};
3
4use std::future::Future;
5
6use bytes::Bytes;
7use futures::channel::oneshot::Canceled;
8use futures::channel::{mpsc, oneshot};
9use futures::future::BoxFuture;
10use futures::prelude::*;
11use futures::stream::FuturesUnordered;
12use reqwest::Method;
13use reqwest_websocket::WebSocket;
14use tokio::time::Instant;
15
16use crate::proto::{
17 web_socket_message, WebSocketMessage, WebSocketRequestMessage,
18 WebSocketResponseMessage,
19};
20use crate::push_service::{self, ServiceError, SignalServiceResponse};
21
22mod request;
23mod sender;
24
25pub use request::WebSocketRequestMessageBuilder;
26
27type RequestStreamItem = (
28 WebSocketRequestMessage,
29 oneshot::Sender<WebSocketResponseMessage>,
30);
31
32pub struct SignalRequestStream {
33 inner: mpsc::Receiver<RequestStreamItem>,
34}
35
36impl Stream for SignalRequestStream {
37 type Item = RequestStreamItem;
38
39 fn poll_next(
40 mut self: std::pin::Pin<&mut Self>,
41 cx: &mut std::task::Context<'_>,
42 ) -> std::task::Poll<Option<Self::Item>> {
43 let inner = &mut self.inner;
44 futures::pin_mut!(inner);
45 Stream::poll_next(inner, cx)
46 }
47}
48
49#[derive(Clone)]
54pub struct SignalWebSocket {
55 inner: Arc<Mutex<SignalWebSocketInner>>,
56 request_sink: mpsc::Sender<(
57 WebSocketRequestMessage,
58 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
59 )>,
60}
61
62struct SignalWebSocketInner {
63 stream: Option<SignalRequestStream>,
64}
65
66struct SignalWebSocketProcess {
67 keep_alive_path: String,
69
70 requests: mpsc::Receiver<(
72 WebSocketRequestMessage,
73 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
74 )>,
75 request_sink: mpsc::Sender<RequestStreamItem>,
77
78 outgoing_requests: HashMap<
79 u64,
80 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
81 >,
82
83 outgoing_keep_alive_set: HashSet<u64>,
84
85 outgoing_responses: FuturesUnordered<
86 BoxFuture<'static, Result<WebSocketResponseMessage, Canceled>>,
87 >,
88
89 ws: WebSocket,
90}
91
92impl SignalWebSocketProcess {
93 async fn process_frame(
94 &mut self,
95 frame: Vec<u8>,
96 ) -> Result<(), ServiceError> {
97 use prost::Message;
98 let msg = WebSocketMessage::decode(Bytes::from(frame))?;
99 if let Some(request) = &msg.request {
100 tracing::trace!(
101 msg_type =? msg.r#type(),
102 request.id,
103 request.verb,
104 request.path,
105 request_body_size_bytes = request.body.as_ref().map(|x| x.len()).unwrap_or(0),
106 ?request.headers,
107 "decoded WebSocketMessage request"
108 );
109 } else if let Some(response) = &msg.response {
110 tracing::trace!(
111 msg_type =? msg.r#type(),
112 response.status,
113 response.message,
114 response_body_size_bytes = response.body.as_ref().map(|x| x.len()).unwrap_or(0),
115 ?response.headers,
116 response.id,
117 "decoded WebSocketMessage response"
118 );
119 } else {
120 tracing::debug!("decoded {msg:?}");
121 }
122
123 use web_socket_message::Type;
124 match (msg.r#type(), msg.request, msg.response) {
125 (Type::Unknown, _, _) => Err(ServiceError::InvalidFrame {
126 reason: "unknown frame type",
127 }),
128 (Type::Request, Some(request), _) => {
129 let (sink, recv) = oneshot::channel();
130 tracing::trace!("sending request with body");
131 self.request_sink.send((request, sink)).await.map_err(
132 |_| ServiceError::WsClosing {
133 reason: "request handler failed",
134 },
135 )?;
136 self.outgoing_responses.push(Box::pin(recv));
137
138 Ok(())
139 },
140 (Type::Request, None, _) => Err(ServiceError::InvalidFrame {
141 reason: "type was request, but does not contain request",
142 }),
143 (Type::Response, _, Some(response)) => {
144 if let Some(id) = response.id {
145 if let Some(responder) = self.outgoing_requests.remove(&id)
146 {
147 if let Err(e) = responder.send(Ok(response)) {
148 tracing::warn!(
149 "Could not deliver response for id {}: {:?}",
150 id,
151 e
152 );
153 }
154 } else if let Some(_x) =
155 self.outgoing_keep_alive_set.take(&id)
156 {
157 let status_code = response.status();
158 if status_code != 200 {
159 tracing::warn!(
160 status_code,
161 "response code for keep-alive != 200"
162 );
163 return Err(ServiceError::UnhandledResponseCode {
164 http_code: response.status() as u16,
165 });
166 }
167 } else {
168 tracing::warn!(
169 ?response,
170 "response for non existing request"
171 );
172 }
173 }
174
175 Ok(())
176 },
177 (Type::Response, _, None) => Err(ServiceError::InvalidFrame {
178 reason: "type was response, but does not contain response",
179 }),
180 }
181 }
182
183 fn next_request_id(&self) -> u64 {
184 use rand::Rng;
185 let mut rng = rand::thread_rng();
186 loop {
187 let id = rng.gen();
188 if !self.outgoing_requests.contains_key(&id) {
189 return id;
190 }
191 }
192 }
193
194 async fn run(mut self) -> Result<(), ServiceError> {
195 let mut ka_interval = tokio::time::interval_at(
196 Instant::now(),
197 push_service::KEEPALIVE_TIMEOUT_SECONDS,
198 );
199
200 loop {
201 futures::select! {
202 _ = ka_interval.tick().fuse() => {
203 use prost::Message;
204 if !self.outgoing_keep_alive_set.is_empty() {
205 tracing::warn!("Websocket will be closed due to failed keepalives.");
206 if let Err(e) = self.ws.close(reqwest_websocket::CloseCode::Away, None).await {
207 tracing::debug!("Could not close WebSocket: {:?}", e);
208 }
209 self.outgoing_keep_alive_set.clear();
210 break;
211 }
212 tracing::debug!("sending keep-alive");
213 let request = WebSocketRequestMessage::new(Method::GET)
214 .id(self.next_request_id())
215 .path(&self.keep_alive_path)
216 .build();
217 self.outgoing_keep_alive_set.insert(request.id.unwrap());
218 let msg = WebSocketMessage {
219 r#type: Some(web_socket_message::Type::Request.into()),
220 request: Some(request),
221 ..Default::default()
222 };
223 let buffer = msg.encode_to_vec();
224 if let Err(e) = self.ws.send(reqwest_websocket::Message::Binary(buffer)).await {
225 tracing::info!("Websocket sink has closed: {:?}.", e);
226 break;
227 };
228 },
229 x = self.requests.next() => {
231 match x {
232 Some((mut request, responder)) => {
233 use prost::Message;
234
235 request.id = Some(
237 request
238 .id
239 .filter(|x| !self.outgoing_requests.contains_key(x))
240 .unwrap_or_else(|| self.next_request_id()),
241 );
242 tracing::trace!(
243 request.id,
244 request.verb,
245 request.path,
246 request_body_size_bytes = request.body.as_ref().map(|x| x.len()),
247 ?request.headers,
248 "sending WebSocketRequestMessage",
249 );
250
251 self.outgoing_requests.insert(request.id.unwrap(), responder);
252 let msg = WebSocketMessage {
253 r#type: Some(web_socket_message::Type::Request.into()),
254 request: Some(request),
255 ..Default::default()
256 };
257 let buffer = msg.encode_to_vec();
258 self.ws.send(reqwest_websocket::Message::Binary(buffer)).await?
259 }
260 None => {
261 return Err(ServiceError::WsClosing {
262 reason: "end of application request stream; socket closing"
263 });
264 }
265 }
266 }
267 web_socket_item = self.ws.next().fuse() => {
269 use reqwest_websocket::Message;
270 match web_socket_item {
271 Some(Ok(Message::Close { code, reason })) => {
272 tracing::warn!(%code, reason, "websocket closed");
273 break;
274 },
275 Some(Ok(Message::Binary(frame))) => {
276 self.process_frame(frame).await?;
277 }
278 Some(Ok(Message::Ping(_))) => {
279 tracing::trace!("received ping");
280 }
281 Some(Ok(Message::Pong(_))) => {
282 tracing::trace!("received pong");
283 }
284 Some(Ok(Message::Text(_))) => {
285 tracing::trace!("received text (unsupported, skipping)");
286 }
287 Some(Err(e)) => return Err(ServiceError::WsError(e)),
288 None => {
289 return Err(ServiceError::WsClosing {
290 reason: "end of web request stream; socket closing"
291 });
292 }
293 }
294 }
295 response = self.outgoing_responses.next() => {
296 use prost::Message;
297 match response {
298 Some(Ok(response)) => {
299 tracing::trace!("sending response {:?}", response);
300
301 let msg = WebSocketMessage {
302 r#type: Some(web_socket_message::Type::Response.into()),
303 response: Some(response),
304 ..Default::default()
305 };
306 let buffer = msg.encode_to_vec();
307 self.ws.send(buffer.into()).await?;
308 }
309 Some(Err(error)) => {
310 tracing::error!(%error, "could not generate response to a Signal request; responder was canceled. continuing.");
311 }
312 None => {
313 unreachable!("outgoing responses should never fuse")
314 }
315 }
316 }
317 }
318 }
319 Ok(())
320 }
321}
322
323impl SignalWebSocket {
324 fn inner_locked(&self) -> MutexGuard<'_, SignalWebSocketInner> {
325 self.inner.lock().unwrap()
326 }
327
328 pub fn from_socket(
329 ws: WebSocket,
330 keep_alive_path: String,
331 ) -> (Self, impl Future<Output = ()>) {
332 let (incoming_request_sink, incoming_request_stream) = mpsc::channel(1);
334 let (outgoing_request_sink, outgoing_requests) = mpsc::channel(1);
335
336 let process = SignalWebSocketProcess {
337 keep_alive_path,
338 requests: outgoing_requests,
339 request_sink: incoming_request_sink,
340 outgoing_requests: HashMap::default(),
341 outgoing_keep_alive_set: HashSet::new(),
342 outgoing_responses: vec![
345 Box::pin(futures::future::pending()) as BoxFuture<_>
346 ]
347 .into_iter()
348 .collect(),
349 ws,
350 };
351 let process = process.run().map(|x| match x {
352 Ok(()) => (),
353 Err(e) => {
354 tracing::error!("SignalWebSocket: {}", e);
355 },
356 });
357
358 (
359 Self {
360 request_sink: outgoing_request_sink,
361 inner: Arc::new(Mutex::new(SignalWebSocketInner {
362 stream: Some(SignalRequestStream {
363 inner: incoming_request_stream,
364 }),
365 })),
366 },
367 process,
368 )
369 }
370
371 pub fn is_closed(&self) -> bool {
372 self.request_sink.is_closed()
373 }
374
375 pub fn is_used(&self) -> bool {
376 self.inner_locked().stream.is_none()
377 }
378
379 pub(crate) fn take_request_stream(
380 &mut self,
381 ) -> Option<SignalRequestStream> {
382 self.inner_locked().stream.take()
383 }
384
385 pub(crate) fn return_request_stream(&mut self, r: SignalRequestStream) {
386 self.inner_locked().stream.replace(r);
387 }
388
389 pub async fn with_request_stream<
392 R: 'static,
393 F: FnOnce(&mut SignalRequestStream) -> R,
394 >(
395 &mut self,
396 f: F,
397 ) -> R {
398 let mut s = self
399 .inner_locked()
400 .stream
401 .take()
402 .expect("request stream invariant");
403 let r = f(&mut s);
404 self.inner_locked().stream.replace(s);
405 r
406 }
407
408 pub fn request(
409 &mut self,
410 r: WebSocketRequestMessage,
411 ) -> impl Future<Output = Result<WebSocketResponseMessage, ServiceError>>
412 {
413 let (sink, recv): (
414 oneshot::Sender<Result<WebSocketResponseMessage, ServiceError>>,
415 _,
416 ) = oneshot::channel();
417
418 let mut request_sink = self.request_sink.clone();
419 async move {
420 if let Err(_e) = request_sink.send((r, sink)).await {
421 return Err(ServiceError::WsClosing {
422 reason: "WebSocket closing while sending request",
423 });
424 }
425 match recv.await {
427 Ok(x) => x,
428 Err(_) => Err(ServiceError::WsClosing {
429 reason: "WebSocket closing while waiting for a response",
430 }),
431 }
432 }
433 }
434
435 pub(crate) async fn request_json<T>(
436 &mut self,
437 r: WebSocketRequestMessage,
438 ) -> Result<T, ServiceError>
439 where
440 for<'de> T: serde::Deserialize<'de>,
441 {
442 self.request(r)
443 .await?
444 .service_error_for_status()
445 .await?
446 .json()
447 .await
448 }
449}