Skip to content

Commit 96b270b

Browse files
committed
create_str_impl_avx512vl() remove support for AVX10.2/256
1 parent 86d63a2 commit 96b270b

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

src/str/avx512.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
use crate::str::pyunicode_new::*;
44

55
use core::arch::x86_64::{
6-
_mm256_and_si256, _mm256_cmpgt_epu8_mask, _mm256_cmpneq_epi8_mask, _mm256_loadu_epi8,
7-
_mm256_mask_cmpneq_epi8_mask, _mm256_maskz_loadu_epi8, _mm256_max_epu8, _mm256_set1_epi8,
6+
_mm512_and_si512, _mm512_cmpgt_epu8_mask, _mm512_cmpneq_epi8_mask, _mm512_loadu_epi8,
7+
_mm512_mask_cmpneq_epi8_mask, _mm512_maskz_loadu_epi8, _mm512_max_epu8, _mm512_set1_epi8,
88
};
99

1010
#[inline(never)]
1111
#[target_feature(enable = "avx512f,avx512bw,avx512vl,bmi2")]
1212
pub(crate) unsafe fn create_str_impl_avx512vl(buf: &str) -> *mut pyo3_ffi::PyObject {
13-
const STRIDE: usize = 32;
13+
const STRIDE: usize = 64;
1414

1515
let buf_ptr = buf.as_bytes().as_ptr().cast::<i8>();
1616
let buf_len = buf.len();
@@ -20,40 +20,40 @@ pub(crate) unsafe fn create_str_impl_avx512vl(buf: &str) -> *mut pyo3_ffi::PyObj
2020
let num_loops = buf_len / STRIDE;
2121
let remainder = buf_len % STRIDE;
2222

23-
let remainder_mask: u32 = !(u32::MAX << remainder);
24-
let mut str_vec = _mm256_maskz_loadu_epi8(remainder_mask, buf_ptr);
23+
let remainder_mask: u64 = !(u64::MAX << remainder);
24+
let mut str_vec = _mm512_maskz_loadu_epi8(remainder_mask, buf_ptr);
2525
let sptr = buf_ptr.add(remainder);
2626

2727
for i in 0..num_loops {
28-
str_vec = _mm256_max_epu8(
28+
str_vec = _mm512_max_epu8(
2929
str_vec,
30-
_mm256_loadu_epi8(sptr.add(STRIDE * i).cast::<i8>()),
30+
_mm512_loadu_epi8(sptr.add(STRIDE * i).cast::<i8>()),
3131
);
3232
}
3333

3434
#[allow(overflowing_literals)]
35-
let vec_128 = _mm256_set1_epi8(0b10000000i8);
36-
if _mm256_cmpgt_epu8_mask(str_vec, vec_128) == 0 {
35+
let vec_128 = _mm512_set1_epi8(0b10000000i8);
36+
if _mm512_cmpgt_epu8_mask(str_vec, vec_128) == 0 {
3737
pyunicode_ascii(buf.as_bytes().as_ptr(), buf_len)
3838
} else {
3939
#[allow(overflowing_literals)]
40-
let is_four = _mm256_cmpgt_epu8_mask(str_vec, _mm256_set1_epi8(239i8)) != 0;
40+
let is_four = _mm512_cmpgt_epu8_mask(str_vec, _mm512_set1_epi8(239i8)) != 0;
4141
#[allow(overflowing_literals)]
42-
let is_not_latin = _mm256_cmpgt_epu8_mask(str_vec, _mm256_set1_epi8(195i8)) != 0;
42+
let is_not_latin = _mm512_cmpgt_epu8_mask(str_vec, _mm512_set1_epi8(195i8)) != 0;
4343
#[allow(overflowing_literals)]
44-
let multibyte = _mm256_set1_epi8(0b11000000i8);
44+
let multibyte = _mm512_set1_epi8(0b11000000i8);
4545

46-
let mut num_chars = _mm256_mask_cmpneq_epi8_mask(
46+
let mut num_chars = _mm512_mask_cmpneq_epi8_mask(
4747
remainder_mask,
48-
_mm256_and_si256(_mm256_maskz_loadu_epi8(remainder_mask, buf_ptr), multibyte),
48+
_mm512_and_si512(_mm512_maskz_loadu_epi8(remainder_mask, buf_ptr), multibyte),
4949
vec_128,
5050
)
5151
.count_ones() as usize;
5252

5353
for i in 0..num_loops {
54-
num_chars += _mm256_cmpneq_epi8_mask(
55-
_mm256_and_si256(
56-
_mm256_loadu_epi8(sptr.add(STRIDE * i).cast::<i8>()),
54+
num_chars += _mm512_cmpneq_epi8_mask(
55+
_mm512_and_si512(
56+
_mm512_loadu_epi8(sptr.add(STRIDE * i).cast::<i8>()),
5757
multibyte,
5858
),
5959
vec_128,

0 commit comments

Comments
 (0)