1#![allow(non_snake_case)]
7
8use curve25519_dalek_signal::ristretto::RistrettoPoint;
9use lazy_static::lazy_static;
10use partial_default::PartialDefault;
11use serde::{Deserialize, Serialize};
12use subtle::{ConditionallySelectable, ConstantTimeEq};
13use zkcredential::attributes::Attribute;
14
15use crate::common::errors::*;
16use crate::common::sho::*;
17use crate::crypto::uid_struct;
18
19lazy_static! {
20 static ref SYSTEM_PARAMS: SystemParams =
21 crate::deserialize::<SystemParams>(&SystemParams::SYSTEM_HARDCODED).unwrap();
22}
23
24#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialDefault)]
25pub struct SystemParams {
26 pub(crate) G_a1: RistrettoPoint,
27 pub(crate) G_a2: RistrettoPoint,
28}
29
30pub type KeyPair = zkcredential::attributes::KeyPair<UidEncryptionDomain>;
31pub type PublicKey = zkcredential::attributes::PublicKey<UidEncryptionDomain>;
32pub type Ciphertext = zkcredential::attributes::Ciphertext<UidEncryptionDomain>;
33
34impl SystemParams {
35 pub fn generate() -> Self {
36 let mut sho = Sho::new(
37 b"Signal_ZKGroup_20200424_Constant_UidEncryption_SystemParams_Generate",
38 b"",
39 );
40 let G_a1 = sho.get_point();
41 let G_a2 = sho.get_point();
42 SystemParams { G_a1, G_a2 }
43 }
44
45 pub fn get_hardcoded() -> SystemParams {
46 *SYSTEM_PARAMS
47 }
48
49 const SYSTEM_HARDCODED: [u8; 64] = [
50 0xa6, 0x32, 0x4c, 0x36, 0x8d, 0xf7, 0x34, 0x69, 0x11, 0x47, 0x98, 0x13, 0x48, 0xb6, 0xe7,
51 0xeb, 0x42, 0xc3, 0x30, 0x7e, 0x71, 0x1b, 0x6c, 0x7e, 0xcc, 0xd3, 0x3, 0x2d, 0x45, 0x69,
52 0x3f, 0x5a, 0x4, 0x80, 0x13, 0x52, 0x5b, 0x76, 0x12, 0x4b, 0xf2, 0x64, 0xc, 0x5e, 0x93,
53 0x69, 0xc7, 0x6e, 0xfb, 0xe8, 0xa, 0xba, 0x2a, 0x24, 0xaa, 0x5d, 0x8e, 0x18, 0xa9, 0x8e,
54 0xba, 0x14, 0xf8, 0x37,
55 ];
56}
57
58pub struct UidEncryptionDomain;
59impl zkcredential::attributes::Domain for UidEncryptionDomain {
60 type Attribute = uid_struct::UidStruct;
61
62 const ID: &'static str = "Signal_ZKGroup_20230419_UidEncryption";
63
64 fn G_a() -> [RistrettoPoint; 2] {
65 let system = SystemParams::get_hardcoded();
66 [system.G_a1, system.G_a2]
67 }
68}
69
70impl UidEncryptionDomain {
71 pub(crate) fn decrypt(
72 key_pair: &KeyPair,
73 ciphertext: &Ciphertext,
74 ) -> Result<libsignal_core::ServiceId, ZkGroupVerificationFailure> {
75 let M2 = key_pair
76 .decrypt_to_second_point(ciphertext)
77 .map_err(|_| ZkGroupVerificationFailure)?;
78 match M2.lizard_decode::<sha2::Sha256>() {
79 None => Err(ZkGroupVerificationFailure),
80 Some(bytes) => {
81 let decoded_uuid = uuid::Uuid::from_bytes(bytes);
88 let decoded_service_ids = [
89 libsignal_core::Aci::from(decoded_uuid).into(),
90 libsignal_core::Pni::from(decoded_uuid).into(),
91 ];
92 let decoded_aci = &decoded_service_ids[0];
93 let decoded_pni = &decoded_service_ids[1];
94 let aci_M1 = uid_struct::UidStruct::calc_M1(*decoded_aci);
95 let pni_M1 = uid_struct::UidStruct::calc_M1(*decoded_pni);
96 debug_assert!(aci_M1 != pni_M1);
97 let decrypted_M1 = key_pair.a1.invert() * ciphertext.as_points()[0];
98 let mut index = u8::MAX;
99 index.conditional_assign(&0, decrypted_M1.ct_eq(&aci_M1));
100 index.conditional_assign(&1, decrypted_M1.ct_eq(&pni_M1));
101 decoded_service_ids
102 .get(index as usize)
103 .copied()
104 .ok_or(ZkGroupVerificationFailure)
105 }
106 }
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use crate::common::constants::*;
114
115 #[test]
116 fn test_uid_encryption() {
117 let master_key = TEST_ARRAY_32;
118 let mut sho = Sho::new(b"Test_Uid_Encryption", &master_key);
119
120 assert!(SystemParams::generate() == SystemParams::get_hardcoded());
123
124 let key_pair = KeyPair::derive_from(sho.as_mut());
125
126 let key_pair_bytes = bincode::serialize(&key_pair).unwrap();
128 match bincode::deserialize::<KeyPair>(&key_pair_bytes[0..key_pair_bytes.len() - 1]) {
129 Err(_) => (),
130 _ => unreachable!(),
131 };
132 let key_pair2: KeyPair = bincode::deserialize(&key_pair_bytes).unwrap();
133 assert!(key_pair == key_pair2);
134
135 let aci = libsignal_core::Aci::from_uuid_bytes(TEST_ARRAY_16);
136 let uid = uid_struct::UidStruct::from_service_id(aci.into());
137 let ciphertext = key_pair.encrypt(&uid);
138
139 let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
141 assert!(ciphertext_bytes.len() == 64);
142 let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
143 assert!(ciphertext == ciphertext2);
144 assert!(
146 ciphertext_bytes
147 == vec![
148 0xf8, 0x9e, 0xe7, 0x70, 0x5a, 0x66, 0x3, 0x6b, 0x90, 0x8d, 0xb8, 0x84, 0x21,
149 0x1b, 0x77, 0x3a, 0xc5, 0x43, 0xee, 0x35, 0xc4, 0xa3, 0x8, 0x62, 0x20, 0xfc,
150 0x3e, 0x1e, 0x35, 0xb4, 0x23, 0x4c, 0xfa, 0x1d, 0x2e, 0xea, 0x2c, 0xc2, 0xf4,
151 0xb4, 0xc4, 0x2c, 0xff, 0x39, 0xa9, 0xdc, 0xeb, 0x57, 0x29, 0x3b, 0x5f, 0x87,
152 0x70, 0xca, 0x60, 0xf9, 0xe9, 0xb7, 0x44, 0x47, 0xbf, 0xd3, 0xbd, 0x3d,
153 ]
154 );
155
156 let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
157 assert!(matches!(plaintext, libsignal_core::ServiceId::Aci(_)));
158 assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
159 }
160
161 #[test]
162 fn test_pni_encryption() {
163 let mut sho = Sho::new(b"Test_Pni_Encryption", &[]);
164 let key_pair = KeyPair::derive_from(sho.as_mut());
165
166 let pni = libsignal_core::Pni::from_uuid_bytes(TEST_ARRAY_16);
167 let uid = uid_struct::UidStruct::from_service_id(pni.into());
168 let ciphertext = key_pair.encrypt(&uid);
169
170 let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
172 assert!(ciphertext_bytes.len() == 64);
173 let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
174 assert!(ciphertext == ciphertext2);
175
176 let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
177 assert!(matches!(plaintext, libsignal_core::ServiceId::Pni(_)));
178 assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
179 }
180}