libsignal_service/
utils.rs

1mod phonenumber;
2use libsignal_core::{Aci, Pni, ServiceId};
3pub use phonenumber::*;
4use uuid::Uuid;
5
6// Signal sometimes adds padding, sometimes it does not.
7// This requires a custom decoding engine.
8// This engine is as general as possible.
9pub const BASE64_RELAXED: base64::engine::GeneralPurpose =
10    base64::engine::GeneralPurpose::new(
11        &base64::alphabet::STANDARD,
12        base64::engine::GeneralPurposeConfig::new()
13            .with_encode_padding(true)
14            .with_decode_padding_mode(
15                base64::engine::DecodePaddingMode::Indifferent,
16            ),
17    );
18
19pub fn parse_aci_with_fallback(
20    bytes: Option<&[u8]>,
21    utf8: Option<&str>,
22) -> Option<Aci> {
23    let binary = bytes.and_then(|bytes| {
24        let bytes = bytes
25            .try_into()
26            .inspect_err(|_e| tracing::warn!("binary ACI not 16 bytes"))
27            .ok()?;
28        Some(Aci::from_uuid_bytes(bytes))
29    });
30
31    binary.or_else(|| {
32        let utf8 = utf8?;
33        match Aci::parse_from_service_id_string(utf8) {
34            Some(sid) => Some(sid),
35            None => {
36                tracing::warn!("unparseable utf8 ACI");
37                None
38            },
39        }
40    })
41}
42
43pub fn parse_pni_with_fallback(
44    bytes: Option<&[u8]>,
45    utf8: Option<&str>,
46    pni_is_uuid: bool,
47) -> Option<Pni> {
48    let binary = bytes.and_then(|bytes| {
49        let bytes = bytes
50            .try_into()
51            .inspect_err(|_e| tracing::warn!("binary PNI not 16 bytes"))
52            .ok()?;
53        Some(Pni::from_uuid_bytes(bytes))
54    });
55
56    binary.or_else(|| {
57        let utf8 = utf8?;
58        if pni_is_uuid {
59            let uuid: uuid::Uuid = utf8
60                .parse()
61                .inspect_err(|e| {
62                    tracing::warn!(error = %e, "unparseable UUID");
63                })
64                .ok()?;
65            Some(Pni::from_uuid_bytes(*uuid.as_bytes()))
66        } else {
67            match Pni::parse_from_service_id_string(utf8) {
68                Some(sid) => Some(sid),
69                None => {
70                    tracing::warn!("unparseable utf8 PNI");
71                    None
72                },
73            }
74        }
75    })
76}
77
78pub fn parse_service_id_with_fallback(
79    bytes: Option<&[u8]>,
80    utf8: Option<&str>,
81) -> Option<ServiceId> {
82    let binary = bytes.and_then(|bytes| {
83        match ServiceId::parse_from_service_id_binary(bytes) {
84            Some(sid) => Some(sid),
85            None => {
86                tracing::warn!("unparseable binary ServiceId");
87                None
88            },
89        }
90    });
91
92    binary.or_else(|| {
93        let utf8 = utf8?;
94        match ServiceId::parse_from_service_id_string(utf8) {
95            Some(sid) => Some(sid),
96            None => {
97                tracing::warn!("unparseable utf8 ServiceId");
98                None
99            },
100        }
101    })
102}
103
104/// Parse protobuf UUIDs specified in both binary and utf8 formats
105///
106/// Prefers the binary format
107pub fn parse_uuid_with_fallback(
108    binary: Option<&[u8]>,
109    utf8: Option<&str>,
110) -> Option<Uuid> {
111    let binary = binary
112        .map(<[u8; 16]>::try_from)
113        .transpose()
114        .inspect_err(|_e| tracing::warn!("invalid binary UUID length"))
115        .ok()
116        .flatten()
117        .map(Uuid::from_bytes);
118
119    binary.or_else(|| {
120        let utf8 = utf8?;
121        utf8.parse()
122            .inspect_err(|e| tracing::warn!(error=%e, "unparseable UUID"))
123            .ok()
124    })
125}
126
127pub fn random_length_padding<R: rand::Rng + rand::CryptoRng>(
128    csprng: &mut R,
129    max_len: usize,
130) -> Vec<u8> {
131    let length = csprng.random_range(0..max_len);
132    let mut padding = vec![0u8; length];
133    csprng.fill_bytes(&mut padding);
134    padding
135}
136
137pub mod serde_base64 {
138    use super::BASE64_RELAXED;
139    use base64::prelude::*;
140    use serde::{Deserialize, Deserializer, Serializer};
141
142    pub fn serialize<T, S>(bytes: &T, serializer: S) -> Result<S::Ok, S::Error>
143    where
144        T: AsRef<[u8]>,
145        S: Serializer,
146    {
147        serializer.serialize_str(&BASE64_RELAXED.encode(bytes.as_ref()))
148    }
149
150    pub fn deserialize<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
151    where
152        D: Deserializer<'de>,
153    {
154        use serde::de::Error;
155        <&str>::deserialize(deserializer).and_then(|string| {
156            BASE64_RELAXED
157                .decode(string)
158                .map_err(|err| Error::custom(err.to_string()))
159        })
160    }
161}
162
163pub mod serde_optional_base64 {
164    use super::BASE64_RELAXED;
165    use base64::prelude::*;
166    use serde::{Deserialize, Deserializer, Serializer};
167
168    use super::serde_base64;
169
170    pub fn serialize<T, S>(
171        bytes: &Option<T>,
172        serializer: S,
173    ) -> Result<S::Ok, S::Error>
174    where
175        T: AsRef<[u8]>,
176        S: Serializer,
177    {
178        match bytes {
179            Some(bytes) => serde_base64::serialize(bytes, serializer),
180            None => serializer.serialize_none(),
181        }
182    }
183
184    pub fn deserialize<'de, D>(
185        deserializer: D,
186    ) -> Result<Option<Vec<u8>>, D::Error>
187    where
188        D: Deserializer<'de>,
189    {
190        use serde::de::Error;
191        match Option::<String>::deserialize(deserializer)? {
192            Some(s) => BASE64_RELAXED
193                .decode(s)
194                .map_err(|err| Error::custom(err.to_string()))
195                .map(Some),
196            None => Ok(None),
197        }
198    }
199}
200
201pub mod serde_identity_key {
202    use super::BASE64_RELAXED;
203    use base64::prelude::*;
204    use libsignal_protocol::IdentityKey;
205    use serde::{Deserialize, Deserializer, Serializer};
206
207    pub fn serialize<S>(
208        public_key: &IdentityKey,
209        serializer: S,
210    ) -> Result<S::Ok, S::Error>
211    where
212        S: Serializer,
213    {
214        let public_key = public_key.serialize();
215        serializer.serialize_str(&BASE64_RELAXED.encode(&public_key))
216    }
217
218    pub fn deserialize<'de, D>(deserializer: D) -> Result<IdentityKey, D::Error>
219    where
220        D: Deserializer<'de>,
221    {
222        IdentityKey::decode(
223            &BASE64_RELAXED
224                .decode(<&str>::deserialize(deserializer)?)
225                .map_err(serde::de::Error::custom)?,
226        )
227        .map_err(serde::de::Error::custom)
228    }
229}
230
231pub mod serde_optional_identity_key {
232    use super::BASE64_RELAXED;
233    use base64::prelude::*;
234    use libsignal_protocol::IdentityKey;
235    use serde::{Deserialize, Deserializer, Serializer};
236
237    use super::serde_identity_key;
238
239    pub fn serialize<S>(
240        public_key: &Option<IdentityKey>,
241        serializer: S,
242    ) -> Result<S::Ok, S::Error>
243    where
244        S: Serializer,
245    {
246        match public_key {
247            Some(public_key) => {
248                serde_identity_key::serialize(public_key, serializer)
249            },
250            None => serializer.serialize_none(),
251        }
252    }
253
254    pub fn deserialize<'de, D>(
255        deserializer: D,
256    ) -> Result<Option<IdentityKey>, D::Error>
257    where
258        D: Deserializer<'de>,
259    {
260        match Option::<String>::deserialize(deserializer)? {
261            Some(public_key) => Ok(Some(
262                IdentityKey::decode(
263                    &BASE64_RELAXED
264                        .decode(public_key)
265                        .map_err(serde::de::Error::custom)?,
266                )
267                .map_err(serde::de::Error::custom)?,
268            )),
269            None => Ok(None),
270        }
271    }
272}
273
274pub mod serde_private_key {
275    use super::BASE64_RELAXED;
276    use base64::prelude::*;
277    use libsignal_protocol::PrivateKey;
278    use serde::{Deserialize, Deserializer, Serializer};
279
280    pub fn serialize<S>(
281        public_key: &PrivateKey,
282        serializer: S,
283    ) -> Result<S::Ok, S::Error>
284    where
285        S: Serializer,
286    {
287        let public_key = public_key.serialize();
288        serializer.serialize_str(&BASE64_RELAXED.encode(public_key))
289    }
290
291    pub fn deserialize<'de, D>(deserializer: D) -> Result<PrivateKey, D::Error>
292    where
293        D: Deserializer<'de>,
294    {
295        PrivateKey::deserialize(
296            &BASE64_RELAXED
297                .decode(<&str>::deserialize(deserializer)?)
298                .map_err(serde::de::Error::custom)?,
299        )
300        .map_err(serde::de::Error::custom)
301    }
302}
303
304pub mod serde_optional_private_key {
305    use super::BASE64_RELAXED;
306    use base64::prelude::*;
307    use libsignal_protocol::PrivateKey;
308    use serde::{Deserialize, Deserializer, Serializer};
309
310    use super::serde_private_key;
311
312    pub fn serialize<S>(
313        private_key: &Option<PrivateKey>,
314        serializer: S,
315    ) -> Result<S::Ok, S::Error>
316    where
317        S: Serializer,
318    {
319        match private_key {
320            Some(private_key) => {
321                serde_private_key::serialize(private_key, serializer)
322            },
323            None => serializer.serialize_none(),
324        }
325    }
326
327    pub fn deserialize<'de, D>(
328        deserializer: D,
329    ) -> Result<Option<PrivateKey>, D::Error>
330    where
331        D: Deserializer<'de>,
332    {
333        match Option::<String>::deserialize(deserializer)? {
334            Some(private_key) => Ok(Some(
335                PrivateKey::deserialize(
336                    &BASE64_RELAXED
337                        .decode(private_key)
338                        .map_err(serde::de::Error::custom)?,
339                )
340                .map_err(serde::de::Error::custom)?,
341            )),
342            None => Ok(None),
343        }
344    }
345}
346
347pub mod serde_optional_e164 {
348    use libsignal_core::E164;
349    use serde::{Deserialize, Deserializer, Serializer};
350
351    pub fn serialize<S>(
352        phone_number: &Option<E164>,
353        serializer: S,
354    ) -> Result<S::Ok, S::Error>
355    where
356        S: Serializer,
357    {
358        match phone_number {
359            Some(p) => serializer.serialize_str(&p.to_string()),
360            None => serializer.serialize_none(),
361        }
362    }
363
364    pub fn deserialize<'de, D>(
365        deserializer: D,
366    ) -> Result<Option<E164>, D::Error>
367    where
368        D: Deserializer<'de>,
369    {
370        match Option::<String>::deserialize(deserializer)? {
371            Some(s) => s.parse().map_err(serde::de::Error::custom).map(Some),
372            None => Ok(None),
373        }
374    }
375}
376
377pub mod serde_e164 {
378    use libsignal_core::E164;
379    use serde::{Deserialize, Deserializer, Serializer};
380
381    pub fn serialize<S>(
382        phone_number: &E164,
383        serializer: S,
384    ) -> Result<S::Ok, S::Error>
385    where
386        S: Serializer,
387    {
388        serializer.serialize_str(&phone_number.to_string())
389    }
390
391    pub fn deserialize<'de, D>(deserializer: D) -> Result<E164, D::Error>
392    where
393        D: Deserializer<'de>,
394    {
395        <&str>::deserialize(deserializer)?
396            .parse()
397            .map_err(serde::de::Error::custom)
398    }
399}
400
401#[cfg(feature = "phonenumber")]
402pub mod serde_phone_number {
403    use phonenumber::PhoneNumber;
404    use serde::{Deserialize, Deserializer, Serializer};
405
406    pub fn serialize<S>(
407        phone_number: &PhoneNumber,
408        serializer: S,
409    ) -> Result<S::Ok, S::Error>
410    where
411        S: Serializer,
412    {
413        serializer.serialize_str(&phone_number.to_string())
414    }
415
416    pub fn deserialize<'de, D>(deserializer: D) -> Result<PhoneNumber, D::Error>
417    where
418        D: Deserializer<'de>,
419    {
420        phonenumber::parse(None, <&str>::deserialize(deserializer)?)
421            .map_err(serde::de::Error::custom)
422    }
423}
424
425pub mod serde_service_id {
426    use libsignal_protocol::ServiceId;
427    use serde::{Deserialize, Deserializer, Serializer};
428
429    pub fn serialize<S>(
430        service_id: &ServiceId,
431        serializer: S,
432    ) -> Result<S::Ok, S::Error>
433    where
434        S: Serializer,
435    {
436        serializer.serialize_str(&service_id.service_id_string())
437    }
438
439    pub fn deserialize<'de, D>(deserializer: D) -> Result<ServiceId, D::Error>
440    where
441        D: Deserializer<'de>,
442    {
443        ServiceId::parse_from_service_id_string(<&str>::deserialize(
444            deserializer,
445        )?)
446        .ok_or_else(|| serde::de::Error::custom("invalid service ID string"))
447    }
448}
449
450pub mod serde_aci {
451    use libsignal_core::Aci;
452    use serde::{Deserialize, Deserializer, Serializer};
453
454    pub fn serialize<S>(aci: &Aci, serializer: S) -> Result<S::Ok, S::Error>
455    where
456        S: Serializer,
457    {
458        serializer.serialize_str(&aci.service_id_string())
459    }
460
461    pub fn deserialize<'de, D>(deserializer: D) -> Result<Aci, D::Error>
462    where
463        D: Deserializer<'de>,
464    {
465        Aci::parse_from_service_id_string(<&str>::deserialize(deserializer)?)
466            .ok_or_else(|| serde::de::Error::custom("invalid ACI string"))
467    }
468}
469
470pub mod serde_device_id {
471    use libsignal_core::DeviceId;
472    use serde::{Deserialize, Deserializer, Serializer};
473
474    pub fn serialize<S>(id: &DeviceId, serializer: S) -> Result<S::Ok, S::Error>
475    where
476        S: Serializer,
477    {
478        serializer.serialize_u8(u8::from(*id))
479    }
480
481    pub fn deserialize<'de, D>(deserializer: D) -> Result<DeviceId, D::Error>
482    where
483        D: Deserializer<'de>,
484    {
485        DeviceId::try_from(u8::deserialize(deserializer)?)
486            .map_err(|_| serde::de::Error::custom("invalid device id"))
487    }
488}
489
490pub mod serde_device_id_vec {
491    use libsignal_core::DeviceId;
492    use serde::{ser::SerializeSeq, Deserialize, Deserializer, Serializer};
493
494    pub fn serialize<S>(
495        ids: &Vec<DeviceId>,
496        serializer: S,
497    ) -> Result<S::Ok, S::Error>
498    where
499        S: Serializer,
500    {
501        let mut seq = serializer.serialize_seq(Some(ids.len()))?;
502        for id in ids {
503            seq.serialize_element(&u8::from(*id))?;
504        }
505        seq.end()
506    }
507
508    pub fn deserialize<'de, D>(
509        deserializer: D,
510    ) -> Result<Vec<DeviceId>, D::Error>
511    where
512        D: Deserializer<'de>,
513    {
514        Vec::<u8>::deserialize(deserializer)?
515            .into_iter()
516            .map(DeviceId::try_from)
517            .collect::<Result<Vec<_>, _>>()
518            .map_err(|_| serde::de::Error::custom("invalid device id"))
519    }
520}
521
522pub mod serde_prost_base64 {
523    use super::BASE64_RELAXED;
524    use base64::Engine;
525    use prost::Message;
526    use serde::{Deserialize, Deserializer, Serializer};
527
528    // Serializes a Prost message into a Base64 string
529    pub fn serialize<T, S>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
530    where
531        T: Message,
532        S: Serializer,
533    {
534        let b64 = BASE64_RELAXED.encode(value.encode_to_vec());
535        serializer.serialize_str(&b64)
536    }
537
538    // Deserializes a Base64 string back into a Prost message
539    pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
540    where
541        T: Message + Default,
542        D: Deserializer<'de>,
543    {
544        let bytes = BASE64_RELAXED
545            .decode(<&str>::deserialize(deserializer)?)
546            .map_err(serde::de::Error::custom)?;
547
548        T::decode(bytes.as_slice()).map_err(serde::de::Error::custom)
549    }
550}
551
552pub mod serde_optional_prost_base64 {
553    use base64::Engine;
554    use prost::Message;
555    use serde::{Deserialize, Deserializer, Serializer};
556
557    use super::{serde_prost_base64, BASE64_RELAXED};
558
559    pub fn serialize<T, S>(
560        value: &Option<T>,
561        serializer: S,
562    ) -> Result<S::Ok, S::Error>
563    where
564        T: Message,
565        S: Serializer,
566    {
567        match value {
568            Some(msg) => serde_prost_base64::serialize(msg, serializer),
569            None => serializer.serialize_none(),
570        }
571    }
572
573    pub fn deserialize<'de, T, D>(
574        deserializer: D,
575    ) -> Result<Option<T>, D::Error>
576    where
577        T: Message + Default,
578        D: Deserializer<'de>,
579    {
580        match Option::<String>::deserialize(deserializer)? {
581            Some(s) => {
582                let bytes = BASE64_RELAXED
583                    .decode(s)
584                    .map_err(serde::de::Error::custom)?;
585                let msg = T::decode(bytes.as_slice())
586                    .map_err(serde::de::Error::custom)?;
587                Ok(Some(msg))
588            },
589            None => Ok(None),
590        }
591    }
592}