libsignal_service/
profile_cipher.rs

1use std::convert::TryInto;
2
3use aes_gcm::{aead::Aead, AeadCore, AeadInPlace, Aes256Gcm, KeyInit};
4use rand::{rand_core, CryptoRng, RngCore};
5use zkgroup::profiles::ProfileKey;
6
7use crate::{
8    profile_name::ProfileName, websocket::profile::SignalServiceProfile,
9    Profile,
10};
11
12/// Encrypt and decrypt a [`ProfileName`] and other profile information.
13///
14/// # Example
15///
16/// ```rust
17/// # use libsignal_service::{profile_name::ProfileName, profile_cipher::ProfileCipher};
18/// # use zkgroup::profiles::ProfileKey;
19/// # use rand::Rng;
20/// # let mut rng = rand::rng();
21/// # let some_randomness = rng.random();
22/// let profile_key = ProfileKey::generate(some_randomness);
23/// let name = ProfileName::<&str> {
24///     given_name: "Bill",
25///     family_name: None,
26/// };
27/// let cipher = ProfileCipher::new(profile_key);
28/// let encrypted = cipher.encrypt_name(&name, &mut rng).unwrap();
29/// let decrypted = cipher.decrypt_name(&encrypted).unwrap().unwrap();
30/// assert_eq!(decrypted.as_ref(), name);
31/// ```
32pub struct ProfileCipher {
33    profile_key: ProfileKey,
34}
35
36const NAME_PADDED_LENGTH_1: usize = 53;
37const NAME_PADDED_LENGTH_2: usize = 257;
38const NAME_PADDING_BRACKETS: &[usize] =
39    &[NAME_PADDED_LENGTH_1, NAME_PADDED_LENGTH_2];
40
41const ABOUT_PADDED_LENGTH_1: usize = 128;
42const ABOUT_PADDED_LENGTH_2: usize = 254;
43const ABOUT_PADDED_LENGTH_3: usize = 512;
44const ABOUT_PADDING_BRACKETS: &[usize] = &[
45    ABOUT_PADDED_LENGTH_1,
46    ABOUT_PADDED_LENGTH_2,
47    ABOUT_PADDED_LENGTH_3,
48];
49
50const EMOJI_PADDED_LENGTH: usize = 32;
51
52#[derive(thiserror::Error, Debug)]
53pub enum ProfileCipherError {
54    #[error("Encryption error")]
55    EncryptionError,
56    #[error("UTF-8 decode error {0}")]
57    Utf8Error(#[from] std::str::Utf8Error),
58    #[error("Input name too long")]
59    InputTooLong,
60}
61
62fn pad_plaintext(
63    bytes: &mut Vec<u8>,
64    brackets: &[usize],
65) -> Result<usize, ProfileCipherError> {
66    let len = brackets
67        .iter()
68        .find(|x| **x >= bytes.len())
69        .ok_or(ProfileCipherError::InputTooLong)?;
70    let len: usize = *len;
71
72    bytes.resize(len, 0);
73    assert!(brackets.contains(&bytes.len()));
74
75    Ok(len)
76}
77
78impl ProfileCipher {
79    pub fn new(profile_key: ProfileKey) -> Self {
80        Self { profile_key }
81    }
82
83    pub fn into_inner(self) -> ProfileKey {
84        self.profile_key
85    }
86
87    fn pad_and_encrypt<R: RngCore + CryptoRng>(
88        &self,
89        mut bytes: Vec<u8>,
90        padding_brackets: &[usize],
91        csprng: &mut R,
92    ) -> Result<Vec<u8>, ProfileCipherError> {
93        let _len = pad_plaintext(&mut bytes, padding_brackets)?;
94
95        let csprng = Rng06Shiv(csprng);
96
97        let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
98        let nonce = Aes256Gcm::generate_nonce(csprng);
99
100        cipher
101            .encrypt_in_place(&nonce, b"", &mut bytes)
102            .map_err(|_| ProfileCipherError::EncryptionError)?;
103
104        let mut concat = Vec::with_capacity(nonce.len() + bytes.len());
105        concat.extend_from_slice(&nonce);
106        concat.extend_from_slice(&bytes);
107        Ok(concat)
108    }
109
110    fn decrypt_and_unpad(
111        &self,
112        bytes: impl AsRef<[u8]>,
113    ) -> Result<Vec<u8>, ProfileCipherError> {
114        let bytes = bytes.as_ref();
115        let nonce: [u8; 12] = bytes[0..12]
116            .try_into()
117            .expect("fixed length nonce material");
118        let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
119
120        let mut plaintext = cipher
121            .decrypt(&nonce.into(), &bytes[12..])
122            .map_err(|_| ProfileCipherError::EncryptionError)?;
123
124        // Unpad
125        let len = plaintext
126            .iter()
127            // Search the first non-0 char...
128            .rposition(|x| *x != 0)
129            // ...and strip until right after.
130            .map(|x| x + 1)
131            // If it's all zeroes, the string is 0-length.
132            .unwrap_or(0);
133        plaintext.truncate(len);
134        Ok(plaintext)
135    }
136
137    pub fn decrypt(
138        &self,
139        encrypted_profile: SignalServiceProfile,
140    ) -> Result<Profile, ProfileCipherError> {
141        let name = encrypted_profile
142            .name
143            .as_ref()
144            .map(|data| self.decrypt_name(data))
145            .transpose()?
146            .flatten();
147        let about = encrypted_profile
148            .about
149            .as_ref()
150            .map(|data| self.decrypt_about(data))
151            .transpose()?;
152        let about_emoji = encrypted_profile
153            .about_emoji
154            .as_ref()
155            .map(|data| self.decrypt_emoji(data))
156            .transpose()?;
157
158        Ok(Profile {
159            name,
160            about,
161            about_emoji,
162            avatar: encrypted_profile.avatar,
163            unrestricted_unidentified_access: encrypted_profile
164                .unrestricted_unidentified_access,
165        })
166    }
167
168    pub fn decrypt_avatar(
169        &self,
170        bytes: &[u8],
171    ) -> Result<Vec<u8>, ProfileCipherError> {
172        self.decrypt_and_unpad(bytes)
173    }
174
175    pub fn encrypt_name<'inp, R: RngCore + CryptoRng>(
176        &self,
177        name: impl std::borrow::Borrow<ProfileName<&'inp str>>,
178        csprng: &mut R,
179    ) -> Result<Vec<u8>, ProfileCipherError> {
180        let name = name.borrow();
181        let bytes = name.serialize();
182        self.pad_and_encrypt(bytes, NAME_PADDING_BRACKETS, csprng)
183    }
184
185    pub fn decrypt_name(
186        &self,
187        bytes: impl AsRef<[u8]>,
188    ) -> Result<Option<ProfileName<String>>, ProfileCipherError> {
189        let bytes = bytes.as_ref();
190
191        let plaintext = self.decrypt_and_unpad(bytes)?;
192
193        Ok(ProfileName::<String>::deserialize(&plaintext)?)
194    }
195
196    pub fn encrypt_about<R: RngCore + CryptoRng>(
197        &self,
198        about: String,
199        csprng: &mut R,
200    ) -> Result<Vec<u8>, ProfileCipherError> {
201        let bytes = about.into_bytes();
202        self.pad_and_encrypt(bytes, ABOUT_PADDING_BRACKETS, csprng)
203    }
204
205    pub fn decrypt_about(
206        &self,
207        bytes: impl AsRef<[u8]>,
208    ) -> Result<String, ProfileCipherError> {
209        let bytes = bytes.as_ref();
210
211        let plaintext = self.decrypt_and_unpad(bytes)?;
212
213        // XXX This re-allocates.
214        Ok(std::str::from_utf8(&plaintext)?.into())
215    }
216
217    pub fn encrypt_emoji<R: RngCore + CryptoRng>(
218        &self,
219        emoji: String,
220        csprng: &mut R,
221    ) -> Result<Vec<u8>, ProfileCipherError> {
222        let bytes = emoji.into_bytes();
223        self.pad_and_encrypt(bytes, &[EMOJI_PADDED_LENGTH], csprng)
224    }
225
226    pub fn decrypt_emoji(
227        &self,
228        bytes: impl AsRef<[u8]>,
229    ) -> Result<String, ProfileCipherError> {
230        let bytes = bytes.as_ref();
231
232        let plaintext = self.decrypt_and_unpad(bytes)?;
233
234        // XXX This re-allocates.
235        Ok(std::str::from_utf8(&plaintext)?.into())
236    }
237}
238
239struct Rng06Shiv<'a, T>(&'a mut T);
240
241impl<T: rand_core::RngCore> rand_core_06::RngCore for Rng06Shiv<'_, T> {
242    fn next_u32(&mut self) -> u32 {
243        self.0.next_u32()
244    }
245
246    fn next_u64(&mut self) -> u64 {
247        self.0.next_u64()
248    }
249
250    fn fill_bytes(&mut self, dest: &mut [u8]) {
251        self.0.fill_bytes(dest)
252    }
253
254    fn try_fill_bytes(
255        &mut self,
256        dest: &mut [u8],
257    ) -> Result<(), rand_core_06::Error> {
258        self.0.fill_bytes(dest);
259        Ok(())
260    }
261}
262
263impl<T: rand_core::CryptoRng> rand_core_06::CryptoRng for Rng06Shiv<'_, T> {}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268    use crate::profile_name::ProfileName;
269    use rand::Rng;
270    use zkgroup::profiles::ProfileKey;
271
272    #[test]
273    fn roundtrip_name() {
274        let names = [
275            "Me and my guitar", // shorter that 53
276            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", // one shorter than 53
277            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzx", // exactly 53
278            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxf", // one more than 53
279            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxfoobar", // a bit more than 53
280        ];
281
282        // Test the test cases
283        assert_eq!(names[1].len(), NAME_PADDED_LENGTH_1 - 1);
284        assert_eq!(names[2].len(), NAME_PADDED_LENGTH_1);
285        assert_eq!(names[3].len(), NAME_PADDED_LENGTH_1 + 1);
286
287        let mut rng = rand::rng();
288        let some_randomness = rng.random();
289        let profile_key = ProfileKey::generate(some_randomness);
290        let cipher = ProfileCipher::new(profile_key);
291        for name in &names {
292            let profile_name = ProfileName::<&str> {
293                given_name: name,
294                family_name: None,
295            };
296            assert_eq!(profile_name.serialize().len(), name.len());
297            let encrypted =
298                cipher.encrypt_name(&profile_name, &mut rng).unwrap();
299            let decrypted = cipher.decrypt_name(encrypted).unwrap().unwrap();
300
301            assert_eq!(decrypted.as_ref(), profile_name);
302            assert_eq!(decrypted.serialize(), profile_name.serialize());
303            assert_eq!(&decrypted.given_name, name);
304        }
305    }
306
307    #[test]
308    fn roundtrip_about() {
309        let abouts = [
310            "Me and my guitar", // shorter that 53
311        ];
312
313        let mut rng = rand::rng();
314        let some_randomness = rng.random();
315        let profile_key = ProfileKey::generate(some_randomness);
316        let cipher = ProfileCipher::new(profile_key);
317
318        for &about in &abouts {
319            let encrypted =
320                cipher.encrypt_about(about.into(), &mut rng).unwrap();
321            let decrypted = cipher.decrypt_about(encrypted).unwrap();
322
323            assert_eq!(decrypted, about);
324        }
325    }
326
327    #[test]
328    fn roundtrip_emoji() {
329        let emojii = ["❤️", "💩", "🤣", "😲", "🐠"];
330
331        let mut rng = rand::rng();
332        let some_randomness = rng.random();
333        let profile_key = ProfileKey::generate(some_randomness);
334        let cipher = ProfileCipher::new(profile_key);
335
336        for &emoji in &emojii {
337            let encrypted =
338                cipher.encrypt_emoji(emoji.into(), &mut rng).unwrap();
339            let decrypted = cipher.decrypt_emoji(encrypted).unwrap();
340
341            assert_eq!(decrypted, emoji);
342        }
343    }
344}