1use std::convert::TryInto;
2
3use aes_gcm::{aead::Aead, AeadCore, AeadInPlace, Aes256Gcm, KeyInit};
4use rand::{CryptoRng, Rng};
5use zkgroup::profiles::ProfileKey;
6
7use crate::{
8 profile_name::ProfileName, push_service::SignalServiceProfile, Profile,
9};
10
11pub struct ProfileCipher {
32 profile_key: ProfileKey,
33}
34
35const NAME_PADDED_LENGTH_1: usize = 53;
36const NAME_PADDED_LENGTH_2: usize = 257;
37const NAME_PADDING_BRACKETS: &[usize] =
38 &[NAME_PADDED_LENGTH_1, NAME_PADDED_LENGTH_2];
39
40const ABOUT_PADDED_LENGTH_1: usize = 128;
41const ABOUT_PADDED_LENGTH_2: usize = 254;
42const ABOUT_PADDED_LENGTH_3: usize = 512;
43const ABOUT_PADDING_BRACKETS: &[usize] = &[
44 ABOUT_PADDED_LENGTH_1,
45 ABOUT_PADDED_LENGTH_2,
46 ABOUT_PADDED_LENGTH_3,
47];
48
49const EMOJI_PADDED_LENGTH: usize = 32;
50
51#[derive(thiserror::Error, Debug)]
52pub enum ProfileCipherError {
53 #[error("Encryption error")]
54 EncryptionError,
55 #[error("UTF-8 decode error {0}")]
56 Utf8Error(#[from] std::str::Utf8Error),
57 #[error("Input name too long")]
58 InputTooLong,
59}
60
61fn pad_plaintext(
62 bytes: &mut Vec<u8>,
63 brackets: &[usize],
64) -> Result<usize, ProfileCipherError> {
65 let len = brackets
66 .iter()
67 .find(|x| **x >= bytes.len())
68 .ok_or(ProfileCipherError::InputTooLong)?;
69 let len: usize = *len;
70
71 bytes.resize(len, 0);
72 assert!(brackets.contains(&bytes.len()));
73
74 Ok(len)
75}
76
77impl ProfileCipher {
78 pub fn new(profile_key: ProfileKey) -> Self {
79 Self { profile_key }
80 }
81
82 pub fn into_inner(self) -> ProfileKey {
83 self.profile_key
84 }
85
86 fn pad_and_encrypt<R: Rng + CryptoRng>(
87 &self,
88 mut bytes: Vec<u8>,
89 padding_brackets: &[usize],
90 csprng: &mut R,
91 ) -> Result<Vec<u8>, ProfileCipherError> {
92 let _len = pad_plaintext(&mut bytes, padding_brackets)?;
93
94 let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
95 let nonce = Aes256Gcm::generate_nonce(csprng);
96
97 cipher
98 .encrypt_in_place(&nonce, b"", &mut bytes)
99 .map_err(|_| ProfileCipherError::EncryptionError)?;
100
101 let mut concat = Vec::with_capacity(nonce.len() + bytes.len());
102 concat.extend_from_slice(&nonce);
103 concat.extend_from_slice(&bytes);
104 Ok(concat)
105 }
106
107 fn decrypt_and_unpad(
108 &self,
109 bytes: impl AsRef<[u8]>,
110 ) -> Result<Vec<u8>, ProfileCipherError> {
111 let bytes = bytes.as_ref();
112 let nonce: [u8; 12] = bytes[0..12]
113 .try_into()
114 .expect("fixed length nonce material");
115 let cipher = Aes256Gcm::new(&self.profile_key.get_bytes().into());
116
117 let mut plaintext = cipher
118 .decrypt(&nonce.into(), &bytes[12..])
119 .map_err(|_| ProfileCipherError::EncryptionError)?;
120
121 let len = plaintext
123 .iter()
124 .rposition(|x| *x != 0)
126 .map(|x| x + 1)
128 .unwrap_or(0);
130 plaintext.truncate(len);
131 Ok(plaintext)
132 }
133
134 pub fn decrypt(
135 &self,
136 encrypted_profile: SignalServiceProfile,
137 ) -> Result<Profile, ProfileCipherError> {
138 let name = encrypted_profile
139 .name
140 .as_ref()
141 .map(|data| self.decrypt_name(data))
142 .transpose()?
143 .flatten();
144 let about = encrypted_profile
145 .about
146 .as_ref()
147 .map(|data| self.decrypt_about(data))
148 .transpose()?;
149 let about_emoji = encrypted_profile
150 .about_emoji
151 .as_ref()
152 .map(|data| self.decrypt_emoji(data))
153 .transpose()?;
154
155 Ok(Profile {
156 name,
157 about,
158 about_emoji,
159 avatar: encrypted_profile.avatar,
160 unrestricted_unidentified_access: encrypted_profile
161 .unrestricted_unidentified_access,
162 })
163 }
164
165 pub fn decrypt_avatar(
166 &self,
167 bytes: &[u8],
168 ) -> Result<Vec<u8>, ProfileCipherError> {
169 self.decrypt_and_unpad(bytes)
170 }
171
172 pub fn encrypt_name<'inp, R: Rng + CryptoRng>(
173 &self,
174 name: impl std::borrow::Borrow<ProfileName<&'inp str>>,
175 csprng: &mut R,
176 ) -> Result<Vec<u8>, ProfileCipherError> {
177 let name = name.borrow();
178 let bytes = name.serialize();
179 self.pad_and_encrypt(bytes, NAME_PADDING_BRACKETS, csprng)
180 }
181
182 pub fn decrypt_name(
183 &self,
184 bytes: impl AsRef<[u8]>,
185 ) -> Result<Option<ProfileName<String>>, ProfileCipherError> {
186 let bytes = bytes.as_ref();
187
188 let plaintext = self.decrypt_and_unpad(bytes)?;
189
190 Ok(ProfileName::<String>::deserialize(&plaintext)?)
191 }
192
193 pub fn encrypt_about<R: Rng + CryptoRng>(
194 &self,
195 about: String,
196 csprng: &mut R,
197 ) -> Result<Vec<u8>, ProfileCipherError> {
198 let bytes = about.into_bytes();
199 self.pad_and_encrypt(bytes, ABOUT_PADDING_BRACKETS, csprng)
200 }
201
202 pub fn decrypt_about(
203 &self,
204 bytes: impl AsRef<[u8]>,
205 ) -> Result<String, ProfileCipherError> {
206 let bytes = bytes.as_ref();
207
208 let plaintext = self.decrypt_and_unpad(bytes)?;
209
210 Ok(std::str::from_utf8(&plaintext)?.into())
212 }
213
214 pub fn encrypt_emoji<R: Rng + CryptoRng>(
215 &self,
216 emoji: String,
217 csprng: &mut R,
218 ) -> Result<Vec<u8>, ProfileCipherError> {
219 let bytes = emoji.into_bytes();
220 self.pad_and_encrypt(bytes, &[EMOJI_PADDED_LENGTH], csprng)
221 }
222
223 pub fn decrypt_emoji(
224 &self,
225 bytes: impl AsRef<[u8]>,
226 ) -> Result<String, ProfileCipherError> {
227 let bytes = bytes.as_ref();
228
229 let plaintext = self.decrypt_and_unpad(bytes)?;
230
231 Ok(std::str::from_utf8(&plaintext)?.into())
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::profile_name::ProfileName;
240 use rand::Rng;
241 use zkgroup::profiles::ProfileKey;
242
243 #[test]
244 fn roundtrip_name() {
245 let names = [
246 "Me and my guitar", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzx", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxf", "abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzxfoobar", ];
252
253 assert_eq!(names[1].len(), NAME_PADDED_LENGTH_1 - 1);
255 assert_eq!(names[2].len(), NAME_PADDED_LENGTH_1);
256 assert_eq!(names[3].len(), NAME_PADDED_LENGTH_1 + 1);
257
258 let mut rng = rand::thread_rng();
259 let some_randomness = rng.gen();
260 let profile_key = ProfileKey::generate(some_randomness);
261 let cipher = ProfileCipher::new(profile_key);
262 for name in &names {
263 let profile_name = ProfileName::<&str> {
264 given_name: name,
265 family_name: None,
266 };
267 assert_eq!(profile_name.serialize().len(), name.len());
268 let encrypted =
269 cipher.encrypt_name(&profile_name, &mut rng).unwrap();
270 let decrypted = cipher.decrypt_name(encrypted).unwrap().unwrap();
271
272 assert_eq!(decrypted.as_ref(), profile_name);
273 assert_eq!(decrypted.serialize(), profile_name.serialize());
274 assert_eq!(&decrypted.given_name, name);
275 }
276 }
277
278 #[test]
279 fn roundtrip_about() {
280 let abouts = [
281 "Me and my guitar", ];
283
284 let mut rng = rand::thread_rng();
285 let some_randomness = rng.gen();
286 let profile_key = ProfileKey::generate(some_randomness);
287 let cipher = ProfileCipher::new(profile_key);
288
289 for &about in &abouts {
290 let encrypted =
291 cipher.encrypt_about(about.into(), &mut rng).unwrap();
292 let decrypted = cipher.decrypt_about(encrypted).unwrap();
293
294 assert_eq!(decrypted, about);
295 }
296 }
297
298 #[test]
299 fn roundtrip_emoji() {
300 let emojii = ["❤️", "💩", "🤣", "😲", "🐠"];
301
302 let mut rng = rand::thread_rng();
303 let some_randomness = rng.gen();
304 let profile_key = ProfileKey::generate(some_randomness);
305 let cipher = ProfileCipher::new(profile_key);
306
307 for &emoji in &emojii {
308 let encrypted =
309 cipher.encrypt_emoji(emoji.into(), &mut rng).unwrap();
310 let decrypted = cipher.decrypt_emoji(encrypted).unwrap();
311
312 assert_eq!(decrypted, emoji);
313 }
314 }
315}