1#![allow(non_snake_case)]
7
8use std::sync::LazyLock;
9
10use curve25519_dalek_signal::ristretto::RistrettoPoint;
11use partial_default::PartialDefault;
12use serde::{Deserialize, Serialize};
13use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
14use zkcredential::attributes::Attribute;
15
16use crate::common::errors::*;
17use crate::common::sho::*;
18use crate::common::simple_types::*;
19use crate::crypto::profile_key_struct;
20
21static SYSTEM_PARAMS: LazyLock<SystemParams> = LazyLock::new(|| {
22 crate::deserialize(&SystemParams::SYSTEM_HARDCODED).expect("valid hardcoded params")
23});
24
25#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialDefault)]
26pub struct SystemParams {
27 pub(crate) G_b1: RistrettoPoint,
28 pub(crate) G_b2: RistrettoPoint,
29}
30
31pub type KeyPair = zkcredential::attributes::KeyPair<ProfileKeyEncryptionDomain>;
32pub type PublicKey = zkcredential::attributes::PublicKey<ProfileKeyEncryptionDomain>;
33pub type Ciphertext = zkcredential::attributes::Ciphertext<ProfileKeyEncryptionDomain>;
34
35impl SystemParams {
36 pub fn generate() -> Self {
37 let mut sho = Sho::new(
38 b"Signal_ZKGroup_20200424_Constant_ProfileKeyEncryption_SystemParams_Generate",
39 b"",
40 );
41 let G_b1 = sho.get_point();
42 let G_b2 = sho.get_point();
43 SystemParams { G_b1, G_b2 }
44 }
45
46 pub fn get_hardcoded() -> SystemParams {
47 *SYSTEM_PARAMS
48 }
49
50 const SYSTEM_HARDCODED: [u8; 64] = [
51 0xf6, 0xba, 0xa3, 0x17, 0xce, 0x18, 0x39, 0xc9, 0x3d, 0x61, 0x7e, 0xc, 0xd8, 0x37, 0xd1,
52 0x9d, 0xa9, 0xc8, 0xa4, 0xc5, 0x20, 0xbf, 0x7c, 0x51, 0xb1, 0xe6, 0xc2, 0xcb, 0x2a, 0x4,
53 0x9c, 0x61, 0x2e, 0x1, 0x75, 0x89, 0x4c, 0x87, 0x30, 0xb2, 0x3, 0xab, 0x3b, 0xd9, 0x8e,
54 0xcb, 0x2d, 0x81, 0xab, 0xac, 0xb6, 0x5f, 0x8a, 0x61, 0x24, 0xf4, 0x97, 0x71, 0xd1, 0x4a,
55 0x98, 0x52, 0x12, 0xc,
56 ];
57}
58
59pub struct ProfileKeyEncryptionDomain;
60impl zkcredential::attributes::Domain for ProfileKeyEncryptionDomain {
61 type Attribute = profile_key_struct::ProfileKeyStruct;
62
63 const ID: &'static str = "Signal_ZKGroup_20231011_ProfileKeyEncryption";
64
65 fn G_a() -> [RistrettoPoint; 2] {
66 let system = SystemParams::get_hardcoded();
67 [system.G_b1, system.G_b2]
68 }
69}
70
71impl ProfileKeyEncryptionDomain {
72 pub(crate) fn decrypt(
73 key_pair: &KeyPair,
74 ciphertext: &Ciphertext,
75 uid_bytes: UidBytes,
76 ) -> Result<profile_key_struct::ProfileKeyStruct, ZkGroupVerificationFailure> {
77 let M4 = key_pair
78 .decrypt_to_second_point(ciphertext)
79 .map_err(|_| ZkGroupVerificationFailure)?;
80 let (mask, candidates) = M4.decode_253_bits();
81
82 let target_M3 = key_pair.a1.invert() * ciphertext.as_points()[0];
83 let seed_sho = profile_key_struct::ProfileKeyStruct::seed_M3();
84
85 let mut retval: profile_key_struct::ProfileKeyStruct = PartialDefault::partial_default();
86 let mut n_found = 0;
87 #[allow(clippy::needless_range_loop)]
88 for i in 0..8 {
89 let is_valid_fe = Choice::from((mask >> i) & 1);
90 let profile_key_bytes: ProfileKeyBytes = candidates[i];
91 for j in 0..8 {
92 let mut pk = profile_key_bytes;
93 if ((j >> 2) & 1) == 1 {
94 pk[0] |= 0x01;
95 }
96 if ((j >> 1) & 1) == 1 {
97 pk[31] |= 0x80;
98 }
99 if (j & 1) == 1 {
100 pk[31] |= 0x40;
101 }
102 let M3 =
103 profile_key_struct::ProfileKeyStruct::calc_M3(seed_sho.clone(), pk, uid_bytes);
104 let candidate_retval = profile_key_struct::ProfileKeyStruct { bytes: pk, M3, M4 };
105 let found = M3.ct_eq(&target_M3) & is_valid_fe;
106 retval.conditional_assign(&candidate_retval, found);
107 n_found += found.unwrap_u8();
108 }
109 }
110 if n_found == 1 {
111 Ok(retval)
112 } else {
113 Err(ZkGroupVerificationFailure)
114 }
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use super::*;
121 use crate::common::constants::*;
122
123 #[test]
124 fn test_profile_key_encryption() {
125 let master_key = TEST_ARRAY_32_1;
126 let mut sho = Sho::new(b"Test_Profile_Key_Encryption", &master_key);
127
128 assert!(SystemParams::generate() == SystemParams::get_hardcoded());
131
132 let key_pair = KeyPair::derive_from(sho.as_mut());
133
134 let key_pair_bytes = bincode::serialize(&key_pair).unwrap();
136 match bincode::deserialize::<KeyPair>(&key_pair_bytes[0..key_pair_bytes.len() - 1]) {
137 Err(_) => (),
138 _ => unreachable!(),
139 };
140 let key_pair2: KeyPair = bincode::deserialize(&key_pair_bytes).unwrap();
141 assert!(key_pair == key_pair2);
142
143 let profile_key_bytes = TEST_ARRAY_32_1;
144 let uid_bytes = TEST_ARRAY_16_1;
145 let profile_key = profile_key_struct::ProfileKeyStruct::new(profile_key_bytes, uid_bytes);
146 let ciphertext = key_pair.encrypt(&profile_key);
147
148 let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
150 assert!(ciphertext_bytes.len() == 64);
151 let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
152 assert!(ciphertext == ciphertext2);
153 println!("ciphertext_bytes = {ciphertext_bytes:#x?}");
154 assert!(
155 ciphertext_bytes
156 == vec![
157 0x56, 0x18, 0xcb, 0x4c, 0x7d, 0x72, 0x1e, 0x1, 0x2b, 0x22, 0xf0, 0x77, 0xef,
158 0x12, 0x64, 0xf6, 0xb1, 0x43, 0xbb, 0x59, 0x7a, 0x1d, 0x66, 0x5a, 0x70, 0xaa,
159 0x84, 0x24, 0x5f, 0x24, 0x6d, 0x20, 0xba, 0xdb, 0x97, 0x47, 0x4a, 0x56, 0xf4,
160 0xb5, 0x36, 0x1a, 0xec, 0xa9, 0xd1, 0x18, 0xb7, 0x0, 0x4e, 0x14, 0x9, 0x71,
161 0x99, 0xa, 0xab, 0x2a, 0xf2, 0x43, 0x2d, 0x3f, 0x8f, 0x7d, 0x21, 0x3a,
162 ]
163 );
164
165 let plaintext =
166 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext2, uid_bytes).unwrap();
167 assert!(plaintext == profile_key);
168
169 let mut sho = Sho::new(b"Test_Repeated_ProfileKeyEnc/Dec", b"seed");
170 for _ in 0..100 {
171 let uid_bytes: UidBytes = sho.squeeze_as_array();
172 let profile_key_bytes: ProfileKeyBytes = sho.squeeze_as_array();
173
174 let profile_key =
175 profile_key_struct::ProfileKeyStruct::new(profile_key_bytes, uid_bytes);
176 let ciphertext = key_pair.encrypt(&profile_key);
177 assert!(
178 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
179 == profile_key
180 );
181 }
182
183 let uid_bytes = TEST_ARRAY_16;
184 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32, TEST_ARRAY_16);
185 let ciphertext = key_pair.encrypt(&profile_key);
186 assert!(
187 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
188 == profile_key
189 );
190
191 let uid_bytes = TEST_ARRAY_16;
192 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32_2, TEST_ARRAY_16);
193 let ciphertext = key_pair.encrypt(&profile_key);
194 assert!(
195 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
196 == profile_key
197 );
198
199 let uid_bytes = TEST_ARRAY_16;
200 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32_3, TEST_ARRAY_16);
201 let ciphertext = key_pair.encrypt(&profile_key);
202 assert!(
203 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
204 == profile_key
205 );
206
207 let uid_bytes = TEST_ARRAY_16;
208 let profile_key = profile_key_struct::ProfileKeyStruct::new(TEST_ARRAY_32_4, TEST_ARRAY_16);
209 let ciphertext = key_pair.encrypt(&profile_key);
210 assert!(
211 ProfileKeyEncryptionDomain::decrypt(&key_pair, &ciphertext, uid_bytes).unwrap()
212 == profile_key
213 );
214 }
215}