zkgroup/crypto/
uid_encryption.rs

1//
2// Copyright 2020 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6#![allow(non_snake_case)]
7
8use curve25519_dalek_signal::ristretto::RistrettoPoint;
9use lazy_static::lazy_static;
10use partial_default::PartialDefault;
11use serde::{Deserialize, Serialize};
12use subtle::{ConditionallySelectable, ConstantTimeEq};
13use zkcredential::attributes::Attribute;
14
15use crate::common::errors::*;
16use crate::common::sho::*;
17use crate::crypto::uid_struct;
18
19lazy_static! {
20    static ref SYSTEM_PARAMS: SystemParams =
21        crate::deserialize::<SystemParams>(&SystemParams::SYSTEM_HARDCODED).unwrap();
22}
23
24#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialDefault)]
25pub struct SystemParams {
26    pub(crate) G_a1: RistrettoPoint,
27    pub(crate) G_a2: RistrettoPoint,
28}
29
30pub type KeyPair = zkcredential::attributes::KeyPair<UidEncryptionDomain>;
31pub type PublicKey = zkcredential::attributes::PublicKey<UidEncryptionDomain>;
32pub type Ciphertext = zkcredential::attributes::Ciphertext<UidEncryptionDomain>;
33
34impl SystemParams {
35    pub fn generate() -> Self {
36        let mut sho = Sho::new(
37            b"Signal_ZKGroup_20200424_Constant_UidEncryption_SystemParams_Generate",
38            b"",
39        );
40        let G_a1 = sho.get_point();
41        let G_a2 = sho.get_point();
42        SystemParams { G_a1, G_a2 }
43    }
44
45    pub fn get_hardcoded() -> SystemParams {
46        *SYSTEM_PARAMS
47    }
48
49    const SYSTEM_HARDCODED: [u8; 64] = [
50        0xa6, 0x32, 0x4c, 0x36, 0x8d, 0xf7, 0x34, 0x69, 0x11, 0x47, 0x98, 0x13, 0x48, 0xb6, 0xe7,
51        0xeb, 0x42, 0xc3, 0x30, 0x7e, 0x71, 0x1b, 0x6c, 0x7e, 0xcc, 0xd3, 0x3, 0x2d, 0x45, 0x69,
52        0x3f, 0x5a, 0x4, 0x80, 0x13, 0x52, 0x5b, 0x76, 0x12, 0x4b, 0xf2, 0x64, 0xc, 0x5e, 0x93,
53        0x69, 0xc7, 0x6e, 0xfb, 0xe8, 0xa, 0xba, 0x2a, 0x24, 0xaa, 0x5d, 0x8e, 0x18, 0xa9, 0x8e,
54        0xba, 0x14, 0xf8, 0x37,
55    ];
56}
57
58pub struct UidEncryptionDomain;
59impl zkcredential::attributes::Domain for UidEncryptionDomain {
60    type Attribute = uid_struct::UidStruct;
61
62    const ID: &'static str = "Signal_ZKGroup_20230419_UidEncryption";
63
64    fn G_a() -> [RistrettoPoint; 2] {
65        let system = SystemParams::get_hardcoded();
66        [system.G_a1, system.G_a2]
67    }
68}
69
70impl UidEncryptionDomain {
71    pub(crate) fn decrypt(
72        key_pair: &KeyPair,
73        ciphertext: &Ciphertext,
74    ) -> Result<libsignal_core::ServiceId, ZkGroupVerificationFailure> {
75        let M2 = key_pair
76            .decrypt_to_second_point(ciphertext)
77            .map_err(|_| ZkGroupVerificationFailure)?;
78        match M2.lizard_decode::<sha2::Sha256>() {
79            None => Err(ZkGroupVerificationFailure),
80            Some(bytes) => {
81                // We want to do a constant-time choice between the ACI and the PNI possibilities.
82                // Only at the end do we do a normal branch to see if decryption succeeded,
83                // and even then we don't want to expose whether we picked the ACI or the PNI.
84                // So we store them both in an array, and index into it at the very end.
85                // This isn't fully "data-oblivious"; only one service ID gets loaded from memory at
86                // the end, and which one is data-dependent. But it is constant-time.
87                let decoded_uuid = uuid::Uuid::from_bytes(bytes);
88                let decoded_service_ids = [
89                    libsignal_core::Aci::from(decoded_uuid).into(),
90                    libsignal_core::Pni::from(decoded_uuid).into(),
91                ];
92                let decoded_aci = &decoded_service_ids[0];
93                let decoded_pni = &decoded_service_ids[1];
94                let aci_M1 = uid_struct::UidStruct::calc_M1(*decoded_aci);
95                let pni_M1 = uid_struct::UidStruct::calc_M1(*decoded_pni);
96                debug_assert!(aci_M1 != pni_M1);
97                let decrypted_M1 = key_pair.a1.invert() * ciphertext.as_points()[0];
98                let mut index = u8::MAX;
99                index.conditional_assign(&0, decrypted_M1.ct_eq(&aci_M1));
100                index.conditional_assign(&1, decrypted_M1.ct_eq(&pni_M1));
101                decoded_service_ids
102                    .get(index as usize)
103                    .copied()
104                    .ok_or(ZkGroupVerificationFailure)
105            }
106        }
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113    use crate::common::constants::*;
114
115    #[test]
116    fn test_uid_encryption() {
117        let master_key = TEST_ARRAY_32;
118        let mut sho = Sho::new(b"Test_Uid_Encryption", &master_key);
119
120        //let system = SystemParams::generate();
121        //println!("PARAMS = {:#x?}", bincode::serialize(&system));
122        assert!(SystemParams::generate() == SystemParams::get_hardcoded());
123
124        let key_pair = KeyPair::derive_from(sho.as_mut());
125
126        // Test serialize of key_pair
127        let key_pair_bytes = bincode::serialize(&key_pair).unwrap();
128        match bincode::deserialize::<KeyPair>(&key_pair_bytes[0..key_pair_bytes.len() - 1]) {
129            Err(_) => (),
130            _ => unreachable!(),
131        };
132        let key_pair2: KeyPair = bincode::deserialize(&key_pair_bytes).unwrap();
133        assert!(key_pair == key_pair2);
134
135        let aci = libsignal_core::Aci::from_uuid_bytes(TEST_ARRAY_16);
136        let uid = uid_struct::UidStruct::from_service_id(aci.into());
137        let ciphertext = key_pair.encrypt(&uid);
138
139        // Test serialize / deserialize of Ciphertext
140        let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
141        assert!(ciphertext_bytes.len() == 64);
142        let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
143        assert!(ciphertext == ciphertext2);
144        //println!("ciphertext_bytes = {:#x?}", ciphertext_bytes);
145        assert!(
146            ciphertext_bytes
147                == vec![
148                    0xf8, 0x9e, 0xe7, 0x70, 0x5a, 0x66, 0x3, 0x6b, 0x90, 0x8d, 0xb8, 0x84, 0x21,
149                    0x1b, 0x77, 0x3a, 0xc5, 0x43, 0xee, 0x35, 0xc4, 0xa3, 0x8, 0x62, 0x20, 0xfc,
150                    0x3e, 0x1e, 0x35, 0xb4, 0x23, 0x4c, 0xfa, 0x1d, 0x2e, 0xea, 0x2c, 0xc2, 0xf4,
151                    0xb4, 0xc4, 0x2c, 0xff, 0x39, 0xa9, 0xdc, 0xeb, 0x57, 0x29, 0x3b, 0x5f, 0x87,
152                    0x70, 0xca, 0x60, 0xf9, 0xe9, 0xb7, 0x44, 0x47, 0xbf, 0xd3, 0xbd, 0x3d,
153                ]
154        );
155
156        let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
157        assert!(matches!(plaintext, libsignal_core::ServiceId::Aci(_)));
158        assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
159    }
160
161    #[test]
162    fn test_pni_encryption() {
163        let mut sho = Sho::new(b"Test_Pni_Encryption", &[]);
164        let key_pair = KeyPair::derive_from(sho.as_mut());
165
166        let pni = libsignal_core::Pni::from_uuid_bytes(TEST_ARRAY_16);
167        let uid = uid_struct::UidStruct::from_service_id(pni.into());
168        let ciphertext = key_pair.encrypt(&uid);
169
170        // Test serialize / deserialize of Ciphertext
171        let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
172        assert!(ciphertext_bytes.len() == 64);
173        let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
174        assert!(ciphertext == ciphertext2);
175
176        let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
177        assert!(matches!(plaintext, libsignal_core::ServiceId::Pni(_)));
178        assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
179    }
180}