libsignal_protocol/
kem.rs

1//
2// Copyright 2023 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6//! Keys and protocol functions for standard key encapsulation mechanisms (KEMs).
7//!
8//! A KEM allows the holder of a `PublicKey` to create a shared secret with the
9//! holder of the corresponding `SecretKey`. This is done by calling the function
10//! `encapsulate` on the `PublicKey` to produce a `SharedSecret` and `Ciphertext`.
11//! The `Ciphertext` is then sent to the recipient who can now call
12//! `SecretKey::decapsulate(ct: Ciphertext)` to construct the same `SharedSecret`.
13//!
14//! # Supported KEMs
15//! The NIST standardized Kyber1024 and Kyber768 KEMs are currently supported.
16//!
17//! # Serialization
18//! `PublicKey`s and `SecretKey`s have serialization functions that encode the
19//! KEM protocol. Calls to `PublicKey::deserialize()` and `SecretKey::deserialize()`
20//! will use this to ensure the key is used for the correct KEM protocol.
21//!
22//! # Example
23//! Basic usage:
24//! ```
25//! # use libsignal_protocol::kem::*;
26//! let mut rng = rand::rng();
27//! // Generate a Kyber1024 key pair
28//! let kp = KeyPair::generate(KeyType::Kyber1024, &mut rng);
29//!
30//! // The sender computes the shared secret and the ciphertext to send
31//! let (ss_for_sender, ct) = kp.public_key.encapsulate(&mut rng).expect("encapsulation succeeds");
32//!
33//! // Once the recipient receives the ciphertext, they use it with the
34//! // secret key to construct the (same) shared secret.
35//! let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation succeeds");
36//! assert_eq!(ss_for_recipient, ss_for_sender);
37//! ```
38//!
39//! Serialization:
40//! ```
41//! # use libsignal_protocol::kem::*;
42//! let mut rng = rand::rng();
43//! // Generate a Kyber1024 key pair
44//! let kp = KeyPair::generate(KeyType::Kyber1024, &mut rng);
45//!
46//! let pk_for_wire = kp.public_key.serialize();
47//! // serialized form has an extra byte to encode the protocol
48//! assert_eq!(pk_for_wire.len(), 1568 + 1);
49//!
50//! let kp_reconstituted = PublicKey::deserialize(pk_for_wire.as_ref()).expect("deserialized correctly");
51//! assert_eq!(kp_reconstituted.key_type(), KeyType::Kyber1024);
52//!
53//! ```
54//!
55mod kyber1024;
56#[cfg(feature = "kyber768")]
57mod kyber768;
58#[cfg(feature = "mlkem1024")]
59mod mlkem1024;
60
61use std::fmt;
62use std::marker::PhantomData;
63
64use derive_where::derive_where;
65use displaydoc::Display;
66use rand::{CryptoRng, Rng};
67use subtle::ConstantTimeEq;
68
69use crate::{Result, SignalProtocolError};
70
71type SharedSecret = Box<[u8]>;
72
73// The difference between the two is that the raw one does not contain the KeyType byte prefix.
74pub(crate) type RawCiphertext = Box<[u8]>;
75pub type SerializedCiphertext = Box<[u8]>;
76
77/// Each KEM supported by libsignal-protocol implements this trait.
78///
79/// Similar to the traits in RustCrypto's [kem](https://docs.rs/kem/) crate.
80///
81/// # Example
82/// ```ignore
83/// struct MyNiftyKEM;
84/// # #[cfg(ignore_even_when_running_all_tests)]
85/// impl Parameters for MyNiftyKEM {
86///     // ...
87/// }
88/// ```
89trait Parameters {
90    const KEY_TYPE: KeyType;
91    const PUBLIC_KEY_LENGTH: usize;
92    const SECRET_KEY_LENGTH: usize;
93    const CIPHERTEXT_LENGTH: usize;
94    const SHARED_SECRET_LENGTH: usize;
95    fn generate<R: CryptoRng + ?Sized>(
96        csprng: &mut R,
97    ) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
98    fn encapsulate<R: CryptoRng + ?Sized>(
99        pub_key: &KeyMaterial<Public>,
100        csprng: &mut R,
101    ) -> std::result::Result<(SharedSecret, RawCiphertext), BadKEMKeyLength>;
102    fn decapsulate(
103        secret_key: &KeyMaterial<Secret>,
104        ciphertext: &[u8],
105    ) -> std::result::Result<SharedSecret, DecapsulateError>;
106}
107
108/// Acts as a bridge between the static [Parameters] trait and the dynamic [KeyType] enum.
109trait DynParameters {
110    fn public_key_length(&self) -> usize;
111    fn secret_key_length(&self) -> usize;
112    fn ciphertext_length(&self) -> usize;
113    #[cfg_attr(not(test), expect(dead_code))]
114    fn shared_secret_length(&self) -> usize;
115    fn generate(&self, rng: &mut dyn CryptoRng) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
116    fn encapsulate(
117        &self,
118        pub_key: &KeyMaterial<Public>,
119        csprng: &mut dyn CryptoRng,
120    ) -> Result<(SharedSecret, RawCiphertext)>;
121    fn decapsulate(
122        &self,
123        secret_key: &KeyMaterial<Secret>,
124        ciphertext: &[u8],
125    ) -> Result<SharedSecret>;
126}
127
128impl<T: Parameters> DynParameters for T {
129    fn public_key_length(&self) -> usize {
130        Self::PUBLIC_KEY_LENGTH
131    }
132
133    fn secret_key_length(&self) -> usize {
134        Self::SECRET_KEY_LENGTH
135    }
136
137    fn ciphertext_length(&self) -> usize {
138        Self::CIPHERTEXT_LENGTH
139    }
140
141    fn shared_secret_length(&self) -> usize {
142        Self::SHARED_SECRET_LENGTH
143    }
144
145    fn generate(&self, csprng: &mut dyn CryptoRng) -> (KeyMaterial<Public>, KeyMaterial<Secret>) {
146        Self::generate(csprng)
147    }
148
149    fn encapsulate(
150        &self,
151        pub_key: &KeyMaterial<Public>,
152        csprng: &mut dyn CryptoRng,
153    ) -> Result<(Box<[u8]>, Box<[u8]>)> {
154        Self::encapsulate(pub_key, csprng).map_err(|BadKEMKeyLength| {
155            SignalProtocolError::BadKEMKeyLength(T::KEY_TYPE, pub_key.len())
156        })
157    }
158
159    fn decapsulate(
160        &self,
161        secret_key: &KeyMaterial<Secret>,
162        ciphertext: &[u8],
163    ) -> Result<SharedSecret> {
164        Self::decapsulate(secret_key, ciphertext).map_err(|e| match e {
165            DecapsulateError::BadKeyLength => {
166                SignalProtocolError::BadKEMKeyLength(T::KEY_TYPE, secret_key.len())
167            }
168            DecapsulateError::BadCiphertext => {
169                SignalProtocolError::BadKEMCiphertextLength(T::KEY_TYPE, ciphertext.len())
170            }
171        })
172    }
173}
174
175/// Helper trait for extracting the size of [`libcrux_ml_kem`]'s generic types.
176trait ConstantLength {
177    const LENGTH: usize;
178}
179
180impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemPrivateKey<N> {
181    const LENGTH: usize = N;
182}
183impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemPublicKey<N> {
184    const LENGTH: usize = N;
185}
186impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemCiphertext<N> {
187    const LENGTH: usize = N;
188}
189
190/// Error returned from [`Parameters::encapsulate`].
191struct BadKEMKeyLength;
192
193/// Error returned from [`Parameters::decapsulate`].
194enum DecapsulateError {
195    BadKeyLength,
196    BadCiphertext,
197}
198
199/// Designates a supported KEM protocol
200#[derive(Display, Debug, Copy, Clone, PartialEq, Eq)]
201pub enum KeyType {
202    /// Kyber768 key
203    #[cfg(feature = "kyber768")]
204    Kyber768,
205    /// Kyber1024 key
206    Kyber1024,
207    /// ML-KEM 1024 key
208    #[cfg(feature = "mlkem1024")]
209    MLKEM1024,
210}
211
212impl KeyType {
213    fn value(&self) -> u8 {
214        match self {
215            #[cfg(feature = "kyber768")]
216            KeyType::Kyber768 => 0x07,
217            KeyType::Kyber1024 => 0x08,
218            #[cfg(feature = "mlkem1024")]
219            KeyType::MLKEM1024 => 0x0A,
220        }
221    }
222
223    /// Allows KeyType to act like `&dyn Parameters` while still being represented by a single byte.
224    ///
225    /// Declared `const` to encourage inlining.
226    const fn parameters(&self) -> &'static dyn DynParameters {
227        match self {
228            #[cfg(feature = "kyber768")]
229            KeyType::Kyber768 => &kyber768::Parameters,
230            KeyType::Kyber1024 => &kyber1024::Parameters,
231            #[cfg(feature = "mlkem1024")]
232            KeyType::MLKEM1024 => &mlkem1024::Parameters,
233        }
234    }
235}
236
237impl TryFrom<u8> for KeyType {
238    type Error = SignalProtocolError;
239
240    fn try_from(x: u8) -> Result<Self> {
241        match x {
242            #[cfg(feature = "kyber768")]
243            0x07 => Ok(KeyType::Kyber768),
244            0x08 => Ok(KeyType::Kyber1024),
245            #[cfg(feature = "mlkem1024")]
246            0x0A => Ok(KeyType::MLKEM1024),
247            t => Err(SignalProtocolError::BadKEMKeyType(t)),
248        }
249    }
250}
251
252pub trait KeyKind {
253    fn key_length(key_type: KeyType) -> usize;
254}
255
256pub enum Public {}
257
258impl KeyKind for Public {
259    fn key_length(key_type: KeyType) -> usize {
260        key_type.parameters().public_key_length()
261    }
262}
263
264pub enum Secret {}
265
266impl KeyKind for Secret {
267    fn key_length(key_type: KeyType) -> usize {
268        key_type.parameters().secret_key_length()
269    }
270}
271
272#[derive(derive_more::Deref)]
273#[derive_where(Clone)]
274pub(crate) struct KeyMaterial<T: KeyKind> {
275    #[deref(forward)]
276    data: Box<[u8]>,
277    kind: PhantomData<T>,
278}
279
280impl<T: KeyKind> KeyMaterial<T> {
281    fn new(data: Box<[u8]>) -> Self {
282        KeyMaterial {
283            data,
284            kind: PhantomData,
285        }
286    }
287}
288
289impl<const SIZE: usize> From<libcrux_ml_kem::MlKemPublicKey<SIZE>> for KeyMaterial<Public> {
290    fn from(value: libcrux_ml_kem::MlKemPublicKey<SIZE>) -> Self {
291        KeyMaterial::new(value.as_ref().into())
292    }
293}
294
295impl<const SIZE: usize> From<libcrux_ml_kem::MlKemPrivateKey<SIZE>> for KeyMaterial<Secret> {
296    fn from(value: libcrux_ml_kem::MlKemPrivateKey<SIZE>) -> Self {
297        KeyMaterial::new(value.as_ref().into())
298    }
299}
300
301#[derive_where(Clone)]
302pub struct Key<T: KeyKind> {
303    key_type: KeyType,
304    key_data: KeyMaterial<T>,
305}
306
307impl<T: KeyKind> fmt::Debug for Key<T> {
308    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309        f.debug_struct("Key")
310            .field("key_type", &self.key_type)
311            .field("bytes_len", &self.key_data.len())
312            .finish()
313    }
314}
315
316impl<T: KeyKind> Key<T> {
317    /// Create a `Key<Kind>` instance from a byte string created with the
318    /// function `Key<Kind>::serialize(&self)`.
319    pub fn deserialize(value: &[u8]) -> Result<Self> {
320        if value.is_empty() {
321            return Err(SignalProtocolError::NoKeyTypeIdentifier);
322        }
323        let key_type = KeyType::try_from(value[0])?;
324        if value.len() != T::key_length(key_type) + 1 {
325            return Err(SignalProtocolError::BadKEMKeyLength(key_type, value.len()));
326        }
327        Ok(Key {
328            key_type,
329            key_data: KeyMaterial::new(value[1..].into()),
330        })
331    }
332    /// Create a binary representation of the key that includes a protocol identifier.
333    pub fn serialize(&self) -> Box<[u8]> {
334        let mut result = Vec::with_capacity(1 + self.key_data.len());
335        result.push(self.key_type.value());
336        result.extend_from_slice(&self.key_data);
337        result.into_boxed_slice()
338    }
339
340    /// Return the `KeyType` that identifies the KEM protocol for this key.
341    pub fn key_type(&self) -> KeyType {
342        self.key_type
343    }
344}
345
346impl Key<Public> {
347    /// Create a `SharedSecret` and a `Ciphertext`. The `Ciphertext` can be safely sent to the
348    /// holder of the corresponding `SecretKey` who can then use it to `decapsulate` the same
349    /// `SharedSecret`.
350    pub fn encapsulate<R: CryptoRng>(
351        &self,
352        csprng: &mut R,
353    ) -> Result<(SharedSecret, SerializedCiphertext)> {
354        let (ss, ct) = self
355            .key_type
356            .parameters()
357            .encapsulate(&self.key_data, csprng)?;
358        Ok((
359            ss,
360            Ciphertext {
361                key_type: self.key_type,
362                data: &ct,
363            }
364            .serialize(),
365        ))
366    }
367}
368
369impl Key<Secret> {
370    /// Decapsulates a `SharedSecret` that was encapsulated into a `Ciphertext` by a holder of
371    /// the corresponding `PublicKey`.
372    pub fn decapsulate(&self, ct_bytes: &SerializedCiphertext) -> Result<Box<[u8]>> {
373        // deserialization checks that the length is correct for the KeyType
374        let ct = Ciphertext::deserialize(ct_bytes)?;
375        if ct.key_type != self.key_type {
376            return Err(SignalProtocolError::WrongKEMKeyType(
377                ct.key_type.value(),
378                self.key_type.value(),
379            ));
380        }
381        self.key_type
382            .parameters()
383            .decapsulate(&self.key_data, ct.data)
384    }
385}
386
387impl TryFrom<&[u8]> for Key<Public> {
388    type Error = SignalProtocolError;
389
390    fn try_from(value: &[u8]) -> Result<Self> {
391        Self::deserialize(value)
392    }
393}
394
395impl TryFrom<&[u8]> for Key<Secret> {
396    type Error = SignalProtocolError;
397
398    fn try_from(value: &[u8]) -> Result<Self> {
399        Self::deserialize(value)
400    }
401}
402
403impl subtle::ConstantTimeEq for Key<Public> {
404    /// A constant-time comparison as long as the two keys have a matching type.
405    ///
406    /// If the two keys have different types, the comparison short-circuits,
407    /// much like comparing two slices of different lengths.
408    fn ct_eq(&self, other: &Self) -> subtle::Choice {
409        if self.key_type != other.key_type {
410            return 0.ct_eq(&1);
411        }
412        self.key_data.ct_eq(&other.key_data)
413    }
414}
415
416impl PartialEq for Key<Public> {
417    fn eq(&self, other: &Self) -> bool {
418        bool::from(self.ct_eq(other))
419    }
420}
421
422impl Eq for Key<Public> {}
423
424/// A KEM public key with the ability to encapsulate a shared secret.
425pub type PublicKey = Key<Public>;
426
427/// A KEM secret key with the ability to decapsulate a shared secret.
428pub type SecretKey = Key<Secret>;
429
430/// A public/secret key pair for a KEM protocol.
431#[derive(Clone)]
432pub struct KeyPair {
433    pub public_key: PublicKey,
434    pub secret_key: SecretKey,
435}
436
437impl KeyPair {
438    /// Creates a public-secret key pair for a specified KEM protocol.
439    pub fn generate<R: Rng + CryptoRng>(key_type: KeyType, csprng: &mut R) -> Self {
440        let (pk, sk) = key_type.parameters().generate(csprng);
441        Self {
442            secret_key: SecretKey {
443                key_type,
444                key_data: sk,
445            },
446            public_key: PublicKey {
447                key_type,
448                key_data: pk,
449            },
450        }
451    }
452
453    pub fn new(public_key: PublicKey, secret_key: SecretKey) -> Self {
454        assert_eq!(public_key.key_type, secret_key.key_type);
455        Self {
456            public_key,
457            secret_key,
458        }
459    }
460
461    /// Deserialize public and secret keys that were serialized by `PublicKey::serialize()`
462    /// and `SecretKey::serialize()` respectively.
463    pub fn from_public_and_private(public_key: &[u8], secret_key: &[u8]) -> Result<Self> {
464        let public_key = PublicKey::try_from(public_key)?;
465        let secret_key = SecretKey::try_from(secret_key)?;
466        if public_key.key_type != secret_key.key_type {
467            Err(SignalProtocolError::WrongKEMKeyType(
468                secret_key.key_type.value(),
469                public_key.key_type.value(),
470            ))
471        } else {
472            Ok(Self {
473                public_key,
474                secret_key,
475            })
476        }
477    }
478}
479
480/// Utility type to handle serialization and deserialization of ciphertext data
481struct Ciphertext<'a> {
482    key_type: KeyType,
483    data: &'a [u8],
484}
485
486impl<'a> Ciphertext<'a> {
487    /// Create a `Ciphertext` instance from a byte string created with the
488    /// function `Ciphertext::serialize(&self)`.
489    pub fn deserialize(value: &'a [u8]) -> Result<Self> {
490        if value.is_empty() {
491            return Err(SignalProtocolError::NoKeyTypeIdentifier);
492        }
493        let key_type = KeyType::try_from(value[0])?;
494        if value.len() != key_type.parameters().ciphertext_length() + 1 {
495            return Err(SignalProtocolError::BadKEMCiphertextLength(
496                key_type,
497                value.len(),
498            ));
499        }
500        Ok(Ciphertext {
501            key_type,
502            data: &value[1..],
503        })
504    }
505
506    /// Create a binary representation of the key that includes a protocol identifier.
507    pub fn serialize(&self) -> SerializedCiphertext {
508        let mut result = Vec::with_capacity(1 + self.data.len());
509        result.push(self.key_type.value());
510        result.extend_from_slice(self.data);
511        result.into_boxed_slice()
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use rand::{Rng as _, TryRngCore as _};
518
519    use super::*;
520
521    #[test]
522    fn test_serialize() {
523        let pk_bytes = include_bytes!("kem/test-data/pk.dat");
524        let sk_bytes = include_bytes!("kem/test-data/sk.dat");
525
526        let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
527        serialized_pk.push(KeyType::Kyber1024.value());
528        serialized_pk.extend_from_slice(pk_bytes);
529
530        let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
531        serialized_sk.push(KeyType::Kyber1024.value());
532        serialized_sk.extend_from_slice(sk_bytes);
533
534        let pk = PublicKey::deserialize(serialized_pk.as_slice()).expect("desrialize pk");
535        let sk = SecretKey::deserialize(serialized_sk.as_slice()).expect("desrialize sk");
536
537        let reserialized_pk = pk.serialize();
538        let reserialized_sk = sk.serialize();
539
540        assert_eq!(serialized_pk, reserialized_pk.into_vec());
541        assert_eq!(serialized_sk, reserialized_sk.into_vec());
542    }
543
544    #[test]
545    fn test_raw_kem() {
546        use libcrux_ml_kem::kyber1024::{decapsulate, encapsulate, generate_key_pair};
547        let mut rng = rand::rngs::OsRng.unwrap_err();
548        let (sk, pk) = generate_key_pair(rng.random()).into_parts();
549        let (ct, ss1) = encapsulate(&pk, rng.random());
550        let ss2 = decapsulate(&sk, &ct);
551        assert!(ss1 == ss2);
552    }
553
554    #[test]
555    fn test_kyber1024_kem() {
556        // test data for kyber1024
557        let pk_bytes = include_bytes!("kem/test-data/pk.dat");
558        let sk_bytes = include_bytes!("kem/test-data/sk.dat");
559        let mut rng = rand::rngs::OsRng.unwrap_err();
560
561        let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
562        serialized_pk.push(KeyType::Kyber1024.value());
563        serialized_pk.extend_from_slice(pk_bytes);
564
565        let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
566        serialized_sk.push(KeyType::Kyber1024.value());
567        serialized_sk.extend_from_slice(sk_bytes);
568
569        let pubkey = PublicKey::deserialize(serialized_pk.as_slice()).expect("deserialize pubkey");
570        let secretkey =
571            SecretKey::deserialize(serialized_sk.as_slice()).expect("deserialize secretkey");
572
573        assert_eq!(pubkey.key_type, KeyType::Kyber1024);
574        let (ss_for_sender, ct) = pubkey.encapsulate(&mut rng).expect("encapsulation works");
575        let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
576
577        assert_eq!(ss_for_sender, ss_for_recipient);
578    }
579
580    #[cfg(feature = "mlkem1024")]
581    #[test]
582    fn test_mlkem1024_kem() {
583        // test data for kyber1024
584        let pk_bytes = include_bytes!("kem/test-data/mlkem-pk.dat");
585        let sk_bytes = include_bytes!("kem/test-data/mlkem-sk.dat");
586        let mut rng = rand::rngs::OsRng.unwrap_err();
587
588        let pubkey = PublicKey::deserialize(pk_bytes).expect("deserialize pubkey");
589        let secretkey = SecretKey::deserialize(sk_bytes).expect("deserialize secretkey");
590
591        assert_eq!(pubkey.key_type, KeyType::MLKEM1024);
592        let (ss_for_sender, ct) = pubkey.encapsulate(&mut rng).expect("encapsulation works");
593        let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
594
595        assert_eq!(ss_for_sender, ss_for_recipient);
596    }
597
598    #[test]
599    fn test_kyber1024_keypair() {
600        let mut rng = rand::rngs::OsRng.unwrap_err();
601        let kp = KeyPair::generate(KeyType::Kyber1024, &mut rng);
602        assert_eq!(
603            kyber1024::Parameters::SECRET_KEY_LENGTH + 1,
604            kp.secret_key.serialize().len()
605        );
606        assert_eq!(
607            kyber1024::Parameters::PUBLIC_KEY_LENGTH + 1,
608            kp.public_key.serialize().len()
609        );
610        let (ss_for_sender, ct) = kp
611            .public_key
612            .encapsulate(&mut rng)
613            .expect("encapsulation works");
614        assert_eq!(kyber1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
615        assert_eq!(
616            kyber1024::Parameters::SHARED_SECRET_LENGTH,
617            ss_for_sender.len()
618        );
619        let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
620        assert_eq!(ss_for_recipient, ss_for_sender);
621    }
622
623    #[cfg(feature = "kyber768")]
624    #[test]
625    fn test_kyber768_keypair() {
626        let mut rng = rand::rngs::OsRng.unwrap_err();
627        let kp = KeyPair::generate(KeyType::Kyber768, &mut rng);
628        assert_eq!(
629            kyber768::Parameters::SECRET_KEY_LENGTH + 1,
630            kp.secret_key.serialize().len()
631        );
632        assert_eq!(
633            kyber768::Parameters::PUBLIC_KEY_LENGTH + 1,
634            kp.public_key.serialize().len()
635        );
636        let (ss_for_sender, ct) = kp
637            .public_key
638            .encapsulate(&mut rng)
639            .expect("encapsulation works");
640        assert_eq!(kyber768::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
641        assert_eq!(
642            kyber768::Parameters::SHARED_SECRET_LENGTH,
643            ss_for_sender.len()
644        );
645        let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
646        assert_eq!(ss_for_recipient, ss_for_sender);
647    }
648
649    #[cfg(feature = "mlkem1024")]
650    #[test]
651    fn test_mlkem1024_keypair() {
652        let mut rng = rand::rngs::OsRng.unwrap_err();
653        let kp = KeyPair::generate(KeyType::MLKEM1024, &mut rng);
654        assert_eq!(
655            mlkem1024::Parameters::SECRET_KEY_LENGTH + 1,
656            kp.secret_key.serialize().len()
657        );
658        assert_eq!(
659            mlkem1024::Parameters::PUBLIC_KEY_LENGTH + 1,
660            kp.public_key.serialize().len()
661        );
662        let (ss_for_sender, ct) = kp
663            .public_key
664            .encapsulate(&mut rng)
665            .expect("encapsulation works");
666        assert_eq!(mlkem1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
667        assert_eq!(
668            mlkem1024::Parameters::SHARED_SECRET_LENGTH,
669            ss_for_sender.len()
670        );
671        let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
672        assert_eq!(ss_for_recipient, ss_for_sender);
673    }
674
675    #[test]
676    fn test_dyn_parameters_consts() {
677        assert_eq!(
678            kyber1024::Parameters::SECRET_KEY_LENGTH,
679            kyber1024::Parameters.secret_key_length()
680        );
681        assert_eq!(
682            kyber1024::Parameters::PUBLIC_KEY_LENGTH,
683            kyber1024::Parameters.public_key_length()
684        );
685        assert_eq!(
686            kyber1024::Parameters::CIPHERTEXT_LENGTH,
687            kyber1024::Parameters.ciphertext_length()
688        );
689        assert_eq!(
690            kyber1024::Parameters::SHARED_SECRET_LENGTH,
691            kyber1024::Parameters.shared_secret_length()
692        );
693    }
694}