libsignal_protocol/
kem.rs

1//
2// Copyright 2023 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6//! Keys and protocol functions for standard key encapsulation mechanisms (KEMs).
7//!
8//! A KEM allows the holder of a `PublicKey` to create a shared secret with the
9//! holder of the corresponding `SecretKey`. This is done by calling the function
10//! `encapsulate` on the `PublicKey` to produce a `SharedSecret` and `Ciphertext`.
11//! The `Ciphertext` is then sent to the recipient who can now call
12//! `SecretKey::decapsulate(ct: Ciphertext)` to construct the same `SharedSecret`.
13//!
14//! # Supported KEMs
15//! The NIST standardized Kyber1024 and Kyber768 KEMs are currently supported.
16//!
17//! # Serialization
18//! `PublicKey`s and `SecretKey`s have serialization functions that encode the
19//! KEM protocol. Calls to `PublicKey::deserialize()` and `SecretKey::deserialize()`
20//! will use this to ensure the key is used for the correct KEM protocol.
21//!
22//! # Example
23//! Basic usage:
24//! ```
25//! # use libsignal_protocol::kem::*;
26//! // Generate a Kyber1024 key pair
27//! let kp = KeyPair::generate(KeyType::Kyber1024);
28//!
29//! // The sender computes the shared secret and the ciphertext to send
30//! let (ss_for_sender, ct) = kp.public_key.encapsulate();
31//!
32//! // Once the recipient receives the ciphertext, they use it with the
33//! // secret key to construct the (same) shared secret.
34//! let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation succeeds");
35//! assert_eq!(ss_for_recipient, ss_for_sender);
36//! ```
37//!
38//! Serialization:
39//! ```
40//! # use libsignal_protocol::kem::*;
41//! // Generate a Kyber1024 key pair
42//! let kp = KeyPair::generate(KeyType::Kyber1024);
43//!
44//! let pk_for_wire = kp.public_key.serialize();
45//! // serialized form has an extra byte to encode the protocol
46//! assert_eq!(pk_for_wire.len(), 1568 + 1);
47//!
48//! let kp_reconstituted = PublicKey::deserialize(pk_for_wire.as_ref()).expect("deserialized correctly");
49//! assert_eq!(kp_reconstituted.key_type(), KeyType::Kyber1024);
50//!
51//! ```
52//!
53mod kyber1024;
54#[cfg(any(feature = "kyber768", test))]
55mod kyber768;
56#[cfg(feature = "mlkem1024")]
57mod mlkem1024;
58
59use std::marker::PhantomData;
60use std::ops::Deref;
61
62use derive_where::derive_where;
63use displaydoc::Display;
64use subtle::ConstantTimeEq;
65
66use crate::{Result, SignalProtocolError};
67
68type SharedSecret = Box<[u8]>;
69
70// The difference between the two is that the raw one does not contain the KeyType byte prefix.
71pub(crate) type RawCiphertext = Box<[u8]>;
72pub type SerializedCiphertext = Box<[u8]>;
73
74/// Each KEM supported by libsignal-protocol implements this trait.
75///
76/// Similar to the traits in RustCrypto's [kem](https://docs.rs/kem/) crate.
77///
78/// # Example
79/// ```ignore
80/// struct MyNiftyKEM;
81/// # #[cfg(ignore_even_when_running_all_tests)]
82/// impl Parameters for MyNiftyKEM {
83///     // ...
84/// }
85/// ```
86trait Parameters {
87    const PUBLIC_KEY_LENGTH: usize;
88    const SECRET_KEY_LENGTH: usize;
89    const CIPHERTEXT_LENGTH: usize;
90    const SHARED_SECRET_LENGTH: usize;
91    fn generate() -> (KeyMaterial<Public>, KeyMaterial<Secret>);
92    fn encapsulate(pub_key: &KeyMaterial<Public>) -> (SharedSecret, RawCiphertext);
93    fn decapsulate(secret_key: &KeyMaterial<Secret>, ciphertext: &[u8]) -> Result<SharedSecret>;
94}
95
96/// Acts as a bridge between the static [Parameters] trait and the dynamic [KeyType] enum.
97trait DynParameters {
98    fn public_key_length(&self) -> usize;
99    fn secret_key_length(&self) -> usize;
100    fn ciphertext_length(&self) -> usize;
101    #[allow(dead_code)]
102    fn shared_secret_length(&self) -> usize;
103    fn generate(&self) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
104    fn encapsulate(&self, pub_key: &KeyMaterial<Public>) -> (SharedSecret, RawCiphertext);
105    fn decapsulate(
106        &self,
107        secret_key: &KeyMaterial<Secret>,
108        ciphertext: &[u8],
109    ) -> Result<SharedSecret>;
110}
111
112impl<T: Parameters> DynParameters for T {
113    fn public_key_length(&self) -> usize {
114        Self::PUBLIC_KEY_LENGTH
115    }
116
117    fn secret_key_length(&self) -> usize {
118        Self::SECRET_KEY_LENGTH
119    }
120
121    fn ciphertext_length(&self) -> usize {
122        Self::CIPHERTEXT_LENGTH
123    }
124
125    fn shared_secret_length(&self) -> usize {
126        Self::SHARED_SECRET_LENGTH
127    }
128
129    fn generate(&self) -> (KeyMaterial<Public>, KeyMaterial<Secret>) {
130        Self::generate()
131    }
132
133    fn encapsulate(&self, pub_key: &KeyMaterial<Public>) -> (SharedSecret, RawCiphertext) {
134        Self::encapsulate(pub_key)
135    }
136
137    fn decapsulate(
138        &self,
139        secret_key: &KeyMaterial<Secret>,
140        ciphertext: &[u8],
141    ) -> Result<SharedSecret> {
142        Self::decapsulate(secret_key, ciphertext)
143    }
144}
145
146/// Designates a supported KEM protocol
147#[derive(Display, Debug, Copy, Clone, PartialEq, Eq)]
148pub enum KeyType {
149    /// Kyber768 key
150    #[cfg(any(feature = "kyber768", test))]
151    Kyber768,
152    /// Kyber1024 key
153    Kyber1024,
154    /// ML-KEM 1024 key
155    #[cfg(feature = "mlkem1024")]
156    MLKEM1024,
157}
158
159impl KeyType {
160    fn value(&self) -> u8 {
161        match self {
162            #[cfg(any(feature = "kyber768", test))]
163            KeyType::Kyber768 => 0x07,
164            KeyType::Kyber1024 => 0x08,
165            #[cfg(feature = "mlkem1024")]
166            KeyType::MLKEM1024 => 0x0A,
167        }
168    }
169
170    /// Allows KeyType to act like `&dyn Parameters` while still being represented by a single byte.
171    ///
172    /// Declared `const` to encourage inlining.
173    const fn parameters(&self) -> &'static dyn DynParameters {
174        match self {
175            #[cfg(any(feature = "kyber768", test))]
176            KeyType::Kyber768 => &kyber768::Parameters,
177            KeyType::Kyber1024 => &kyber1024::Parameters,
178            #[cfg(feature = "mlkem1024")]
179            KeyType::MLKEM1024 => &mlkem1024::Parameters,
180        }
181    }
182}
183
184impl TryFrom<u8> for KeyType {
185    type Error = SignalProtocolError;
186
187    fn try_from(x: u8) -> Result<Self> {
188        match x {
189            #[cfg(any(feature = "kyber768", test))]
190            0x07 => Ok(KeyType::Kyber768),
191            0x08 => Ok(KeyType::Kyber1024),
192            #[cfg(feature = "mlkem1024")]
193            0x0A => Ok(KeyType::MLKEM1024),
194            t => Err(SignalProtocolError::BadKEMKeyType(t)),
195        }
196    }
197}
198
199pub trait KeyKind {
200    fn key_length(key_type: KeyType) -> usize;
201}
202
203pub enum Public {}
204
205impl KeyKind for Public {
206    fn key_length(key_type: KeyType) -> usize {
207        key_type.parameters().public_key_length()
208    }
209}
210
211pub enum Secret {}
212
213impl KeyKind for Secret {
214    fn key_length(key_type: KeyType) -> usize {
215        key_type.parameters().secret_key_length()
216    }
217}
218
219#[derive_where(Clone)]
220pub(crate) struct KeyMaterial<T: KeyKind> {
221    data: Box<[u8]>,
222    kind: PhantomData<T>,
223}
224
225impl<T: KeyKind> KeyMaterial<T> {
226    fn new(data: Box<[u8]>) -> Self {
227        KeyMaterial {
228            data,
229            kind: PhantomData,
230        }
231    }
232}
233
234impl<T: KeyKind> Deref for KeyMaterial<T> {
235    type Target = [u8];
236
237    fn deref(&self) -> &Self::Target {
238        self.data.deref()
239    }
240}
241
242#[derive_where(Clone)]
243pub struct Key<T: KeyKind> {
244    key_type: KeyType,
245    key_data: KeyMaterial<T>,
246}
247
248impl<T: KeyKind> Key<T> {
249    /// Create a `Key<Kind>` instance from a byte string created with the
250    /// function `Key<Kind>::serialize(&self)`.
251    pub fn deserialize(value: &[u8]) -> Result<Self> {
252        if value.is_empty() {
253            return Err(SignalProtocolError::NoKeyTypeIdentifier);
254        }
255        let key_type = KeyType::try_from(value[0])?;
256        if value.len() != T::key_length(key_type) + 1 {
257            return Err(SignalProtocolError::BadKEMKeyLength(key_type, value.len()));
258        }
259        Ok(Key {
260            key_type,
261            key_data: KeyMaterial::new(value[1..].into()),
262        })
263    }
264    /// Create a binary representation of the key that includes a protocol identifier.
265    pub fn serialize(&self) -> Box<[u8]> {
266        let mut result = Vec::with_capacity(1 + self.key_data.len());
267        result.push(self.key_type.value());
268        result.extend_from_slice(&self.key_data);
269        result.into_boxed_slice()
270    }
271
272    /// Return the `KeyType` that identifies the KEM protocol for this key.
273    pub fn key_type(&self) -> KeyType {
274        self.key_type
275    }
276}
277
278impl Key<Public> {
279    /// Create a `SharedSecret` and a `Ciphertext`. The `Ciphertext` can be safely sent to the
280    /// holder of the corresponding `SecretKey` who can then use it to `decapsulate` the same
281    /// `SharedSecret`.
282    pub fn encapsulate(&self) -> (SharedSecret, SerializedCiphertext) {
283        let (ss, ct) = self.key_type.parameters().encapsulate(&self.key_data);
284        (
285            ss,
286            Ciphertext {
287                key_type: self.key_type,
288                data: &ct,
289            }
290            .serialize(),
291        )
292    }
293}
294
295impl Key<Secret> {
296    /// Decapsulates a `SharedSecret` that was encapsulated into a `Ciphertext` by a holder of
297    /// the corresponding `PublicKey`.
298    pub fn decapsulate(&self, ct_bytes: &SerializedCiphertext) -> Result<Box<[u8]>> {
299        // deserialization checks that the length is correct for the KeyType
300        let ct = Ciphertext::deserialize(ct_bytes)?;
301        if ct.key_type != self.key_type {
302            return Err(SignalProtocolError::WrongKEMKeyType(
303                ct.key_type.value(),
304                self.key_type.value(),
305            ));
306        }
307        self.key_type
308            .parameters()
309            .decapsulate(&self.key_data, ct.data)
310    }
311}
312
313impl TryFrom<&[u8]> for Key<Public> {
314    type Error = SignalProtocolError;
315
316    fn try_from(value: &[u8]) -> Result<Self> {
317        Self::deserialize(value)
318    }
319}
320
321impl TryFrom<&[u8]> for Key<Secret> {
322    type Error = SignalProtocolError;
323
324    fn try_from(value: &[u8]) -> Result<Self> {
325        Self::deserialize(value)
326    }
327}
328
329impl subtle::ConstantTimeEq for Key<Public> {
330    /// A constant-time comparison as long as the two keys have a matching type.
331    ///
332    /// If the two keys have different types, the comparison short-circuits,
333    /// much like comparing two slices of different lengths.
334    fn ct_eq(&self, other: &Self) -> subtle::Choice {
335        if self.key_type != other.key_type {
336            return 0.ct_eq(&1);
337        }
338        self.key_data.ct_eq(&other.key_data)
339    }
340}
341
342impl PartialEq for Key<Public> {
343    fn eq(&self, other: &Self) -> bool {
344        bool::from(self.ct_eq(other))
345    }
346}
347
348impl Eq for Key<Public> {}
349
350/// A KEM public key with the ability to encapsulate a shared secret.
351pub type PublicKey = Key<Public>;
352
353/// A KEM secret key with the ability to decapsulate a shared secret.
354pub type SecretKey = Key<Secret>;
355
356/// A public/secret key pair for a KEM protocol.
357#[derive(Clone)]
358pub struct KeyPair {
359    pub public_key: PublicKey,
360    pub secret_key: SecretKey,
361}
362
363impl KeyPair {
364    /// Creates a public-secret key pair for a specified KEM protocol. Uses system randomness
365    /// [implemented by PQClean](https://github.com/PQClean/PQClean/blob/c1b19a865de329e87e9b3e9152362fcb709da8ab/common/randombytes.c#L335).
366    pub fn generate(key_type: KeyType) -> Self {
367        let (pk, sk) = key_type.parameters().generate();
368        Self {
369            secret_key: SecretKey {
370                key_type,
371                key_data: sk,
372            },
373            public_key: PublicKey {
374                key_type,
375                key_data: pk,
376            },
377        }
378    }
379
380    pub fn new(public_key: PublicKey, secret_key: SecretKey) -> Self {
381        assert_eq!(public_key.key_type, secret_key.key_type);
382        Self {
383            public_key,
384            secret_key,
385        }
386    }
387
388    /// Deserialize public and secret keys that were serialized by `PublicKey::serialize()`
389    /// and `SecretKey::serialize()` respectively.
390    pub fn from_public_and_private(public_key: &[u8], secret_key: &[u8]) -> Result<Self> {
391        let public_key = PublicKey::try_from(public_key)?;
392        let secret_key = SecretKey::try_from(secret_key)?;
393        if public_key.key_type != secret_key.key_type {
394            Err(SignalProtocolError::WrongKEMKeyType(
395                secret_key.key_type.value(),
396                public_key.key_type.value(),
397            ))
398        } else {
399            Ok(Self {
400                public_key,
401                secret_key,
402            })
403        }
404    }
405}
406
407/// Utility type to handle serialization and deserialization of ciphertext data
408struct Ciphertext<'a> {
409    key_type: KeyType,
410    data: &'a [u8],
411}
412
413impl<'a> Ciphertext<'a> {
414    /// Create a `Ciphertext` instance from a byte string created with the
415    /// function `Ciphertext::serialize(&self)`.
416    pub fn deserialize(value: &'a [u8]) -> Result<Self> {
417        if value.is_empty() {
418            return Err(SignalProtocolError::NoKeyTypeIdentifier);
419        }
420        let key_type = KeyType::try_from(value[0])?;
421        if value.len() != key_type.parameters().ciphertext_length() + 1 {
422            return Err(SignalProtocolError::BadKEMCiphertextLength(
423                key_type,
424                value.len(),
425            ));
426        }
427        Ok(Ciphertext {
428            key_type,
429            data: &value[1..],
430        })
431    }
432
433    /// Create a binary representation of the key that includes a protocol identifier.
434    pub fn serialize(&self) -> SerializedCiphertext {
435        let mut result = Vec::with_capacity(1 + self.data.len());
436        result.push(self.key_type.value());
437        result.extend_from_slice(self.data);
438        result.into_boxed_slice()
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_serialize() {
448        let pk_bytes = include_bytes!("kem/test-data/pk.dat");
449        let sk_bytes = include_bytes!("kem/test-data/sk.dat");
450
451        let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
452        serialized_pk.push(KeyType::Kyber1024.value());
453        serialized_pk.extend_from_slice(pk_bytes);
454
455        let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
456        serialized_sk.push(KeyType::Kyber1024.value());
457        serialized_sk.extend_from_slice(sk_bytes);
458
459        let pk = PublicKey::deserialize(serialized_pk.as_slice()).expect("desrialize pk");
460        let sk = SecretKey::deserialize(serialized_sk.as_slice()).expect("desrialize sk");
461
462        let reserialized_pk = pk.serialize();
463        let reserialized_sk = sk.serialize();
464
465        assert_eq!(serialized_pk, reserialized_pk.into_vec());
466        assert_eq!(serialized_sk, reserialized_sk.into_vec());
467    }
468
469    #[test]
470    fn test_raw_kem() {
471        use pqcrypto_kyber::kyber1024::{decapsulate, encapsulate, keypair};
472        let (pk, sk) = keypair();
473        let (ss1, ct) = encapsulate(&pk);
474        let ss2 = decapsulate(&ct, &sk);
475        assert!(ss1 == ss2);
476    }
477
478    #[test]
479    fn test_kyber1024_kem() {
480        // test data for kyber1024
481        let pk_bytes = include_bytes!("kem/test-data/pk.dat");
482        let sk_bytes = include_bytes!("kem/test-data/sk.dat");
483
484        let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
485        serialized_pk.push(KeyType::Kyber1024.value());
486        serialized_pk.extend_from_slice(pk_bytes);
487
488        let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
489        serialized_sk.push(KeyType::Kyber1024.value());
490        serialized_sk.extend_from_slice(sk_bytes);
491
492        let pubkey = PublicKey::deserialize(serialized_pk.as_slice()).expect("deserialize pubkey");
493        let secretkey =
494            SecretKey::deserialize(serialized_sk.as_slice()).expect("deserialize secretkey");
495
496        assert_eq!(pubkey.key_type, KeyType::Kyber1024);
497        let (ss_for_sender, ct) = pubkey.encapsulate();
498        let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
499
500        assert_eq!(ss_for_sender, ss_for_recipient);
501    }
502
503    #[cfg(feature = "mlkem1024")]
504    #[test]
505    fn test_mlkem1024_kem() {
506        // test data for kyber1024
507        let pk_bytes = include_bytes!("kem/test-data/mlkem-pk.dat");
508        let sk_bytes = include_bytes!("kem/test-data/mlkem-sk.dat");
509
510        let pubkey = PublicKey::deserialize(pk_bytes).expect("deserialize pubkey");
511        let secretkey = SecretKey::deserialize(sk_bytes).expect("deserialize secretkey");
512
513        assert_eq!(pubkey.key_type, KeyType::MLKEM1024);
514        let (ss_for_sender, ct) = pubkey.encapsulate();
515        let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
516
517        assert_eq!(ss_for_sender, ss_for_recipient);
518    }
519
520    #[test]
521    fn test_kyber1024_keypair() {
522        let kp = KeyPair::generate(KeyType::Kyber1024);
523        assert_eq!(
524            kyber1024::Parameters::SECRET_KEY_LENGTH + 1,
525            kp.secret_key.serialize().len()
526        );
527        assert_eq!(
528            kyber1024::Parameters::PUBLIC_KEY_LENGTH + 1,
529            kp.public_key.serialize().len()
530        );
531        let (ss_for_sender, ct) = kp.public_key.encapsulate();
532        assert_eq!(kyber1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
533        assert_eq!(
534            kyber1024::Parameters::SHARED_SECRET_LENGTH,
535            ss_for_sender.len()
536        );
537        let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
538        assert_eq!(ss_for_recipient, ss_for_sender);
539    }
540
541    #[test]
542    fn test_kyber768_keypair() {
543        let kp = KeyPair::generate(KeyType::Kyber768);
544        assert_eq!(
545            kyber768::Parameters::SECRET_KEY_LENGTH + 1,
546            kp.secret_key.serialize().len()
547        );
548        assert_eq!(
549            kyber768::Parameters::PUBLIC_KEY_LENGTH + 1,
550            kp.public_key.serialize().len()
551        );
552        let (ss_for_sender, ct) = kp.public_key.encapsulate();
553        assert_eq!(kyber768::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
554        assert_eq!(
555            kyber768::Parameters::SHARED_SECRET_LENGTH,
556            ss_for_sender.len()
557        );
558        let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
559        assert_eq!(ss_for_recipient, ss_for_sender);
560    }
561
562    #[cfg(feature = "mlkem1024")]
563    #[test]
564    fn test_mlkem1024_keypair() {
565        let kp = KeyPair::generate(KeyType::MLKEM1024);
566        assert_eq!(
567            mlkem1024::Parameters::SECRET_KEY_LENGTH + 1,
568            kp.secret_key.serialize().len()
569        );
570        assert_eq!(
571            mlkem1024::Parameters::PUBLIC_KEY_LENGTH + 1,
572            kp.public_key.serialize().len()
573        );
574        let (ss_for_sender, ct) = kp.public_key.encapsulate();
575        assert_eq!(mlkem1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
576        assert_eq!(
577            mlkem1024::Parameters::SHARED_SECRET_LENGTH,
578            ss_for_sender.len()
579        );
580        let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
581        assert_eq!(ss_for_recipient, ss_for_sender);
582    }
583
584    #[test]
585    fn test_dyn_parameters_consts() {
586        assert_eq!(
587            kyber1024::Parameters::SECRET_KEY_LENGTH,
588            kyber1024::Parameters.secret_key_length()
589        );
590        assert_eq!(
591            kyber1024::Parameters::PUBLIC_KEY_LENGTH,
592            kyber1024::Parameters.public_key_length()
593        );
594        assert_eq!(
595            kyber1024::Parameters::CIPHERTEXT_LENGTH,
596            kyber1024::Parameters.ciphertext_length()
597        );
598        assert_eq!(
599            kyber1024::Parameters::SHARED_SECRET_LENGTH,
600            kyber1024::Parameters.shared_secret_length()
601        );
602    }
603}