1mod kyber1024;
56#[cfg(any(feature = "kyber768", test))]
57mod kyber768;
58#[cfg(feature = "mlkem1024")]
59mod mlkem1024;
60
61use std::marker::PhantomData;
62
63use derive_where::derive_where;
64use displaydoc::Display;
65use rand::{CryptoRng, Rng};
66use subtle::ConstantTimeEq;
67
68use crate::{Result, SignalProtocolError};
69
70type SharedSecret = Box<[u8]>;
71
72pub(crate) type RawCiphertext = Box<[u8]>;
74pub type SerializedCiphertext = Box<[u8]>;
75
76trait Parameters {
89 const KEY_TYPE: KeyType;
90 const PUBLIC_KEY_LENGTH: usize;
91 const SECRET_KEY_LENGTH: usize;
92 const CIPHERTEXT_LENGTH: usize;
93 const SHARED_SECRET_LENGTH: usize;
94 fn generate<R: CryptoRng + ?Sized>(
95 csprng: &mut R,
96 ) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
97 fn encapsulate<R: CryptoRng + ?Sized>(
98 pub_key: &KeyMaterial<Public>,
99 csprng: &mut R,
100 ) -> std::result::Result<(SharedSecret, RawCiphertext), BadKEMKeyLength>;
101 fn decapsulate(
102 secret_key: &KeyMaterial<Secret>,
103 ciphertext: &[u8],
104 ) -> std::result::Result<SharedSecret, DecapsulateError>;
105}
106
107trait DynParameters {
109 fn public_key_length(&self) -> usize;
110 fn secret_key_length(&self) -> usize;
111 fn ciphertext_length(&self) -> usize;
112 #[cfg_attr(not(test), expect(dead_code))]
113 fn shared_secret_length(&self) -> usize;
114 fn generate(&self, rng: &mut dyn CryptoRng) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
115 fn encapsulate(
116 &self,
117 pub_key: &KeyMaterial<Public>,
118 csprng: &mut dyn CryptoRng,
119 ) -> Result<(SharedSecret, RawCiphertext)>;
120 fn decapsulate(
121 &self,
122 secret_key: &KeyMaterial<Secret>,
123 ciphertext: &[u8],
124 ) -> Result<SharedSecret>;
125}
126
127impl<T: Parameters> DynParameters for T {
128 fn public_key_length(&self) -> usize {
129 Self::PUBLIC_KEY_LENGTH
130 }
131
132 fn secret_key_length(&self) -> usize {
133 Self::SECRET_KEY_LENGTH
134 }
135
136 fn ciphertext_length(&self) -> usize {
137 Self::CIPHERTEXT_LENGTH
138 }
139
140 fn shared_secret_length(&self) -> usize {
141 Self::SHARED_SECRET_LENGTH
142 }
143
144 fn generate(&self, csprng: &mut dyn CryptoRng) -> (KeyMaterial<Public>, KeyMaterial<Secret>) {
145 Self::generate(csprng)
146 }
147
148 fn encapsulate(
149 &self,
150 pub_key: &KeyMaterial<Public>,
151 csprng: &mut dyn CryptoRng,
152 ) -> Result<(Box<[u8]>, Box<[u8]>)> {
153 Self::encapsulate(pub_key, csprng).map_err(|BadKEMKeyLength| {
154 SignalProtocolError::BadKEMKeyLength(T::KEY_TYPE, pub_key.len())
155 })
156 }
157
158 fn decapsulate(
159 &self,
160 secret_key: &KeyMaterial<Secret>,
161 ciphertext: &[u8],
162 ) -> Result<SharedSecret> {
163 Self::decapsulate(secret_key, ciphertext).map_err(|e| match e {
164 DecapsulateError::BadKeyLength => {
165 SignalProtocolError::BadKEMKeyLength(T::KEY_TYPE, secret_key.len())
166 }
167 DecapsulateError::BadCiphertext => {
168 SignalProtocolError::BadKEMCiphertextLength(T::KEY_TYPE, ciphertext.len())
169 }
170 })
171 }
172}
173
174trait ConstantLength {
176 const LENGTH: usize;
177}
178
179impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemPrivateKey<N> {
180 const LENGTH: usize = N;
181}
182impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemPublicKey<N> {
183 const LENGTH: usize = N;
184}
185impl<const N: usize> ConstantLength for libcrux_ml_kem::MlKemCiphertext<N> {
186 const LENGTH: usize = N;
187}
188
189struct BadKEMKeyLength;
191
192enum DecapsulateError {
194 BadKeyLength,
195 BadCiphertext,
196}
197
198#[derive(Display, Debug, Copy, Clone, PartialEq, Eq)]
200pub enum KeyType {
201 #[cfg(any(feature = "kyber768", test))]
203 Kyber768,
204 Kyber1024,
206 #[cfg(feature = "mlkem1024")]
208 MLKEM1024,
209}
210
211impl KeyType {
212 fn value(&self) -> u8 {
213 match self {
214 #[cfg(any(feature = "kyber768", test))]
215 KeyType::Kyber768 => 0x07,
216 KeyType::Kyber1024 => 0x08,
217 #[cfg(feature = "mlkem1024")]
218 KeyType::MLKEM1024 => 0x0A,
219 }
220 }
221
222 const fn parameters(&self) -> &'static dyn DynParameters {
226 match self {
227 #[cfg(any(feature = "kyber768", test))]
228 KeyType::Kyber768 => &kyber768::Parameters,
229 KeyType::Kyber1024 => &kyber1024::Parameters,
230 #[cfg(feature = "mlkem1024")]
231 KeyType::MLKEM1024 => &mlkem1024::Parameters,
232 }
233 }
234}
235
236impl TryFrom<u8> for KeyType {
237 type Error = SignalProtocolError;
238
239 fn try_from(x: u8) -> Result<Self> {
240 match x {
241 #[cfg(any(feature = "kyber768", test))]
242 0x07 => Ok(KeyType::Kyber768),
243 0x08 => Ok(KeyType::Kyber1024),
244 #[cfg(feature = "mlkem1024")]
245 0x0A => Ok(KeyType::MLKEM1024),
246 t => Err(SignalProtocolError::BadKEMKeyType(t)),
247 }
248 }
249}
250
251pub trait KeyKind {
252 fn key_length(key_type: KeyType) -> usize;
253}
254
255pub enum Public {}
256
257impl KeyKind for Public {
258 fn key_length(key_type: KeyType) -> usize {
259 key_type.parameters().public_key_length()
260 }
261}
262
263pub enum Secret {}
264
265impl KeyKind for Secret {
266 fn key_length(key_type: KeyType) -> usize {
267 key_type.parameters().secret_key_length()
268 }
269}
270
271#[derive(derive_more::Deref)]
272#[derive_where(Clone)]
273pub(crate) struct KeyMaterial<T: KeyKind> {
274 #[deref(forward)]
275 data: Box<[u8]>,
276 kind: PhantomData<T>,
277}
278
279impl<T: KeyKind> KeyMaterial<T> {
280 fn new(data: Box<[u8]>) -> Self {
281 KeyMaterial {
282 data,
283 kind: PhantomData,
284 }
285 }
286}
287
288impl<const SIZE: usize> From<libcrux_ml_kem::MlKemPublicKey<SIZE>> for KeyMaterial<Public> {
289 fn from(value: libcrux_ml_kem::MlKemPublicKey<SIZE>) -> Self {
290 KeyMaterial::new(value.as_ref().into())
291 }
292}
293
294impl<const SIZE: usize> From<libcrux_ml_kem::MlKemPrivateKey<SIZE>> for KeyMaterial<Secret> {
295 fn from(value: libcrux_ml_kem::MlKemPrivateKey<SIZE>) -> Self {
296 KeyMaterial::new(value.as_ref().into())
297 }
298}
299
300#[derive_where(Clone)]
301pub struct Key<T: KeyKind> {
302 key_type: KeyType,
303 key_data: KeyMaterial<T>,
304}
305
306impl<T: KeyKind> Key<T> {
307 pub fn deserialize(value: &[u8]) -> Result<Self> {
310 if value.is_empty() {
311 return Err(SignalProtocolError::NoKeyTypeIdentifier);
312 }
313 let key_type = KeyType::try_from(value[0])?;
314 if value.len() != T::key_length(key_type) + 1 {
315 return Err(SignalProtocolError::BadKEMKeyLength(key_type, value.len()));
316 }
317 Ok(Key {
318 key_type,
319 key_data: KeyMaterial::new(value[1..].into()),
320 })
321 }
322 pub fn serialize(&self) -> Box<[u8]> {
324 let mut result = Vec::with_capacity(1 + self.key_data.len());
325 result.push(self.key_type.value());
326 result.extend_from_slice(&self.key_data);
327 result.into_boxed_slice()
328 }
329
330 pub fn key_type(&self) -> KeyType {
332 self.key_type
333 }
334}
335
336impl Key<Public> {
337 pub fn encapsulate<R: CryptoRng>(
341 &self,
342 csprng: &mut R,
343 ) -> Result<(SharedSecret, SerializedCiphertext)> {
344 let (ss, ct) = self
345 .key_type
346 .parameters()
347 .encapsulate(&self.key_data, csprng)?;
348 Ok((
349 ss,
350 Ciphertext {
351 key_type: self.key_type,
352 data: &ct,
353 }
354 .serialize(),
355 ))
356 }
357}
358
359impl Key<Secret> {
360 pub fn decapsulate(&self, ct_bytes: &SerializedCiphertext) -> Result<Box<[u8]>> {
363 let ct = Ciphertext::deserialize(ct_bytes)?;
365 if ct.key_type != self.key_type {
366 return Err(SignalProtocolError::WrongKEMKeyType(
367 ct.key_type.value(),
368 self.key_type.value(),
369 ));
370 }
371 self.key_type
372 .parameters()
373 .decapsulate(&self.key_data, ct.data)
374 }
375}
376
377impl TryFrom<&[u8]> for Key<Public> {
378 type Error = SignalProtocolError;
379
380 fn try_from(value: &[u8]) -> Result<Self> {
381 Self::deserialize(value)
382 }
383}
384
385impl TryFrom<&[u8]> for Key<Secret> {
386 type Error = SignalProtocolError;
387
388 fn try_from(value: &[u8]) -> Result<Self> {
389 Self::deserialize(value)
390 }
391}
392
393impl subtle::ConstantTimeEq for Key<Public> {
394 fn ct_eq(&self, other: &Self) -> subtle::Choice {
399 if self.key_type != other.key_type {
400 return 0.ct_eq(&1);
401 }
402 self.key_data.ct_eq(&other.key_data)
403 }
404}
405
406impl PartialEq for Key<Public> {
407 fn eq(&self, other: &Self) -> bool {
408 bool::from(self.ct_eq(other))
409 }
410}
411
412impl Eq for Key<Public> {}
413
414pub type PublicKey = Key<Public>;
416
417pub type SecretKey = Key<Secret>;
419
420#[derive(Clone)]
422pub struct KeyPair {
423 pub public_key: PublicKey,
424 pub secret_key: SecretKey,
425}
426
427impl KeyPair {
428 pub fn generate<R: Rng + CryptoRng>(key_type: KeyType, csprng: &mut R) -> Self {
430 let (pk, sk) = key_type.parameters().generate(csprng);
431 Self {
432 secret_key: SecretKey {
433 key_type,
434 key_data: sk,
435 },
436 public_key: PublicKey {
437 key_type,
438 key_data: pk,
439 },
440 }
441 }
442
443 pub fn new(public_key: PublicKey, secret_key: SecretKey) -> Self {
444 assert_eq!(public_key.key_type, secret_key.key_type);
445 Self {
446 public_key,
447 secret_key,
448 }
449 }
450
451 pub fn from_public_and_private(public_key: &[u8], secret_key: &[u8]) -> Result<Self> {
454 let public_key = PublicKey::try_from(public_key)?;
455 let secret_key = SecretKey::try_from(secret_key)?;
456 if public_key.key_type != secret_key.key_type {
457 Err(SignalProtocolError::WrongKEMKeyType(
458 secret_key.key_type.value(),
459 public_key.key_type.value(),
460 ))
461 } else {
462 Ok(Self {
463 public_key,
464 secret_key,
465 })
466 }
467 }
468}
469
470struct Ciphertext<'a> {
472 key_type: KeyType,
473 data: &'a [u8],
474}
475
476impl<'a> Ciphertext<'a> {
477 pub fn deserialize(value: &'a [u8]) -> Result<Self> {
480 if value.is_empty() {
481 return Err(SignalProtocolError::NoKeyTypeIdentifier);
482 }
483 let key_type = KeyType::try_from(value[0])?;
484 if value.len() != key_type.parameters().ciphertext_length() + 1 {
485 return Err(SignalProtocolError::BadKEMCiphertextLength(
486 key_type,
487 value.len(),
488 ));
489 }
490 Ok(Ciphertext {
491 key_type,
492 data: &value[1..],
493 })
494 }
495
496 pub fn serialize(&self) -> SerializedCiphertext {
498 let mut result = Vec::with_capacity(1 + self.data.len());
499 result.push(self.key_type.value());
500 result.extend_from_slice(self.data);
501 result.into_boxed_slice()
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use rand::{Rng as _, TryRngCore as _};
508
509 use super::*;
510
511 #[test]
512 fn test_serialize() {
513 let pk_bytes = include_bytes!("kem/test-data/pk.dat");
514 let sk_bytes = include_bytes!("kem/test-data/sk.dat");
515
516 let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
517 serialized_pk.push(KeyType::Kyber1024.value());
518 serialized_pk.extend_from_slice(pk_bytes);
519
520 let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
521 serialized_sk.push(KeyType::Kyber1024.value());
522 serialized_sk.extend_from_slice(sk_bytes);
523
524 let pk = PublicKey::deserialize(serialized_pk.as_slice()).expect("desrialize pk");
525 let sk = SecretKey::deserialize(serialized_sk.as_slice()).expect("desrialize sk");
526
527 let reserialized_pk = pk.serialize();
528 let reserialized_sk = sk.serialize();
529
530 assert_eq!(serialized_pk, reserialized_pk.into_vec());
531 assert_eq!(serialized_sk, reserialized_sk.into_vec());
532 }
533
534 #[test]
535 fn test_raw_kem() {
536 use libcrux_ml_kem::kyber1024::{decapsulate, encapsulate, generate_key_pair};
537 let mut rng = rand::rngs::OsRng.unwrap_err();
538 let (sk, pk) = generate_key_pair(rng.random()).into_parts();
539 let (ct, ss1) = encapsulate(&pk, rng.random());
540 let ss2 = decapsulate(&sk, &ct);
541 assert!(ss1 == ss2);
542 }
543
544 #[test]
545 fn test_kyber1024_kem() {
546 let pk_bytes = include_bytes!("kem/test-data/pk.dat");
548 let sk_bytes = include_bytes!("kem/test-data/sk.dat");
549 let mut rng = rand::rngs::OsRng.unwrap_err();
550
551 let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
552 serialized_pk.push(KeyType::Kyber1024.value());
553 serialized_pk.extend_from_slice(pk_bytes);
554
555 let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
556 serialized_sk.push(KeyType::Kyber1024.value());
557 serialized_sk.extend_from_slice(sk_bytes);
558
559 let pubkey = PublicKey::deserialize(serialized_pk.as_slice()).expect("deserialize pubkey");
560 let secretkey =
561 SecretKey::deserialize(serialized_sk.as_slice()).expect("deserialize secretkey");
562
563 assert_eq!(pubkey.key_type, KeyType::Kyber1024);
564 let (ss_for_sender, ct) = pubkey.encapsulate(&mut rng).expect("encapsulation works");
565 let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
566
567 assert_eq!(ss_for_sender, ss_for_recipient);
568 }
569
570 #[cfg(feature = "mlkem1024")]
571 #[test]
572 fn test_mlkem1024_kem() {
573 let pk_bytes = include_bytes!("kem/test-data/mlkem-pk.dat");
575 let sk_bytes = include_bytes!("kem/test-data/mlkem-sk.dat");
576 let mut rng = rand::rngs::OsRng.unwrap_err();
577
578 let pubkey = PublicKey::deserialize(pk_bytes).expect("deserialize pubkey");
579 let secretkey = SecretKey::deserialize(sk_bytes).expect("deserialize secretkey");
580
581 assert_eq!(pubkey.key_type, KeyType::MLKEM1024);
582 let (ss_for_sender, ct) = pubkey.encapsulate(&mut rng).expect("encapsulation works");
583 let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
584
585 assert_eq!(ss_for_sender, ss_for_recipient);
586 }
587
588 #[test]
589 fn test_kyber1024_keypair() {
590 let mut rng = rand::rngs::OsRng.unwrap_err();
591 let kp = KeyPair::generate(KeyType::Kyber1024, &mut rng);
592 assert_eq!(
593 kyber1024::Parameters::SECRET_KEY_LENGTH + 1,
594 kp.secret_key.serialize().len()
595 );
596 assert_eq!(
597 kyber1024::Parameters::PUBLIC_KEY_LENGTH + 1,
598 kp.public_key.serialize().len()
599 );
600 let (ss_for_sender, ct) = kp
601 .public_key
602 .encapsulate(&mut rng)
603 .expect("encapsulation works");
604 assert_eq!(kyber1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
605 assert_eq!(
606 kyber1024::Parameters::SHARED_SECRET_LENGTH,
607 ss_for_sender.len()
608 );
609 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
610 assert_eq!(ss_for_recipient, ss_for_sender);
611 }
612
613 #[test]
614 fn test_kyber768_keypair() {
615 let mut rng = rand::rngs::OsRng.unwrap_err();
616 let kp = KeyPair::generate(KeyType::Kyber768, &mut rng);
617 assert_eq!(
618 kyber768::Parameters::SECRET_KEY_LENGTH + 1,
619 kp.secret_key.serialize().len()
620 );
621 assert_eq!(
622 kyber768::Parameters::PUBLIC_KEY_LENGTH + 1,
623 kp.public_key.serialize().len()
624 );
625 let (ss_for_sender, ct) = kp
626 .public_key
627 .encapsulate(&mut rng)
628 .expect("encapsulation works");
629 assert_eq!(kyber768::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
630 assert_eq!(
631 kyber768::Parameters::SHARED_SECRET_LENGTH,
632 ss_for_sender.len()
633 );
634 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
635 assert_eq!(ss_for_recipient, ss_for_sender);
636 }
637
638 #[cfg(feature = "mlkem1024")]
639 #[test]
640 fn test_mlkem1024_keypair() {
641 let mut rng = rand::rngs::OsRng.unwrap_err();
642 let kp = KeyPair::generate(KeyType::MLKEM1024, &mut rng);
643 assert_eq!(
644 mlkem1024::Parameters::SECRET_KEY_LENGTH + 1,
645 kp.secret_key.serialize().len()
646 );
647 assert_eq!(
648 mlkem1024::Parameters::PUBLIC_KEY_LENGTH + 1,
649 kp.public_key.serialize().len()
650 );
651 let (ss_for_sender, ct) = kp
652 .public_key
653 .encapsulate(&mut rng)
654 .expect("encapsulation works");
655 assert_eq!(mlkem1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
656 assert_eq!(
657 mlkem1024::Parameters::SHARED_SECRET_LENGTH,
658 ss_for_sender.len()
659 );
660 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
661 assert_eq!(ss_for_recipient, ss_for_sender);
662 }
663
664 #[test]
665 fn test_dyn_parameters_consts() {
666 assert_eq!(
667 kyber1024::Parameters::SECRET_KEY_LENGTH,
668 kyber1024::Parameters.secret_key_length()
669 );
670 assert_eq!(
671 kyber1024::Parameters::PUBLIC_KEY_LENGTH,
672 kyber1024::Parameters.public_key_length()
673 );
674 assert_eq!(
675 kyber1024::Parameters::CIPHERTEXT_LENGTH,
676 kyber1024::Parameters.ciphertext_length()
677 );
678 assert_eq!(
679 kyber1024::Parameters::SHARED_SECRET_LENGTH,
680 kyber1024::Parameters.shared_secret_length()
681 );
682 }
683}