1#![allow(non_snake_case)]
7
8use curve25519_dalek_signal::ristretto::RistrettoPoint;
9use lazy_static::lazy_static;
10use partial_default::PartialDefault;
11use serde::{Deserialize, Serialize};
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
13use zkcredential::attributes::Attribute;
14
15use crate::common::errors::*;
16use crate::common::sho::*;
17use crate::common::simple_types::*;
18use crate::crypto::profile_key_struct;
19
20lazy_static! {
21 static ref SYSTEM_PARAMS: SystemParams =
22 crate::deserialize::<SystemParams>(&SystemParams::SYSTEM_HARDCODED).unwrap();
23}
24
25#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialDefault)]
26pub struct SystemParams {
27 pub(crate) G_b1: RistrettoPoint,
28 pub(crate) G_b2: RistrettoPoint,
29}
30
31pub type KeyPair = zkcredential::attributes::KeyPair<ProfileKeyEncryptionDomain>;
32pub type PublicKey = zkcredential::attributes::PublicKey<ProfileKeyEncryptionDomain>;
33pub type Ciphertext = zkcredential::attributes::Ciphertext<ProfileKeyEncryptionDomain>;
34
35impl SystemParams {
36 pub fn generate() -> Self {
37 let mut sho = Sho::new(
38 b"Signal_ZKGroup_20200424_Constant_ProfileKeyEncryption_SystemParams_Generate",
39 b"",
40 );
41 let G_b1 = sho.get_point();
42 let G_b2 = sho.get_point();
43 SystemParams { G_b1, G_b2 }
44 }
45
46 pub fn get_hardcoded() -> SystemParams {
47 *SYSTEM_PARAMS
48 }
49
50 const SYSTEM_HARDCODED: [u8; 64] = [
51 0xf6, 0xba, 0xa3, 0x17, 0xce, 0x18, 0x39, 0xc9, 0x3d, 0x61, 0x7e, 0xc, 0xd8, 0x37, 0xd1,
52 0x9d, 0xa9, 0xc8, 0xa4, 0xc5, 0x20, 0xbf, 0x7c, 0x51, 0xb1, 0xe6, 0xc2, 0xcb, 0x2a, 0x4,
53 0x9c, 0x61, 0x2e, 0x1, 0x75, 0x89, 0x4c, 0x87, 0x30, 0xb2, 0x3, 0xab, 0x3b, 0xd9, 0x8e,
54 0xcb, 0x2d, 0x81, 0xab, 0xac, 0xb6, 0x5f, 0x8a, 0x61, 0x24, 0xf4, 0x97, 0x71, 0xd1, 0x4a,
55 0x98, 0x52, 0x12, 0xc,
56 ];
57}
58
59pub struct ProfileKeyEncryptionDomain;
60impl zkcredential::attributes::Domain for ProfileKeyEncryptionDomain {
61 type Attribute = profile_key_struct::ProfileKeyStruct;
62
63 const ID: &'static str = "Signal_ZKGroup_20231011_ProfileKeyEncryption";
64
65 fn G_a() -> [RistrettoPoint; 2] {
66 let system = SystemParams::get_hardcoded();
67 [system.G_b1, system.G_b2]
68 }
69}
70
71impl ProfileKeyEncryptionDomain {
72 pub(crate) fn decrypt(
73 key_pair: &KeyPair,
74 ciphertext: &Ciphertext,
75 uid_bytes: UidBytes,
76 ) -> Result<profile_key_struct::ProfileKeyStruct, ZkGroupVerificationFailure> {
77 let M4 = key_pair
78 .decrypt_to_second_point(ciphertext)
79 .map_err(|_| ZkGroupVerificationFailure)?;
80 let (mask, candidates) = M4.decode_253_bits();
81
82 let target_M3 = key_pair.a1.invert() * ciphertext.as_points()[0];
83
84 let mut retval: profile_key_struct::ProfileKeyStruct = PartialDefault::partial_default();
85 let mut n_found = 0;
86 #[allow(clippy::needless_range_loop)]
87 for i in 0..8 {
88 let is_valid_fe = Choice::from((mask >> i) & 1);
89 let profile_key_bytes: ProfileKeyBytes = candidates[i];
90 for j in 0..8 {
91 let mut pk = profile_key_bytes;
92 if ((j >> 2) & 1) == 1 {
93 pk[0] |= 0x01;
94 }
95 if ((j >> 1) & 1) == 1 {
96 pk[31] |= 0x80;
97 }
98 if (j & 1) == 1 {
99 pk[31] |= 0x40;
100 }
101 let M3 = profile_key_struct::ProfileKeyStruct::calc_M3(pk, uid_bytes);
102 let candidate_retval = profile_key_struct::ProfileKeyStruct { bytes: pk, M3, M4 };
103 let found = M3.ct_eq(&target_M3) & is_valid_fe;
104 retval.conditional_assign(&candidate_retval, found);
105 n_found += found.unwrap_u8();
106 }
107 }
108 if n_found == 1 {
109 Ok(retval)
110 } else {
111 Err(ZkGroupVerificationFailure)
112 }
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use crate::common::constants::*;
120
121 #[test]
122 fn test_profile_key_encryption() {
123 let master_key = TEST_ARRAY_32_1;
124 let mut sho = Sho::new(b"Test_Profile_Key_Encryption", &master_key);
125
126 assert!(SystemParams::generate() == SystemParams::get_hardcoded());
129
130 let key_pair = KeyPair::derive_from(sho.as_mut());
131
132 let key_pair_bytes = bincode::serialize(&key_pair).unwrap();
134 match bincode::deserialize::<KeyPair>(&key_pair_bytes[0..key_pair_bytes.len() - 1]) {
135 Err(_) => (),
136 _ => unreachable!(),
137 };
138 let key_pair2: KeyPair = bincode::deserialize(&key_pair_bytes).unwrap();
139 assert!(key_pair == key_pair2);
140
141 let profile_key_bytes = TEST_ARRAY_32_1;
142 let uid_bytes = TEST_ARRAY_16_1;
143 let profile_key = profile_key_struct::ProfileKeyStruct::new(profile_key_bytes, uid_bytes);
144 let ciphertext = key_pair.encrypt(&profile_key);
145
146 let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
148 assert!(ciphertext_bytes.len() == 64);
149 let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
150 assert!(ciphertext == ciphertext2);
151 println!("ciphertext_bytes = {:#x?}", ciphertext_bytes);
152 assert!(
153 ciphertext_bytes
154 == vec![
155 0x56, 0x18, 0xcb, 0x4c, 0x7d, 0x72, 0x1e, 0x1, 0x2b, 0x22, 0xf0, 0x77, 0xef,
156 0x12, 0x64, 0xf6, 0xb1, 0x43, 0xbb, 0x59, 0x7a, 0x1d, 0x66, 0x5a, 0x70, 0xaa,
157 0x84, 0x24, 0x5f, 0x24, 0x6d, 0x20, 0xba, 0xdb, 0x97, 0x47, 0x4a, 0x56, 0xf4,
158 0xb5, 0x36, 0x1a, 0xec, 0xa9, 0xd1, 0x18, 0xb7, 0x0, 0x4e, 0x14, 0x9, 0x71,
159 0x99, 0xa, 0xab, 0x2a, 0xf2, 0x43, 0x2d, 0x3f, 0x8f, 0x7d, 0x21, 0x3a,
160 ]
161 );
162
163 let plaintext =
164 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext2, uid_bytes).unwrap();
165 assert!(plaintext == profile_key);
166
167 let mut sho = Sho::new(b"Test_Repeated_ProfileKeyEnc/Dec", b"seed");
168 for _ in 0..100 {
169 let mut uid_bytes: UidBytes = Default::default();
170 let mut profile_key_bytes: ProfileKeyBytes = Default::default();
171
172 uid_bytes.copy_from_slice(&sho.squeeze(UUID_LEN)[..]);
173 profile_key_bytes.copy_from_slice(&sho.squeeze(PROFILE_KEY_LEN)[..]);
174
175 let profile_key =
176 profile_key_struct::ProfileKeyStruct::new(profile_key_bytes, uid_bytes);
177 let ciphertext = key_pair.encrypt(&profile_key);
178 assert!(
179 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
180 == profile_key
181 );
182 }
183
184 let uid_bytes = TEST_ARRAY_16;
185 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32, TEST_ARRAY_16);
186 let ciphertext = key_pair.encrypt(&profile_key);
187 assert!(
188 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
189 == profile_key
190 );
191
192 let uid_bytes = TEST_ARRAY_16;
193 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32_2, TEST_ARRAY_16);
194 let ciphertext = key_pair.encrypt(&profile_key);
195 assert!(
196 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
197 == profile_key
198 );
199
200 let uid_bytes = TEST_ARRAY_16;
201 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32_3, TEST_ARRAY_16);
202 let ciphertext = key_pair.encrypt(&profile_key);
203 assert!(
204 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
205 == profile_key
206 );
207
208 let uid_bytes = TEST_ARRAY_16;
209 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32_4, TEST_ARRAY_16);
210 let ciphertext = key_pair.encrypt(&profile_key);
211 assert!(
212 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
213 == profile_key
214 );
215 }
216}