libsignal_service/push_service/
response.rs

1use reqwest::StatusCode;
2
3use crate::proto::WebSocketResponseMessage;
4
5use super::ServiceError;
6
7async fn service_error_for_status<R>(response: R) -> Result<R, ServiceError>
8where
9    R: SignalServiceResponse,
10    ServiceError: From<<R as SignalServiceResponse>::Error>,
11{
12    match response.status_code() {
13        StatusCode::OK
14        | StatusCode::CREATED
15        | StatusCode::ACCEPTED
16        | StatusCode::NO_CONTENT => Ok(response),
17        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
18            Err(ServiceError::Unauthorized)
19        },
20        StatusCode::NOT_FOUND => {
21            // This is 404 and means that e.g. recipient is not registered
22            Err(ServiceError::NotFoundError)
23        },
24        StatusCode::PAYLOAD_TOO_LARGE => {
25            // This is 413 and means rate limit exceeded for Signal.
26            Err(ServiceError::RateLimitExceeded)
27        },
28        StatusCode::CONFLICT => {
29            let mismatched_devices =
30                response.json().await.map_err(|error| {
31                    tracing::error!(
32                        %error,
33                        "failed to decode HTTP 409 status"
34                    );
35                    ServiceError::UnhandledResponseCode {
36                        http_code: StatusCode::CONFLICT.as_u16(),
37                    }
38                })?;
39            Err(ServiceError::MismatchedDevicesException(mismatched_devices))
40        },
41        StatusCode::GONE => {
42            let stale_devices = response.json().await.map_err(|error| {
43                tracing::error!(%error, "failed to decode HTTP 410 status");
44                ServiceError::UnhandledResponseCode {
45                    http_code: StatusCode::GONE.as_u16(),
46                }
47            })?;
48            Err(ServiceError::StaleDevices(stale_devices))
49        },
50        StatusCode::LOCKED => {
51            let locked = response.json().await.map_err(|error| {
52                tracing::error!(%error, "failed to decode HTTP 423 status");
53                ServiceError::UnhandledResponseCode {
54                    http_code: StatusCode::LOCKED.as_u16(),
55                }
56            })?;
57            Err(ServiceError::Locked(locked))
58        },
59        StatusCode::PRECONDITION_REQUIRED => {
60            let proof_required = response.json().await.map_err(|error| {
61                tracing::error!(
62                    %error,
63                    "failed to decode HTTP 428 status"
64                );
65                ServiceError::UnhandledResponseCode {
66                    http_code: StatusCode::PRECONDITION_REQUIRED.as_u16(),
67                }
68            })?;
69            Err(ServiceError::ProofRequiredError(proof_required))
70        },
71        // XXX: fill in rest from PushServiceSocket
72        code => {
73            let response_text = response.text().await?;
74            tracing::trace!(status_code =% code, body = response_text, "unhandled HTTP response");
75            Err(ServiceError::UnhandledResponseCode {
76                http_code: code.as_u16(),
77            })
78        },
79    }
80}
81
82#[async_trait::async_trait]
83pub(crate) trait SignalServiceResponse {
84    type Error: std::error::Error;
85
86    fn status_code(&self) -> StatusCode;
87
88    async fn json<U>(self) -> Result<U, Self::Error>
89    where
90        for<'de> U: serde::Deserialize<'de>;
91
92    async fn text(self) -> Result<String, Self::Error>;
93}
94
95#[async_trait::async_trait]
96impl SignalServiceResponse for reqwest::Response {
97    type Error = reqwest::Error;
98
99    fn status_code(&self) -> StatusCode {
100        self.status()
101    }
102
103    async fn json<U>(self) -> Result<U, Self::Error>
104    where
105        for<'de> U: serde::Deserialize<'de>,
106    {
107        reqwest::Response::json(self).await
108    }
109
110    async fn text(self) -> Result<String, Self::Error> {
111        reqwest::Response::text(self).await
112    }
113}
114
115#[async_trait::async_trait]
116impl SignalServiceResponse for WebSocketResponseMessage {
117    type Error = ServiceError;
118
119    fn status_code(&self) -> StatusCode {
120        StatusCode::from_u16(self.status() as u16).unwrap_or_default()
121    }
122
123    async fn json<U>(self) -> Result<U, Self::Error>
124    where
125        for<'de> U: serde::Deserialize<'de>,
126    {
127        serde_json::from_slice(self.body()).map_err(Into::into)
128    }
129
130    async fn text(self) -> Result<String, Self::Error> {
131        Ok(self
132            .body
133            .map(|body| String::from_utf8_lossy(&body).to_string())
134            .unwrap_or_default())
135    }
136}
137
138#[async_trait::async_trait]
139pub(crate) trait ReqwestExt
140where
141    Self: Sized,
142{
143    /// convenience error handler to be used in the builder-style API of `reqwest::Response`
144    async fn service_error_for_status(
145        self,
146    ) -> Result<reqwest::Response, ServiceError>;
147}
148
149#[async_trait::async_trait]
150impl ReqwestExt for reqwest::Response {
151    async fn service_error_for_status(
152        self,
153    ) -> Result<reqwest::Response, ServiceError> {
154        service_error_for_status(self).await
155    }
156}
157
158impl WebSocketResponseMessage {
159    pub async fn service_error_for_status(self) -> Result<Self, ServiceError> {
160        service_error_for_status(self).await
161    }
162}