1use std::fmt::Debug;
15
16use derive_where::derive_where;
17use partial_default::PartialDefault;
18use poksho::ShoApi;
19use rayon::iter::{IndexedParallelIterator as _, ParallelIterator as _};
20use serde::{Deserialize, Serialize};
21use zkcredential::attributes::Attribute as _;
22
23use crate::common::array_utils;
24use crate::common::serialization::ReservedByte;
25use crate::crypto::uid_encryption;
26use crate::groups::{GroupSecretParams, UuidCiphertext};
27use crate::{
28 crypto, RandomnessBytes, Timestamp, ZkGroupDeserializationFailure, ZkGroupVerificationFailure,
29 SECONDS_PER_DAY,
30};
31
32const SECONDS_PER_HOUR: u64 = 60 * 60;
33
34#[derive(Serialize, Deserialize, PartialDefault)]
40pub struct GroupSendDerivedKeyPair {
41 reserved: ReservedByte,
42 key_pair: zkcredential::endorsements::ServerDerivedKeyPair,
43 expiration: Timestamp,
44}
45
46impl GroupSendDerivedKeyPair {
47 fn tag_info(expiration: Timestamp) -> impl poksho::ShoApi + Clone {
50 let mut sho = poksho::ShoHmacSha256::new(b"20240215_Signal_GroupSendEndorsement");
51 sho.absorb_and_ratchet(&expiration.to_be_bytes());
52 sho
53 }
54
55 pub fn for_expiration(
57 expiration: Timestamp,
58 root: impl AsRef<zkcredential::endorsements::ServerRootKeyPair>,
59 ) -> Self {
60 Self {
61 reserved: ReservedByte::default(),
62 key_pair: root.as_ref().derive_key(Self::tag_info(expiration)),
63 expiration,
64 }
65 }
66}
67
68#[derive(Serialize, Deserialize, PartialDefault, Debug)]
73pub struct GroupSendEndorsementsResponse {
74 reserved: ReservedByte,
75 endorsements: zkcredential::endorsements::EndorsementResponse,
76 expiration: Timestamp,
77}
78
79impl GroupSendEndorsementsResponse {
80 pub fn default_expiration(current_time: Timestamp) -> Timestamp {
81 let current_time_in_seconds = current_time.epoch_seconds();
84 let start_of_day = current_time_in_seconds - (current_time_in_seconds % SECONDS_PER_DAY);
85 let mut expiration = start_of_day + 2 * SECONDS_PER_DAY;
86 if (expiration - current_time_in_seconds) < SECONDS_PER_DAY + SECONDS_PER_HOUR {
87 expiration += SECONDS_PER_DAY;
88 }
89 Timestamp::from_epoch_seconds(expiration)
90 }
91
92 fn sort_points(points: &mut [(usize, curve25519_dalek_signal::RistrettoPoint)]) {
99 debug_assert!(points.iter().enumerate().all(|(i, (j, _))| i == *j));
100 let sort_keys = curve25519_dalek_signal::RistrettoPoint::double_and_compress_batch(
101 points.iter().map(|(_i, point)| point),
102 );
103 points.sort_unstable_by_key(|(i, _point)| sort_keys[*i].as_bytes());
104 }
105
106 pub fn issue(
110 member_ciphertexts: impl IntoIterator<Item = UuidCiphertext>,
111 key_pair: &GroupSendDerivedKeyPair,
112 randomness: RandomnessBytes,
113 ) -> Self {
114 let mut points_to_sign: Vec<(usize, curve25519_dalek_signal::RistrettoPoint)> =
118 member_ciphertexts
119 .into_iter()
120 .map(|ciphertext| ciphertext.ciphertext.as_points()[0])
121 .enumerate()
122 .collect();
123 Self::sort_points(&mut points_to_sign);
124
125 let endorsements = zkcredential::endorsements::EndorsementResponse::issue(
126 points_to_sign.iter().map(|(_i, point)| *point),
127 &key_pair.key_pair,
128 randomness,
129 );
130
131 Self {
136 reserved: ReservedByte::default(),
137 endorsements,
138 expiration: key_pair.expiration,
139 }
140 }
141
142 pub fn expiration(&self) -> Timestamp {
144 self.expiration
145 }
146
147 fn derive_public_signing_key_from_expiration(
154 &self,
155 now: Timestamp,
156 root_public_key: impl AsRef<zkcredential::endorsements::ServerRootPublicKey>,
157 ) -> Result<zkcredential::endorsements::ServerDerivedPublicKey, ZkGroupVerificationFailure>
158 {
159 if !self.expiration.is_day_aligned() {
160 return Err(ZkGroupVerificationFailure);
163 }
164 let time_remaining_in_seconds = self.expiration.saturating_seconds_since(now);
165 if time_remaining_in_seconds < 2 * SECONDS_PER_HOUR {
166 return Err(ZkGroupVerificationFailure);
170 }
171 if time_remaining_in_seconds > 7 * SECONDS_PER_DAY {
172 return Err(ZkGroupVerificationFailure);
175 }
176
177 Ok(root_public_key
178 .as_ref()
179 .derive_key(GroupSendDerivedKeyPair::tag_info(self.expiration)))
180 }
181
182 pub fn receive_with_service_ids_single_threaded(
188 self,
189 user_ids: impl IntoIterator<Item = libsignal_core::ServiceId>,
190 now: Timestamp,
191 group_params: &GroupSecretParams,
192 root_public_key: impl AsRef<zkcredential::endorsements::ServerRootPublicKey>,
193 ) -> Result<Vec<ReceivedEndorsement>, ZkGroupVerificationFailure> {
194 let derived_key = self.derive_public_signing_key_from_expiration(now, root_public_key)?;
195
196 let mut member_points: Vec<(usize, curve25519_dalek_signal::RistrettoPoint)> = user_ids
201 .into_iter()
202 .map(|user_id| {
203 group_params.uid_enc_key_pair.a1 * crypto::uid_struct::UidStruct::calc_M1(user_id)
204 })
205 .enumerate()
206 .collect();
207 Self::sort_points(&mut member_points);
208
209 let endorsements = self
210 .endorsements
211 .receive(member_points.iter().map(|(_i, point)| *point), &derived_key)
212 .map_err(|_| ZkGroupVerificationFailure)?;
213
214 Ok(array_utils::collect_permutation(
215 endorsements
216 .compressed
217 .into_iter()
218 .zip(endorsements.decompressed)
219 .map(|(compressed, decompressed)| ReceivedEndorsement {
220 compressed: GroupSendEndorsement {
221 reserved: ReservedByte::default(),
222 endorsement: compressed,
223 },
224 decompressed: GroupSendEndorsement {
225 reserved: ReservedByte::default(),
226 endorsement: decompressed,
227 },
228 })
229 .zip(member_points.iter().map(|(i, _)| *i)),
230 ))
231 }
232
233 pub fn receive_with_service_ids<T>(
241 self,
242 user_ids: T,
243 now: Timestamp,
244 group_params: &GroupSecretParams,
245 root_public_key: impl AsRef<zkcredential::endorsements::ServerRootPublicKey>,
246 ) -> Result<Vec<ReceivedEndorsement>, ZkGroupVerificationFailure>
247 where
248 T: rayon::iter::IntoParallelIterator<
249 Item = libsignal_core::ServiceId,
250 Iter: rayon::iter::IndexedParallelIterator,
251 >,
252 {
253 let derived_key = self.derive_public_signing_key_from_expiration(now, root_public_key)?;
254
255 let mut member_points: Vec<(usize, curve25519_dalek_signal::RistrettoPoint)> = user_ids
260 .into_par_iter()
261 .map(|user_id| {
262 group_params.uid_enc_key_pair.a1 * crypto::uid_struct::UidStruct::calc_M1(user_id)
263 })
264 .enumerate()
265 .collect();
266 Self::sort_points(&mut member_points);
267
268 let endorsements = self
269 .endorsements
270 .receive(member_points.iter().map(|(_i, point)| *point), &derived_key)
271 .map_err(|_| ZkGroupVerificationFailure)?;
272
273 Ok(array_utils::collect_permutation(
274 endorsements
275 .compressed
276 .into_iter()
277 .zip(endorsements.decompressed)
278 .map(|(compressed, decompressed)| ReceivedEndorsement {
279 compressed: GroupSendEndorsement {
280 reserved: ReservedByte::default(),
281 endorsement: compressed,
282 },
283 decompressed: GroupSendEndorsement {
284 reserved: ReservedByte::default(),
285 endorsement: decompressed,
286 },
287 })
288 .zip(member_points.iter().map(|(i, _)| *i)),
289 ))
290 }
291
292 pub fn receive_with_ciphertexts(
301 self,
302 member_ciphertexts: impl IntoIterator<Item = UuidCiphertext>,
303 now: Timestamp,
304 root_public_key: impl AsRef<zkcredential::endorsements::ServerRootPublicKey>,
305 ) -> Result<Vec<ReceivedEndorsement>, ZkGroupVerificationFailure> {
306 let derived_key = self.derive_public_signing_key_from_expiration(now, root_public_key)?;
307
308 let mut points_to_check: Vec<_> = member_ciphertexts
312 .into_iter()
313 .map(|ciphertext| ciphertext.ciphertext.as_points()[0])
314 .enumerate()
315 .collect();
316 Self::sort_points(&mut points_to_check);
317
318 let endorsements = self
319 .endorsements
320 .receive(
321 points_to_check.iter().map(|(_i, point)| *point),
322 &derived_key,
323 )
324 .map_err(|_| ZkGroupVerificationFailure)?;
325
326 Ok(array_utils::collect_permutation(
327 endorsements
328 .compressed
329 .into_iter()
330 .zip(endorsements.decompressed)
331 .map(|(compressed, decompressed)| ReceivedEndorsement {
332 compressed: GroupSendEndorsement {
333 reserved: ReservedByte::default(),
334 endorsement: compressed,
335 },
336 decompressed: GroupSendEndorsement {
337 reserved: ReservedByte::default(),
338 endorsement: decompressed,
339 },
340 })
341 .zip(points_to_check.iter().map(|(i, _)| *i)),
342 ))
343 }
344}
345
346#[derive(Serialize, Deserialize, PartialDefault, Clone, Copy)]
352#[partial_default(bound = "Storage: curve25519_dalek_signal::traits::Identity")]
353#[derive_where(PartialEq; Storage: subtle::ConstantTimeEq)]
354pub struct GroupSendEndorsement<Storage = curve25519_dalek_signal::RistrettoPoint> {
355 reserved: ReservedByte,
356 endorsement: zkcredential::endorsements::Endorsement<Storage>,
357}
358
359impl Debug for GroupSendEndorsement<curve25519_dalek_signal::RistrettoPoint> {
360 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361 f.debug_struct("GroupSendEndorsement")
362 .field("reserved", &self.reserved)
363 .field("endorsement", &self.endorsement)
364 .finish()
365 }
366}
367
368impl Debug for GroupSendEndorsement<curve25519_dalek_signal::ristretto::CompressedRistretto> {
369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 f.debug_struct("GroupSendEndorsement")
371 .field("reserved", &self.reserved)
372 .field("endorsement", &self.endorsement)
373 .finish()
374 }
375}
376
377#[allow(missing_docs)]
387#[derive(Clone, Copy, PartialDefault)]
388pub struct ReceivedEndorsement {
389 pub compressed: GroupSendEndorsement<curve25519_dalek_signal::ristretto::CompressedRistretto>,
397 pub decompressed: GroupSendEndorsement,
398}
399
400impl GroupSendEndorsement<curve25519_dalek_signal::ristretto::CompressedRistretto> {
401 pub fn decompress(
409 self,
410 ) -> Result<
411 GroupSendEndorsement<curve25519_dalek_signal::RistrettoPoint>,
412 ZkGroupDeserializationFailure,
413 > {
414 Ok(GroupSendEndorsement {
415 reserved: self.reserved,
416 endorsement: self
417 .endorsement
418 .decompress()
419 .map_err(|_| ZkGroupDeserializationFailure::new::<Self>())?,
420 })
421 }
422}
423
424impl GroupSendEndorsement<curve25519_dalek_signal::RistrettoPoint> {
425 pub fn compress(
430 self,
431 ) -> GroupSendEndorsement<curve25519_dalek_signal::ristretto::CompressedRistretto> {
432 GroupSendEndorsement {
433 reserved: self.reserved,
434 endorsement: self.endorsement.compress(),
435 }
436 }
437}
438
439impl GroupSendEndorsement {
440 pub fn combine(
447 endorsements: impl IntoIterator<Item = GroupSendEndorsement>,
448 ) -> GroupSendEndorsement {
449 let mut endorsements = endorsements.into_iter();
450 let Some(mut result) = endorsements.next() else {
451 return GroupSendEndorsement {
455 reserved: ReservedByte::default(),
456 endorsement: Default::default(),
457 };
458 };
459 for next in endorsements {
460 assert_eq!(
461 result.reserved, next.reserved,
462 "endorsements must all have the same version"
463 );
464 result.endorsement = result.endorsement.combine_with(&next.endorsement);
465 }
466 result
467 }
468
469 pub fn remove(&self, unwanted_endorsements: &GroupSendEndorsement) -> GroupSendEndorsement {
477 assert_eq!(
478 self.reserved, unwanted_endorsements.reserved,
479 "endorsements must have the same version"
480 );
481 GroupSendEndorsement {
482 reserved: self.reserved,
483 endorsement: self.endorsement.remove(&unwanted_endorsements.endorsement),
484 }
485 }
486
487 pub fn to_token<T: AsRef<uid_encryption::KeyPair>>(&self, key_pair: T) -> GroupSendToken {
492 let client_key =
493 zkcredential::endorsements::ClientDecryptionKey::for_first_point_of_attribute(
494 key_pair.as_ref(),
495 );
496 let raw_token = self.endorsement.to_token(&client_key);
497 GroupSendToken {
498 reserved: ReservedByte::default(),
499 raw_token,
500 }
501 }
502}
503
504#[derive(Serialize, Deserialize, PartialDefault)]
509pub struct GroupSendToken {
510 reserved: ReservedByte,
511 raw_token: Box<[u8]>,
512}
513
514impl Debug for GroupSendToken {
515 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
516 f.debug_struct("GroupSendToken")
517 .field("reserved", &self.reserved)
518 .field("raw_token", &zkcredential::PrintAsHex(&*self.raw_token))
519 .finish()
520 }
521}
522
523impl GroupSendToken {
524 pub fn into_full_token(self, expiration: Timestamp) -> GroupSendFullToken {
528 GroupSendFullToken {
529 reserved: self.reserved,
530 raw_token: self.raw_token,
531 expiration,
532 }
533 }
534}
535
536#[derive(Serialize, Deserialize, PartialDefault)]
540pub struct GroupSendFullToken {
541 reserved: ReservedByte,
542 raw_token: Box<[u8]>,
543 expiration: Timestamp,
544}
545
546impl Debug for GroupSendFullToken {
547 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
548 f.debug_struct("GroupSendFullToken")
549 .field("reserved", &self.reserved)
550 .field("raw_token", &zkcredential::PrintAsHex(&*self.raw_token))
551 .field("expiration", &self.expiration)
552 .finish()
553 }
554}
555
556impl GroupSendFullToken {
557 pub fn expiration(&self) -> Timestamp {
558 self.expiration
559 }
560
561 pub fn verify(
564 &self,
565 user_ids: impl IntoIterator<Item = libsignal_core::ServiceId>,
566 now: Timestamp,
567 key_pair: &GroupSendDerivedKeyPair,
568 ) -> Result<(), ZkGroupVerificationFailure> {
569 if now > self.expiration {
570 return Err(ZkGroupVerificationFailure);
571 }
572 assert_eq!(
573 self.expiration, key_pair.expiration,
574 "wrong key pair used for this token"
575 );
576
577 let user_id_sum: curve25519_dalek_signal::RistrettoPoint = user_ids
578 .into_iter()
579 .map(crypto::uid_struct::UidStruct::calc_M1)
580 .sum();
581
582 key_pair
583 .key_pair
584 .verify(&user_id_sum, &self.raw_token)
585 .map_err(|_| ZkGroupVerificationFailure)
586 }
587}