1pub(crate) mod curve25519;
7
8use std::cmp::Ordering;
9use std::fmt;
10
11use arrayref::array_ref;
12use curve25519_dalek::scalar;
13use rand::{CryptoRng, Rng};
14use subtle::ConstantTimeEq;
15
16use crate::{Result, SignalProtocolError};
17
18#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
19pub enum KeyType {
20 Djb,
21}
22
23impl fmt::Display for KeyType {
24 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25 fmt::Debug::fmt(self, f)
26 }
27}
28
29impl KeyType {
30 fn value(&self) -> u8 {
31 match &self {
32 KeyType::Djb => 0x05u8,
33 }
34 }
35}
36
37impl TryFrom<u8> for KeyType {
38 type Error = SignalProtocolError;
39
40 fn try_from(x: u8) -> Result<Self> {
41 match x {
42 0x05u8 => Ok(KeyType::Djb),
43 t => Err(SignalProtocolError::BadKeyType(t)),
44 }
45 }
46}
47
48#[derive(Debug, Clone, Copy, Eq, PartialEq)]
49enum PublicKeyData {
50 DjbPublicKey([u8; curve25519::PUBLIC_KEY_LENGTH]),
51}
52
53#[derive(Clone, Copy, Eq)]
54pub struct PublicKey {
55 key: PublicKeyData,
56}
57
58impl PublicKey {
59 fn new(key: PublicKeyData) -> Self {
60 Self { key }
61 }
62
63 pub fn deserialize(value: &[u8]) -> Result<Self> {
64 if value.is_empty() {
65 return Err(SignalProtocolError::NoKeyTypeIdentifier);
66 }
67 let key_type = KeyType::try_from(value[0])?;
68 match key_type {
69 KeyType::Djb => {
70 if value.len() < curve25519::PUBLIC_KEY_LENGTH + 1 {
72 return Err(SignalProtocolError::BadKeyLength(KeyType::Djb, value.len()));
73 }
74 let mut key = [0u8; curve25519::PUBLIC_KEY_LENGTH];
75 key.copy_from_slice(&value[1..][..curve25519::PUBLIC_KEY_LENGTH]);
76 Ok(PublicKey {
77 key: PublicKeyData::DjbPublicKey(key),
78 })
79 }
80 }
81 }
82
83 pub fn public_key_bytes(&self) -> Result<&[u8]> {
84 match &self.key {
85 PublicKeyData::DjbPublicKey(v) => Ok(v),
86 }
87 }
88
89 pub fn from_djb_public_key_bytes(bytes: &[u8]) -> Result<Self> {
90 match <[u8; curve25519::PUBLIC_KEY_LENGTH]>::try_from(bytes) {
91 Err(_) => Err(SignalProtocolError::BadKeyLength(KeyType::Djb, bytes.len())),
92 Ok(key) => Ok(PublicKey {
93 key: PublicKeyData::DjbPublicKey(key),
94 }),
95 }
96 }
97
98 pub fn serialize(&self) -> Box<[u8]> {
99 let value_len = match &self.key {
100 PublicKeyData::DjbPublicKey(v) => v.len(),
101 };
102 let mut result = Vec::with_capacity(1 + value_len);
103 result.push(self.key_type().value());
104 match &self.key {
105 PublicKeyData::DjbPublicKey(v) => result.extend_from_slice(v),
106 }
107 result.into_boxed_slice()
108 }
109
110 pub fn verify_signature(&self, message: &[u8], signature: &[u8]) -> Result<bool> {
111 self.verify_signature_for_multipart_message(&[message], signature)
112 }
113
114 pub fn verify_signature_for_multipart_message(
115 &self,
116 message: &[&[u8]],
117 signature: &[u8],
118 ) -> Result<bool> {
119 match &self.key {
120 PublicKeyData::DjbPublicKey(pub_key) => {
121 if signature.len() != curve25519::SIGNATURE_LENGTH {
122 return Ok(false);
123 }
124 Ok(curve25519::PrivateKey::verify_signature(
125 pub_key,
126 message,
127 array_ref![signature, 0, curve25519::SIGNATURE_LENGTH],
128 ))
129 }
130 }
131 }
132
133 fn key_data(&self) -> &[u8] {
134 match &self.key {
135 PublicKeyData::DjbPublicKey(ref k) => k.as_ref(),
136 }
137 }
138
139 pub fn key_type(&self) -> KeyType {
140 match &self.key {
141 PublicKeyData::DjbPublicKey(_) => KeyType::Djb,
142 }
143 }
144}
145
146impl From<PublicKeyData> for PublicKey {
147 fn from(key: PublicKeyData) -> PublicKey {
148 Self { key }
149 }
150}
151
152impl TryFrom<&[u8]> for PublicKey {
153 type Error = SignalProtocolError;
154
155 fn try_from(value: &[u8]) -> Result<Self> {
156 Self::deserialize(value)
157 }
158}
159
160impl subtle::ConstantTimeEq for PublicKey {
161 fn ct_eq(&self, other: &PublicKey) -> subtle::Choice {
166 if self.key_type() != other.key_type() {
167 return 0.ct_eq(&1);
168 }
169 self.key_data().ct_eq(other.key_data())
170 }
171}
172
173impl PartialEq for PublicKey {
174 fn eq(&self, other: &PublicKey) -> bool {
175 bool::from(self.ct_eq(other))
176 }
177}
178
179impl Ord for PublicKey {
180 fn cmp(&self, other: &Self) -> Ordering {
181 if self.key_type() != other.key_type() {
182 return self.key_type().cmp(&other.key_type());
183 }
184
185 crate::utils::constant_time_cmp(self.key_data(), other.key_data())
186 }
187}
188
189impl PartialOrd for PublicKey {
190 fn partial_cmp(&self, other: &PublicKey) -> Option<Ordering> {
191 Some(self.cmp(other))
192 }
193}
194
195impl fmt::Debug for PublicKey {
196 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197 write!(
198 f,
199 "PublicKey {{ key_type={}, serialize={:?} }}",
200 self.key_type(),
201 self.serialize()
202 )
203 }
204}
205
206#[derive(Debug, Clone, Copy, Eq, PartialEq)]
207enum PrivateKeyData {
208 DjbPrivateKey([u8; curve25519::PRIVATE_KEY_LENGTH]),
209}
210
211#[derive(Clone, Copy, Eq, PartialEq)]
212pub struct PrivateKey {
213 key: PrivateKeyData,
214}
215
216impl PrivateKey {
217 pub fn deserialize(value: &[u8]) -> Result<Self> {
218 if value.len() != curve25519::PRIVATE_KEY_LENGTH {
219 Err(SignalProtocolError::BadKeyLength(KeyType::Djb, value.len()))
220 } else {
221 let mut key = [0u8; curve25519::PRIVATE_KEY_LENGTH];
222 key.copy_from_slice(&value[..curve25519::PRIVATE_KEY_LENGTH]);
223 key = scalar::clamp_integer(key);
225 Ok(Self {
226 key: PrivateKeyData::DjbPrivateKey(key),
227 })
228 }
229 }
230
231 pub fn serialize(&self) -> Vec<u8> {
232 match &self.key {
233 PrivateKeyData::DjbPrivateKey(v) => v.to_vec(),
234 }
235 }
236
237 pub fn public_key(&self) -> Result<PublicKey> {
238 match &self.key {
239 PrivateKeyData::DjbPrivateKey(private_key) => {
240 let public_key =
241 curve25519::PrivateKey::from(*private_key).derive_public_key_bytes();
242 Ok(PublicKey::new(PublicKeyData::DjbPublicKey(public_key)))
243 }
244 }
245 }
246
247 pub fn key_type(&self) -> KeyType {
248 match &self.key {
249 PrivateKeyData::DjbPrivateKey(_) => KeyType::Djb,
250 }
251 }
252
253 pub fn calculate_signature<R: CryptoRng + Rng>(
254 &self,
255 message: &[u8],
256 csprng: &mut R,
257 ) -> Result<Box<[u8]>> {
258 self.calculate_signature_for_multipart_message(&[message], csprng)
259 }
260
261 pub fn calculate_signature_for_multipart_message<R: CryptoRng + Rng>(
262 &self,
263 message: &[&[u8]],
264 csprng: &mut R,
265 ) -> Result<Box<[u8]>> {
266 match self.key {
267 PrivateKeyData::DjbPrivateKey(k) => {
268 let private_key = curve25519::PrivateKey::from(k);
269 Ok(Box::new(private_key.calculate_signature(csprng, message)))
270 }
271 }
272 }
273
274 pub fn calculate_agreement(&self, their_key: &PublicKey) -> Result<Box<[u8]>> {
275 match (self.key, their_key.key) {
276 (PrivateKeyData::DjbPrivateKey(priv_key), PublicKeyData::DjbPublicKey(pub_key)) => {
277 let private_key = curve25519::PrivateKey::from(priv_key);
278 Ok(Box::new(private_key.calculate_agreement(&pub_key)))
279 }
280 }
281 }
282}
283
284impl From<PrivateKeyData> for PrivateKey {
285 fn from(key: PrivateKeyData) -> PrivateKey {
286 Self { key }
287 }
288}
289
290impl TryFrom<&[u8]> for PrivateKey {
291 type Error = SignalProtocolError;
292
293 fn try_from(value: &[u8]) -> Result<Self> {
294 Self::deserialize(value)
295 }
296}
297
298#[derive(Copy, Clone)]
299pub struct KeyPair {
300 pub public_key: PublicKey,
301 pub private_key: PrivateKey,
302}
303
304impl KeyPair {
305 pub fn generate<R: Rng + CryptoRng>(csprng: &mut R) -> Self {
306 let private_key = curve25519::PrivateKey::new(csprng);
307
308 let public_key = PublicKey::from(PublicKeyData::DjbPublicKey(
309 private_key.derive_public_key_bytes(),
310 ));
311 let private_key = PrivateKey::from(PrivateKeyData::DjbPrivateKey(
312 private_key.private_key_bytes(),
313 ));
314
315 Self {
316 public_key,
317 private_key,
318 }
319 }
320
321 pub fn new(public_key: PublicKey, private_key: PrivateKey) -> Self {
322 Self {
323 public_key,
324 private_key,
325 }
326 }
327
328 pub fn from_public_and_private(public_key: &[u8], private_key: &[u8]) -> Result<Self> {
329 let public_key = PublicKey::try_from(public_key)?;
330 let private_key = PrivateKey::try_from(private_key)?;
331 Ok(Self {
332 public_key,
333 private_key,
334 })
335 }
336
337 pub fn calculate_signature<R: CryptoRng + Rng>(
338 &self,
339 message: &[u8],
340 csprng: &mut R,
341 ) -> Result<Box<[u8]>> {
342 self.private_key.calculate_signature(message, csprng)
343 }
344
345 pub fn calculate_agreement(&self, their_key: &PublicKey) -> Result<Box<[u8]>> {
346 self.private_key.calculate_agreement(their_key)
347 }
348}
349
350impl TryFrom<PrivateKey> for KeyPair {
351 type Error = SignalProtocolError;
352
353 fn try_from(value: PrivateKey) -> Result<Self> {
354 let public_key = value.public_key()?;
355 Ok(Self::new(public_key, value))
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use rand::rngs::OsRng;
362
363 use super::*;
364
365 #[test]
366 fn test_large_signatures() -> Result<()> {
367 let mut csprng = OsRng;
368 let key_pair = KeyPair::generate(&mut csprng);
369 let mut message = [0u8; 1024 * 1024];
370 let signature = key_pair
371 .private_key
372 .calculate_signature(&message, &mut csprng)?;
373
374 assert!(key_pair.public_key.verify_signature(&message, &signature)?);
375 message[0] ^= 0x01u8;
376 assert!(!key_pair.public_key.verify_signature(&message, &signature)?);
377 message[0] ^= 0x01u8;
378 let public_key = key_pair.private_key.public_key()?;
379 assert!(public_key.verify_signature(&message, &signature)?);
380
381 assert!(public_key
382 .verify_signature_for_multipart_message(&[&message[..7], &message[7..]], &signature)?);
383
384 let signature = key_pair
385 .private_key
386 .calculate_signature_for_multipart_message(
387 &[&message[..20], &message[20..]],
388 &mut csprng,
389 )?;
390 assert!(public_key.verify_signature(&message, &signature)?);
391
392 Ok(())
393 }
394
395 #[test]
396 fn test_decode_size() -> Result<()> {
397 let mut csprng = OsRng;
398 let key_pair = KeyPair::generate(&mut csprng);
399 let serialized_public = key_pair.public_key.serialize();
400
401 assert_eq!(
402 serialized_public,
403 key_pair.private_key.public_key()?.serialize()
404 );
405 let empty: [u8; 0] = [];
406
407 let just_right = PublicKey::try_from(&serialized_public[..]);
408
409 assert!(just_right.is_ok());
410 assert!(PublicKey::try_from(&serialized_public[1..]).is_err());
411 assert!(PublicKey::try_from(&empty[..]).is_err());
412
413 let mut bad_key_type = [0u8; 33];
414 bad_key_type[..].copy_from_slice(&serialized_public[..]);
415 bad_key_type[0] = 0x01u8;
416 assert!(PublicKey::try_from(&bad_key_type[..]).is_err());
417
418 let mut extra_space = [0u8; 34];
419 extra_space[..33].copy_from_slice(&serialized_public[..]);
420 let extra_space_decode = PublicKey::try_from(&extra_space[..]);
421 assert!(extra_space_decode.is_ok());
422
423 assert_eq!(&serialized_public[..], &just_right?.serialize()[..]);
424 assert_eq!(&serialized_public[..], &extra_space_decode?.serialize()[..]);
425 Ok(())
426 }
427}