mikejsavage.co.uk / blog

14 Aug 2017 / Rust performance: finishing the job

Today I saw a story about profiling and optimising some Rust code. It's a nice little into to perf, but the author stops early and leaves quite a lot on the table. The code he ends up with is:

pub fn get(&self, bwt: &BWTSlice, r: usize, a: u8) -> usize {
        let i = r / self.k;

        let mut count = 0;

        // count all the matching bytes b/t the closest checkpoint and our desired lookup
        for idx in (i * self.k) + 1 .. r + 1 {
                if bwt[idx] == a {
                        count += 1;
                }
        }

        // return the sampled checkpoint for this character + the manual count we just did
        self.occ[i][a as usize] + count
}

Let's factor out the hot loop to make it dead clear what's going on:

// BWTSlice is just [u8]
pub fn count(bwt: &[u8], a: u8) -> usize {
        let mut c = 0;
        for x in bwt {
                if x == a {
                        c += 1;
                }
        }
        c
}

pub fn get(&self, bwt: &BWTSlice, r: usize, a: u8) -> usize {
        let i = r / self.k;
        self.occ[i][a as usize] + count(bwt[(i * self.k) + 1 .. r + 1])
}

It's just counting the number of times a occurs in the array bwt. This code is totally reasonable and if it didn't show up in the profiler you could just leave it at that, but as we'll see it's not optimal.

BTW I want to be clear that I have no idea what context this code is used in or whether there are higher level code changes that would make a bigger difference, I just want to focus on optimising the snippet from the original post.

SIMD

x86 has had instructions to perform the same basic operation on more than one piece of data for quite a while now. For example there are instructions that operate on four floats at a time, instructions that operator on a pair of doubles, instructions that operate on 16 8bit ints, etc. Generally, these are called Single Instruction Multiple Data instructions, and on x86 they fall under the MMX/SSE/AVX instruction sets. Since the loop we want to optimise is doing the same operation to every element in the array independently of one another, it seems like a good candidate for vectorisation. (which is what we call rewriting normal code to use SIMD instructions)

Rewrite It In C++

I would have liked to have optimised the Rust code, and it is totally possible, but the benchmarking code for rust-bio does not compile with stable Rust, nor does the Rust SIMD library. There's not much I'd rather do less than spend ages dicking about downloading and installing other people's software to try and fix something that should really not be broken to begin with, so let's begin by rewriting the loop in C++. This is unfortunate because my timings aren't comparable to the numbers in the original blog post, and I'm not able to get numbers for the version Rust.

size_t count_slow( u8 * haystack, size_t n, u8 needle ) {
        size_t c = 0;
        for( size_t i = 0; i < n; i++ ) {
                if( haystack[ i ] == needle ) {
                        c++;
                }
        }
        return c;
}

As a fairly contrived benchmark, let's use this count how many times the letter 'o' appears in a string containing 10000 Lorem Ipsums. To actually perform the tests I disabled CPU frequency scaling (this saves about 1.5ms!), wrapped the code in some timing boilerplate, ran the code some times (by hand, so not that many), and recorded the fastest result. See the end of the post for the full code listing if you want to try it yourself.

If we build with gcc -O2 -march=native (we really only need -mpopcnt. FWIW -march=native helps the scalar code more than it helps mine) the benchmark completes in 21.3ms. If we build with -O3 the autovectoriser kicks in and the benchmark completes in 7.09ms. Just to reiterate, it makes no sense to compare these with the numbers in the original article, but I expect if I was able to compile the Rust version it would be about the same as -O2.

Vectorising by hand

The algorithm we are going to use is as follows:

  1. The instructions we want to use only work on data that's aligned to a 16 byte boundary, so we need to run the slow loop a few times if haystack is not aligned (this is called "loop peeling")
  2. For each block of 16 bytes, we can compare all of them with needle at once to get a mask with the PCMPEQB instruction. Note that matches are set to 0xff (eight ones), rather than just a single one like C comparisons.
  3. We can count the number of ones in the mask, and divide it by eight to get the number of needles in that 16 byte block. x86 has POPCNT to count the number of ones in a 64 bit number, so we need to call that twice per 16 byte block.
  4. When we're down to less than 16 bytes remaining, fall back to the slow loop again.

Unsurprisingly the implementation is quite a bit trickier:

size_t count_fast( const u8 * haystack, size_t n, u8 needle ) {
        const u8 * one_past_end = haystack + n;
        size_t c = 0;

        // peel
        while( uintptr_t( haystack ) % 16 != 0 && haystack < one_past_end ) {
                if( *haystack == needle ) {
                        c++;
                }
                haystack++;
        }

        // haystack is now aligned to 16 bytes
        // loop as long as we have 16 bytes left in haystack
        __m128i needles = _mm_set1_epi8( needle );
        while( haystack < one_past_end - 16 ) {
                __m128i chunk = _mm_load_si128( ( const __m128i * ) haystack );
                __m128i cmp = _mm_cmpeq_epi8( needles, chunk );
                u64 pophi = popcnt64( _mm_cvtsi128_si64( _mm_unpackhi_epi64( cmp, cmp ) ) );
                u64 poplo = popcnt64( _mm_cvtsi128_si64( cmp ) );
                c += ( pophi + poplo ) / 8;
                haystack += 16;
        }

        // remainder
        while( haystack < one_past_end ) {
                if( *haystack == needle ) {
                        c++;
                }
                haystack++;
        }

        return c;
}

But it's totally worth it, because the new code runs in 2.74ms, which is 13% the time of -O2, and 39% the time of -O3!

Unrolling

Since the loop body is so short, evaluating the loop condition ends up consuming a non-negligible amount of time per iteration. The simplest fix for this is to check whether there are 32 bytes remaining instead, and run the loop body twice per iteration:

while( haystack < one_past_end - 32 ) {
        {
                __m128i chunk = _mm_load_si128( ( const __m128i * ) haystack );
                haystack += 16; // note I also moved this up. seems to save some microseconds
                __m128i cmp = _mm_cmpeq_epi8( needles, chunk );
                u64 pophi = popcnt64( _mm_cvtsi128_si64( _mm_unpackhi_epi64( cmp, cmp ) ) );
                u64 poplo = popcnt64( _mm_cvtsi128_si64( cmp ) );
                c += ( pophi + poplo ) / 8;
        }
        {
                __m128i chunk = _mm_load_si128( ( const __m128i * ) haystack );
                haystack += 16;
                __m128i cmp = _mm_cmpeq_epi8( needles, chunk );
                u64 pophi = popcnt64( _mm_cvtsi128_si64( _mm_unpackhi_epi64( cmp, cmp ) ) );
                u64 poplo = popcnt64( _mm_cvtsi128_si64( cmp ) );
                c += ( pophi + poplo ) / 8;
        }
}

It's just a little bit faster, completing the benchmark in 2.45ms, which is 89% of the vectorised loop, 12% of -O2, and 35% of -O3.

Conclusion

For reference, here's a little table of results:

VersionTime% of -O2% of -O3
Scalar -O221.3ms100%-
Scalar -O37.09ms33%100%
Vectorised2.74ms13%39%
Unrolled2.45ms12%35%

Hopefully this post has served as a decent introduction to vectorisation, and has shown you that not only can you beat the compiler, but you can really do a lot better than the compiler without too much difficulty.

I am no expert on this so it's very possible that there's an even better approach. I should really try writing the loop in assembly by hand just to check the compiler hasn't done anything derpy.

Full code

For reference if you want to try it yourself. It should compile and run anywhere x86, let me know if you have problems.