Skip to main content

libsignal_protocol/ratchet/
keys.rs

1//
2// Copyright 2020 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::fmt;
7
8use libsignal_core::derive_arrays;
9
10use crate::proto::storage::session_structure;
11use crate::{PrivateKey, PublicKey, Result, crypto};
12
13#[derive(Clone)]
14pub(crate) enum MessageKeyGenerator {
15    Keys(MessageKeys),
16    Seed((Vec<u8>, u32)),
17}
18
19impl fmt::Debug for MessageKeyGenerator {
20    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
21        match self {
22            Self::Keys(k) => f.debug_tuple("Keys").field(&k.counter()).finish(),
23            Self::Seed((_, idx)) => f.debug_tuple("Seed").field(idx).finish(),
24        }
25    }
26}
27
28impl MessageKeyGenerator {
29    pub(crate) fn new_from_seed(seed: &[u8], counter: u32) -> Self {
30        Self::Seed((seed.to_vec(), counter))
31    }
32    pub(crate) fn generate_keys(self, pqr_key: spqr::MessageKey) -> MessageKeys {
33        match self {
34            Self::Seed((seed, counter)) => {
35                MessageKeys::derive_keys(&seed, pqr_key.as_deref(), counter)
36            }
37            Self::Keys(k) => {
38                // PQR keys should only be set for newer sessions, and in
39                // newer sessions there should be only seed-based generators.
40                assert!(pqr_key.is_none());
41                k
42            }
43        }
44    }
45    pub(crate) fn into_pb(self) -> session_structure::chain::MessageKey {
46        match self {
47            Self::Keys(k) => session_structure::chain::MessageKey {
48                cipher_key: k.cipher_key().to_vec(),
49                mac_key: k.mac_key().to_vec(),
50                iv: k.iv().to_vec(),
51                index: k.counter(),
52                seed: vec![],
53            },
54            Self::Seed((seed, counter)) => session_structure::chain::MessageKey {
55                cipher_key: vec![],
56                mac_key: vec![],
57                iv: vec![],
58                index: counter,
59                seed,
60            },
61        }
62    }
63    pub(crate) fn from_pb(
64        pb: session_structure::chain::MessageKey,
65    ) -> std::result::Result<Self, &'static str> {
66        Ok(if pb.seed.is_empty() {
67            Self::Keys(MessageKeys {
68                cipher_key: pb
69                    .cipher_key
70                    .as_slice()
71                    .try_into()
72                    .map_err(|_| "invalid message cipher key")?,
73                mac_key: pb
74                    .mac_key
75                    .as_slice()
76                    .try_into()
77                    .map_err(|_| "invalid message MAC key")?,
78                iv: pb
79                    .iv
80                    .as_slice()
81                    .try_into()
82                    .map_err(|_| "invalid message IV")?,
83                counter: pb.index,
84            })
85        } else {
86            Self::Seed((pb.seed, pb.index))
87        })
88    }
89}
90
91#[derive(Clone, Copy)]
92pub(crate) struct MessageKeys {
93    cipher_key: [u8; 32],
94    mac_key: [u8; 32],
95    iv: [u8; 16],
96    counter: u32,
97}
98
99impl MessageKeys {
100    pub(crate) fn derive_keys(
101        input_key_material: &[u8],
102        optional_salt: Option<&[u8]>,
103        counter: u32,
104    ) -> Self {
105        let _trace = libsignal_debug::trace_block!("MessageKeys::derive_keys");
106        let (cipher_key, mac_key, iv) = derive_arrays(|okm| {
107            hkdf::Hkdf::<sha2::Sha256>::new(optional_salt, input_key_material)
108                .expand(b"WhisperMessageKeys", okm)
109                .expect("valid output length")
110        });
111
112        MessageKeys {
113            cipher_key,
114            mac_key,
115            iv,
116            counter,
117        }
118    }
119
120    #[inline]
121    pub(crate) fn cipher_key(&self) -> &[u8; 32] {
122        &self.cipher_key
123    }
124
125    #[inline]
126    pub(crate) fn mac_key(&self) -> &[u8; 32] {
127        &self.mac_key
128    }
129
130    #[inline]
131    pub(crate) fn iv(&self) -> &[u8; 16] {
132        &self.iv
133    }
134
135    #[inline]
136    pub(crate) fn counter(&self) -> u32 {
137        self.counter
138    }
139}
140
141#[derive(Clone, Debug)]
142pub(crate) struct ChainKey {
143    key: [u8; 32],
144    index: u32,
145}
146
147impl ChainKey {
148    const MESSAGE_KEY_SEED: [u8; 1] = [0x01u8];
149    const CHAIN_KEY_SEED: [u8; 1] = [0x02u8];
150
151    pub(crate) fn new(key: [u8; 32], index: u32) -> Self {
152        Self { key, index }
153    }
154
155    #[inline]
156    pub(crate) fn key(&self) -> &[u8; 32] {
157        &self.key
158    }
159
160    #[inline]
161    pub(crate) fn index(&self) -> u32 {
162        self.index
163    }
164
165    pub(crate) fn next_chain_key(&self) -> Self {
166        Self {
167            key: self.calculate_base_material(Self::CHAIN_KEY_SEED),
168            index: self.index + 1,
169        }
170    }
171
172    pub(crate) fn message_keys(&self) -> MessageKeyGenerator {
173        MessageKeyGenerator::new_from_seed(
174            &self.calculate_base_material(Self::MESSAGE_KEY_SEED),
175            self.index,
176        )
177    }
178
179    fn calculate_base_material(&self, seed: [u8; 1]) -> [u8; 32] {
180        let _trace = libsignal_debug::trace_block!("keys::calculate_base_material");
181        crypto::hmac_sha256(&self.key, &seed)
182    }
183}
184
185#[derive(Clone, Debug)]
186pub(crate) struct RootKey {
187    key: [u8; 32],
188}
189
190impl RootKey {
191    pub(crate) fn new(key: [u8; 32]) -> Self {
192        Self { key }
193    }
194
195    pub(crate) fn key(&self) -> &[u8; 32] {
196        &self.key
197    }
198
199    pub(crate) fn create_chain(
200        self,
201        their_ratchet_key: &PublicKey,
202        our_ratchet_key: &PrivateKey,
203    ) -> Result<(RootKey, ChainKey)> {
204        let shared_secret = our_ratchet_key.calculate_agreement(their_ratchet_key)?;
205        let (root_key, chain_key, []) = derive_arrays(|bytes| {
206            hkdf::Hkdf::<sha2::Sha256>::new(Some(&self.key), &shared_secret)
207                .expand(b"WhisperRatchet", bytes)
208                .expect("valid output length")
209        });
210
211        Ok((
212            RootKey { key: root_key },
213            ChainKey {
214                key: chain_key,
215                index: 0,
216            },
217        ))
218    }
219}
220
221impl fmt::Display for RootKey {
222    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
223        write!(f, "{}", hex::encode(self.key))
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_chain_key_derivation() -> Result<()> {
233        let seed = [
234            0x8au8, 0xb7, 0x2d, 0x6f, 0x4c, 0xc5, 0xac, 0x0d, 0x38, 0x7e, 0xaf, 0x46, 0x33, 0x78,
235            0xdd, 0xb2, 0x8e, 0xdd, 0x07, 0x38, 0x5b, 0x1c, 0xb0, 0x12, 0x50, 0xc7, 0x15, 0x98,
236            0x2e, 0x7a, 0xd4, 0x8f,
237        ];
238        let message_key = [
239            0xbfu8, 0x51, 0xe9, 0xd7, 0x5e, 0x0e, 0x31, 0x03, 0x10, 0x51, 0xf8, 0x2a, 0x24, 0x91,
240            0xff, 0xc0, 0x84, 0xfa, 0x29, 0x8b, 0x77, 0x93, 0xbd, 0x9d, 0xb6, 0x20, 0x05, 0x6f,
241            0xeb, 0xf4, 0x52, 0x17,
242        ];
243        let mac_key = [
244            0xc6u8, 0xc7, 0x7d, 0x6a, 0x73, 0xa3, 0x54, 0x33, 0x7a, 0x56, 0x43, 0x5e, 0x34, 0x60,
245            0x7d, 0xfe, 0x48, 0xe3, 0xac, 0xe1, 0x4e, 0x77, 0x31, 0x4d, 0xc6, 0xab, 0xc1, 0x72,
246            0xe7, 0xa7, 0x03, 0x0b,
247        ];
248        let next_chain_key = [
249            0x28u8, 0xe8, 0xf8, 0xfe, 0xe5, 0x4b, 0x80, 0x1e, 0xef, 0x7c, 0x5c, 0xfb, 0x2f, 0x17,
250            0xf3, 0x2c, 0x7b, 0x33, 0x44, 0x85, 0xbb, 0xb7, 0x0f, 0xac, 0x6e, 0xc1, 0x03, 0x42,
251            0xa2, 0x46, 0xd1, 0x5d,
252        ];
253
254        let chain_key = ChainKey::new(seed, 0);
255        assert_eq!(&seed, chain_key.key());
256        assert_eq!(
257            &message_key,
258            chain_key.message_keys().generate_keys(None).cipher_key()
259        );
260        assert_eq!(
261            &mac_key,
262            chain_key.message_keys().generate_keys(None).mac_key()
263        );
264        assert_eq!(&next_chain_key, chain_key.next_chain_key().key());
265        assert_eq!(0, chain_key.index());
266        assert_eq!(0, chain_key.message_keys().generate_keys(None).counter());
267        assert_eq!(1, chain_key.next_chain_key().index());
268        assert_eq!(
269            1,
270            chain_key
271                .next_chain_key()
272                .message_keys()
273                .generate_keys(None)
274                .counter()
275        );
276        Ok(())
277    }
278}