Skip to content

Commit d2e401f

Browse files
committed
chacha: Use Overlapping in the implementation of the fallback impl.
Eliminate all of the `unsafe` in the fallback implementation.
1 parent 504685d commit d2e401f

File tree

4 files changed

+144
-20
lines changed

4 files changed

+144
-20
lines changed

src/aead/chacha/fallback.rs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
// Adapted from the public domain, estream code by D. Bernstein.
1616
// Adapted from the BoringSSL crypto/chacha/chacha.c.
1717

18-
use super::{Counter, Key, Overlapping, BLOCK_LEN};
18+
use super::{super::overlapping::IndexError, Counter, Key, Overlapping, BLOCK_LEN};
1919
use crate::{constant_time, polyfill::sliceutil};
20-
use core::{mem::size_of, slice};
20+
use core::mem::size_of;
2121

22-
pub(super) fn ChaCha20_ctr32(key: &Key, counter: Counter, in_out: Overlapping<'_>) {
22+
pub(super) fn ChaCha20_ctr32(key: &Key, counter: Counter, mut in_out: Overlapping<'_>) {
2323
const SIGMA: [u32; 4] = [
2424
u32::from_le_bytes(*b"expa"),
2525
u32::from_le_bytes(*b"nd 3"),
@@ -35,31 +35,34 @@ pub(super) fn ChaCha20_ctr32(key: &Key, counter: Counter, in_out: Overlapping<'_
3535
key[6], key[7], counter[0], counter[1], counter[2], counter[3],
3636
];
3737

38-
let (mut input, mut output, mut in_out_len) = in_out.into_input_output_len();
38+
let mut in_out_len = in_out.len();
3939

4040
let mut buf = [0u8; BLOCK_LEN];
4141
while in_out_len > 0 {
4242
chacha_core(&mut buf, &state);
4343
state[12] += 1;
4444

45+
debug_assert_eq!(in_out_len, in_out.len());
46+
4547
// Both branches do the same thing, but the duplication helps the
4648
// compiler optimize (vectorize) the `BLOCK_LEN` case.
4749
if in_out_len >= BLOCK_LEN {
48-
let input = unsafe { slice::from_raw_parts(input, BLOCK_LEN) };
49-
constant_time::xor_assign_at_start(&mut buf, input);
50-
let output = unsafe { slice::from_raw_parts_mut(output, BLOCK_LEN) };
51-
sliceutil::overwrite_at_start(output, &buf);
50+
in_out = in_out
51+
.split_first_chunk::<BLOCK_LEN>(|in_out| {
52+
constant_time::xor_assign_at_start(&mut buf, in_out.input());
53+
sliceutil::overwrite_at_start(in_out.into_unwritten_output(), &buf);
54+
})
55+
.unwrap_or_else(|IndexError { .. }| {
56+
// Since `in_out_len == in_out.len() && in_out_len >= BLOCK_LEN`.
57+
unreachable!()
58+
});
5259
} else {
53-
let input = unsafe { slice::from_raw_parts(input, in_out_len) };
54-
constant_time::xor_assign_at_start(&mut buf, input);
55-
let output = unsafe { slice::from_raw_parts_mut(output, in_out_len) };
56-
sliceutil::overwrite_at_start(output, &buf);
60+
constant_time::xor_assign_at_start(&mut buf, in_out.input());
61+
sliceutil::overwrite_at_start(in_out.into_unwritten_output(), &buf);
5762
break;
5863
}
5964

6065
in_out_len -= BLOCK_LEN;
61-
input = unsafe { input.add(BLOCK_LEN) };
62-
output = unsafe { output.add(BLOCK_LEN) };
6366
}
6467
}
6568

src/aead/overlapping/array.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright 2024 Brian Smith.
2+
//
3+
// Permission to use, copy, modify, and/or distribute this software for any
4+
// purpose with or without fee is hereby granted, provided that the above
5+
// copyright notice and this permission notice appear in all copies.
6+
//
7+
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHORS DISCLAIM ALL WARRANTIES
8+
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
9+
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY
10+
// SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
11+
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
12+
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
13+
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
14+
15+
#![cfg_attr(not(test), allow(dead_code))]
16+
17+
use super::Overlapping;
18+
use core::array::TryFromSliceError;
19+
20+
pub struct Array<'o, T, const N: usize> {
21+
// Invariant: N != 0.
22+
// Invariant: `self.in_out.len() == N`.
23+
in_out: Overlapping<'o, T>,
24+
}
25+
26+
impl<'o, T, const N: usize> Array<'o, T, N> {
27+
pub(super) fn new(in_out: Overlapping<'o, T>) -> Result<Self, LenMismatchError> {
28+
if N == 0 || in_out.len() != N {
29+
return Err(LenMismatchError::new(N));
30+
}
31+
Ok(Self { in_out })
32+
}
33+
34+
pub fn into_unwritten_output(self) -> &'o mut [T; N]
35+
where
36+
&'o mut [T]: TryInto<&'o mut [T; N], Error = TryFromSliceError>,
37+
{
38+
self.in_out
39+
.into_unwritten_output()
40+
.try_into()
41+
.unwrap_or_else(|TryFromSliceError { .. }| {
42+
unreachable!() // Due to invariant
43+
})
44+
}
45+
}
46+
47+
impl<T, const N: usize> Array<'_, T, N> {
48+
pub fn input<'s>(&'s self) -> &'s [T; N]
49+
where
50+
&'s [T]: TryInto<&'s [T; N], Error = TryFromSliceError>,
51+
{
52+
self.in_out
53+
.input()
54+
.try_into()
55+
.unwrap_or_else(|TryFromSliceError { .. }| {
56+
unreachable!() // Due to invariant
57+
})
58+
}
59+
}
60+
61+
pub struct LenMismatchError {
62+
#[allow(dead_code)]
63+
len: usize,
64+
}
65+
66+
impl LenMismatchError {
67+
#[cold]
68+
#[inline(never)]
69+
fn new(len: usize) -> Self {
70+
Self { len }
71+
}
72+
}

