mikejsavage.co.uk • About • Archive • RSS • Thanks for blocking ads! Blocking ads owns: AdGuard for Safari / uBlock Origin for everything else
Today I saw a story about profiling and optimising some Rust
code.
It's a nice little into to perf
, but the author stops early and leaves
quite a lot on the table. The code he ends up with is:
pub fn get(&self, bwt: &BWTSlice, r: usize, a: u8) -> usize {
let i = r / self.k;
let mut count = 0;
// count all the matching bytes b/t the closest checkpoint and our desired lookup
for idx in (i * self.k) + 1 .. r + 1 {
if bwt[idx] == a {
count += 1;
}
}
// return the sampled checkpoint for this character + the manual count we just did
self.occ[i][a as usize] + count
}
Let's factor out the hot loop to make it dead clear what's going on:
// BWTSlice is just [u8]
pub fn count(bwt: &[u8], a: u8) -> usize {
let mut c = 0;
for x in bwt {
if x == a {
c += 1;
}
}
c
}
pub fn get(&self, bwt: &BWTSlice, r: usize, a: u8) -> usize {
let i = r / self.k;
self.occ[i][a as usize] + count(bwt[(i * self.k) + 1 .. r + 1])
}
It's just counting the number of times a
occurs in the array bwt
.
This code is totally reasonable and if it didn't show up in the profiler
you could just leave it at that, but as we'll see it's not optimal.
BTW I want to be clear that I have no idea what context this code is used in or whether there are higher level code changes that would make a bigger difference, I just want to focus on optimising the snippet from the original post.
x86 has had instructions to perform the same basic operation on more than one piece of data for quite a while now. For example there are instructions that operate on four floats at a time, instructions that operator on a pair of doubles, instructions that operate on 16 8bit ints, etc. Generally, these are called SIMD instructions, and on x86 they fall under the MMX/SSE/AVX instruction sets. Since the loop we want to optimise is doing the same operation to every element in the array independently of one another, it seems like a good candidate for vectorisation. (which is what we call rewriting normal code to use SIMD instructions)
I would have liked to have optimised the Rust code, and it is totally possible, but the benchmarking code for rust-bio does not compile with stable Rust, nor does the Rust SIMD library. There's not much I'd rather do less than spend ages dicking about downloading and installing other people's software to try and fix something that should really not be broken to begin with, so let's begin by rewriting the loop in C++. This is unfortunate because my timings aren't comparable to the numbers in the original blog post, and I'm not able to get numbers for the Rust version.
size_t count_slow( u8 * haystack, size_t n, u8 needle ) {
size_t c = 0;
for( size_t i = 0; i < n; i++ ) {
if( haystack[ i ] == needle ) {
c++;
}
}
return c;
}
As a fairly contrived benchmark, let's use this count how many times the letter 'o' appears in a string containing 10000 Lorem Ipsums. To actually perform the tests I disabled CPU frequency scaling (this saves about 1.5ms!), wrapped the code in some timing boilerplate, ran the code some times (by hand, so not that many), and recorded the fastest result. See the end of the post for the full code listing if you want to try it yourself.
If we build with gcc -O2 -march=native
(we really only need
-mpopcnt
. FWIW -march=native
helps the scalar code more than it
helps mine) the benchmark completes in 21.3ms. If we build with -O3
the autovectoriser kicks in and the benchmark completes in 7.09ms. Just
to reiterate, it makes no sense to compare these with the numbers in the
original article, but I expect if I was able to compile the Rust version
it would be about the same as -O2
.
The algorithm we are going to use is as follows:
haystack
is not aligned (this is called "loop
peeling")needle
at once to get a mask with [the PCMPEQB
instruction](https://msdn.microsoft.com/en-us/library/bz5xk21a(v=vs.90).aspx).
Note that matches are set to 0xff
(eight ones), rather than just a
single one like C comparisons.needle
s in that 16 byte block. x86 has
POPCNT
to count the number of ones in a 64 bit number, so we need
to call that twice per 16 byte block.(BTW see the followup post for a better approach)
Unsurprisingly the implementation is quite a bit trickier:
size_t count_fast( const u8 * haystack, size_t n, u8 needle ) {
const u8 * one_past_end = haystack + n;
size_t c = 0;
// peel
while( uintptr_t( haystack ) % 16 != 0 && haystack < one_past_end ) {
if( *haystack == needle ) {
c++;
}
haystack++;
}
// haystack is now aligned to 16 bytes
// loop as long as we have 16 bytes left in haystack
__m128i needles = _mm_set1_epi8( needle );
while( haystack < one_past_end - 16 ) {
__m128i chunk = _mm_load_si128( ( const __m128i * ) haystack );
__m128i cmp = _mm_cmpeq_epi8( needles, chunk );
u64 pophi = popcnt64( _mm_cvtsi128_si64( _mm_unpackhi_epi64( cmp, cmp ) ) );
u64 poplo = popcnt64( _mm_cvtsi128_si64( cmp ) );
c += ( pophi + poplo ) / 8;
haystack += 16;
}
// remainder
while( haystack < one_past_end ) {
if( *haystack == needle ) {
c++;
}
haystack++;
}
return c;
}
But it's totally worth it, because the new code runs in 2.74ms, which is
13% the time of -O2
, and 39% the time of -O3
!
Since the loop body is so short, evaluating the loop condition ends up consuming a non-negligible amount of time per iteration. The simplest fix for this is to check whether there are 32 bytes remaining instead, and run the loop body twice per iteration:
while( haystack < one_past_end - 32 ) {
{
__m128i chunk = _mm_load_si128( ( const __m128i * ) haystack );
haystack += 16; // note I also moved this up. seems to save some microseconds
__m128i cmp = _mm_cmpeq_epi8( needles, chunk );
u64 pophi = popcnt64( _mm_cvtsi128_si64( _mm_unpackhi_epi64( cmp, cmp ) ) );
u64 poplo = popcnt64( _mm_cvtsi128_si64( cmp ) );
c += ( pophi + poplo ) / 8;
}
{
__m128i chunk = _mm_load_si128( ( const __m128i * ) haystack );
haystack += 16;
__m128i cmp = _mm_cmpeq_epi8( needles, chunk );
u64 pophi = popcnt64( _mm_cvtsi128_si64( _mm_unpackhi_epi64( cmp, cmp ) ) );
u64 poplo = popcnt64( _mm_cvtsi128_si64( cmp ) );
c += ( pophi + poplo ) / 8;
}
}
It's just a little bit faster, completing the benchmark in 2.45ms, which
is 89% of the vectorised loop, 12% of -O2
, and 35% of -O3
.
For reference, here's a little table of results:
Version | Time | % of -O2 | % of -O3 |
---|---|---|---|
Scalar -O2 | 21.3ms | 100% | - |
Scalar -O3 | 7.09ms | 33% | 100% |
Vectorised | 2.74ms | 13% | 39% |
Unrolled | 2.45ms | 12% | 35% |
Hopefully this post has served as a decent introduction to vectorisation, and has shown you that not only can you beat the compiler, but you can really do a lot better than the compiler without too much difficulty.
I am no expert on this so it's very possible that there's an even better approach (sure enough, there is). I should really try writing the loop in assembly by hand just to check the compiler hasn't done anything derpy.