1use std::convert::TryFrom;
2
3use crate::{
4 timestamp::TimestampExt as _,
5 utils::{serde_base64, serde_identity_key},
6};
7use async_trait::async_trait;
8use libsignal_protocol::{
9 error::SignalProtocolError, kem, GenericSignedPreKey, IdentityKey,
10 IdentityKeyPair, IdentityKeyStore, KeyPair, KyberPreKeyId,
11 KyberPreKeyRecord, KyberPreKeyStore, PreKeyRecord, PreKeyStore,
12 SignedPreKeyId, SignedPreKeyRecord, SignedPreKeyStore, Timestamp,
13};
14
15use rand::{CryptoRng, Rng};
16use serde::{Deserialize, Serialize};
17use tracing::Instrument;
18
19#[async_trait(?Send)]
20pub trait KyberPreKeyStoreExt: KyberPreKeyStore {
24 async fn store_last_resort_kyber_pre_key(
25 &mut self,
26 kyber_prekey_id: KyberPreKeyId,
27 record: &KyberPreKeyRecord,
28 ) -> Result<(), SignalProtocolError>;
29
30 async fn load_last_resort_kyber_pre_keys(
31 &self,
32 ) -> Result<Vec<KyberPreKeyRecord>, SignalProtocolError>;
33
34 async fn remove_kyber_pre_key(
35 &mut self,
36 kyber_prekey_id: KyberPreKeyId,
37 ) -> Result<(), SignalProtocolError>;
38
39 async fn mark_all_one_time_kyber_pre_keys_stale_if_necessary(
41 &mut self,
42 stale_time: chrono::DateTime<chrono::Utc>,
43 ) -> Result<(), SignalProtocolError>;
44
45 async fn delete_all_stale_one_time_kyber_pre_keys(
47 &mut self,
48 threshold: chrono::DateTime<chrono::Utc>,
49 min_count: usize,
50 ) -> Result<(), SignalProtocolError>;
51}
52
53#[async_trait(?Send)]
57pub trait PreKeysStore:
58 PreKeyStore
59 + IdentityKeyStore
60 + SignedPreKeyStore
61 + KyberPreKeyStore
62 + KyberPreKeyStoreExt
63{
64 async fn next_pre_key_id(&self) -> Result<u32, SignalProtocolError>;
66
67 async fn next_signed_pre_key_id(&self) -> Result<u32, SignalProtocolError>;
69
70 async fn next_pq_pre_key_id(&self) -> Result<u32, SignalProtocolError>;
72
73 async fn signed_pre_keys_count(&self)
75 -> Result<usize, SignalProtocolError>;
76
77 async fn kyber_pre_keys_count(
79 &self,
80 last_resort: bool,
81 ) -> Result<usize, SignalProtocolError>;
82
83 async fn signed_prekey_id(
84 &self,
85 ) -> Result<Option<SignedPreKeyId>, SignalProtocolError>;
86
87 async fn last_resort_kyber_prekey_id(
88 &self,
89 ) -> Result<Option<KyberPreKeyId>, SignalProtocolError>;
90}
91
92#[derive(Debug, Deserialize, Serialize)]
93#[serde(rename_all = "camelCase")]
94pub struct PreKeyEntity {
95 pub key_id: u32,
96 #[serde(with = "serde_base64")]
97 pub public_key: Vec<u8>,
98}
99
100impl TryFrom<PreKeyRecord> for PreKeyEntity {
101 type Error = SignalProtocolError;
102
103 fn try_from(key: PreKeyRecord) -> Result<Self, Self::Error> {
104 Ok(PreKeyEntity {
105 key_id: key.id()?.into(),
106 public_key: key.key_pair()?.public_key.serialize().to_vec(),
107 })
108 }
109}
110
111#[derive(Debug, Deserialize, Serialize)]
112#[serde(rename_all = "camelCase")]
113pub struct SignedPreKeyEntity {
114 pub key_id: u32,
115 #[serde(with = "serde_base64")]
116 pub public_key: Vec<u8>,
117 #[serde(with = "serde_base64")]
118 pub signature: Vec<u8>,
119}
120
121impl TryFrom<&'_ SignedPreKeyRecord> for SignedPreKeyEntity {
122 type Error = SignalProtocolError;
123
124 fn try_from(key: &'_ SignedPreKeyRecord) -> Result<Self, Self::Error> {
125 Ok(SignedPreKeyEntity {
126 key_id: key.id()?.into(),
127 public_key: key.key_pair()?.public_key.serialize().to_vec(),
128 signature: key.signature()?.to_vec(),
129 })
130 }
131}
132
133impl TryFrom<SignedPreKeyRecord> for SignedPreKeyEntity {
134 type Error = SignalProtocolError;
135
136 fn try_from(key: SignedPreKeyRecord) -> Result<Self, Self::Error> {
137 SignedPreKeyEntity::try_from(&key)
138 }
139}
140
141#[derive(Debug, Deserialize, Serialize)]
142#[serde(rename_all = "camelCase")]
143pub struct KyberPreKeyEntity {
144 pub key_id: u32,
145 #[serde(with = "serde_base64")]
146 pub public_key: Vec<u8>,
147 #[serde(with = "serde_base64")]
148 pub signature: Vec<u8>,
149}
150
151impl TryFrom<&'_ KyberPreKeyRecord> for KyberPreKeyEntity {
152 type Error = SignalProtocolError;
153
154 fn try_from(key: &'_ KyberPreKeyRecord) -> Result<Self, Self::Error> {
155 Ok(KyberPreKeyEntity {
156 key_id: key.id()?.into(),
157 public_key: key.key_pair()?.public_key.serialize().to_vec(),
158 signature: key.signature()?,
159 })
160 }
161}
162
163impl TryFrom<KyberPreKeyRecord> for KyberPreKeyEntity {
164 type Error = SignalProtocolError;
165
166 fn try_from(key: KyberPreKeyRecord) -> Result<Self, Self::Error> {
167 KyberPreKeyEntity::try_from(&key)
168 }
169}
170
171#[derive(Debug, Serialize)]
172#[serde(rename_all = "camelCase")]
173pub struct PreKeyState {
174 pub pre_keys: Vec<PreKeyEntity>,
175 pub signed_pre_key: SignedPreKeyEntity,
176 #[serde(with = "serde_identity_key")]
177 pub identity_key: IdentityKey,
178 #[serde(skip_serializing_if = "Option::is_none")]
179 pub pq_last_resort_key: Option<KyberPreKeyEntity>,
180 pub pq_pre_keys: Vec<KyberPreKeyEntity>,
181}
182
183pub(crate) const PRE_KEY_MINIMUM: u32 = 10;
184pub(crate) const PRE_KEY_BATCH_SIZE: u32 = 100;
185pub(crate) const PRE_KEY_MEDIUM_MAX_VALUE: u32 = 0xFFFFFF;
186
187pub(crate) async fn replenish_pre_keys<R: Rng + CryptoRng, P: PreKeysStore>(
188 protocol_store: &mut P,
189 csprng: &mut R,
190 identity_key_pair: &IdentityKeyPair,
191 use_last_resort_key: bool,
192 pre_key_count: u32,
193 kyber_pre_key_count: u32,
194) -> Result<
195 (
196 Vec<PreKeyRecord>,
197 SignedPreKeyRecord,
198 Vec<KyberPreKeyRecord>,
199 Option<KyberPreKeyRecord>,
200 ),
201 SignalProtocolError,
202> {
203 let pre_keys_offset_id = protocol_store.next_pre_key_id().await?;
204 let next_signed_pre_key_id =
205 protocol_store.next_signed_pre_key_id().await?;
206 let pq_pre_keys_offset_id = protocol_store.next_pq_pre_key_id().await?;
207
208 let span = tracing::span!(tracing::Level::DEBUG, "Generating pre keys");
209
210 let mut pre_keys = vec![];
211 let mut pq_pre_keys = vec![];
212
213 for i in 0..pre_key_count {
215 let key_pair = KeyPair::generate(csprng);
216 let pre_key_id =
217 (((pre_keys_offset_id + i) % (PRE_KEY_MEDIUM_MAX_VALUE - 1)) + 1)
218 .into();
219 let pre_key_record = PreKeyRecord::new(pre_key_id, &key_pair);
220 protocol_store
221 .save_pre_key(pre_key_id, &pre_key_record)
222 .instrument(tracing::trace_span!(parent: &span, "save pre key", ?pre_key_id)).await?;
223 pre_keys.push(pre_key_record);
228 }
229
230 for i in 0..kyber_pre_key_count {
232 let pre_key_id = (((pq_pre_keys_offset_id + i)
233 % (PRE_KEY_MEDIUM_MAX_VALUE - 1))
234 + 1)
235 .into();
236 let pre_key_record = KyberPreKeyRecord::generate(
237 kem::KeyType::Kyber1024,
238 pre_key_id,
239 identity_key_pair.private_key(),
240 )?;
241 protocol_store
242 .save_kyber_pre_key(pre_key_id, &pre_key_record)
243 .instrument(tracing::trace_span!(parent: &span, "save kyber pre key", ?pre_key_id)).await?;
244 pq_pre_keys.push(pre_key_record);
249 }
250
251 let signed_pre_key_pair = KeyPair::generate(csprng);
253 let signed_pre_key_public = signed_pre_key_pair.public_key;
254 let signed_pre_key_signature = identity_key_pair
255 .private_key()
256 .calculate_signature(&signed_pre_key_public.serialize(), csprng)?;
257
258 let signed_prekey_record = SignedPreKeyRecord::new(
259 next_signed_pre_key_id.into(),
260 Timestamp::now(),
261 &signed_pre_key_pair,
262 &signed_pre_key_signature,
263 );
264
265 protocol_store
266 .save_signed_pre_key(
267 next_signed_pre_key_id.into(),
268 &signed_prekey_record,
269 )
270 .instrument(tracing::trace_span!(parent: &span, "save signed pre key", signed_pre_key_id = ?next_signed_pre_key_id)).await?;
271
272 let pq_last_resort_key = if use_last_resort_key {
273 let pre_key_id = (((pq_pre_keys_offset_id + kyber_pre_key_count)
274 % (PRE_KEY_MEDIUM_MAX_VALUE - 1))
275 + 1)
276 .into();
277
278 if !pq_pre_keys.is_empty() {
279 assert_eq!(
280 u32::from(pq_pre_keys.last().unwrap().id()?) + 1,
281 u32::from(pre_key_id)
282 );
283 }
284
285 let pre_key_record = KyberPreKeyRecord::generate(
286 kem::KeyType::Kyber1024,
287 pre_key_id,
288 identity_key_pair.private_key(),
289 )?;
290 protocol_store
291 .store_last_resort_kyber_pre_key(pre_key_id, &pre_key_record)
292 .instrument(tracing::trace_span!(parent: &span, "save last resort kyber pre key", ?pre_key_id)).await?;
293 Some(pre_key_record)
298 } else {
299 None
300 };
301
302 Ok((
303 pre_keys,
304 signed_prekey_record,
305 pq_pre_keys,
306 pq_last_resort_key,
307 ))
308}