1use std::time::SystemTime;
7
8use rand::{CryptoRng, Rng};
9
10use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION;
11use crate::ratchet::{AliceSignalProtocolParameters, BobSignalProtocolParameters};
12use crate::state::GenericSignedPreKey;
13use crate::{
14 CiphertextMessageType, Direction, IdentityKey, IdentityKeyStore, KeyPair, KyberPreKeyId,
15 KyberPreKeyStore, PreKeyBundle, PreKeyId, PreKeySignalMessage, PreKeyStore, ProtocolAddress,
16 Result, SessionRecord, SessionStore, SignalProtocolError, SignedPreKeyId, SignedPreKeyStore,
17 ratchet,
18};
19
20pub struct PreKeysUsed {
21 pub one_time_ec_pre_key_id: Option<PreKeyId>,
22 pub signed_ec_pre_key_id: SignedPreKeyId,
23 pub kyber_pre_key_id: Option<KyberPreKeyId>,
24}
25
26#[must_use]
32pub struct IdentityToSave<'a> {
33 pub remote_address: &'a ProtocolAddress,
34 pub their_identity_key: &'a IdentityKey,
35}
36
37pub async fn process_prekey<'a>(
47 message: &'a PreKeySignalMessage,
48 remote_address: &'a ProtocolAddress,
49 local_address: &ProtocolAddress,
50 session_record: &mut SessionRecord,
51 identity_store: &dyn IdentityKeyStore,
52 pre_key_store: &dyn PreKeyStore,
53 signed_prekey_store: &dyn SignedPreKeyStore,
54 kyber_prekey_store: &dyn KyberPreKeyStore,
55) -> Result<(Option<PreKeysUsed>, IdentityToSave<'a>)> {
56 let their_identity_key = message.identity_key();
57
58 if !identity_store
59 .is_trusted_identity(remote_address, their_identity_key, Direction::Receiving)
60 .await?
61 {
62 return Err(SignalProtocolError::UntrustedIdentity(
63 remote_address.clone(),
64 ));
65 }
66
67 let pre_keys_used = process_prekey_impl(
68 message,
69 remote_address,
70 local_address,
71 session_record,
72 signed_prekey_store,
73 kyber_prekey_store,
74 pre_key_store,
75 identity_store,
76 )
77 .await?;
78
79 let identity_to_save = IdentityToSave {
80 remote_address,
81 their_identity_key,
82 };
83
84 Ok((pre_keys_used, identity_to_save))
85}
86
87async fn process_prekey_impl(
88 message: &PreKeySignalMessage,
89 remote_address: &ProtocolAddress,
90 local_address: &ProtocolAddress,
91 session_record: &mut SessionRecord,
92 signed_prekey_store: &dyn SignedPreKeyStore,
93 kyber_prekey_store: &dyn KyberPreKeyStore,
94 pre_key_store: &dyn PreKeyStore,
95 identity_store: &dyn IdentityKeyStore,
96) -> Result<Option<PreKeysUsed>> {
97 if session_record.promote_matching_session(
98 message.message_version() as u32,
99 &message.base_key().serialize(),
100 )? {
101 return Ok(None);
103 }
104
105 if message.message_version() == CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
108 return Err(SignalProtocolError::InvalidMessage(
112 CiphertextMessageType::PreKey,
113 "X3DH no longer supported",
114 ));
115 }
116
117 let our_signed_pre_key_pair = signed_prekey_store
118 .get_signed_pre_key(message.signed_pre_key_id())
119 .await?
120 .key_pair()?;
121
122 let our_kyber_pre_key_pair = if let Some(kyber_pre_key_id) = message.kyber_pre_key_id() {
123 kyber_prekey_store
124 .get_kyber_pre_key(kyber_pre_key_id)
125 .await?
126 .key_pair()?
127 } else {
128 return Err(SignalProtocolError::InvalidMessage(
129 CiphertextMessageType::PreKey,
130 "missing pq pre-key ID",
131 ));
132 };
133 let kyber_ciphertext =
134 message
135 .kyber_ciphertext()
136 .ok_or(SignalProtocolError::InvalidMessage(
137 CiphertextMessageType::PreKey,
138 "missing pq ciphertext",
139 ))?;
140
141 let our_one_time_pre_key_pair = if let Some(pre_key_id) = message.pre_key_id() {
142 log::info!("processing PreKey message from {remote_address}");
143 Some(pre_key_store.get_pre_key(pre_key_id).await?.key_pair()?)
144 } else {
145 log::warn!("processing PreKey message from {remote_address} which had no one-time prekey");
146 None
147 };
148
149 let our_identity_key_pair = identity_store.get_identity_key_pair().await?;
150 let parameters = BobSignalProtocolParameters::new(
151 our_identity_key_pair,
152 our_signed_pre_key_pair,
153 our_one_time_pre_key_pair,
154 our_kyber_pre_key_pair,
155 *message.identity_key(),
156 *message.base_key(),
157 kyber_ciphertext,
158 our_identity_key_pair.identity_key().is_same_account(
159 local_address,
160 message.identity_key(),
161 remote_address,
162 ),
163 );
164
165 let mut new_session = ratchet::initialize_bob_session(¶meters, &our_signed_pre_key_pair)?;
167
168 new_session.set_local_registration_id(identity_store.get_local_registration_id().await?);
169 new_session.set_remote_registration_id(message.registration_id());
170
171 session_record.promote_state(new_session);
172
173 let pre_keys_used = PreKeysUsed {
174 one_time_ec_pre_key_id: message.pre_key_id(),
175 signed_ec_pre_key_id: message.signed_pre_key_id(),
176 kyber_pre_key_id: message.kyber_pre_key_id(),
177 };
178 Ok(Some(pre_keys_used))
179}
180
181pub async fn process_prekey_bundle<R: Rng + CryptoRng>(
182 remote_address: &ProtocolAddress,
183 local_address: &ProtocolAddress,
184 session_store: &mut dyn SessionStore,
185 identity_store: &mut dyn IdentityKeyStore,
186 bundle: &PreKeyBundle,
187 now: SystemTime,
188 mut csprng: &mut R,
189) -> Result<()> {
190 let their_identity_key = bundle.identity_key()?;
191
192 if !identity_store
193 .is_trusted_identity(remote_address, their_identity_key, Direction::Sending)
194 .await?
195 {
196 return Err(SignalProtocolError::UntrustedIdentity(
197 remote_address.clone(),
198 ));
199 }
200
201 if !their_identity_key.public_key().verify_signature(
202 &bundle.signed_pre_key_public()?.serialize(),
203 bundle.signed_pre_key_signature()?,
204 ) {
205 return Err(SignalProtocolError::SignatureValidationFailed);
206 }
207
208 if !their_identity_key.public_key().verify_signature(
209 &bundle.kyber_pre_key_public()?.serialize(),
210 bundle.kyber_pre_key_signature()?,
211 ) {
212 return Err(SignalProtocolError::SignatureValidationFailed);
213 }
214
215 let mut session_record = session_store
216 .load_session(remote_address)
217 .await?
218 .unwrap_or_else(SessionRecord::new_fresh);
219
220 let our_base_key_pair = KeyPair::generate(&mut csprng);
221 let their_signed_prekey = bundle.signed_pre_key_public()?;
222 let their_kyber_prekey = bundle.kyber_pre_key_public()?;
223
224 let their_one_time_prekey_id = bundle.pre_key_id()?;
225
226 let our_identity_key_pair = identity_store.get_identity_key_pair().await?;
227
228 let mut parameters = AliceSignalProtocolParameters::new(
229 our_identity_key_pair,
230 our_base_key_pair,
231 *their_identity_key,
232 their_signed_prekey,
233 their_signed_prekey,
234 their_kyber_prekey.clone(),
235 our_identity_key_pair.identity_key().is_same_account(
236 local_address,
237 their_identity_key,
238 remote_address,
239 ),
240 );
241 if let Some(key) = bundle.pre_key_public()? {
242 parameters.set_their_one_time_pre_key(key);
243 }
244
245 let mut session = ratchet::initialize_alice_session(¶meters, csprng)?;
246
247 log::info!(
248 "set_unacknowledged_pre_key_message for: {} with preKeyId: {}",
249 remote_address,
250 their_one_time_prekey_id.map_or_else(|| "<none>".to_string(), |id| id.to_string())
251 );
252
253 session.set_unacknowledged_pre_key_message(
254 their_one_time_prekey_id,
255 bundle.signed_pre_key_id()?,
256 &our_base_key_pair.public_key,
257 now,
258 );
259 session.set_unacknowledged_kyber_pre_key_id(bundle.kyber_pre_key_id()?);
260
261 session.set_local_registration_id(identity_store.get_local_registration_id().await?);
262 session.set_remote_registration_id(bundle.registration_id()?);
263
264 identity_store
265 .save_identity(remote_address, their_identity_key)
266 .await?;
267
268 session_record.promote_state(session);
269
270 session_store
271 .store_session(remote_address, &session_record)
272 .await?;
273
274 Ok(())
275}