libsignal_service/push_service/
response.rs1use reqwest::StatusCode;
2
3use crate::proto::WebSocketResponseMessage;
4
5use super::ServiceError;
6
7async fn service_error_for_status<R>(response: R) -> Result<R, ServiceError>
8where
9 R: SignalServiceResponse,
10 ServiceError: From<<R as SignalServiceResponse>::Error>,
11{
12 match response.status_code() {
13 StatusCode::OK
14 | StatusCode::CREATED
15 | StatusCode::ACCEPTED
16 | StatusCode::NO_CONTENT => Ok(response),
17 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
18 Err(ServiceError::Unauthorized)
19 },
20 StatusCode::NOT_FOUND => {
21 Err(ServiceError::NotFoundError)
23 },
24 StatusCode::PAYLOAD_TOO_LARGE => {
25 Err(ServiceError::RateLimitExceeded)
27 },
28 StatusCode::CONFLICT => {
29 let mismatched_devices =
30 response.json().await.map_err(|error| {
31 tracing::error!(
32 %error,
33 "failed to decode HTTP 409 status"
34 );
35 ServiceError::UnhandledResponseCode {
36 http_code: StatusCode::CONFLICT.as_u16(),
37 }
38 })?;
39 Err(ServiceError::MismatchedDevicesException(mismatched_devices))
40 },
41 StatusCode::GONE => {
42 let stale_devices = response.json().await.map_err(|error| {
43 tracing::error!(%error, "failed to decode HTTP 410 status");
44 ServiceError::UnhandledResponseCode {
45 http_code: StatusCode::GONE.as_u16(),
46 }
47 })?;
48 Err(ServiceError::StaleDevices(stale_devices))
49 },
50 StatusCode::LOCKED => {
51 let locked = response.json().await.map_err(|error| {
52 tracing::error!(%error, "failed to decode HTTP 423 status");
53 ServiceError::UnhandledResponseCode {
54 http_code: StatusCode::LOCKED.as_u16(),
55 }
56 })?;
57 Err(ServiceError::Locked(locked))
58 },
59 StatusCode::PRECONDITION_REQUIRED => {
60 let proof_required = response.json().await.map_err(|error| {
61 tracing::error!(
62 %error,
63 "failed to decode HTTP 428 status"
64 );
65 ServiceError::UnhandledResponseCode {
66 http_code: StatusCode::PRECONDITION_REQUIRED.as_u16(),
67 }
68 })?;
69 Err(ServiceError::ProofRequiredError(proof_required))
70 },
71 code => {
73 let response_text = response.text().await?;
74 tracing::trace!(status_code =% code, body = response_text, "unhandled HTTP response");
75 Err(ServiceError::UnhandledResponseCode {
76 http_code: code.as_u16(),
77 })
78 },
79 }
80}
81
82#[async_trait::async_trait]
83pub(crate) trait SignalServiceResponse {
84 type Error: std::error::Error;
85
86 fn status_code(&self) -> StatusCode;
87
88 async fn json<U>(self) -> Result<U, Self::Error>
89 where
90 for<'de> U: serde::Deserialize<'de>;
91
92 async fn text(self) -> Result<String, Self::Error>;
93}
94
95#[async_trait::async_trait]
96impl SignalServiceResponse for reqwest::Response {
97 type Error = reqwest::Error;
98
99 fn status_code(&self) -> StatusCode {
100 self.status()
101 }
102
103 async fn json<U>(self) -> Result<U, Self::Error>
104 where
105 for<'de> U: serde::Deserialize<'de>,
106 {
107 reqwest::Response::json(self).await
108 }
109
110 async fn text(self) -> Result<String, Self::Error> {
111 reqwest::Response::text(self).await
112 }
113}
114
115#[async_trait::async_trait]
116impl SignalServiceResponse for WebSocketResponseMessage {
117 type Error = ServiceError;
118
119 fn status_code(&self) -> StatusCode {
120 StatusCode::from_u16(self.status() as u16).unwrap_or_default()
121 }
122
123 async fn json<U>(self) -> Result<U, Self::Error>
124 where
125 for<'de> U: serde::Deserialize<'de>,
126 {
127 serde_json::from_slice(self.body()).map_err(Into::into)
128 }
129
130 async fn text(self) -> Result<String, Self::Error> {
131 Ok(self
132 .body
133 .map(|body| String::from_utf8_lossy(&body).to_string())
134 .unwrap_or_default())
135 }
136}
137
138#[async_trait::async_trait]
139pub(crate) trait ReqwestExt
140where
141 Self: Sized,
142{
143 async fn service_error_for_status(
145 self,
146 ) -> Result<reqwest::Response, ServiceError>;
147}
148
149#[async_trait::async_trait]
150impl ReqwestExt for reqwest::Response {
151 async fn service_error_for_status(
152 self,
153 ) -> Result<reqwest::Response, ServiceError> {
154 service_error_for_status(self).await
155 }
156}
157
158impl WebSocketResponseMessage {
159 pub async fn service_error_for_status(self) -> Result<Self, ServiceError> {
160 service_error_for_status(self).await
161 }
162}