libsignal_service/
pre_keys.rs

1use std::convert::TryFrom;
2
3use crate::{
4    timestamp::TimestampExt as _,
5    utils::{serde_base64, serde_identity_key},
6};
7use async_trait::async_trait;
8use libsignal_protocol::{
9    error::SignalProtocolError, kem, GenericSignedPreKey, IdentityKey,
10    IdentityKeyPair, IdentityKeyStore, KeyPair, KyberPreKeyId,
11    KyberPreKeyRecord, KyberPreKeyStore, PreKeyRecord, PreKeyStore,
12    SignedPreKeyId, SignedPreKeyRecord, SignedPreKeyStore, Timestamp,
13};
14
15use rand::{CryptoRng, Rng};
16use serde::{Deserialize, Serialize};
17use tracing::Instrument;
18
19#[async_trait(?Send)]
20/// Additional methods for the Kyber pre key store
21///
22/// Analogue of Android's ServiceKyberPreKeyStore
23pub trait KyberPreKeyStoreExt: KyberPreKeyStore {
24    async fn store_last_resort_kyber_pre_key(
25        &mut self,
26        kyber_prekey_id: KyberPreKeyId,
27        record: &KyberPreKeyRecord,
28    ) -> Result<(), SignalProtocolError>;
29
30    async fn load_last_resort_kyber_pre_keys(
31        &self,
32    ) -> Result<Vec<KyberPreKeyRecord>, SignalProtocolError>;
33
34    async fn remove_kyber_pre_key(
35        &mut self,
36        kyber_prekey_id: KyberPreKeyId,
37    ) -> Result<(), SignalProtocolError>;
38
39    /// Analogous to markAllOneTimeKyberPreKeysStaleIfNecessary
40    async fn mark_all_one_time_kyber_pre_keys_stale_if_necessary(
41        &mut self,
42        stale_time: chrono::DateTime<chrono::Utc>,
43    ) -> Result<(), SignalProtocolError>;
44
45    /// Analogue of deleteAllStaleOneTimeKyberPreKeys
46    async fn delete_all_stale_one_time_kyber_pre_keys(
47        &mut self,
48        threshold: chrono::DateTime<chrono::Utc>,
49        min_count: usize,
50    ) -> Result<(), SignalProtocolError>;
51}
52
53/// Stores the ID of keys published ahead of time
54///
55/// <https://signal.org/docs/specifications/x3dh/>
56#[async_trait(?Send)]
57pub trait PreKeysStore:
58    PreKeyStore
59    + IdentityKeyStore
60    + SignedPreKeyStore
61    + KyberPreKeyStore
62    + KyberPreKeyStoreExt
63{
64    /// ID of the next pre key
65    async fn next_pre_key_id(&self) -> Result<u32, SignalProtocolError>;
66
67    /// ID of the next signed pre key
68    async fn next_signed_pre_key_id(&self) -> Result<u32, SignalProtocolError>;
69
70    /// ID of the next PQ pre key
71    async fn next_pq_pre_key_id(&self) -> Result<u32, SignalProtocolError>;
72
73    /// number of signed pre-keys we currently have in store
74    async fn signed_pre_keys_count(&self)
75        -> Result<usize, SignalProtocolError>;
76
77    /// number of kyber pre-keys we currently have in store
78    async fn kyber_pre_keys_count(
79        &self,
80        last_resort: bool,
81    ) -> Result<usize, SignalProtocolError>;
82
83    async fn signed_prekey_id(
84        &self,
85    ) -> Result<Option<SignedPreKeyId>, SignalProtocolError>;
86
87    async fn last_resort_kyber_prekey_id(
88        &self,
89    ) -> Result<Option<KyberPreKeyId>, SignalProtocolError>;
90}
91
92#[derive(Debug, Deserialize, Serialize)]
93#[serde(rename_all = "camelCase")]
94pub struct PreKeyEntity {
95    pub key_id: u32,
96    #[serde(with = "serde_base64")]
97    pub public_key: Vec<u8>,
98}
99
100impl TryFrom<PreKeyRecord> for PreKeyEntity {
101    type Error = SignalProtocolError;
102
103    fn try_from(key: PreKeyRecord) -> Result<Self, Self::Error> {
104        Ok(PreKeyEntity {
105            key_id: key.id()?.into(),
106            public_key: key.key_pair()?.public_key.serialize().to_vec(),
107        })
108    }
109}
110
111#[derive(Debug, Deserialize, Serialize)]
112#[serde(rename_all = "camelCase")]
113pub struct SignedPreKeyEntity {
114    pub key_id: u32,
115    #[serde(with = "serde_base64")]
116    pub public_key: Vec<u8>,
117    #[serde(with = "serde_base64")]
118    pub signature: Vec<u8>,
119}
120
121impl TryFrom<&'_ SignedPreKeyRecord> for SignedPreKeyEntity {
122    type Error = SignalProtocolError;
123
124    fn try_from(key: &'_ SignedPreKeyRecord) -> Result<Self, Self::Error> {
125        Ok(SignedPreKeyEntity {
126            key_id: key.id()?.into(),
127            public_key: key.key_pair()?.public_key.serialize().to_vec(),
128            signature: key.signature()?.to_vec(),
129        })
130    }
131}
132
133impl TryFrom<SignedPreKeyRecord> for SignedPreKeyEntity {
134    type Error = SignalProtocolError;
135
136    fn try_from(key: SignedPreKeyRecord) -> Result<Self, Self::Error> {
137        SignedPreKeyEntity::try_from(&key)
138    }
139}
140
141#[derive(Debug, Deserialize, Serialize)]
142#[serde(rename_all = "camelCase")]
143pub struct KyberPreKeyEntity {
144    pub key_id: u32,
145    #[serde(with = "serde_base64")]
146    pub public_key: Vec<u8>,
147    #[serde(with = "serde_base64")]
148    pub signature: Vec<u8>,
149}
150
151impl TryFrom<&'_ KyberPreKeyRecord> for KyberPreKeyEntity {
152    type Error = SignalProtocolError;
153
154    fn try_from(key: &'_ KyberPreKeyRecord) -> Result<Self, Self::Error> {
155        Ok(KyberPreKeyEntity {
156            key_id: key.id()?.into(),
157            public_key: key.key_pair()?.public_key.serialize().to_vec(),
158            signature: key.signature()?,
159        })
160    }
161}
162
163impl TryFrom<KyberPreKeyRecord> for KyberPreKeyEntity {
164    type Error = SignalProtocolError;
165
166    fn try_from(key: KyberPreKeyRecord) -> Result<Self, Self::Error> {
167        KyberPreKeyEntity::try_from(&key)
168    }
169}
170
171#[derive(Debug, Serialize)]
172#[serde(rename_all = "camelCase")]
173pub struct PreKeyState {
174    pub pre_keys: Vec<PreKeyEntity>,
175    pub signed_pre_key: SignedPreKeyEntity,
176    #[serde(with = "serde_identity_key")]
177    pub identity_key: IdentityKey,
178    #[serde(skip_serializing_if = "Option::is_none")]
179    pub pq_last_resort_key: Option<KyberPreKeyEntity>,
180    pub pq_pre_keys: Vec<KyberPreKeyEntity>,
181}
182
183pub(crate) const PRE_KEY_MINIMUM: u32 = 10;
184pub(crate) const PRE_KEY_BATCH_SIZE: u32 = 100;
185pub(crate) const PRE_KEY_MEDIUM_MAX_VALUE: u32 = 0xFFFFFF;
186
187pub(crate) async fn replenish_pre_keys<R: Rng + CryptoRng, P: PreKeysStore>(
188    protocol_store: &mut P,
189    csprng: &mut R,
190    identity_key_pair: &IdentityKeyPair,
191    use_last_resort_key: bool,
192    pre_key_count: u32,
193    kyber_pre_key_count: u32,
194) -> Result<
195    (
196        Vec<PreKeyRecord>,
197        SignedPreKeyRecord,
198        Vec<KyberPreKeyRecord>,
199        Option<KyberPreKeyRecord>,
200    ),
201    SignalProtocolError,
202> {
203    let pre_keys_offset_id = protocol_store.next_pre_key_id().await?;
204    let next_signed_pre_key_id =
205        protocol_store.next_signed_pre_key_id().await?;
206    let pq_pre_keys_offset_id = protocol_store.next_pq_pre_key_id().await?;
207
208    let span = tracing::span!(tracing::Level::DEBUG, "Generating pre keys");
209
210    let mut pre_keys = vec![];
211    let mut pq_pre_keys = vec![];
212
213    // EC keys
214    for i in 0..pre_key_count {
215        let key_pair = KeyPair::generate(csprng);
216        let pre_key_id =
217            (((pre_keys_offset_id + i) % (PRE_KEY_MEDIUM_MAX_VALUE - 1)) + 1)
218                .into();
219        let pre_key_record = PreKeyRecord::new(pre_key_id, &key_pair);
220        protocol_store
221                    .save_pre_key(pre_key_id, &pre_key_record)
222                    .instrument(tracing::trace_span!(parent: &span, "save pre key", ?pre_key_id)).await?;
223        // TODO: Shouldn't this also remove the previous pre-keys from storage?
224        //       I think we might want to update the storage, and then sync the storage to the
225        //       server.
226
227        pre_keys.push(pre_key_record);
228    }
229
230    // Kyber keys
231    for i in 0..kyber_pre_key_count {
232        let pre_key_id = (((pq_pre_keys_offset_id + i)
233            % (PRE_KEY_MEDIUM_MAX_VALUE - 1))
234            + 1)
235        .into();
236        let pre_key_record = KyberPreKeyRecord::generate(
237            kem::KeyType::Kyber1024,
238            pre_key_id,
239            identity_key_pair.private_key(),
240        )?;
241        protocol_store
242                    .save_kyber_pre_key(pre_key_id, &pre_key_record)
243                    .instrument(tracing::trace_span!(parent: &span, "save kyber pre key", ?pre_key_id)).await?;
244        // TODO: Shouldn't this also remove the previous pre-keys from storage?
245        //       I think we might want to update the storage, and then sync the storage to the
246        //       server.
247
248        pq_pre_keys.push(pre_key_record);
249    }
250
251    // Generate and store the next signed prekey
252    let signed_pre_key_pair = KeyPair::generate(csprng);
253    let signed_pre_key_public = signed_pre_key_pair.public_key;
254    let signed_pre_key_signature = identity_key_pair
255        .private_key()
256        .calculate_signature(&signed_pre_key_public.serialize(), csprng)?;
257
258    let signed_prekey_record = SignedPreKeyRecord::new(
259        next_signed_pre_key_id.into(),
260        Timestamp::now(),
261        &signed_pre_key_pair,
262        &signed_pre_key_signature,
263    );
264
265    protocol_store
266                .save_signed_pre_key(
267                    next_signed_pre_key_id.into(),
268                    &signed_prekey_record,
269                )
270                    .instrument(tracing::trace_span!(parent: &span, "save signed pre key", signed_pre_key_id = ?next_signed_pre_key_id)).await?;
271
272    let pq_last_resort_key = if use_last_resort_key {
273        let pre_key_id = (((pq_pre_keys_offset_id + kyber_pre_key_count)
274            % (PRE_KEY_MEDIUM_MAX_VALUE - 1))
275            + 1)
276        .into();
277
278        if !pq_pre_keys.is_empty() {
279            assert_eq!(
280                u32::from(pq_pre_keys.last().unwrap().id()?) + 1,
281                u32::from(pre_key_id)
282            );
283        }
284
285        let pre_key_record = KyberPreKeyRecord::generate(
286            kem::KeyType::Kyber1024,
287            pre_key_id,
288            identity_key_pair.private_key(),
289        )?;
290        protocol_store
291                    .store_last_resort_kyber_pre_key(pre_key_id, &pre_key_record)
292                    .instrument(tracing::trace_span!(parent: &span, "save last resort kyber pre key", ?pre_key_id)).await?;
293        // TODO: Shouldn't this also remove the previous pre-keys from storage?
294        //       I think we might want to update the storage, and then sync the storage to the
295        //       server.
296
297        Some(pre_key_record)
298    } else {
299        None
300    };
301
302    Ok((
303        pre_keys,
304        signed_prekey_record,
305        pq_pre_keys,
306        pq_last_resort_key,
307    ))
308}