zkgroup/crypto/
uid_encryption.rs

1//
2// Copyright 2020 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6#![allow(non_snake_case)]
7
8use std::sync::LazyLock;
9
10use curve25519_dalek_signal::ristretto::RistrettoPoint;
11use partial_default::PartialDefault;
12use serde::{Deserialize, Serialize};
13use subtle::{ConditionallySelectable, ConstantTimeEq};
14use zkcredential::attributes::Attribute;
15
16use crate::common::errors::*;
17use crate::common::sho::*;
18use crate::crypto::uid_struct;
19
20static SYSTEM_PARAMS: LazyLock<SystemParams> =
21    LazyLock::new(|| crate::deserialize::<SystemParams>(&SystemParams::SYSTEM_HARDCODED).unwrap());
22
23#[derive(Copy, Clone, PartialEq, Eq, Serialize, Deserialize, PartialDefault)]
24pub struct SystemParams {
25    pub(crate) G_a1: RistrettoPoint,
26    pub(crate) G_a2: RistrettoPoint,
27}
28
29pub type KeyPair = zkcredential::attributes::KeyPair<UidEncryptionDomain>;
30pub type PublicKey = zkcredential::attributes::PublicKey<UidEncryptionDomain>;
31pub type Ciphertext = zkcredential::attributes::Ciphertext<UidEncryptionDomain>;
32
33impl SystemParams {
34    pub fn generate() -> Self {
35        let mut sho = Sho::new(
36            b"Signal_ZKGroup_20200424_Constant_UidEncryption_SystemParams_Generate",
37            b"",
38        );
39        let G_a1 = sho.get_point();
40        let G_a2 = sho.get_point();
41        SystemParams { G_a1, G_a2 }
42    }
43
44    pub fn get_hardcoded() -> SystemParams {
45        *SYSTEM_PARAMS
46    }
47
48    const SYSTEM_HARDCODED: [u8; 64] = [
49        0xa6, 0x32, 0x4c, 0x36, 0x8d, 0xf7, 0x34, 0x69, 0x11, 0x47, 0x98, 0x13, 0x48, 0xb6, 0xe7,
50        0xeb, 0x42, 0xc3, 0x30, 0x7e, 0x71, 0x1b, 0x6c, 0x7e, 0xcc, 0xd3, 0x3, 0x2d, 0x45, 0x69,
51        0x3f, 0x5a, 0x4, 0x80, 0x13, 0x52, 0x5b, 0x76, 0x12, 0x4b, 0xf2, 0x64, 0xc, 0x5e, 0x93,
52        0x69, 0xc7, 0x6e, 0xfb, 0xe8, 0xa, 0xba, 0x2a, 0x24, 0xaa, 0x5d, 0x8e, 0x18, 0xa9, 0x8e,
53        0xba, 0x14, 0xf8, 0x37,
54    ];
55}
56
57pub struct UidEncryptionDomain;
58impl zkcredential::attributes::Domain for UidEncryptionDomain {
59    type Attribute = uid_struct::UidStruct;
60
61    const ID: &'static str = "Signal_ZKGroup_20230419_UidEncryption";
62
63    fn G_a() -> [RistrettoPoint; 2] {
64        let system = SystemParams::get_hardcoded();
65        [system.G_a1, system.G_a2]
66    }
67}
68
69impl UidEncryptionDomain {
70    pub(crate) fn decrypt(
71        key_pair: &KeyPair,
72        ciphertext: &Ciphertext,
73    ) -> Result<libsignal_core::ServiceId, ZkGroupVerificationFailure> {
74        let M2 = key_pair
75            .decrypt_to_second_point(ciphertext)
76            .map_err(|_| ZkGroupVerificationFailure)?;
77        match M2.lizard_decode::<sha2::Sha256>() {
78            None => Err(ZkGroupVerificationFailure),
79            Some(bytes) => {
80                // We want to do a constant-time choice between the ACI and the PNI possibilities.
81                // Only at the end do we do a normal branch to see if decryption succeeded,
82                // and even then we don't want to expose whether we picked the ACI or the PNI.
83                // So we store them both in an array, and index into it at the very end.
84                // This isn't fully "data-oblivious"; only one service ID gets loaded from memory at
85                // the end, and which one is data-dependent. But it is constant-time.
86                let decoded_uuid = uuid::Uuid::from_bytes(bytes);
87                let decoded_service_ids = [
88                    libsignal_core::Aci::from(decoded_uuid).into(),
89                    libsignal_core::Pni::from(decoded_uuid).into(),
90                ];
91                let decoded_aci = &decoded_service_ids[0];
92                let decoded_pni = &decoded_service_ids[1];
93                let aci_M1 = uid_struct::UidStruct::calc_M1(*decoded_aci);
94                let pni_M1 = uid_struct::UidStruct::calc_M1(*decoded_pni);
95                debug_assert!(aci_M1 != pni_M1);
96                let decrypted_M1 = key_pair.a1.invert() * ciphertext.as_points()[0];
97                let mut index = u8::MAX;
98                index.conditional_assign(&0, decrypted_M1.ct_eq(&aci_M1));
99                index.conditional_assign(&1, decrypted_M1.ct_eq(&pni_M1));
100                decoded_service_ids
101                    .get(index as usize)
102                    .copied()
103                    .ok_or(ZkGroupVerificationFailure)
104            }
105        }
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112    use crate::common::constants::*;
113
114    #[test]
115    fn test_uid_encryption() {
116        let master_key = TEST_ARRAY_32;
117        let mut sho = Sho::new(b"Test_Uid_Encryption", &master_key);
118
119        //let system = SystemParams::generate();
120        //println!("PARAMS = {:#x?}", bincode::serialize(&system));
121        assert!(SystemParams::generate() == SystemParams::get_hardcoded());
122
123        let key_pair = KeyPair::derive_from(sho.as_mut());
124
125        // Test serialize of key_pair
126        let key_pair_bytes = bincode::serialize(&key_pair).unwrap();
127        match bincode::deserialize::<KeyPair>(&key_pair_bytes[0..key_pair_bytes.len() - 1]) {
128            Err(_) => (),
129            _ => unreachable!(),
130        };
131        let key_pair2: KeyPair = bincode::deserialize(&key_pair_bytes).unwrap();
132        assert!(key_pair == key_pair2);
133
134        let aci = libsignal_core::Aci::from_uuid_bytes(TEST_ARRAY_16);
135        let uid = uid_struct::UidStruct::from_service_id(aci.into());
136        let ciphertext = key_pair.encrypt(&uid);
137
138        // Test serialize / deserialize of Ciphertext
139        let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
140        assert!(ciphertext_bytes.len() == 64);
141        let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
142        assert!(ciphertext == ciphertext2);
143        //println!("ciphertext_bytes = {:#x?}", ciphertext_bytes);
144        assert!(
145            ciphertext_bytes
146                == vec![
147                    0xf8, 0x9e, 0xe7, 0x70, 0x5a, 0x66, 0x3, 0x6b, 0x90, 0x8d, 0xb8, 0x84, 0x21,
148                    0x1b, 0x77, 0x3a, 0xc5, 0x43, 0xee, 0x35, 0xc4, 0xa3, 0x8, 0x62, 0x20, 0xfc,
149                    0x3e, 0x1e, 0x35, 0xb4, 0x23, 0x4c, 0xfa, 0x1d, 0x2e, 0xea, 0x2c, 0xc2, 0xf4,
150                    0xb4, 0xc4, 0x2c, 0xff, 0x39, 0xa9, 0xdc, 0xeb, 0x57, 0x29, 0x3b, 0x5f, 0x87,
151                    0x70, 0xca, 0x60, 0xf9, 0xe9, 0xb7, 0x44, 0x47, 0xbf, 0xd3, 0xbd, 0x3d,
152                ]
153        );
154
155        let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
156        assert!(matches!(plaintext, libsignal_core::ServiceId::Aci(_)));
157        assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
158    }
159
160    #[test]
161    fn test_pni_encryption() {
162        let mut sho = Sho::new(b"Test_Pni_Encryption", &[]);
163        let key_pair = KeyPair::derive_from(sho.as_mut());
164
165        let pni = libsignal_core::Pni::from_uuid_bytes(TEST_ARRAY_16);
166        let uid = uid_struct::UidStruct::from_service_id(pni.into());
167        let ciphertext = key_pair.encrypt(&uid);
168
169        // Test serialize / deserialize of Ciphertext
170        let ciphertext_bytes = bincode::serialize(&ciphertext).unwrap();
171        assert!(ciphertext_bytes.len() == 64);
172        let ciphertext2: Ciphertext = bincode::deserialize(&ciphertext_bytes).unwrap();
173        assert!(ciphertext == ciphertext2);
174
175        let plaintext = UidEncryptionDomain::decrypt(&key_pair, &ciphertext2).unwrap();
176        assert!(matches!(plaintext, libsignal_core::ServiceId::Pni(_)));
177        assert!(uid_struct::UidStruct::from_service_id(plaintext) == uid);
178    }
179}