libsignal_protocol/
curve.rs

1//
2// Copyright 2020-2021 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6pub(crate) mod curve25519;
7
8use std::cmp::Ordering;
9use std::fmt;
10
11use arrayref::array_ref;
12use curve25519_dalek::scalar;
13use rand::{CryptoRng, Rng};
14use subtle::ConstantTimeEq;
15
16use crate::{Result, SignalProtocolError};
17
18#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
19pub enum KeyType {
20    Djb,
21}
22
23impl fmt::Display for KeyType {
24    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
25        fmt::Debug::fmt(self, f)
26    }
27}
28
29impl KeyType {
30    fn value(&self) -> u8 {
31        match &self {
32            KeyType::Djb => 0x05u8,
33        }
34    }
35}
36
37impl TryFrom<u8> for KeyType {
38    type Error = SignalProtocolError;
39
40    fn try_from(x: u8) -> Result<Self> {
41        match x {
42            0x05u8 => Ok(KeyType::Djb),
43            t => Err(SignalProtocolError::BadKeyType(t)),
44        }
45    }
46}
47
48#[derive(Debug, Clone, Copy, Eq, PartialEq)]
49enum PublicKeyData {
50    DjbPublicKey([u8; curve25519::PUBLIC_KEY_LENGTH]),
51}
52
53#[derive(Clone, Copy, Eq)]
54pub struct PublicKey {
55    key: PublicKeyData,
56}
57
58impl PublicKey {
59    fn new(key: PublicKeyData) -> Self {
60        Self { key }
61    }
62
63    pub fn deserialize(value: &[u8]) -> Result<Self> {
64        if value.is_empty() {
65            return Err(SignalProtocolError::NoKeyTypeIdentifier);
66        }
67        let key_type = KeyType::try_from(value[0])?;
68        match key_type {
69            KeyType::Djb => {
70                // We allow trailing data after the public key (why?)
71                if value.len() < curve25519::PUBLIC_KEY_LENGTH + 1 {
72                    return Err(SignalProtocolError::BadKeyLength(KeyType::Djb, value.len()));
73                }
74                let mut key = [0u8; curve25519::PUBLIC_KEY_LENGTH];
75                key.copy_from_slice(&value[1..][..curve25519::PUBLIC_KEY_LENGTH]);
76                Ok(PublicKey {
77                    key: PublicKeyData::DjbPublicKey(key),
78                })
79            }
80        }
81    }
82
83    pub fn public_key_bytes(&self) -> Result<&[u8]> {
84        match &self.key {
85            PublicKeyData::DjbPublicKey(v) => Ok(v),
86        }
87    }
88
89    pub fn from_djb_public_key_bytes(bytes: &[u8]) -> Result<Self> {
90        match <[u8; curve25519::PUBLIC_KEY_LENGTH]>::try_from(bytes) {
91            Err(_) => Err(SignalProtocolError::BadKeyLength(KeyType::Djb, bytes.len())),
92            Ok(key) => Ok(PublicKey {
93                key: PublicKeyData::DjbPublicKey(key),
94            }),
95        }
96    }
97
98    pub fn serialize(&self) -> Box<[u8]> {
99        let value_len = match &self.key {
100            PublicKeyData::DjbPublicKey(v) => v.len(),
101        };
102        let mut result = Vec::with_capacity(1 + value_len);
103        result.push(self.key_type().value());
104        match &self.key {
105            PublicKeyData::DjbPublicKey(v) => result.extend_from_slice(v),
106        }
107        result.into_boxed_slice()
108    }
109
110    pub fn verify_signature(&self, message: &[u8], signature: &[u8]) -> Result<bool> {
111        self.verify_signature_for_multipart_message(&[message], signature)
112    }
113
114    pub fn verify_signature_for_multipart_message(
115        &self,
116        message: &[&[u8]],
117        signature: &[u8],
118    ) -> Result<bool> {
119        match &self.key {
120            PublicKeyData::DjbPublicKey(pub_key) => {
121                if signature.len() != curve25519::SIGNATURE_LENGTH {
122                    return Ok(false);
123                }
124                Ok(curve25519::PrivateKey::verify_signature(
125                    pub_key,
126                    message,
127                    array_ref![signature, 0, curve25519::SIGNATURE_LENGTH],
128                ))
129            }
130        }
131    }
132
133    fn key_data(&self) -> &[u8] {
134        match &self.key {
135            PublicKeyData::DjbPublicKey(ref k) => k.as_ref(),
136        }
137    }
138
139    pub fn key_type(&self) -> KeyType {
140        match &self.key {
141            PublicKeyData::DjbPublicKey(_) => KeyType::Djb,
142        }
143    }
144}
145
146impl From<PublicKeyData> for PublicKey {
147    fn from(key: PublicKeyData) -> PublicKey {
148        Self { key }
149    }
150}
151
152impl TryFrom<&[u8]> for PublicKey {
153    type Error = SignalProtocolError;
154
155    fn try_from(value: &[u8]) -> Result<Self> {
156        Self::deserialize(value)
157    }
158}
159
160impl subtle::ConstantTimeEq for PublicKey {
161    /// A constant-time comparison as long as the two keys have a matching type.
162    ///
163    /// If the two keys have different types, the comparison short-circuits,
164    /// much like comparing two slices of different lengths.
165    fn ct_eq(&self, other: &PublicKey) -> subtle::Choice {
166        if self.key_type() != other.key_type() {
167            return 0.ct_eq(&1);
168        }
169        self.key_data().ct_eq(other.key_data())
170    }
171}
172
173impl PartialEq for PublicKey {
174    fn eq(&self, other: &PublicKey) -> bool {
175        bool::from(self.ct_eq(other))
176    }
177}
178
179impl Ord for PublicKey {
180    fn cmp(&self, other: &Self) -> Ordering {
181        if self.key_type() != other.key_type() {
182            return self.key_type().cmp(&other.key_type());
183        }
184
185        crate::utils::constant_time_cmp(self.key_data(), other.key_data())
186    }
187}
188
189impl PartialOrd for PublicKey {
190    fn partial_cmp(&self, other: &PublicKey) -> Option<Ordering> {
191        Some(self.cmp(other))
192    }
193}
194
195impl fmt::Debug for PublicKey {
196    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197        write!(
198            f,
199            "PublicKey {{ key_type={}, serialize={:?} }}",
200            self.key_type(),
201            self.serialize()
202        )
203    }
204}
205
206#[derive(Debug, Clone, Copy, Eq, PartialEq)]
207enum PrivateKeyData {
208    DjbPrivateKey([u8; curve25519::PRIVATE_KEY_LENGTH]),
209}
210
211#[derive(Clone, Copy, Eq, PartialEq)]
212pub struct PrivateKey {
213    key: PrivateKeyData,
214}
215
216impl PrivateKey {
217    pub fn deserialize(value: &[u8]) -> Result<Self> {
218        if value.len() != curve25519::PRIVATE_KEY_LENGTH {
219            Err(SignalProtocolError::BadKeyLength(KeyType::Djb, value.len()))
220        } else {
221            let mut key = [0u8; curve25519::PRIVATE_KEY_LENGTH];
222            key.copy_from_slice(&value[..curve25519::PRIVATE_KEY_LENGTH]);
223            // Clamping is not necessary but is kept for backward compatibility
224            key = scalar::clamp_integer(key);
225            Ok(Self {
226                key: PrivateKeyData::DjbPrivateKey(key),
227            })
228        }
229    }
230
231    pub fn serialize(&self) -> Vec<u8> {
232        match &self.key {
233            PrivateKeyData::DjbPrivateKey(v) => v.to_vec(),
234        }
235    }
236
237    pub fn public_key(&self) -> Result<PublicKey> {
238        match &self.key {
239            PrivateKeyData::DjbPrivateKey(private_key) => {
240                let public_key =
241                    curve25519::PrivateKey::from(*private_key).derive_public_key_bytes();
242                Ok(PublicKey::new(PublicKeyData::DjbPublicKey(public_key)))
243            }
244        }
245    }
246
247    pub fn key_type(&self) -> KeyType {
248        match &self.key {
249            PrivateKeyData::DjbPrivateKey(_) => KeyType::Djb,
250        }
251    }
252
253    pub fn calculate_signature<R: CryptoRng + Rng>(
254        &self,
255        message: &[u8],
256        csprng: &mut R,
257    ) -> Result<Box<[u8]>> {
258        self.calculate_signature_for_multipart_message(&[message], csprng)
259    }
260
261    pub fn calculate_signature_for_multipart_message<R: CryptoRng + Rng>(
262        &self,
263        message: &[&[u8]],
264        csprng: &mut R,
265    ) -> Result<Box<[u8]>> {
266        match self.key {
267            PrivateKeyData::DjbPrivateKey(k) => {
268                let private_key = curve25519::PrivateKey::from(k);
269                Ok(Box::new(private_key.calculate_signature(csprng, message)))
270            }
271        }
272    }
273
274    pub fn calculate_agreement(&self, their_key: &PublicKey) -> Result<Box<[u8]>> {
275        match (self.key, their_key.key) {
276            (PrivateKeyData::DjbPrivateKey(priv_key), PublicKeyData::DjbPublicKey(pub_key)) => {
277                let private_key = curve25519::PrivateKey::from(priv_key);
278                Ok(Box::new(private_key.calculate_agreement(&pub_key)))
279            }
280        }
281    }
282}
283
284impl From<PrivateKeyData> for PrivateKey {
285    fn from(key: PrivateKeyData) -> PrivateKey {
286        Self { key }
287    }
288}
289
290impl TryFrom<&[u8]> for PrivateKey {
291    type Error = SignalProtocolError;
292
293    fn try_from(value: &[u8]) -> Result<Self> {
294        Self::deserialize(value)
295    }
296}
297
298#[derive(Copy, Clone)]
299pub struct KeyPair {
300    pub public_key: PublicKey,
301    pub private_key: PrivateKey,
302}
303
304impl KeyPair {
305    pub fn generate<R: Rng + CryptoRng>(csprng: &mut R) -> Self {
306        let private_key = curve25519::PrivateKey::new(csprng);
307
308        let public_key = PublicKey::from(PublicKeyData::DjbPublicKey(
309            private_key.derive_public_key_bytes(),
310        ));
311        let private_key = PrivateKey::from(PrivateKeyData::DjbPrivateKey(
312            private_key.private_key_bytes(),
313        ));
314
315        Self {
316            public_key,
317            private_key,
318        }
319    }
320
321    pub fn new(public_key: PublicKey, private_key: PrivateKey) -> Self {
322        Self {
323            public_key,
324            private_key,
325        }
326    }
327
328    pub fn from_public_and_private(public_key: &[u8], private_key: &[u8]) -> Result<Self> {
329        let public_key = PublicKey::try_from(public_key)?;
330        let private_key = PrivateKey::try_from(private_key)?;
331        Ok(Self {
332            public_key,
333            private_key,
334        })
335    }
336
337    pub fn calculate_signature<R: CryptoRng + Rng>(
338        &self,
339        message: &[u8],
340        csprng: &mut R,
341    ) -> Result<Box<[u8]>> {
342        self.private_key.calculate_signature(message, csprng)
343    }
344
345    pub fn calculate_agreement(&self, their_key: &PublicKey) -> Result<Box<[u8]>> {
346        self.private_key.calculate_agreement(their_key)
347    }
348}
349
350impl TryFrom<PrivateKey> for KeyPair {
351    type Error = SignalProtocolError;
352
353    fn try_from(value: PrivateKey) -> Result<Self> {
354        let public_key = value.public_key()?;
355        Ok(Self::new(public_key, value))
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use rand::rngs::OsRng;
362
363    use super::*;
364
365    #[test]
366    fn test_large_signatures() -> Result<()> {
367        let mut csprng = OsRng;
368        let key_pair = KeyPair::generate(&mut csprng);
369        let mut message = [0u8; 1024 * 1024];
370        let signature = key_pair
371            .private_key
372            .calculate_signature(&message, &mut csprng)?;
373
374        assert!(key_pair.public_key.verify_signature(&message, &signature)?);
375        message[0] ^= 0x01u8;
376        assert!(!key_pair.public_key.verify_signature(&message, &signature)?);
377        message[0] ^= 0x01u8;
378        let public_key = key_pair.private_key.public_key()?;
379        assert!(public_key.verify_signature(&message, &signature)?);
380
381        assert!(public_key
382            .verify_signature_for_multipart_message(&[&message[..7], &message[7..]], &signature)?);
383
384        let signature = key_pair
385            .private_key
386            .calculate_signature_for_multipart_message(
387                &[&message[..20], &message[20..]],
388                &mut csprng,
389            )?;
390        assert!(public_key.verify_signature(&message, &signature)?);
391
392        Ok(())
393    }
394
395    #[test]
396    fn test_decode_size() -> Result<()> {
397        let mut csprng = OsRng;
398        let key_pair = KeyPair::generate(&mut csprng);
399        let serialized_public = key_pair.public_key.serialize();
400
401        assert_eq!(
402            serialized_public,
403            key_pair.private_key.public_key()?.serialize()
404        );
405        let empty: [u8; 0] = [];
406
407        let just_right = PublicKey::try_from(&serialized_public[..]);
408
409        assert!(just_right.is_ok());
410        assert!(PublicKey::try_from(&serialized_public[1..]).is_err());
411        assert!(PublicKey::try_from(&empty[..]).is_err());
412
413        let mut bad_key_type = [0u8; 33];
414        bad_key_type[..].copy_from_slice(&serialized_public[..]);
415        bad_key_type[0] = 0x01u8;
416        assert!(PublicKey::try_from(&bad_key_type[..]).is_err());
417
418        let mut extra_space = [0u8; 34];
419        extra_space[..33].copy_from_slice(&serialized_public[..]);
420        let extra_space_decode = PublicKey::try_from(&extra_space[..]);
421        assert!(extra_space_decode.is_ok());
422
423        assert_eq!(&serialized_public[..], &just_right?.serialize()[..]);
424        assert_eq!(&serialized_public[..], &extra_space_decode?.serialize()[..]);
425        Ok(())
426    }
427}