libsignal_protocol/storage/
inmem.rs

1//
2// Copyright 2020-2022 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6//! Implementations for stores defined in [super::traits].
7//!
8//! These implementations are purely in-memory, and therefore most likely useful for testing.
9
10use std::borrow::Cow;
11use std::collections::HashMap;
12
13use async_trait::async_trait;
14use uuid::Uuid;
15
16use crate::storage::traits::{self, IdentityChange};
17use crate::{
18    CiphertextMessageType, IdentityKey, IdentityKeyPair, KyberPreKeyId, KyberPreKeyRecord,
19    PreKeyId, PreKeyRecord, ProtocolAddress, PublicKey, Result, SenderKeyRecord, SessionRecord,
20    SignalProtocolError, SignedPreKeyId, SignedPreKeyRecord,
21};
22
23/// Reference implementation of [traits::IdentityKeyStore].
24#[derive(Clone)]
25pub struct InMemIdentityKeyStore {
26    key_pair: IdentityKeyPair,
27    registration_id: u32,
28    known_keys: HashMap<ProtocolAddress, IdentityKey>,
29}
30
31impl InMemIdentityKeyStore {
32    /// Create a new instance.
33    ///
34    /// `key_pair` corresponds to [traits::IdentityKeyStore::get_identity_key_pair], and
35    /// `registration_id` corresponds to [traits::IdentityKeyStore::get_local_registration_id].
36    pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Self {
37        Self {
38            key_pair,
39            registration_id,
40            known_keys: HashMap::new(),
41        }
42    }
43
44    /// Clear the mapping of known keys.
45    pub fn reset(&mut self) {
46        self.known_keys.clear();
47    }
48}
49
50#[async_trait(?Send)]
51impl traits::IdentityKeyStore for InMemIdentityKeyStore {
52    async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
53        Ok(self.key_pair)
54    }
55
56    async fn get_local_registration_id(&self) -> Result<u32> {
57        Ok(self.registration_id)
58    }
59
60    async fn save_identity(
61        &mut self,
62        address: &ProtocolAddress,
63        identity: &IdentityKey,
64    ) -> Result<IdentityChange> {
65        match self.known_keys.get(address) {
66            None => {
67                self.known_keys.insert(address.clone(), *identity);
68                Ok(IdentityChange::NewOrUnchanged)
69            }
70            Some(k) if k == identity => Ok(IdentityChange::NewOrUnchanged),
71            Some(_k) => {
72                self.known_keys.insert(address.clone(), *identity);
73                Ok(IdentityChange::ReplacedExisting)
74            }
75        }
76    }
77
78    async fn is_trusted_identity(
79        &self,
80        address: &ProtocolAddress,
81        identity: &IdentityKey,
82        _direction: traits::Direction,
83    ) -> Result<bool> {
84        match self.known_keys.get(address) {
85            None => {
86                Ok(true) // first use
87            }
88            Some(k) => Ok(k == identity),
89        }
90    }
91
92    async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
93        match self.known_keys.get(address) {
94            None => Ok(None),
95            Some(k) => Ok(Some(k.to_owned())),
96        }
97    }
98}
99
100/// Reference implementation of [traits::PreKeyStore].
101#[derive(Clone)]
102pub struct InMemPreKeyStore {
103    pre_keys: HashMap<PreKeyId, PreKeyRecord>,
104}
105
106impl InMemPreKeyStore {
107    /// Create an empty pre-key store.
108    pub fn new() -> Self {
109        Self {
110            pre_keys: HashMap::new(),
111        }
112    }
113
114    /// Returns all registered pre-key ids
115    pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
116        self.pre_keys.keys()
117    }
118}
119
120impl Default for InMemPreKeyStore {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126#[async_trait(?Send)]
127impl traits::PreKeyStore for InMemPreKeyStore {
128    async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
129        Ok(self
130            .pre_keys
131            .get(&id)
132            .ok_or(SignalProtocolError::InvalidPreKeyId)?
133            .clone())
134    }
135
136    async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
137        // This overwrites old values, which matches Java behavior, but is it correct?
138        self.pre_keys.insert(id, record.to_owned());
139        Ok(())
140    }
141
142    async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
143        // If id does not exist this silently does nothing
144        self.pre_keys.remove(&id);
145        Ok(())
146    }
147}
148
149/// Reference implementation of [traits::SignedPreKeyStore].
150#[derive(Clone)]
151pub struct InMemSignedPreKeyStore {
152    signed_pre_keys: HashMap<SignedPreKeyId, SignedPreKeyRecord>,
153}
154
155impl InMemSignedPreKeyStore {
156    /// Create an empty signed pre-key store.
157    pub fn new() -> Self {
158        Self {
159            signed_pre_keys: HashMap::new(),
160        }
161    }
162
163    /// Returns all registered signed pre-key ids
164    pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
165        self.signed_pre_keys.keys()
166    }
167}
168
169impl Default for InMemSignedPreKeyStore {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175#[async_trait(?Send)]
176impl traits::SignedPreKeyStore for InMemSignedPreKeyStore {
177    async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
178        Ok(self
179            .signed_pre_keys
180            .get(&id)
181            .ok_or(SignalProtocolError::InvalidSignedPreKeyId)?
182            .clone())
183    }
184
185    async fn save_signed_pre_key(
186        &mut self,
187        id: SignedPreKeyId,
188        record: &SignedPreKeyRecord,
189    ) -> Result<()> {
190        // This overwrites old values, which matches Java behavior, but is it correct?
191        self.signed_pre_keys.insert(id, record.to_owned());
192        Ok(())
193    }
194}
195
196/// Basic implementation of [traits::KyberPreKeyStore].
197///
198/// Note that this implementation does not clear any keys upon use! This is correct for last-resort
199/// keys, but a real client would normally have a set of one-time keys to use first.
200#[derive(Clone)]
201pub struct InMemKyberPreKeyStore {
202    kyber_pre_keys: HashMap<KyberPreKeyId, KyberPreKeyRecord>,
203    base_keys_seen: HashMap<(KyberPreKeyId, SignedPreKeyId), Vec<PublicKey>>,
204}
205
206impl InMemKyberPreKeyStore {
207    /// Create an empty kyber pre-key store.
208    pub fn new() -> Self {
209        Self {
210            kyber_pre_keys: HashMap::new(),
211            base_keys_seen: HashMap::new(),
212        }
213    }
214
215    /// Returns all registered Kyber pre-key ids
216    pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
217        self.kyber_pre_keys.keys()
218    }
219}
220
221impl Default for InMemKyberPreKeyStore {
222    fn default() -> Self {
223        Self::new()
224    }
225}
226
227#[async_trait(?Send)]
228impl traits::KyberPreKeyStore for InMemKyberPreKeyStore {
229    async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
230        Ok(self
231            .kyber_pre_keys
232            .get(&kyber_prekey_id)
233            .ok_or(SignalProtocolError::InvalidKyberPreKeyId)?
234            .clone())
235    }
236
237    async fn save_kyber_pre_key(
238        &mut self,
239        kyber_prekey_id: KyberPreKeyId,
240        record: &KyberPreKeyRecord,
241    ) -> Result<()> {
242        self.kyber_pre_keys
243            .insert(kyber_prekey_id, record.to_owned());
244        Ok(())
245    }
246
247    async fn mark_kyber_pre_key_used(
248        &mut self,
249        kyber_prekey_id: KyberPreKeyId,
250        ec_prekey_id: SignedPreKeyId,
251        base_key: &PublicKey,
252    ) -> Result<()> {
253        let base_keys_seen = self
254            .base_keys_seen
255            .entry((kyber_prekey_id, ec_prekey_id))
256            .or_default();
257        if base_keys_seen.contains(base_key) {
258            return Err(SignalProtocolError::InvalidMessage(
259                CiphertextMessageType::PreKey,
260                "reused base key",
261            ));
262        }
263        base_keys_seen.push(*base_key);
264        Ok(())
265    }
266}
267
268/// Reference implementation of [traits::SessionStore].
269#[derive(Clone)]
270pub struct InMemSessionStore {
271    sessions: HashMap<ProtocolAddress, SessionRecord>,
272}
273
274impl InMemSessionStore {
275    /// Create an empty session store.
276    pub fn new() -> Self {
277        Self {
278            sessions: HashMap::new(),
279        }
280    }
281
282    /// Bulk version of [`SessionStore::load_session`].
283    ///
284    /// Useful for [crate::sealed_sender_multi_recipient_encrypt].
285    ///
286    /// [`SessionStore::load_session`]: crate::SessionStore::load_session
287    pub fn load_existing_sessions(
288        &self,
289        addresses: &[&ProtocolAddress],
290    ) -> Result<Vec<&SessionRecord>> {
291        addresses
292            .iter()
293            .map(|&address| {
294                self.sessions
295                    .get(address)
296                    .ok_or_else(|| SignalProtocolError::SessionNotFound(address.clone()))
297            })
298            .collect()
299    }
300}
301
302impl Default for InMemSessionStore {
303    fn default() -> Self {
304        Self::new()
305    }
306}
307
308#[async_trait(?Send)]
309impl traits::SessionStore for InMemSessionStore {
310    async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
311        match self.sessions.get(address) {
312            None => Ok(None),
313            Some(s) => Ok(Some(s.clone())),
314        }
315    }
316
317    async fn store_session(
318        &mut self,
319        address: &ProtocolAddress,
320        record: &SessionRecord,
321    ) -> Result<()> {
322        self.sessions.insert(address.clone(), record.clone());
323        Ok(())
324    }
325}
326
327/// Reference implementation of [traits::SenderKeyStore].
328#[derive(Clone)]
329pub struct InMemSenderKeyStore {
330    // We use Cow keys in order to store owned values but compare to referenced ones.
331    // See https://users.rust-lang.org/t/hashmap-with-tuple-keys/12711/6.
332    keys: HashMap<(Cow<'static, ProtocolAddress>, Uuid), SenderKeyRecord>,
333}
334
335impl InMemSenderKeyStore {
336    /// Create an empty sender key store.
337    pub fn new() -> Self {
338        Self {
339            keys: HashMap::new(),
340        }
341    }
342}
343
344impl Default for InMemSenderKeyStore {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350#[async_trait(?Send)]
351impl traits::SenderKeyStore for InMemSenderKeyStore {
352    async fn store_sender_key(
353        &mut self,
354        sender: &ProtocolAddress,
355        distribution_id: Uuid,
356        record: &SenderKeyRecord,
357    ) -> Result<()> {
358        self.keys.insert(
359            (Cow::Owned(sender.clone()), distribution_id),
360            record.clone(),
361        );
362        Ok(())
363    }
364
365    async fn load_sender_key(
366        &mut self,
367        sender: &ProtocolAddress,
368        distribution_id: Uuid,
369    ) -> Result<Option<SenderKeyRecord>> {
370        Ok(self
371            .keys
372            .get(&(Cow::Borrowed(sender), distribution_id))
373            .cloned())
374    }
375}
376
377/// Reference implementation of [traits::ProtocolStore].
378#[allow(missing_docs)]
379#[derive(Clone)]
380pub struct InMemSignalProtocolStore {
381    pub session_store: InMemSessionStore,
382    pub pre_key_store: InMemPreKeyStore,
383    pub signed_pre_key_store: InMemSignedPreKeyStore,
384    pub kyber_pre_key_store: InMemKyberPreKeyStore,
385    pub identity_store: InMemIdentityKeyStore,
386    pub sender_key_store: InMemSenderKeyStore,
387}
388
389impl InMemSignalProtocolStore {
390    /// Create an object with the minimal implementation of [traits::ProtocolStore], representing
391    /// the given identity `key_pair` along with the separate randomly chosen `registration_id`.
392    pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Result<Self> {
393        Ok(Self {
394            session_store: InMemSessionStore::new(),
395            pre_key_store: InMemPreKeyStore::new(),
396            signed_pre_key_store: InMemSignedPreKeyStore::new(),
397            kyber_pre_key_store: InMemKyberPreKeyStore::new(),
398            identity_store: InMemIdentityKeyStore::new(key_pair, registration_id),
399            sender_key_store: InMemSenderKeyStore::new(),
400        })
401    }
402
403    /// Returns all registered pre-key ids
404    pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
405        self.pre_key_store.all_pre_key_ids()
406    }
407
408    /// Returns all registered signed pre-key ids
409    pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
410        self.signed_pre_key_store.all_signed_pre_key_ids()
411    }
412
413    /// Returns all registered Kyber pre-key ids
414    pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
415        self.kyber_pre_key_store.all_kyber_pre_key_ids()
416    }
417}
418
419#[async_trait(?Send)]
420impl traits::IdentityKeyStore for InMemSignalProtocolStore {
421    async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
422        self.identity_store.get_identity_key_pair().await
423    }
424
425    async fn get_local_registration_id(&self) -> Result<u32> {
426        self.identity_store.get_local_registration_id().await
427    }
428
429    async fn save_identity(
430        &mut self,
431        address: &ProtocolAddress,
432        identity: &IdentityKey,
433    ) -> Result<IdentityChange> {
434        self.identity_store.save_identity(address, identity).await
435    }
436
437    async fn is_trusted_identity(
438        &self,
439        address: &ProtocolAddress,
440        identity: &IdentityKey,
441        direction: traits::Direction,
442    ) -> Result<bool> {
443        self.identity_store
444            .is_trusted_identity(address, identity, direction)
445            .await
446    }
447
448    async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
449        self.identity_store.get_identity(address).await
450    }
451}
452
453#[async_trait(?Send)]
454impl traits::PreKeyStore for InMemSignalProtocolStore {
455    async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
456        self.pre_key_store.get_pre_key(id).await
457    }
458
459    async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
460        self.pre_key_store.save_pre_key(id, record).await
461    }
462
463    async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
464        self.pre_key_store.remove_pre_key(id).await
465    }
466}
467
468#[async_trait(?Send)]
469impl traits::SignedPreKeyStore for InMemSignalProtocolStore {
470    async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
471        self.signed_pre_key_store.get_signed_pre_key(id).await
472    }
473
474    async fn save_signed_pre_key(
475        &mut self,
476        id: SignedPreKeyId,
477        record: &SignedPreKeyRecord,
478    ) -> Result<()> {
479        self.signed_pre_key_store
480            .save_signed_pre_key(id, record)
481            .await
482    }
483}
484
485#[async_trait(?Send)]
486impl traits::KyberPreKeyStore for InMemSignalProtocolStore {
487    async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
488        self.kyber_pre_key_store
489            .get_kyber_pre_key(kyber_prekey_id)
490            .await
491    }
492
493    async fn save_kyber_pre_key(
494        &mut self,
495        kyber_prekey_id: KyberPreKeyId,
496        record: &KyberPreKeyRecord,
497    ) -> Result<()> {
498        self.kyber_pre_key_store
499            .save_kyber_pre_key(kyber_prekey_id, record)
500            .await
501    }
502
503    async fn mark_kyber_pre_key_used(
504        &mut self,
505        kyber_prekey_id: KyberPreKeyId,
506        ec_prekey_id: SignedPreKeyId,
507        base_key: &PublicKey,
508    ) -> Result<()> {
509        self.kyber_pre_key_store
510            .mark_kyber_pre_key_used(kyber_prekey_id, ec_prekey_id, base_key)
511            .await
512    }
513}
514
515#[async_trait(?Send)]
516impl traits::SessionStore for InMemSignalProtocolStore {
517    async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
518        self.session_store.load_session(address).await
519    }
520
521    async fn store_session(
522        &mut self,
523        address: &ProtocolAddress,
524        record: &SessionRecord,
525    ) -> Result<()> {
526        self.session_store.store_session(address, record).await
527    }
528}
529
530#[async_trait(?Send)]
531impl traits::SenderKeyStore for InMemSignalProtocolStore {
532    async fn store_sender_key(
533        &mut self,
534        sender: &ProtocolAddress,
535        distribution_id: Uuid,
536        record: &SenderKeyRecord,
537    ) -> Result<()> {
538        self.sender_key_store
539            .store_sender_key(sender, distribution_id, record)
540            .await
541    }
542
543    async fn load_sender_key(
544        &mut self,
545        sender: &ProtocolAddress,
546        distribution_id: Uuid,
547    ) -> Result<Option<SenderKeyRecord>> {
548        self.sender_key_store
549            .load_sender_key(sender, distribution_id)
550            .await
551    }
552}
553
554impl traits::ProtocolStore for InMemSignalProtocolStore {}