We saw in the addition part 1 multiple ways to represent our field elements which is key to optimization.
The final version was 4 x u52 and 1 x 48 to store our 256 bits leaving the extra bits for the carry.
To make it easier we can start working with 4 u64 digits as we did for the addition, so a × b:
I labeled each column with ti with i := 0 → 7 so it’s easier. And finally we also define ci with i := 0 → 6 as the carries for each column.
Since we represent a and b with 4 digits, there are now 8 terms in our result:
Now we have an issue here, because most of these terms can exceed 128 bits while we only have a u128. For example t1
Fortunately if we go back to our u52 strategy, we don’t have this issue anymore because at worse t3
Let’s get back to our u52 baby
Now our beautiful multiplication looks more like:
We can also see our result R as R = R1 ⋅ 2256 + R0 with:
Now because of the size of R1, it’s actually more interesting to reduce our result modulo P directly in the multiplication as opposed to our previous add function in which we could store the carry in the extra bits.
The final purpose is to be able to make multiplications of numbers close to P so the carry R1 won’t fit in our 64-bits reserved space.
Similarly to the addition, we have 3 cases:
R = R0
R = R0 − P
R = R0 + R1 ⋅ (2256−P)
For more clarity, let’s define P0 = 2256 − P = 232 + 977, we can now have:
Cool, now we have new carries to propagate:
It’s actually way easier to see it this way:
pub fn mul(&self, b: &Self) -> Self {
const M52: u128 = 0x000fffffffffffffu128; // 2^52 - 1
const M48: u128 = 0x0000ffffffffffffu128; // 2^48 - 1
const P0: u128 = 0x1000003d1u128; // 2^32 + 977
let (a0, a1, a2, a3, a4) = (self.d[0], self.d[1], self.d[2], self.d[3], self.d[4]);
let (b0, b1, b2, b3, b4) = (b.d[0], b.d[1], b.d[2], b.d[3], b.d[4]);
let (
mut t0, mut t1, mut t2,
mut t3, mut t4, mut t5,
mut t6, mut t7, mut t8
): (
u64, u64, u64,
u64, u64, u64,
u64, u64, u64
);
let t9: u64;
let mut t: u128;
let mut c: u128;
t = a0 as u128 * b0 as u128;
t0 = (t & M52) as u64;
t >>= 52;
t += a0 as u128 * b1 as u128 +
a1 as u128 * b0 as u128;
t1 = (t & M52) as u64;
t >>= 52;
t += a0 as u128 * b2 as u128 +
a1 as u128 * b1 as u128 +
a2 as u128 * b0 as u128;
t2 = (t & M52) as u64;
t >>= 52;
t += a0 as u128 * b3 as u128 +
a1 as u128 * b2 as u128 +
a2 as u128 * b1 as u128 +
a3 as u128 * b0 as u128;
t3 = (t & M52) as u64;
t >>= 52;
t += a0 as u128 * b4 as u128 +
a1 as u128 * b3 as u128 +
a2 as u128 * b2 as u128 +
a3 as u128 * b1 as u128 +
a4 as u128 * b0 as u128;
t4 = (t & M52) as u64;
t >>= 52;
c = t4 as u128 >> 48;
t4 &= M48 as u64;
t += a1 as u128 * b4 as u128 +
a2 as u128 * b3 as u128 +
a3 as u128 * b2 as u128 +
a4 as u128 * b1 as u128;
t5 = (t & M52) as u64;
t >>= 52;
t5 = t5 << 4 | c as u64;
c = t5 as u128 >> 52;
t5 &= M52 as u64;
t += a2 as u128 * b4 as u128 +
a3 as u128 * b3 as u128 +
a4 as u128 * b2 as u128;
t6 = (t & M52) as u64;
t >>= 52;
t6 = t6 << 4 | c as u64;
c = t6 as u128 >> 52;
t6 &= M52 as u64;
t += a3 as u128 * b4 as u128 +
a4 as u128 * b3 as u128;
t7 = (t & M52) as u64;
t >>= 52;
t7 = t7 << 4 | c as u64;
c = t7 as u128 >> 52;
t7 &= M52 as u64;
t += a4 as u128 * b4 as u128;
t8 = (t & M52) as u64;
t >>= 52;
t8 = t8 << 4 | c as u64;
c = t8 as u128 >> 52;
t8 &= M52 as u64;
t9 = (t << 4 | c) as u64;
// 1st reduction R = R1 + R0 * P0
t = t0 as u128 + t5 as u128 * P0;
t0 = (t & M52) as u64;
t >>= 52;
t += t1 as u128 + t6 as u128 * P0;
t1 = (t & M52) as u64;
t >>= 52;
t += t2 as u128 + t7 as u128 * P0;
t2 = (t & M52) as u64;
t >>= 52;
t += t3 as u128 + t8 as u128 * P0;
t3 = (t & M52) as u64;
t >>= 52;
t += t4 as u128 + t9 as u128 * P0;
t4 = (t & M52) as u64;
t >>= 52;
c = (t4 >> 48) as u128;
t4 &= M48 as u64;
c = c | t << 4;
// 2nd pass
t = t0 as u128 + c * P0;
t0 = (t & M52) as u64;
t >>= 52;
t += t1 as u128;
t1 = (t & M52) as u64;
t >>= 52;
t += t2 as u128;
t2 = (t & M52) as u64;
t >>= 52;
t += t3 as u128;
t3 = (t & M52) as u64;
t >>= 52;
t += t4 as u128;
t4 = (t & M48) as u64;
Self { d: [t0, t1, t2, t3, t4] }
}
Finally let’s write a quick test:
#[test]
fn it_multiply_field_elements() {
// A=0xfffffffffffffffffffffffffffffffffffffffffffffffffffffbfefffffc2f = p - 2^42
// B=0xfffffffffffffffffffffffffffffffffffffffffffffffffffff7fefffffc2f = p - 2^43
let a = Fe::new(
0xffffffffffffffffu64,
0xffffffffffffffffu64,
0xffffffffffffffffu64,
0xfffffbfefffffc2fu64,
);
let b = Fe::new(
0xffffffffffffffffu64,
0xffffffffffffffffu64,
0xffffffffffffffffu64,
0xfffff7fefffffc2fu64,
);
let r = a.mul(&b);
// r = ((p - 2^42) * (p - 2^43)) % p
let expected = Fe::new(
0x0000000000000000u64,
0x0000000000000000u64,
0x0000000000200000u64,
0x0000000000000000u64,
);
assert_eq!(r, expected);
}
and we now have a multiplication.
Verbose, but it works.
Okay, now let’s just think about it, we actually don’t have to propate the carry after the 2nd pass. Same as the addition, we can just use our carry-save storage.
A very simple way to prove it is to work with the worst case scenario: (P−1)(P−1)
After the 1st pass, it gives us a result R0 = 1000003d0fffffffffffffffffffffffffffffffffffffffffffffffefffff85dfff16f6016 Or R0 = 2256 × 1000003d016 + fffffffffffffffffffffffffffffffffffffffffffffffefffff85dfff16f6016
Now applying the 2nd pass:
1000003d016 × (232+977) is 65 bits so our first digit needs 66 bits. We only have 12 extra bits on the 1st digit d0, so we need to propagate the carry on the 2nd digit d1
And we don’t have to go further for now, we can still work with few operations before we overflow our 320 bits storage.
Let’s remove the last carry propagation on t3 and t4:
// ...
// 2nd pass
t = t0 as u128 + c * P0;
t0 = (t & M52) as u64;
t >>= 52;
t += t1 as u128;
t1 = (t & M52) as u64;
Self { d: [t0, t1, t2, t3, t4] }
Another optimization would be to combine the operations on t0, we have:
// 1st pass
= t0 as u128 + t5 as u128 * P0;
t = (t & M52) as u64;
t0 // ...
// 2nd pass
= t0 as u128 + c * P0;
t = (t & M52) as u64; t0
Could be nice to combine those just so we save some instructions, better to do t0 + (a+b)P0 than t0 + aP0 + bP0
We can use the fact that we don’t need to propagate the carry over t3 and t4 on the 2nd pass. Therefore, we can work with them first.
pub fn mul(&mut self, b: &Self) -> Self {
const M52: u128 = 0x000fffffffffffffu128; // 2^52 - 1
const M48: u64 = 0x0000ffffffffffffu64; // 2^48 - 1
const P0: u128 = 0x1000003d1u128; // 2^32 + 977
const P1: u128 = 0x1000003d10u128; // 2^32 + 977 << 4
let (a0, a1, a2, a3, a4) = (
self.d[0] as u128, self.d[1] as u128, self.d[2] as u128,
self.d[3] as u128, self.d[4] as u128
;
)let (b0, b1, b2, b3, b4) = (
.d[0] as u128, b.d[1] as u128, b.d[2] as u128,
b.d[3] as u128, b.d[4] as u128
b;
)let mut tx: u128;
let mut cx: u128;
let (t0, t1, t2, mut t3, mut t4, mut t5): (u64, u64, u64, u64, u64, u128);
let c4: u64;
// t3
= a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0;
tx // t8
= a4 * b4;
cx // t3 + t8 * P1
+= (cx & M52) * P1;
tx >>= 52;
cx = (tx & M52) as u64;
t3 >>= 52;
tx
// t4
+= a0 * b4 + a1 * b3 + a2 * b2 + a3 * b1 + a4 * b0;
tx // (c3 + t4) + (c8 + t9) * P1
+= cx * P1;
tx = (tx & M52) as u64;
t4 >>= 52;
tx = t4 >> 48;
c4 &= M48;
t4
// t5
= tx + a1 * b4 + a2 * b3 + a3 * b2 + a4 * b1;
cx // t0
= a0 * b0;
tx = cx & M52;
t5 >>= 52;
cx = (t5 << 4) | c4 as u128;
t5 // c9 + t0 + (c4 + t5) * P0
+= t5 * P0;
tx = (tx & M52) as u64;
t0 >>= 52;
tx
// t1
+= a0 * b1 + a1 * b0;
tx // t6
+= a2 * b4 + a3 * b3 + a4 * b2;
cx // c0 + t1 + (c5 + t6) * P1
+= (cx & M52) * P1;
tx >>= 52;
cx = (tx & M52) as u64;
t1 >>= 52;
tx
// t2
+= a0 * b2 + a1 * b1 + a2 * b0;
tx // t7
+= a3 * b4 + a4 * b3;
cx // t2 + t7 * P1
+= (cx & M52) * P1;
tx >>= 52;
cx = (tx & M52) as u64;
t2 >>= 52;
tx
// t23
+= cx * P1 + t3 as u128;
tx = (tx & M52) as u64;
t3 >>= 52;
tx // t24
+= t4 as u128;
tx = tx as u64;
t4
Self { d: [t0, t1, t2, t3, t4] }
}