zkgroup/common/
array_utils.rs

1//
2// Copyright 2021 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use std::ops::Index;
7
8use partial_default::PartialDefault;
9use serde::{Deserialize, Serialize};
10
11/// Abstracts over fixed-length arrays (and similar types) with an element type `T`.
12///
13/// Provides `iter` and `Index` rather than `Deref` or `AsRef<[T]>` to allow for alternate forms of
14/// indexing, for which exposing a slice could be confusing. See [`OneBased`].
15pub trait ArrayLike<T>: Index<usize, Output = T> {
16    const LEN: usize;
17    fn create(create_element: impl FnMut() -> T) -> Self;
18    fn iter(&self) -> std::slice::Iter<T>;
19}
20
21impl<T, const LEN: usize> ArrayLike<T> for [T; LEN] {
22    const LEN: usize = LEN;
23    fn create(mut create_element: impl FnMut() -> T) -> Self {
24        [0; LEN].map(|_| create_element())
25    }
26    fn iter(&self) -> std::slice::Iter<T> {
27        self[..].iter()
28    }
29}
30
31/// A wrapper around an array or slice to use one-based indexing.
32#[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Serialize, Deserialize)]
33pub struct OneBased<T>(pub T);
34
35impl<T> Index<usize> for OneBased<T>
36where
37    T: Index<usize>,
38{
39    type Output = T::Output;
40    fn index(&self, index: usize) -> &Self::Output {
41        assert!(index > 0, "one-based index cannot be zero");
42        &self.0[index - 1]
43    }
44}
45
46impl<T, Ts> ArrayLike<T> for OneBased<Ts>
47where
48    Ts: ArrayLike<T>,
49{
50    const LEN: usize = Ts::LEN;
51
52    fn create(create_element: impl FnMut() -> T) -> Self {
53        OneBased(Ts::create(create_element))
54    }
55
56    fn iter(&self) -> std::slice::Iter<T> {
57        self.0.iter()
58    }
59}
60
61pub(crate) fn collect_permutation<T: PartialDefault + Clone>(
62    iter: impl ExactSizeIterator<Item = (T, usize)>,
63) -> Vec<T> {
64    let mut result = vec![T::partial_default(); iter.len()];
65
66    for (value, position) in iter {
67        result[position] = value
68    }
69
70    result
71}
72
73#[cfg(test)]
74mod tests {
75    use rand::Rng as _;
76
77    use super::*;
78
79    #[test]
80    fn test_one_based_indexing() {
81        let array = OneBased([10, 20, 30]);
82        assert_eq!(10, array[1]);
83        assert_eq!(20, array[2]);
84        assert_eq!(30, array[3]);
85    }
86
87    #[test]
88    #[should_panic]
89    fn test_one_based_indexing_with_zero() {
90        let array = OneBased([10, 20, 30]);
91        let _ = array[0];
92    }
93
94    #[test]
95    #[should_panic]
96    fn test_one_based_indexing_past_end() {
97        let array = OneBased([10, 20, 30]);
98        let _ = array[4];
99    }
100
101    #[test]
102    fn test_one_based_iter() {
103        let array = OneBased([10, 20, 30]);
104        assert_eq!(vec![10, 20, 30], array.iter().copied().collect::<Vec<_>>());
105    }
106
107    #[test]
108    fn test_permute_simple() {
109        let elements = [5, 6, 7, 8];
110        let permutation = [3, 2, 1, 0];
111        let result = collect_permutation(elements.into_iter().zip(permutation));
112        assert_eq!([8, 7, 6, 5].as_slice(), result.as_slice());
113    }
114
115    #[test]
116    fn test_permute_scramble_and_unscramble() {
117        for _ in 0..100 {
118            let mut elements = [0u32; 512];
119            rand::thread_rng().fill(&mut elements);
120
121            let mut elements_with_indexes: Vec<_> = elements.into_iter().zip(0..).collect();
122            elements_with_indexes.sort_unstable();
123
124            let result = collect_permutation(elements_with_indexes.into_iter());
125            assert_eq!(elements.as_slice(), result.as_slice());
126        }
127    }
128}