libsignal_protocol/storage/
inmem.rs

1//
2// Copyright 2020-2022 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6//! Implementations for stores defined in [super::traits].
7//!
8//! These implementations are purely in-memory, and therefore most likely useful for testing.
9
10use std::borrow::Cow;
11use std::collections::HashMap;
12
13use async_trait::async_trait;
14use uuid::Uuid;
15
16use crate::storage::traits;
17use crate::{
18    IdentityKey, IdentityKeyPair, KyberPreKeyId, KyberPreKeyRecord, PreKeyId, PreKeyRecord,
19    ProtocolAddress, Result, SenderKeyRecord, SessionRecord, SignalProtocolError, SignedPreKeyId,
20    SignedPreKeyRecord,
21};
22
23/// Reference implementation of [traits::IdentityKeyStore].
24#[derive(Clone)]
25pub struct InMemIdentityKeyStore {
26    key_pair: IdentityKeyPair,
27    registration_id: u32,
28    known_keys: HashMap<ProtocolAddress, IdentityKey>,
29}
30
31impl InMemIdentityKeyStore {
32    /// Create a new instance.
33    ///
34    /// `key_pair` corresponds to [traits::IdentityKeyStore::get_identity_key_pair], and
35    /// `registration_id` corresponds to [traits::IdentityKeyStore::get_local_registration_id].
36    pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Self {
37        Self {
38            key_pair,
39            registration_id,
40            known_keys: HashMap::new(),
41        }
42    }
43
44    /// Clear the mapping of known keys.
45    pub fn reset(&mut self) {
46        self.known_keys.clear();
47    }
48}
49
50#[async_trait(?Send)]
51impl traits::IdentityKeyStore for InMemIdentityKeyStore {
52    async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
53        Ok(self.key_pair)
54    }
55
56    async fn get_local_registration_id(&self) -> Result<u32> {
57        Ok(self.registration_id)
58    }
59
60    async fn save_identity(
61        &mut self,
62        address: &ProtocolAddress,
63        identity: &IdentityKey,
64    ) -> Result<bool> {
65        match self.known_keys.get(address) {
66            None => {
67                self.known_keys.insert(address.clone(), *identity);
68                Ok(false) // new key
69            }
70            Some(k) if k == identity => {
71                Ok(false) // same key
72            }
73            Some(_k) => {
74                self.known_keys.insert(address.clone(), *identity);
75                Ok(true) // overwrite
76            }
77        }
78    }
79
80    async fn is_trusted_identity(
81        &self,
82        address: &ProtocolAddress,
83        identity: &IdentityKey,
84        _direction: traits::Direction,
85    ) -> Result<bool> {
86        match self.known_keys.get(address) {
87            None => {
88                Ok(true) // first use
89            }
90            Some(k) => Ok(k == identity),
91        }
92    }
93
94    async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
95        match self.known_keys.get(address) {
96            None => Ok(None),
97            Some(k) => Ok(Some(k.to_owned())),
98        }
99    }
100}
101
102/// Reference implementation of [traits::PreKeyStore].
103#[derive(Clone)]
104pub struct InMemPreKeyStore {
105    pre_keys: HashMap<PreKeyId, PreKeyRecord>,
106}
107
108impl InMemPreKeyStore {
109    /// Create an empty pre-key store.
110    pub fn new() -> Self {
111        Self {
112            pre_keys: HashMap::new(),
113        }
114    }
115
116    /// Returns all registered pre-key ids
117    pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
118        self.pre_keys.keys()
119    }
120}
121
122impl Default for InMemPreKeyStore {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128#[async_trait(?Send)]
129impl traits::PreKeyStore for InMemPreKeyStore {
130    async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
131        Ok(self
132            .pre_keys
133            .get(&id)
134            .ok_or(SignalProtocolError::InvalidPreKeyId)?
135            .clone())
136    }
137
138    async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
139        // This overwrites old values, which matches Java behavior, but is it correct?
140        self.pre_keys.insert(id, record.to_owned());
141        Ok(())
142    }
143
144    async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
145        // If id does not exist this silently does nothing
146        self.pre_keys.remove(&id);
147        Ok(())
148    }
149}
150
151/// Reference implementation of [traits::SignedPreKeyStore].
152#[derive(Clone)]
153pub struct InMemSignedPreKeyStore {
154    signed_pre_keys: HashMap<SignedPreKeyId, SignedPreKeyRecord>,
155}
156
157impl InMemSignedPreKeyStore {
158    /// Create an empty signed pre-key store.
159    pub fn new() -> Self {
160        Self {
161            signed_pre_keys: HashMap::new(),
162        }
163    }
164
165    /// Returns all registered signed pre-key ids
166    pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
167        self.signed_pre_keys.keys()
168    }
169}
170
171impl Default for InMemSignedPreKeyStore {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177#[async_trait(?Send)]
178impl traits::SignedPreKeyStore for InMemSignedPreKeyStore {
179    async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
180        Ok(self
181            .signed_pre_keys
182            .get(&id)
183            .ok_or(SignalProtocolError::InvalidSignedPreKeyId)?
184            .clone())
185    }
186
187    async fn save_signed_pre_key(
188        &mut self,
189        id: SignedPreKeyId,
190        record: &SignedPreKeyRecord,
191    ) -> Result<()> {
192        // This overwrites old values, which matches Java behavior, but is it correct?
193        self.signed_pre_keys.insert(id, record.to_owned());
194        Ok(())
195    }
196}
197
198/// Reference implementation of [traits::KyberPreKeyStore].
199#[derive(Clone)]
200pub struct InMemKyberPreKeyStore {
201    kyber_pre_keys: HashMap<KyberPreKeyId, KyberPreKeyRecord>,
202}
203
204impl InMemKyberPreKeyStore {
205    /// Create an empty kyber pre-key store.
206    pub fn new() -> Self {
207        Self {
208            kyber_pre_keys: HashMap::new(),
209        }
210    }
211
212    /// Returns all registered Kyber pre-key ids
213    pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
214        self.kyber_pre_keys.keys()
215    }
216}
217
218impl Default for InMemKyberPreKeyStore {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224#[async_trait(?Send)]
225impl traits::KyberPreKeyStore for InMemKyberPreKeyStore {
226    async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
227        Ok(self
228            .kyber_pre_keys
229            .get(&kyber_prekey_id)
230            .ok_or(SignalProtocolError::InvalidKyberPreKeyId)?
231            .clone())
232    }
233
234    async fn save_kyber_pre_key(
235        &mut self,
236        kyber_prekey_id: KyberPreKeyId,
237        record: &KyberPreKeyRecord,
238    ) -> Result<()> {
239        self.kyber_pre_keys
240            .insert(kyber_prekey_id, record.to_owned());
241        Ok(())
242    }
243
244    async fn mark_kyber_pre_key_used(&mut self, _kyber_prekey_id: KyberPreKeyId) -> Result<()> {
245        Ok(())
246    }
247}
248
249/// Reference implementation of [traits::SessionStore].
250#[derive(Clone)]
251pub struct InMemSessionStore {
252    sessions: HashMap<ProtocolAddress, SessionRecord>,
253}
254
255impl InMemSessionStore {
256    /// Create an empty session store.
257    pub fn new() -> Self {
258        Self {
259            sessions: HashMap::new(),
260        }
261    }
262
263    /// Bulk version of [`SessionStore::load_session`].
264    ///
265    /// Useful for [crate::sealed_sender_multi_recipient_encrypt].
266    ///
267    /// [`SessionStore::load_session`]: crate::SessionStore::load_session
268    pub fn load_existing_sessions(
269        &self,
270        addresses: &[&ProtocolAddress],
271    ) -> Result<Vec<&SessionRecord>> {
272        addresses
273            .iter()
274            .map(|&address| {
275                self.sessions
276                    .get(address)
277                    .ok_or_else(|| SignalProtocolError::SessionNotFound(address.clone()))
278            })
279            .collect()
280    }
281}
282
283impl Default for InMemSessionStore {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289#[async_trait(?Send)]
290impl traits::SessionStore for InMemSessionStore {
291    async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
292        match self.sessions.get(address) {
293            None => Ok(None),
294            Some(s) => Ok(Some(s.clone())),
295        }
296    }
297
298    async fn store_session(
299        &mut self,
300        address: &ProtocolAddress,
301        record: &SessionRecord,
302    ) -> Result<()> {
303        self.sessions.insert(address.clone(), record.clone());
304        Ok(())
305    }
306}
307
308/// Reference implementation of [traits::SenderKeyStore].
309#[derive(Clone)]
310pub struct InMemSenderKeyStore {
311    // We use Cow keys in order to store owned values but compare to referenced ones.
312    // See https://users.rust-lang.org/t/hashmap-with-tuple-keys/12711/6.
313    keys: HashMap<(Cow<'static, ProtocolAddress>, Uuid), SenderKeyRecord>,
314}
315
316impl InMemSenderKeyStore {
317    /// Create an empty sender key store.
318    pub fn new() -> Self {
319        Self {
320            keys: HashMap::new(),
321        }
322    }
323}
324
325impl Default for InMemSenderKeyStore {
326    fn default() -> Self {
327        Self::new()
328    }
329}
330
331#[async_trait(?Send)]
332impl traits::SenderKeyStore for InMemSenderKeyStore {
333    async fn store_sender_key(
334        &mut self,
335        sender: &ProtocolAddress,
336        distribution_id: Uuid,
337        record: &SenderKeyRecord,
338    ) -> Result<()> {
339        self.keys.insert(
340            (Cow::Owned(sender.clone()), distribution_id),
341            record.clone(),
342        );
343        Ok(())
344    }
345
346    async fn load_sender_key(
347        &mut self,
348        sender: &ProtocolAddress,
349        distribution_id: Uuid,
350    ) -> Result<Option<SenderKeyRecord>> {
351        Ok(self
352            .keys
353            .get(&(Cow::Borrowed(sender), distribution_id))
354            .cloned())
355    }
356}
357
358/// Reference implementation of [traits::ProtocolStore].
359#[allow(missing_docs)]
360#[derive(Clone)]
361pub struct InMemSignalProtocolStore {
362    pub session_store: InMemSessionStore,
363    pub pre_key_store: InMemPreKeyStore,
364    pub signed_pre_key_store: InMemSignedPreKeyStore,
365    pub kyber_pre_key_store: InMemKyberPreKeyStore,
366    pub identity_store: InMemIdentityKeyStore,
367    pub sender_key_store: InMemSenderKeyStore,
368}
369
370impl InMemSignalProtocolStore {
371    /// Create an object with the minimal implementation of [traits::ProtocolStore], representing
372    /// the given identity `key_pair` along with the separate randomly chosen `registration_id`.
373    pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Result<Self> {
374        Ok(Self {
375            session_store: InMemSessionStore::new(),
376            pre_key_store: InMemPreKeyStore::new(),
377            signed_pre_key_store: InMemSignedPreKeyStore::new(),
378            kyber_pre_key_store: InMemKyberPreKeyStore::new(),
379            identity_store: InMemIdentityKeyStore::new(key_pair, registration_id),
380            sender_key_store: InMemSenderKeyStore::new(),
381        })
382    }
383
384    /// Returns all registered pre-key ids
385    pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
386        self.pre_key_store.all_pre_key_ids()
387    }
388
389    /// Returns all registered signed pre-key ids
390    pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
391        self.signed_pre_key_store.all_signed_pre_key_ids()
392    }
393
394    /// Returns all registered Kyber pre-key ids
395    pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
396        self.kyber_pre_key_store.all_kyber_pre_key_ids()
397    }
398}
399
400#[async_trait(?Send)]
401impl traits::IdentityKeyStore for InMemSignalProtocolStore {
402    async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
403        self.identity_store.get_identity_key_pair().await
404    }
405
406    async fn get_local_registration_id(&self) -> Result<u32> {
407        self.identity_store.get_local_registration_id().await
408    }
409
410    async fn save_identity(
411        &mut self,
412        address: &ProtocolAddress,
413        identity: &IdentityKey,
414    ) -> Result<bool> {
415        self.identity_store.save_identity(address, identity).await
416    }
417
418    async fn is_trusted_identity(
419        &self,
420        address: &ProtocolAddress,
421        identity: &IdentityKey,
422        direction: traits::Direction,
423    ) -> Result<bool> {
424        self.identity_store
425            .is_trusted_identity(address, identity, direction)
426            .await
427    }
428
429    async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
430        self.identity_store.get_identity(address).await
431    }
432}
433
434#[async_trait(?Send)]
435impl traits::PreKeyStore for InMemSignalProtocolStore {
436    async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
437        self.pre_key_store.get_pre_key(id).await
438    }
439
440    async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
441        self.pre_key_store.save_pre_key(id, record).await
442    }
443
444    async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
445        self.pre_key_store.remove_pre_key(id).await
446    }
447}
448
449#[async_trait(?Send)]
450impl traits::SignedPreKeyStore for InMemSignalProtocolStore {
451    async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
452        self.signed_pre_key_store.get_signed_pre_key(id).await
453    }
454
455    async fn save_signed_pre_key(
456        &mut self,
457        id: SignedPreKeyId,
458        record: &SignedPreKeyRecord,
459    ) -> Result<()> {
460        self.signed_pre_key_store
461            .save_signed_pre_key(id, record)
462            .await
463    }
464}
465
466#[async_trait(?Send)]
467impl traits::KyberPreKeyStore for InMemSignalProtocolStore {
468    async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
469        self.kyber_pre_key_store
470            .get_kyber_pre_key(kyber_prekey_id)
471            .await
472    }
473
474    async fn save_kyber_pre_key(
475        &mut self,
476        kyber_prekey_id: KyberPreKeyId,
477        record: &KyberPreKeyRecord,
478    ) -> Result<()> {
479        self.kyber_pre_key_store
480            .save_kyber_pre_key(kyber_prekey_id, record)
481            .await
482    }
483
484    async fn mark_kyber_pre_key_used(&mut self, kyber_prekey_id: KyberPreKeyId) -> Result<()> {
485        self.kyber_pre_key_store
486            .mark_kyber_pre_key_used(kyber_prekey_id)
487            .await
488    }
489}
490
491#[async_trait(?Send)]
492impl traits::SessionStore for InMemSignalProtocolStore {
493    async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
494        self.session_store.load_session(address).await
495    }
496
497    async fn store_session(
498        &mut self,
499        address: &ProtocolAddress,
500        record: &SessionRecord,
501    ) -> Result<()> {
502        self.session_store.store_session(address, record).await
503    }
504}
505
506#[async_trait(?Send)]
507impl traits::SenderKeyStore for InMemSignalProtocolStore {
508    async fn store_sender_key(
509        &mut self,
510        sender: &ProtocolAddress,
511        distribution_id: Uuid,
512        record: &SenderKeyRecord,
513    ) -> Result<()> {
514        self.sender_key_store
515            .store_sender_key(sender, distribution_id, record)
516            .await
517    }
518
519    async fn load_sender_key(
520        &mut self,
521        sender: &ProtocolAddress,
522        distribution_id: Uuid,
523    ) -> Result<Option<SenderKeyRecord>> {
524        self.sender_key_store
525            .load_sender_key(sender, distribution_id)
526            .await
527    }
528}
529
530impl traits::ProtocolStore for InMemSignalProtocolStore {}