1use std::fmt;
7
8use libsignal_core::derive_arrays;
9
10use crate::proto::storage::session_structure;
11use crate::{PrivateKey, PublicKey, Result, crypto};
12
13pub(crate) enum MessageKeyGenerator {
14 Keys(MessageKeys),
15 Seed((Vec<u8>, u32)),
16}
17
18impl MessageKeyGenerator {
19 pub(crate) fn new_from_seed(seed: &[u8], counter: u32) -> Self {
20 Self::Seed((seed.to_vec(), counter))
21 }
22 pub(crate) fn generate_keys(self, pqr_key: spqr::MessageKey) -> MessageKeys {
23 match self {
24 Self::Seed((seed, counter)) => {
25 MessageKeys::derive_keys(&seed, pqr_key.as_deref(), counter)
26 }
27 Self::Keys(k) => {
28 assert!(pqr_key.is_none());
31 k
32 }
33 }
34 }
35 pub(crate) fn into_pb(self) -> session_structure::chain::MessageKey {
36 match self {
37 Self::Keys(k) => session_structure::chain::MessageKey {
38 cipher_key: k.cipher_key().to_vec(),
39 mac_key: k.mac_key().to_vec(),
40 iv: k.iv().to_vec(),
41 index: k.counter(),
42 seed: vec![],
43 },
44 Self::Seed((seed, counter)) => session_structure::chain::MessageKey {
45 cipher_key: vec![],
46 mac_key: vec![],
47 iv: vec![],
48 index: counter,
49 seed,
50 },
51 }
52 }
53 pub(crate) fn from_pb(
54 pb: session_structure::chain::MessageKey,
55 ) -> std::result::Result<Self, &'static str> {
56 Ok(if pb.seed.is_empty() {
57 Self::Keys(MessageKeys {
58 cipher_key: pb
59 .cipher_key
60 .as_slice()
61 .try_into()
62 .map_err(|_| "invalid message cipher key")?,
63 mac_key: pb
64 .mac_key
65 .as_slice()
66 .try_into()
67 .map_err(|_| "invalid message MAC key")?,
68 iv: pb
69 .iv
70 .as_slice()
71 .try_into()
72 .map_err(|_| "invalid message IV")?,
73 counter: pb.index,
74 })
75 } else {
76 Self::Seed((pb.seed, pb.index))
77 })
78 }
79}
80
81#[derive(Clone, Copy)]
82pub(crate) struct MessageKeys {
83 cipher_key: [u8; 32],
84 mac_key: [u8; 32],
85 iv: [u8; 16],
86 counter: u32,
87}
88
89impl MessageKeys {
90 pub(crate) fn derive_keys(
91 input_key_material: &[u8],
92 optional_salt: Option<&[u8]>,
93 counter: u32,
94 ) -> Self {
95 let _trace = libsignal_debug::trace_block!("MessageKeys::derive_keys");
96 let (cipher_key, mac_key, iv) = derive_arrays(|okm| {
97 hkdf::Hkdf::<sha2::Sha256>::new(optional_salt, input_key_material)
98 .expand(b"WhisperMessageKeys", okm)
99 .expect("valid output length")
100 });
101
102 MessageKeys {
103 cipher_key,
104 mac_key,
105 iv,
106 counter,
107 }
108 }
109
110 #[inline]
111 pub(crate) fn cipher_key(&self) -> &[u8; 32] {
112 &self.cipher_key
113 }
114
115 #[inline]
116 pub(crate) fn mac_key(&self) -> &[u8; 32] {
117 &self.mac_key
118 }
119
120 #[inline]
121 pub(crate) fn iv(&self) -> &[u8; 16] {
122 &self.iv
123 }
124
125 #[inline]
126 pub(crate) fn counter(&self) -> u32 {
127 self.counter
128 }
129}
130
131#[derive(Clone, Debug)]
132pub(crate) struct ChainKey {
133 key: [u8; 32],
134 index: u32,
135}
136
137impl ChainKey {
138 const MESSAGE_KEY_SEED: [u8; 1] = [0x01u8];
139 const CHAIN_KEY_SEED: [u8; 1] = [0x02u8];
140
141 pub(crate) fn new(key: [u8; 32], index: u32) -> Self {
142 Self { key, index }
143 }
144
145 #[inline]
146 pub(crate) fn key(&self) -> &[u8; 32] {
147 &self.key
148 }
149
150 #[inline]
151 pub(crate) fn index(&self) -> u32 {
152 self.index
153 }
154
155 pub(crate) fn next_chain_key(&self) -> Self {
156 Self {
157 key: self.calculate_base_material(Self::CHAIN_KEY_SEED),
158 index: self.index + 1,
159 }
160 }
161
162 pub(crate) fn message_keys(&self) -> MessageKeyGenerator {
163 MessageKeyGenerator::new_from_seed(
164 &self.calculate_base_material(Self::MESSAGE_KEY_SEED),
165 self.index,
166 )
167 }
168
169 fn calculate_base_material(&self, seed: [u8; 1]) -> [u8; 32] {
170 let _trace = libsignal_debug::trace_block!("keys::calculate_base_material");
171 crypto::hmac_sha256(&self.key, &seed)
172 }
173}
174
175#[derive(Clone, Debug)]
176pub(crate) struct RootKey {
177 key: [u8; 32],
178}
179
180impl RootKey {
181 pub(crate) fn new(key: [u8; 32]) -> Self {
182 Self { key }
183 }
184
185 pub(crate) fn key(&self) -> &[u8; 32] {
186 &self.key
187 }
188
189 pub(crate) fn create_chain(
190 self,
191 their_ratchet_key: &PublicKey,
192 our_ratchet_key: &PrivateKey,
193 ) -> Result<(RootKey, ChainKey)> {
194 let shared_secret = our_ratchet_key.calculate_agreement(their_ratchet_key)?;
195 let (root_key, chain_key, []) = derive_arrays(|bytes| {
196 hkdf::Hkdf::<sha2::Sha256>::new(Some(&self.key), &shared_secret)
197 .expand(b"WhisperRatchet", bytes)
198 .expect("valid output length")
199 });
200
201 Ok((
202 RootKey { key: root_key },
203 ChainKey {
204 key: chain_key,
205 index: 0,
206 },
207 ))
208 }
209}
210
211impl fmt::Display for RootKey {
212 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
213 write!(f, "{}", hex::encode(self.key))
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_chain_key_derivation() -> Result<()> {
223 let seed = [
224 0x8au8, 0xb7, 0x2d, 0x6f, 0x4c, 0xc5, 0xac, 0x0d, 0x38, 0x7e, 0xaf, 0x46, 0x33, 0x78,
225 0xdd, 0xb2, 0x8e, 0xdd, 0x07, 0x38, 0x5b, 0x1c, 0xb0, 0x12, 0x50, 0xc7, 0x15, 0x98,
226 0x2e, 0x7a, 0xd4, 0x8f,
227 ];
228 let message_key = [
229 0xbfu8, 0x51, 0xe9, 0xd7, 0x5e, 0x0e, 0x31, 0x03, 0x10, 0x51, 0xf8, 0x2a, 0x24, 0x91,
230 0xff, 0xc0, 0x84, 0xfa, 0x29, 0x8b, 0x77, 0x93, 0xbd, 0x9d, 0xb6, 0x20, 0x05, 0x6f,
231 0xeb, 0xf4, 0x52, 0x17,
232 ];
233 let mac_key = [
234 0xc6u8, 0xc7, 0x7d, 0x6a, 0x73, 0xa3, 0x54, 0x33, 0x7a, 0x56, 0x43, 0x5e, 0x34, 0x60,
235 0x7d, 0xfe, 0x48, 0xe3, 0xac, 0xe1, 0x4e, 0x77, 0x31, 0x4d, 0xc6, 0xab, 0xc1, 0x72,
236 0xe7, 0xa7, 0x03, 0x0b,
237 ];
238 let next_chain_key = [
239 0x28u8, 0xe8, 0xf8, 0xfe, 0xe5, 0x4b, 0x80, 0x1e, 0xef, 0x7c, 0x5c, 0xfb, 0x2f, 0x17,
240 0xf3, 0x2c, 0x7b, 0x33, 0x44, 0x85, 0xbb, 0xb7, 0x0f, 0xac, 0x6e, 0xc1, 0x03, 0x42,
241 0xa2, 0x46, 0xd1, 0x5d,
242 ];
243
244 let chain_key = ChainKey::new(seed, 0);
245 assert_eq!(&seed, chain_key.key());
246 assert_eq!(
247 &message_key,
248 chain_key.message_keys().generate_keys(None).cipher_key()
249 );
250 assert_eq!(
251 &mac_key,
252 chain_key.message_keys().generate_keys(None).mac_key()
253 );
254 assert_eq!(&next_chain_key, chain_key.next_chain_key().key());
255 assert_eq!(0, chain_key.index());
256 assert_eq!(0, chain_key.message_keys().generate_keys(None).counter());
257 assert_eq!(1, chain_key.next_chain_key().index());
258 assert_eq!(
259 1,
260 chain_key
261 .next_chain_key()
262 .message_keys()
263 .generate_keys(None)
264 .counter()
265 );
266 Ok(())
267 }
268}