libsignal_protocol/storage/
inmem.rs1use std::borrow::Cow;
11use std::collections::HashMap;
12
13use async_trait::async_trait;
14use uuid::Uuid;
15
16use crate::storage::traits::{self, IdentityChange};
17use crate::{
18 CiphertextMessageType, IdentityKey, IdentityKeyPair, KyberPreKeyId, KyberPreKeyRecord,
19 PreKeyId, PreKeyRecord, ProtocolAddress, PublicKey, Result, SenderKeyRecord, SessionRecord,
20 SignalProtocolError, SignedPreKeyId, SignedPreKeyRecord,
21};
22
23#[derive(Clone)]
25pub struct InMemIdentityKeyStore {
26 key_pair: IdentityKeyPair,
27 registration_id: u32,
28 known_keys: HashMap<ProtocolAddress, IdentityKey>,
29}
30
31impl InMemIdentityKeyStore {
32 pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Self {
37 Self {
38 key_pair,
39 registration_id,
40 known_keys: HashMap::new(),
41 }
42 }
43
44 pub fn reset(&mut self) {
46 self.known_keys.clear();
47 }
48}
49
50#[async_trait(?Send)]
51impl traits::IdentityKeyStore for InMemIdentityKeyStore {
52 async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
53 Ok(self.key_pair)
54 }
55
56 async fn get_local_registration_id(&self) -> Result<u32> {
57 Ok(self.registration_id)
58 }
59
60 async fn save_identity(
61 &mut self,
62 address: &ProtocolAddress,
63 identity: &IdentityKey,
64 ) -> Result<IdentityChange> {
65 match self.known_keys.get(address) {
66 None => {
67 self.known_keys.insert(address.clone(), *identity);
68 Ok(IdentityChange::NewOrUnchanged)
69 }
70 Some(k) if k == identity => Ok(IdentityChange::NewOrUnchanged),
71 Some(_k) => {
72 self.known_keys.insert(address.clone(), *identity);
73 Ok(IdentityChange::ReplacedExisting)
74 }
75 }
76 }
77
78 async fn is_trusted_identity(
79 &self,
80 address: &ProtocolAddress,
81 identity: &IdentityKey,
82 _direction: traits::Direction,
83 ) -> Result<bool> {
84 match self.known_keys.get(address) {
85 None => {
86 Ok(true) }
88 Some(k) => Ok(k == identity),
89 }
90 }
91
92 async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
93 match self.known_keys.get(address) {
94 None => Ok(None),
95 Some(k) => Ok(Some(k.to_owned())),
96 }
97 }
98}
99
100#[derive(Clone)]
102pub struct InMemPreKeyStore {
103 pre_keys: HashMap<PreKeyId, PreKeyRecord>,
104}
105
106impl InMemPreKeyStore {
107 pub fn new() -> Self {
109 Self {
110 pre_keys: HashMap::new(),
111 }
112 }
113
114 pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
116 self.pre_keys.keys()
117 }
118}
119
120impl Default for InMemPreKeyStore {
121 fn default() -> Self {
122 Self::new()
123 }
124}
125
126#[async_trait(?Send)]
127impl traits::PreKeyStore for InMemPreKeyStore {
128 async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
129 Ok(self
130 .pre_keys
131 .get(&id)
132 .ok_or(SignalProtocolError::InvalidPreKeyId)?
133 .clone())
134 }
135
136 async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
137 self.pre_keys.insert(id, record.to_owned());
139 Ok(())
140 }
141
142 async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
143 self.pre_keys.remove(&id);
145 Ok(())
146 }
147}
148
149#[derive(Clone)]
151pub struct InMemSignedPreKeyStore {
152 signed_pre_keys: HashMap<SignedPreKeyId, SignedPreKeyRecord>,
153}
154
155impl InMemSignedPreKeyStore {
156 pub fn new() -> Self {
158 Self {
159 signed_pre_keys: HashMap::new(),
160 }
161 }
162
163 pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
165 self.signed_pre_keys.keys()
166 }
167}
168
169impl Default for InMemSignedPreKeyStore {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175#[async_trait(?Send)]
176impl traits::SignedPreKeyStore for InMemSignedPreKeyStore {
177 async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
178 Ok(self
179 .signed_pre_keys
180 .get(&id)
181 .ok_or(SignalProtocolError::InvalidSignedPreKeyId)?
182 .clone())
183 }
184
185 async fn save_signed_pre_key(
186 &mut self,
187 id: SignedPreKeyId,
188 record: &SignedPreKeyRecord,
189 ) -> Result<()> {
190 self.signed_pre_keys.insert(id, record.to_owned());
192 Ok(())
193 }
194}
195
196#[derive(Clone)]
201pub struct InMemKyberPreKeyStore {
202 kyber_pre_keys: HashMap<KyberPreKeyId, KyberPreKeyRecord>,
203 base_keys_seen: HashMap<(KyberPreKeyId, SignedPreKeyId), Vec<PublicKey>>,
204}
205
206impl InMemKyberPreKeyStore {
207 pub fn new() -> Self {
209 Self {
210 kyber_pre_keys: HashMap::new(),
211 base_keys_seen: HashMap::new(),
212 }
213 }
214
215 pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
217 self.kyber_pre_keys.keys()
218 }
219}
220
221impl Default for InMemKyberPreKeyStore {
222 fn default() -> Self {
223 Self::new()
224 }
225}
226
227#[async_trait(?Send)]
228impl traits::KyberPreKeyStore for InMemKyberPreKeyStore {
229 async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
230 Ok(self
231 .kyber_pre_keys
232 .get(&kyber_prekey_id)
233 .ok_or(SignalProtocolError::InvalidKyberPreKeyId)?
234 .clone())
235 }
236
237 async fn save_kyber_pre_key(
238 &mut self,
239 kyber_prekey_id: KyberPreKeyId,
240 record: &KyberPreKeyRecord,
241 ) -> Result<()> {
242 self.kyber_pre_keys
243 .insert(kyber_prekey_id, record.to_owned());
244 Ok(())
245 }
246
247 async fn mark_kyber_pre_key_used(
248 &mut self,
249 kyber_prekey_id: KyberPreKeyId,
250 ec_prekey_id: SignedPreKeyId,
251 base_key: &PublicKey,
252 ) -> Result<()> {
253 let base_keys_seen = self
254 .base_keys_seen
255 .entry((kyber_prekey_id, ec_prekey_id))
256 .or_default();
257 if base_keys_seen.contains(base_key) {
258 return Err(SignalProtocolError::InvalidMessage(
259 CiphertextMessageType::PreKey,
260 "reused base key",
261 ));
262 }
263 base_keys_seen.push(*base_key);
264 Ok(())
265 }
266}
267
268#[derive(Clone)]
270pub struct InMemSessionStore {
271 sessions: HashMap<ProtocolAddress, SessionRecord>,
272}
273
274impl InMemSessionStore {
275 pub fn new() -> Self {
277 Self {
278 sessions: HashMap::new(),
279 }
280 }
281
282 pub fn load_existing_sessions(
288 &self,
289 addresses: &[&ProtocolAddress],
290 ) -> Result<Vec<&SessionRecord>> {
291 addresses
292 .iter()
293 .map(|&address| {
294 self.sessions
295 .get(address)
296 .ok_or_else(|| SignalProtocolError::SessionNotFound(address.clone()))
297 })
298 .collect()
299 }
300}
301
302impl Default for InMemSessionStore {
303 fn default() -> Self {
304 Self::new()
305 }
306}
307
308#[async_trait(?Send)]
309impl traits::SessionStore for InMemSessionStore {
310 async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
311 match self.sessions.get(address) {
312 None => Ok(None),
313 Some(s) => Ok(Some(s.clone())),
314 }
315 }
316
317 async fn store_session(
318 &mut self,
319 address: &ProtocolAddress,
320 record: &SessionRecord,
321 ) -> Result<()> {
322 self.sessions.insert(address.clone(), record.clone());
323 Ok(())
324 }
325}
326
327#[derive(Clone)]
329pub struct InMemSenderKeyStore {
330 keys: HashMap<(Cow<'static, ProtocolAddress>, Uuid), SenderKeyRecord>,
333}
334
335impl InMemSenderKeyStore {
336 pub fn new() -> Self {
338 Self {
339 keys: HashMap::new(),
340 }
341 }
342}
343
344impl Default for InMemSenderKeyStore {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350#[async_trait(?Send)]
351impl traits::SenderKeyStore for InMemSenderKeyStore {
352 async fn store_sender_key(
353 &mut self,
354 sender: &ProtocolAddress,
355 distribution_id: Uuid,
356 record: &SenderKeyRecord,
357 ) -> Result<()> {
358 self.keys.insert(
359 (Cow::Owned(sender.clone()), distribution_id),
360 record.clone(),
361 );
362 Ok(())
363 }
364
365 async fn load_sender_key(
366 &mut self,
367 sender: &ProtocolAddress,
368 distribution_id: Uuid,
369 ) -> Result<Option<SenderKeyRecord>> {
370 Ok(self
371 .keys
372 .get(&(Cow::Borrowed(sender), distribution_id))
373 .cloned())
374 }
375}
376
377#[allow(missing_docs)]
379#[derive(Clone)]
380pub struct InMemSignalProtocolStore {
381 pub session_store: InMemSessionStore,
382 pub pre_key_store: InMemPreKeyStore,
383 pub signed_pre_key_store: InMemSignedPreKeyStore,
384 pub kyber_pre_key_store: InMemKyberPreKeyStore,
385 pub identity_store: InMemIdentityKeyStore,
386 pub sender_key_store: InMemSenderKeyStore,
387}
388
389impl InMemSignalProtocolStore {
390 pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Result<Self> {
393 Ok(Self {
394 session_store: InMemSessionStore::new(),
395 pre_key_store: InMemPreKeyStore::new(),
396 signed_pre_key_store: InMemSignedPreKeyStore::new(),
397 kyber_pre_key_store: InMemKyberPreKeyStore::new(),
398 identity_store: InMemIdentityKeyStore::new(key_pair, registration_id),
399 sender_key_store: InMemSenderKeyStore::new(),
400 })
401 }
402
403 pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
405 self.pre_key_store.all_pre_key_ids()
406 }
407
408 pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
410 self.signed_pre_key_store.all_signed_pre_key_ids()
411 }
412
413 pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
415 self.kyber_pre_key_store.all_kyber_pre_key_ids()
416 }
417}
418
419#[async_trait(?Send)]
420impl traits::IdentityKeyStore for InMemSignalProtocolStore {
421 async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
422 self.identity_store.get_identity_key_pair().await
423 }
424
425 async fn get_local_registration_id(&self) -> Result<u32> {
426 self.identity_store.get_local_registration_id().await
427 }
428
429 async fn save_identity(
430 &mut self,
431 address: &ProtocolAddress,
432 identity: &IdentityKey,
433 ) -> Result<IdentityChange> {
434 self.identity_store.save_identity(address, identity).await
435 }
436
437 async fn is_trusted_identity(
438 &self,
439 address: &ProtocolAddress,
440 identity: &IdentityKey,
441 direction: traits::Direction,
442 ) -> Result<bool> {
443 self.identity_store
444 .is_trusted_identity(address, identity, direction)
445 .await
446 }
447
448 async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
449 self.identity_store.get_identity(address).await
450 }
451}
452
453#[async_trait(?Send)]
454impl traits::PreKeyStore for InMemSignalProtocolStore {
455 async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
456 self.pre_key_store.get_pre_key(id).await
457 }
458
459 async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
460 self.pre_key_store.save_pre_key(id, record).await
461 }
462
463 async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
464 self.pre_key_store.remove_pre_key(id).await
465 }
466}
467
468#[async_trait(?Send)]
469impl traits::SignedPreKeyStore for InMemSignalProtocolStore {
470 async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
471 self.signed_pre_key_store.get_signed_pre_key(id).await
472 }
473
474 async fn save_signed_pre_key(
475 &mut self,
476 id: SignedPreKeyId,
477 record: &SignedPreKeyRecord,
478 ) -> Result<()> {
479 self.signed_pre_key_store
480 .save_signed_pre_key(id, record)
481 .await
482 }
483}
484
485#[async_trait(?Send)]
486impl traits::KyberPreKeyStore for InMemSignalProtocolStore {
487 async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
488 self.kyber_pre_key_store
489 .get_kyber_pre_key(kyber_prekey_id)
490 .await
491 }
492
493 async fn save_kyber_pre_key(
494 &mut self,
495 kyber_prekey_id: KyberPreKeyId,
496 record: &KyberPreKeyRecord,
497 ) -> Result<()> {
498 self.kyber_pre_key_store
499 .save_kyber_pre_key(kyber_prekey_id, record)
500 .await
501 }
502
503 async fn mark_kyber_pre_key_used(
504 &mut self,
505 kyber_prekey_id: KyberPreKeyId,
506 ec_prekey_id: SignedPreKeyId,
507 base_key: &PublicKey,
508 ) -> Result<()> {
509 self.kyber_pre_key_store
510 .mark_kyber_pre_key_used(kyber_prekey_id, ec_prekey_id, base_key)
511 .await
512 }
513}
514
515#[async_trait(?Send)]
516impl traits::SessionStore for InMemSignalProtocolStore {
517 async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
518 self.session_store.load_session(address).await
519 }
520
521 async fn store_session(
522 &mut self,
523 address: &ProtocolAddress,
524 record: &SessionRecord,
525 ) -> Result<()> {
526 self.session_store.store_session(address, record).await
527 }
528}
529
530#[async_trait(?Send)]
531impl traits::SenderKeyStore for InMemSignalProtocolStore {
532 async fn store_sender_key(
533 &mut self,
534 sender: &ProtocolAddress,
535 distribution_id: Uuid,
536 record: &SenderKeyRecord,
537 ) -> Result<()> {
538 self.sender_key_store
539 .store_sender_key(sender, distribution_id, record)
540 .await
541 }
542
543 async fn load_sender_key(
544 &mut self,
545 sender: &ProtocolAddress,
546 distribution_id: Uuid,
547 ) -> Result<Option<SenderKeyRecord>> {
548 self.sender_key_store
549 .load_sender_key(sender, distribution_id)
550 .await
551 }
552}
553
554impl traits::ProtocolStore for InMemSignalProtocolStore {}