1#![allow(non_snake_case)]
7
8use std::sync::LazyLock;
9
10use curve25519_dalek_signal::ristretto::RistrettoPoint;
11use partial_default::PartialDefault;
12use serde::{Deserialize, Serialize};
13use subtle::{ConditionallySelectable, ConstantTimeEq};
14use zkcredential::attributes::Attribute;
15
16use crate::common::errors::*;
17use crate::common::sho::*;
18use crate::crypto::uid_struct;
19
20static SYSTEM_PARAMS: LazyLock<SystemParams> = LazyLock::new(|| {
21 crate::deserialize(&SystemParams::SYSTEM_HARDCODED).expect("valid hardcoded params")
22});
23
24#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialDefault)]
25pub struct SystemParams {
26 pub(crate) G_a1: RistrettoPoint,
27 pub(crate) G_a2: RistrettoPoint,
28}
29
30pub type KeyPair = zkcredential::attributes::KeyPair<UidEncryptionDomain>;
31pub type PublicKey = zkcredential::attributes::PublicKey<UidEncryptionDomain>;
32pub type Ciphertext = zkcredential::attributes::Ciphertext<UidEncryptionDomain>;
33
34impl SystemParams {
35 pub fn generate() -> Self {
36 let mut sho = Sho::new(
37 b"Signal_ZKGroup_20200424_Constant_UidEncryption_SystemParams_Generate",
38 b"",
39 );
40 let G_a1 = sho.get_point();
41 let G_a2 = sho.get_point();
42 SystemParams { G_a1, G_a2 }
43 }
44
45 pub fn get_hardcoded() -> SystemParams {
46 *SYSTEM_PARAMS
47 }
48
49 const SYSTEM_HARDCODED: [u8; 64] = [
50 0xa6, 0x32, 0x4c, 0x36, 0x8d, 0xf7, 0x34, 0x69, 0x11, 0x47, 0x98, 0x13, 0x48, 0xb6, 0xe7,
51 0xeb, 0x42, 0xc3, 0x30, 0x7e, 0x71, 0x1b, 0x6c, 0x7e, 0xcc, 0xd3, 0x3, 0x2d, 0x45, 0x69,
52 0x3f, 0x5a, 0x4, 0x80, 0x13, 0x52, 0x5b, 0x76, 0x12, 0x4b, 0xf2, 0x64, 0xc, 0x5e, 0x93,
53 0x69, 0xc7, 0x6e, 0xfb, 0xe8, 0xa, 0xba, 0x2a, 0x24, 0xaa, 0x5d, 0x8e, 0x18, 0xa9, 0x8e,
54 0xba, 0x14, 0xf8, 0x37,
55 ];
56}
57
58pub struct UidEncryptionDomain;
59impl zkcredential::attributes::Domain for UidEncryptionDomain {
60 type Attribute = uid_struct::UidStruct;
61
62 const ID: &'static str = "Signal_ZKGroup_20230419_UidEncryption";
63
64 fn G_a() -> [RistrettoPoint; 2] {
65 let system = SystemParams::get_hardcoded();
66 [system.G_a1, system.G_a2]
67 }
68}
69
70impl UidEncryptionDomain {
71 pub(crate) fn decrypt(
72 key_pair: &KeyPair,
73 ciphertext: &Ciphertext,
74 ) -> Result<libsignal_core::ServiceId, ZkGroupVerificationFailure> {
75 let M2 = key_pair
76 .decrypt_to_second_point(ciphertext)
77 .map_err(|_| ZkGroupVerificationFailure)?;
78 match M2.lizard_decode::<sha2::Sha256>() {
79 None => Err(ZkGroupVerificationFailure),
80 Some(bytes) => {
81 let decoded_uuid = uuid::Uuid::from_bytes(bytes);
88 let decoded_service_ids = [
89 libsignal_core::Aci::from(decoded_uuid).into(),
90 libsignal_core::Pni::from(decoded_uuid).into(),
91 ];
92 let decoded_aci = &decoded_service_ids[0];
93 let decoded_pni = &decoded_service_ids[1];
94 let sho_seed = uid_struct::UidStruct::seed_M1();
95 let aci_M1 = uid_struct::UidStruct::calc_M1(sho_seed.clone(), *decoded_aci);
96 let pni_M1 = uid_struct::UidStruct::calc_M1(sho_seed, *decoded_pni);
97 debug_assert!(aci_M1 != pni_M1);
98 let decrypted_M1 = key_pair.a1.invert() * ciphertext.as_points()[0];
99 let mut index = u8::MAX;
100 index.conditional_assign(&0, decrypted_M1.ct_eq(&aci_M1));
101 index.conditional_assign(&1, decrypted_M1.ct_eq(&pni_M1));
102 decoded_service_ids
103 .get(index as usize)
104 .copied()
105 .ok_or(ZkGroupVerificationFailure)
106 }
107 }
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::common::constants::*;
115
116 #[test]
117 fn test_uid_encryption() {
118 let master_key = TEST_ARRAY_32;
119 let mut sho = Sho::new(b"Test_Uid_Encryption", &master_key);
120
121 assert!(SystemParams::generate() == SystemParams::get_hardcoded());
124
125 let key_pair = KeyPair::derive_from(sho.as_mut());
126
127 let key_pair_bytes = bincode::serialize(&key_pair).unwrap();
129 match bincode::deserialize::<KeyPair>(&key_pair_bytes[0..key_pair_bytes.len() - 1]) {
130 Err(_) => (),
131 _ => unreachable!(),
132 };
133 let key_pair2: KeyPair = bincode::deserialize(&key_pair_bytes).unwrap();
134 assert!(key_pair == key_pair2);
135
136 let aci = libsignal_core::Aci::from_uuid_bytes(TEST_ARRAY_16);
137 let uid = uid_struct::UidStruct::from_service_id(aci.into());
138 let ciphertext = key_pair.encrypt(&uid);
139
140 let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
142 assert!(ciphertext_bytes.len() == 64);
143 let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
144 assert!(ciphertext == ciphertext2);
145 assert!(
147 ciphertext_bytes
148 == vec![
149 0xf8, 0x9e, 0xe7, 0x70, 0x5a, 0x66, 0x3, 0x6b, 0x90, 0x8d, 0xb8, 0x84, 0x21,
150 0x1b, 0x77, 0x3a, 0xc5, 0x43, 0xee, 0x35, 0xc4, 0xa3, 0x8, 0x62, 0x20, 0xfc,
151 0x3e, 0x1e, 0x35, 0xb4, 0x23, 0x4c, 0xfa, 0x1d, 0x2e, 0xea, 0x2c, 0xc2, 0xf4,
152 0xb4, 0xc4, 0x2c, 0xff, 0x39, 0xa9, 0xdc, 0xeb, 0x57, 0x29, 0x3b, 0x5f, 0x87,
153 0x70, 0xca, 0x60, 0xf9, 0xe9, 0xb7, 0x44, 0x47, 0xbf, 0xd3, 0xbd, 0x3d,
154 ]
155 );
156
157 let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
158 assert!(matches!(plaintext, libsignal_core::ServiceId::Aci(_)));
159 assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
160 }
161
162 #[test]
163 fn test_pni_encryption() {
164 let mut sho = Sho::new(b"Test_Pni_Encryption", &[]);
165 let key_pair = KeyPair::derive_from(sho.as_mut());
166
167 let pni = libsignal_core::Pni::from_uuid_bytes(TEST_ARRAY_16);
168 let uid = uid_struct::UidStruct::from_service_id(pni.into());
169 let ciphertext = key_pair.encrypt(&uid);
170
171 let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
173 assert!(ciphertext_bytes.len() == 64);
174 let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
175 assert!(ciphertext == ciphertext2);
176
177 let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
178 assert!(matches!(plaintext, libsignal_core::ServiceId::Pni(_)));
179 assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
180 }
181}