mod kyber1024;
#[cfg(any(feature = "kyber768", test))]
mod kyber768;
#[cfg(feature = "mlkem1024")]
mod mlkem1024;
use std::marker::PhantomData;
use std::ops::Deref;
use derive_where::derive_where;
use displaydoc::Display;
use subtle::ConstantTimeEq;
use crate::{Result, SignalProtocolError};
type SharedSecret = Box<[u8]>;
pub(crate) type RawCiphertext = Box<[u8]>;
pub type SerializedCiphertext = Box<[u8]>;
trait Parameters {
const PUBLIC_KEY_LENGTH: usize;
const SECRET_KEY_LENGTH: usize;
const CIPHERTEXT_LENGTH: usize;
const SHARED_SECRET_LENGTH: usize;
fn generate() -> (KeyMaterial<Public>, KeyMaterial<Secret>);
fn encapsulate(pub_key: &KeyMaterial<Public>) -> (SharedSecret, RawCiphertext);
fn decapsulate(secret_key: &KeyMaterial<Secret>, ciphertext: &[u8]) -> Result<SharedSecret>;
}
trait DynParameters {
fn public_key_length(&self) -> usize;
fn secret_key_length(&self) -> usize;
fn ciphertext_length(&self) -> usize;
#[allow(dead_code)]
fn shared_secret_length(&self) -> usize;
fn generate(&self) -> (KeyMaterial<Public>, KeyMaterial<Secret>);
fn encapsulate(&self, pub_key: &KeyMaterial<Public>) -> (SharedSecret, RawCiphertext);
fn decapsulate(
&self,
secret_key: &KeyMaterial<Secret>,
ciphertext: &[u8],
) -> Result<SharedSecret>;
}
impl<T: Parameters> DynParameters for T {
fn public_key_length(&self) -> usize {
Self::PUBLIC_KEY_LENGTH
}
fn secret_key_length(&self) -> usize {
Self::SECRET_KEY_LENGTH
}
fn ciphertext_length(&self) -> usize {
Self::CIPHERTEXT_LENGTH
}
fn shared_secret_length(&self) -> usize {
Self::SHARED_SECRET_LENGTH
}
fn generate(&self) -> (KeyMaterial<Public>, KeyMaterial<Secret>) {
Self::generate()
}
fn encapsulate(&self, pub_key: &KeyMaterial<Public>) -> (SharedSecret, RawCiphertext) {
Self::encapsulate(pub_key)
}
fn decapsulate(
&self,
secret_key: &KeyMaterial<Secret>,
ciphertext: &[u8],
) -> Result<SharedSecret> {
Self::decapsulate(secret_key, ciphertext)
}
}
#[derive(Display, Debug, Copy, Clone, PartialEq, Eq)]
pub enum KeyType {
#[cfg(any(feature = "kyber768", test))]
Kyber768,
Kyber1024,
#[cfg(feature = "mlkem1024")]
MLKEM1024,
}
impl KeyType {
fn value(&self) -> u8 {
match self {
#[cfg(any(feature = "kyber768", test))]
KeyType::Kyber768 => 0x07,
KeyType::Kyber1024 => 0x08,
#[cfg(feature = "mlkem1024")]
KeyType::MLKEM1024 => 0x0A,
}
}
const fn parameters(&self) -> &'static dyn DynParameters {
match self {
#[cfg(any(feature = "kyber768", test))]
KeyType::Kyber768 => &kyber768::Parameters,
KeyType::Kyber1024 => &kyber1024::Parameters,
#[cfg(feature = "mlkem1024")]
KeyType::MLKEM1024 => &mlkem1024::Parameters,
}
}
}
impl TryFrom<u8> for KeyType {
type Error = SignalProtocolError;
fn try_from(x: u8) -> Result<Self> {
match x {
#[cfg(any(feature = "kyber768", test))]
0x07 => Ok(KeyType::Kyber768),
0x08 => Ok(KeyType::Kyber1024),
#[cfg(feature = "mlkem1024")]
0x0A => Ok(KeyType::MLKEM1024),
t => Err(SignalProtocolError::BadKEMKeyType(t)),
}
}
}
pub trait KeyKind {
fn key_length(key_type: KeyType) -> usize;
}
pub enum Public {}
impl KeyKind for Public {
fn key_length(key_type: KeyType) -> usize {
key_type.parameters().public_key_length()
}
}
pub enum Secret {}
impl KeyKind for Secret {
fn key_length(key_type: KeyType) -> usize {
key_type.parameters().secret_key_length()
}
}
#[derive_where(Clone)]
pub(crate) struct KeyMaterial<T: KeyKind> {
data: Box<[u8]>,
kind: PhantomData<T>,
}
impl<T: KeyKind> KeyMaterial<T> {
fn new(data: Box<[u8]>) -> Self {
KeyMaterial {
data,
kind: PhantomData,
}
}
}
impl<T: KeyKind> Deref for KeyMaterial<T> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
self.data.deref()
}
}
#[derive_where(Clone)]
pub struct Key<T: KeyKind> {
key_type: KeyType,
key_data: KeyMaterial<T>,
}
impl<T: KeyKind> Key<T> {
pub fn deserialize(value: &[u8]) -> Result<Self> {
if value.is_empty() {
return Err(SignalProtocolError::NoKeyTypeIdentifier);
}
let key_type = KeyType::try_from(value[0])?;
if value.len() != T::key_length(key_type) + 1 {
return Err(SignalProtocolError::BadKEMKeyLength(key_type, value.len()));
}
Ok(Key {
key_type,
key_data: KeyMaterial::new(value[1..].into()),
})
}
pub fn serialize(&self) -> Box<[u8]> {
let mut result = Vec::with_capacity(1 + self.key_data.len());
result.push(self.key_type.value());
result.extend_from_slice(&self.key_data);
result.into_boxed_slice()
}
pub fn key_type(&self) -> KeyType {
self.key_type
}
}
impl Key<Public> {
pub fn encapsulate(&self) -> (SharedSecret, SerializedCiphertext) {
let (ss, ct) = self.key_type.parameters().encapsulate(&self.key_data);
(
ss,
Ciphertext {
key_type: self.key_type,
data: &ct,
}
.serialize(),
)
}
}
impl Key<Secret> {
pub fn decapsulate(&self, ct_bytes: &SerializedCiphertext) -> Result<Box<[u8]>> {
let ct = Ciphertext::deserialize(ct_bytes)?;
if ct.key_type != self.key_type {
return Err(SignalProtocolError::WrongKEMKeyType(
ct.key_type.value(),
self.key_type.value(),
));
}
self.key_type
.parameters()
.decapsulate(&self.key_data, ct.data)
}
}
impl TryFrom<&[u8]> for Key<Public> {
type Error = SignalProtocolError;
fn try_from(value: &[u8]) -> Result<Self> {
Self::deserialize(value)
}
}
impl TryFrom<&[u8]> for Key<Secret> {
type Error = SignalProtocolError;
fn try_from(value: &[u8]) -> Result<Self> {
Self::deserialize(value)
}
}
impl subtle::ConstantTimeEq for Key<Public> {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
if self.key_type != other.key_type {
return 0.ct_eq(&1);
}
self.key_data.ct_eq(&other.key_data)
}
}
impl PartialEq for Key<Public> {
fn eq(&self, other: &Self) -> bool {
bool::from(self.ct_eq(other))
}
}
impl Eq for Key<Public> {}
pub type PublicKey = Key<Public>;
pub type SecretKey = Key<Secret>;
#[derive(Clone)]
pub struct KeyPair {
pub public_key: PublicKey,
pub secret_key: SecretKey,
}
impl KeyPair {
pub fn generate(key_type: KeyType) -> Self {
let (pk, sk) = key_type.parameters().generate();
Self {
secret_key: SecretKey {
key_type,
key_data: sk,
},
public_key: PublicKey {
key_type,
key_data: pk,
},
}
}
pub fn new(public_key: PublicKey, secret_key: SecretKey) -> Self {
assert_eq!(public_key.key_type, secret_key.key_type);
Self {
public_key,
secret_key,
}
}
pub fn from_public_and_private(public_key: &[u8], secret_key: &[u8]) -> Result<Self> {
let public_key = PublicKey::try_from(public_key)?;
let secret_key = SecretKey::try_from(secret_key)?;
if public_key.key_type != secret_key.key_type {
Err(SignalProtocolError::WrongKEMKeyType(
secret_key.key_type.value(),
public_key.key_type.value(),
))
} else {
Ok(Self {
public_key,
secret_key,
})
}
}
}
struct Ciphertext<'a> {
key_type: KeyType,
data: &'a [u8],
}
impl<'a> Ciphertext<'a> {
pub fn deserialize(value: &'a [u8]) -> Result<Self> {
if value.is_empty() {
return Err(SignalProtocolError::NoKeyTypeIdentifier);
}
let key_type = KeyType::try_from(value[0])?;
if value.len() != key_type.parameters().ciphertext_length() + 1 {
return Err(SignalProtocolError::BadKEMCiphertextLength(
key_type,
value.len(),
));
}
Ok(Ciphertext {
key_type,
data: &value[1..],
})
}
pub fn serialize(&self) -> SerializedCiphertext {
let mut result = Vec::with_capacity(1 + self.data.len());
result.push(self.key_type.value());
result.extend_from_slice(self.data);
result.into_boxed_slice()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialize() {
let pk_bytes = include_bytes!("kem/test-data/pk.dat");
let sk_bytes = include_bytes!("kem/test-data/sk.dat");
let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
serialized_pk.push(KeyType::Kyber1024.value());
serialized_pk.extend_from_slice(pk_bytes);
let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
serialized_sk.push(KeyType::Kyber1024.value());
serialized_sk.extend_from_slice(sk_bytes);
let pk = PublicKey::deserialize(serialized_pk.as_slice()).expect("desrialize pk");
let sk = SecretKey::deserialize(serialized_sk.as_slice()).expect("desrialize sk");
let reserialized_pk = pk.serialize();
let reserialized_sk = sk.serialize();
assert_eq!(serialized_pk, reserialized_pk.into_vec());
assert_eq!(serialized_sk, reserialized_sk.into_vec());
}
#[test]
fn test_raw_kem() {
use pqcrypto_kyber::kyber1024::{decapsulate, encapsulate, keypair};
let (pk, sk) = keypair();
let (ss1, ct) = encapsulate(&pk);
let ss2 = decapsulate(&ct, &sk);
assert!(ss1 == ss2);
}
#[test]
fn test_kyber1024_kem() {
let pk_bytes = include_bytes!("kem/test-data/pk.dat");
let sk_bytes = include_bytes!("kem/test-data/sk.dat");
let mut serialized_pk = Vec::with_capacity(1 + kyber1024::Parameters::PUBLIC_KEY_LENGTH);
serialized_pk.push(KeyType::Kyber1024.value());
serialized_pk.extend_from_slice(pk_bytes);
let mut serialized_sk = Vec::with_capacity(1 + kyber1024::Parameters::SECRET_KEY_LENGTH);
serialized_sk.push(KeyType::Kyber1024.value());
serialized_sk.extend_from_slice(sk_bytes);
let pubkey = PublicKey::deserialize(serialized_pk.as_slice()).expect("deserialize pubkey");
let secretkey =
SecretKey::deserialize(serialized_sk.as_slice()).expect("deserialize secretkey");
assert_eq!(pubkey.key_type, KeyType::Kyber1024);
let (ss_for_sender, ct) = pubkey.encapsulate();
let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
assert_eq!(ss_for_sender, ss_for_recipient);
}
#[cfg(feature = "mlkem1024")]
#[test]
fn test_mlkem1024_kem() {
let pk_bytes = include_bytes!("kem/test-data/mlkem-pk.dat");
let sk_bytes = include_bytes!("kem/test-data/mlkem-sk.dat");
let pubkey = PublicKey::deserialize(pk_bytes).expect("deserialize pubkey");
let secretkey = SecretKey::deserialize(sk_bytes).expect("deserialize secretkey");
assert_eq!(pubkey.key_type, KeyType::MLKEM1024);
let (ss_for_sender, ct) = pubkey.encapsulate();
let ss_for_recipient = secretkey.decapsulate(&ct).expect("decapsulation works");
assert_eq!(ss_for_sender, ss_for_recipient);
}
#[test]
fn test_kyber1024_keypair() {
let kp = KeyPair::generate(KeyType::Kyber1024);
assert_eq!(
kyber1024::Parameters::SECRET_KEY_LENGTH + 1,
kp.secret_key.serialize().len()
);
assert_eq!(
kyber1024::Parameters::PUBLIC_KEY_LENGTH + 1,
kp.public_key.serialize().len()
);
let (ss_for_sender, ct) = kp.public_key.encapsulate();
assert_eq!(kyber1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
assert_eq!(
kyber1024::Parameters::SHARED_SECRET_LENGTH,
ss_for_sender.len()
);
let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
assert_eq!(ss_for_recipient, ss_for_sender);
}
#[test]
fn test_kyber768_keypair() {
let kp = KeyPair::generate(KeyType::Kyber768);
assert_eq!(
kyber768::Parameters::SECRET_KEY_LENGTH + 1,
kp.secret_key.serialize().len()
);
assert_eq!(
kyber768::Parameters::PUBLIC_KEY_LENGTH + 1,
kp.public_key.serialize().len()
);
let (ss_for_sender, ct) = kp.public_key.encapsulate();
assert_eq!(kyber768::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
assert_eq!(
kyber768::Parameters::SHARED_SECRET_LENGTH,
ss_for_sender.len()
);
let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
assert_eq!(ss_for_recipient, ss_for_sender);
}
#[cfg(feature = "mlkem1024")]
#[test]
fn test_mlkem1024_keypair() {
let kp = KeyPair::generate(KeyType::MLKEM1024);
assert_eq!(
mlkem1024::Parameters::SECRET_KEY_LENGTH + 1,
kp.secret_key.serialize().len()
);
assert_eq!(
mlkem1024::Parameters::PUBLIC_KEY_LENGTH + 1,
kp.public_key.serialize().len()
);
let (ss_for_sender, ct) = kp.public_key.encapsulate();
assert_eq!(mlkem1024::Parameters::CIPHERTEXT_LENGTH + 1, ct.len());
assert_eq!(
mlkem1024::Parameters::SHARED_SECRET_LENGTH,
ss_for_sender.len()
);
let ss_for_recipient = kp.secret_key.decapsulate(&ct).expect("decapsulation works");
assert_eq!(ss_for_recipient, ss_for_sender);
}
#[test]
fn test_dyn_parameters_consts() {
assert_eq!(
kyber1024::Parameters::SECRET_KEY_LENGTH,
kyber1024::Parameters.secret_key_length()
);
assert_eq!(
kyber1024::Parameters::PUBLIC_KEY_LENGTH,
kyber1024::Parameters.public_key_length()
);
assert_eq!(
kyber1024::Parameters::CIPHERTEXT_LENGTH,
kyber1024::Parameters.ciphertext_length()
);
assert_eq!(
kyber1024::Parameters::SHARED_SECRET_LENGTH,
kyber1024::Parameters.shared_secret_length()
);
}
}