33use crate :: str:: pyunicode_new:: * ;
44
55use core:: arch:: x86_64:: {
6- _mm256_and_si256 , _mm256_cmpgt_epu8_mask , _mm256_cmpneq_epi8_mask , _mm256_loadu_epi8 ,
7- _mm256_mask_cmpneq_epi8_mask , _mm256_maskz_loadu_epi8 , _mm256_max_epu8 , _mm256_set1_epi8 ,
6+ _mm512_and_si512 , _mm512_cmpgt_epu8_mask , _mm512_cmpneq_epi8_mask , _mm512_loadu_epi8 ,
7+ _mm512_mask_cmpneq_epi8_mask , _mm512_maskz_loadu_epi8 , _mm512_max_epu8 , _mm512_set1_epi8 ,
88} ;
99
1010#[ inline( never) ]
1111#[ target_feature( enable = "avx512f,avx512bw,avx512vl,bmi2" ) ]
1212pub ( crate ) unsafe fn create_str_impl_avx512vl ( buf : & str ) -> * mut pyo3_ffi:: PyObject {
13- const STRIDE : usize = 32 ;
13+ const STRIDE : usize = 64 ;
1414
1515 let buf_ptr = buf. as_bytes ( ) . as_ptr ( ) . cast :: < i8 > ( ) ;
1616 let buf_len = buf. len ( ) ;
@@ -20,40 +20,40 @@ pub(crate) unsafe fn create_str_impl_avx512vl(buf: &str) -> *mut pyo3_ffi::PyObj
2020 let num_loops = buf_len / STRIDE ;
2121 let remainder = buf_len % STRIDE ;
2222
23- let remainder_mask: u32 = !( u32 :: MAX << remainder) ;
24- let mut str_vec = _mm256_maskz_loadu_epi8 ( remainder_mask, buf_ptr) ;
23+ let remainder_mask: u64 = !( u64 :: MAX << remainder) ;
24+ let mut str_vec = _mm512_maskz_loadu_epi8 ( remainder_mask, buf_ptr) ;
2525 let sptr = buf_ptr. add ( remainder) ;
2626
2727 for i in 0 ..num_loops {
28- str_vec = _mm256_max_epu8 (
28+ str_vec = _mm512_max_epu8 (
2929 str_vec,
30- _mm256_loadu_epi8 ( sptr. add ( STRIDE * i) . cast :: < i8 > ( ) ) ,
30+ _mm512_loadu_epi8 ( sptr. add ( STRIDE * i) . cast :: < i8 > ( ) ) ,
3131 ) ;
3232 }
3333
3434 #[ allow( overflowing_literals) ]
35- let vec_128 = _mm256_set1_epi8 ( 0b10000000i8 ) ;
36- if _mm256_cmpgt_epu8_mask ( str_vec, vec_128) == 0 {
35+ let vec_128 = _mm512_set1_epi8 ( 0b10000000i8 ) ;
36+ if _mm512_cmpgt_epu8_mask ( str_vec, vec_128) == 0 {
3737 pyunicode_ascii ( buf. as_bytes ( ) . as_ptr ( ) , buf_len)
3838 } else {
3939 #[ allow( overflowing_literals) ]
40- let is_four = _mm256_cmpgt_epu8_mask ( str_vec, _mm256_set1_epi8 ( 239i8 ) ) != 0 ;
40+ let is_four = _mm512_cmpgt_epu8_mask ( str_vec, _mm512_set1_epi8 ( 239i8 ) ) != 0 ;
4141 #[ allow( overflowing_literals) ]
42- let is_not_latin = _mm256_cmpgt_epu8_mask ( str_vec, _mm256_set1_epi8 ( 195i8 ) ) != 0 ;
42+ let is_not_latin = _mm512_cmpgt_epu8_mask ( str_vec, _mm512_set1_epi8 ( 195i8 ) ) != 0 ;
4343 #[ allow( overflowing_literals) ]
44- let multibyte = _mm256_set1_epi8 ( 0b11000000i8 ) ;
44+ let multibyte = _mm512_set1_epi8 ( 0b11000000i8 ) ;
4545
46- let mut num_chars = _mm256_mask_cmpneq_epi8_mask (
46+ let mut num_chars = _mm512_mask_cmpneq_epi8_mask (
4747 remainder_mask,
48- _mm256_and_si256 ( _mm256_maskz_loadu_epi8 ( remainder_mask, buf_ptr) , multibyte) ,
48+ _mm512_and_si512 ( _mm512_maskz_loadu_epi8 ( remainder_mask, buf_ptr) , multibyte) ,
4949 vec_128,
5050 )
5151 . count_ones ( ) as usize ;
5252
5353 for i in 0 ..num_loops {
54- num_chars += _mm256_cmpneq_epi8_mask (
55- _mm256_and_si256 (
56- _mm256_loadu_epi8 ( sptr. add ( STRIDE * i) . cast :: < i8 > ( ) ) ,
54+ num_chars += _mm512_cmpneq_epi8_mask (
55+ _mm512_and_si512 (
56+ _mm512_loadu_epi8 ( sptr. add ( STRIDE * i) . cast :: < i8 > ( ) ) ,
5757 multibyte,
5858 ) ,
5959 vec_128,
0 commit comments