1use std::convert::TryInto;
2
3use aes_gcm::{aead::Aead, AeadCore, AeadInPlace, Aes256Gcm, KeyInit};
4use rand::{rand_core, CryptoRng, RngCore};
5use zkgroup::profiles::ProfileKey;
6
7use crate::{
8 profile_name::ProfileName, websocket::profile::SignalServiceProfile,
9 Profile,
10};
11
12pub struct ProfileCipher {
33 profile_key: ProfileKey,
34}
35
36const NAME_PADDED_LENGTH_1: usize = 53;
37const NAME_PADDED_LENGTH_2: usize = 257;
38const NAME_PADDING_BRACKETS: &[usize] =
39 &[NAME_PADDED_LENGTH_1, NAME_PADDED_LENGTH_2];
40
41const ABOUT_PADDED_LENGTH_1: usize = 128;
42const ABOUT_PADDED_LENGTH_2: usize = 254;
43const ABOUT_PADDED_LENGTH_3: usize = 512;
44const ABOUT_PADDING_BRACKETS: &[usize] = &[
45 ABOUT_PADDED_LENGTH_1,
46 ABOUT_PADDED_LENGTH_2,
47 ABOUT_PADDED_LENGTH_3,
48];
49
50const EMOJI_PADDED_LENGTH: usize = 32;
51
52#[derive(thiserror::Error, Debug)]
53pub enum ProfileCipherError {
54 #[error("Encryption error")]
55 EncryptionError,
56 #[error("UTF-8 decode error {0}")]
57 Utf8Error(#[from] std::str::Utf8Error),
58 #[error("Input name too long")]
59 InputTooLong,
60}
61
62fn pad_plaintext(
63 bytes: &mut Vec<u8>,
64 brackets: &[usize],
65) -> Result<usize, ProfileCipherError> {
66 let len = brackets
67 .iter()
68 .find(|x| **x >= bytes.len())
69 .ok_or(ProfileCipherError::InputTooLong)?;
70 let len: usize = *len;
71
72 bytes.resize(len, 0);
73 assert!(brackets.contains(&bytes.len()));
74
75 Ok(len)
76}
77
78impl ProfileCipher {
79 pub fn new(profile_key: ProfileKey) -> Self {
80 Self { profile_key }
81 }
82
83 pub fn into_inner(self) -> ProfileKey {
84 self.profile_key
85 }
86
87 fn pad_and_encrypt<R: RngCore + CryptoRng>(
88 &self,
89 mut bytes: Vec<u8>,
90 padding_brackets: &[usize],
91 csprng: &mut R,
92 ) -> Result<Vec<u8>, ProfileCipherError> {
93 let _len = pad_plaintext(&mut bytes, padding_brackets)?;
94
95 let csprng = Rng06Shiv(csprng);
96
97 let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
98 let nonce = Aes256Gcm::generate_nonce(csprng);
99
100 cipher
101 .encrypt_in_place(&nonce, b"", &mut bytes)
102 .map_err(|_| ProfileCipherError::EncryptionError)?;
103
104 let mut concat = Vec::with_capacity(nonce.len() + bytes.len());
105 concat.extend_from_slice(&nonce);
106 concat.extend_from_slice(&bytes);
107 Ok(concat)
108 }
109
110 fn decrypt_and_unpad(
111 &self,
112 bytes: impl AsRef<[u8]>,
113 ) -> Result<Vec<u8>, ProfileCipherError> {
114 let bytes = bytes.as_ref();
115 let nonce: [u8; 12] = bytes[0..12]
116 .try_into()
117 .expect("fixed length nonce material");
118 let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
119
120 let mut plaintext = cipher
121 .decrypt(&nonce.into(), &bytes[12..])
122 .map_err(|_| ProfileCipherError::EncryptionError)?;
123
124 let len = plaintext
126 .iter()
127 .rposition(|x| *x != 0)
129 .map(|x| x + 1)
131 .unwrap_or(0);
133 plaintext.truncate(len);
134 Ok(plaintext)
135 }
136
137 pub fn decrypt(
138 &self,
139 encrypted_profile: SignalServiceProfile,
140 ) -> Result<Profile, ProfileCipherError> {
141 let name = encrypted_profile
142 .name
143 .as_ref()
144 .map(|data| self.decrypt_name(data))
145 .transpose()?
146 .flatten();
147 let about = encrypted_profile
148 .about
149 .as_ref()
150 .map(|data| self.decrypt_about(data))
151 .transpose()?;
152 let about_emoji = encrypted_profile
153 .about_emoji
154 .as_ref()
155 .map(|data| self.decrypt_emoji(data))
156 .transpose()?;
157
158 Ok(Profile {
159 name,
160 about,
161 about_emoji,
162 avatar: encrypted_profile.avatar,
163 unrestricted_unidentified_access: encrypted_profile
164 .unrestricted_unidentified_access,
165 })
166 }
167
168 pub fn decrypt_avatar(
169 &self,
170 bytes: &[u8],
171 ) -> Result<Vec<u8>, ProfileCipherError> {
172 self.decrypt_and_unpad(bytes)
173 }
174
175 pub fn encrypt_name<'inp, R: RngCore + CryptoRng>(
176 &self,
177 name: impl std::borrow::Borrow<ProfileName<&'inp str>>,
178 csprng: &mut R,
179 ) -> Result<Vec<u8>, ProfileCipherError> {
180 let name = name.borrow();
181 let bytes = name.serialize();
182 self.pad_and_encrypt(bytes, NAME_PADDING_BRACKETS, csprng)
183 }
184
185 pub fn decrypt_name(
186 &self,
187 bytes: impl AsRef<[u8]>,
188 ) -> Result<Option<ProfileName<String>>, ProfileCipherError> {
189 let bytes = bytes.as_ref();
190
191 let plaintext = self.decrypt_and_unpad(bytes)?;
192
193 Ok(ProfileName::<String>::deserialize(&plaintext)?)
194 }
195
196 pub fn encrypt_about<R: RngCore + CryptoRng>(
197 &self,
198 about: String,
199 csprng: &mut R,
200 ) -> Result<Vec<u8>, ProfileCipherError> {
201 let bytes = about.into_bytes();
202 self.pad_and_encrypt(bytes, ABOUT_PADDING_BRACKETS, csprng)
203 }
204
205 pub fn decrypt_about(
206 &self,
207 bytes: impl AsRef<[u8]>,
208 ) -> Result<String, ProfileCipherError> {
209 let bytes = bytes.as_ref();
210
211 let plaintext = self.decrypt_and_unpad(bytes)?;
212
213 Ok(std::str::from_utf8(&plaintext)?.into())
215 }
216
217 pub fn encrypt_emoji<R: RngCore + CryptoRng>(
218 &self,
219 emoji: String,
220 csprng: &mut R,
221 ) -> Result<Vec<u8>, ProfileCipherError> {
222 let bytes = emoji.into_bytes();
223 self.pad_and_encrypt(bytes, &[EMOJI_PADDED_LENGTH], csprng)
224 }
225
226 pub fn decrypt_emoji(
227 &self,
228 bytes: impl AsRef<[u8]>,
229 ) -> Result<String, ProfileCipherError> {
230 let bytes = bytes.as_ref();
231
232 let plaintext = self.decrypt_and_unpad(bytes)?;
233
234 Ok(std::str::from_utf8(&plaintext)?.into())
236 }
237}
238
239struct Rng06Shiv<'a, T>(&'a mut T);
240
241impl<T: rand_core::RngCore> rand_core_06::RngCore for Rng06Shiv<'_, T> {
242 fn next_u32(&mut self) -> u32 {
243 self.0.next_u32()
244 }
245
246 fn next_u64(&mut self) -> u64 {
247 self.0.next_u64()
248 }
249
250 fn fill_bytes(&mut self, dest: &mut [u8]) {
251 self.0.fill_bytes(dest)
252 }
253
254 fn try_fill_bytes(
255 &mut self,
256 dest: &mut [u8],
257 ) -> Result<(), rand_core_06::Error> {
258 self.0.fill_bytes(dest);
259 Ok(())
260 }
261}
262
263impl<T: rand_core::CryptoRng> rand_core_06::CryptoRng for Rng06Shiv<'_, T> {}
264
265#[cfg(test)]
266mod tests {
267 use super::*;
268 use crate::profile_name::ProfileName;
269 use rand::Rng;
270 use zkgroup::profiles::ProfileKey;
271
272 #[test]
273 fn roundtrip_name() {
274 let names = [
275 "Me and my guitar", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzx", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxf", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxfoobar", ];
281
282 assert_eq!(names[1].len(), NAME_PADDED_LENGTH_1 - 1);
284 assert_eq!(names[2].len(), NAME_PADDED_LENGTH_1);
285 assert_eq!(names[3].len(), NAME_PADDED_LENGTH_1 + 1);
286
287 let mut rng = rand::rng();
288 let some_randomness = rng.random();
289 let profile_key = ProfileKey::generate(some_randomness);
290 let cipher = ProfileCipher::new(profile_key);
291 for name in &names {
292 let profile_name = ProfileName::<&str> {
293 given_name: name,
294 family_name: None,
295 };
296 assert_eq!(profile_name.serialize().len(), name.len());
297 let encrypted =
298 cipher.encrypt_name(&profile_name, &mut rng).unwrap();
299 let decrypted = cipher.decrypt_name(encrypted).unwrap().unwrap();
300
301 assert_eq!(decrypted.as_ref(), profile_name);
302 assert_eq!(decrypted.serialize(), profile_name.serialize());
303 assert_eq!(&decrypted.given_name, name);
304 }
305 }
306
307 #[test]
308 fn roundtrip_about() {
309 let abouts = [
310 "Me and my guitar", ];
312
313 let mut rng = rand::rng();
314 let some_randomness = rng.random();
315 let profile_key = ProfileKey::generate(some_randomness);
316 let cipher = ProfileCipher::new(profile_key);
317
318 for &about in &abouts {
319 let encrypted =
320 cipher.encrypt_about(about.into(), &mut rng).unwrap();
321 let decrypted = cipher.decrypt_about(encrypted).unwrap();
322
323 assert_eq!(decrypted, about);
324 }
325 }
326
327 #[test]
328 fn roundtrip_emoji() {
329 let emojii = ["❤️", "💩", "🤣", "😲", "🐠"];
330
331 let mut rng = rand::rng();
332 let some_randomness = rng.random();
333 let profile_key = ProfileKey::generate(some_randomness);
334 let cipher = ProfileCipher::new(profile_key);
335
336 for &emoji in &emojii {
337 let encrypted =
338 cipher.encrypt_emoji(emoji.into(), &mut rng).unwrap();
339 let decrypted = cipher.decrypt_emoji(encrypted).unwrap();
340
341 assert_eq!(decrypted, emoji);
342 }
343 }
344}