src/aead/overlapping/base.rs

Lines changed: 51 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
// OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
1313
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
1414

15-
use core::ops::RangeFrom;
15+
use super::{Array, LenMismatchError};
16+
use core::{mem, ops::RangeFrom};
1617

1718
pub struct Overlapping<'o, T> {
1819
// Invariant: self.src.start <= in_out.len().
@@ -28,7 +29,7 @@ impl<'o, T> Overlapping<'o, T> {
2829
pub fn new(in_out: &'o mut [T], src: RangeFrom<usize>) -> Result<Self, IndexError> {
2930
match in_out.get(src.clone()) {
3031
Some(_) => Ok(Self { in_out, src }),
31-
None => Err(IndexError::new(src)),
32+
None => Err(IndexError::new(src.start)),
3233
}
3334
}
3435

@@ -51,7 +52,7 @@ impl<'o, T> Overlapping<'o, T> {
5152
(self.in_out, self.src)
5253
}
5354

54-
pub(super) fn into_unwritten_output(self) -> &'o mut [T] {
55+
pub fn into_unwritten_output(self) -> &'o mut [T] {
5556
let len = self.len();
5657
self.in_out.get_mut(..len).unwrap_or_else(|| {
5758
// The invariant ensures this succeeds.
@@ -83,14 +84,58 @@ impl<T> Overlapping<'_, T> {
8384
let input = unsafe { output_const.add(self.src.start) };
8485
(input, output, len)
8586
}
87+
88+
// Perhaps unlike `slice::split_first_chunk_mut`, this is biased,
89+
// performance-wise, against the case where `N > self.len()`, so callers
90+
// should be structured to avoid that.
91+
//
92+
// If the result is `Err` then nothing was written to `self`; if anything
93+
// was written then the result will not be `Err`.
94+
#[cfg_attr(not(test), allow(dead_code))]
95+
pub fn split_first_chunk<const N: usize>(
96+
mut self,
97+
f: impl for<'a> FnOnce(Array<'a, T, N>),
98+
) -> Result<Self, IndexError> {
99+
let src = self.src.clone();
100+
let end = self
101+
.src
102+
.start
103+
.checked_add(N)
104+
.ok_or_else(|| IndexError::new(N))?;
105+
let first = self
106+
.in_out
107+
.get_mut(..end)
108+
.ok_or_else(|| IndexError::new(N))?;
109+
let first = Overlapping::new(first, src).unwrap_or_else(|IndexError { .. }| {
110+
// Since `end == src.start + N`.
111+
unreachable!()
112+
});
113+
let first = Array::new(first).unwrap_or_else(|LenMismatchError { .. }| {
114+
// Since `end == src.start + N`.
115+
unreachable!()
116+
});
117+
// Once we call `f`, we must return `Ok` because `f` may have written
118+
// over (part of) the input.
119+
Ok({
120+
f(first);
121+
let tail = mem::take(&mut self.in_out).get_mut(N..).unwrap_or_else(|| {
122+
// There are at least `N` elements since `end == src.start + N`.
123+
unreachable!()
124+
});
125+
Self::new(tail, self.src).unwrap_or_else(|IndexError { .. }| {
126+
// Follows from `end == src.start + N`.
127+
unreachable!()
128+
})
129+
})
130+
}
86131
}
87132

88-
pub struct IndexError(#[allow(dead_code)] RangeFrom<usize>);
133+
pub struct IndexError(#[allow(dead_code)] usize);
89134

90135
impl IndexError {
91136
#[cold]
92137
#[inline(never)]
93-
fn new(src: RangeFrom<usize>) -> Self {
94-
Self(src)
138+
fn new(index: usize) -> Self {
139+
Self(index)
95140
}
96141
}

src/aead/overlapping/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
// CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
1414

1515
pub use self::{
16+
array::Array,
1617
base::{IndexError, Overlapping},
1718
partial_block::PartialBlock,
1819
};
1920

21+
use self::array::LenMismatchError;
22+
23+
mod array;
2024
mod base;
2125
mod partial_block;

0 commit comments

Comments
 (0)