1use std::fmt::Debug;
15
16use derive_where::derive_where;
17use partial_default::PartialDefault;
18use poksho::ShoApi;
19use rayon::iter::{IndexedParallelIterator as _, ParallelIterator as _};
20use serde::{Deserialize, Serialize};
21use zkcredential::attributes::Attribute as _;
22
23use crate::common::array_utils;
24use crate::common::serialization::ReservedByte;
25use crate::groups::{GroupSecretParams, UuidCiphertext};
26use crate::{
27 crypto, RandomnessBytes, ServerPublicParams, ServerSecretParams, Timestamp,
28 ZkGroupDeserializationFailure, ZkGroupVerificationFailure, SECONDS_PER_DAY,
29};
30
31const SECONDS_PER_HOUR: u64 = 60 * 60;
32
33#[derive(Serialize, Deserialize, PartialDefault)]
39pub struct GroupSendDerivedKeyPair {
40 reserved: ReservedByte,
41 key_pair: zkcredential::endorsements::ServerDerivedKeyPair,
42 expiration: Timestamp,
43}
44
45impl GroupSendDerivedKeyPair {
46 fn tag_info(expiration: Timestamp) -> impl poksho::ShoApi + Clone {
49 let mut sho = poksho::ShoHmacSha256::new(b"20240215_Signal_GroupSendEndorsement");
50 sho.absorb_and_ratchet(&expiration.to_be_bytes());
51 sho
52 }
53
54 pub fn for_expiration(expiration: Timestamp, params: &ServerSecretParams) -> Self {
56 Self {
57 reserved: ReservedByte::default(),
58 key_pair: params
59 .endorsement_key_pair
60 .derive_key(Self::tag_info(expiration)),
61 expiration,
62 }
63 }
64}
65
66#[derive(Serialize, Deserialize, PartialDefault, Debug)]
71pub struct GroupSendEndorsementsResponse {
72 reserved: ReservedByte,
73 endorsements: zkcredential::endorsements::EndorsementResponse,
74 expiration: Timestamp,
75}
76
77impl GroupSendEndorsementsResponse {
78 pub fn default_expiration(current_time: Timestamp) -> Timestamp {
79 let current_time_in_seconds = current_time.epoch_seconds();
82 let start_of_day = current_time_in_seconds - (current_time_in_seconds % SECONDS_PER_DAY);
83 let mut expiration = start_of_day + 2 * SECONDS_PER_DAY;
84 if (expiration - current_time_in_seconds) < SECONDS_PER_DAY + SECONDS_PER_HOUR {
85 expiration += SECONDS_PER_DAY;
86 }
87 Timestamp::from_epoch_seconds(expiration)
88 }
89
90 fn sort_points(points: &mut [(usize, curve25519_dalek_signal::RistrettoPoint)]) {
97 debug_assert!(points.iter().enumerate().all(|(i, (j, _))| i == *j));
98 let sort_keys = curve25519_dalek_signal::RistrettoPoint::double_and_compress_batch(
99 points.iter().map(|(_i, point)| point),
100 );
101 points.sort_unstable_by_key(|(i, _point)| sort_keys[*i].as_bytes());
102 }
103
104 pub fn issue(
108 member_ciphertexts: impl IntoIterator<Item = UuidCiphertext>,
109 key_pair: &GroupSendDerivedKeyPair,
110 randomness: RandomnessBytes,
111 ) -> Self {
112 let mut points_to_sign: Vec<(usize, curve25519_dalek_signal::RistrettoPoint)> =
116 member_ciphertexts
117 .into_iter()
118 .map(|ciphertext| ciphertext.ciphertext.as_points()[0])
119 .enumerate()
120 .collect();
121 Self::sort_points(&mut points_to_sign);
122
123 let endorsements = zkcredential::endorsements::EndorsementResponse::issue(
124 points_to_sign.iter().map(|(_i, point)| *point),
125 &key_pair.key_pair,
126 randomness,
127 );
128
129 Self {
134 reserved: ReservedByte::default(),
135 endorsements,
136 expiration: key_pair.expiration,
137 }
138 }
139
140 pub fn expiration(&self) -> Timestamp {
142 self.expiration
143 }
144
145 fn derive_public_signing_key_from_expiration(
152 &self,
153 now: Timestamp,
154 server_params: &ServerPublicParams,
155 ) -> Result<zkcredential::endorsements::ServerDerivedPublicKey, ZkGroupVerificationFailure>
156 {
157 if !self.expiration.is_day_aligned() {
158 return Err(ZkGroupVerificationFailure);
161 }
162 let time_remaining_in_seconds = self.expiration.saturating_seconds_since(now);
163 if time_remaining_in_seconds < 2 * SECONDS_PER_HOUR {
164 return Err(ZkGroupVerificationFailure);
168 }
169 if time_remaining_in_seconds > 7 * SECONDS_PER_DAY {
170 return Err(ZkGroupVerificationFailure);
173 }
174
175 Ok(server_params
176 .endorsement_public_key
177 .derive_key(GroupSendDerivedKeyPair::tag_info(self.expiration)))
178 }
179
180 pub fn receive_with_service_ids_single_threaded(
186 self,
187 user_ids: impl IntoIterator<Item = libsignal_core::ServiceId>,
188 now: Timestamp,
189 group_params: &GroupSecretParams,
190 server_params: &ServerPublicParams,
191 ) -> Result<Vec<ReceivedEndorsement>, ZkGroupVerificationFailure> {
192 let derived_key = self.derive_public_signing_key_from_expiration(now, server_params)?;
193
194 let mut member_points: Vec<(usize, curve25519_dalek_signal::RistrettoPoint)> = user_ids
199 .into_iter()
200 .map(|user_id| {
201 group_params.uid_enc_key_pair.a1 * crypto::uid_struct::UidStruct::calc_M1(user_id)
202 })
203 .enumerate()
204 .collect();
205 Self::sort_points(&mut member_points);
206
207 let endorsements = self
208 .endorsements
209 .receive(member_points.iter().map(|(_i, point)| *point), &derived_key)
210 .map_err(|_| ZkGroupVerificationFailure)?;
211
212 Ok(array_utils::collect_permutation(
213 endorsements
214 .compressed
215 .into_iter()
216 .zip(endorsements.decompressed)
217 .map(|(compressed, decompressed)| ReceivedEndorsement {
218 compressed: GroupSendEndorsement {
219 reserved: ReservedByte::default(),
220 endorsement: compressed,
221 },
222 decompressed: GroupSendEndorsement {
223 reserved: ReservedByte::default(),
224 endorsement: decompressed,
225 },
226 })
227 .zip(member_points.iter().map(|(i, _)| *i)),
228 ))
229 }
230
231 pub fn receive_with_service_ids<T>(
239 self,
240 user_ids: T,
241 now: Timestamp,
242 group_params: &GroupSecretParams,
243 server_params: &ServerPublicParams,
244 ) -> Result<Vec<ReceivedEndorsement>, ZkGroupVerificationFailure>
245 where
246 T: rayon::iter::IntoParallelIterator<Item = libsignal_core::ServiceId>,
247 T::Iter: rayon::iter::IndexedParallelIterator,
248 {
249 let derived_key = self.derive_public_signing_key_from_expiration(now, server_params)?;
250
251 let mut member_points: Vec<(usize, curve25519_dalek_signal::RistrettoPoint)> = user_ids
256 .into_par_iter()
257 .map(|user_id| {
258 group_params.uid_enc_key_pair.a1 * crypto::uid_struct::UidStruct::calc_M1(user_id)
259 })
260 .enumerate()
261 .collect();
262 Self::sort_points(&mut member_points);
263
264 let endorsements = self
265 .endorsements
266 .receive(member_points.iter().map(|(_i, point)| *point), &derived_key)
267 .map_err(|_| ZkGroupVerificationFailure)?;
268
269 Ok(array_utils::collect_permutation(
270 endorsements
271 .compressed
272 .into_iter()
273 .zip(endorsements.decompressed)
274 .map(|(compressed, decompressed)| ReceivedEndorsement {
275 compressed: GroupSendEndorsement {
276 reserved: ReservedByte::default(),
277 endorsement: compressed,
278 },
279 decompressed: GroupSendEndorsement {
280 reserved: ReservedByte::default(),
281 endorsement: decompressed,
282 },
283 })
284 .zip(member_points.iter().map(|(i, _)| *i)),
285 ))
286 }
287
288 pub fn receive_with_ciphertexts(
297 self,
298 member_ciphertexts: impl IntoIterator<Item = UuidCiphertext>,
299 now: Timestamp,
300 server_params: &ServerPublicParams,
301 ) -> Result<Vec<ReceivedEndorsement>, ZkGroupVerificationFailure> {
302 let derived_key = self.derive_public_signing_key_from_expiration(now, server_params)?;
303
304 let mut points_to_check: Vec<_> = member_ciphertexts
308 .into_iter()
309 .map(|ciphertext| ciphertext.ciphertext.as_points()[0])
310 .enumerate()
311 .collect();
312 Self::sort_points(&mut points_to_check);
313
314 let endorsements = self
315 .endorsements
316 .receive(
317 points_to_check.iter().map(|(_i, point)| *point),
318 &derived_key,
319 )
320 .map_err(|_| ZkGroupVerificationFailure)?;
321
322 Ok(array_utils::collect_permutation(
323 endorsements
324 .compressed
325 .into_iter()
326 .zip(endorsements.decompressed)
327 .map(|(compressed, decompressed)| ReceivedEndorsement {
328 compressed: GroupSendEndorsement {
329 reserved: ReservedByte::default(),
330 endorsement: compressed,
331 },
332 decompressed: GroupSendEndorsement {
333 reserved: ReservedByte::default(),
334 endorsement: decompressed,
335 },
336 })
337 .zip(points_to_check.iter().map(|(i, _)| *i)),
338 ))
339 }
340}
341
342#[derive(Serialize, Deserialize, PartialDefault, Clone, Copy)]
348#[partial_default(bound = "Storage: curve25519_dalek_signal::traits::Identity")]
349#[derive_where(PartialEq; Storage: subtle::ConstantTimeEq)]
350pub struct GroupSendEndorsement<Storage = curve25519_dalek_signal::RistrettoPoint> {
351 reserved: ReservedByte,
352 endorsement: zkcredential::endorsements::Endorsement<Storage>,
353}
354
355impl Debug for GroupSendEndorsement<curve25519_dalek_signal::RistrettoPoint> {
356 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
357 f.debug_struct("GroupSendEndorsement")
358 .field("reserved", &self.reserved)
359 .field("endorsement", &self.endorsement)
360 .finish()
361 }
362}
363
364impl Debug for GroupSendEndorsement<curve25519_dalek_signal::ristretto::CompressedRistretto> {
365 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366 f.debug_struct("GroupSendEndorsement")
367 .field("reserved", &self.reserved)
368 .field("endorsement", &self.endorsement)
369 .finish()
370 }
371}
372
373#[allow(missing_docs)]
383#[derive(Clone, Copy, PartialDefault)]
384pub struct ReceivedEndorsement {
385 pub compressed: GroupSendEndorsement<curve25519_dalek_signal::ristretto::CompressedRistretto>,
393 pub decompressed: GroupSendEndorsement,
394}
395
396impl GroupSendEndorsement<curve25519_dalek_signal::ristretto::CompressedRistretto> {
397 pub fn decompress(
405 self,
406 ) -> Result<
407 GroupSendEndorsement<curve25519_dalek_signal::RistrettoPoint>,
408 ZkGroupDeserializationFailure,
409 > {
410 Ok(GroupSendEndorsement {
411 reserved: self.reserved,
412 endorsement: self
413 .endorsement
414 .decompress()
415 .map_err(|_| ZkGroupDeserializationFailure::new::<Self>())?,
416 })
417 }
418}
419
420impl GroupSendEndorsement<curve25519_dalek_signal::RistrettoPoint> {
421 pub fn compress(
426 self,
427 ) -> GroupSendEndorsement<curve25519_dalek_signal::ristretto::CompressedRistretto> {
428 GroupSendEndorsement {
429 reserved: self.reserved,
430 endorsement: self.endorsement.compress(),
431 }
432 }
433}
434
435impl GroupSendEndorsement {
436 pub fn combine(
443 endorsements: impl IntoIterator<Item = GroupSendEndorsement>,
444 ) -> GroupSendEndorsement {
445 let mut endorsements = endorsements.into_iter();
446 let Some(mut result) = endorsements.next() else {
447 return GroupSendEndorsement {
451 reserved: ReservedByte::default(),
452 endorsement: Default::default(),
453 };
454 };
455 for next in endorsements {
456 assert_eq!(
457 result.reserved, next.reserved,
458 "endorsements must all have the same version"
459 );
460 result.endorsement = result.endorsement.combine_with(&next.endorsement);
461 }
462 result
463 }
464
465 pub fn remove(&self, unwanted_endorsements: &GroupSendEndorsement) -> GroupSendEndorsement {
473 assert_eq!(
474 self.reserved, unwanted_endorsements.reserved,
475 "endorsements must have the same version"
476 );
477 GroupSendEndorsement {
478 reserved: self.reserved,
479 endorsement: self.endorsement.remove(&unwanted_endorsements.endorsement),
480 }
481 }
482
483 pub fn to_token(&self, group_params: &GroupSecretParams) -> GroupSendToken {
488 let client_key =
489 zkcredential::endorsements::ClientDecryptionKey::for_first_point_of_attribute(
490 &group_params.uid_enc_key_pair,
491 );
492 let raw_token = self.endorsement.to_token(&client_key);
493 GroupSendToken {
494 reserved: ReservedByte::default(),
495 raw_token,
496 }
497 }
498}
499
500#[derive(Serialize, Deserialize, PartialDefault)]
505pub struct GroupSendToken {
506 reserved: ReservedByte,
507 raw_token: Box<[u8]>,
508}
509
510impl Debug for GroupSendToken {
511 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
512 f.debug_struct("GroupSendToken")
513 .field("reserved", &self.reserved)
514 .field("raw_token", &zkcredential::PrintAsHex(&*self.raw_token))
515 .finish()
516 }
517}
518
519impl GroupSendToken {
520 pub fn into_full_token(self, expiration: Timestamp) -> GroupSendFullToken {
524 GroupSendFullToken {
525 reserved: self.reserved,
526 raw_token: self.raw_token,
527 expiration,
528 }
529 }
530}
531
532#[derive(Serialize, Deserialize, PartialDefault)]
536pub struct GroupSendFullToken {
537 reserved: ReservedByte,
538 raw_token: Box<[u8]>,
539 expiration: Timestamp,
540}
541
542impl Debug for GroupSendFullToken {
543 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
544 f.debug_struct("GroupSendFullToken")
545 .field("reserved", &self.reserved)
546 .field("raw_token", &zkcredential::PrintAsHex(&*self.raw_token))
547 .field("expiration", &self.expiration)
548 .finish()
549 }
550}
551
552impl GroupSendFullToken {
553 pub fn expiration(&self) -> Timestamp {
554 self.expiration
555 }
556
557 pub fn verify(
560 &self,
561 user_ids: impl IntoIterator<Item = libsignal_core::ServiceId>,
562 now: Timestamp,
563 key_pair: &GroupSendDerivedKeyPair,
564 ) -> Result<(), ZkGroupVerificationFailure> {
565 if now > self.expiration {
566 return Err(ZkGroupVerificationFailure);
567 }
568 assert_eq!(
569 self.expiration, key_pair.expiration,
570 "wrong key pair used for this token"
571 );
572
573 let user_id_sum: curve25519_dalek_signal::RistrettoPoint = user_ids
574 .into_iter()
575 .map(crypto::uid_struct::UidStruct::calc_M1)
576 .sum();
577
578 key_pair
579 .key_pair
580 .verify(&user_id_sum, &self.raw_token)
581 .map_err(|_| ZkGroupVerificationFailure)
582 }
583}