libsignal_protocol/ratchet/
keys.rs

1//
2// Copyright 2020 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::fmt;
7
8use libsignal_core::derive_arrays;
9
10use crate::proto::storage::session_structure;
11use crate::{PrivateKey, PublicKey, Result, crypto};
12
13pub(crate) enum MessageKeyGenerator {
14    Keys(MessageKeys),
15    Seed((Vec<u8>, u32)),
16}
17
18impl MessageKeyGenerator {
19    pub(crate) fn new_from_seed(seed: &[u8], counter: u32) -> Self {
20        Self::Seed((seed.to_vec(), counter))
21    }
22    pub(crate) fn generate_keys(self, pqr_key: spqr::MessageKey) -> MessageKeys {
23        match self {
24            Self::Seed((seed, counter)) => {
25                MessageKeys::derive_keys(&seed, pqr_key.as_deref(), counter)
26            }
27            Self::Keys(k) => {
28                // PQR keys should only be set for newer sessions, and in
29                // newer sessions there should be only seed-based generators.
30                assert!(pqr_key.is_none());
31                k
32            }
33        }
34    }
35    pub(crate) fn into_pb(self) -> session_structure::chain::MessageKey {
36        match self {
37            Self::Keys(k) => session_structure::chain::MessageKey {
38                cipher_key: k.cipher_key().to_vec(),
39                mac_key: k.mac_key().to_vec(),
40                iv: k.iv().to_vec(),
41                index: k.counter(),
42                seed: vec![],
43            },
44            Self::Seed((seed, counter)) => session_structure::chain::MessageKey {
45                cipher_key: vec![],
46                mac_key: vec![],
47                iv: vec![],
48                index: counter,
49                seed,
50            },
51        }
52    }
53    pub(crate) fn from_pb(
54        pb: session_structure::chain::MessageKey,
55    ) -> std::result::Result<Self, &'static str> {
56        Ok(if pb.seed.is_empty() {
57            Self::Keys(MessageKeys {
58                cipher_key: pb
59                    .cipher_key
60                    .as_slice()
61                    .try_into()
62                    .map_err(|_| "invalid message cipher key")?,
63                mac_key: pb
64                    .mac_key
65                    .as_slice()
66                    .try_into()
67                    .map_err(|_| "invalid message MAC key")?,
68                iv: pb
69                    .iv
70                    .as_slice()
71                    .try_into()
72                    .map_err(|_| "invalid message IV")?,
73                counter: pb.index,
74            })
75        } else {
76            Self::Seed((pb.seed, pb.index))
77        })
78    }
79}
80
81#[derive(Clone, Copy)]
82pub(crate) struct MessageKeys {
83    cipher_key: [u8; 32],
84    mac_key: [u8; 32],
85    iv: [u8; 16],
86    counter: u32,
87}
88
89impl MessageKeys {
90    pub(crate) fn derive_keys(
91        input_key_material: &[u8],
92        optional_salt: Option<&[u8]>,
93        counter: u32,
94    ) -> Self {
95        let _trace = libsignal_debug::trace_block!("MessageKeys::derive_keys");
96        let (cipher_key, mac_key, iv) = derive_arrays(|okm| {
97            hkdf::Hkdf::<sha2::Sha256>::new(optional_salt, input_key_material)
98                .expand(b"WhisperMessageKeys", okm)
99                .expect("valid output length")
100        });
101
102        MessageKeys {
103            cipher_key,
104            mac_key,
105            iv,
106            counter,
107        }
108    }
109
110    #[inline]
111    pub(crate) fn cipher_key(&self) -> &[u8; 32] {
112        &self.cipher_key
113    }
114
115    #[inline]
116    pub(crate) fn mac_key(&self) -> &[u8; 32] {
117        &self.mac_key
118    }
119
120    #[inline]
121    pub(crate) fn iv(&self) -> &[u8; 16] {
122        &self.iv
123    }
124
125    #[inline]
126    pub(crate) fn counter(&self) -> u32 {
127        self.counter
128    }
129}
130
131#[derive(Clone, Debug)]
132pub(crate) struct ChainKey {
133    key: [u8; 32],
134    index: u32,
135}
136
137impl ChainKey {
138    const MESSAGE_KEY_SEED: [u8; 1] = [0x01u8];
139    const CHAIN_KEY_SEED: [u8; 1] = [0x02u8];
140
141    pub(crate) fn new(key: [u8; 32], index: u32) -> Self {
142        Self { key, index }
143    }
144
145    #[inline]
146    pub(crate) fn key(&self) -> &[u8; 32] {
147        &self.key
148    }
149
150    #[inline]
151    pub(crate) fn index(&self) -> u32 {
152        self.index
153    }
154
155    pub(crate) fn next_chain_key(&self) -> Self {
156        Self {
157            key: self.calculate_base_material(Self::CHAIN_KEY_SEED),
158            index: self.index + 1,
159        }
160    }
161
162    pub(crate) fn message_keys(&self) -> MessageKeyGenerator {
163        MessageKeyGenerator::new_from_seed(
164            &self.calculate_base_material(Self::MESSAGE_KEY_SEED),
165            self.index,
166        )
167    }
168
169    fn calculate_base_material(&self, seed: [u8; 1]) -> [u8; 32] {
170        let _trace = libsignal_debug::trace_block!("keys::calculate_base_material");
171        crypto::hmac_sha256(&self.key, &seed)
172    }
173}
174
175#[derive(Clone, Debug)]
176pub(crate) struct RootKey {
177    key: [u8; 32],
178}
179
180impl RootKey {
181    pub(crate) fn new(key: [u8; 32]) -> Self {
182        Self { key }
183    }
184
185    pub(crate) fn key(&self) -> &[u8; 32] {
186        &self.key
187    }
188
189    pub(crate) fn create_chain(
190        self,
191        their_ratchet_key: &PublicKey,
192        our_ratchet_key: &PrivateKey,
193    ) -> Result<(RootKey, ChainKey)> {
194        let shared_secret = our_ratchet_key.calculate_agreement(their_ratchet_key)?;
195        let (root_key, chain_key, []) = derive_arrays(|bytes| {
196            hkdf::Hkdf::<sha2::Sha256>::new(Some(&self.key), &shared_secret)
197                .expand(b"WhisperRatchet", bytes)
198                .expect("valid output length")
199        });
200
201        Ok((
202            RootKey { key: root_key },
203            ChainKey {
204                key: chain_key,
205                index: 0,
206            },
207        ))
208    }
209}
210
211impl fmt::Display for RootKey {
212    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
213        write!(f, "{}", hex::encode(self.key))
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_chain_key_derivation() -> Result<()> {
223        let seed = [
224            0x8au8, 0xb7, 0x2d, 0x6f, 0x4c, 0xc5, 0xac, 0x0d, 0x38, 0x7e, 0xaf, 0x46, 0x33, 0x78,
225            0xdd, 0xb2, 0x8e, 0xdd, 0x07, 0x38, 0x5b, 0x1c, 0xb0, 0x12, 0x50, 0xc7, 0x15, 0x98,
226            0x2e, 0x7a, 0xd4, 0x8f,
227        ];
228        let message_key = [
229            0xbfu8, 0x51, 0xe9, 0xd7, 0x5e, 0x0e, 0x31, 0x03, 0x10, 0x51, 0xf8, 0x2a, 0x24, 0x91,
230            0xff, 0xc0, 0x84, 0xfa, 0x29, 0x8b, 0x77, 0x93, 0xbd, 0x9d, 0xb6, 0x20, 0x05, 0x6f,
231            0xeb, 0xf4, 0x52, 0x17,
232        ];
233        let mac_key = [
234            0xc6u8, 0xc7, 0x7d, 0x6a, 0x73, 0xa3, 0x54, 0x33, 0x7a, 0x56, 0x43, 0x5e, 0x34, 0x60,
235            0x7d, 0xfe, 0x48, 0xe3, 0xac, 0xe1, 0x4e, 0x77, 0x31, 0x4d, 0xc6, 0xab, 0xc1, 0x72,
236            0xe7, 0xa7, 0x03, 0x0b,
237        ];
238        let next_chain_key = [
239            0x28u8, 0xe8, 0xf8, 0xfe, 0xe5, 0x4b, 0x80, 0x1e, 0xef, 0x7c, 0x5c, 0xfb, 0x2f, 0x17,
240            0xf3, 0x2c, 0x7b, 0x33, 0x44, 0x85, 0xbb, 0xb7, 0x0f, 0xac, 0x6e, 0xc1, 0x03, 0x42,
241            0xa2, 0x46, 0xd1, 0x5d,
242        ];
243
244        let chain_key = ChainKey::new(seed, 0);
245        assert_eq!(&seed, chain_key.key());
246        assert_eq!(
247            &message_key,
248            chain_key.message_keys().generate_keys(None).cipher_key()
249        );
250        assert_eq!(
251            &mac_key,
252            chain_key.message_keys().generate_keys(None).mac_key()
253        );
254        assert_eq!(&next_chain_key, chain_key.next_chain_key().key());
255        assert_eq!(0, chain_key.index());
256        assert_eq!(0, chain_key.message_keys().generate_keys(None).counter());
257        assert_eq!(1, chain_key.next_chain_key().index());
258        assert_eq!(
259            1,
260            chain_key
261                .next_chain_key()
262                .message_keys()
263                .generate_keys(None)
264                .counter()
265        );
266        Ok(())
267    }
268}