libsignal_service/
profile_cipher.rs

1use std::convert::TryInto;
2
3use aes_gcm::{aead::Aead, AeadCore, AeadInPlace, Aes256Gcm, KeyInit};
4use rand::{CryptoRng, Rng};
5use zkgroup::profiles::ProfileKey;
6
7use crate::{
8    profile_name::ProfileName, push_service::SignalServiceProfile, Profile,
9};
10
11/// Encrypt and decrypt a [`ProfileName`] and other profile information.
12///
13/// # Example
14///
15/// ```rust
16/// # use libsignal_service::{profile_name::ProfileName, profile_cipher::ProfileCipher};
17/// # use zkgroup::profiles::ProfileKey;
18/// # use rand::Rng;
19/// # let mut rng = rand::thread_rng();
20/// # let some_randomness = rng.gen();
21/// let profile_key = ProfileKey::generate(some_randomness);
22/// let name = ProfileName::<&str> {
23///     given_name: "Bill",
24///     family_name: None,
25/// };
26/// let cipher = ProfileCipher::new(profile_key);
27/// let encrypted = cipher.encrypt_name(&name, &mut rng).unwrap();
28/// let decrypted = cipher.decrypt_name(&encrypted).unwrap().unwrap();
29/// assert_eq!(decrypted.as_ref(), name);
30/// ```
31pub struct ProfileCipher {
32    profile_key: ProfileKey,
33}
34
35const NAME_PADDED_LENGTH_1: usize = 53;
36const NAME_PADDED_LENGTH_2: usize = 257;
37const NAME_PADDING_BRACKETS: &[usize] =
38    &[NAME_PADDED_LENGTH_1, NAME_PADDED_LENGTH_2];
39
40const ABOUT_PADDED_LENGTH_1: usize = 128;
41const ABOUT_PADDED_LENGTH_2: usize = 254;
42const ABOUT_PADDED_LENGTH_3: usize = 512;
43const ABOUT_PADDING_BRACKETS: &[usize] = &[
44    ABOUT_PADDED_LENGTH_1,
45    ABOUT_PADDED_LENGTH_2,
46    ABOUT_PADDED_LENGTH_3,
47];
48
49const EMOJI_PADDED_LENGTH: usize = 32;
50
51#[derive(thiserror::Error, Debug)]
52pub enum ProfileCipherError {
53    #[error("Encryption error")]
54    EncryptionError,
55    #[error("UTF-8 decode error {0}")]
56    Utf8Error(#[from] std::str::Utf8Error),
57    #[error("Input name too long")]
58    InputTooLong,
59}
60
61fn pad_plaintext(
62    bytes: &mut Vec<u8>,
63    brackets: &[usize],
64) -> Result<usize, ProfileCipherError> {
65    let len = brackets
66        .iter()
67        .find(|x| **x >= bytes.len())
68        .ok_or(ProfileCipherError::InputTooLong)?;
69    let len: usize = *len;
70
71    bytes.resize(len, 0);
72    assert!(brackets.contains(&bytes.len()));
73
74    Ok(len)
75}
76
77impl ProfileCipher {
78    pub fn new(profile_key: ProfileKey) -> Self {
79        Self { profile_key }
80    }
81
82    pub fn into_inner(self) -> ProfileKey {
83        self.profile_key
84    }
85
86    fn pad_and_encrypt<R: Rng + CryptoRng>(
87        &self,
88        mut bytes: Vec<u8>,
89        padding_brackets: &[usize],
90        csprng: &mut R,
91    ) -> Result<Vec<u8>, ProfileCipherError> {
92        let _len = pad_plaintext(&mut bytes, padding_brackets)?;
93
94        let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
95        let nonce = Aes256Gcm::generate_nonce(csprng);
96
97        cipher
98            .encrypt_in_place(&nonce, b"", &mut bytes)
99            .map_err(|_| ProfileCipherError::EncryptionError)?;
100
101        let mut concat = Vec::with_capacity(nonce.len() + bytes.len());
102        concat.extend_from_slice(&nonce);
103        concat.extend_from_slice(&bytes);
104        Ok(concat)
105    }
106
107    fn decrypt_and_unpad(
108        &self,
109        bytes: impl AsRef<[u8]>,
110    ) -> Result<Vec<u8>, ProfileCipherError> {
111        let bytes = bytes.as_ref();
112        let nonce: [u8; 12] = bytes[0..12]
113            .try_into()
114            .expect("fixed length nonce material");
115        let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
116
117        let mut plaintext = cipher
118            .decrypt(&nonce.into(), &bytes[12..])
119            .map_err(|_| ProfileCipherError::EncryptionError)?;
120
121        // Unpad
122        let len = plaintext
123            .iter()
124            // Search the first non-0 char...
125            .rposition(|x| *x != 0)
126            // ...and strip until right after.
127            .map(|x| x + 1)
128            // If it's all zeroes, the string is 0-length.
129            .unwrap_or(0);
130        plaintext.truncate(len);
131        Ok(plaintext)
132    }
133
134    pub fn decrypt(
135        &self,
136        encrypted_profile: SignalServiceProfile,
137    ) -> Result<Profile, ProfileCipherError> {
138        let name = encrypted_profile
139            .name
140            .as_ref()
141            .map(|data| self.decrypt_name(data))
142            .transpose()?
143            .flatten();
144        let about = encrypted_profile
145            .about
146            .as_ref()
147            .map(|data| self.decrypt_about(data))
148            .transpose()?;
149        let about_emoji = encrypted_profile
150            .about_emoji
151            .as_ref()
152            .map(|data| self.decrypt_emoji(data))
153            .transpose()?;
154
155        Ok(Profile {
156            name,
157            about,
158            about_emoji,
159            avatar: encrypted_profile.avatar,
160            unrestricted_unidentified_access: encrypted_profile
161                .unrestricted_unidentified_access,
162        })
163    }
164
165    pub fn decrypt_avatar(
166        &self,
167        bytes: &[u8],
168    ) -> Result<Vec<u8>, ProfileCipherError> {
169        self.decrypt_and_unpad(bytes)
170    }
171
172    pub fn encrypt_name<'inp, R: Rng + CryptoRng>(
173        &self,
174        name: impl std::borrow::Borrow<ProfileName<&'inp str>>,
175        csprng: &mut R,
176    ) -> Result<Vec<u8>, ProfileCipherError> {
177        let name = name.borrow();
178        let bytes = name.serialize();
179        self.pad_and_encrypt(bytes, NAME_PADDING_BRACKETS, csprng)
180    }
181
182    pub fn decrypt_name(
183        &self,
184        bytes: impl AsRef<[u8]>,
185    ) -> Result<Option<ProfileName<String>>, ProfileCipherError> {
186        let bytes = bytes.as_ref();
187
188        let plaintext = self.decrypt_and_unpad(bytes)?;
189
190        Ok(ProfileName::<String>::deserialize(&plaintext)?)
191    }
192
193    pub fn encrypt_about<R: Rng + CryptoRng>(
194        &self,
195        about: String,
196        csprng: &mut R,
197    ) -> Result<Vec<u8>, ProfileCipherError> {
198        let bytes = about.into_bytes();
199        self.pad_and_encrypt(bytes, ABOUT_PADDING_BRACKETS, csprng)
200    }
201
202    pub fn decrypt_about(
203        &self,
204        bytes: impl AsRef<[u8]>,
205    ) -> Result<String, ProfileCipherError> {
206        let bytes = bytes.as_ref();
207
208        let plaintext = self.decrypt_and_unpad(bytes)?;
209
210        // XXX This re-allocates.
211        Ok(std::str::from_utf8(&plaintext)?.into())
212    }
213
214    pub fn encrypt_emoji<R: Rng + CryptoRng>(
215        &self,
216        emoji: String,
217        csprng: &mut R,
218    ) -> Result<Vec<u8>, ProfileCipherError> {
219        let bytes = emoji.into_bytes();
220        self.pad_and_encrypt(bytes, &[EMOJI_PADDED_LENGTH], csprng)
221    }
222
223    pub fn decrypt_emoji(
224        &self,
225        bytes: impl AsRef<[u8]>,
226    ) -> Result<String, ProfileCipherError> {
227        let bytes = bytes.as_ref();
228
229        let plaintext = self.decrypt_and_unpad(bytes)?;
230
231        // XXX This re-allocates.
232        Ok(std::str::from_utf8(&plaintext)?.into())
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::profile_name::ProfileName;
240    use rand::Rng;
241    use zkgroup::profiles::ProfileKey;
242
243    #[test]
244    fn roundtrip_name() {
245        let names = [
246            "Me and my guitar", // shorter that 53
247            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", // one shorter than 53
248            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzx", // exactly 53
249            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxf", // one more than 53
250            "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxfoobar", // a bit more than 53
251        ];
252
253        // Test the test cases
254        assert_eq!(names[1].len(), NAME_PADDED_LENGTH_1 - 1);
255        assert_eq!(names[2].len(), NAME_PADDED_LENGTH_1);
256        assert_eq!(names[3].len(), NAME_PADDED_LENGTH_1 + 1);
257
258        let mut rng = rand::thread_rng();
259        let some_randomness = rng.gen();
260        let profile_key = ProfileKey::generate(some_randomness);
261        let cipher = ProfileCipher::new(profile_key);
262        for name in &names {
263            let profile_name = ProfileName::<&str> {
264                given_name: name,
265                family_name: None,
266            };
267            assert_eq!(profile_name.serialize().len(), name.len());
268            let encrypted =
269                cipher.encrypt_name(&profile_name, &mut rng).unwrap();
270            let decrypted = cipher.decrypt_name(encrypted).unwrap().unwrap();
271
272            assert_eq!(decrypted.as_ref(), profile_name);
273            assert_eq!(decrypted.serialize(), profile_name.serialize());
274            assert_eq!(&decrypted.given_name, name);
275        }
276    }
277
278    #[test]
279    fn roundtrip_about() {
280        let abouts = [
281            "Me and my guitar", // shorter that 53
282        ];
283
284        let mut rng = rand::thread_rng();
285        let some_randomness = rng.gen();
286        let profile_key = ProfileKey::generate(some_randomness);
287        let cipher = ProfileCipher::new(profile_key);
288
289        for &about in &abouts {
290            let encrypted =
291                cipher.encrypt_about(about.into(), &mut rng).unwrap();
292            let decrypted = cipher.decrypt_about(encrypted).unwrap();
293
294            assert_eq!(decrypted, about);
295        }
296    }
297
298    #[test]
299    fn roundtrip_emoji() {
300        let emojii = ["❤️", "💩", "🤣", "😲", "🐠"];
301
302        let mut rng = rand::thread_rng();
303        let some_randomness = rng.gen();
304        let profile_key = ProfileKey::generate(some_randomness);
305        let cipher = ProfileCipher::new(profile_key);
306
307        for &emoji in &emojii {
308            let encrypted =
309                cipher.encrypt_emoji(emoji.into(), &mut rng).unwrap();
310            let decrypted = cipher.decrypt_emoji(encrypted).unwrap();
311
312            assert_eq!(decrypted, emoji);
313        }
314    }
315}