|
10 | 10 | use core::borrow::{Borrow, BorrowMut};
|
11 | 11 | use core::iter::FusedIterator;
|
12 | 12 | use core::mem;
|
| 13 | +use core::mem::MaybeUninit; |
13 | 14 | use core::ptr;
|
14 | 15 | use core::str::pattern::{DoubleEndedSearcher, Pattern, ReverseSearcher, Searcher};
|
15 | 16 | use core::unicode::conversions;
|
@@ -366,14 +367,7 @@ impl str {
|
366 | 367 | without modifying the original"]
|
367 | 368 | #[stable(feature = "unicode_case_mapping", since = "1.2.0")]
|
368 | 369 | pub fn to_lowercase(&self) -> String {
|
369 |
| - let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_lowercase); |
370 |
| - |
371 |
| - // Safety: we know this is a valid char boundary since |
372 |
| - // out.len() is only progressed if ascii bytes are found |
373 |
| - let rest = unsafe { self.get_unchecked(out.len()..) }; |
374 |
| - |
375 |
| - // Safety: We have written only valid ASCII to our vec |
376 |
| - let mut s = unsafe { String::from_utf8_unchecked(out) }; |
| 370 | + let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_lowercase); |
377 | 371 |
|
378 | 372 | for (i, c) in rest[..].char_indices() {
|
379 | 373 | if c == 'Σ' {
|
@@ -457,14 +451,7 @@ impl str {
|
457 | 451 | without modifying the original"]
|
458 | 452 | #[stable(feature = "unicode_case_mapping", since = "1.2.0")]
|
459 | 453 | pub fn to_uppercase(&self) -> String {
|
460 |
| - let out = convert_while_ascii(self.as_bytes(), u8::to_ascii_uppercase); |
461 |
| - |
462 |
| - // Safety: we know this is a valid char boundary since |
463 |
| - // out.len() is only progressed if ascii bytes are found |
464 |
| - let rest = unsafe { self.get_unchecked(out.len()..) }; |
465 |
| - |
466 |
| - // Safety: We have written only valid ASCII to our vec |
467 |
| - let mut s = unsafe { String::from_utf8_unchecked(out) }; |
| 454 | + let (mut s, rest) = convert_while_ascii(self, u8::to_ascii_uppercase); |
468 | 455 |
|
469 | 456 | for c in rest.chars() {
|
470 | 457 | match conversions::to_upper(c) {
|
@@ -613,50 +600,74 @@ pub unsafe fn from_boxed_utf8_unchecked(v: Box<[u8]>) -> Box<str> {
|
613 | 600 | unsafe { Box::from_raw(Box::into_raw(v) as *mut str) }
|
614 | 601 | }
|
615 | 602 |
|
616 |
| -/// Converts the bytes while the bytes are still ascii. |
| 603 | +/// Converts leading ascii bytes in `s` by calling the `convert` function. |
| 604 | +/// |
617 | 605 | /// For better average performance, this happens in chunks of `2*size_of::<usize>()`.
|
618 |
| -/// Returns a vec with the converted bytes. |
| 606 | +/// |
| 607 | +/// Returns a tuple of the converted prefix and the remainder starting from |
| 608 | +/// the first non-ascii character. |
619 | 609 | #[inline]
|
620 | 610 | #[cfg(not(test))]
|
621 | 611 | #[cfg(not(no_global_oom_handling))]
|
622 |
| -fn convert_while_ascii(b: &[u8], convert: fn(&u8) -> u8) -> Vec<u8> { |
623 |
| - let mut out = Vec::with_capacity(b.len()); |
624 |
| - |
| 612 | +fn convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> (String, &str) { |
625 | 613 | const USIZE_SIZE: usize = mem::size_of::<usize>();
|
626 | 614 | const MAGIC_UNROLL: usize = 2;
|
627 | 615 | const N: usize = USIZE_SIZE * MAGIC_UNROLL;
|
628 |
| - const NONASCII_MASK: usize = usize::from_ne_bytes([0x80; USIZE_SIZE]); |
629 | 616 |
|
630 |
| - let mut i = 0; |
631 |
| - unsafe { |
632 |
| - while i + N <= b.len() { |
633 |
| - // Safety: we have checks the sizes `b` and `out` to know that our |
634 |
| - let in_chunk = b.get_unchecked(i..i + N); |
635 |
| - let out_chunk = out.spare_capacity_mut().get_unchecked_mut(i..i + N); |
636 |
| - |
637 |
| - let mut bits = 0; |
638 |
| - for j in 0..MAGIC_UNROLL { |
639 |
| - // read the bytes 1 usize at a time (unaligned since we haven't checked the alignment) |
640 |
| - // safety: in_chunk is valid bytes in the range |
641 |
| - bits |= in_chunk.as_ptr().cast::<usize>().add(j).read_unaligned(); |
642 |
| - } |
643 |
| - // if our chunks aren't ascii, then return only the prior bytes as init |
644 |
| - if bits & NONASCII_MASK != 0 { |
645 |
| - break; |
646 |
| - } |
| 617 | + let mut slice = s.as_bytes(); |
| 618 | + let mut out = Vec::with_capacity(slice.len()); |
| 619 | + let mut out_slice = out.spare_capacity_mut(); |
647 | 620 |
|
648 |
| - // perform the case conversions on N bytes (gets heavily autovec'd) |
649 |
| - for j in 0..N { |
650 |
| - // safety: in_chunk and out_chunk is valid bytes in the range |
651 |
| - let out = out_chunk.get_unchecked_mut(j); |
652 |
| - out.write(convert(in_chunk.get_unchecked(j))); |
653 |
| - } |
| 621 | + let mut i = 0_usize; |
654 | 622 |
|
655 |
| - // mark these bytes as initialised |
656 |
| - i += N; |
| 623 | + // process the input in chunks to enable auto-vectorization |
| 624 | + while slice.len() >= N { |
| 625 | + let chunk = &slice[..N]; |
| 626 | + let mut is_ascii = [false; N]; |
| 627 | + |
| 628 | + for j in 0..N { |
| 629 | + is_ascii[j] = chunk[j] <= 127; |
657 | 630 | }
|
658 |
| - out.set_len(i); |
| 631 | + |
| 632 | + // auto-vectorization for this check is a bit fragile, |
| 633 | + // sum and comparing against the chunk size gives the best result, |
| 634 | + // specifically a pmovmsk instruction on x86. |
| 635 | + if is_ascii.into_iter().map(|x| x as u8).sum::<u8>() as usize != N { |
| 636 | + break; |
| 637 | + } |
| 638 | + |
| 639 | + for j in 0..N { |
| 640 | + out_slice[j] = MaybeUninit::new(convert(&chunk[j])); |
| 641 | + } |
| 642 | + |
| 643 | + i += N; |
| 644 | + slice = &slice[N..]; |
| 645 | + out_slice = &mut out_slice[N..]; |
| 646 | + } |
| 647 | + |
| 648 | + // handle the remainder as individual bytes |
| 649 | + while !slice.is_empty() { |
| 650 | + let byte = slice[0]; |
| 651 | + if byte > 127 { |
| 652 | + break; |
| 653 | + } |
| 654 | + out_slice[0] = MaybeUninit::new(convert(&byte)); |
| 655 | + i += 1; |
| 656 | + slice = &slice[1..]; |
| 657 | + out_slice = &mut out_slice[1..]; |
659 | 658 | }
|
660 | 659 |
|
661 |
| - out |
| 660 | + unsafe { |
| 661 | + // SAFETY: i bytes have been initialized above |
| 662 | + out.set_len(i); |
| 663 | + |
| 664 | + // SAFETY: We have written only valid ascii to the output vec |
| 665 | + let ascii_string = String::from_utf8_unchecked(out); |
| 666 | + |
| 667 | + // SAFETY: we know this is a valid char boundary |
| 668 | + // since we only skipped over leading ascii bytes |
| 669 | + let rest = core::str::from_utf8_unchecked(slice); |
| 670 | + |
| 671 | + (ascii_string, rest) |
| 672 | + } |
662 | 673 | }
|
0 commit comments