Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4-way KeccaK-f[1600] using AVX2 to speed up Dilithium #111

Merged
merged 22 commits into from
May 21, 2020

Conversation

bwesterb
Copy link
Member

@bwesterb bwesterb commented May 8, 2020

Will also be useful to speed up XMSS[MT], SPHINCS+, etc.

Non-AVX2:

BenchmarkPermutationFunction-8   	 3594723	       326 ns/op	 613.13 MB/s

AVX2:

BenchmarkF1600x4-8   	 2926183	       399 ns/op

The latter computes four times as much so it's a 3.2x speedup

Applied to Dilithium (mode3) we get (new):

BenchmarkSkUnpack-8                      	   37950	     30844 ns/op
BenchmarkPkUnpack-8                      	   39169	     30442 ns/op
BenchmarkVerify-8                        	   76672	     15513 ns/op
BenchmarkSign-8                          	   10000	    125782 ns/op
BenchmarkGenerateKey-8                   	   17419	     68635 ns/op
BenchmarkPublicFromPrivate-8             	  134430	      9250 ns/op

old:

BenchmarkSkUnpack-8            	   18980	     61840 ns/op
BenchmarkPkUnpack-8            	   19743	     60087 ns/op
BenchmarkVerify-8              	   39998	     30421 ns/op
BenchmarkSign-8                	    5420	    228631 ns/op
BenchmarkGenerateKey-8         	   12037	     99672 ns/op
BenchmarkPublicFromPrivate-8   	  107148	     11512 ns/op

The non-optimized C reference implementation:

BenchmarkKeygen-8   	   10000	    109216 ns/op
BenchmarkSign-8     	    2206	    545616 ns/op
BenchmarkVerify-8   	   10000	    108419 ns/op

The AVX2 optimized C reference implementation:

BenchmarkKeygen-8   	   36510	     32242 ns/op
BenchmarkSign-8     	   10000	    117060 ns/op
BenchmarkVerify-8   	   31605	     34123 ns/op

See #113

@bwesterb bwesterb requested a review from armfazh May 8, 2020 18:34
Copy link
Contributor

@mmcloughlin mmcloughlin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! So exciting to see you using avo 😄

return Mem{Base: state_ptr, Disp: 32 * offset}
}

rc_ptr := Load(Param("rc"), GP64())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't mind duplicating the round constants you could also define them as a DATA section.

https://github.com/mmcloughlin/avo/tree/master/examples/data

// We want state to be 16 byte alligned. Go only guarantees an
// (array of) uint64 to be alligned on 8 bytes except if it's heap
// allocated. Thus we do not add the noescape pragma to ensure that
// state is heap allocated and thus 16 byte alligned.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, this feels fragile. Do you need alignment for VMOVDQA, or is there another reason? Have you tried VMOVDQU? I suspect the unaligned version will incur a small or negligible performance hit, probably worth it for correctness.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know how to request the go compiler for aligned memory ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mmcloughlin Memory is not just loaded at the start, but it’s read and loaded between each round. I think it’ll have a measurable impact. (Vector operations which do allow unaligned pointers are generally still faster with aligned pointers.)

@armfazh I don’t think there is good support for it yet.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bwesterb I made the change to VMOVDQU locally and benchmarked it each time.

$ cat aligned.txt unaligned.txt 
goos: linux
goarch: amd64
pkg: github.com/cloudflare/circl/internal/shakex4
BenchmarkF1600x4-8   	 2147116	       556 ns/op
PASS
ok  	github.com/cloudflare/circl/internal/shakex4	1.762s
goos: linux
goarch: amd64
pkg: github.com/cloudflare/circl/internal/shakex4
BenchmarkF1600x4-8   	 2194231	       546 ns/op
PASS
ok  	github.com/cloudflare/circl/internal/shakex4	1.754s

Unaligned was faster in this quick uncontrolled test, which is probably noise, but that just goes to show how small the effect is. I encourage you to benchmark this yourself, and I think you'll find the performance change (if any at all) is so small that it's not worth the trouble of forcing Go to give you aligned memory, since as you say there is not good support for it.

https://lemire.me/blog/2012/05/31/data-alignment-for-speed-myth-or-reality/

I'm happy to be proven wrong, but I just think you should use the unaligned access until you can demonstrate alignment really helps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The distinction in time is not observed in our current processors and laptops, say after haswell u-arch (2014 approx).
But in it will appear in older processors. In those one probably see a difference in performance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

256-bit registers appear in the Sandy Bridge arch and it was called AVX. Hence, there was load aligned and unaligned before haswell.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but for the purposes of this PR, that doesn't matter. Other instructions in this code require AVX2.

Copy link
Member Author

@bwesterb bwesterb May 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran my own benchmarks. There doesn't seem to be a big difference between VMOVDQA and VMOVDQU for 16 byte aligned memory. If I change alignment, I get a noticeable slowdown. (I ran the benchmarks interleaved and then sorted them to account for my CPU throttling.)

