libsignal_protocol/storage/
inmem.rs

1//
2// Copyright 2020-2022 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6//! Implementations for stores defined in [super::traits].
7//!
8//! These implementations are purely in-memory, and therefore most likely useful for testing.
9
10use std::borrow::Cow;
11use std::collections::HashMap;
12
13use async_trait::async_trait;
14use uuid::Uuid;
15
16use crate::storage::traits::{self, IdentityChange};
17use crate::{
18    IdentityKey, IdentityKeyPair, KyberPreKeyId, KyberPreKeyRecord, PreKeyId, PreKeyRecord,
19    ProtocolAddress, Result, SenderKeyRecord, SessionRecord, SignalProtocolError, SignedPreKeyId,
20    SignedPreKeyRecord,
21};
22
23/// Reference implementation of [traits::IdentityKeyStore].
24#[derive(Clone)]
25pub struct InMemIdentityKeyStore {
26    key_pair: IdentityKeyPair,
27    registration_id: u32,
28    known_keys: HashMap<ProtocolAddress, IdentityKey>,
29}
30
31impl InMemIdentityKeyStore {
32    /// Create a new instance.
33    ///
34    /// `key_pair` corresponds to [traits::IdentityKeyStore::get_identity_key_pair], and
35    /// `registration_id` corresponds to [traits::IdentityKeyStore::get_local_registration_id].
36    pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Self {
37        Self {
38            key_pair,
39            registration_id,
40            known_keys: HashMap::new(),
41        }
42    }
43
44    /// Clear the mapping of known keys.
45    pub fn reset(&mut self) {
46        self.known_keys.clear();
47    }
48}
49
50#[async_trait(?Send)]
51impl traits::IdentityKeyStore for InMemIdentityKeyStore {
52    async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
53        Ok(self.key_pair)
54    }
55
56    async fn get_local_registration_id(&self) -> Result<u32> {
57        Ok(self.registration_id)
58    }
59
60    async fn save_identity(
61        &mut self,
62        address: &ProtocolAddress,
63        identity: &IdentityKey,
64    ) -> Result<IdentityChange> {
65        match self.known_keys.get(address) {
66            None => {
67                self.known_keys.insert(address.clone(), *identity);
68                Ok(IdentityChange::NewOrUnchanged)
69            }
70            Some(k) if k == identity => Ok(IdentityChange::NewOrUnchanged),
71            Some(_k) => {
72                self.known_keys.insert(address.clone(), *identity);
73                Ok(IdentityChange::ReplacedExisting)
74            }
75        }
76    }
77
78    async fn is_trusted_identity(
79        &self,
80        address: &ProtocolAddress,
81        identity: &IdentityKey,
82        _direction: traits::Direction,
83    ) -> Result<bool> {
84        match self.known_keys.get(address) {
85            None => {
86                Ok(true) // first use
87            }
88            Some(k) => Ok(k == identity),
89        }
90    }
91
92    async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
93        match self.known_keys.get(address) {
94            None => Ok(None),
95            Some(k) => Ok(Some(k.to_owned())),
96        }
97    }
98}
99
100/// Reference implementation of [traits::PreKeyStore].
101#[derive(Clone)]
102pub struct InMemPreKeyStore {
103    pre_keys: HashMap<PreKeyId, PreKeyRecord>,
104}
105
106impl InMemPreKeyStore {
107    /// Create an empty pre-key store.
108    pub fn new() -> Self {
109        Self {
110            pre_keys: HashMap::new(),
111        }
112    }
113
114    /// Returns all registered pre-key ids
115    pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
116        self.pre_keys.keys()
117    }
118}
119
120impl Default for InMemPreKeyStore {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126#[async_trait(?Send)]
127impl traits::PreKeyStore for InMemPreKeyStore {
128    async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
129        Ok(self
130            .pre_keys
131            .get(&id)
132            .ok_or(SignalProtocolError::InvalidPreKeyId)?
133            .clone())
134    }
135
136    async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
137        // This overwrites old values, which matches Java behavior, but is it correct?
138        self.pre_keys.insert(id, record.to_owned());
139        Ok(())
140    }
141
142    async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
143        // If id does not exist this silently does nothing
144        self.pre_keys.remove(&id);
145        Ok(())
146    }
147}
148
149/// Reference implementation of [traits::SignedPreKeyStore].
150#[derive(Clone)]
151pub struct InMemSignedPreKeyStore {
152    signed_pre_keys: HashMap<SignedPreKeyId, SignedPreKeyRecord>,
153}
154
155impl InMemSignedPreKeyStore {
156    /// Create an empty signed pre-key store.
157    pub fn new() -> Self {
158        Self {
159            signed_pre_keys: HashMap::new(),
160        }
161    }
162
163    /// Returns all registered signed pre-key ids
164    pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
165        self.signed_pre_keys.keys()
166    }
167}
168
169impl Default for InMemSignedPreKeyStore {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175#[async_trait(?Send)]
176impl traits::SignedPreKeyStore for InMemSignedPreKeyStore {
177    async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
178        Ok(self
179            .signed_pre_keys
180            .get(&id)
181            .ok_or(SignalProtocolError::InvalidSignedPreKeyId)?
182            .clone())
183    }
184
185    async fn save_signed_pre_key(
186        &mut self,
187        id: SignedPreKeyId,
188        record: &SignedPreKeyRecord,
189    ) -> Result<()> {
190        // This overwrites old values, which matches Java behavior, but is it correct?
191        self.signed_pre_keys.insert(id, record.to_owned());
192        Ok(())
193    }
194}
195
196/// Reference implementation of [traits::KyberPreKeyStore].
197#[derive(Clone)]
198pub struct InMemKyberPreKeyStore {
199    kyber_pre_keys: HashMap<KyberPreKeyId, KyberPreKeyRecord>,
200}
201
202impl InMemKyberPreKeyStore {
203    /// Create an empty kyber pre-key store.
204    pub fn new() -> Self {
205        Self {
206            kyber_pre_keys: HashMap::new(),
207        }
208    }
209
210    /// Returns all registered Kyber pre-key ids
211    pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
212        self.kyber_pre_keys.keys()
213    }
214}
215
216impl Default for InMemKyberPreKeyStore {
217    fn default() -> Self {
218        Self::new()
219    }
220}
221
222#[async_trait(?Send)]
223impl traits::KyberPreKeyStore for InMemKyberPreKeyStore {
224    async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
225        Ok(self
226            .kyber_pre_keys
227            .get(&kyber_prekey_id)
228            .ok_or(SignalProtocolError::InvalidKyberPreKeyId)?
229            .clone())
230    }
231
232    async fn save_kyber_pre_key(
233        &mut self,
234        kyber_prekey_id: KyberPreKeyId,
235        record: &KyberPreKeyRecord,
236    ) -> Result<()> {
237        self.kyber_pre_keys
238            .insert(kyber_prekey_id, record.to_owned());
239        Ok(())
240    }
241
242    async fn mark_kyber_pre_key_used(&mut self, _kyber_prekey_id: KyberPreKeyId) -> Result<()> {
243        Ok(())
244    }
245}
246
247/// Reference implementation of [traits::SessionStore].
248#[derive(Clone)]
249pub struct InMemSessionStore {
250    sessions: HashMap<ProtocolAddress, SessionRecord>,
251}
252
253impl InMemSessionStore {
254    /// Create an empty session store.
255    pub fn new() -> Self {
256        Self {
257            sessions: HashMap::new(),
258        }
259    }
260
261    /// Bulk version of [`SessionStore::load_session`].
262    ///
263    /// Useful for [crate::sealed_sender_multi_recipient_encrypt].
264    ///
265    /// [`SessionStore::load_session`]: crate::SessionStore::load_session
266    pub fn load_existing_sessions(
267        &self,
268        addresses: &[&ProtocolAddress],
269    ) -> Result<Vec<&SessionRecord>> {
270        addresses
271            .iter()
272            .map(|&address| {
273                self.sessions
274                    .get(address)
275                    .ok_or_else(|| SignalProtocolError::SessionNotFound(address.clone()))
276            })
277            .collect()
278    }
279}
280
281impl Default for InMemSessionStore {
282    fn default() -> Self {
283        Self::new()
284    }
285}
286
287#[async_trait(?Send)]
288impl traits::SessionStore for InMemSessionStore {
289    async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
290        match self.sessions.get(address) {
291            None => Ok(None),
292            Some(s) => Ok(Some(s.clone())),
293        }
294    }
295
296    async fn store_session(
297        &mut self,
298        address: &ProtocolAddress,
299        record: &SessionRecord,
300    ) -> Result<()> {
301        self.sessions.insert(address.clone(), record.clone());
302        Ok(())
303    }
304}
305
306/// Reference implementation of [traits::SenderKeyStore].
307#[derive(Clone)]
308pub struct InMemSenderKeyStore {
309    // We use Cow keys in order to store owned values but compare to referenced ones.
310    // See https://users.rust-lang.org/t/hashmap-with-tuple-keys/12711/6.
311    keys: HashMap<(Cow<'static, ProtocolAddress>, Uuid), SenderKeyRecord>,
312}
313
314impl InMemSenderKeyStore {
315    /// Create an empty sender key store.
316    pub fn new() -> Self {
317        Self {
318            keys: HashMap::new(),
319        }
320    }
321}
322
323impl Default for InMemSenderKeyStore {
324    fn default() -> Self {
325        Self::new()
326    }
327}
328
329#[async_trait(?Send)]
330impl traits::SenderKeyStore for InMemSenderKeyStore {
331    async fn store_sender_key(
332        &mut self,
333        sender: &ProtocolAddress,
334        distribution_id: Uuid,
335        record: &SenderKeyRecord,
336    ) -> Result<()> {
337        self.keys.insert(
338            (Cow::Owned(sender.clone()), distribution_id),
339            record.clone(),
340        );
341        Ok(())
342    }
343
344    async fn load_sender_key(
345        &mut self,
346        sender: &ProtocolAddress,
347        distribution_id: Uuid,
348    ) -> Result<Option<SenderKeyRecord>> {
349        Ok(self
350            .keys
351            .get(&(Cow::Borrowed(sender), distribution_id))
352            .cloned())
353    }
354}
355
356/// Reference implementation of [traits::ProtocolStore].
357#[allow(missing_docs)]
358#[derive(Clone)]
359pub struct InMemSignalProtocolStore {
360    pub session_store: InMemSessionStore,
361    pub pre_key_store: InMemPreKeyStore,
362    pub signed_pre_key_store: InMemSignedPreKeyStore,
363    pub kyber_pre_key_store: InMemKyberPreKeyStore,
364    pub identity_store: InMemIdentityKeyStore,
365    pub sender_key_store: InMemSenderKeyStore,
366}
367
368impl InMemSignalProtocolStore {
369    /// Create an object with the minimal implementation of [traits::ProtocolStore], representing
370    /// the given identity `key_pair` along with the separate randomly chosen `registration_id`.
371    pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Result<Self> {
372        Ok(Self {
373            session_store: InMemSessionStore::new(),
374            pre_key_store: InMemPreKeyStore::new(),
375            signed_pre_key_store: InMemSignedPreKeyStore::new(),
376            kyber_pre_key_store: InMemKyberPreKeyStore::new(),
377            identity_store: InMemIdentityKeyStore::new(key_pair, registration_id),
378            sender_key_store: InMemSenderKeyStore::new(),
379        })
380    }
381
382    /// Returns all registered pre-key ids
383    pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
384        self.pre_key_store.all_pre_key_ids()
385    }
386
387    /// Returns all registered signed pre-key ids
388    pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
389        self.signed_pre_key_store.all_signed_pre_key_ids()
390    }
391
392    /// Returns all registered Kyber pre-key ids
393    pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
394        self.kyber_pre_key_store.all_kyber_pre_key_ids()
395    }
396}
397
398#[async_trait(?Send)]
399impl traits::IdentityKeyStore for InMemSignalProtocolStore {
400    async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
401        self.identity_store.get_identity_key_pair().await
402    }
403
404    async fn get_local_registration_id(&self) -> Result<u32> {
405        self.identity_store.get_local_registration_id().await
406    }
407
408    async fn save_identity(
409        &mut self,
410        address: &ProtocolAddress,
411        identity: &IdentityKey,
412    ) -> Result<IdentityChange> {
413        self.identity_store.save_identity(address, identity).await
414    }
415
416    async fn is_trusted_identity(
417        &self,
418        address: &ProtocolAddress,
419        identity: &IdentityKey,
420        direction: traits::Direction,
421    ) -> Result<bool> {
422        self.identity_store
423            .is_trusted_identity(address, identity, direction)
424            .await
425    }
426
427    async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
428        self.identity_store.get_identity(address).await
429    }
430}
431
432#[async_trait(?Send)]
433impl traits::PreKeyStore for InMemSignalProtocolStore {
434    async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
435        self.pre_key_store.get_pre_key(id).await
436    }
437
438    async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
439        self.pre_key_store.save_pre_key(id, record).await
440    }
441
442    async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
443        self.pre_key_store.remove_pre_key(id).await
444    }
445}
446
447#[async_trait(?Send)]
448impl traits::SignedPreKeyStore for InMemSignalProtocolStore {
449    async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
450        self.signed_pre_key_store.get_signed_pre_key(id).await
451    }
452
453    async fn save_signed_pre_key(
454        &mut self,
455        id: SignedPreKeyId,
456        record: &SignedPreKeyRecord,
457    ) -> Result<()> {
458        self.signed_pre_key_store
459            .save_signed_pre_key(id, record)
460            .await
461    }
462}
463
464#[async_trait(?Send)]
465impl traits::KyberPreKeyStore for InMemSignalProtocolStore {
466    async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
467        self.kyber_pre_key_store
468            .get_kyber_pre_key(kyber_prekey_id)
469            .await
470    }
471
472    async fn save_kyber_pre_key(
473        &mut self,
474        kyber_prekey_id: KyberPreKeyId,
475        record: &KyberPreKeyRecord,
476    ) -> Result<()> {
477        self.kyber_pre_key_store
478            .save_kyber_pre_key(kyber_prekey_id, record)
479            .await
480    }
481
482    async fn mark_kyber_pre_key_used(&mut self, kyber_prekey_id: KyberPreKeyId) -> Result<()> {
483        self.kyber_pre_key_store
484            .mark_kyber_pre_key_used(kyber_prekey_id)
485            .await
486    }
487}
488
489#[async_trait(?Send)]
490impl traits::SessionStore for InMemSignalProtocolStore {
491    async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
492        self.session_store.load_session(address).await
493    }
494
495    async fn store_session(
496        &mut self,
497        address: &ProtocolAddress,
498        record: &SessionRecord,
499    ) -> Result<()> {
500        self.session_store.store_session(address, record).await
501    }
502}
503
504#[async_trait(?Send)]
505impl traits::SenderKeyStore for InMemSignalProtocolStore {
506    async fn store_sender_key(
507        &mut self,
508        sender: &ProtocolAddress,
509        distribution_id: Uuid,
510        record: &SenderKeyRecord,
511    ) -> Result<()> {
512        self.sender_key_store
513            .store_sender_key(sender, distribution_id, record)
514            .await
515    }
516
517    async fn load_sender_key(
518        &mut self,
519        sender: &ProtocolAddress,
520        distribution_id: Uuid,
521    ) -> Result<Option<SenderKeyRecord>> {
522        self.sender_key_store
523            .load_sender_key(sender, distribution_id)
524            .await
525    }
526}
527
528impl traits::ProtocolStore for InMemSignalProtocolStore {}