libsignal_protocol/ratchet/
keys.rs

1//
2// Copyright 2020 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::fmt;
7
8use zerocopy::{FromBytes, IntoBytes, KnownLayout};
9
10use crate::proto::storage::session_structure;
11use crate::{PrivateKey, PublicKey, Result, crypto};
12
13pub(crate) enum MessageKeyGenerator {
14    Keys(MessageKeys),
15    Seed((Vec<u8>, u32)),
16}
17
18impl MessageKeyGenerator {
19    pub(crate) fn new_from_seed(seed: &[u8], counter: u32) -> Self {
20        Self::Seed((seed.to_vec(), counter))
21    }
22    pub(crate) fn generate_keys(self, pqr_key: spqr::MessageKey) -> MessageKeys {
23        match self {
24            Self::Seed((seed, counter)) => {
25                MessageKeys::derive_keys(&seed, pqr_key.as_deref(), counter)
26            }
27            Self::Keys(k) => {
28                // PQR keys should only be set for newer sessions, and in
29                // newer sessions there should be only seed-based generators.
30                assert!(pqr_key.is_none());
31                k
32            }
33        }
34    }
35    pub(crate) fn into_pb(self) -> session_structure::chain::MessageKey {
36        match self {
37            Self::Keys(k) => session_structure::chain::MessageKey {
38                cipher_key: k.cipher_key().to_vec(),
39                mac_key: k.mac_key().to_vec(),
40                iv: k.iv().to_vec(),
41                index: k.counter(),
42                seed: vec![],
43            },
44            Self::Seed((seed, counter)) => session_structure::chain::MessageKey {
45                cipher_key: vec![],
46                mac_key: vec![],
47                iv: vec![],
48                index: counter,
49                seed,
50            },
51        }
52    }
53    pub(crate) fn from_pb(
54        pb: session_structure::chain::MessageKey,
55    ) -> std::result::Result<Self, &'static str> {
56        Ok(if pb.seed.is_empty() {
57            Self::Keys(MessageKeys {
58                cipher_key: pb
59                    .cipher_key
60                    .as_slice()
61                    .try_into()
62                    .map_err(|_| "invalid message cipher key")?,
63                mac_key: pb
64                    .mac_key
65                    .as_slice()
66                    .try_into()
67                    .map_err(|_| "invalid message MAC key")?,
68                iv: pb
69                    .iv
70                    .as_slice()
71                    .try_into()
72                    .map_err(|_| "invalid message IV")?,
73                counter: pb.index,
74            })
75        } else {
76            Self::Seed((pb.seed, pb.index))
77        })
78    }
79}
80
81#[derive(Clone, Copy)]
82pub(crate) struct MessageKeys {
83    cipher_key: [u8; 32],
84    mac_key: [u8; 32],
85    iv: [u8; 16],
86    counter: u32,
87}
88
89impl MessageKeys {
90    pub(crate) fn derive_keys(
91        input_key_material: &[u8],
92        optional_salt: Option<&[u8]>,
93        counter: u32,
94    ) -> Self {
95        #[derive(Default, KnownLayout, IntoBytes, FromBytes)]
96        #[repr(C, packed)]
97        struct DerivedSecretBytes([u8; 32], [u8; 32], [u8; 16]);
98        let mut okm = DerivedSecretBytes::default();
99
100        hkdf::Hkdf::<sha2::Sha256>::new(optional_salt, input_key_material)
101            .expand(b"WhisperMessageKeys", okm.as_mut_bytes())
102            .expect("valid output length");
103
104        let DerivedSecretBytes(cipher_key, mac_key, iv) = okm;
105
106        MessageKeys {
107            cipher_key,
108            mac_key,
109            iv,
110            counter,
111        }
112    }
113
114    #[inline]
115    pub(crate) fn cipher_key(&self) -> &[u8; 32] {
116        &self.cipher_key
117    }
118
119    #[inline]
120    pub(crate) fn mac_key(&self) -> &[u8; 32] {
121        &self.mac_key
122    }
123
124    #[inline]
125    pub(crate) fn iv(&self) -> &[u8; 16] {
126        &self.iv
127    }
128
129    #[inline]
130    pub(crate) fn counter(&self) -> u32 {
131        self.counter
132    }
133}
134
135#[derive(Clone, Debug)]
136pub(crate) struct ChainKey {
137    key: [u8; 32],
138    index: u32,
139}
140
141impl ChainKey {
142    const MESSAGE_KEY_SEED: [u8; 1] = [0x01u8];
143    const CHAIN_KEY_SEED: [u8; 1] = [0x02u8];
144
145    pub(crate) fn new(key: [u8; 32], index: u32) -> Self {
146        Self { key, index }
147    }
148
149    #[inline]
150    pub(crate) fn key(&self) -> &[u8; 32] {
151        &self.key
152    }
153
154    #[inline]
155    pub(crate) fn index(&self) -> u32 {
156        self.index
157    }
158
159    pub(crate) fn next_chain_key(&self) -> Self {
160        Self {
161            key: self.calculate_base_material(Self::CHAIN_KEY_SEED),
162            index: self.index + 1,
163        }
164    }
165
166    pub(crate) fn message_keys(&self) -> MessageKeyGenerator {
167        MessageKeyGenerator::new_from_seed(
168            &self.calculate_base_material(Self::MESSAGE_KEY_SEED),
169            self.index,
170        )
171    }
172
173    fn calculate_base_material(&self, seed: [u8; 1]) -> [u8; 32] {
174        crypto::hmac_sha256(&self.key, &seed)
175    }
176}
177
178#[derive(Clone, Debug)]
179pub(crate) struct RootKey {
180    key: [u8; 32],
181}
182
183impl RootKey {
184    pub(crate) fn new(key: [u8; 32]) -> Self {
185        Self { key }
186    }
187
188    pub(crate) fn key(&self) -> &[u8; 32] {
189        &self.key
190    }
191
192    pub(crate) fn create_chain(
193        self,
194        their_ratchet_key: &PublicKey,
195        our_ratchet_key: &PrivateKey,
196    ) -> Result<(RootKey, ChainKey)> {
197        let shared_secret = our_ratchet_key.calculate_agreement(their_ratchet_key)?;
198        #[derive(Default, KnownLayout, IntoBytes, FromBytes)]
199        #[repr(C, packed)]
200        struct DerivedSecretBytes([u8; 32], [u8; 32]);
201        let mut derived_secret_bytes = DerivedSecretBytes::default();
202
203        hkdf::Hkdf::<sha2::Sha256>::new(Some(&self.key), &shared_secret)
204            .expand(b"WhisperRatchet", derived_secret_bytes.as_mut_bytes())
205            .expect("valid output length");
206
207        let DerivedSecretBytes(root_key, chain_key) = derived_secret_bytes;
208
209        Ok((
210            RootKey { key: root_key },
211            ChainKey {
212                key: chain_key,
213                index: 0,
214            },
215        ))
216    }
217}
218
219impl fmt::Display for RootKey {
220    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221        write!(f, "{}", hex::encode(self.key))
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_chain_key_derivation() -> Result<()> {
231        let seed = [
232            0x8au8, 0xb7, 0x2d, 0x6f, 0x4c, 0xc5, 0xac, 0x0d, 0x38, 0x7e, 0xaf, 0x46, 0x33, 0x78,
233            0xdd, 0xb2, 0x8e, 0xdd, 0x07, 0x38, 0x5b, 0x1c, 0xb0, 0x12, 0x50, 0xc7, 0x15, 0x98,
234            0x2e, 0x7a, 0xd4, 0x8f,
235        ];
236        let message_key = [
237            0xbfu8, 0x51, 0xe9, 0xd7, 0x5e, 0x0e, 0x31, 0x03, 0x10, 0x51, 0xf8, 0x2a, 0x24, 0x91,
238            0xff, 0xc0, 0x84, 0xfa, 0x29, 0x8b, 0x77, 0x93, 0xbd, 0x9d, 0xb6, 0x20, 0x05, 0x6f,
239            0xeb, 0xf4, 0x52, 0x17,
240        ];
241        let mac_key = [
242            0xc6u8, 0xc7, 0x7d, 0x6a, 0x73, 0xa3, 0x54, 0x33, 0x7a, 0x56, 0x43, 0x5e, 0x34, 0x60,
243            0x7d, 0xfe, 0x48, 0xe3, 0xac, 0xe1, 0x4e, 0x77, 0x31, 0x4d, 0xc6, 0xab, 0xc1, 0x72,
244            0xe7, 0xa7, 0x03, 0x0b,
245        ];
246        let next_chain_key = [
247            0x28u8, 0xe8, 0xf8, 0xfe, 0xe5, 0x4b, 0x80, 0x1e, 0xef, 0x7c, 0x5c, 0xfb, 0x2f, 0x17,
248            0xf3, 0x2c, 0x7b, 0x33, 0x44, 0x85, 0xbb, 0xb7, 0x0f, 0xac, 0x6e, 0xc1, 0x03, 0x42,
249            0xa2, 0x46, 0xd1, 0x5d,
250        ];
251
252        let chain_key = ChainKey::new(seed, 0);
253        assert_eq!(&seed, chain_key.key());
254        assert_eq!(
255            &message_key,
256            chain_key.message_keys().generate_keys(None).cipher_key()
257        );
258        assert_eq!(
259            &mac_key,
260            chain_key.message_keys().generate_keys(None).mac_key()
261        );
262        assert_eq!(&next_chain_key, chain_key.next_chain_key().key());
263        assert_eq!(0, chain_key.index());
264        assert_eq!(0, chain_key.message_keys().generate_keys(None).counter());
265        assert_eq!(1, chain_key.next_chain_key().index());
266        assert_eq!(
267            1,
268            chain_key
269                .next_chain_key()
270                .message_keys()
271                .generate_keys(None)
272                .counter()
273        );
274        Ok(())
275    }
276}