zkgroup/common/
array_utils.rs1use std::ops::Index;
7
8use partial_default::PartialDefault;
9use serde::{Deserialize, Serialize};
10
11pub trait ArrayLike<T>: Index<usize, Output = T> {
16 const LEN: usize;
17 fn create(create_element: impl FnMut() -> T) -> Self;
18 fn iter(&self) -> std::slice::Iter<T>;
19}
20
21impl<T, const LEN: usize> ArrayLike<T> for [T; LEN] {
22 const LEN: usize = LEN;
23 fn create(mut create_element: impl FnMut() -> T) -> Self {
24 [0; LEN].map(|_| create_element())
25 }
26 fn iter(&self) -> std::slice::Iter<T> {
27 self[..].iter()
28 }
29}
30
31#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Serialize, Deserialize)]
33pub struct OneBased<T>(pub T);
34
35impl<T> Index<usize> for OneBased<T>
36where
37 T: Index<usize>,
38{
39 type Output = T::Output;
40 fn index(&self, index: usize) -> &Self::Output {
41 assert!(index > 0, "one-based index cannot be zero");
42 &self.0[index - 1]
43 }
44}
45
46impl<T, Ts> ArrayLike<T> for OneBased<Ts>
47where
48 Ts: ArrayLike<T>,
49{
50 const LEN: usize = Ts::LEN;
51
52 fn create(create_element: impl FnMut() -> T) -> Self {
53 OneBased(Ts::create(create_element))
54 }
55
56 fn iter(&self) -> std::slice::Iter<T> {
57 self.0.iter()
58 }
59}
60
61pub(crate) fn collect_permutation<T: PartialDefault + Clone>(
62 iter: impl ExactSizeIterator<Item = (T, usize)>,
63) -> Vec<T> {
64 let mut result = vec![T::partial_default(); iter.len()];
65
66 for (value, position) in iter {
67 result[position] = value
68 }
69
70 result
71}
72
73#[cfg(test)]
74mod tests {
75 use rand::Rng as _;
76
77 use super::*;
78
79 #[test]
80 fn test_one_based_indexing() {
81 let array = OneBased([10, 20, 30]);
82 assert_eq!(10, array[1]);
83 assert_eq!(20, array[2]);
84 assert_eq!(30, array[3]);
85 }
86
87 #[test]
88 #[should_panic]
89 fn test_one_based_indexing_with_zero() {
90 let array = OneBased([10, 20, 30]);
91 let _ = array[0];
92 }
93
94 #[test]
95 #[should_panic]
96 fn test_one_based_indexing_past_end() {
97 let array = OneBased([10, 20, 30]);
98 let _ = array[4];
99 }
100
101 #[test]
102 fn test_one_based_iter() {
103 let array = OneBased([10, 20, 30]);
104 assert_eq!(vec![10, 20, 30], array.iter().copied().collect::<Vec<_>>());
105 }
106
107 #[test]
108 fn test_permute_simple() {
109 let elements = [5, 6, 7, 8];
110 let permutation = [3, 2, 1, 0];
111 let result = collect_permutation(elements.into_iter().zip(permutation));
112 assert_eq!([8, 7, 6, 5].as_slice(), result.as_slice());
113 }
114
115 #[test]
116 fn test_permute_scramble_and_unscramble() {
117 for _ in 0..100 {
118 let mut elements = [0u32; 512];
119 rand::thread_rng().fill(&mut elements);
120
121 let mut elements_with_indexes: Vec<_> = elements.into_iter().zip(0..).collect();
122 elements_with_indexes.sort_unstable();
123
124 let result = collect_permutation(elements_with_indexes.into_iter());
125 assert_eq!(elements.as_slice(), result.as_slice());
126 }
127 }
128}