1use std::time::SystemTime;
7
8use rand::{CryptoRng, Rng};
9
10use crate::protocol::CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION;
11use crate::ratchet::{AliceSignalProtocolParameters, BobSignalProtocolParameters};
12use crate::state::GenericSignedPreKey;
13use crate::{
14 CiphertextMessageType, Direction, IdentityKey, IdentityKeyStore, KeyPair, KyberPreKeyId,
15 KyberPreKeyStore, PreKeyBundle, PreKeyId, PreKeySignalMessage, PreKeyStore, ProtocolAddress,
16 Result, SessionRecord, SessionStore, SignalProtocolError, SignedPreKeyId, SignedPreKeyStore,
17 ratchet,
18};
19
20pub struct PreKeysUsed {
21 pub one_time_ec_pre_key_id: Option<PreKeyId>,
22 pub signed_ec_pre_key_id: SignedPreKeyId,
23 pub kyber_pre_key_id: Option<KyberPreKeyId>,
24}
25
26#[must_use]
32pub struct IdentityToSave<'a> {
33 pub remote_address: &'a ProtocolAddress,
34 pub their_identity_key: &'a IdentityKey,
35}
36
37pub async fn process_prekey<'a>(
47 message: &'a PreKeySignalMessage,
48 remote_address: &'a ProtocolAddress,
49 session_record: &mut SessionRecord,
50 identity_store: &dyn IdentityKeyStore,
51 pre_key_store: &dyn PreKeyStore,
52 signed_prekey_store: &dyn SignedPreKeyStore,
53 kyber_prekey_store: &dyn KyberPreKeyStore,
54) -> Result<(Option<PreKeysUsed>, IdentityToSave<'a>)> {
55 let their_identity_key = message.identity_key();
56
57 if !identity_store
58 .is_trusted_identity(remote_address, their_identity_key, Direction::Receiving)
59 .await?
60 {
61 return Err(SignalProtocolError::UntrustedIdentity(
62 remote_address.clone(),
63 ));
64 }
65
66 let pre_keys_used = process_prekey_impl(
67 message,
68 remote_address,
69 session_record,
70 signed_prekey_store,
71 kyber_prekey_store,
72 pre_key_store,
73 identity_store,
74 )
75 .await?;
76
77 let identity_to_save = IdentityToSave {
78 remote_address,
79 their_identity_key,
80 };
81
82 Ok((pre_keys_used, identity_to_save))
83}
84
85async fn process_prekey_impl(
86 message: &PreKeySignalMessage,
87 remote_address: &ProtocolAddress,
88 session_record: &mut SessionRecord,
89 signed_prekey_store: &dyn SignedPreKeyStore,
90 kyber_prekey_store: &dyn KyberPreKeyStore,
91 pre_key_store: &dyn PreKeyStore,
92 identity_store: &dyn IdentityKeyStore,
93) -> Result<Option<PreKeysUsed>> {
94 if session_record.promote_matching_session(
95 message.message_version() as u32,
96 &message.base_key().serialize(),
97 )? {
98 return Ok(None);
100 }
101
102 if message.message_version() == CIPHERTEXT_MESSAGE_PRE_KYBER_VERSION {
105 return Err(SignalProtocolError::InvalidMessage(
109 CiphertextMessageType::PreKey,
110 "X3DH no longer supported",
111 ));
112 }
113
114 let our_signed_pre_key_pair = signed_prekey_store
115 .get_signed_pre_key(message.signed_pre_key_id())
116 .await?
117 .key_pair()?;
118
119 let our_kyber_pre_key_pair = if let Some(kyber_pre_key_id) = message.kyber_pre_key_id() {
120 kyber_prekey_store
121 .get_kyber_pre_key(kyber_pre_key_id)
122 .await?
123 .key_pair()?
124 } else {
125 return Err(SignalProtocolError::InvalidMessage(
126 CiphertextMessageType::PreKey,
127 "missing pq pre-key ID",
128 ));
129 };
130 let kyber_ciphertext =
131 message
132 .kyber_ciphertext()
133 .ok_or(SignalProtocolError::InvalidMessage(
134 CiphertextMessageType::PreKey,
135 "missing pq ciphertext",
136 ))?;
137
138 let our_one_time_pre_key_pair = if let Some(pre_key_id) = message.pre_key_id() {
139 log::info!("processing PreKey message from {remote_address}");
140 Some(pre_key_store.get_pre_key(pre_key_id).await?.key_pair()?)
141 } else {
142 log::warn!("processing PreKey message from {remote_address} which had no one-time prekey");
143 None
144 };
145
146 let parameters = BobSignalProtocolParameters::new(
147 identity_store.get_identity_key_pair().await?,
148 our_signed_pre_key_pair, our_one_time_pre_key_pair,
150 our_signed_pre_key_pair, our_kyber_pre_key_pair,
152 *message.identity_key(),
153 *message.base_key(),
154 kyber_ciphertext,
155 );
156
157 let mut new_session = ratchet::initialize_bob_session(¶meters)?;
158
159 new_session.set_local_registration_id(identity_store.get_local_registration_id().await?);
160 new_session.set_remote_registration_id(message.registration_id());
161
162 session_record.promote_state(new_session);
163
164 let pre_keys_used = PreKeysUsed {
165 one_time_ec_pre_key_id: message.pre_key_id(),
166 signed_ec_pre_key_id: message.signed_pre_key_id(),
167 kyber_pre_key_id: message.kyber_pre_key_id(),
168 };
169 Ok(Some(pre_keys_used))
170}
171
172pub async fn process_prekey_bundle<R: Rng + CryptoRng>(
173 remote_address: &ProtocolAddress,
174 session_store: &mut dyn SessionStore,
175 identity_store: &mut dyn IdentityKeyStore,
176 bundle: &PreKeyBundle,
177 now: SystemTime,
178 mut csprng: &mut R,
179) -> Result<()> {
180 let their_identity_key = bundle.identity_key()?;
181
182 if !identity_store
183 .is_trusted_identity(remote_address, their_identity_key, Direction::Sending)
184 .await?
185 {
186 return Err(SignalProtocolError::UntrustedIdentity(
187 remote_address.clone(),
188 ));
189 }
190
191 if !their_identity_key.public_key().verify_signature(
192 &bundle.signed_pre_key_public()?.serialize(),
193 bundle.signed_pre_key_signature()?,
194 ) {
195 return Err(SignalProtocolError::SignatureValidationFailed);
196 }
197
198 if !their_identity_key.public_key().verify_signature(
199 &bundle.kyber_pre_key_public()?.serialize(),
200 bundle.kyber_pre_key_signature()?,
201 ) {
202 return Err(SignalProtocolError::SignatureValidationFailed);
203 }
204
205 let mut session_record = session_store
206 .load_session(remote_address)
207 .await?
208 .unwrap_or_else(SessionRecord::new_fresh);
209
210 let our_base_key_pair = KeyPair::generate(&mut csprng);
211 let their_signed_prekey = bundle.signed_pre_key_public()?;
212 let their_kyber_prekey = bundle.kyber_pre_key_public()?;
213
214 let their_one_time_prekey_id = bundle.pre_key_id()?;
215
216 let our_identity_key_pair = identity_store.get_identity_key_pair().await?;
217
218 let mut parameters = AliceSignalProtocolParameters::new(
219 our_identity_key_pair,
220 our_base_key_pair,
221 *their_identity_key,
222 their_signed_prekey,
223 their_signed_prekey,
224 their_kyber_prekey.clone(),
225 );
226 if let Some(key) = bundle.pre_key_public()? {
227 parameters.set_their_one_time_pre_key(key);
228 }
229
230 let mut session = ratchet::initialize_alice_session(¶meters, csprng)?;
231
232 log::info!(
233 "set_unacknowledged_pre_key_message for: {} with preKeyId: {}",
234 remote_address,
235 their_one_time_prekey_id.map_or_else(|| "<none>".to_string(), |id| id.to_string())
236 );
237
238 session.set_unacknowledged_pre_key_message(
239 their_one_time_prekey_id,
240 bundle.signed_pre_key_id()?,
241 &our_base_key_pair.public_key,
242 now,
243 );
244 session.set_unacknowledged_kyber_pre_key_id(bundle.kyber_pre_key_id()?);
245
246 session.set_local_registration_id(identity_store.get_local_registration_id().await?);
247 session.set_remote_registration_id(bundle.registration_id()?);
248
249 identity_store
250 .save_identity(remote_address, their_identity_key)
251 .await?;
252
253 session_record.promote_state(session);
254
255 session_store
256 .store_session(remote_address, &session_record)
257 .await?;
258
259 Ok(())
260}