1mod kyber1024;
56#[cfg(feature = "kyber768")]
57mod kyber768;
58#[cfg(feature = "mlkem1024")]
59mod mlkem1024;
60
61use std::fmt;
62use std::marker::PhantomData;
63
64use derive_where::derive_where;
65use displaydoc::Display;
66use rand::{CryptoRng, Rng};
67use subtle::ConstantTimeEq;
68
69use crate::{Result, SignalProtocolError};
70
71type SharedSecret = Box<[u8]>;
72
73pub(crate) type RawCiphertext = Box<[u8]>;
75pub type SerializedCiphertext = Box<[u8]>;
76
77trait Parameters {
90 const KEY_TYPE: KeyType;
91 const PUBLIC_KEY_LENGTH: usize;
92 const SECRET_KEY_LENGTH: usize;
93 const CIPHERTEXT_LENGTH: usize;
94 const SHARED_SECRET_LENGTH: usize;
95 fn generate<R: CryptoRng + ?Sized>(
96 csprng: &mut R,
97 ) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
98 fn encapsulate<R: CryptoRng + ?Sized>(
99 pub_key: &KeyMaterial<Public>,
100 csprng: &mut R,
101 ) -> std::result::Result<(SharedSecret, RawCiphertext), BadKEMKeyLength>;
102 fn decapsulate(
103 secret_key: &KeyMaterial<Secret>,
104 ciphertext: &[u8],
105 ) -> std::result::Result<SharedSecret, DecapsulateError>;
106}
107
108trait DynParameters {
110 fn public_key_length(&self) -> usize;
111 fn secret_key_length(&self) -> usize;
112 fn ciphertext_length(&self) -> usize;
113 #[cfg_attr(not(test), expect(dead_code))]
114 fn shared_secret_length(&self) -> usize;
115 fn generate(&self, rng: &mut dyn CryptoRng) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
116 fn encapsulate(
117 &self,
118 pub_key: &KeyMaterial<Public>,
119 csprng: &mut dyn CryptoRng,
120 ) -> Result<(SharedSecret, RawCiphertext)>;
121 fn decapsulate(
122 &self,
123 secret_key: &KeyMaterial<Secret>,
124 ciphertext: &[u8],
125 ) -> Result<SharedSecret>;
126}
127
128impl<T: Parameters> DynParameters for T {
129 fn public_key_length(&self) -> usize {
130 Self::PUBLIC_KEY_LENGTH
131 }
132
133 fn secret_key_length(&self) -> usize {
134 Self::SECRET_KEY_LENGTH
135 }
136
137 fn ciphertext_length(&self) -> usize {
138 Self::CIPHERTEXT_LENGTH
139 }
140
141 fn shared_secret_length(&self) -> usize {
142 Self::SHARED_SECRET_LENGTH
143 }
144
145 fn generate(&self, csprng: &mut dyn CryptoRng) -> (KeyMaterial<Public>, KeyMaterial<Secret>) {
146 Self::generate(csprng)
147 }
148
149 fn encapsulate(
150 &self,
151 pub_key: &KeyMaterial<Public>,
152 csprng: &mut dyn CryptoRng,
153 ) -> Result<(Box<[u8]>, Box<[u8]>)> {
154 Self::encapsulate(pub_key, csprng).map_err(|BadKEMKeyLength| {
155 SignalProtocolError::BadKEMKeyLength(T::KEY_TYPE, pub_key.len())
156 })
157 }
158
159 fn decapsulate(
160 &self,
161 secret_key: &KeyMaterial<Secret>,
162 ciphertext: &[u8],
163 ) -> Result<SharedSecret> {
164 Self::decapsulate(secret_key, ciphertext).map_err(|e| match e {
165 DecapsulateError::BadKeyLength => {
166 SignalProtocolError::BadKEMKeyLength(T::KEY_TYPE, secret_key.len())
167 }
168 DecapsulateError::BadCiphertext => {
169 SignalProtocolError::BadKEMCiphertextLength(T::KEY_TYPE, ciphertext.len())
170 }
171 })
172 }
173}
174
175trait ConstantLength {
177 const LENGTH: usize;
178}
179
180impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemPrivateKey<N> {
181 const LENGTH: usize = N;
182}
183impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemPublicKey<N> {
184 const LENGTH: usize = N;
185}
186impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemCiphertext<N> {
187 const LENGTH: usize = N;
188}
189
190struct BadKEMKeyLength;
192
193enum DecapsulateError {
195 BadKeyLength,
196 BadCiphertext,
197}
198
199#[derive(Display, Debug, Copy, Clone, PartialEq, Eq)]
201pub enum KeyType {
202 #[cfg(feature = "kyber768")]
204 Kyber768,
205 Kyber1024,
207 #[cfg(feature = "mlkem1024")]
209 MLKEM1024,
210}
211
212impl KeyType {
213 fn value(&self) -> u8 {
214 match self {
215 #[cfg(feature = "kyber768")]
216 KeyType::Kyber768 => 0x07,
217 KeyType::Kyber1024 => 0x08,
218 #[cfg(feature = "mlkem1024")]
219 KeyType::MLKEM1024 => 0x0A,
220 }
221 }
222
223 const fn parameters(&self) -> &'static dyn DynParameters {
227 match self {
228 #[cfg(feature = "kyber768")]
229 KeyType::Kyber768 => &kyber768::Parameters,
230 KeyType::Kyber1024 => &kyber1024::Parameters,
231 #[cfg(feature = "mlkem1024")]
232 KeyType::MLKEM1024 => &mlkem1024::Parameters,
233 }
234 }
235}
236
237impl TryFrom<u8> for KeyType {
238 type Error = SignalProtocolError;
239
240 fn try_from(x: u8) -> Result<Self> {
241 match x {
242 #[cfg(feature = "kyber768")]
243 0x07 => Ok(KeyType::Kyber768),
244 0x08 => Ok(KeyType::Kyber1024),
245 #[cfg(feature = "mlkem1024")]
246 0x0A => Ok(KeyType::MLKEM1024),
247 t => Err(SignalProtocolError::BadKEMKeyType(t)),
248 }
249 }
250}
251
252pub trait KeyKind {
253 fn key_length(key_type: KeyType) -> usize;
254}
255
256pub enum Public {}
257
258impl KeyKind for Public {
259 fn key_length(key_type: KeyType) -> usize {
260 key_type.parameters().public_key_length()
261 }
262}
263
264pub enum Secret {}
265
266impl KeyKind for Secret {
267 fn key_length(key_type: KeyType) -> usize {
268 key_type.parameters().secret_key_length()
269 }
270}
271
272#[derive(derive_more::Deref)]
273#[derive_where(Clone)]
274pub(crate) struct KeyMaterial<T: KeyKind> {
275 #[deref(forward)]
276 data: Box<[u8]>,
277 kind: PhantomData<T>,
278}
279
280impl<T: KeyKind> KeyMaterial<T> {
281 fn new(data: Box<[u8]>) -> Self {
282 KeyMaterial {
283 data,
284 kind: PhantomData,
285 }
286 }
287}
288
289impl<const SIZE: usize> From<libcrux_ml_kem::MlKemPublicKey<SIZE>> for KeyMaterial<Public> {
290 fn from(value: libcrux_ml_kem::MlKemPublicKey<SIZE>) -> Self {
291 KeyMaterial::new(value.as_ref().into())
292 }
293}
294
295impl<const SIZE: usize> From<libcrux_ml_kem::MlKemPrivateKey<SIZE>> for KeyMaterial<Secret> {
296 fn from(value: libcrux_ml_kem::MlKemPrivateKey<SIZE>) -> Self {
297 KeyMaterial::new(value.as_ref().into())
298 }
299}
300
301#[derive_where(Clone)]
302pub struct Key<T: KeyKind> {
303 key_type: KeyType,
304 key_data: KeyMaterial<T>,
305}
306
307impl<T: KeyKind> fmt::Debug for Key<T> {
308 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
309 f.debug_struct("Key")
310 .field("key_type", &self.key_type)
311 .field("bytes_len", &self.key_data.len())
312 .finish()
313 }
314}
315
316impl<T: KeyKind> Key<T> {
317 pub fn deserialize(value: &[u8]) -> Result<Self> {
320 if value.is_empty() {
321 return Err(SignalProtocolError::NoKeyTypeIdentifier);
322 }
323 let key_type = KeyType::try_from(value[0])?;
324 if value.len() != T::key_length(key_type) + 1 {
325 return Err(SignalProtocolError::BadKEMKeyLength(key_type, value.len()));
326 }
327 Ok(Key {
328 key_type,
329 key_data: KeyMaterial::new(value[1..].into()),
330 })
331 }
332 pub fn serialize(&self) -> Box<[u8]> {
334 let mut result = Vec::with_capacity(1 + self.key_data.len());
335 result.push(self.key_type.value());
336 result.extend_from_slice(&self.key_data);
337 result.into_boxed_slice()
338 }
339
340 pub fn key_type(&self) -> KeyType {
342 self.key_type
343 }
344}
345
346impl Key<Public> {
347 pub fn encapsulate<R: CryptoRng>(
351 &self,
352 csprng: &mut R,
353 ) -> Result<(SharedSecret, SerializedCiphertext)> {
354 let (ss, ct) = self
355 .key_type
356 .parameters()
357 .encapsulate(&self.key_data, csprng)?;
358 Ok((
359 ss,
360 Ciphertext {
361 key_type: self.key_type,
362 data: &ct,
363 }
364 .serialize(),
365 ))
366 }
367}
368
369impl Key<Secret> {
370 pub fn decapsulate(&self, ct_bytes: &SerializedCiphertext) -> Result<Box<[u8]>> {
373 let ct = Ciphertext::deserialize(ct_bytes)?;
375 if ct.key_type != self.key_type {
376 return Err(SignalProtocolError::WrongKEMKeyType(
377 ct.key_type.value(),
378 self.key_type.value(),
379 ));
380 }
381 self.key_type
382 .parameters()
383 .decapsulate(&self.key_data, ct.data)
384 }
385}
386
387impl TryFrom<&[u8]> for Key<Public> {
388 type Error = SignalProtocolError;
389
390 fn try_from(value: &[u8]) -> Result<Self> {
391 Self::deserialize(value)
392 }
393}
394
395impl TryFrom<&[u8]> for Key<Secret> {
396 type Error = SignalProtocolError;
397
398 fn try_from(value: &[u8]) -> Result<Self> {
399 Self::deserialize(value)
400 }
401}
402
403impl subtle::ConstantTimeEq for Key<Public> {
404 fn ct_eq(&self, other: &Self) -> subtle::Choice {
409 if self.key_type != other.key_type {
410 return 0.ct_eq(&1);
411 }
412 self.key_data.ct_eq(&other.key_data)
413 }
414}
415
416impl PartialEq for Key<Public> {
417 fn eq(&self, other: &Self) -> bool {
418 bool::from(self.ct_eq(other))
419 }
420}
421
422impl Eq for Key<Public> {}
423
424pub type PublicKey = Key<Public>;
426
427pub type SecretKey = Key<Secret>;
429
430#[derive(Clone)]
432pub struct KeyPair {
433 pub public_key: PublicKey,
434 pub secret_key: SecretKey,
435}
436
437impl KeyPair {
438 pub fn generate<R: Rng + CryptoRng>(key_type: KeyType, csprng: &mut R) -> Self {
440 let (pk, sk) = key_type.parameters().generate(csprng);
441 Self {
442 secret_key: SecretKey {
443 key_type,
444 key_data: sk,
445 },
446 public_key: PublicKey {
447 key_type,
448 key_data: pk,
449 },
450 }
451 }
452
453 pub fn new(public_key: PublicKey, secret_key: SecretKey) -> Self {
454 assert_eq!(public_key.key_type, secret_key.key_type);
455 Self {
456 public_key,
457 secret_key,
458 }
459 }
460
461 pub fn from_public_and_private(public_key: &[u8], secret_key: &[u8]) -> Result<Self> {
464 let public_key = PublicKey::try_from(public_key)?;
465 let secret_key = SecretKey::try_from(secret_key)?;
466 if public_key.key_type != secret_key.key_type {
467 Err(SignalProtocolError::WrongKEMKeyType(
468 secret_key.key_type.value(),
469 public_key.key_type.value(),
470 ))
471 } else {
472 Ok(Self {
473 public_key,
474 secret_key,
475 })
476 }
477 }
478}
479
480struct Ciphertext<'a> {
482 key_type: KeyType,
483 data: &'a [u8],
484}
485
486impl<'a> Ciphertext<'a> {
487 pub fn deserialize(value: &'a [u8]) -> Result<Self> {
490 if value.is_empty() {
491 return Err(SignalProtocolError::NoKeyTypeIdentifier);
492 }
493 let key_type = KeyType::try_from(value[0])?;
494 if value.len() != key_type.parameters().ciphertext_length() + 1 {
495 return Err(SignalProtocolError::BadKEMCiphertextLength(
496 key_type,
497 value.len(),
498 ));
499 }
500 Ok(Ciphertext {
501 key_type,
502 data: &value[1..],
503 })
504 }
505
506 pub fn serialize(&self) -> SerializedCiphertext {
508 let mut result = Vec::with_capacity(1 + self.data.len());
509 result.push(self.key_type.value());
510 result.extend_from_slice(self.data);
511 result.into_boxed_slice()
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use rand::{Rng as _, TryRngCore as _};
518
519 use super::*;
520
521 #[test]
522 fn test_serialize() {
523 let pk_bytes = include_bytes!("kem/test-data/pk.dat");
524 let sk_bytes = include_bytes!("kem/test-data/sk.dat");
525
526 let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
527 serialized_pk.push(KeyType::Kyber1024.value());
528 serialized_pk.extend_from_slice(pk_bytes);
529
530 let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
531 serialized_sk.push(KeyType::Kyber1024.value());
532 serialized_sk.extend_from_slice(sk_bytes);
533
534 let pk = PublicKey::deserialize(serialized_pk.as_slice()).expect("desrialize pk");
535 let sk = SecretKey::deserialize(serialized_sk.as_slice()).expect("desrialize sk");
536
537 let reserialized_pk = pk.serialize();
538 let reserialized_sk = sk.serialize();
539
540 assert_eq!(serialized_pk, reserialized_pk.into_vec());
541 assert_eq!(serialized_sk, reserialized_sk.into_vec());
542 }
543
544 #[test]
545 fn test_raw_kem() {
546 use libcrux_ml_kem::kyber1024::{decapsulate, encapsulate, generate_key_pair};
547 let mut rng = rand::rngs::OsRng.unwrap_err();
548 let (sk, pk) = generate_key_pair(rng.random()).into_parts();
549 let (ct, ss1) = encapsulate(&pk, rng.random());
550 let ss2 = decapsulate(&sk, &ct);
551 assert!(ss1 == ss2);
552 }
553
554 #[test]
555 fn test_kyber1024_kem() {
556 let pk_bytes = include_bytes!("kem/test-data/pk.dat");
558 let sk_bytes = include_bytes!("kem/test-data/sk.dat");
559 let mut rng = rand::rngs::OsRng.unwrap_err();
560
561 let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
562 serialized_pk.push(KeyType::Kyber1024.value());
563 serialized_pk.extend_from_slice(pk_bytes);
564
565 let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
566 serialized_sk.push(KeyType::Kyber1024.value());
567 serialized_sk.extend_from_slice(sk_bytes);
568
569 let pubkey = PublicKey::deserialize(serialized_pk.as_slice()).expect("deserialize pubkey");
570 let secretkey =
571 SecretKey::deserialize(serialized_sk.as_slice()).expect("deserialize secretkey");
572
573 assert_eq!(pubkey.key_type, KeyType::Kyber1024);
574 let (ss_for_sender, ct) = pubkey.encapsulate(&mut rng).expect("encapsulation works");
575 let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
576
577 assert_eq!(ss_for_sender, ss_for_recipient);
578 }
579
580 #[cfg(feature = "mlkem1024")]
581 #[test]
582 fn test_mlkem1024_kem() {
583 let pk_bytes = include_bytes!("kem/test-data/mlkem-pk.dat");
585 let sk_bytes = include_bytes!("kem/test-data/mlkem-sk.dat");
586 let mut rng = rand::rngs::OsRng.unwrap_err();
587
588 let pubkey = PublicKey::deserialize(pk_bytes).expect("deserialize pubkey");
589 let secretkey = SecretKey::deserialize(sk_bytes).expect("deserialize secretkey");
590
591 assert_eq!(pubkey.key_type, KeyType::MLKEM1024);
592 let (ss_for_sender, ct) = pubkey.encapsulate(&mut rng).expect("encapsulation works");
593 let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
594
595 assert_eq!(ss_for_sender, ss_for_recipient);
596 }
597
598 #[test]
599 fn test_kyber1024_keypair() {
600 let mut rng = rand::rngs::OsRng.unwrap_err();
601 let kp = KeyPair::generate(KeyType::Kyber1024, &mut rng);
602 assert_eq!(
603 kyber1024::Parameters::SECRET_KEY_LENGTH + 1,
604 kp.secret_key.serialize().len()
605 );
606 assert_eq!(
607 kyber1024::Parameters::PUBLIC_KEY_LENGTH + 1,
608 kp.public_key.serialize().len()
609 );
610 let (ss_for_sender, ct) = kp
611 .public_key
612 .encapsulate(&mut rng)
613 .expect("encapsulation works");
614 assert_eq!(kyber1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
615 assert_eq!(
616 kyber1024::Parameters::SHARED_SECRET_LENGTH,
617 ss_for_sender.len()
618 );
619 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
620 assert_eq!(ss_for_recipient, ss_for_sender);
621 }
622
623 #[cfg(feature = "kyber768")]
624 #[test]
625 fn test_kyber768_keypair() {
626 let mut rng = rand::rngs::OsRng.unwrap_err();
627 let kp = KeyPair::generate(KeyType::Kyber768, &mut rng);
628 assert_eq!(
629 kyber768::Parameters::SECRET_KEY_LENGTH + 1,
630 kp.secret_key.serialize().len()
631 );
632 assert_eq!(
633 kyber768::Parameters::PUBLIC_KEY_LENGTH + 1,
634 kp.public_key.serialize().len()
635 );
636 let (ss_for_sender, ct) = kp
637 .public_key
638 .encapsulate(&mut rng)
639 .expect("encapsulation works");
640 assert_eq!(kyber768::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
641 assert_eq!(
642 kyber768::Parameters::SHARED_SECRET_LENGTH,
643 ss_for_sender.len()
644 );
645 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
646 assert_eq!(ss_for_recipient, ss_for_sender);
647 }
648
649 #[cfg(feature = "mlkem1024")]
650 #[test]
651 fn test_mlkem1024_keypair() {
652 let mut rng = rand::rngs::OsRng.unwrap_err();
653 let kp = KeyPair::generate(KeyType::MLKEM1024, &mut rng);
654 assert_eq!(
655 mlkem1024::Parameters::SECRET_KEY_LENGTH + 1,
656 kp.secret_key.serialize().len()
657 );
658 assert_eq!(
659 mlkem1024::Parameters::PUBLIC_KEY_LENGTH + 1,
660 kp.public_key.serialize().len()
661 );
662 let (ss_for_sender, ct) = kp
663 .public_key
664 .encapsulate(&mut rng)
665 .expect("encapsulation works");
666 assert_eq!(mlkem1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
667 assert_eq!(
668 mlkem1024::Parameters::SHARED_SECRET_LENGTH,
669 ss_for_sender.len()
670 );
671 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
672 assert_eq!(ss_for_recipient, ss_for_sender);
673 }
674
675 #[test]
676 fn test_dyn_parameters_consts() {
677 assert_eq!(
678 kyber1024::Parameters::SECRET_KEY_LENGTH,
679 kyber1024::Parameters.secret_key_length()
680 );
681 assert_eq!(
682 kyber1024::Parameters::PUBLIC_KEY_LENGTH,
683 kyber1024::Parameters.public_key_length()
684 );
685 assert_eq!(
686 kyber1024::Parameters::CIPHERTEXT_LENGTH,
687 kyber1024::Parameters.ciphertext_length()
688 );
689 assert_eq!(
690 kyber1024::Parameters::SHARED_SECRET_LENGTH,
691 kyber1024::Parameters.shared_secret_length()
692 );
693 }
694}