1use std::convert::TryInto;
2
3use aes_gcm::{aead::Aead, AeadCore, AeadInPlace, Aes256Gcm, KeyInit};
4use rand::{rand_core, CryptoRng, RngCore};
5use zkgroup::profiles::ProfileKey;
6
7use crate::{
8 profile_name::ProfileName, push_service::SignalServiceProfile, Profile,
9};
10
11pub struct ProfileCipher {
32 profile_key: ProfileKey,
33}
34
35const NAME_PADDED_LENGTH_1: usize = 53;
36const NAME_PADDED_LENGTH_2: usize = 257;
37const NAME_PADDING_BRACKETS: &[usize] =
38 &[NAME_PADDED_LENGTH_1, NAME_PADDED_LENGTH_2];
39
40const ABOUT_PADDED_LENGTH_1: usize = 128;
41const ABOUT_PADDED_LENGTH_2: usize = 254;
42const ABOUT_PADDED_LENGTH_3: usize = 512;
43const ABOUT_PADDING_BRACKETS: &[usize] = &[
44 ABOUT_PADDED_LENGTH_1,
45 ABOUT_PADDED_LENGTH_2,
46 ABOUT_PADDED_LENGTH_3,
47];
48
49const EMOJI_PADDED_LENGTH: usize = 32;
50
51#[derive(thiserror::Error, Debug)]
52pub enum ProfileCipherError {
53 #[error("Encryption error")]
54 EncryptionError,
55 #[error("UTF-8 decode error {0}")]
56 Utf8Error(#[from] std::str::Utf8Error),
57 #[error("Input name too long")]
58 InputTooLong,
59}
60
61fn pad_plaintext(
62 bytes: &mut Vec<u8>,
63 brackets: &[usize],
64) -> Result<usize, ProfileCipherError> {
65 let len = brackets
66 .iter()
67 .find(|x| **x >= bytes.len())
68 .ok_or(ProfileCipherError::InputTooLong)?;
69 let len: usize = *len;
70
71 bytes.resize(len, 0);
72 assert!(brackets.contains(&bytes.len()));
73
74 Ok(len)
75}
76
77impl ProfileCipher {
78 pub fn new(profile_key: ProfileKey) -> Self {
79 Self { profile_key }
80 }
81
82 pub fn into_inner(self) -> ProfileKey {
83 self.profile_key
84 }
85
86 fn pad_and_encrypt<R: RngCore + CryptoRng>(
87 &self,
88 mut bytes: Vec<u8>,
89 padding_brackets: &[usize],
90 csprng: &mut R,
91 ) -> Result<Vec<u8>, ProfileCipherError> {
92 let _len = pad_plaintext(&mut bytes, padding_brackets)?;
93
94 let csprng = Rng06Shiv(csprng);
95
96 let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
97 let nonce = Aes256Gcm::generate_nonce(csprng);
98
99 cipher
100 .encrypt_in_place(&nonce, b"", &mut bytes)
101 .map_err(|_| ProfileCipherError::EncryptionError)?;
102
103 let mut concat = Vec::with_capacity(nonce.len() + bytes.len());
104 concat.extend_from_slice(&nonce);
105 concat.extend_from_slice(&bytes);
106 Ok(concat)
107 }
108
109 fn decrypt_and_unpad(
110 &self,
111 bytes: impl AsRef<[u8]>,
112 ) -> Result<Vec<u8>, ProfileCipherError> {
113 let bytes = bytes.as_ref();
114 let nonce: [u8; 12] = bytes[0..12]
115 .try_into()
116 .expect("fixed length nonce material");
117 let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
118
119 let mut plaintext = cipher
120 .decrypt(&nonce.into(), &bytes[12..])
121 .map_err(|_| ProfileCipherError::EncryptionError)?;
122
123 let len = plaintext
125 .iter()
126 .rposition(|x| *x != 0)
128 .map(|x| x + 1)
130 .unwrap_or(0);
132 plaintext.truncate(len);
133 Ok(plaintext)
134 }
135
136 pub fn decrypt(
137 &self,
138 encrypted_profile: SignalServiceProfile,
139 ) -> Result<Profile, ProfileCipherError> {
140 let name = encrypted_profile
141 .name
142 .as_ref()
143 .map(|data| self.decrypt_name(data))
144 .transpose()?
145 .flatten();
146 let about = encrypted_profile
147 .about
148 .as_ref()
149 .map(|data| self.decrypt_about(data))
150 .transpose()?;
151 let about_emoji = encrypted_profile
152 .about_emoji
153 .as_ref()
154 .map(|data| self.decrypt_emoji(data))
155 .transpose()?;
156
157 Ok(Profile {
158 name,
159 about,
160 about_emoji,
161 avatar: encrypted_profile.avatar,
162 unrestricted_unidentified_access: encrypted_profile
163 .unrestricted_unidentified_access,
164 })
165 }
166
167 pub fn decrypt_avatar(
168 &self,
169 bytes: &[u8],
170 ) -> Result<Vec<u8>, ProfileCipherError> {
171 self.decrypt_and_unpad(bytes)
172 }
173
174 pub fn encrypt_name<'inp, R: RngCore + CryptoRng>(
175 &self,
176 name: impl std::borrow::Borrow<ProfileName<&'inp str>>,
177 csprng: &mut R,
178 ) -> Result<Vec<u8>, ProfileCipherError> {
179 let name = name.borrow();
180 let bytes = name.serialize();
181 self.pad_and_encrypt(bytes, NAME_PADDING_BRACKETS, csprng)
182 }
183
184 pub fn decrypt_name(
185 &self,
186 bytes: impl AsRef<[u8]>,
187 ) -> Result<Option<ProfileName<String>>, ProfileCipherError> {
188 let bytes = bytes.as_ref();
189
190 let plaintext = self.decrypt_and_unpad(bytes)?;
191
192 Ok(ProfileName::<String>::deserialize(&plaintext)?)
193 }
194
195 pub fn encrypt_about<R: RngCore + CryptoRng>(
196 &self,
197 about: String,
198 csprng: &mut R,
199 ) -> Result<Vec<u8>, ProfileCipherError> {
200 let bytes = about.into_bytes();
201 self.pad_and_encrypt(bytes, ABOUT_PADDING_BRACKETS, csprng)
202 }
203
204 pub fn decrypt_about(
205 &self,
206 bytes: impl AsRef<[u8]>,
207 ) -> Result<String, ProfileCipherError> {
208 let bytes = bytes.as_ref();
209
210 let plaintext = self.decrypt_and_unpad(bytes)?;
211
212 Ok(std::str::from_utf8(&plaintext)?.into())
214 }
215
216 pub fn encrypt_emoji<R: RngCore + CryptoRng>(
217 &self,
218 emoji: String,
219 csprng: &mut R,
220 ) -> Result<Vec<u8>, ProfileCipherError> {
221 let bytes = emoji.into_bytes();
222 self.pad_and_encrypt(bytes, &[EMOJI_PADDED_LENGTH], csprng)
223 }
224
225 pub fn decrypt_emoji(
226 &self,
227 bytes: impl AsRef<[u8]>,
228 ) -> Result<String, ProfileCipherError> {
229 let bytes = bytes.as_ref();
230
231 let plaintext = self.decrypt_and_unpad(bytes)?;
232
233 Ok(std::str::from_utf8(&plaintext)?.into())
235 }
236}
237
238struct Rng06Shiv<'a, T>(&'a mut T);
239
240impl<T: rand_core::RngCore> rand_core_06::RngCore for Rng06Shiv<'_, T> {
241 fn next_u32(&mut self) -> u32 {
242 self.0.next_u32()
243 }
244
245 fn next_u64(&mut self) -> u64 {
246 self.0.next_u64()
247 }
248
249 fn fill_bytes(&mut self, dest: &mut [u8]) {
250 self.0.fill_bytes(dest)
251 }
252
253 fn try_fill_bytes(
254 &mut self,
255 dest: &mut [u8],
256 ) -> Result<(), rand_core_06::Error> {
257 self.0.fill_bytes(dest);
258 Ok(())
259 }
260}
261
262impl<T: rand_core::CryptoRng> rand_core_06::CryptoRng for Rng06Shiv<'_, T> {}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::profile_name::ProfileName;
268 use rand::Rng;
269 use zkgroup::profiles::ProfileKey;
270
271 #[test]
272 fn roundtrip_name() {
273 let names = [
274 "Me and my guitar", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzx", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxf", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxfoobar", ];
280
281 assert_eq!(names[1].len(), NAME_PADDED_LENGTH_1 - 1);
283 assert_eq!(names[2].len(), NAME_PADDED_LENGTH_1);
284 assert_eq!(names[3].len(), NAME_PADDED_LENGTH_1 + 1);
285
286 let mut rng = rand::rng();
287 let some_randomness = rng.random();
288 let profile_key = ProfileKey::generate(some_randomness);
289 let cipher = ProfileCipher::new(profile_key);
290 for name in &names {
291 let profile_name = ProfileName::<&str> {
292 given_name: name,
293 family_name: None,
294 };
295 assert_eq!(profile_name.serialize().len(), name.len());
296 let encrypted =
297 cipher.encrypt_name(&profile_name, &mut rng).unwrap();
298 let decrypted = cipher.decrypt_name(encrypted).unwrap().unwrap();
299
300 assert_eq!(decrypted.as_ref(), profile_name);
301 assert_eq!(decrypted.serialize(), profile_name.serialize());
302 assert_eq!(&decrypted.given_name, name);
303 }
304 }
305
306 #[test]
307 fn roundtrip_about() {
308 let abouts = [
309 "Me and my guitar", ];
311
312 let mut rng = rand::rng();
313 let some_randomness = rng.random();
314 let profile_key = ProfileKey::generate(some_randomness);
315 let cipher = ProfileCipher::new(profile_key);
316
317 for &about in &abouts {
318 let encrypted =
319 cipher.encrypt_about(about.into(), &mut rng).unwrap();
320 let decrypted = cipher.decrypt_about(encrypted).unwrap();
321
322 assert_eq!(decrypted, about);
323 }
324 }
325
326 #[test]
327 fn roundtrip_emoji() {
328 let emojii = ["❤️", "💩", "🤣", "😲", "🐠"];
329
330 let mut rng = rand::rng();
331 let some_randomness = rng.random();
332 let profile_key = ProfileKey::generate(some_randomness);
333 let cipher = ProfileCipher::new(profile_key);
334
335 for &emoji in &emojii {
336 let encrypted =
337 cipher.encrypt_emoji(emoji.into(), &mut rng).unwrap();
338 let decrypted = cipher.decrypt_emoji(encrypted).unwrap();
339
340 assert_eq!(decrypted, emoji);
341 }
342 }
343}