1use rand::{CryptoRng, Rng};
7use uuid::Uuid;
8
9use crate::protocol::SENDERKEY_MESSAGE_CURRENT_VERSION;
10use crate::sender_keys::{SenderKeyState, SenderMessageKey};
11use crate::{
12 consts, CiphertextMessageType, KeyPair, ProtocolAddress, Result, SenderKeyDistributionMessage,
13 SenderKeyMessage, SenderKeyRecord, SenderKeyStore, SignalProtocolError,
14};
15
16pub async fn group_encrypt<R: Rng + CryptoRng>(
17 sender_key_store: &mut dyn SenderKeyStore,
18 sender: &ProtocolAddress,
19 distribution_id: Uuid,
20 plaintext: &[u8],
21 csprng: &mut R,
22) -> Result<SenderKeyMessage> {
23 let mut record = sender_key_store
24 .load_sender_key(sender, distribution_id)
25 .await?
26 .ok_or(SignalProtocolError::NoSenderKeyState { distribution_id })?;
27
28 let sender_key_state = record
29 .sender_key_state_mut()
30 .map_err(|_| SignalProtocolError::InvalidSenderKeySession { distribution_id })?;
31
32 let sender_chain_key = sender_key_state
33 .sender_chain_key()
34 .ok_or(SignalProtocolError::InvalidSenderKeySession { distribution_id })?;
35
36 let message_keys = sender_chain_key.sender_message_key();
37
38 let ciphertext =
39 signal_crypto::aes_256_cbc_encrypt(plaintext, message_keys.cipher_key(), message_keys.iv())
40 .map_err(|_| {
41 log::error!(
42 "outgoing sender key state corrupt for distribution ID {}",
43 distribution_id,
44 );
45 SignalProtocolError::InvalidSenderKeySession { distribution_id }
46 })?;
47
48 let signing_key = sender_key_state
49 .signing_key_private()
50 .map_err(|_| SignalProtocolError::InvalidSenderKeySession { distribution_id })?;
51
52 let skm = SenderKeyMessage::new(
53 sender_key_state.message_version() as u8,
54 distribution_id,
55 sender_key_state.chain_id(),
56 message_keys.iteration(),
57 ciphertext.into_boxed_slice(),
58 csprng,
59 &signing_key,
60 )?;
61
62 sender_key_state.set_sender_chain_key(sender_chain_key.next());
63
64 sender_key_store
65 .store_sender_key(sender, distribution_id, &record)
66 .await?;
67
68 Ok(skm)
69}
70
71fn get_sender_key(
72 state: &mut SenderKeyState,
73 iteration: u32,
74 distribution_id: Uuid,
75) -> Result<SenderMessageKey> {
76 let sender_chain_key = state
77 .sender_chain_key()
78 .ok_or(SignalProtocolError::InvalidSenderKeySession { distribution_id })?;
79 let current_iteration = sender_chain_key.iteration();
80
81 if current_iteration > iteration {
82 if let Some(smk) = state.remove_sender_message_key(iteration) {
83 return Ok(smk);
84 } else {
85 log::info!(
86 "SenderKey distribution {} Duplicate message for iteration: {}",
87 distribution_id,
88 iteration
89 );
90 return Err(SignalProtocolError::DuplicatedMessage(
91 current_iteration,
92 iteration,
93 ));
94 }
95 }
96
97 let jump = (iteration - current_iteration) as usize;
98 if jump > consts::MAX_FORWARD_JUMPS {
99 log::error!(
100 "SenderKey distribution {} Exceeded future message limit: {}, current iteration: {})",
101 distribution_id,
102 consts::MAX_FORWARD_JUMPS,
103 current_iteration
104 );
105 return Err(SignalProtocolError::InvalidMessage(
106 CiphertextMessageType::SenderKey,
107 "message from too far into the future",
108 ));
109 }
110
111 let mut sender_chain_key = sender_chain_key;
112
113 while sender_chain_key.iteration() < iteration {
114 state.add_sender_message_key(&sender_chain_key.sender_message_key());
115 sender_chain_key = sender_chain_key.next();
116 }
117
118 state.set_sender_chain_key(sender_chain_key.next());
119 Ok(sender_chain_key.sender_message_key())
120}
121
122pub async fn group_decrypt(
123 skm_bytes: &[u8],
124 sender_key_store: &mut dyn SenderKeyStore,
125 sender: &ProtocolAddress,
126) -> Result<Vec<u8>> {
127 let skm = SenderKeyMessage::try_from(skm_bytes)?;
128
129 let distribution_id = skm.distribution_id();
130 let chain_id = skm.chain_id();
131
132 let mut record = sender_key_store
133 .load_sender_key(sender, skm.distribution_id())
134 .await?
135 .ok_or(SignalProtocolError::NoSenderKeyState { distribution_id })?;
136
137 let sender_key_state = match record.sender_key_state_for_chain_id(chain_id) {
138 Some(state) => state,
139 None => {
140 log::error!(
141 "SenderKey distribution {} could not find chain ID {} (known chain IDs: {:?})",
142 distribution_id,
143 chain_id,
144 record.chain_ids_for_logging().collect::<Vec<_>>(),
145 );
146 return Err(SignalProtocolError::NoSenderKeyState { distribution_id });
147 }
148 };
149
150 let message_version = skm.message_version() as u32;
151 if message_version != sender_key_state.message_version() {
152 return Err(SignalProtocolError::UnrecognizedMessageVersion(
153 message_version,
154 ));
155 }
156
157 let signing_key = sender_key_state
158 .signing_key_public()
159 .map_err(|_| SignalProtocolError::InvalidSenderKeySession { distribution_id })?;
160 if !skm.verify_signature(&signing_key)? {
161 return Err(SignalProtocolError::SignatureValidationFailed);
162 }
163
164 let sender_key = get_sender_key(sender_key_state, skm.iteration(), distribution_id)?;
165
166 let plaintext = match signal_crypto::aes_256_cbc_decrypt(
167 skm.ciphertext(),
168 sender_key.cipher_key(),
169 sender_key.iv(),
170 ) {
171 Ok(plaintext) => plaintext,
172 Err(signal_crypto::DecryptionError::BadKeyOrIv) => {
173 log::error!(
174 "incoming sender key state corrupt for {}, distribution ID {}, chain ID {}",
175 sender,
176 distribution_id,
177 chain_id,
178 );
179 return Err(SignalProtocolError::InvalidSenderKeySession { distribution_id });
180 }
181 Err(signal_crypto::DecryptionError::BadCiphertext(msg)) => {
182 log::error!("sender key decryption failed: {}", msg);
183 return Err(SignalProtocolError::InvalidMessage(
184 CiphertextMessageType::SenderKey,
185 "decryption failed",
186 ));
187 }
188 };
189
190 sender_key_store
191 .store_sender_key(sender, distribution_id, &record)
192 .await?;
193
194 Ok(plaintext)
195}
196
197pub async fn process_sender_key_distribution_message(
198 sender: &ProtocolAddress,
199 skdm: &SenderKeyDistributionMessage,
200 sender_key_store: &mut dyn SenderKeyStore,
201) -> Result<()> {
202 let distribution_id = skdm.distribution_id()?;
203 log::info!(
204 "{} Processing SenderKey distribution {} with chain ID {}",
205 sender,
206 distribution_id,
207 skdm.chain_id()?
208 );
209
210 let mut sender_key_record = sender_key_store
211 .load_sender_key(sender, distribution_id)
212 .await?
213 .unwrap_or_else(SenderKeyRecord::new_empty);
214
215 sender_key_record.add_sender_key_state(
216 skdm.message_version(),
217 skdm.chain_id()?,
218 skdm.iteration()?,
219 skdm.chain_key()?,
220 *skdm.signing_key()?,
221 None,
222 );
223 sender_key_store
224 .store_sender_key(sender, distribution_id, &sender_key_record)
225 .await?;
226 Ok(())
227}
228
229pub async fn create_sender_key_distribution_message<R: Rng + CryptoRng>(
230 sender: &ProtocolAddress,
231 distribution_id: Uuid,
232 sender_key_store: &mut dyn SenderKeyStore,
233 csprng: &mut R,
234) -> Result<SenderKeyDistributionMessage> {
235 let sender_key_record = sender_key_store
236 .load_sender_key(sender, distribution_id)
237 .await?;
238
239 let sender_key_record = match sender_key_record {
240 Some(record) => record,
241 None => {
242 let chain_id = (csprng.gen::<u32>()) >> 1;
244 log::info!(
245 "Creating SenderKey for distribution {} with chain ID {}",
246 distribution_id,
247 chain_id
248 );
249
250 let iteration = 0;
251 let sender_key: [u8; 32] = csprng.gen();
252 let signing_key = KeyPair::generate(csprng);
253 let mut record = SenderKeyRecord::new_empty();
254 record.add_sender_key_state(
255 SENDERKEY_MESSAGE_CURRENT_VERSION,
256 chain_id,
257 iteration,
258 &sender_key,
259 signing_key.public_key,
260 Some(signing_key.private_key),
261 );
262 sender_key_store
263 .store_sender_key(sender, distribution_id, &record)
264 .await?;
265 record
266 }
267 };
268
269 let state = sender_key_record
270 .sender_key_state()
271 .map_err(|_| SignalProtocolError::InvalidSenderKeySession { distribution_id })?;
272 let sender_chain_key = state
273 .sender_chain_key()
274 .ok_or(SignalProtocolError::InvalidSenderKeySession { distribution_id })?;
275
276 SenderKeyDistributionMessage::new(
277 state.message_version() as u8,
278 distribution_id,
279 state.chain_id(),
280 sender_chain_key.iteration(),
281 sender_chain_key.seed().to_vec(),
282 state
283 .signing_key_public()
284 .map_err(|_| SignalProtocolError::InvalidSenderKeySession { distribution_id })?,
285 )
286}