mirror of https://gitee.com/Nocallback/glibc.git
Browse Source
As indicated in a recent thread, this it is a simple brute-force
algorithm that checks the whole needle at a matching character pair
(and does so 1 byte at a time after the first 64 bytes of a needle).
Also it never skips ahead and thus can match at every haystack
position after trying to match all of the needle, which generic
implementation avoids.
As indicated by Wilco, a 4x larger needle and 16x larger haystack gives
a clear 65x slowdown both basic_strstr and __strstr_avx512:
"ifuncs": ["basic_strstr", "twoway_strstr", "__strstr_avx512",
"__strstr_sse2_unaligned", "__strstr_generic"],
{
"len_haystack": 65536,
"len_needle": 1024,
"align_haystack": 0,
"align_needle": 0,
"fail": 1,
"desc": "Difficult bruteforce needle",
"timings": [4.0948e+07, 15094.5, 3.20818e+07, 108558, 10839.2]
},
{
"len_haystack": 1048576,
"len_needle": 4096,
"align_haystack": 0,
"align_needle": 0,
"fail": 1,
"desc": "Difficult bruteforce needle",
"timings": [2.69767e+09, 100797, 2.08535e+09, 495706, 82666.9]
}
PS: I don't have an AVX512 capable machine to verify this issues, but
skimming through the code it does seems to follow what Wilco has
described.
Reviewed-by: Noah Goldstein <goldstein.w.n@gmail.com>
master
4 changed files with 4 additions and 248 deletions
@ -1,218 +0,0 @@ |
|||
/* strstr optimized with 512-bit AVX-512 instructions
|
|||
Copyright (C) 2022-2024 Free Software Foundation, Inc. |
|||
This file is part of the GNU C Library. |
|||
|
|||
The GNU C Library is free software; you can redistribute it and/or |
|||
modify it under the terms of the GNU Lesser General Public |
|||
License as published by the Free Software Foundation; either |
|||
version 2.1 of the License, or (at your option) any later version. |
|||
|
|||
The GNU C Library is distributed in the hope that it will be useful, |
|||
but WITHOUT ANY WARRANTY; without even the implied warranty of |
|||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
|||
Lesser General Public License for more details. |
|||
|
|||
You should have received a copy of the GNU Lesser General Public |
|||
License along with the GNU C Library; if not, see |
|||
<https://www.gnu.org/licenses/>. */
|
|||
|
|||
#include <immintrin.h> |
|||
#include <inttypes.h> |
|||
#include <stdbool.h> |
|||
#include <string.h> |
|||
|
|||
#define FULL_MMASK64 0xffffffffffffffff |
|||
#define ONE_64BIT 0x1ull |
|||
#define ZMM_SIZE_IN_BYTES 64 |
|||
#define PAGESIZE 4096 |
|||
|
|||
#define cvtmask64_u64(...) (uint64_t) (__VA_ARGS__) |
|||
#define kshiftri_mask64(x, y) ((x) >> (y)) |
|||
#define kand_mask64(x, y) ((x) & (y)) |
|||
|
|||
/*
|
|||
Returns the index of the first edge within the needle, returns 0 if no edge |
|||
is found. Example: 'ab' is the first edge in 'aaaaaaaaaabaarddg' |
|||
*/ |
|||
static inline size_t |
|||
find_edge_in_needle (const char *ned) |
|||
{ |
|||
size_t ind = 0; |
|||
while (ned[ind + 1] != '\0') |
|||
{ |
|||
if (ned[ind] != ned[ind + 1]) |
|||
return ind; |
|||
else |
|||
ind = ind + 1; |
|||
} |
|||
return 0; |
|||
} |
|||
|
|||
/*
|
|||
Compare needle with haystack byte by byte at specified location |
|||
*/ |
|||
static inline bool |
|||
verify_string_match (const char *hay, const size_t hay_index, const char *ned, |
|||
size_t ind) |
|||
{ |
|||
while (ned[ind] != '\0') |
|||
{ |
|||
if (ned[ind] != hay[hay_index + ind]) |
|||
return false; |
|||
ind = ind + 1; |
|||
} |
|||
return true; |
|||
} |
|||
|
|||
/*
|
|||
Compare needle with haystack at specified location. The first 64 bytes are |
|||
compared using a ZMM register. |
|||
*/ |
|||
static inline bool |
|||
verify_string_match_avx512 (const char *hay, const size_t hay_index, |
|||
const char *ned, const __mmask64 ned_mask, |
|||
const __m512i ned_zmm) |
|||
{ |
|||
/* check first 64 bytes using zmm and then scalar */ |
|||
__m512i hay_zmm = _mm512_loadu_si512 (hay + hay_index); // safe to do so
|
|||
__mmask64 match = _mm512_mask_cmpneq_epi8_mask (ned_mask, hay_zmm, ned_zmm); |
|||
if (match != 0x0) // failed the first few chars
|
|||
return false; |
|||
else if (ned_mask == FULL_MMASK64) |
|||
return verify_string_match (hay, hay_index, ned, ZMM_SIZE_IN_BYTES); |
|||
return true; |
|||
} |
|||
|
|||
char * |
|||
__strstr_avx512 (const char *haystack, const char *ned) |
|||
{ |
|||
char first = ned[0]; |
|||
if (first == '\0') |
|||
return (char *)haystack; |
|||
if (ned[1] == '\0') |
|||
return (char *)strchr (haystack, ned[0]); |
|||
|
|||
size_t edge = find_edge_in_needle (ned); |
|||
|
|||
/* ensure haystack is as long as the pos of edge in needle */ |
|||
for (int ii = 0; ii < edge; ++ii) |
|||
{ |
|||
if (haystack[ii] == '\0') |
|||
return NULL; |
|||
} |
|||
|
|||
/*
|
|||
Load 64 bytes of the needle and save it to a zmm register |
|||
Read one cache line at a time to avoid loading across a page boundary |
|||
*/ |
|||
__mmask64 ned_load_mask = _bzhi_u64 ( |
|||
FULL_MMASK64, 64 - ((uintptr_t) (ned) & 63)); |
|||
__m512i ned_zmm = _mm512_maskz_loadu_epi8 (ned_load_mask, ned); |
|||
__mmask64 ned_nullmask |
|||
= _mm512_mask_testn_epi8_mask (ned_load_mask, ned_zmm, ned_zmm); |
|||
|
|||
if (__glibc_unlikely (ned_nullmask == 0x0)) |
|||
{ |
|||
ned_zmm = _mm512_loadu_si512 (ned); |
|||
ned_nullmask = _mm512_testn_epi8_mask (ned_zmm, ned_zmm); |
|||
ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT); |
|||
if (ned_nullmask != 0x0) |
|||
ned_load_mask = ned_load_mask >> 1; |
|||
} |
|||
else |
|||
{ |
|||
ned_load_mask = ned_nullmask ^ (ned_nullmask - ONE_64BIT); |
|||
ned_load_mask = ned_load_mask >> 1; |
|||
} |
|||
const __m512i ned0 = _mm512_set1_epi8 (ned[edge]); |
|||
const __m512i ned1 = _mm512_set1_epi8 (ned[edge + 1]); |
|||
|
|||
/*
|
|||
Read the bytes of haystack in the current cache line |
|||
*/ |
|||
size_t hay_index = edge; |
|||
__mmask64 loadmask = _bzhi_u64 ( |
|||
FULL_MMASK64, 64 - ((uintptr_t) (haystack + hay_index) & 63)); |
|||
/* First load is a partial cache line */ |
|||
__m512i hay0 = _mm512_maskz_loadu_epi8 (loadmask, haystack + hay_index); |
|||
/* Search for NULL and compare only till null char */ |
|||
uint64_t nullmask |
|||
= cvtmask64_u64 (_mm512_mask_testn_epi8_mask (loadmask, hay0, hay0)); |
|||
uint64_t cmpmask = nullmask ^ (nullmask - ONE_64BIT); |
|||
cmpmask = cmpmask & cvtmask64_u64 (loadmask); |
|||
/* Search for the 2 characters of needle */ |
|||
__mmask64 k0 = _mm512_cmpeq_epi8_mask (hay0, ned0); |
|||
__mmask64 k1 = _mm512_cmpeq_epi8_mask (hay0, ned1); |
|||
k1 = kshiftri_mask64 (k1, 1); |
|||
/* k2 masks tell us if both chars from needle match */ |
|||
uint64_t k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask; |
|||
/* For every match, search for the entire needle for a full match */ |
|||
while (k2) |
|||
{ |
|||
uint64_t bitcount = _tzcnt_u64 (k2); |
|||
k2 = _blsr_u64 (k2); |
|||
size_t match_pos = hay_index + bitcount - edge; |
|||
if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1)) |
|||
< PAGESIZE - 1 - ZMM_SIZE_IN_BYTES) |
|||
{ |
|||
/*
|
|||
* Use vector compare as long as you are not crossing a page |
|||
*/ |
|||
if (verify_string_match_avx512 (haystack, match_pos, ned, |
|||
ned_load_mask, ned_zmm)) |
|||
return (char *)haystack + match_pos; |
|||
} |
|||
else |
|||
{ |
|||
if (verify_string_match (haystack, match_pos, ned, 0)) |
|||
return (char *)haystack + match_pos; |
|||
} |
|||
} |
|||
/* We haven't checked for potential match at the last char yet */ |
|||
haystack = (const char *)(((uintptr_t) (haystack + hay_index) | 63)); |
|||
hay_index = 0; |
|||
|
|||
/*
|
|||
Loop over one cache line at a time to prevent reading over page |
|||
boundary |
|||
*/ |
|||
__m512i hay1; |
|||
while (nullmask == 0) |
|||
{ |
|||
hay0 = _mm512_loadu_si512 (haystack + hay_index); |
|||
hay1 = _mm512_load_si512 (haystack + hay_index |
|||
+ 1); // Always 64 byte aligned
|
|||
nullmask = cvtmask64_u64 (_mm512_testn_epi8_mask (hay1, hay1)); |
|||
/* Compare only till null char */ |
|||
cmpmask = nullmask ^ (nullmask - ONE_64BIT); |
|||
k0 = _mm512_cmpeq_epi8_mask (hay0, ned0); |
|||
k1 = _mm512_cmpeq_epi8_mask (hay1, ned1); |
|||
/* k2 masks tell us if both chars from needle match */ |
|||
k2 = cvtmask64_u64 (kand_mask64 (k0, k1)) & cmpmask; |
|||
/* For every match, compare full strings for potential match */ |
|||
while (k2) |
|||
{ |
|||
uint64_t bitcount = _tzcnt_u64 (k2); |
|||
k2 = _blsr_u64 (k2); |
|||
size_t match_pos = hay_index + bitcount - edge; |
|||
if (((uintptr_t) (haystack + match_pos) & (PAGESIZE - 1)) |
|||
< PAGESIZE - 1 - ZMM_SIZE_IN_BYTES) |
|||
{ |
|||
/*
|
|||
* Use vector compare as long as you are not crossing a page |
|||
*/ |
|||
if (verify_string_match_avx512 (haystack, match_pos, ned, |
|||
ned_load_mask, ned_zmm)) |
|||
return (char *)haystack + match_pos; |
|||
} |
|||
else |
|||
{ |
|||
/* Compare byte by byte */ |
|||
if (verify_string_match (haystack, match_pos, ned, 0)) |
|||
return (char *)haystack + match_pos; |
|||
} |
|||
} |
|||
hay_index += ZMM_SIZE_IN_BYTES; |
|||
} |
|||
return NULL; |
|||
} |
|||
Loading…
Reference in new issue