1#![allow(non_snake_case)]
7
8use std::sync::LazyLock;
9
10use curve25519_dalek_signal::ristretto::RistrettoPoint;
11use partial_default::PartialDefault;
12use serde::{Deserialize, Serialize};
13use subtle::{ConditionallySelectable, ConstantTimeEq};
14use zkcredential::attributes::Attribute;
15
16use crate::common::errors::*;
17use crate::common::sho::*;
18use crate::crypto::uid_struct;
19
20static SYSTEM_PARAMS: LazyLock<SystemParams> =
21 LazyLock::new(|| crate::deserialize::<SystemParams>(&SystemParams::SYSTEM_HARDCODED).unwrap());
22
23#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialDefault)]
24pub struct SystemParams {
25 pub(crate) G_a1: RistrettoPoint,
26 pub(crate) G_a2: RistrettoPoint,
27}
28
29pub type KeyPair = zkcredential::attributes::KeyPair<UidEncryptionDomain>;
30pub type PublicKey = zkcredential::attributes::PublicKey<UidEncryptionDomain>;
31pub type Ciphertext = zkcredential::attributes::Ciphertext<UidEncryptionDomain>;
32
33impl SystemParams {
34 pub fn generate() -> Self {
35 let mut sho = Sho::new(
36 b"Signal_ZKGroup_20200424_Constant_UidEncryption_SystemParams_Generate",
37 b"",
38 );
39 let G_a1 = sho.get_point();
40 let G_a2 = sho.get_point();
41 SystemParams { G_a1, G_a2 }
42 }
43
44 pub fn get_hardcoded() -> SystemParams {
45 *SYSTEM_PARAMS
46 }
47
48 const SYSTEM_HARDCODED: [u8; 64] = [
49 0xa6, 0x32, 0x4c, 0x36, 0x8d, 0xf7, 0x34, 0x69, 0x11, 0x47, 0x98, 0x13, 0x48, 0xb6, 0xe7,
50 0xeb, 0x42, 0xc3, 0x30, 0x7e, 0x71, 0x1b, 0x6c, 0x7e, 0xcc, 0xd3, 0x3, 0x2d, 0x45, 0x69,
51 0x3f, 0x5a, 0x4, 0x80, 0x13, 0x52, 0x5b, 0x76, 0x12, 0x4b, 0xf2, 0x64, 0xc, 0x5e, 0x93,
52 0x69, 0xc7, 0x6e, 0xfb, 0xe8, 0xa, 0xba, 0x2a, 0x24, 0xaa, 0x5d, 0x8e, 0x18, 0xa9, 0x8e,
53 0xba, 0x14, 0xf8, 0x37,
54 ];
55}
56
57pub struct UidEncryptionDomain;
58impl zkcredential::attributes::Domain for UidEncryptionDomain {
59 type Attribute = uid_struct::UidStruct;
60
61 const ID: &'static str = "Signal_ZKGroup_20230419_UidEncryption";
62
63 fn G_a() -> [RistrettoPoint; 2] {
64 let system = SystemParams::get_hardcoded();
65 [system.G_a1, system.G_a2]
66 }
67}
68
69impl UidEncryptionDomain {
70 pub(crate) fn decrypt(
71 key_pair: &KeyPair,
72 ciphertext: &Ciphertext,
73 ) -> Result<libsignal_core::ServiceId, ZkGroupVerificationFailure> {
74 let M2 = key_pair
75 .decrypt_to_second_point(ciphertext)
76 .map_err(|_| ZkGroupVerificationFailure)?;
77 match M2.lizard_decode::<sha2::Sha256>() {
78 None => Err(ZkGroupVerificationFailure),
79 Some(bytes) => {
80 let decoded_uuid = uuid::Uuid::from_bytes(bytes);
87 let decoded_service_ids = [
88 libsignal_core::Aci::from(decoded_uuid).into(),
89 libsignal_core::Pni::from(decoded_uuid).into(),
90 ];
91 let decoded_aci = &decoded_service_ids[0];
92 let decoded_pni = &decoded_service_ids[1];
93 let aci_M1 = uid_struct::UidStruct::calc_M1(*decoded_aci);
94 let pni_M1 = uid_struct::UidStruct::calc_M1(*decoded_pni);
95 debug_assert!(aci_M1 != pni_M1);
96 let decrypted_M1 = key_pair.a1.invert() * ciphertext.as_points()[0];
97 let mut index = u8::MAX;
98 index.conditional_assign(&0, decrypted_M1.ct_eq(&aci_M1));
99 index.conditional_assign(&1, decrypted_M1.ct_eq(&pni_M1));
100 decoded_service_ids
101 .get(index as usize)
102 .copied()
103 .ok_or(ZkGroupVerificationFailure)
104 }
105 }
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112 use crate::common::constants::*;
113
114 #[test]
115 fn test_uid_encryption() {
116 let master_key = TEST_ARRAY_32;
117 let mut sho = Sho::new(b"Test_Uid_Encryption", &master_key);
118
119 assert!(SystemParams::generate() == SystemParams::get_hardcoded());
122
123 let key_pair = KeyPair::derive_from(sho.as_mut());
124
125 let key_pair_bytes = bincode::serialize(&key_pair).unwrap();
127 match bincode::deserialize::<KeyPair>(&key_pair_bytes[0..key_pair_bytes.len() - 1]) {
128 Err(_) => (),
129 _ => unreachable!(),
130 };
131 let key_pair2: KeyPair = bincode::deserialize(&key_pair_bytes).unwrap();
132 assert!(key_pair == key_pair2);
133
134 let aci = libsignal_core::Aci::from_uuid_bytes(TEST_ARRAY_16);
135 let uid = uid_struct::UidStruct::from_service_id(aci.into());
136 let ciphertext = key_pair.encrypt(&uid);
137
138 let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
140 assert!(ciphertext_bytes.len() == 64);
141 let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
142 assert!(ciphertext == ciphertext2);
143 assert!(
145 ciphertext_bytes
146 == vec![
147 0xf8, 0x9e, 0xe7, 0x70, 0x5a, 0x66, 0x3, 0x6b, 0x90, 0x8d, 0xb8, 0x84, 0x21,
148 0x1b, 0x77, 0x3a, 0xc5, 0x43, 0xee, 0x35, 0xc4, 0xa3, 0x8, 0x62, 0x20, 0xfc,
149 0x3e, 0x1e, 0x35, 0xb4, 0x23, 0x4c, 0xfa, 0x1d, 0x2e, 0xea, 0x2c, 0xc2, 0xf4,
150 0xb4, 0xc4, 0x2c, 0xff, 0x39, 0xa9, 0xdc, 0xeb, 0x57, 0x29, 0x3b, 0x5f, 0x87,
151 0x70, 0xca, 0x60, 0xf9, 0xe9, 0xb7, 0x44, 0x47, 0xbf, 0xd3, 0xbd, 0x3d,
152 ]
153 );
154
155 let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
156 assert!(matches!(plaintext, libsignal_core::ServiceId::Aci(_)));
157 assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
158 }
159
160 #[test]
161 fn test_pni_encryption() {
162 let mut sho = Sho::new(b"Test_Pni_Encryption", &[]);
163 let key_pair = KeyPair::derive_from(sho.as_mut());
164
165 let pni = libsignal_core::Pni::from_uuid_bytes(TEST_ARRAY_16);
166 let uid = uid_struct::UidStruct::from_service_id(pni.into());
167 let ciphertext = key_pair.encrypt(&uid);
168
169 let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
171 assert!(ciphertext_bytes.len() == 64);
172 let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
173 assert!(ciphertext == ciphertext2);
174
175 let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
176 assert!(matches!(plaintext, libsignal_core::ServiceId::Pni(_)));
177 assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
178 }
179}