1mod kyber1024;
54#[cfg(any(feature = "kyber768", test))]
55mod kyber768;
56#[cfg(feature = "mlkem1024")]
57mod mlkem1024;
58
59use std::marker::PhantomData;
60use std::ops::Deref;
61
62use derive_where::derive_where;
63use displaydoc::Display;
64use subtle::ConstantTimeEq;
65
66use crate::{Result, SignalProtocolError};
67
68type SharedSecret = Box<[u8]>;
69
70pub(crate) type RawCiphertext = Box<[u8]>;
72pub type SerializedCiphertext = Box<[u8]>;
73
74trait Parameters {
87 const PUBLIC_KEY_LENGTH: usize;
88 const SECRET_KEY_LENGTH: usize;
89 const CIPHERTEXT_LENGTH: usize;
90 const SHARED_SECRET_LENGTH: usize;
91 fn generate() -> (KeyMaterial<Public>, KeyMaterial<Secret>);
92 fn encapsulate(pub_key: &KeyMaterial<Public>) -> (SharedSecret, RawCiphertext);
93 fn decapsulate(secret_key: &KeyMaterial<Secret>, ciphertext: &[u8]) -> Result<SharedSecret>;
94}
95
96trait DynParameters {
98 fn public_key_length(&self) -> usize;
99 fn secret_key_length(&self) -> usize;
100 fn ciphertext_length(&self) -> usize;
101 #[allow(dead_code)]
102 fn shared_secret_length(&self) -> usize;
103 fn generate(&self) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
104 fn encapsulate(&self, pub_key: &KeyMaterial<Public>) -> (SharedSecret, RawCiphertext);
105 fn decapsulate(
106 &self,
107 secret_key: &KeyMaterial<Secret>,
108 ciphertext: &[u8],
109 ) -> Result<SharedSecret>;
110}
111
112impl<T: Parameters> DynParameters for T {
113 fn public_key_length(&self) -> usize {
114 Self::PUBLIC_KEY_LENGTH
115 }
116
117 fn secret_key_length(&self) -> usize {
118 Self::SECRET_KEY_LENGTH
119 }
120
121 fn ciphertext_length(&self) -> usize {
122 Self::CIPHERTEXT_LENGTH
123 }
124
125 fn shared_secret_length(&self) -> usize {
126 Self::SHARED_SECRET_LENGTH
127 }
128
129 fn generate(&self) -> (KeyMaterial<Public>, KeyMaterial<Secret>) {
130 Self::generate()
131 }
132
133 fn encapsulate(&self, pub_key: &KeyMaterial<Public>) -> (SharedSecret, RawCiphertext) {
134 Self::encapsulate(pub_key)
135 }
136
137 fn decapsulate(
138 &self,
139 secret_key: &KeyMaterial<Secret>,
140 ciphertext: &[u8],
141 ) -> Result<SharedSecret> {
142 Self::decapsulate(secret_key, ciphertext)
143 }
144}
145
146#[derive(Display, Debug, Copy, Clone, PartialEq, Eq)]
148pub enum KeyType {
149 #[cfg(any(feature = "kyber768", test))]
151 Kyber768,
152 Kyber1024,
154 #[cfg(feature = "mlkem1024")]
156 MLKEM1024,
157}
158
159impl KeyType {
160 fn value(&self) -> u8 {
161 match self {
162 #[cfg(any(feature = "kyber768", test))]
163 KeyType::Kyber768 => 0x07,
164 KeyType::Kyber1024 => 0x08,
165 #[cfg(feature = "mlkem1024")]
166 KeyType::MLKEM1024 => 0x0A,
167 }
168 }
169
170 const fn parameters(&self) -> &'static dyn DynParameters {
174 match self {
175 #[cfg(any(feature = "kyber768", test))]
176 KeyType::Kyber768 => &kyber768::Parameters,
177 KeyType::Kyber1024 => &kyber1024::Parameters,
178 #[cfg(feature = "mlkem1024")]
179 KeyType::MLKEM1024 => &mlkem1024::Parameters,
180 }
181 }
182}
183
184impl TryFrom<u8> for KeyType {
185 type Error = SignalProtocolError;
186
187 fn try_from(x: u8) -> Result<Self> {
188 match x {
189 #[cfg(any(feature = "kyber768", test))]
190 0x07 => Ok(KeyType::Kyber768),
191 0x08 => Ok(KeyType::Kyber1024),
192 #[cfg(feature = "mlkem1024")]
193 0x0A => Ok(KeyType::MLKEM1024),
194 t => Err(SignalProtocolError::BadKEMKeyType(t)),
195 }
196 }
197}
198
199pub trait KeyKind {
200 fn key_length(key_type: KeyType) -> usize;
201}
202
203pub enum Public {}
204
205impl KeyKind for Public {
206 fn key_length(key_type: KeyType) -> usize {
207 key_type.parameters().public_key_length()
208 }
209}
210
211pub enum Secret {}
212
213impl KeyKind for Secret {
214 fn key_length(key_type: KeyType) -> usize {
215 key_type.parameters().secret_key_length()
216 }
217}
218
219#[derive_where(Clone)]
220pub(crate) struct KeyMaterial<T: KeyKind> {
221 data: Box<[u8]>,
222 kind: PhantomData<T>,
223}
224
225impl<T: KeyKind> KeyMaterial<T> {
226 fn new(data: Box<[u8]>) -> Self {
227 KeyMaterial {
228 data,
229 kind: PhantomData,
230 }
231 }
232}
233
234impl<T: KeyKind> Deref for KeyMaterial<T> {
235 type Target = [u8];
236
237 fn deref(&self) -> &Self::Target {
238 self.data.deref()
239 }
240}
241
242#[derive_where(Clone)]
243pub struct Key<T: KeyKind> {
244 key_type: KeyType,
245 key_data: KeyMaterial<T>,
246}
247
248impl<T: KeyKind> Key<T> {
249 pub fn deserialize(value: &[u8]) -> Result<Self> {
252 if value.is_empty() {
253 return Err(SignalProtocolError::NoKeyTypeIdentifier);
254 }
255 let key_type = KeyType::try_from(value[0])?;
256 if value.len() != T::key_length(key_type) + 1 {
257 return Err(SignalProtocolError::BadKEMKeyLength(key_type, value.len()));
258 }
259 Ok(Key {
260 key_type,
261 key_data: KeyMaterial::new(value[1..].into()),
262 })
263 }
264 pub fn serialize(&self) -> Box<[u8]> {
266 let mut result = Vec::with_capacity(1 + self.key_data.len());
267 result.push(self.key_type.value());
268 result.extend_from_slice(&self.key_data);
269 result.into_boxed_slice()
270 }
271
272 pub fn key_type(&self) -> KeyType {
274 self.key_type
275 }
276}
277
278impl Key<Public> {
279 pub fn encapsulate(&self) -> (SharedSecret, SerializedCiphertext) {
283 let (ss, ct) = self.key_type.parameters().encapsulate(&self.key_data);
284 (
285 ss,
286 Ciphertext {
287 key_type: self.key_type,
288 data: &ct,
289 }
290 .serialize(),
291 )
292 }
293}
294
295impl Key<Secret> {
296 pub fn decapsulate(&self, ct_bytes: &SerializedCiphertext) -> Result<Box<[u8]>> {
299 let ct = Ciphertext::deserialize(ct_bytes)?;
301 if ct.key_type != self.key_type {
302 return Err(SignalProtocolError::WrongKEMKeyType(
303 ct.key_type.value(),
304 self.key_type.value(),
305 ));
306 }
307 self.key_type
308 .parameters()
309 .decapsulate(&self.key_data, ct.data)
310 }
311}
312
313impl TryFrom<&[u8]> for Key<Public> {
314 type Error = SignalProtocolError;
315
316 fn try_from(value: &[u8]) -> Result<Self> {
317 Self::deserialize(value)
318 }
319}
320
321impl TryFrom<&[u8]> for Key<Secret> {
322 type Error = SignalProtocolError;
323
324 fn try_from(value: &[u8]) -> Result<Self> {
325 Self::deserialize(value)
326 }
327}
328
329impl subtle::ConstantTimeEq for Key<Public> {
330 fn ct_eq(&self, other: &Self) -> subtle::Choice {
335 if self.key_type != other.key_type {
336 return 0.ct_eq(&1);
337 }
338 self.key_data.ct_eq(&other.key_data)
339 }
340}
341
342impl PartialEq for Key<Public> {
343 fn eq(&self, other: &Self) -> bool {
344 bool::from(self.ct_eq(other))
345 }
346}
347
348impl Eq for Key<Public> {}
349
350pub type PublicKey = Key<Public>;
352
353pub type SecretKey = Key<Secret>;
355
356#[derive(Clone)]
358pub struct KeyPair {
359 pub public_key: PublicKey,
360 pub secret_key: SecretKey,
361}
362
363impl KeyPair {
364 pub fn generate(key_type: KeyType) -> Self {
367 let (pk, sk) = key_type.parameters().generate();
368 Self {
369 secret_key: SecretKey {
370 key_type,
371 key_data: sk,
372 },
373 public_key: PublicKey {
374 key_type,
375 key_data: pk,
376 },
377 }
378 }
379
380 pub fn new(public_key: PublicKey, secret_key: SecretKey) -> Self {
381 assert_eq!(public_key.key_type, secret_key.key_type);
382 Self {
383 public_key,
384 secret_key,
385 }
386 }
387
388 pub fn from_public_and_private(public_key: &[u8], secret_key: &[u8]) -> Result<Self> {
391 let public_key = PublicKey::try_from(public_key)?;
392 let secret_key = SecretKey::try_from(secret_key)?;
393 if public_key.key_type != secret_key.key_type {
394 Err(SignalProtocolError::WrongKEMKeyType(
395 secret_key.key_type.value(),
396 public_key.key_type.value(),
397 ))
398 } else {
399 Ok(Self {
400 public_key,
401 secret_key,
402 })
403 }
404 }
405}
406
407struct Ciphertext<'a> {
409 key_type: KeyType,
410 data: &'a [u8],
411}
412
413impl<'a> Ciphertext<'a> {
414 pub fn deserialize(value: &'a [u8]) -> Result<Self> {
417 if value.is_empty() {
418 return Err(SignalProtocolError::NoKeyTypeIdentifier);
419 }
420 let key_type = KeyType::try_from(value[0])?;
421 if value.len() != key_type.parameters().ciphertext_length() + 1 {
422 return Err(SignalProtocolError::BadKEMCiphertextLength(
423 key_type,
424 value.len(),
425 ));
426 }
427 Ok(Ciphertext {
428 key_type,
429 data: &value[1..],
430 })
431 }
432
433 pub fn serialize(&self) -> SerializedCiphertext {
435 let mut result = Vec::with_capacity(1 + self.data.len());
436 result.push(self.key_type.value());
437 result.extend_from_slice(self.data);
438 result.into_boxed_slice()
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_serialize() {
448 let pk_bytes = include_bytes!("kem/test-data/pk.dat");
449 let sk_bytes = include_bytes!("kem/test-data/sk.dat");
450
451 let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
452 serialized_pk.push(KeyType::Kyber1024.value());
453 serialized_pk.extend_from_slice(pk_bytes);
454
455 let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
456 serialized_sk.push(KeyType::Kyber1024.value());
457 serialized_sk.extend_from_slice(sk_bytes);
458
459 let pk = PublicKey::deserialize(serialized_pk.as_slice()).expect("desrialize pk");
460 let sk = SecretKey::deserialize(serialized_sk.as_slice()).expect("desrialize sk");
461
462 let reserialized_pk = pk.serialize();
463 let reserialized_sk = sk.serialize();
464
465 assert_eq!(serialized_pk, reserialized_pk.into_vec());
466 assert_eq!(serialized_sk, reserialized_sk.into_vec());
467 }
468
469 #[test]
470 fn test_raw_kem() {
471 use pqcrypto_kyber::kyber1024::{decapsulate, encapsulate, keypair};
472 let (pk, sk) = keypair();
473 let (ss1, ct) = encapsulate(&pk);
474 let ss2 = decapsulate(&ct, &sk);
475 assert!(ss1 == ss2);
476 }
477
478 #[test]
479 fn test_kyber1024_kem() {
480 let pk_bytes = include_bytes!("kem/test-data/pk.dat");
482 let sk_bytes = include_bytes!("kem/test-data/sk.dat");
483
484 let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
485 serialized_pk.push(KeyType::Kyber1024.value());
486 serialized_pk.extend_from_slice(pk_bytes);
487
488 let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
489 serialized_sk.push(KeyType::Kyber1024.value());
490 serialized_sk.extend_from_slice(sk_bytes);
491
492 let pubkey = PublicKey::deserialize(serialized_pk.as_slice()).expect("deserialize pubkey");
493 let secretkey =
494 SecretKey::deserialize(serialized_sk.as_slice()).expect("deserialize secretkey");
495
496 assert_eq!(pubkey.key_type, KeyType::Kyber1024);
497 let (ss_for_sender, ct) = pubkey.encapsulate();
498 let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
499
500 assert_eq!(ss_for_sender, ss_for_recipient);
501 }
502
503 #[cfg(feature = "mlkem1024")]
504 #[test]
505 fn test_mlkem1024_kem() {
506 let pk_bytes = include_bytes!("kem/test-data/mlkem-pk.dat");
508 let sk_bytes = include_bytes!("kem/test-data/mlkem-sk.dat");
509
510 let pubkey = PublicKey::deserialize(pk_bytes).expect("deserialize pubkey");
511 let secretkey = SecretKey::deserialize(sk_bytes).expect("deserialize secretkey");
512
513 assert_eq!(pubkey.key_type, KeyType::MLKEM1024);
514 let (ss_for_sender, ct) = pubkey.encapsulate();
515 let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
516
517 assert_eq!(ss_for_sender, ss_for_recipient);
518 }
519
520 #[test]
521 fn test_kyber1024_keypair() {
522 let kp = KeyPair::generate(KeyType::Kyber1024);
523 assert_eq!(
524 kyber1024::Parameters::SECRET_KEY_LENGTH + 1,
525 kp.secret_key.serialize().len()
526 );
527 assert_eq!(
528 kyber1024::Parameters::PUBLIC_KEY_LENGTH + 1,
529 kp.public_key.serialize().len()
530 );
531 let (ss_for_sender, ct) = kp.public_key.encapsulate();
532 assert_eq!(kyber1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
533 assert_eq!(
534 kyber1024::Parameters::SHARED_SECRET_LENGTH,
535 ss_for_sender.len()
536 );
537 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
538 assert_eq!(ss_for_recipient, ss_for_sender);
539 }
540
541 #[test]
542 fn test_kyber768_keypair() {
543 let kp = KeyPair::generate(KeyType::Kyber768);
544 assert_eq!(
545 kyber768::Parameters::SECRET_KEY_LENGTH + 1,
546 kp.secret_key.serialize().len()
547 );
548 assert_eq!(
549 kyber768::Parameters::PUBLIC_KEY_LENGTH + 1,
550 kp.public_key.serialize().len()
551 );
552 let (ss_for_sender, ct) = kp.public_key.encapsulate();
553 assert_eq!(kyber768::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
554 assert_eq!(
555 kyber768::Parameters::SHARED_SECRET_LENGTH,
556 ss_for_sender.len()
557 );
558 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
559 assert_eq!(ss_for_recipient, ss_for_sender);
560 }
561
562 #[cfg(feature = "mlkem1024")]
563 #[test]
564 fn test_mlkem1024_keypair() {
565 let kp = KeyPair::generate(KeyType::MLKEM1024);
566 assert_eq!(
567 mlkem1024::Parameters::SECRET_KEY_LENGTH + 1,
568 kp.secret_key.serialize().len()
569 );
570 assert_eq!(
571 mlkem1024::Parameters::PUBLIC_KEY_LENGTH + 1,
572 kp.public_key.serialize().len()
573 );
574 let (ss_for_sender, ct) = kp.public_key.encapsulate();
575 assert_eq!(mlkem1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
576 assert_eq!(
577 mlkem1024::Parameters::SHARED_SECRET_LENGTH,
578 ss_for_sender.len()
579 );
580 let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
581 assert_eq!(ss_for_recipient, ss_for_sender);
582 }
583
584 #[test]
585 fn test_dyn_parameters_consts() {
586 assert_eq!(
587 kyber1024::Parameters::SECRET_KEY_LENGTH,
588 kyber1024::Parameters.secret_key_length()
589 );
590 assert_eq!(
591 kyber1024::Parameters::PUBLIC_KEY_LENGTH,
592 kyber1024::Parameters.public_key_length()
593 );
594 assert_eq!(
595 kyber1024::Parameters::CIPHERTEXT_LENGTH,
596 kyber1024::Parameters.ciphertext_length()
597 );
598 assert_eq!(
599 kyber1024::Parameters::SHARED_SECRET_LENGTH,
600 kyber1024::Parameters.shared_secret_length()
601 );
602 }
603}