BenchmarkF1600x4-8               2783442           428 ns/op                     
BenchmarkF1600x4-8               2850756           421 ns/op                     
BenchmarkF1600x4-8               2858080           431 ns/op                     
BenchmarkF1600x4aligned-8        2700562           430 ns/op                     
BenchmarkF1600x4aligned-8        2888199           437 ns/op                     
BenchmarkF1600x4aligned-8        2889950           425 ns/op                     
BenchmarkF1600x4plusEight-8      2475714           487 ns/op                     
BenchmarkF1600x4plusEight-8      2558049           455 ns/op                     
BenchmarkF1600x4plusEight-8      2564606           465 ns/op                     
BenchmarkF1600x4plusFour-8       2580600           463 ns/op                     
BenchmarkF1600x4plusFour-8       2585941           475 ns/op                     
BenchmarkF1600x4plusFour-8       2585942           467 ns/op                     
BenchmarkF1600x4plusOne-8        2505846           477 ns/op                     
BenchmarkF1600x4plusOne-8        2581089           479 ns/op                     
BenchmarkF1600x4plusOne-8        2647069           463 ns/op                     
BenchmarkF1600x4plusTwo-8        2512466           474 ns/op                     
BenchmarkF1600x4plusTwo-8        2519284           470 ns/op                     
BenchmarkF1600x4plusTwo-8        2531073           478 ns/op

The plusN are the variants where I misalign the buffer by N bytes. The aligned variant uses VMOVDQA and the one without anything uses VMOVDQU. Code for the plusTwo:

func BenchmarkF1600x4plusTwo(b *testing.B) {
	a  := new([102]uint64)

	for i := 0; i < b.N; i++ {
		f1600x4AVX2((*[100]uint64)(unsafe.Pointer(&(*[404]byte)(unsafe.Pointer(a))[2])),
			&shake.RC)
	}
}

and the rest is similar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh awesome, interesting results.

I would still recommend switching to VMOVQU, but keep the no-noescape technique you are using to get aligned memory from Go. You'll get the performance gain, but it will continue to work if you don't get aligned memory for whatever reason.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The no-noescape caused quite a bit of allocator overhead, so instead I opted to allocate an [103]uint64 and move my offset to ensure 32 byte alignment.

