Skip to content

Commit 2be1083

Browse files
committed
Improve autovectorization of to_lowercase / to_uppercase functions
Refactor the code in the `convert_while_ascii` helper function to make it more suitable for auto-vectorization and also process the full ascii prefix of the string. The generic case conversion logic will only be invoked starting from the first non-ascii character. The runtime on a microbenchmark with a small ascii-only input decreases from ~55ns to ~18ns per iteration. The new implementation also reduces the amount of unsafe code and encapsulates all unsafe inside the helper function. Fixes #123712
1 parent 033becf commit 2be1083

File tree

2 files changed

+67
-49
lines changed

2 files changed

+67
-49
lines changed

library/alloc/benches/string.rs

+7
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,10 @@ fn bench_insert_str_long(b: &mut Bencher) {
162162
x
163163
})
164164
}
165+
166+
#[bench]
167+
fn bench_to_lowercase(b: &mut Bencher) {
168+
let s = "Hello there, the quick brown fox jumped over the lazy dog! \
169+
Lorem ipsum dolor sit amet, consectetur. ";
170+
b.iter(|| s.to_lowercase())
171+
}

library/alloc/src/str.rs

+60-49
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
use core::borrow::{Borrow, BorrowMut};
1111
use core::iter::FusedIterator;
1212
use core::mem;
13+
use core::mem::MaybeUninit;
1314
use core::ptr;
1415
use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher};
1516
use core::unicode::conversions;
@@ -366,14 +367,7 @@ impl str {
366367
without modifying the original"]
367368
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
368369
pub fn to_lowercase(&self) -> String {
369-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase);
370-
371-
// Safety: we know this is a valid char boundary since
372-
// out.len() is only progressed if ascii bytes are found
373-
let rest = unsafe { self.get_unchecked(out.len()..) };
374-
375-
// Safety: We have written only valid ASCII to our vec
376-
let mut s = unsafe { String::from_utf8_unchecked(out) };
370+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase);
377371

378372
for (i, c) in rest[..].char_indices() {
379373
if c == 'Σ' {
@@ -457,14 +451,7 @@ impl str {
457451
without modifying the original"]
458452
#[stable(feature = "unicode_case_mapping", since = "1.2.0")]
459453
pub fn to_uppercase(&self) -> String {
460-
let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase);
461-
462-
// Safety: we know this is a valid char boundary since
463-
// out.len() is only progressed if ascii bytes are found
464-
let rest = unsafe { self.get_unchecked(out.len()..) };
465-
466-
// Safety: We have written only valid ASCII to our vec
467-
let mut s = unsafe { String::from_utf8_unchecked(out) };
454+
let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase);
468455

469456
for c in rest.chars() {
470457
match conversions::to_upper(c) {
@@ -613,50 +600,74 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box<str> {
613600
unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
614601
}
615602

616-
/// Converts the bytes while the bytes are still ascii.
603+
/// Converts leading ascii bytes in `s` by calling the `convert` function.
604+
///
617605
/// For better average performance, this happens in chunks of `2*size_of::<usize>()`.
618-
/// Returns a vec with the converted bytes.
606+
///
607+
/// Returns a tuple of the converted prefix and the remainder starting from
608+
/// the first non-ascii character.
619609
#[inline]
620610
#[cfg(not(test))]
621611
#[cfg(not(no_global_oom_handling))]
622-
fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec<u8> {
623-
let mut out = Vec::with_capacity(b.len());
624-
612+
fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) {
625613
const USIZE_SIZE: usize = mem::size_of::<usize>();
626614
const MAGIC_UNROLL: usize = 2;
627615
const N: usize = USIZE_SIZE * MAGIC_UNROLL;
628-
const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]);
629616

630-
let mut i = 0;
631-
unsafe {
632-
while i + N <= b.len() {
633-
// Safety: we have checks the sizes `b` and `out` to know that our
634-
let in_chunk = b.get_unchecked(i..i + N);
635-
let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N);
636-
637-
let mut bits = 0;
638-
for j in 0..MAGIC_UNROLL {
639-
// read the bytes 1 usize at a time (unaligned since we haven't checked the alignment)
640-
// safety: in_chunk is valid bytes in the range
641-
bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned();
642-
}
643-
// if our chunks aren't ascii, then return only the prior bytes as init
644-
if bits & NONASCII_MASK != 0 {
645-
break;
646-
}
617+
let mut slice = s.as_bytes();
618+
let mut out = Vec::with_capacity(slice.len());
619+
let mut out_slice = out.spare_capacity_mut();
647620

648-
// perform the case conversions on N bytes (gets heavily autovec'd)
649-
for j in 0..N {
650-
// safety: in_chunk and out_chunk is valid bytes in the range
651-
let out = out_chunk.get_unchecked_mut(j);
652-
out.write(convert(in_chunk.get_unchecked(j)));
653-
}
621+
let mut i = 0_usize;
654622

655-
// mark these bytes as initialised
656-
i += N;
623+
// process the input in chunks to enable auto-vectorization
624+
while slice.len() >= N {
625+
let chunk = &slice[..N];
626+
let mut is_ascii = [false; N];
627+
628+
for j in 0..N {
629+
is_ascii[j] = chunk[j] <= 127;
657630
}
658-
out.set_len(i);
631+
632+
// auto-vectorization for this check is a bit fragile,
633+
// sum and comparing against the chunk size gives the best result,
634+
// specifically a pmovmsk instruction on x86.
635+
if is_ascii.into_iter().map(|x| x as u8).sum::<u8>() as usize != N {
636+
break;
637+
}
638+
639+
for j in 0..N {
640+
out_slice[j] = MaybeUninit::new(convert(&chunk[j]));
641+
}
642+
643+
i += N;
644+
slice = &slice[N..];
645+
out_slice = &mut out_slice[N..];
646+
}
647+
648+
// handle the remainder as individual bytes
649+
while !slice.is_empty() {
650+
let byte = slice[0];
651+
if byte > 127 {
652+
break;
653+
}
654+
out_slice[0] = MaybeUninit::new(convert(&byte));
655+
i += 1;
656+
slice = &slice[1..];
657+
out_slice = &mut out_slice[1..];
659658
}
660659

661-
out
660+
unsafe {
661+
// SAFETY: i bytes have been initialized above
662+
out.set_len(i);
663+
664+
// SAFETY: We have written only valid ascii to the output vec
665+
let ascii_string = String::from_utf8_unchecked(out);
666+
667+
// SAFETY: we know this is a valid char boundary
668+
// since we only skipped over leading ascii bytes
669+
let rest = core::str::from_utf8_unchecked(slice);
670+
671+
(ascii_string, rest)
672+
}
662673
}

0 commit comments

Comments
 (0)