1use std::time::SystemTime;
7
8use rand::{CryptoRng, Rng};
9
10use crate::ratchet::{AliceSignalProtocolParameters, BobSignalProtocolParameters};
11use crate::state::GenericSignedPreKey;
12use crate::{
13 kem, ratchet, Direction, IdentityKeyStore, KeyPair, KyberPreKeyId, KyberPreKeyStore,
14 PreKeyBundle, PreKeyId, PreKeySignalMessage, PreKeyStore, ProtocolAddress, Result,
15 SessionRecord, SessionStore, SignalProtocolError, SignedPreKeyStore,
16};
17
18#[derive(Default)]
19pub struct PreKeysUsed {
20 pub pre_key_id: Option<PreKeyId>,
21 pub kyber_pre_key_id: Option<KyberPreKeyId>,
22}
23
24pub async fn process_prekey(
34 message: &PreKeySignalMessage,
35 remote_address: &ProtocolAddress,
36 session_record: &mut SessionRecord,
37 identity_store: &mut dyn IdentityKeyStore,
38 pre_key_store: &dyn PreKeyStore,
39 signed_prekey_store: &dyn SignedPreKeyStore,
40 kyber_prekey_store: &dyn KyberPreKeyStore,
41) -> Result<PreKeysUsed> {
42 let their_identity_key = message.identity_key();
43
44 if !identity_store
45 .is_trusted_identity(remote_address, their_identity_key, Direction::Receiving)
46 .await?
47 {
48 return Err(SignalProtocolError::UntrustedIdentity(
49 remote_address.clone(),
50 ));
51 }
52
53 let pre_keys_used = process_prekey_impl(
54 message,
55 remote_address,
56 session_record,
57 signed_prekey_store,
58 kyber_prekey_store,
59 pre_key_store,
60 identity_store,
61 )
62 .await?;
63
64 identity_store
65 .save_identity(remote_address, their_identity_key)
66 .await?;
67
68 Ok(pre_keys_used)
69}
70
71async fn process_prekey_impl(
72 message: &PreKeySignalMessage,
73 remote_address: &ProtocolAddress,
74 session_record: &mut SessionRecord,
75 signed_prekey_store: &dyn SignedPreKeyStore,
76 kyber_prekey_store: &dyn KyberPreKeyStore,
77 pre_key_store: &dyn PreKeyStore,
78 identity_store: &dyn IdentityKeyStore,
79) -> Result<PreKeysUsed> {
80 if session_record.has_session_state(
81 message.message_version() as u32,
82 &message.base_key().serialize(),
83 )? {
84 return Ok(Default::default());
86 }
87
88 let our_signed_pre_key_pair = signed_prekey_store
89 .get_signed_pre_key(message.signed_pre_key_id())
90 .await?
91 .key_pair()?;
92
93 let our_kyber_pre_key_pair: Option<kem::KeyPair>;
95 if let Some(kyber_pre_key_id) = message.kyber_pre_key_id() {
96 our_kyber_pre_key_pair = Some(
97 kyber_prekey_store
98 .get_kyber_pre_key(kyber_pre_key_id)
99 .await?
100 .key_pair()?,
101 );
102 } else {
103 our_kyber_pre_key_pair = None;
104 }
105
106 let our_one_time_pre_key_pair = if let Some(pre_key_id) = message.pre_key_id() {
107 log::info!("processing PreKey message from {}", remote_address);
108 Some(pre_key_store.get_pre_key(pre_key_id).await?.key_pair()?)
109 } else {
110 log::warn!(
111 "processing PreKey message from {} which had no one-time prekey",
112 remote_address
113 );
114 None
115 };
116
117 let parameters = BobSignalProtocolParameters::new(
118 identity_store.get_identity_key_pair().await?,
119 our_signed_pre_key_pair, our_one_time_pre_key_pair,
121 our_signed_pre_key_pair, our_kyber_pre_key_pair,
123 *message.identity_key(),
124 *message.base_key(),
125 message.kyber_ciphertext(),
126 );
127
128 let mut new_session = ratchet::initialize_bob_session(¶meters)?;
129
130 new_session.set_local_registration_id(identity_store.get_local_registration_id().await?);
131 new_session.set_remote_registration_id(message.registration_id());
132
133 session_record.promote_state(new_session);
134
135 let pre_keys_used = PreKeysUsed {
136 pre_key_id: message.pre_key_id(),
137 kyber_pre_key_id: message.kyber_pre_key_id(),
138 };
139 Ok(pre_keys_used)
140}
141
142pub async fn process_prekey_bundle<R: Rng + CryptoRng>(
143 remote_address: &ProtocolAddress,
144 session_store: &mut dyn SessionStore,
145 identity_store: &mut dyn IdentityKeyStore,
146 bundle: &PreKeyBundle,
147 now: SystemTime,
148 mut csprng: &mut R,
149) -> Result<()> {
150 let their_identity_key = bundle.identity_key()?;
151
152 if !identity_store
153 .is_trusted_identity(remote_address, their_identity_key, Direction::Sending)
154 .await?
155 {
156 return Err(SignalProtocolError::UntrustedIdentity(
157 remote_address.clone(),
158 ));
159 }
160
161 if !their_identity_key.public_key().verify_signature(
162 &bundle.signed_pre_key_public()?.serialize(),
163 bundle.signed_pre_key_signature()?,
164 )? {
165 return Err(SignalProtocolError::SignatureValidationFailed);
166 }
167
168 if let Some(kyber_public) = bundle.kyber_pre_key_public()? {
169 if !their_identity_key.public_key().verify_signature(
170 kyber_public.serialize().as_ref(),
171 bundle
172 .kyber_pre_key_signature()?
173 .expect("signature must be present"),
174 )? {
175 return Err(SignalProtocolError::SignatureValidationFailed);
176 }
177 }
178
179 let mut session_record = session_store
180 .load_session(remote_address)
181 .await?
182 .unwrap_or_else(SessionRecord::new_fresh);
183
184 let our_base_key_pair = KeyPair::generate(&mut csprng);
185 let their_signed_prekey = bundle.signed_pre_key_public()?;
186
187 let their_one_time_prekey_id = bundle.pre_key_id()?;
188
189 let our_identity_key_pair = identity_store.get_identity_key_pair().await?;
190
191 let mut parameters = AliceSignalProtocolParameters::new(
192 our_identity_key_pair,
193 our_base_key_pair,
194 *their_identity_key,
195 their_signed_prekey,
196 their_signed_prekey,
197 );
198 if let Some(key) = bundle.pre_key_public()? {
199 parameters.set_their_one_time_pre_key(key);
200 }
201
202 if let Some(key) = bundle.kyber_pre_key_public()? {
203 parameters.set_their_kyber_pre_key(key);
204 }
205
206 let mut session = ratchet::initialize_alice_session(¶meters, csprng)?;
207
208 log::info!(
209 "set_unacknowledged_pre_key_message for: {} with preKeyId: {}",
210 remote_address,
211 their_one_time_prekey_id.map_or_else(|| "<none>".to_string(), |id| id.to_string())
212 );
213
214 session.set_unacknowledged_pre_key_message(
215 their_one_time_prekey_id,
216 bundle.signed_pre_key_id()?,
217 &our_base_key_pair.public_key,
218 now,
219 );
220
221 if let Some(kyber_pre_key_id) = bundle.kyber_pre_key_id()? {
222 session.set_unacknowledged_kyber_pre_key_id(kyber_pre_key_id);
223 }
224
225 session.set_local_registration_id(identity_store.get_local_registration_id().await?);
226 session.set_remote_registration_id(bundle.registration_id()?);
227
228 identity_store
229 .save_identity(remote_address, their_identity_key)
230 .await?;
231
232 session_record.promote_state(session);
233
234 session_store
235 .store_session(remote_address, &session_record)
236 .await?;
237
238 Ok(())
239}