zkgroup/crypto/
uid_encryption.rs

1//
2// Copyright 2020 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6#![allow(non_snake_case)]
7
8use std::sync::LazyLock;
9
10use curve25519_dalek_signal::ristretto::RistrettoPoint;
11use partial_default::PartialDefault;
12use serde::{Deserialize, Serialize};
13use subtle::{ConditionallySelectable, ConstantTimeEq};
14use zkcredential::attributes::Attribute;
15
16use crate::common::errors::*;
17use crate::common::sho::*;
18use crate::crypto::uid_struct;
19
20static SYSTEM_PARAMS: LazyLock<SystemParams> = LazyLock::new(|| {
21    crate::deserialize(&SystemParams::SYSTEM_HARDCODED).expect("valid hardcoded params")
22});
23
24#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialDefault)]
25pub struct SystemParams {
26    pub(crate) G_a1: RistrettoPoint,
27    pub(crate) G_a2: RistrettoPoint,
28}
29
30pub type KeyPair = zkcredential::attributes::KeyPair<UidEncryptionDomain>;
31pub type PublicKey = zkcredential::attributes::PublicKey<UidEncryptionDomain>;
32pub type Ciphertext = zkcredential::attributes::Ciphertext<UidEncryptionDomain>;
33
34impl SystemParams {
35    pub fn generate() -> Self {
36        let mut sho = Sho::new(
37            b"Signal_ZKGroup_20200424_Constant_UidEncryption_SystemParams_Generate",
38            b"",
39        );
40        let G_a1 = sho.get_point();
41        let G_a2 = sho.get_point();
42        SystemParams { G_a1, G_a2 }
43    }
44
45    pub fn get_hardcoded() -> SystemParams {
46        *SYSTEM_PARAMS
47    }
48
49    const SYSTEM_HARDCODED: [u8; 64] = [
50        0xa6, 0x32, 0x4c, 0x36, 0x8d, 0xf7, 0x34, 0x69, 0x11, 0x47, 0x98, 0x13, 0x48, 0xb6, 0xe7,
51        0xeb, 0x42, 0xc3, 0x30, 0x7e, 0x71, 0x1b, 0x6c, 0x7e, 0xcc, 0xd3, 0x3, 0x2d, 0x45, 0x69,
52        0x3f, 0x5a, 0x4, 0x80, 0x13, 0x52, 0x5b, 0x76, 0x12, 0x4b, 0xf2, 0x64, 0xc, 0x5e, 0x93,
53        0x69, 0xc7, 0x6e, 0xfb, 0xe8, 0xa, 0xba, 0x2a, 0x24, 0xaa, 0x5d, 0x8e, 0x18, 0xa9, 0x8e,
54        0xba, 0x14, 0xf8, 0x37,
55    ];
56}
57
58pub struct UidEncryptionDomain;
59impl zkcredential::attributes::Domain for UidEncryptionDomain {
60    type Attribute = uid_struct::UidStruct;
61
62    const ID: &'static str = "Signal_ZKGroup_20230419_UidEncryption";
63
64    fn G_a() -> [RistrettoPoint; 2] {
65        let system = SystemParams::get_hardcoded();
66        [system.G_a1, system.G_a2]
67    }
68}
69
70impl UidEncryptionDomain {
71    pub(crate) fn decrypt(
72        key_pair: &KeyPair,
73        ciphertext: &Ciphertext,
74    ) -> Result<libsignal_core::ServiceId, ZkGroupVerificationFailure> {
75        let M2 = key_pair
76            .decrypt_to_second_point(ciphertext)
77            .map_err(|_| ZkGroupVerificationFailure)?;
78        match M2.lizard_decode::<sha2::Sha256>() {
79            None => Err(ZkGroupVerificationFailure),
80            Some(bytes) => {
81                // We want to do a constant-time choice between the ACI and the PNI possibilities.
82                // Only at the end do we do a normal branch to see if decryption succeeded,
83                // and even then we don't want to expose whether we picked the ACI or the PNI.
84                // So we store them both in an array, and index into it at the very end.
85                // This isn't fully "data-oblivious"; only one service ID gets loaded from memory at
86                // the end, and which one is data-dependent. But it is constant-time.
87                let decoded_uuid = uuid::Uuid::from_bytes(bytes);
88                let decoded_service_ids = [
89                    libsignal_core::Aci::from(decoded_uuid).into(),
90                    libsignal_core::Pni::from(decoded_uuid).into(),
91                ];
92                let decoded_aci = &decoded_service_ids[0];
93                let decoded_pni = &decoded_service_ids[1];
94                let sho_seed = uid_struct::UidStruct::seed_M1();
95                let aci_M1 = uid_struct::UidStruct::calc_M1(sho_seed.clone(), *decoded_aci);
96                let pni_M1 = uid_struct::UidStruct::calc_M1(sho_seed, *decoded_pni);
97                debug_assert!(aci_M1 != pni_M1);
98                let decrypted_M1 = key_pair.a1.invert() * ciphertext.as_points()[0];
99                let mut index = u8::MAX;
100                index.conditional_assign(&0, decrypted_M1.ct_eq(&aci_M1));
101                index.conditional_assign(&1, decrypted_M1.ct_eq(&pni_M1));
102                decoded_service_ids
103                    .get(index as usize)
104                    .copied()
105                    .ok_or(ZkGroupVerificationFailure)
106            }
107        }
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::common::constants::*;
115
116    #[test]
117    fn test_uid_encryption() {
118        let master_key = TEST_ARRAY_32;
119        let mut sho = Sho::new(b"Test_Uid_Encryption", &master_key);
120
121        //let system = SystemParams::generate();
122        //println!("PARAMS = {:#x?}", bincode::serialize(&system));
123        assert!(SystemParams::generate() == SystemParams::get_hardcoded());
124
125        let key_pair = KeyPair::derive_from(sho.as_mut());
126
127        // Test serialize of key_pair
128        let key_pair_bytes = bincode::serialize(&key_pair).unwrap();
129        match bincode::deserialize::<KeyPair>(&key_pair_bytes[0..key_pair_bytes.len() - 1]) {
130            Err(_) => (),
131            _ => unreachable!(),
132        };
133        let key_pair2: KeyPair = bincode::deserialize(&key_pair_bytes).unwrap();
134        assert!(key_pair == key_pair2);
135
136        let aci = libsignal_core::Aci::from_uuid_bytes(TEST_ARRAY_16);
137        let uid = uid_struct::UidStruct::from_service_id(aci.into());
138        let ciphertext = key_pair.encrypt(&uid);
139
140        // Test serialize / deserialize of Ciphertext
141        let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
142        assert!(ciphertext_bytes.len() == 64);
143        let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
144        assert!(ciphertext == ciphertext2);
145        //println!("ciphertext_bytes = {:#x?}", ciphertext_bytes);
146        assert!(
147            ciphertext_bytes
148                == vec![
149                    0xf8, 0x9e, 0xe7, 0x70, 0x5a, 0x66, 0x3, 0x6b, 0x90, 0x8d, 0xb8, 0x84, 0x21,
150                    0x1b, 0x77, 0x3a, 0xc5, 0x43, 0xee, 0x35, 0xc4, 0xa3, 0x8, 0x62, 0x20, 0xfc,
151                    0x3e, 0x1e, 0x35, 0xb4, 0x23, 0x4c, 0xfa, 0x1d, 0x2e, 0xea, 0x2c, 0xc2, 0xf4,
152                    0xb4, 0xc4, 0x2c, 0xff, 0x39, 0xa9, 0xdc, 0xeb, 0x57, 0x29, 0x3b, 0x5f, 0x87,
153                    0x70, 0xca, 0x60, 0xf9, 0xe9, 0xb7, 0x44, 0x47, 0xbf, 0xd3, 0xbd, 0x3d,
154                ]
155        );
156
157        let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
158        assert!(matches!(plaintext, libsignal_core::ServiceId::Aci(_)));
159        assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
160    }
161
162    #[test]
163    fn test_pni_encryption() {
164        let mut sho = Sho::new(b"Test_Pni_Encryption", &[]);
165        let key_pair = KeyPair::derive_from(sho.as_mut());
166
167        let pni = libsignal_core::Pni::from_uuid_bytes(TEST_ARRAY_16);
168        let uid = uid_struct::UidStruct::from_service_id(pni.into());
169        let ciphertext = key_pair.encrypt(&uid);
170
171        // Test serialize / deserialize of Ciphertext
172        let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
173        assert!(ciphertext_bytes.len() == 64);
174        let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
175        assert!(ciphertext == ciphertext2);
176
177        let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
178        assert!(matches!(plaintext, libsignal_core::ServiceId::Pni(_)));
179        assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
180    }
181}