libsignal_protocol/ratchet/
keys.rs

1//
2// Copyright 2020 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::fmt;
7
8use libsignal_core::derive_arrays;
9
10use crate::proto::storage::session_structure;
11use crate::{PrivateKey, PublicKey, Result, crypto};
12
13pub(crate) enum MessageKeyGenerator {
14    Keys(MessageKeys),
15    Seed((Vec<u8>, u32)),
16}
17
18impl MessageKeyGenerator {
19    pub(crate) fn new_from_seed(seed: &[u8], counter: u32) -> Self {
20        Self::Seed((seed.to_vec(), counter))
21    }
22    pub(crate) fn generate_keys(self, pqr_key: spqr::MessageKey) -> MessageKeys {
23        match self {
24            Self::Seed((seed, counter)) => {
25                MessageKeys::derive_keys(&seed, pqr_key.as_deref(), counter)
26            }
27            Self::Keys(k) => {
28                // PQR keys should only be set for newer sessions, and in
29                // newer sessions there should be only seed-based generators.
30                assert!(pqr_key.is_none());
31                k
32            }
33        }
34    }
35    pub(crate) fn into_pb(self) -> session_structure::chain::MessageKey {
36        match self {
37            Self::Keys(k) => session_structure::chain::MessageKey {
38                cipher_key: k.cipher_key().to_vec(),
39                mac_key: k.mac_key().to_vec(),
40                iv: k.iv().to_vec(),
41                index: k.counter(),
42                seed: vec![],
43            },
44            Self::Seed((seed, counter)) => session_structure::chain::MessageKey {
45                cipher_key: vec![],
46                mac_key: vec![],
47                iv: vec![],
48                index: counter,
49                seed,
50            },
51        }
52    }
53    pub(crate) fn from_pb(
54        pb: session_structure::chain::MessageKey,
55    ) -> std::result::Result<Self, &'static str> {
56        Ok(if pb.seed.is_empty() {
57            Self::Keys(MessageKeys {
58                cipher_key: pb
59                    .cipher_key
60                    .as_slice()
61                    .try_into()
62                    .map_err(|_| "invalid message cipher key")?,
63                mac_key: pb
64                    .mac_key
65                    .as_slice()
66                    .try_into()
67                    .map_err(|_| "invalid message MAC key")?,
68                iv: pb
69                    .iv
70                    .as_slice()
71                    .try_into()
72                    .map_err(|_| "invalid message IV")?,
73                counter: pb.index,
74            })
75        } else {
76            Self::Seed((pb.seed, pb.index))
77        })
78    }
79}
80
81#[derive(Clone, Copy)]
82pub(crate) struct MessageKeys {
83    cipher_key: [u8; 32],
84    mac_key: [u8; 32],
85    iv: [u8; 16],
86    counter: u32,
87}
88
89impl MessageKeys {
90    pub(crate) fn derive_keys(
91        input_key_material: &[u8],
92        optional_salt: Option<&[u8]>,
93        counter: u32,
94    ) -> Self {
95        let (cipher_key, mac_key, iv) = derive_arrays(|okm| {
96            hkdf::Hkdf::<sha2::Sha256>::new(optional_salt, input_key_material)
97                .expand(b"WhisperMessageKeys", okm)
98                .expect("valid output length")
99        });
100
101        MessageKeys {
102            cipher_key,
103            mac_key,
104            iv,
105            counter,
106        }
107    }
108
109    #[inline]
110    pub(crate) fn cipher_key(&self) -> &[u8; 32] {
111        &self.cipher_key
112    }
113
114    #[inline]
115    pub(crate) fn mac_key(&self) -> &[u8; 32] {
116        &self.mac_key
117    }
118
119    #[inline]
120    pub(crate) fn iv(&self) -> &[u8; 16] {
121        &self.iv
122    }
123
124    #[inline]
125    pub(crate) fn counter(&self) -> u32 {
126        self.counter
127    }
128}
129
130#[derive(Clone, Debug)]
131pub(crate) struct ChainKey {
132    key: [u8; 32],
133    index: u32,
134}
135
136impl ChainKey {
137    const MESSAGE_KEY_SEED: [u8; 1] = [0x01u8];
138    const CHAIN_KEY_SEED: [u8; 1] = [0x02u8];
139
140    pub(crate) fn new(key: [u8; 32], index: u32) -> Self {
141        Self { key, index }
142    }
143
144    #[inline]
145    pub(crate) fn key(&self) -> &[u8; 32] {
146        &self.key
147    }
148
149    #[inline]
150    pub(crate) fn index(&self) -> u32 {
151        self.index
152    }
153
154    pub(crate) fn next_chain_key(&self) -> Self {
155        Self {
156            key: self.calculate_base_material(Self::CHAIN_KEY_SEED),
157            index: self.index + 1,
158        }
159    }
160
161    pub(crate) fn message_keys(&self) -> MessageKeyGenerator {
162        MessageKeyGenerator::new_from_seed(
163            &self.calculate_base_material(Self::MESSAGE_KEY_SEED),
164            self.index,
165        )
166    }
167
168    fn calculate_base_material(&self, seed: [u8; 1]) -> [u8; 32] {
169        crypto::hmac_sha256(&self.key, &seed)
170    }
171}
172
173#[derive(Clone, Debug)]
174pub(crate) struct RootKey {
175    key: [u8; 32],
176}
177
178impl RootKey {
179    pub(crate) fn new(key: [u8; 32]) -> Self {
180        Self { key }
181    }
182
183    pub(crate) fn key(&self) -> &[u8; 32] {
184        &self.key
185    }
186
187    pub(crate) fn create_chain(
188        self,
189        their_ratchet_key: &PublicKey,
190        our_ratchet_key: &PrivateKey,
191    ) -> Result<(RootKey, ChainKey)> {
192        let shared_secret = our_ratchet_key.calculate_agreement(their_ratchet_key)?;
193        let (root_key, chain_key, []) = derive_arrays(|bytes| {
194            hkdf::Hkdf::<sha2::Sha256>::new(Some(&self.key), &shared_secret)
195                .expand(b"WhisperRatchet", bytes)
196                .expect("valid output length")
197        });
198
199        Ok((
200            RootKey { key: root_key },
201            ChainKey {
202                key: chain_key,
203                index: 0,
204            },
205        ))
206    }
207}
208
209impl fmt::Display for RootKey {
210    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
211        write!(f, "{}", hex::encode(self.key))
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218
219    #[test]
220    fn test_chain_key_derivation() -> Result<()> {
221        let seed = [
222            0x8au8, 0xb7, 0x2d, 0x6f, 0x4c, 0xc5, 0xac, 0x0d, 0x38, 0x7e, 0xaf, 0x46, 0x33, 0x78,
223            0xdd, 0xb2, 0x8e, 0xdd, 0x07, 0x38, 0x5b, 0x1c, 0xb0, 0x12, 0x50, 0xc7, 0x15, 0x98,
224            0x2e, 0x7a, 0xd4, 0x8f,
225        ];
226        let message_key = [
227            0xbfu8, 0x51, 0xe9, 0xd7, 0x5e, 0x0e, 0x31, 0x03, 0x10, 0x51, 0xf8, 0x2a, 0x24, 0x91,
228            0xff, 0xc0, 0x84, 0xfa, 0x29, 0x8b, 0x77, 0x93, 0xbd, 0x9d, 0xb6, 0x20, 0x05, 0x6f,
229            0xeb, 0xf4, 0x52, 0x17,
230        ];
231        let mac_key = [
232            0xc6u8, 0xc7, 0x7d, 0x6a, 0x73, 0xa3, 0x54, 0x33, 0x7a, 0x56, 0x43, 0x5e, 0x34, 0x60,
233            0x7d, 0xfe, 0x48, 0xe3, 0xac, 0xe1, 0x4e, 0x77, 0x31, 0x4d, 0xc6, 0xab, 0xc1, 0x72,
234            0xe7, 0xa7, 0x03, 0x0b,
235        ];
236        let next_chain_key = [
237            0x28u8, 0xe8, 0xf8, 0xfe, 0xe5, 0x4b, 0x80, 0x1e, 0xef, 0x7c, 0x5c, 0xfb, 0x2f, 0x17,
238            0xf3, 0x2c, 0x7b, 0x33, 0x44, 0x85, 0xbb, 0xb7, 0x0f, 0xac, 0x6e, 0xc1, 0x03, 0x42,
239            0xa2, 0x46, 0xd1, 0x5d,
240        ];
241
242        let chain_key = ChainKey::new(seed, 0);
243        assert_eq!(&seed, chain_key.key());
244        assert_eq!(
245            &message_key,
246            chain_key.message_keys().generate_keys(None).cipher_key()
247        );
248        assert_eq!(
249            &mac_key,
250            chain_key.message_keys().generate_keys(None).mac_key()
251        );
252        assert_eq!(&next_chain_key, chain_key.next_chain_key().key());
253        assert_eq!(0, chain_key.index());
254        assert_eq!(0, chain_key.message_keys().generate_keys(None).counter());
255        assert_eq!(1, chain_key.next_chain_key().index());
256        assert_eq!(
257            1,
258            chain_key
259                .next_chain_key()
260                .message_keys()
261                .generate_keys(None)
262                .counter()
263        );
264        Ok(())
265    }
266}