1use std::fmt;
7
8use libsignal_core::derive_arrays;
9
10use crate::proto::storage::session_structure;
11use crate::{PrivateKey, PublicKey, Result, crypto};
12
13pub(crate) enum MessageKeyGenerator {
14 Keys(MessageKeys),
15 Seed((Vec<u8>, u32)),
16}
17
18impl MessageKeyGenerator {
19 pub(crate) fn new_from_seed(seed: &[u8], counter: u32) -> Self {
20 Self::Seed((seed.to_vec(), counter))
21 }
22 pub(crate) fn generate_keys(self, pqr_key: spqr::MessageKey) -> MessageKeys {
23 match self {
24 Self::Seed((seed, counter)) => {
25 MessageKeys::derive_keys(&seed, pqr_key.as_deref(), counter)
26 }
27 Self::Keys(k) => {
28 assert!(pqr_key.is_none());
31 k
32 }
33 }
34 }
35 pub(crate) fn into_pb(self) -> session_structure::chain::MessageKey {
36 match self {
37 Self::Keys(k) => session_structure::chain::MessageKey {
38 cipher_key: k.cipher_key().to_vec(),
39 mac_key: k.mac_key().to_vec(),
40 iv: k.iv().to_vec(),
41 index: k.counter(),
42 seed: vec![],
43 },
44 Self::Seed((seed, counter)) => session_structure::chain::MessageKey {
45 cipher_key: vec![],
46 mac_key: vec![],
47 iv: vec![],
48 index: counter,
49 seed,
50 },
51 }
52 }
53 pub(crate) fn from_pb(
54 pb: session_structure::chain::MessageKey,
55 ) -> std::result::Result<Self, &'static str> {
56 Ok(if pb.seed.is_empty() {
57 Self::Keys(MessageKeys {
58 cipher_key: pb
59 .cipher_key
60 .as_slice()
61 .try_into()
62 .map_err(|_| "invalid message cipher key")?,
63 mac_key: pb
64 .mac_key
65 .as_slice()
66 .try_into()
67 .map_err(|_| "invalid message MAC key")?,
68 iv: pb
69 .iv
70 .as_slice()
71 .try_into()
72 .map_err(|_| "invalid message IV")?,
73 counter: pb.index,
74 })
75 } else {
76 Self::Seed((pb.seed, pb.index))
77 })
78 }
79}
80
81#[derive(Clone, Copy)]
82pub(crate) struct MessageKeys {
83 cipher_key: [u8; 32],
84 mac_key: [u8; 32],
85 iv: [u8; 16],
86 counter: u32,
87}
88
89impl MessageKeys {
90 pub(crate) fn derive_keys(
91 input_key_material: &[u8],
92 optional_salt: Option<&[u8]>,
93 counter: u32,
94 ) -> Self {
95 let (cipher_key, mac_key, iv) = derive_arrays(|okm| {
96 hkdf::Hkdf::<sha2::Sha256>::new(optional_salt, input_key_material)
97 .expand(b"WhisperMessageKeys", okm)
98 .expect("valid output length")
99 });
100
101 MessageKeys {
102 cipher_key,
103 mac_key,
104 iv,
105 counter,
106 }
107 }
108
109 #[inline]
110 pub(crate) fn cipher_key(&self) -> &[u8; 32] {
111 &self.cipher_key
112 }
113
114 #[inline]
115 pub(crate) fn mac_key(&self) -> &[u8; 32] {
116 &self.mac_key
117 }
118
119 #[inline]
120 pub(crate) fn iv(&self) -> &[u8; 16] {
121 &self.iv
122 }
123
124 #[inline]
125 pub(crate) fn counter(&self) -> u32 {
126 self.counter
127 }
128}
129
130#[derive(Clone, Debug)]
131pub(crate) struct ChainKey {
132 key: [u8; 32],
133 index: u32,
134}
135
136impl ChainKey {
137 const MESSAGE_KEY_SEED: [u8; 1] = [0x01u8];
138 const CHAIN_KEY_SEED: [u8; 1] = [0x02u8];
139
140 pub(crate) fn new(key: [u8; 32], index: u32) -> Self {
141 Self { key, index }
142 }
143
144 #[inline]
145 pub(crate) fn key(&self) -> &[u8; 32] {
146 &self.key
147 }
148
149 #[inline]
150 pub(crate) fn index(&self) -> u32 {
151 self.index
152 }
153
154 pub(crate) fn next_chain_key(&self) -> Self {
155 Self {
156 key: self.calculate_base_material(Self::CHAIN_KEY_SEED),
157 index: self.index + 1,
158 }
159 }
160
161 pub(crate) fn message_keys(&self) -> MessageKeyGenerator {
162 MessageKeyGenerator::new_from_seed(
163 &self.calculate_base_material(Self::MESSAGE_KEY_SEED),
164 self.index,
165 )
166 }
167
168 fn calculate_base_material(&self, seed: [u8; 1]) -> [u8; 32] {
169 crypto::hmac_sha256(&self.key, &seed)
170 }
171}
172
173#[derive(Clone, Debug)]
174pub(crate) struct RootKey {
175 key: [u8; 32],
176}
177
178impl RootKey {
179 pub(crate) fn new(key: [u8; 32]) -> Self {
180 Self { key }
181 }
182
183 pub(crate) fn key(&self) -> &[u8; 32] {
184 &self.key
185 }
186
187 pub(crate) fn create_chain(
188 self,
189 their_ratchet_key: &PublicKey,
190 our_ratchet_key: &PrivateKey,
191 ) -> Result<(RootKey, ChainKey)> {
192 let shared_secret = our_ratchet_key.calculate_agreement(their_ratchet_key)?;
193 let (root_key, chain_key, []) = derive_arrays(|bytes| {
194 hkdf::Hkdf::<sha2::Sha256>::new(Some(&self.key), &shared_secret)
195 .expand(b"WhisperRatchet", bytes)
196 .expect("valid output length")
197 });
198
199 Ok((
200 RootKey { key: root_key },
201 ChainKey {
202 key: chain_key,
203 index: 0,
204 },
205 ))
206 }
207}
208
209impl fmt::Display for RootKey {
210 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
211 write!(f, "{}", hex::encode(self.key))
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218
219 #[test]
220 fn test_chain_key_derivation() -> Result<()> {
221 let seed = [
222 0x8au8, 0xb7, 0x2d, 0x6f, 0x4c, 0xc5, 0xac, 0x0d, 0x38, 0x7e, 0xaf, 0x46, 0x33, 0x78,
223 0xdd, 0xb2, 0x8e, 0xdd, 0x07, 0x38, 0x5b, 0x1c, 0xb0, 0x12, 0x50, 0xc7, 0x15, 0x98,
224 0x2e, 0x7a, 0xd4, 0x8f,
225 ];
226 let message_key = [
227 0xbfu8, 0x51, 0xe9, 0xd7, 0x5e, 0x0e, 0x31, 0x03, 0x10, 0x51, 0xf8, 0x2a, 0x24, 0x91,
228 0xff, 0xc0, 0x84, 0xfa, 0x29, 0x8b, 0x77, 0x93, 0xbd, 0x9d, 0xb6, 0x20, 0x05, 0x6f,
229 0xeb, 0xf4, 0x52, 0x17,
230 ];
231 let mac_key = [
232 0xc6u8, 0xc7, 0x7d, 0x6a, 0x73, 0xa3, 0x54, 0x33, 0x7a, 0x56, 0x43, 0x5e, 0x34, 0x60,
233 0x7d, 0xfe, 0x48, 0xe3, 0xac, 0xe1, 0x4e, 0x77, 0x31, 0x4d, 0xc6, 0xab, 0xc1, 0x72,
234 0xe7, 0xa7, 0x03, 0x0b,
235 ];
236 let next_chain_key = [
237 0x28u8, 0xe8, 0xf8, 0xfe, 0xe5, 0x4b, 0x80, 0x1e, 0xef, 0x7c, 0x5c, 0xfb, 0x2f, 0x17,
238 0xf3, 0x2c, 0x7b, 0x33, 0x44, 0x85, 0xbb, 0xb7, 0x0f, 0xac, 0x6e, 0xc1, 0x03, 0x42,
239 0xa2, 0x46, 0xd1, 0x5d,
240 ];
241
242 let chain_key = ChainKey::new(seed, 0);
243 assert_eq!(&seed, chain_key.key());
244 assert_eq!(
245 &message_key,
246 chain_key.message_keys().generate_keys(None).cipher_key()
247 );
248 assert_eq!(
249 &mac_key,
250 chain_key.message_keys().generate_keys(None).mac_key()
251 );
252 assert_eq!(&next_chain_key, chain_key.next_chain_key().key());
253 assert_eq!(0, chain_key.index());
254 assert_eq!(0, chain_key.message_keys().generate_keys(None).counter());
255 assert_eq!(1, chain_key.next_chain_key().index());
256 assert_eq!(
257 1,
258 chain_key
259 .next_chain_key()
260 .message_keys()
261 .generate_keys(None)
262 .counter()
263 );
264 Ok(())
265 }
266}