func main() {
ConstraintExpr("amd64")

TEXT("f1600x4AVX2", 0, "func(state *[100]uint64, rc *[24]uint64)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't look like you are using stack space? In that case you can use the NOSPLIT flag.

Copy link
Member Author

@bwesterb bwesterb May 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you’re right. Thanks.

@bwesterb
Copy link
Member Author

I've tried not unrolling four loops, but simply looping over the rounds

func f1600x4AVX2small() {
	TEXT("f1600x4AVX2small", NOSPLIT, "func(state *[100]uint64, rc *[24]uint64)")

	state_ptr := Load(Param("state"), GP64())
	state := func(offset int) Op {
		return Mem{Base: state_ptr, Disp: 32 * offset}
	}

	rc_ptr := Load(Param("rc"), GP64())

	round := GP64()
	XORQ(round, round)

	Label("loop")

	// Compute parities: p[i] = a[i] ^ a[i + 5] ^ ... ^ a[i + 20].
	p := []Op{YMM(), YMM(), YMM(), YMM(), YMM()}
	for i := 0; i < 5; i++ {
		VMOVDQA(state(i), p[i])
	}
	for j := 1; j < 5; j++ {
		for i := 0; i < 5; i++ {
			VPXOR(state(5*j+i), p[i], p[i])
		}
	}

	// Rotate and xor parities: d[i] = rotate_left(p[i+1], 1) ^ p[i-1]
	t := []Op{YMM(), YMM(), YMM(), YMM(), YMM()}
	d := []Op{YMM(), YMM(), YMM(), YMM(), YMM()}
	for i := 0; i < 5; i++ {
		VPSLLQ(U8(1), p[(i+1)%5], t[i])
	}
	for i := 0; i < 5; i++ {
		VPSRLQ(U8(63), p[(i+1)%5], d[i])
	}
	for i := 0; i < 5; i++ {
		VPOR(t[i], d[i], d[i])
	}
	for i := 0; i < 5; i++ {
		VPXOR(d[i], p[(i+4)%5], d[i])
	}

	// Inverse of the permutation π
	invPi := []int{0, 6, 12, 18, 24, 3, 9, 10, 16, 22, 1, 7,
		13, 19, 20, 4, 5, 11, 17, 23, 2, 8, 14, 15, 21}

	// Appropriate rotations
	rot := []int{0, 44, 43, 21, 14, 28, 20, 3, 45, 61, 1, 6,
		25, 8, 18, 27, 36, 10, 15, 56, 62, 55, 39, 41, 2}

	// Instead of executing ρ, π, χ one at a time writing back the
	// state after each step, we postpone the permutation π to the end
	// (using π⁻¹ χ π instad of χ) and do this in five independant chunks.
	for j := 0; j < 5; j++ {
		s := []Op{YMM(), YMM(), YMM(), YMM(), YMM()}

		// Load the right five words from the state and XOR d into them.
		for i := 0; i < 5; i++ {
			idx := invPi[5*j+i]
			VPXOR(state(idx), d[idx%5], s[i])
		}

		// Rotate each s[i] by the appropriate amount --- this is ρ
		for i := 0; i < 5; i++ {
			cr := rot[5*j+i]
			if cr != 0 {
				VPSLLQ(U8(cr), s[i], t[i])
			}
		}
		for i := 0; i < 5; i++ {
			cr := rot[5*j+i]
			if cr != 0 {
				VPSRLQ(U8(64-cr), s[i], s[i])
			}
		}
		for i := 0; i < 5; i++ {
			if rot[5*j+i] != 0 {
				VPOR(t[i], s[i], s[i])
			}
		}

		// Compute the new words s[i] ^ (s[i+2] & ~s[i+1]) --- this is χ
		for i := 0; i < 5; i++ {
			VPANDN(s[(i+2)%5], s[(i+1)%5], t[i])
		}
		for i := 0; i < 5; i++ {
			VPXOR(s[i], t[i], t[i])
		}

		// Round constant
		if j == 0 {
			rc := YMM()
			VPBROADCASTQ(Mem{Base: rc_ptr, Scale: 8, Index: round}, rc)
			VPXOR(rc, t[0], t[0])
		}

		// Store back into state
		for i := 0; i < 5; i++ {
			VMOVDQA(t[i], state(invPi[j*5+i]))
		}
	}

	// // Finally execute π^{-1}.
	// // XXX optimize
	state1 := YMM()
	tmp1 := YMM()
	last := 1
	VMOVDQA(state(1), state1)
	for j := 0; j < 23; j += 1 {
		VMOVDQA(state(invPi[last]), tmp1)
		VMOVDQA(tmp1, state(last))
		last = invPi[last]
	}
	VMOVDQA(state1, state(10))

	INCQ(round)
	CMPQ(round, Imm(24))
	JNE(LabelRef("loop"))

	RET()
}

This is unfortunately slower

BenchmarkF1600x4AVX2-8        	 2941527	       415 ns/op
BenchmarkF1600x4AVX2small-8   	 2277976	       531 ns/op

(even if I remove the patently unoptimised permutation π⁻¹ at the end.)

@bwesterb bwesterb changed the title [WIP] 4way SHAKE using AVX2 [WIP] 4-way KeccaK-f[1600] using AVX2 to speed up Dilithium May 10, 2020
@bwesterb bwesterb force-pushed the shakex4 branch 3 times, most recently from 2c7c717 to c7f3d2f Compare May 12, 2020 09:16
Non AVX2:

    BenchmarkPermutationFunction-8   	 3594723	       326 ns/op

AVX2:

    BenchmarkF1600x4-8   	 2926183	       399 ns/op

Dilithium Mode 3 using 4-way f[1600]:

    BenchmarkSkUnpack-8            	   40710	     29802 ns/op
    BenchmarkPkUnpack-8            	   40216	     29125 ns/op
    BenchmarkVerify-8              	   41216	     29493 ns/op
    BenchmarkSign-8                	    5734	    225791 ns/op
    BenchmarkGenerateKey-8         	   17356	     68622 ns/op
    BenchmarkPublicFromPrivate-8   	  105301	     11576 ns/op

Without:

    BenchmarkSkUnpack-8            	   18980	     61840 ns/op
    BenchmarkPkUnpack-8            	   19743	     60087 ns/op
    BenchmarkVerify-8              	   39998	     30421 ns/op
    BenchmarkSign-8                	    5420	    228631 ns/op
    BenchmarkGenerateKey-8         	   12037	     99672 ns/op
    BenchmarkPublicFromPrivate-8   	  107148	     11512 ns/op
@bwesterb bwesterb changed the title [WIP] 4-way KeccaK-f[1600] using AVX2 to speed up Dilithium 4-way KeccaK-f[1600] using AVX2 to speed up Dilithium May 13, 2020
@bwesterb
Copy link
Member Author

@armfazh Ready for review.

Copy link
Contributor

@armfazh armfazh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor changes.

sign/dilithium/mode3/internal/sample.go Outdated Show resolved Hide resolved
var perm f1600x4.State
state := perm.Initialize()

// Absorb the seed in the four states
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code looks to be repeated in the previous function, maybe it will go into a small function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's slightly different: the bound on i is 4 in one and 6 in the other

sign/dilithium/mode3/internal/sample.go Outdated Show resolved Hide resolved
sign/dilithium/mode3/internal/sample_test.go Show resolved Hide resolved
Copy link
Contributor

@armfazh armfazh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of comments.
Everything else looks good.

sign/dilithium/mode3/internal/sample.go Outdated Show resolved Hide resolved
sign/dilithium/mode3/internal/sample.go Outdated Show resolved Hide resolved
sign/dilithium/mode3/internal/sample.go Outdated Show resolved Hide resolved
@bwesterb
Copy link
Member Author

@armfazh I addressed all issues you raised.

@armfazh armfazh self-requested a review May 21, 2020 16:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants