libsignal_protocol/storage/
inmem.rs1use std::borrow::Cow;
11use std::collections::HashMap;
12
13use async_trait::async_trait;
14use uuid::Uuid;
15
16use crate::storage::traits;
17use crate::{
18 IdentityKey, IdentityKeyPair, KyberPreKeyId, KyberPreKeyRecord, PreKeyId, PreKeyRecord,
19 ProtocolAddress, Result, SenderKeyRecord, SessionRecord, SignalProtocolError, SignedPreKeyId,
20 SignedPreKeyRecord,
21};
22
23#[derive(Clone)]
25pub struct InMemIdentityKeyStore {
26 key_pair: IdentityKeyPair,
27 registration_id: u32,
28 known_keys: HashMap<ProtocolAddress, IdentityKey>,
29}
30
31impl InMemIdentityKeyStore {
32 pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Self {
37 Self {
38 key_pair,
39 registration_id,
40 known_keys: HashMap::new(),
41 }
42 }
43
44 pub fn reset(&mut self) {
46 self.known_keys.clear();
47 }
48}
49
50#[async_trait(?Send)]
51impl traits::IdentityKeyStore for InMemIdentityKeyStore {
52 async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
53 Ok(self.key_pair)
54 }
55
56 async fn get_local_registration_id(&self) -> Result<u32> {
57 Ok(self.registration_id)
58 }
59
60 async fn save_identity(
61 &mut self,
62 address: &ProtocolAddress,
63 identity: &IdentityKey,
64 ) -> Result<bool> {
65 match self.known_keys.get(address) {
66 None => {
67 self.known_keys.insert(address.clone(), *identity);
68 Ok(false) }
70 Some(k) if k == identity => {
71 Ok(false) }
73 Some(_k) => {
74 self.known_keys.insert(address.clone(), *identity);
75 Ok(true) }
77 }
78 }
79
80 async fn is_trusted_identity(
81 &self,
82 address: &ProtocolAddress,
83 identity: &IdentityKey,
84 _direction: traits::Direction,
85 ) -> Result<bool> {
86 match self.known_keys.get(address) {
87 None => {
88 Ok(true) }
90 Some(k) => Ok(k == identity),
91 }
92 }
93
94 async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
95 match self.known_keys.get(address) {
96 None => Ok(None),
97 Some(k) => Ok(Some(k.to_owned())),
98 }
99 }
100}
101
102#[derive(Clone)]
104pub struct InMemPreKeyStore {
105 pre_keys: HashMap<PreKeyId, PreKeyRecord>,
106}
107
108impl InMemPreKeyStore {
109 pub fn new() -> Self {
111 Self {
112 pre_keys: HashMap::new(),
113 }
114 }
115
116 pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
118 self.pre_keys.keys()
119 }
120}
121
122impl Default for InMemPreKeyStore {
123 fn default() -> Self {
124 Self::new()
125 }
126}
127
128#[async_trait(?Send)]
129impl traits::PreKeyStore for InMemPreKeyStore {
130 async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
131 Ok(self
132 .pre_keys
133 .get(&id)
134 .ok_or(SignalProtocolError::InvalidPreKeyId)?
135 .clone())
136 }
137
138 async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
139 self.pre_keys.insert(id, record.to_owned());
141 Ok(())
142 }
143
144 async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
145 self.pre_keys.remove(&id);
147 Ok(())
148 }
149}
150
151#[derive(Clone)]
153pub struct InMemSignedPreKeyStore {
154 signed_pre_keys: HashMap<SignedPreKeyId, SignedPreKeyRecord>,
155}
156
157impl InMemSignedPreKeyStore {
158 pub fn new() -> Self {
160 Self {
161 signed_pre_keys: HashMap::new(),
162 }
163 }
164
165 pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
167 self.signed_pre_keys.keys()
168 }
169}
170
171impl Default for InMemSignedPreKeyStore {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[async_trait(?Send)]
178impl traits::SignedPreKeyStore for InMemSignedPreKeyStore {
179 async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
180 Ok(self
181 .signed_pre_keys
182 .get(&id)
183 .ok_or(SignalProtocolError::InvalidSignedPreKeyId)?
184 .clone())
185 }
186
187 async fn save_signed_pre_key(
188 &mut self,
189 id: SignedPreKeyId,
190 record: &SignedPreKeyRecord,
191 ) -> Result<()> {
192 self.signed_pre_keys.insert(id, record.to_owned());
194 Ok(())
195 }
196}
197
198#[derive(Clone)]
200pub struct InMemKyberPreKeyStore {
201 kyber_pre_keys: HashMap<KyberPreKeyId, KyberPreKeyRecord>,
202}
203
204impl InMemKyberPreKeyStore {
205 pub fn new() -> Self {
207 Self {
208 kyber_pre_keys: HashMap::new(),
209 }
210 }
211
212 pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
214 self.kyber_pre_keys.keys()
215 }
216}
217
218impl Default for InMemKyberPreKeyStore {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224#[async_trait(?Send)]
225impl traits::KyberPreKeyStore for InMemKyberPreKeyStore {
226 async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
227 Ok(self
228 .kyber_pre_keys
229 .get(&kyber_prekey_id)
230 .ok_or(SignalProtocolError::InvalidKyberPreKeyId)?
231 .clone())
232 }
233
234 async fn save_kyber_pre_key(
235 &mut self,
236 kyber_prekey_id: KyberPreKeyId,
237 record: &KyberPreKeyRecord,
238 ) -> Result<()> {
239 self.kyber_pre_keys
240 .insert(kyber_prekey_id, record.to_owned());
241 Ok(())
242 }
243
244 async fn mark_kyber_pre_key_used(&mut self, _kyber_prekey_id: KyberPreKeyId) -> Result<()> {
245 Ok(())
246 }
247}
248
249#[derive(Clone)]
251pub struct InMemSessionStore {
252 sessions: HashMap<ProtocolAddress, SessionRecord>,
253}
254
255impl InMemSessionStore {
256 pub fn new() -> Self {
258 Self {
259 sessions: HashMap::new(),
260 }
261 }
262
263 pub fn load_existing_sessions(
269 &self,
270 addresses: &[&ProtocolAddress],
271 ) -> Result<Vec<&SessionRecord>> {
272 addresses
273 .iter()
274 .map(|&address| {
275 self.sessions
276 .get(address)
277 .ok_or_else(|| SignalProtocolError::SessionNotFound(address.clone()))
278 })
279 .collect()
280 }
281}
282
283impl Default for InMemSessionStore {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289#[async_trait(?Send)]
290impl traits::SessionStore for InMemSessionStore {
291 async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
292 match self.sessions.get(address) {
293 None => Ok(None),
294 Some(s) => Ok(Some(s.clone())),
295 }
296 }
297
298 async fn store_session(
299 &mut self,
300 address: &ProtocolAddress,
301 record: &SessionRecord,
302 ) -> Result<()> {
303 self.sessions.insert(address.clone(), record.clone());
304 Ok(())
305 }
306}
307
308#[derive(Clone)]
310pub struct InMemSenderKeyStore {
311 keys: HashMap<(Cow<'static, ProtocolAddress>, Uuid), SenderKeyRecord>,
314}
315
316impl InMemSenderKeyStore {
317 pub fn new() -> Self {
319 Self {
320 keys: HashMap::new(),
321 }
322 }
323}
324
325impl Default for InMemSenderKeyStore {
326 fn default() -> Self {
327 Self::new()
328 }
329}
330
331#[async_trait(?Send)]
332impl traits::SenderKeyStore for InMemSenderKeyStore {
333 async fn store_sender_key(
334 &mut self,
335 sender: &ProtocolAddress,
336 distribution_id: Uuid,
337 record: &SenderKeyRecord,
338 ) -> Result<()> {
339 self.keys.insert(
340 (Cow::Owned(sender.clone()), distribution_id),
341 record.clone(),
342 );
343 Ok(())
344 }
345
346 async fn load_sender_key(
347 &mut self,
348 sender: &ProtocolAddress,
349 distribution_id: Uuid,
350 ) -> Result<Option<SenderKeyRecord>> {
351 Ok(self
352 .keys
353 .get(&(Cow::Borrowed(sender), distribution_id))
354 .cloned())
355 }
356}
357
358#[allow(missing_docs)]
360#[derive(Clone)]
361pub struct InMemSignalProtocolStore {
362 pub session_store: InMemSessionStore,
363 pub pre_key_store: InMemPreKeyStore,
364 pub signed_pre_key_store: InMemSignedPreKeyStore,
365 pub kyber_pre_key_store: InMemKyberPreKeyStore,
366 pub identity_store: InMemIdentityKeyStore,
367 pub sender_key_store: InMemSenderKeyStore,
368}
369
370impl InMemSignalProtocolStore {
371 pub fn new(key_pair: IdentityKeyPair, registration_id: u32) -> Result<Self> {
374 Ok(Self {
375 session_store: InMemSessionStore::new(),
376 pre_key_store: InMemPreKeyStore::new(),
377 signed_pre_key_store: InMemSignedPreKeyStore::new(),
378 kyber_pre_key_store: InMemKyberPreKeyStore::new(),
379 identity_store: InMemIdentityKeyStore::new(key_pair, registration_id),
380 sender_key_store: InMemSenderKeyStore::new(),
381 })
382 }
383
384 pub fn all_pre_key_ids(&self) -> impl Iterator<Item = &PreKeyId> {
386 self.pre_key_store.all_pre_key_ids()
387 }
388
389 pub fn all_signed_pre_key_ids(&self) -> impl Iterator<Item = &SignedPreKeyId> {
391 self.signed_pre_key_store.all_signed_pre_key_ids()
392 }
393
394 pub fn all_kyber_pre_key_ids(&self) -> impl Iterator<Item = &KyberPreKeyId> {
396 self.kyber_pre_key_store.all_kyber_pre_key_ids()
397 }
398}
399
400#[async_trait(?Send)]
401impl traits::IdentityKeyStore for InMemSignalProtocolStore {
402 async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair> {
403 self.identity_store.get_identity_key_pair().await
404 }
405
406 async fn get_local_registration_id(&self) -> Result<u32> {
407 self.identity_store.get_local_registration_id().await
408 }
409
410 async fn save_identity(
411 &mut self,
412 address: &ProtocolAddress,
413 identity: &IdentityKey,
414 ) -> Result<bool> {
415 self.identity_store.save_identity(address, identity).await
416 }
417
418 async fn is_trusted_identity(
419 &self,
420 address: &ProtocolAddress,
421 identity: &IdentityKey,
422 direction: traits::Direction,
423 ) -> Result<bool> {
424 self.identity_store
425 .is_trusted_identity(address, identity, direction)
426 .await
427 }
428
429 async fn get_identity(&self, address: &ProtocolAddress) -> Result<Option<IdentityKey>> {
430 self.identity_store.get_identity(address).await
431 }
432}
433
434#[async_trait(?Send)]
435impl traits::PreKeyStore for InMemSignalProtocolStore {
436 async fn get_pre_key(&self, id: PreKeyId) -> Result<PreKeyRecord> {
437 self.pre_key_store.get_pre_key(id).await
438 }
439
440 async fn save_pre_key(&mut self, id: PreKeyId, record: &PreKeyRecord) -> Result<()> {
441 self.pre_key_store.save_pre_key(id, record).await
442 }
443
444 async fn remove_pre_key(&mut self, id: PreKeyId) -> Result<()> {
445 self.pre_key_store.remove_pre_key(id).await
446 }
447}
448
449#[async_trait(?Send)]
450impl traits::SignedPreKeyStore for InMemSignalProtocolStore {
451 async fn get_signed_pre_key(&self, id: SignedPreKeyId) -> Result<SignedPreKeyRecord> {
452 self.signed_pre_key_store.get_signed_pre_key(id).await
453 }
454
455 async fn save_signed_pre_key(
456 &mut self,
457 id: SignedPreKeyId,
458 record: &SignedPreKeyRecord,
459 ) -> Result<()> {
460 self.signed_pre_key_store
461 .save_signed_pre_key(id, record)
462 .await
463 }
464}
465
466#[async_trait(?Send)]
467impl traits::KyberPreKeyStore for InMemSignalProtocolStore {
468 async fn get_kyber_pre_key(&self, kyber_prekey_id: KyberPreKeyId) -> Result<KyberPreKeyRecord> {
469 self.kyber_pre_key_store
470 .get_kyber_pre_key(kyber_prekey_id)
471 .await
472 }
473
474 async fn save_kyber_pre_key(
475 &mut self,
476 kyber_prekey_id: KyberPreKeyId,
477 record: &KyberPreKeyRecord,
478 ) -> Result<()> {
479 self.kyber_pre_key_store
480 .save_kyber_pre_key(kyber_prekey_id, record)
481 .await
482 }
483
484 async fn mark_kyber_pre_key_used(&mut self, kyber_prekey_id: KyberPreKeyId) -> Result<()> {
485 self.kyber_pre_key_store
486 .mark_kyber_pre_key_used(kyber_prekey_id)
487 .await
488 }
489}
490
491#[async_trait(?Send)]
492impl traits::SessionStore for InMemSignalProtocolStore {
493 async fn load_session(&self, address: &ProtocolAddress) -> Result<Option<SessionRecord>> {
494 self.session_store.load_session(address).await
495 }
496
497 async fn store_session(
498 &mut self,
499 address: &ProtocolAddress,
500 record: &SessionRecord,
501 ) -> Result<()> {
502 self.session_store.store_session(address, record).await
503 }
504}
505
506#[async_trait(?Send)]
507impl traits::SenderKeyStore for InMemSignalProtocolStore {
508 async fn store_sender_key(
509 &mut self,
510 sender: &ProtocolAddress,
511 distribution_id: Uuid,
512 record: &SenderKeyRecord,
513 ) -> Result<()> {
514 self.sender_key_store
515 .store_sender_key(sender, distribution_id, record)
516 .await
517 }
518
519 async fn load_sender_key(
520 &mut self,
521 sender: &ProtocolAddress,
522 distribution_id: Uuid,
523 ) -> Result<Option<SenderKeyRecord>> {
524 self.sender_key_store
525 .load_sender_key(sender, distribution_id)
526 .await
527 }
528}
529
530impl traits::ProtocolStore for InMemSignalProtocolStore {}