libsignal_protocol/storage/
inmem.rs1use std::borrow::Cow;
11use std::collections::HashMap;
12
13use async_trait::async_trait;
14use uuid::Uuid;
15
16use crate::storage::traits::{self, IdentityChange};
17use crate::{
18 IdentityKey, IdentityKeyPair, KyberPreKeyId, KyberPreKeyRecord, PreKeyId, PreKeyRecord,
19 ProtocolAddress, Result, SenderKeyRecord, SessionRecord, SignalProtocolError, SignedPreKeyId,
20 SignedPreKeyRecord,
21};
22
23#[derive(Clone)]
25pub struct InMemIdentityKeyStore {
26 key_pair: IdentityKeyPair,
27 registration_id: u32,
28 known_keys: HashMap<ProtocolAddress, IdentityKey>,
29}
30
31impl InMemIdentityKeyStore {
32 pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Self {
37 Self {
38 key_pair,
39 registration_id,
40 known_keys: HashMap::new(),
41 }
42 }
43
44 pub fn reset(&mut self) {
46 self.known_keys.clear();
47 }
48}
49
50#[async_trait(?Send)]
51impl traits::IdentityKeyStore for InMemIdentityKeyStore {
52 async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
53 Ok(self.key_pair)
54 }
55
56 async fn get_local_registration_id(&self) -> Result<u32> {
57 Ok(self.registration_id)
58 }
59
60 async fn save_identity(
61 &mut self,
62 address: &ProtocolAddress,
63 identity: &IdentityKey,
64 ) -> Result<IdentityChange> {
65 match self.known_keys.get(address) {
66 None => {
67 self.known_keys.insert(address.clone(), *identity);
68 Ok(IdentityChange::NewOrUnchanged)
69 }
70 Some(k) if k == identity => Ok(IdentityChange::NewOrUnchanged),
71 Some(_k) => {
72 self.known_keys.insert(address.clone(), *identity);
73 Ok(IdentityChange::ReplacedExisting)
74 }
75 }
76 }
77
78 async fn is_trusted_identity(
79 &self,
80 address: &ProtocolAddress,
81 identity: &IdentityKey,
82 _direction: traits::Direction,
83 ) -> Result<bool> {
84 match self.known_keys.get(address) {
85 None => {
86 Ok(true) }
88 Some(k) => Ok(k == identity),
89 }
90 }
91
92 async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
93 match self.known_keys.get(address) {
94 None => Ok(None),
95 Some(k) => Ok(Some(k.to_owned())),
96 }
97 }
98}
99
100#[derive(Clone)]
102pub struct InMemPreKeyStore {
103 pre_keys: HashMap<PreKeyId, PreKeyRecord>,
104}
105
106impl InMemPreKeyStore {
107 pub fn new() -> Self {
109 Self {
110 pre_keys: HashMap::new(),
111 }
112 }
113
114 pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
116 self.pre_keys.keys()
117 }
118}
119
120impl Default for InMemPreKeyStore {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126#[async_trait(?Send)]
127impl traits::PreKeyStore for InMemPreKeyStore {
128 async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
129 Ok(self
130 .pre_keys
131 .get(&id)
132 .ok_or(SignalProtocolError::InvalidPreKeyId)?
133 .clone())
134 }
135
136 async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
137 self.pre_keys.insert(id, record.to_owned());
139 Ok(())
140 }
141
142 async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
143 self.pre_keys.remove(&id);
145 Ok(())
146 }
147}
148
149#[derive(Clone)]
151pub struct InMemSignedPreKeyStore {
152 signed_pre_keys: HashMap<SignedPreKeyId, SignedPreKeyRecord>,
153}
154
155impl InMemSignedPreKeyStore {
156 pub fn new() -> Self {
158 Self {
159 signed_pre_keys: HashMap::new(),
160 }
161 }
162
163 pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
165 self.signed_pre_keys.keys()
166 }
167}
168
169impl Default for InMemSignedPreKeyStore {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175#[async_trait(?Send)]
176impl traits::SignedPreKeyStore for InMemSignedPreKeyStore {
177 async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
178 Ok(self
179 .signed_pre_keys
180 .get(&id)
181 .ok_or(SignalProtocolError::InvalidSignedPreKeyId)?
182 .clone())
183 }
184
185 async fn save_signed_pre_key(
186 &mut self,
187 id: SignedPreKeyId,
188 record: &SignedPreKeyRecord,
189 ) -> Result<()> {
190 self.signed_pre_keys.insert(id, record.to_owned());
192 Ok(())
193 }
194}
195
196#[derive(Clone)]
198pub struct InMemKyberPreKeyStore {
199 kyber_pre_keys: HashMap<KyberPreKeyId, KyberPreKeyRecord>,
200}
201
202impl InMemKyberPreKeyStore {
203 pub fn new() -> Self {
205 Self {
206 kyber_pre_keys: HashMap::new(),
207 }
208 }
209
210 pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
212 self.kyber_pre_keys.keys()
213 }
214}
215
216impl Default for InMemKyberPreKeyStore {
217 fn default() -> Self {
218 Self::new()
219 }
220}
221
222#[async_trait(?Send)]
223impl traits::KyberPreKeyStore for InMemKyberPreKeyStore {
224 async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
225 Ok(self
226 .kyber_pre_keys
227 .get(&kyber_prekey_id)
228 .ok_or(SignalProtocolError::InvalidKyberPreKeyId)?
229 .clone())
230 }
231
232 async fn save_kyber_pre_key(
233 &mut self,
234 kyber_prekey_id: KyberPreKeyId,
235 record: &KyberPreKeyRecord,
236 ) -> Result<()> {
237 self.kyber_pre_keys
238 .insert(kyber_prekey_id, record.to_owned());
239 Ok(())
240 }
241
242 async fn mark_kyber_pre_key_used(&mut self, _kyber_prekey_id: KyberPreKeyId) -> Result<()> {
243 Ok(())
244 }
245}
246
247#[derive(Clone)]
249pub struct InMemSessionStore {
250 sessions: HashMap<ProtocolAddress, SessionRecord>,
251}
252
253impl InMemSessionStore {
254 pub fn new() -> Self {
256 Self {
257 sessions: HashMap::new(),
258 }
259 }
260
261 pub fn load_existing_sessions(
267 &self,
268 addresses: &[&ProtocolAddress],
269 ) -> Result<Vec<&SessionRecord>> {
270 addresses
271 .iter()
272 .map(|&address| {
273 self.sessions
274 .get(address)
275 .ok_or_else(|| SignalProtocolError::SessionNotFound(address.clone()))
276 })
277 .collect()
278 }
279}
280
281impl Default for InMemSessionStore {
282 fn default() -> Self {
283 Self::new()
284 }
285}
286
287#[async_trait(?Send)]
288impl traits::SessionStore for InMemSessionStore {
289 async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
290 match self.sessions.get(address) {
291 None => Ok(None),
292 Some(s) => Ok(Some(s.clone())),
293 }
294 }
295
296 async fn store_session(
297 &mut self,
298 address: &ProtocolAddress,
299 record: &SessionRecord,
300 ) -> Result<()> {
301 self.sessions.insert(address.clone(), record.clone());
302 Ok(())
303 }
304}
305
306#[derive(Clone)]
308pub struct InMemSenderKeyStore {
309 keys: HashMap<(Cow<'static, ProtocolAddress>, Uuid), SenderKeyRecord>,
312}
313
314impl InMemSenderKeyStore {
315 pub fn new() -> Self {
317 Self {
318 keys: HashMap::new(),
319 }
320 }
321}
322
323impl Default for InMemSenderKeyStore {
324 fn default() -> Self {
325 Self::new()
326 }
327}
328
329#[async_trait(?Send)]
330impl traits::SenderKeyStore for InMemSenderKeyStore {
331 async fn store_sender_key(
332 &mut self,
333 sender: &ProtocolAddress,
334 distribution_id: Uuid,
335 record: &SenderKeyRecord,
336 ) -> Result<()> {
337 self.keys.insert(
338 (Cow::Owned(sender.clone()), distribution_id),
339 record.clone(),
340 );
341 Ok(())
342 }
343
344 async fn load_sender_key(
345 &mut self,
346 sender: &ProtocolAddress,
347 distribution_id: Uuid,
348 ) -> Result<Option<SenderKeyRecord>> {
349 Ok(self
350 .keys
351 .get(&(Cow::Borrowed(sender), distribution_id))
352 .cloned())
353 }
354}
355
356#[allow(missing_docs)]
358#[derive(Clone)]
359pub struct InMemSignalProtocolStore {
360 pub session_store: InMemSessionStore,
361 pub pre_key_store: InMemPreKeyStore,
362 pub signed_pre_key_store: InMemSignedPreKeyStore,
363 pub kyber_pre_key_store: InMemKyberPreKeyStore,
364 pub identity_store: InMemIdentityKeyStore,
365 pub sender_key_store: InMemSenderKeyStore,
366}
367
368impl InMemSignalProtocolStore {
369 pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Result<Self> {
372 Ok(Self {
373 session_store: InMemSessionStore::new(),
374 pre_key_store: InMemPreKeyStore::new(),
375 signed_pre_key_store: InMemSignedPreKeyStore::new(),
376 kyber_pre_key_store: InMemKyberPreKeyStore::new(),
377 identity_store: InMemIdentityKeyStore::new(key_pair, registration_id),
378 sender_key_store: InMemSenderKeyStore::new(),
379 })
380 }
381
382 pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
384 self.pre_key_store.all_pre_key_ids()
385 }
386
387 pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
389 self.signed_pre_key_store.all_signed_pre_key_ids()
390 }
391
392 pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
394 self.kyber_pre_key_store.all_kyber_pre_key_ids()
395 }
396}
397
398#[async_trait(?Send)]
399impl traits::IdentityKeyStore for InMemSignalProtocolStore {
400 async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
401 self.identity_store.get_identity_key_pair().await
402 }
403
404 async fn get_local_registration_id(&self) -> Result<u32> {
405 self.identity_store.get_local_registration_id().await
406 }
407
408 async fn save_identity(
409 &mut self,
410 address: &ProtocolAddress,
411 identity: &IdentityKey,
412 ) -> Result<IdentityChange> {
413 self.identity_store.save_identity(address, identity).await
414 }
415
416 async fn is_trusted_identity(
417 &self,
418 address: &ProtocolAddress,
419 identity: &IdentityKey,
420 direction: traits::Direction,
421 ) -> Result<bool> {
422 self.identity_store
423 .is_trusted_identity(address, identity, direction)
424 .await
425 }
426
427 async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
428 self.identity_store.get_identity(address).await
429 }
430}
431
432#[async_trait(?Send)]
433impl traits::PreKeyStore for InMemSignalProtocolStore {
434 async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
435 self.pre_key_store.get_pre_key(id).await
436 }
437
438 async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
439 self.pre_key_store.save_pre_key(id, record).await
440 }
441
442 async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
443 self.pre_key_store.remove_pre_key(id).await
444 }
445}
446
447#[async_trait(?Send)]
448impl traits::SignedPreKeyStore for InMemSignalProtocolStore {
449 async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
450 self.signed_pre_key_store.get_signed_pre_key(id).await
451 }
452
453 async fn save_signed_pre_key(
454 &mut self,
455 id: SignedPreKeyId,
456 record: &SignedPreKeyRecord,
457 ) -> Result<()> {
458 self.signed_pre_key_store
459 .save_signed_pre_key(id, record)
460 .await
461 }
462}
463
464#[async_trait(?Send)]
465impl traits::KyberPreKeyStore for InMemSignalProtocolStore {
466 async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
467 self.kyber_pre_key_store
468 .get_kyber_pre_key(kyber_prekey_id)
469 .await
470 }
471
472 async fn save_kyber_pre_key(
473 &mut self,
474 kyber_prekey_id: KyberPreKeyId,
475 record: &KyberPreKeyRecord,
476 ) -> Result<()> {
477 self.kyber_pre_key_store
478 .save_kyber_pre_key(kyber_prekey_id, record)
479 .await
480 }
481
482 async fn mark_kyber_pre_key_used(&mut self, kyber_prekey_id: KyberPreKeyId) -> Result<()> {
483 self.kyber_pre_key_store
484 .mark_kyber_pre_key_used(kyber_prekey_id)
485 .await
486 }
487}
488
489#[async_trait(?Send)]
490impl traits::SessionStore for InMemSignalProtocolStore {
491 async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
492 self.session_store.load_session(address).await
493 }
494
495 async fn store_session(
496 &mut self,
497 address: &ProtocolAddress,
498 record: &SessionRecord,
499 ) -> Result<()> {
500 self.session_store.store_session(address, record).await
501 }
502}
503
504#[async_trait(?Send)]
505impl traits::SenderKeyStore for InMemSignalProtocolStore {
506 async fn store_sender_key(
507 &mut self,
508 sender: &ProtocolAddress,
509 distribution_id: Uuid,
510 record: &SenderKeyRecord,
511 ) -> Result<()> {
512 self.sender_key_store
513 .store_sender_key(sender, distribution_id, record)
514 .await
515 }
516
517 async fn load_sender_key(
518 &mut self,
519 sender: &ProtocolAddress,
520 distribution_id: Uuid,
521 ) -> Result<Option<SenderKeyRecord>> {
522 self.sender_key_store
523 .load_sender_key(sender, distribution_id)
524 .await
525 }
526}
527
528impl traits::ProtocolStore for InMemSignalProtocolStore {}