zkgroup/common/
serialization.rs1use bincode::Options;
7use partial_default::PartialDefault;
8use serde::{Deserialize, Serialize};
9
10use crate::ZkGroupDeserializationFailure;
11
12fn zkgroup_bincode_options() -> impl bincode::Options {
13 bincode::DefaultOptions::new()
14 .with_fixint_encoding()
15 .with_little_endian()
16 .reject_trailing_bytes()
17}
18
19pub fn deserialize<'a, T: Deserialize<'a> + PartialDefault>(
23 bytes: &'a [u8],
24) -> Result<T, ZkGroupDeserializationFailure> {
25 let mut result = T::partial_default();
26 T::deserialize_in_place(
30 &mut bincode::Deserializer::from_slice(bytes, zkgroup_bincode_options()),
31 &mut result,
32 )
33 .map_err(|_| ZkGroupDeserializationFailure::new::<T>())?;
34 Ok(result)
35}
36
37pub fn serialize<T: Serialize>(value: &T) -> Vec<u8> {
39 zkgroup_bincode_options()
40 .serialize(value)
41 .expect("cannot fail")
42}
43
44#[derive(Copy, Clone, Debug, Default, Eq, PartialEq)]
50pub struct VersionByte<const C: u8>;
51
52impl<const C: u8> From<VersionByte<C>> for u8 {
53 fn from(VersionByte: VersionByte<C>) -> Self {
54 C
55 }
56}
57
58#[derive(Copy, Clone, Debug, Eq, PartialEq, displaydoc::Display)]
60pub struct VersionMismatchError<const EXPECTED: u8> {
61 found: u8,
62}
63
64impl<const C: u8> TryFrom<u8> for VersionByte<C> {
65 type Error = VersionMismatchError<C>;
66 fn try_from(value: u8) -> Result<Self, Self::Error> {
67 (value == C)
68 .then_some(VersionByte::<C>)
69 .ok_or(VersionMismatchError::<C> { found: value })
70 }
71}
72
73impl<const C: u8> Serialize for VersionByte<C> {
74 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
75 where
76 S: serde::Serializer,
77 {
78 u8::serialize(&C, serializer)
79 }
80}
81
82impl<'de, const C: u8> Deserialize<'de> for VersionByte<C> {
83 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
84 where
85 D: serde::Deserializer<'de>,
86 {
87 let v = u8::deserialize(deserializer)?;
88 v.try_into().map_err(|_| {
89 <D::Error as serde::de::Error>::invalid_value(
90 serde::de::Unexpected::Unsigned(v.into()),
91 &format!("version `{C}`").as_str(),
92 )
93 })
94 }
95}
96
97pub type ReservedByte = VersionByte<0>;
99
100#[cfg(test)]
101mod test {
102 use std::fmt::Debug;
103
104 use test_case::test_case;
105
106 use super::*;
107
108 #[derive(Debug, Serialize, Deserialize, PartialEq, PartialDefault)]
109 struct WithLeadingByte<T> {
110 leading: T,
111 string: String,
112 }
113
114 impl<T: Default> WithLeadingByte<T> {
115 fn test_value() -> Self {
116 Self {
117 leading: T::default(),
118 string: "a string".to_string(),
119 }
120 }
121 }
122
123 type WithReservedByte = WithLeadingByte<ReservedByte>;
124 type WithVersionByte = WithLeadingByte<VersionByte<42>>;
125
126 #[test_case(WithReservedByte::test_value(), 0)]
127 #[test_case(WithVersionByte::test_value(), 42)]
128 fn round_trip<T: Serialize + for<'a> Deserialize<'a> + PartialEq + PartialDefault + Debug>(
129 test_value: T,
130 expected_first_byte: u8,
131 ) {
132 let serialized = crate::serialize(&test_value);
133
134 assert_eq!(serialized[0], expected_first_byte);
135 let deserialized: T = crate::deserialize(&serialized).expect("can deserialize");
136
137 assert_eq!(deserialized, test_value);
138 }
139
140 #[test_case(WithReservedByte::test_value())]
141 #[test_case(WithVersionByte::test_value())]
142 fn version_byte_wrong<
143 T: Serialize + for<'a> Deserialize<'a> + PartialEq + PartialDefault + Debug,
144 >(
145 test_value: T,
146 ) {
147 let mut serialized = crate::serialize(&test_value);
148 serialized[0] += 1;
150 crate::deserialize::<T>(&serialized).expect_err("invalid version");
151 }
152
153 #[test]
154 fn version_byte_error_message() {
155 let mut bincode_serialized =
156 bincode::serialize(&WithVersionByte::test_value()).expect("should serialize");
157 bincode_serialized[0] = 41;
158
159 let error_message =
160 bincode::deserialize::<WithVersionByte>(&bincode_serialized).expect_err("should fail");
161 assert_eq!(
162 error_message.to_string(),
163 "invalid value: integer `41`, expected version `42`"
164 );
165 }
166}