1use std::fmt;
7
8use zerocopy::{FromBytes, IntoBytes, KnownLayout};
9
10use crate::proto::storage::session_structure;
11use crate::{PrivateKey, PublicKey, Result, crypto};
12
13pub(crate) enum MessageKeyGenerator {
14 Keys(MessageKeys),
15 Seed((Vec<u8>, u32)),
16}
17
18impl MessageKeyGenerator {
19 pub(crate) fn new_from_seed(seed: &[u8], counter: u32) -> Self {
20 Self::Seed((seed.to_vec(), counter))
21 }
22 pub(crate) fn generate_keys(self, pqr_key: spqr::MessageKey) -> MessageKeys {
23 match self {
24 Self::Seed((seed, counter)) => {
25 MessageKeys::derive_keys(&seed, pqr_key.as_deref(), counter)
26 }
27 Self::Keys(k) => {
28 assert!(pqr_key.is_none());
31 k
32 }
33 }
34 }
35 pub(crate) fn into_pb(self) -> session_structure::chain::MessageKey {
36 match self {
37 Self::Keys(k) => session_structure::chain::MessageKey {
38 cipher_key: k.cipher_key().to_vec(),
39 mac_key: k.mac_key().to_vec(),
40 iv: k.iv().to_vec(),
41 index: k.counter(),
42 seed: vec![],
43 },
44 Self::Seed((seed, counter)) => session_structure::chain::MessageKey {
45 cipher_key: vec![],
46 mac_key: vec![],
47 iv: vec![],
48 index: counter,
49 seed,
50 },
51 }
52 }
53 pub(crate) fn from_pb(
54 pb: session_structure::chain::MessageKey,
55 ) -> std::result::Result<Self, &'static str> {
56 Ok(if pb.seed.is_empty() {
57 Self::Keys(MessageKeys {
58 cipher_key: pb
59 .cipher_key
60 .as_slice()
61 .try_into()
62 .map_err(|_| "invalid message cipher key")?,
63 mac_key: pb
64 .mac_key
65 .as_slice()
66 .try_into()
67 .map_err(|_| "invalid message MAC key")?,
68 iv: pb
69 .iv
70 .as_slice()
71 .try_into()
72 .map_err(|_| "invalid message IV")?,
73 counter: pb.index,
74 })
75 } else {
76 Self::Seed((pb.seed, pb.index))
77 })
78 }
79}
80
81#[derive(Clone, Copy)]
82pub(crate) struct MessageKeys {
83 cipher_key: [u8; 32],
84 mac_key: [u8; 32],
85 iv: [u8; 16],
86 counter: u32,
87}
88
89impl MessageKeys {
90 pub(crate) fn derive_keys(
91 input_key_material: &[u8],
92 optional_salt: Option<&[u8]>,
93 counter: u32,
94 ) -> Self {
95 #[derive(Default, KnownLayout, IntoBytes, FromBytes)]
96 #[repr(C, packed)]
97 struct DerivedSecretBytes([u8; 32], [u8; 32], [u8; 16]);
98 let mut okm = DerivedSecretBytes::default();
99
100 hkdf::Hkdf::<sha2::Sha256>::new(optional_salt, input_key_material)
101 .expand(b"WhisperMessageKeys", okm.as_mut_bytes())
102 .expect("valid output length");
103
104 let DerivedSecretBytes(cipher_key, mac_key, iv) = okm;
105
106 MessageKeys {
107 cipher_key,
108 mac_key,
109 iv,
110 counter,
111 }
112 }
113
114 #[inline]
115 pub(crate) fn cipher_key(&self) -> &[u8; 32] {
116 &self.cipher_key
117 }
118
119 #[inline]
120 pub(crate) fn mac_key(&self) -> &[u8; 32] {
121 &self.mac_key
122 }
123
124 #[inline]
125 pub(crate) fn iv(&self) -> &[u8; 16] {
126 &self.iv
127 }
128
129 #[inline]
130 pub(crate) fn counter(&self) -> u32 {
131 self.counter
132 }
133}
134
135#[derive(Clone, Debug)]
136pub(crate) struct ChainKey {
137 key: [u8; 32],
138 index: u32,
139}
140
141impl ChainKey {
142 const MESSAGE_KEY_SEED: [u8; 1] = [0x01u8];
143 const CHAIN_KEY_SEED: [u8; 1] = [0x02u8];
144
145 pub(crate) fn new(key: [u8; 32], index: u32) -> Self {
146 Self { key, index }
147 }
148
149 #[inline]
150 pub(crate) fn key(&self) -> &[u8; 32] {
151 &self.key
152 }
153
154 #[inline]
155 pub(crate) fn index(&self) -> u32 {
156 self.index
157 }
158
159 pub(crate) fn next_chain_key(&self) -> Self {
160 Self {
161 key: self.calculate_base_material(Self::CHAIN_KEY_SEED),
162 index: self.index + 1,
163 }
164 }
165
166 pub(crate) fn message_keys(&self) -> MessageKeyGenerator {
167 MessageKeyGenerator::new_from_seed(
168 &self.calculate_base_material(Self::MESSAGE_KEY_SEED),
169 self.index,
170 )
171 }
172
173 fn calculate_base_material(&self, seed: [u8; 1]) -> [u8; 32] {
174 crypto::hmac_sha256(&self.key, &seed)
175 }
176}
177
178#[derive(Clone, Debug)]
179pub(crate) struct RootKey {
180 key: [u8; 32],
181}
182
183impl RootKey {
184 pub(crate) fn new(key: [u8; 32]) -> Self {
185 Self { key }
186 }
187
188 pub(crate) fn key(&self) -> &[u8; 32] {
189 &self.key
190 }
191
192 pub(crate) fn create_chain(
193 self,
194 their_ratchet_key: &PublicKey,
195 our_ratchet_key: &PrivateKey,
196 ) -> Result<(RootKey, ChainKey)> {
197 let shared_secret = our_ratchet_key.calculate_agreement(their_ratchet_key)?;
198 #[derive(Default, KnownLayout, IntoBytes, FromBytes)]
199 #[repr(C, packed)]
200 struct DerivedSecretBytes([u8; 32], [u8; 32]);
201 let mut derived_secret_bytes = DerivedSecretBytes::default();
202
203 hkdf::Hkdf::<sha2::Sha256>::new(Some(&self.key), &shared_secret)
204 .expand(b"WhisperRatchet", derived_secret_bytes.as_mut_bytes())
205 .expect("valid output length");
206
207 let DerivedSecretBytes(root_key, chain_key) = derived_secret_bytes;
208
209 Ok((
210 RootKey { key: root_key },
211 ChainKey {
212 key: chain_key,
213 index: 0,
214 },
215 ))
216 }
217}
218
219impl fmt::Display for RootKey {
220 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
221 write!(f, "{}", hex::encode(self.key))
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_chain_key_derivation() -> Result<()> {
231 let seed = [
232 0x8au8, 0xb7, 0x2d, 0x6f, 0x4c, 0xc5, 0xac, 0x0d, 0x38, 0x7e, 0xaf, 0x46, 0x33, 0x78,
233 0xdd, 0xb2, 0x8e, 0xdd, 0x07, 0x38, 0x5b, 0x1c, 0xb0, 0x12, 0x50, 0xc7, 0x15, 0x98,
234 0x2e, 0x7a, 0xd4, 0x8f,
235 ];
236 let message_key = [
237 0xbfu8, 0x51, 0xe9, 0xd7, 0x5e, 0x0e, 0x31, 0x03, 0x10, 0x51, 0xf8, 0x2a, 0x24, 0x91,
238 0xff, 0xc0, 0x84, 0xfa, 0x29, 0x8b, 0x77, 0x93, 0xbd, 0x9d, 0xb6, 0x20, 0x05, 0x6f,
239 0xeb, 0xf4, 0x52, 0x17,
240 ];
241 let mac_key = [
242 0xc6u8, 0xc7, 0x7d, 0x6a, 0x73, 0xa3, 0x54, 0x33, 0x7a, 0x56, 0x43, 0x5e, 0x34, 0x60,
243 0x7d, 0xfe, 0x48, 0xe3, 0xac, 0xe1, 0x4e, 0x77, 0x31, 0x4d, 0xc6, 0xab, 0xc1, 0x72,
244 0xe7, 0xa7, 0x03, 0x0b,
245 ];
246 let next_chain_key = [
247 0x28u8, 0xe8, 0xf8, 0xfe, 0xe5, 0x4b, 0x80, 0x1e, 0xef, 0x7c, 0x5c, 0xfb, 0x2f, 0x17,
248 0xf3, 0x2c, 0x7b, 0x33, 0x44, 0x85, 0xbb, 0xb7, 0x0f, 0xac, 0x6e, 0xc1, 0x03, 0x42,
249 0xa2, 0x46, 0xd1, 0x5d,
250 ];
251
252 let chain_key = ChainKey::new(seed, 0);
253 assert_eq!(&seed, chain_key.key());
254 assert_eq!(
255 &message_key,
256 chain_key.message_keys().generate_keys(None).cipher_key()
257 );
258 assert_eq!(
259 &mac_key,
260 chain_key.message_keys().generate_keys(None).mac_key()
261 );
262 assert_eq!(&next_chain_key, chain_key.next_chain_key().key());
263 assert_eq!(0, chain_key.index());
264 assert_eq!(0, chain_key.message_keys().generate_keys(None).counter());
265 assert_eq!(1, chain_key.next_chain_key().index());
266 assert_eq!(
267 1,
268 chain_key
269 .next_chain_key()
270 .message_keys()
271 .generate_keys(None)
272 .counter()
273 );
274 Ok(())
275 }
276}