1#![allow(non_snake_case)]
7
8use std::sync::LazyLock;
9
10use curve25519_dalek_signal::ristretto::RistrettoPoint;
11use partial_default::PartialDefault;
12use serde::{Deserialize, Serialize};
13use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
14use zkcredential::attributes::Attribute;
15
16use crate::common::errors::*;
17use crate::common::sho::*;
18use crate::common::simple_types::*;
19use crate::crypto::profile_key_struct;
20
21static SYSTEM_PARAMS: LazyLock<SystemParams> =
22 LazyLock::new(|| crate::deserialize::<SystemParams>(&SystemParams::SYSTEM_HARDCODED).unwrap());
23
24#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialDefault)]
25pub struct SystemParams {
26 pub(crate) G_b1: RistrettoPoint,
27 pub(crate) G_b2: RistrettoPoint,
28}
29
30pub type KeyPair = zkcredential::attributes::KeyPair<ProfileKeyEncryptionDomain>;
31pub type PublicKey = zkcredential::attributes::PublicKey<ProfileKeyEncryptionDomain>;
32pub type Ciphertext = zkcredential::attributes::Ciphertext<ProfileKeyEncryptionDomain>;
33
34impl SystemParams {
35 pub fn generate() -> Self {
36 let mut sho = Sho::new(
37 b"Signal_ZKGroup_20200424_Constant_ProfileKeyEncryption_SystemParams_Generate",
38 b"",
39 );
40 let G_b1 = sho.get_point();
41 let G_b2 = sho.get_point();
42 SystemParams { G_b1, G_b2 }
43 }
44
45 pub fn get_hardcoded() -> SystemParams {
46 *SYSTEM_PARAMS
47 }
48
49 const SYSTEM_HARDCODED: [u8; 64] = [
50 0xf6, 0xba, 0xa3, 0x17, 0xce, 0x18, 0x39, 0xc9, 0x3d, 0x61, 0x7e, 0xc, 0xd8, 0x37, 0xd1,
51 0x9d, 0xa9, 0xc8, 0xa4, 0xc5, 0x20, 0xbf, 0x7c, 0x51, 0xb1, 0xe6, 0xc2, 0xcb, 0x2a, 0x4,
52 0x9c, 0x61, 0x2e, 0x1, 0x75, 0x89, 0x4c, 0x87, 0x30, 0xb2, 0x3, 0xab, 0x3b, 0xd9, 0x8e,
53 0xcb, 0x2d, 0x81, 0xab, 0xac, 0xb6, 0x5f, 0x8a, 0x61, 0x24, 0xf4, 0x97, 0x71, 0xd1, 0x4a,
54 0x98, 0x52, 0x12, 0xc,
55 ];
56}
57
58pub struct ProfileKeyEncryptionDomain;
59impl zkcredential::attributes::Domain for ProfileKeyEncryptionDomain {
60 type Attribute = profile_key_struct::ProfileKeyStruct;
61
62 const ID: &'static str = "Signal_ZKGroup_20231011_ProfileKeyEncryption";
63
64 fn G_a() -> [RistrettoPoint; 2] {
65 let system = SystemParams::get_hardcoded();
66 [system.G_b1, system.G_b2]
67 }
68}
69
70impl ProfileKeyEncryptionDomain {
71 pub(crate) fn decrypt(
72 key_pair: &KeyPair,
73 ciphertext: &Ciphertext,
74 uid_bytes: UidBytes,
75 ) -> Result<profile_key_struct::ProfileKeyStruct, ZkGroupVerificationFailure> {
76 let M4 = key_pair
77 .decrypt_to_second_point(ciphertext)
78 .map_err(|_| ZkGroupVerificationFailure)?;
79 let (mask, candidates) = M4.decode_253_bits();
80
81 let target_M3 = key_pair.a1.invert() * ciphertext.as_points()[0];
82
83 let mut retval: profile_key_struct::ProfileKeyStruct = PartialDefault::partial_default();
84 let mut n_found = 0;
85 #[allow(clippy::needless_range_loop)]
86 for i in 0..8 {
87 let is_valid_fe = Choice::from((mask >> i) & 1);
88 let profile_key_bytes: ProfileKeyBytes = candidates[i];
89 for j in 0..8 {
90 let mut pk = profile_key_bytes;
91 if ((j >> 2) & 1) == 1 {
92 pk[0] |= 0x01;
93 }
94 if ((j >> 1) & 1) == 1 {
95 pk[31] |= 0x80;
96 }
97 if (j & 1) == 1 {
98 pk[31] |= 0x40;
99 }
100 let M3 = profile_key_struct::ProfileKeyStruct::calc_M3(pk, uid_bytes);
101 let candidate_retval = profile_key_struct::ProfileKeyStruct { bytes: pk, M3, M4 };
102 let found = M3.ct_eq(&target_M3) & is_valid_fe;
103 retval.conditional_assign(&candidate_retval, found);
104 n_found += found.unwrap_u8();
105 }
106 }
107 if n_found == 1 {
108 Ok(retval)
109 } else {
110 Err(ZkGroupVerificationFailure)
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::common::constants::*;
119
120 #[test]
121 fn test_profile_key_encryption() {
122 let master_key = TEST_ARRAY_32_1;
123 let mut sho = Sho::new(b"Test_Profile_Key_Encryption", &master_key);
124
125 assert!(SystemParams::generate() == SystemParams::get_hardcoded());
128
129 let key_pair = KeyPair::derive_from(sho.as_mut());
130
131 let key_pair_bytes = bincode::serialize(&key_pair).unwrap();
133 match bincode::deserialize::<KeyPair>(&key_pair_bytes[0..key_pair_bytes.len() - 1]) {
134 Err(_) => (),
135 _ => unreachable!(),
136 };
137 let key_pair2: KeyPair = bincode::deserialize(&key_pair_bytes).unwrap();
138 assert!(key_pair == key_pair2);
139
140 let profile_key_bytes = TEST_ARRAY_32_1;
141 let uid_bytes = TEST_ARRAY_16_1;
142 let profile_key = profile_key_struct::ProfileKeyStruct::new(profile_key_bytes, uid_bytes);
143 let ciphertext = key_pair.encrypt(&profile_key);
144
145 let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
147 assert!(ciphertext_bytes.len() == 64);
148 let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
149 assert!(ciphertext == ciphertext2);
150 println!("ciphertext_bytes = {ciphertext_bytes:#x?}");
151 assert!(
152 ciphertext_bytes
153 == vec![
154 0x56, 0x18, 0xcb, 0x4c, 0x7d, 0x72, 0x1e, 0x1, 0x2b, 0x22, 0xf0, 0x77, 0xef,
155 0x12, 0x64, 0xf6, 0xb1, 0x43, 0xbb, 0x59, 0x7a, 0x1d, 0x66, 0x5a, 0x70, 0xaa,
156 0x84, 0x24, 0x5f, 0x24, 0x6d, 0x20, 0xba, 0xdb, 0x97, 0x47, 0x4a, 0x56, 0xf4,
157 0xb5, 0x36, 0x1a, 0xec, 0xa9, 0xd1, 0x18, 0xb7, 0x0, 0x4e, 0x14, 0x9, 0x71,
158 0x99, 0xa, 0xab, 0x2a, 0xf2, 0x43, 0x2d, 0x3f, 0x8f, 0x7d, 0x21, 0x3a,
159 ]
160 );
161
162 let plaintext =
163 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext2, uid_bytes).unwrap();
164 assert!(plaintext == profile_key);
165
166 let mut sho = Sho::new(b"Test_Repeated_ProfileKeyEnc/Dec", b"seed");
167 for _ in 0..100 {
168 let uid_bytes: UidBytes = sho.squeeze_as_array();
169 let profile_key_bytes: ProfileKeyBytes = sho.squeeze_as_array();
170
171 let profile_key =
172 profile_key_struct::ProfileKeyStruct::new(profile_key_bytes, uid_bytes);
173 let ciphertext = key_pair.encrypt(&profile_key);
174 assert!(
175 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
176 == profile_key
177 );
178 }
179
180 let uid_bytes = TEST_ARRAY_16;
181 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32, TEST_ARRAY_16);
182 let ciphertext = key_pair.encrypt(&profile_key);
183 assert!(
184 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
185 == profile_key
186 );
187
188 let uid_bytes = TEST_ARRAY_16;
189 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32_2, TEST_ARRAY_16);
190 let ciphertext = key_pair.encrypt(&profile_key);
191 assert!(
192 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
193 == profile_key
194 );
195
196 let uid_bytes = TEST_ARRAY_16;
197 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32_3, TEST_ARRAY_16);
198 let ciphertext = key_pair.encrypt(&profile_key);
199 assert!(
200 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
201 == profile_key
202 );
203
204 let uid_bytes = TEST_ARRAY_16;
205 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32_4, TEST_ARRAY_16);
206 let ciphertext = key_pair.encrypt(&profile_key);
207 assert!(
208 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
209 == profile_key
210 );
211 }
212}