use std::cmp::Ordering;
use std::ops::{Add, Div, Rem};
fn expand_top_bit(a: u8) -> u8 {
0u8.wrapping_sub(a >> 7)
}
fn ct_is_zero(a: u8) -> u8 {
expand_top_bit(!a & a.wrapping_sub(1))
}
fn ct_is_eq(a: u8, b: u8) -> u8 {
ct_is_zero(a ^ b)
}
fn ct_is_lt(a: u8, b: u8) -> u8 {
expand_top_bit(a ^ ((a ^ b) | ((a.wrapping_sub(b)) ^ a)))
}
fn ct_select(mask: u8, a: u8, b: u8) -> u8 {
debug_assert!(mask == 0 || mask == 0xFF);
b ^ (mask & (a ^ b))
}
pub(crate) fn constant_time_cmp(x: &[u8], y: &[u8]) -> Ordering {
if x.len() < y.len() {
return Ordering::Less;
}
if x.len() > y.len() {
return Ordering::Greater;
}
let mut result: u8 = 0;
for i in 0..x.len() {
let a = x[x.len() - 1 - i];
let b = y[x.len() - 1 - i];
let is_eq = ct_is_eq(a, b);
let is_lt = ct_is_lt(a, b);
result = ct_select(is_eq, result, ct_select(is_lt, 1, 255));
}
debug_assert!(result == 0 || result == 1 || result == 255);
if result == 0 {
Ordering::Equal
} else if result == 1 {
Ordering::Less
} else {
Ordering::Greater
}
}
#[inline]
pub(crate) fn div_ceil<
T: Copy + Div<Output = T> + Rem<Output = T> + Add<Output = T> + Ord + From<u8>,
>(
dividend: T,
divisor: T,
) -> T {
let q = dividend / divisor;
let r = dividend % divisor;
if (r > 0u8.into() && divisor > 0u8.into()) || (r < 0u8.into() && divisor < 0u8.into()) {
q + 1u8.into()
} else {
q
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_time_cmp() {
use rand::Rng;
assert_eq!(constant_time_cmp(&[1], &[1]), Ordering::Equal);
assert_eq!(constant_time_cmp(&[0, 1], &[1]), Ordering::Greater);
assert_eq!(constant_time_cmp(&[1], &[0, 1]), Ordering::Less);
assert_eq!(constant_time_cmp(&[2], &[1, 0, 1]), Ordering::Less);
let mut rng = rand::rngs::OsRng;
for len in 1..320 {
let x: Vec<u8> = (0..len).map(|_| rng.gen()).collect();
let y: Vec<u8> = (0..len).map(|_| rng.gen()).collect();
let expected = x.cmp(&y);
let result = constant_time_cmp(&x, &y);
assert_eq!(result, expected);
let expected = y.cmp(&x);
let result = constant_time_cmp(&y, &x);
assert_eq!(result, expected);
}
}
#[test]
fn test_ct_is_zero() {
assert_eq!(ct_is_zero(0), 0xFF);
for i in 1..255 {
assert_eq!(ct_is_zero(i), 0x00);
}
}
#[test]
fn test_ct_is_lt() {
for x in 0..255 {
for y in 0..255 {
let expected = if x < y { 0xFF } else { 0 };
let result = ct_is_lt(x, y);
assert_eq!(result, expected);
}
}
}
#[test]
fn test_ct_is_eq() {
for x in 0..255 {
for y in 0..255 {
let expected = if x == y { 0xFF } else { 0 };
let result = ct_is_eq(x, y);
assert_eq!(result, expected);
}
}
}
#[test]
fn test_div_ceil() {
assert_eq!(div_ceil(0_usize, 4), 0);
assert_eq!(div_ceil(7_usize, 4), 2);
assert_eq!(div_ceil(8_usize, 4), 2);
assert_eq!(div_ceil(9_usize, 4), 3);
}
#[test]
#[should_panic]
fn test_div_ceil_panic() {
_ = div_ceil(4_usize, 0);
}
#[test]
fn test_div_ceil_isize() {
let a: isize = 8;
let b = 3;
assert_eq!(div_ceil(a, b), 3);
assert_eq!(div_ceil(a, -b), -2);
assert_eq!(div_ceil(-a, b), -2);
assert_eq!(div_ceil(-a, -b), 3);
}
}