1mod kyber1024;
56#[cfg(any(feature = "kyber768", test))]
57mod kyber768;
58#[cfg(feature = "mlkem1024")]
59mod mlkem1024;
60
61use std::marker::PhantomData;
62
63use derive_where::derive_where;
64use displaydoc::Display;
65use rand::{CryptoRng, Rng};
66use subtle::ConstantTimeEq;
67
68use crate::{Result, SignalProtocolError};
69
70type SharedSecret = Box<[u8]>;
71
72pub(crate) type RawCiphertext = Box<[u8]>;
74pub type SerializedCiphertext = Box<[u8]>;
75
76trait Parameters {
89 const KEY_TYPE: KeyType;
90 const PUBLIC_KEY_LENGTH: usize;
91 const SECRET_KEY_LENGTH: usize;
92 const CIPHERTEXT_LENGTH: usize;
93 #[cfg_attr(not(test), expect(dead_code))]
94 const SHARED_SECRET_LENGTH: usize;
95 fn generate<R: CryptoRng + ?Sized>(
96 csprng: &mut R,
97 ) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
98 fn encapsulate<R: CryptoRng + ?Sized>(
99 pub_key: &KeyMaterial<Public>,
100 csprng: &mut R,
101 ) -> std::result::Result<(SharedSecret, RawCiphertext), BadKEMKeyLength>;
102 fn decapsulate(
103 secret_key: &KeyMaterial<Secret>,
104 ciphertext: &[u8],
105 ) -> std::result::Result<SharedSecret, DecapsulateError>;
106}
107
108trait DynParameters {
110 fn public_key_length(&self) -> usize;
111 fn secret_key_length(&self) -> usize;
112 fn ciphertext_length(&self) -> usize;
113 #[cfg_attr(not(test), expect(dead_code))]
114 fn shared_secret_length(&self) -> usize;
115 fn generate(&self, rng: &mut dyn CryptoRng) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
116 fn encapsulate(
117 &self,
118 pub_key: &KeyMaterial<Public>,
119 csprng: &mut dyn CryptoRng,
120 ) -> Result<(SharedSecret, RawCiphertext)>;
121 fn decapsulate(
122 &self,
123 secret_key: &KeyMaterial<Secret>,
124 ciphertext: &[u8],
125 ) -> Result<SharedSecret>;
126}
127
128impl<T: Parameters> DynParameters for T {
129 fn public_key_length(&self) -> usize {
130 Self::PUBLIC_KEY_LENGTH
131 }
132
133 fn secret_key_length(&self) -> usize {
134 Self::SECRET_KEY_LENGTH
135 }
136
137 fn ciphertext_length(&self) -> usize {
138 Self::CIPHERTEXT_LENGTH
139 }
140
141 fn shared_secret_length(&self) -> usize {
142 Self::SHARED_SECRET_LENGTH
143 }
144
145 fn generate(&self, csprng: &mut dyn CryptoRng) -> (KeyMaterial<Public>, KeyMaterial<Secret>) {
146 Self::generate(csprng)
147 }
148
149 fn encapsulate(
150 &self,
151 pub_key: &KeyMaterial<Public>,
152 csprng: &mut dyn CryptoRng,
153 ) -> Result<(Box<[u8]>, Box<[u8]>)> {
154 Self::encapsulate(pub_key, csprng).map_err(|BadKEMKeyLength| {
155 SignalProtocolError::BadKEMKeyLength(T::KEY_TYPE, pub_key.len())
156 })
157 }
158
159 fn decapsulate(
160 &self,
161 secret_key: &KeyMaterial<Secret>,
162 ciphertext: &[u8],
163 ) -> Result<SharedSecret> {
164 Self::decapsulate(secret_key, ciphertext).map_err(|e| match e {
165 DecapsulateError::BadKeyLength => {
166 SignalProtocolError::BadKEMKeyLength(T::KEY_TYPE, secret_key.len())
167 }
168 DecapsulateError::BadCiphertext => {
169 SignalProtocolError::BadKEMCiphertextLength(T::KEY_TYPE, ciphertext.len())
170 }
171 })
172 }
173}
174
175trait ConstantLength {
177 const LENGTH: usize;
178}
179
180impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemPrivateKey<N> {
181 const LENGTH: usize = N;
182}
183impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemPublicKey<N> {
184 const LENGTH: usize = N;
185}
186impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemCiphertext<N> {
187 const LENGTH: usize = N;
188}
189
190struct BadKEMKeyLength;
192
193enum DecapsulateError {
195 BadKeyLength,
196 BadCiphertext,
197}
198
199#[derive(Display, Debug, Copy, Clone, PartialEq, Eq)]
201pub enum KeyType {
202 #[cfg(any(feature = "kyber768", test))]
204 Kyber768,
205 Kyber1024,
207 #[cfg(feature = "mlkem1024")]
209 MLKEM1024,
210}
211
212impl KeyType {
213 fn value(&self) -> u8 {
214 match self {
215 #[cfg(any(feature = "kyber768", test))]
216 KeyType::Kyber768 => 0x07,
217 KeyType::Kyber1024 => 0x08,
218 #[cfg(feature = "mlkem1024")]
219 KeyType::MLKEM1024 => 0x0A,
220 }
221 }
222
223 const fn parameters(&self) -> &'static dyn DynParameters {
227 match self {
228 #[cfg(any(feature = "kyber768", test))]
229 KeyType::Kyber768 => &kyber768::Parameters,
230 KeyType::Kyber1024 => &kyber1024::Parameters,
231 #[cfg(feature = "mlkem1024")]
232 KeyType::MLKEM1024 => &mlkem1024::Parameters,
233 }
234 }
235}
236
237impl TryFrom<u8> for KeyType {
238 type Error = SignalProtocolError;
239
240 fn try_from(x: u8) -> Result<Self> {
241 match x {
242 #[cfg(any(feature = "kyber768", test))]
243 0x07 => Ok(KeyType::Kyber768),
244 0x08 => Ok(KeyType::Kyber1024),
245 #[cfg(feature = "mlkem1024")]
246 0x0A => Ok(KeyType::MLKEM1024),
247 t => Err(SignalProtocolError::BadKEMKeyType(t)),
248 }
249 }
250}
251
252pub trait KeyKind {
253 fn key_length(key_type: KeyType) -> usize;
254}
255
256pub enum Public {}
257
258impl KeyKind for Public {
259 fn key_length(key_type: KeyType) -> usize {
260 key_type.parameters().public_key_length()
261 }
262}
263
264pub enum Secret {}
265
266impl KeyKind for Secret {
267 fn key_length(key_type: KeyType) -> usize {
268 key_type.parameters().secret_key_length()
269 }
270}
271
272#[derive(derive_more::Deref)]
273#[derive_where(Clone)]
274pub(crate) struct KeyMaterial<T: KeyKind> {
275 #[deref(forward)]
276 data: Box<[u8]>,
277 kind: PhantomData<T>,
278}
279
280impl<T: KeyKind> KeyMaterial<T> {
281 fn new(data: Box<[u8]>) -> Self {
282 KeyMaterial {
283 data,
284 kind: PhantomData,
285 }
286 }
287}
288
289impl<const SIZE: usize> From<libcrux_ml_kem::MlKemPublicKey<SIZE>> for KeyMaterial<Public> {
290 fn from(value: libcrux_ml_kem::MlKemPublicKey<SIZE>) -> Self {
291 KeyMaterial::new(value.as_ref().into())
292 }
293}
294
295impl<const SIZE: usize> From<libcrux_ml_kem::MlKemPrivateKey<SIZE>> for KeyMaterial<Secret> {
296 fn from(value: libcrux_ml_kem::MlKemPrivateKey<SIZE>) -> Self {
297 KeyMaterial::new(value.as_ref().into())
298 }
299}
300
301#[derive_where(Clone)]
302pub struct Key<T: KeyKind> {
303 key_type: KeyType,
304 key_data: KeyMaterial<T>,
305}
306
307impl<T: KeyKind> Key<T> {
308 pub fn deserialize(value: &[u8]) -> Result<Self> {
311 if value.is_empty() {
312 return Err(SignalProtocolError::NoKeyTypeIdentifier);
313 }
314 let key_type = KeyType::try_from(value[0])?;
315 if value.len() != T::key_length(key_type) + 1 {
316 return Err(SignalProtocolError::BadKEMKeyLength(key_type, value.len()));
317 }
318 Ok(Key {
319 key_type,
320 key_data: KeyMaterial::new(value[1..].into()),
321 })
322 }
323 pub fn serialize(&self) -> Box<[u8]> {
325 let mut result = Vec::with_capacity(1 + self.key_data.len());
326 result.push(self.key_type.value());
327 result.extend_from_slice(&self.key_data);
328 result.into_boxed_slice()
329 }
330
331 pub fn key_type(&self) -> KeyType {
333 self.key_type
334 }
335}
336
337impl Key<Public> {
338 pub fn encapsulate<R: CryptoRng>(
342 &self,
343 csprng: &mut R,
344 ) -> Result<(SharedSecret, SerializedCiphertext)> {
345 let (ss, ct) = self
346 .key_type
347 .parameters()
348 .encapsulate(&self.key_data, csprng)?;
349 Ok((
350 ss,
351 Ciphertext {
352 key_type: self.key_type,
353 data: &ct,
354 }
355 .serialize(),
356 ))
357 }
358}
359
360impl Key<Secret> {
361 pub fn decapsulate(&self, ct_bytes: &SerializedCiphertext) -> Result<Box<[u8]>> {
364 let ct = Ciphertext::deserialize(ct_bytes)?;
366 if ct.key_type != self.key_type {
367 return Err(SignalProtocolError::WrongKEMKeyType(
368 ct.key_type.value(),
369 self.key_type.value(),
370 ));
371 }
372 self.key_type
373 .parameters()
374 .decapsulate(&self.key_data, ct.data)
375 }
376}
377
378impl TryFrom<&[u8]> for Key<Public> {
379 type Error = SignalProtocolError;
380
381 fn try_from(value: &[u8]) -> Result<Self> {
382 Self::deserialize(value)
383 }
384}
385
386impl TryFrom<&[u8]> for Key<Secret> {
387 type Error = SignalProtocolError;
388
389 fn try_from(value: &[u8]) -> Result<Self> {
390 Self::deserialize(value)
391 }
392}
393
394impl subtle::ConstantTimeEq for Key<Public> {
395 fn ct_eq(&self, other: &Self) -> subtle::Choice {
400 if self.key_type != other.key_type {
401 return 0.ct_eq(&1);
402 }
403 self.key_data.ct_eq(&other.key_data)
404 }
405}
406
407impl PartialEq for Key<Public> {
408 fn eq(&self, other: &Self) -> bool {
409 bool::from(self.ct_eq(other))
410 }
411}
412
413impl Eq for Key<Public> {}
414
415pub type PublicKey = Key<Public>;
417
418pub type SecretKey = Key<Secret>;
420
421#[derive(Clone)]
423pub struct KeyPair {
424 pub public_key: PublicKey,
425 pub secret_key: SecretKey,
426}
427
428impl KeyPair {
429 pub fn generate<R: Rng + CryptoRng>(key_type: KeyType, csprng: &mut R) -> Self {
431 let (pk, sk) = key_type.parameters().generate(csprng);
432 Self {
433 secret_key: SecretKey {
434 key_type,
435 key_data: sk,
436 },
437 public_key: PublicKey {
438 key_type,
439 key_data: pk,
440 },
441 }
442 }
443
444 pub fn new(public_key: PublicKey, secret_key: SecretKey) -> Self {
445 assert_eq!(public_key.key_type, secret_key.key_type);
446 Self {
447 public_key,
448 secret_key,
449 }
450 }
451
452 pub fn from_public_and_private(public_key: &[u8], secret_key: &[u8]) -> Result<Self> {
455 let public_key = PublicKey::try_from(public_key)?;
456 let secret_key = SecretKey::try_from(secret_key)?;
457 if public_key.key_type != secret_key.key_type {
458 Err(SignalProtocolError::WrongKEMKeyType(
459 secret_key.key_type.value(),
460 public_key.key_type.value(),
461 ))
462 } else {
463 Ok(Self {
464 public_key,
465 secret_key,
466 })
467 }
468 }
469}
470
471struct Ciphertext<'a> {
473 key_type: KeyType,
474 data: &'a [u8],
475}
476
477impl<'a> Ciphertext<'a> {
478 pub fn deserialize(value: &'a [u8]) -> Result<Self> {
481 if value.is_empty() {
482 return Err(SignalProtocolError::NoKeyTypeIdentifier);
483 }
484 let key_type = KeyType::try_from(value[0])?;
485 if value.len() != key_type.parameters().ciphertext_length() + 1 {
486 return Err(SignalProtocolError::BadKEMCiphertextLength(
487 key_type,
488 value.len(),
489 ));
490 }
491 Ok(Ciphertext {
492 key_type,
493 data: &value[1..],
494 })
495 }
496
497 pub fn serialize(&self) -> SerializedCiphertext {
499 let mut result = Vec::with_capacity(1 + self.data.len());
500 result.push(self.key_type.value());
501 result.extend_from_slice(self.data);
502 result.into_boxed_slice()
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use rand::{Rng as _, TryRngCore as _};
509
510 use super::*;
511
512 #[test]
513 fn test_serialize() {
514 let pk_bytes = include_bytes!("kem/test-data/pk.dat");
515 let sk_bytes = include_bytes!("kem/test-data/sk.dat");
516
517 let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
518 serialized_pk.push(KeyType::Kyber1024.value());
519 serialized_pk.extend_from_slice(pk_bytes);
520
521 let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
522 serialized_sk.push(KeyType::Kyber1024.value());
523 serialized_sk.extend_from_slice(sk_bytes);
524
525 let pk = PublicKey::deserialize(serialized_pk.as_slice()).expect("desrialize pk");
526 let sk = SecretKey::deserialize(serialized_sk.as_slice()).expect("desrialize sk");
527
528 let reserialized_pk = pk.serialize();
529 let reserialized_sk = sk.serialize();
530
531 assert_eq!(serialized_pk, reserialized_pk.into_vec());
532 assert_eq!(serialized_sk, reserialized_sk.into_vec());
533 }
534
535 #[test]
536 fn test_raw_kem() {
537 use libcrux_ml_kem::kyber1024::{decapsulate, encapsulate, generate_key_pair};
538 let mut rng = rand::rngs::OsRng.unwrap_err();
539 let (sk, pk) = generate_key_pair(rng.random()).into_parts();
540 let (ct, ss1) = encapsulate(&pk, rng.random());
541 let ss2 = decapsulate(&sk, &ct);
542 assert!(ss1 == ss2);
543 }
544
545 #[test]
546 fn test_kyber1024_kem() {
547 let pk_bytes = include_bytes!("kem/test-data/pk.dat");
549 let sk_bytes = include_bytes!("kem/test-data/sk.dat");
550 let mut rng = rand::rngs::OsRng.unwrap_err();
551
552 let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
553 serialized_pk.push(KeyType::Kyber1024.value());
554 serialized_pk.extend_from_slice(pk_bytes);
555
556 let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
557 serialized_sk.push(KeyType::Kyber1024.value());
558 serialized_sk.extend_from_slice(sk_bytes);
559
560 let pubkey = PublicKey::deserialize(serialized_pk.as_slice()).expect("deserialize pubkey");
561 let secretkey =
562 SecretKey::deserialize(serialized_sk.as_slice()).expect("deserialize secretkey");
563
564 assert_eq!(pubkey.key_type, KeyType::Kyber1024);
565 let (ss_for_sender, ct) = pubkey.encapsulate(&mut rng).expect("encapsulation works");
566 let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
567
568 assert_eq!(ss_for_sender, ss_for_recipient);
569 }
570
571 #[cfg(feature = "mlkem1024")]
572 #[test]
573 fn test_mlkem1024_kem() {
574 let pk_bytes = include_bytes!("kem/test-data/mlkem-pk.dat");
576 let sk_bytes = include_bytes!("kem/test-data/mlkem-sk.dat");
577 let mut rng = rand::rngs::OsRng.unwrap_err();
578
579 let pubkey = PublicKey::deserialize(pk_bytes).expect("deserialize pubkey");
580 let secretkey = SecretKey::deserialize(sk_bytes).expect("deserialize secretkey");
581
582 assert_eq!(pubkey.key_type, KeyType::MLKEM1024);
583 let (ss_for_sender, ct) = pubkey.encapsulate(&mut rng).expect("encapsulation works");
584 let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
585
586 assert_eq!(ss_for_sender, ss_for_recipient);
587 }
588
589 #[test]
590 fn test_kyber1024_keypair() {
591 let mut rng = rand::rngs::OsRng.unwrap_err();
592 let kp = KeyPair::generate(KeyType::Kyber1024, &mut rng);
593 assert_eq!(
594 kyber1024::Parameters::SECRET_KEY_LENGTH + 1,
595 kp.secret_key.serialize().len()
596 );
597 assert_eq!(
598 kyber1024::Parameters::PUBLIC_KEY_LENGTH + 1,
599 kp.public_key.serialize().len()
600 );
601 let (ss_for_sender, ct) = kp
602 .public_key
603 .encapsulate(&mut rng)
604 .expect("encapsulation works");
605 assert_eq!(kyber1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
606 assert_eq!(
607 kyber1024::Parameters::SHARED_SECRET_LENGTH,
608 ss_for_sender.len()
609 );
610 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
611 assert_eq!(ss_for_recipient, ss_for_sender);
612 }
613
614 #[test]
615 fn test_kyber768_keypair() {
616 let mut rng = rand::rngs::OsRng.unwrap_err();
617 let kp = KeyPair::generate(KeyType::Kyber768, &mut rng);
618 assert_eq!(
619 kyber768::Parameters::SECRET_KEY_LENGTH + 1,
620 kp.secret_key.serialize().len()
621 );
622 assert_eq!(
623 kyber768::Parameters::PUBLIC_KEY_LENGTH + 1,
624 kp.public_key.serialize().len()
625 );
626 let (ss_for_sender, ct) = kp
627 .public_key
628 .encapsulate(&mut rng)
629 .expect("encapsulation works");
630 assert_eq!(kyber768::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
631 assert_eq!(
632 kyber768::Parameters::SHARED_SECRET_LENGTH,
633 ss_for_sender.len()
634 );
635 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
636 assert_eq!(ss_for_recipient, ss_for_sender);
637 }
638
639 #[cfg(feature = "mlkem1024")]
640 #[test]
641 fn test_mlkem1024_keypair() {
642 let mut rng = rand::rngs::OsRng.unwrap_err();
643 let kp = KeyPair::generate(KeyType::MLKEM1024, &mut rng);
644 assert_eq!(
645 mlkem1024::Parameters::SECRET_KEY_LENGTH + 1,
646 kp.secret_key.serialize().len()
647 );
648 assert_eq!(
649 mlkem1024::Parameters::PUBLIC_KEY_LENGTH + 1,
650 kp.public_key.serialize().len()
651 );
652 let (ss_for_sender, ct) = kp
653 .public_key
654 .encapsulate(&mut rng)
655 .expect("encapsulation works");
656 assert_eq!(mlkem1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
657 assert_eq!(
658 mlkem1024::Parameters::SHARED_SECRET_LENGTH,
659 ss_for_sender.len()
660 );
661 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
662 assert_eq!(ss_for_recipient, ss_for_sender);
663 }
664
665 #[test]
666 fn test_dyn_parameters_consts() {
667 assert_eq!(
668 kyber1024::Parameters::SECRET_KEY_LENGTH,
669 kyber1024::Parameters.secret_key_length()
670 );
671 assert_eq!(
672 kyber1024::Parameters::PUBLIC_KEY_LENGTH,
673 kyber1024::Parameters.public_key_length()
674 );
675 assert_eq!(
676 kyber1024::Parameters::CIPHERTEXT_LENGTH,
677 kyber1024::Parameters.ciphertext_length()
678 );
679 assert_eq!(
680 kyber1024::Parameters::SHARED_SECRET_LENGTH,
681 kyber1024::Parameters.shared_secret_length()
682 );
683 }
684}