libsignal_service/
profile_cipher.rs

1use std::convert::TryInto;
2
3use aes_gcm::{aead::Aead, AeadCore, AeadInPlace, Aes256Gcm, KeyInit};
4use rand::{rand_core, CryptoRng, RngCore};
5use zkgroup::profiles::ProfileKey;
6
7use crate::{
8    profile_name::ProfileName, push_service::SignalServiceProfile, Profile,
9};
10
11/// Encrypt and decrypt a [`ProfileName`] and other profile information.
12///
13/// # Example
14///
15/// ```rust
16/// # use libsignal_service::{profile_name::ProfileName, profile_cipher::ProfileCipher};
17/// # use zkgroup::profiles::ProfileKey;
18/// # use rand::Rng;
19/// # let mut rng = rand::rng();
20/// # let some_randomness = rng.random();
21/// let profile_key = ProfileKey::generate(some_randomness);
22/// let name = ProfileName::<&str> {
23///     given_name: "Bill",
24///     family_name: None,
25/// };
26/// let cipher = ProfileCipher::new(profile_key);
27/// let encrypted = cipher.encrypt_name(&name, &mut rng).unwrap();
28/// let decrypted = cipher.decrypt_name(&encrypted).unwrap().unwrap();
29/// assert_eq!(decrypted.as_ref(), name);
30/// ```
31pub struct ProfileCipher {
32    profile_key: ProfileKey,
33}
34
35const NAME_PADDED_LENGTH_1: usize = 53;
36const NAME_PADDED_LENGTH_2: usize = 257;
37const NAME_PADDING_BRACKETS: &[usize] =
38    &[NAME_PADDED_LENGTH_1, NAME_PADDED_LENGTH_2];
39
40const ABOUT_PADDED_LENGTH_1: usize = 128;
41const ABOUT_PADDED_LENGTH_2: usize = 254;
42const ABOUT_PADDED_LENGTH_3: usize = 512;
43const ABOUT_PADDING_BRACKETS: &[usize] = &[
44    ABOUT_PADDED_LENGTH_1,
45    ABOUT_PADDED_LENGTH_2,
46    ABOUT_PADDED_LENGTH_3,
47];
48
49const EMOJI_PADDED_LENGTH: usize = 32;
50
51#[derive(thiserror::Error, Debug)]
52pub enum ProfileCipherError {
53    #[error("Encryption error")]
54    EncryptionError,
55    #[error("UTF-8 decode error {0}")]
56    Utf8Error(#[from] std::str::Utf8Error),
57    #[error("Input name too long")]
58    InputTooLong,
59}
60
61fn pad_plaintext(
62    bytes: &mut Vec<u8>,
63    brackets: &[usize],
64) -> Result<usize, ProfileCipherError> {
65    let len = brackets
66        .iter()
67        .find(|x| **x >= bytes.len())
68        .ok_or(ProfileCipherError::InputTooLong)?;
69    let len: usize = *len;
70
71    bytes.resize(len, 0);
72    assert!(brackets.contains(&bytes.len()));
73
74    Ok(len)
75}
76
77impl ProfileCipher {
78    pub fn new(profile_key: ProfileKey) -> Self {
79        Self { profile_key }
80    }
81
82    pub fn into_inner(self) -> ProfileKey {
83        self.profile_key
84    }
85
86    fn pad_and_encrypt<R: RngCore + CryptoRng>(
87        &self,
88        mut bytes: Vec<u8>,
89        padding_brackets: &[usize],
90        csprng: &mut R,
91    ) -> Result<Vec<u8>, ProfileCipherError> {
92        let _len = pad_plaintext(&mut bytes, padding_brackets)?;
93
94        let csprng = Rng06Shiv(csprng);
95
96        let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
97        let nonce = Aes256Gcm::generate_nonce(csprng);
98
99        cipher
100            .encrypt_in_place(&nonce, b"", &mut bytes)
101            .map_err(|_| ProfileCipherError::EncryptionError)?;
102
103        let mut concat = Vec::with_capacity(nonce.len() + bytes.len());
104        concat.extend_from_slice(&nonce);
105        concat.extend_from_slice(&bytes);
106        Ok(concat)
107    }
108
109    fn decrypt_and_unpad(
110        &self,
111        bytes: impl AsRef<[u8]>,
112    ) -> Result<Vec<u8>, ProfileCipherError> {
113        let bytes = bytes.as_ref();
114        let nonce: [u8; 12] = bytes[0..12]
115            .try_into()
116            .expect("fixed length nonce material");
117        let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
118
119        let mut plaintext = cipher
120            .decrypt(&nonce.into(), &bytes[12..])
121            .map_err(|_| ProfileCipherError::EncryptionError)?;
122
123        // Unpad
124        let len = plaintext
125            .iter()
126            // Search the first non-0 char...
127            .rposition(|x| *x != 0)
128            // ...and strip until right after.
129            .map(|x| x + 1)
130            // If it's all zeroes, the string is 0-length.
131            .unwrap_or(0);
132        plaintext.truncate(len);
133        Ok(plaintext)
134    }
135
136    pub fn decrypt(
137        &self,
138        encrypted_profile: SignalServiceProfile,
139    ) -> Result<Profile, ProfileCipherError> {
140        let name = encrypted_profile
141            .name
142            .as_ref()
143            .map(|data| self.decrypt_name(data))
144            .transpose()?
145            .flatten();
146        let about = encrypted_profile
147            .about
148            .as_ref()
149            .map(|data| self.decrypt_about(data))
150            .transpose()?;
151        let about_emoji = encrypted_profile
152            .about_emoji
153            .as_ref()
154            .map(|data| self.decrypt_emoji(data))
155            .transpose()?;
156
157        Ok(Profile {
158            name,
159            about,
160            about_emoji,
161            avatar: encrypted_profile.avatar,
162            unrestricted_unidentified_access: encrypted_profile
163                .unrestricted_unidentified_access,
164        })
165    }
166
167    pub fn decrypt_avatar(
168        &self,
169        bytes: &[u8],
170    ) -> Result<Vec<u8>, ProfileCipherError> {
171        self.decrypt_and_unpad(bytes)
172    }
173
174    pub fn encrypt_name<'inp, R: RngCore + CryptoRng>(
175        &self,
176        name: impl std::borrow::Borrow<ProfileName<&'inp str>>,
177        csprng: &mut R,
178    ) -> Result<Vec<u8>, ProfileCipherError> {
179        let name = name.borrow();
180        let bytes = name.serialize();
181        self.pad_and_encrypt(bytes, NAME_PADDING_BRACKETS, csprng)
182    }
183
184    pub fn decrypt_name(
185        &self,
186        bytes: impl AsRef<[u8]>,
187    ) -> Result<Option<ProfileName<String>>, ProfileCipherError> {
188        let bytes = bytes.as_ref();
189
190        let plaintext = self.decrypt_and_unpad(bytes)?;
191
192        Ok(ProfileName::<String>::deserialize(&plaintext)?)
193    }
194
195    pub fn encrypt_about<R: RngCore + CryptoRng>(
196        &self,
197        about: String,
198        csprng: &mut R,
199    ) -> Result<Vec<u8>, ProfileCipherError> {
200        let bytes = about.into_bytes();
201        self.pad_and_encrypt(bytes, ABOUT_PADDING_BRACKETS, csprng)
202    }
203
204    pub fn decrypt_about(
205        &self,
206        bytes: impl AsRef<[u8]>,
207    ) -> Result<String, ProfileCipherError> {
208        let bytes = bytes.as_ref();
209
210        let plaintext = self.decrypt_and_unpad(bytes)?;
211
212        // XXX This re-allocates.
213        Ok(std::str::from_utf8(&plaintext)?.into())
214    }
215
216    pub fn encrypt_emoji<R: RngCore + CryptoRng>(
217        &self,
218        emoji: String,
219        csprng: &mut R,
220    ) -> Result<Vec<u8>, ProfileCipherError> {
221        let bytes = emoji.into_bytes();
222        self.pad_and_encrypt(bytes, &[EMOJI_PADDED_LENGTH], csprng)
223    }
224
225    pub fn decrypt_emoji(
226        &self,
227        bytes: impl AsRef<[u8]>,
228    ) -> Result<String, ProfileCipherError> {
229        let bytes = bytes.as_ref();
230
231        let plaintext = self.decrypt_and_unpad(bytes)?;
232
233        // XXX This re-allocates.
234        Ok(std::str::from_utf8(&plaintext)?.into())
235    }
236}
237
238struct Rng06Shiv<'a, T>(&'a mut T);
239
240impl<T: rand_core::RngCore> rand_core_06::RngCore for Rng06Shiv<'_, T> {
241    fn next_u32(&mut self) -> u32 {
242        self.0.next_u32()
243    }
244
245    fn next_u64(&mut self) -> u64 {
246        self.0.next_u64()
247    }
248
249    fn fill_bytes(&mut self, dest: &mut [u8]) {
250        self.0.fill_bytes(dest)
251    }
252
253    fn try_fill_bytes(
254        &mut self,
255        dest: &mut [u8],
256    ) -> Result<(), rand_core_06::Error> {
257        self.0.fill_bytes(dest);
258        Ok(())
259    }
260}
261
262impl<T: rand_core::CryptoRng> rand_core_06::CryptoRng for Rng06Shiv<'_, T> {}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::profile_name::ProfileName;
268    use rand::Rng;
269    use zkgroup::profiles::ProfileKey;
270
271    #[test]
272    fn roundtrip_name() {
273        let names = [
274            "Me and my guitar", // shorter that 53
275            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", // one shorter than 53
276            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzx", // exactly 53
277            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxf", // one more than 53
278            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxfoobar", // a bit more than 53
279        ];
280
281        // Test the test cases
282        assert_eq!(names[1].len(), NAME_PADDED_LENGTH_1 - 1);
283        assert_eq!(names[2].len(), NAME_PADDED_LENGTH_1);
284        assert_eq!(names[3].len(), NAME_PADDED_LENGTH_1 + 1);
285
286        let mut rng = rand::rng();
287        let some_randomness = rng.random();
288        let profile_key = ProfileKey::generate(some_randomness);
289        let cipher = ProfileCipher::new(profile_key);
290        for name in &names {
291            let profile_name = ProfileName::<&str> {
292                given_name: name,
293                family_name: None,
294            };
295            assert_eq!(profile_name.serialize().len(), name.len());
296            let encrypted =
297                cipher.encrypt_name(&profile_name, &mut rng).unwrap();
298            let decrypted = cipher.decrypt_name(encrypted).unwrap().unwrap();
299
300            assert_eq!(decrypted.as_ref(), profile_name);
301            assert_eq!(decrypted.serialize(), profile_name.serialize());
302            assert_eq!(&decrypted.given_name, name);
303        }
304    }
305
306    #[test]
307    fn roundtrip_about() {
308        let abouts = [
309            "Me and my guitar", // shorter that 53
310        ];
311
312        let mut rng = rand::rng();
313        let some_randomness = rng.random();
314        let profile_key = ProfileKey::generate(some_randomness);
315        let cipher = ProfileCipher::new(profile_key);
316
317        for &about in &abouts {
318            let encrypted =
319                cipher.encrypt_about(about.into(), &mut rng).unwrap();
320            let decrypted = cipher.decrypt_about(encrypted).unwrap();
321
322            assert_eq!(decrypted, about);
323        }
324    }
325
326    #[test]
327    fn roundtrip_emoji() {
328        let emojii = ["❤️", "💩", "🤣", "😲", "🐠"];
329
330        let mut rng = rand::rng();
331        let some_randomness = rng.random();
332        let profile_key = ProfileKey::generate(some_randomness);
333        let cipher = ProfileCipher::new(profile_key);
334
335        for &emoji in &emojii {
336            let encrypted =
337                cipher.encrypt_emoji(emoji.into(), &mut rng).unwrap();
338            let decrypted = cipher.decrypt_emoji(encrypted).unwrap();
339
340            assert_eq!(decrypted, emoji);
341        }
342    }
343}