1use std::time::SystemTime;
7
8use rand::{CryptoRng, Rng};
9
10use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION;
11use crate::ratchet::{AliceSignalProtocolParameters, BobSignalProtocolParameters};
12use crate::state::GenericSignedPreKey;
13use crate::{
14 ratchet, CiphertextMessageType, Direction, IdentityKey, IdentityKeyStore, KeyPair,
15 KyberPreKeyId, KyberPreKeyStore, PreKeyBundle, PreKeyId, PreKeySignalMessage, PreKeyStore,
16 ProtocolAddress, Result, SessionRecord, SessionStore, SignalProtocolError, SignedPreKeyStore,
17};
18
19#[derive(Default)]
20pub struct PreKeysUsed {
21 pub pre_key_id: Option<PreKeyId>,
22 pub kyber_pre_key_id: Option<KyberPreKeyId>,
23}
24
25#[must_use]
31pub struct IdentityToSave<'a> {
32 pub remote_address: &'a ProtocolAddress,
33 pub their_identity_key: &'a IdentityKey,
34}
35
36pub async fn process_prekey<'a>(
46 message: &'a PreKeySignalMessage,
47 remote_address: &'a ProtocolAddress,
48 session_record: &mut SessionRecord,
49 identity_store: &dyn IdentityKeyStore,
50 pre_key_store: &dyn PreKeyStore,
51 signed_prekey_store: &dyn SignedPreKeyStore,
52 kyber_prekey_store: &dyn KyberPreKeyStore,
53 use_pq_ratchet: ratchet::UsePQRatchet,
54) -> Result<(PreKeysUsed, IdentityToSave<'a>)> {
55 let their_identity_key = message.identity_key();
56
57 if !identity_store
58 .is_trusted_identity(remote_address, their_identity_key, Direction::Receiving)
59 .await?
60 {
61 return Err(SignalProtocolError::UntrustedIdentity(
62 remote_address.clone(),
63 ));
64 }
65
66 let pre_keys_used = process_prekey_impl(
67 message,
68 remote_address,
69 session_record,
70 signed_prekey_store,
71 kyber_prekey_store,
72 pre_key_store,
73 identity_store,
74 use_pq_ratchet,
75 )
76 .await?;
77
78 let identity_to_save = IdentityToSave {
79 remote_address,
80 their_identity_key,
81 };
82
83 Ok((pre_keys_used, identity_to_save))
84}
85
86async fn process_prekey_impl(
87 message: &PreKeySignalMessage,
88 remote_address: &ProtocolAddress,
89 session_record: &mut SessionRecord,
90 signed_prekey_store: &dyn SignedPreKeyStore,
91 kyber_prekey_store: &dyn KyberPreKeyStore,
92 pre_key_store: &dyn PreKeyStore,
93 identity_store: &dyn IdentityKeyStore,
94 use_pq_ratchet: ratchet::UsePQRatchet,
95) -> Result<PreKeysUsed> {
96 if session_record.promote_matching_session(
97 message.message_version() as u32,
98 &message.base_key().serialize(),
99 )? {
100 return Ok(Default::default());
102 }
103
104 if message.message_version() == CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
107 return Err(SignalProtocolError::InvalidMessage(
111 CiphertextMessageType::PreKey,
112 "X3DH no longer supported",
113 ));
114 }
115
116 let our_signed_pre_key_pair = signed_prekey_store
117 .get_signed_pre_key(message.signed_pre_key_id())
118 .await?
119 .key_pair()?;
120
121 let our_kyber_pre_key_pair = if let Some(kyber_pre_key_id) = message.kyber_pre_key_id() {
122 kyber_prekey_store
123 .get_kyber_pre_key(kyber_pre_key_id)
124 .await?
125 .key_pair()?
126 } else {
127 return Err(SignalProtocolError::InvalidMessage(
128 CiphertextMessageType::PreKey,
129 "missing pq pre-key ID",
130 ));
131 };
132 let kyber_ciphertext =
133 message
134 .kyber_ciphertext()
135 .ok_or(SignalProtocolError::InvalidMessage(
136 CiphertextMessageType::PreKey,
137 "missing pq ciphertext",
138 ))?;
139
140 let our_one_time_pre_key_pair = if let Some(pre_key_id) = message.pre_key_id() {
141 log::info!("processing PreKey message from {remote_address}");
142 Some(pre_key_store.get_pre_key(pre_key_id).await?.key_pair()?)
143 } else {
144 log::warn!("processing PreKey message from {remote_address} which had no one-time prekey");
145 None
146 };
147
148 let parameters = BobSignalProtocolParameters::new(
149 identity_store.get_identity_key_pair().await?,
150 our_signed_pre_key_pair, our_one_time_pre_key_pair,
152 our_signed_pre_key_pair, our_kyber_pre_key_pair,
154 *message.identity_key(),
155 *message.base_key(),
156 kyber_ciphertext,
157 use_pq_ratchet,
158 );
159
160 let mut new_session = ratchet::initialize_bob_session(¶meters)?;
161
162 new_session.set_local_registration_id(identity_store.get_local_registration_id().await?);
163 new_session.set_remote_registration_id(message.registration_id());
164
165 session_record.promote_state(new_session);
166
167 let pre_keys_used = PreKeysUsed {
168 pre_key_id: message.pre_key_id(),
169 kyber_pre_key_id: message.kyber_pre_key_id(),
170 };
171 Ok(pre_keys_used)
172}
173
174pub async fn process_prekey_bundle<R: Rng + CryptoRng>(
175 remote_address: &ProtocolAddress,
176 session_store: &mut dyn SessionStore,
177 identity_store: &mut dyn IdentityKeyStore,
178 bundle: &PreKeyBundle,
179 now: SystemTime,
180 mut csprng: &mut R,
181 use_pq_ratchet: ratchet::UsePQRatchet,
182) -> Result<()> {
183 let their_identity_key = bundle.identity_key()?;
184
185 if !identity_store
186 .is_trusted_identity(remote_address, their_identity_key, Direction::Sending)
187 .await?
188 {
189 return Err(SignalProtocolError::UntrustedIdentity(
190 remote_address.clone(),
191 ));
192 }
193
194 if !their_identity_key.public_key().verify_signature(
195 &bundle.signed_pre_key_public()?.serialize(),
196 bundle.signed_pre_key_signature()?,
197 ) {
198 return Err(SignalProtocolError::SignatureValidationFailed);
199 }
200
201 if !their_identity_key.public_key().verify_signature(
202 &bundle.kyber_pre_key_public()?.serialize(),
203 bundle.kyber_pre_key_signature()?,
204 ) {
205 return Err(SignalProtocolError::SignatureValidationFailed);
206 }
207
208 let mut session_record = session_store
209 .load_session(remote_address)
210 .await?
211 .unwrap_or_else(SessionRecord::new_fresh);
212
213 let our_base_key_pair = KeyPair::generate(&mut csprng);
214 let their_signed_prekey = bundle.signed_pre_key_public()?;
215 let their_kyber_prekey = bundle.kyber_pre_key_public()?;
216
217 let their_one_time_prekey_id = bundle.pre_key_id()?;
218
219 let our_identity_key_pair = identity_store.get_identity_key_pair().await?;
220
221 let mut parameters = AliceSignalProtocolParameters::new(
222 our_identity_key_pair,
223 our_base_key_pair,
224 *their_identity_key,
225 their_signed_prekey,
226 their_signed_prekey,
227 their_kyber_prekey.clone(),
228 use_pq_ratchet,
229 );
230 if let Some(key) = bundle.pre_key_public()? {
231 parameters.set_their_one_time_pre_key(key);
232 }
233
234 let mut session = ratchet::initialize_alice_session(¶meters, csprng)?;
235
236 log::info!(
237 "set_unacknowledged_pre_key_message for: {} with preKeyId: {}",
238 remote_address,
239 their_one_time_prekey_id.map_or_else(|| "<none>".to_string(), |id| id.to_string())
240 );
241
242 session.set_unacknowledged_pre_key_message(
243 their_one_time_prekey_id,
244 bundle.signed_pre_key_id()?,
245 &our_base_key_pair.public_key,
246 now,
247 );
248 session.set_unacknowledged_kyber_pre_key_id(bundle.kyber_pre_key_id()?);
249
250 session.set_local_registration_id(identity_store.get_local_registration_id().await?);
251 session.set_remote_registration_id(bundle.registration_id()?);
252
253 identity_store
254 .save_identity(remote_address, their_identity_key)
255 .await?;
256
257 session_record.promote_state(session);
258
259 session_store
260 .store_session(remote_address, &session_record)
261 .await?;
262
263 Ok(())
264}