zkgroup/common/
serialization.rs

1//
2// Copyright 2023 Signal Messenger, LLC.
3// SPDX-License-Identifier: AGPL-3.0-only
4//
5
6use bincode::Options;
7use partial_default::PartialDefault;
8use serde::{Deserialize, Serialize};
9
10use crate::ZkGroupDeserializationFailure;
11
12fn zkgroup_bincode_options() -> impl bincode::Options {
13    bincode::DefaultOptions::new()
14        .with_fixint_encoding()
15        .with_little_endian()
16        .reject_trailing_bytes()
17}
18
19/// Deserializes a type using the standard zkgroup encoding (based on bincode).
20///
21/// The type must support [`PartialDefault`] to save on code size.
22pub fn deserialize<'a, T: Deserialize<'a> + PartialDefault>(
23    bytes: &'a [u8],
24) -> Result<T, ZkGroupDeserializationFailure> {
25    let mut result = T::partial_default();
26    // Use the same encoding options as plain bincode::deserialize, which we used historically,
27    // but also reject trailing bytes.
28    // See https://docs.rs/bincode/1.3.3/bincode/config/index.html#options-struct-vs-bincode-functions.
29    T::deserialize_in_place(
30        &mut bincode::Deserializer::from_slice(bytes, zkgroup_bincode_options()),
31        &mut result,
32    )
33    .map_err(|_| ZkGroupDeserializationFailure::new::<T>())?;
34    Ok(result)
35}
36
37/// Serializes a type using the standard zkgroup encoding (based on bincode).
38pub fn serialize<T: Serialize>(value: &T) -> Vec<u8> {
39    zkgroup_bincode_options()
40        .serialize(value)
41        .expect("cannot fail")
42}
43
44/// Constant version number `C` as a type.
45///
46/// Zero-sized type that converts to and from for the value `C` via `Into`,
47/// `TryFrom`, [`Serialize`], and [`Deserialize`]. Used for providing a version
48/// tag at the beginning of serialized structs.
49#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
50pub struct VersionByte<const C: u8>;
51
52impl<const C: u8> From<VersionByte<C>> for u8 {
53    fn from(VersionByte: VersionByte<C>) -> Self {
54        C
55    }
56}
57
58/// version byte was {found}, not {EXPECTED:?}
59#[derive(Copy, Clone, Debug, Eq, PartialEq, displaydoc::Display)]
60pub struct VersionMismatchError<const EXPECTED: u8> {
61    found: u8,
62}
63
64impl<const C: u8> TryFrom<u8> for VersionByte<C> {
65    type Error = VersionMismatchError<C>;
66    fn try_from(value: u8) -> Result<Self, Self::Error> {
67        (value == C)
68            .then_some(VersionByte::<C>)
69            .ok_or(VersionMismatchError::<C> { found: value })
70    }
71}
72
73impl<const C: u8> Serialize for VersionByte<C> {
74    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
75    where
76        S: serde::Serializer,
77    {
78        u8::serialize(&C, serializer)
79    }
80}
81
82impl<'de, const C: u8> Deserialize<'de> for VersionByte<C> {
83    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
84    where
85        D: serde::Deserializer<'de>,
86    {
87        let v = u8::deserialize(deserializer)?;
88        v.try_into().map_err(|_| {
89            <D::Error as serde::de::Error>::invalid_value(
90                serde::de::Unexpected::Unsigned(v.into()),
91                &format!("version `{C}`").as_str(),
92            )
93        })
94    }
95}
96
97/// Value that always serializes to and from `0u8`.
98pub type ReservedByte = VersionByte<0>;
99
100#[cfg(test)]
101mod test {
102    use std::fmt::Debug;
103
104    use test_case::test_case;
105
106    use super::*;
107
108    #[derive(Debug, Serialize, Deserialize, PartialEq, PartialDefault)]
109    struct WithLeadingByte<T> {
110        leading: T,
111        string: String,
112    }
113
114    impl<T: Default> WithLeadingByte<T> {
115        fn test_value() -> Self {
116            Self {
117                leading: T::default(),
118                string: "a string".to_string(),
119            }
120        }
121    }
122
123    type WithReservedByte = WithLeadingByte<ReservedByte>;
124    type WithVersionByte = WithLeadingByte<VersionByte<42>>;
125
126    #[test_case(WithReservedByte::test_value(), 0)]
127    #[test_case(WithVersionByte::test_value(), 42)]
128    fn round_trip<T: Serialize + for<'a> Deserialize<'a> + PartialEq + PartialDefault + Debug>(
129        test_value: T,
130        expected_first_byte: u8,
131    ) {
132        let serialized = crate::serialize(&test_value);
133
134        assert_eq!(serialized[0], expected_first_byte);
135        let deserialized: T = crate::deserialize(&serialized).expect("can deserialize");
136
137        assert_eq!(deserialized, test_value);
138    }
139
140    #[test_case(WithReservedByte::test_value())]
141    #[test_case(WithVersionByte::test_value())]
142    fn version_byte_wrong<
143        T: Serialize + for<'a> Deserialize<'a> + PartialEq + PartialDefault + Debug,
144    >(
145        test_value: T,
146    ) {
147        let mut serialized = crate::serialize(&test_value);
148        // perturb the first byte.
149        serialized[0] += 1;
150        crate::deserialize::<T>(&serialized).expect_err("invalid version");
151    }
152
153    #[test]
154    fn version_byte_error_message() {
155        let mut bincode_serialized =
156            bincode::serialize(&WithVersionByte::test_value()).expect("should serialize");
157        bincode_serialized[0] = 41;
158
159        let error_message =
160            bincode::deserialize::<WithVersionByte>(&bincode_serialized).expect_err("should fail");
161        assert_eq!(
162            error_message.to_string(),
163            "invalid value: integer `41`, expected version `42`"
164        );
165    }
166}