We saw in the addition part 1 multiple ways to represent our field elements which is key to optimization.
The final version was 4 x u52 and 1 x 48 to store our 256 bits leaving the extra bits for the carry.
To make it easier we can start working with 4 u64 digits as we did for the addition, so
I labeled each column with
Since we represent
Now we have an issue here, because most of these terms can exceed 128 bits while we only have a u128. For example
Fortunately if we go back to our u52 strategy, we don’t have this issue anymore because at worse
Let’s get back to our u52 baby
Now our beautiful multiplication looks more like:
We can also see our result R as
Now because of the size of
The final purpose is to be able to make multiplications of numbers close to
Similarly to the addition, we have 3 cases:
For more clarity, let’s define
Cool, now we have new carries to propagate:
It’s actually way easier to see it this way:
pub fn mul(&self, b: &Self) -> Self {
const M52: u128 = 0x000fffffffffffffu128; // 2^52 - 1
const M48: u128 = 0x0000ffffffffffffu128; // 2^48 - 1
const P0: u128 = 0x1000003d1u128; // 2^32 + 977
let (a0, a1, a2, a3, a4) = (self.d[0], self.d[1], self.d[2], self.d[3], self.d[4]);
let (b0, b1, b2, b3, b4) = (b.d[0], b.d[1], b.d[2], b.d[3], b.d[4]);
let (
mut t0, mut t1, mut t2,
mut t3, mut t4, mut t5,
mut t6, mut t7, mut t8
): (
u64, u64, u64,
u64, u64, u64,
u64, u64, u64
);
let t9: u64;
let mut t: u128;
let mut c: u128;
t = a0 as u128 * b0 as u128;
t0 = (t & M52) as u64;
t >>= 52;
t += a0 as u128 * b1 as u128 +
a1 as u128 * b0 as u128;
t1 = (t & M52) as u64;
t >>= 52;
t += a0 as u128 * b2 as u128 +
a1 as u128 * b1 as u128 +
a2 as u128 * b0 as u128;
t2 = (t & M52) as u64;
t >>= 52;
t += a0 as u128 * b3 as u128 +
a1 as u128 * b2 as u128 +
a2 as u128 * b1 as u128 +
a3 as u128 * b0 as u128;
t3 = (t & M52) as u64;
t >>= 52;
t += a0 as u128 * b4 as u128 +
a1 as u128 * b3 as u128 +
a2 as u128 * b2 as u128 +
a3 as u128 * b1 as u128 +
a4 as u128 * b0 as u128;
t4 = (t & M52) as u64;
t >>= 52;
c = t4 as u128 >> 48;
t4 &= M48 as u64;
t += a1 as u128 * b4 as u128 +
a2 as u128 * b3 as u128 +
a3 as u128 * b2 as u128 +
a4 as u128 * b1 as u128;
t5 = (t & M52) as u64;
t >>= 52;
t5 = t5 << 4 | c as u64;
c = t5 as u128 >> 52;
t5 &= M52 as u64;
t += a2 as u128 * b4 as u128 +
a3 as u128 * b3 as u128 +
a4 as u128 * b2 as u128;
t6 = (t & M52) as u64;
t >>= 52;
t6 = t6 << 4 | c as u64;
c = t6 as u128 >> 52;
t6 &= M52 as u64;
t += a3 as u128 * b4 as u128 +
a4 as u128 * b3 as u128;
t7 = (t & M52) as u64;
t >>= 52;
t7 = t7 << 4 | c as u64;
c = t7 as u128 >> 52;
t7 &= M52 as u64;
t += a4 as u128 * b4 as u128;
t8 = (t & M52) as u64;
t >>= 52;
t8 = t8 << 4 | c as u64;
c = t8 as u128 >> 52;
t8 &= M52 as u64;
t9 = (t << 4 | c) as u64;
// 1st reduction R = R1 + R0 * P0
t = t0 as u128 + t5 as u128 * P0;
t0 = (t & M52) as u64;
t >>= 52;
t += t1 as u128 + t6 as u128 * P0;
t1 = (t & M52) as u64;
t >>= 52;
t += t2 as u128 + t7 as u128 * P0;
t2 = (t & M52) as u64;
t >>= 52;
t += t3 as u128 + t8 as u128 * P0;
t3 = (t & M52) as u64;
t >>= 52;
t += t4 as u128 + t9 as u128 * P0;
t4 = (t & M52) as u64;
t >>= 52;
c = (t4 >> 48) as u128;
t4 &= M48 as u64;
c = c | t << 4;
// 2nd pass
t = t0 as u128 + c * P0;
t0 = (t & M52) as u64;
t >>= 52;
t += t1 as u128;
t1 = (t & M52) as u64;
t >>= 52;
t += t2 as u128;
t2 = (t & M52) as u64;
t >>= 52;
t += t3 as u128;
t3 = (t & M52) as u64;
t >>= 52;
t += t4 as u128;
t4 = (t & M48) as u64;
Self { d: [t0, t1, t2, t3, t4] }
}
Finally let’s write a quick test:
#[test]
fn it_multiply_field_elements() {
// A=0xfffffffffffffffffffffffffffffffffffffffffffffffffffffbfefffffc2f = p - 2^42
// B=0xfffffffffffffffffffffffffffffffffffffffffffffffffffff7fefffffc2f = p - 2^43
let a = Fe::new(
0xffffffffffffffffu64,
0xffffffffffffffffu64,
0xffffffffffffffffu64,
0xfffffbfefffffc2fu64,
);
let b = Fe::new(
0xffffffffffffffffu64,
0xffffffffffffffffu64,
0xffffffffffffffffu64,
0xfffff7fefffffc2fu64,
);
let r = a.mul(&b);
// r = ((p - 2^42) * (p - 2^43)) % p
let expected = Fe::new(
0x0000000000000000u64,
0x0000000000000000u64,
0x0000000000200000u64,
0x0000000000000000u64,
);
assert_eq!(r, expected);
}
and we now have a multiplication.
Verbose, but it works.
Okay, now let’s just think about it, we actually don’t have to propate the carry after the 2nd pass. Same as the addition, we can just use our carry-save storage.
A very simple way to prove it is to work with the worst case scenario:
After the 1st pass, it gives us a result
Now applying the 2nd pass:
And we don’t have to go further for now, we can still work with few operations before we overflow our 320 bits storage.
Let’s remove the last carry propagation on t3 and t4:
// ...
// 2nd pass
t = t0 as u128 + c * P0;
t0 = (t & M52) as u64;
t >>= 52;
t += t1 as u128;
t1 = (t & M52) as u64;
Self { d: [t0, t1, t2, t3, t4] }
Another optimization would be to combine the operations on
// 1st pass
= t0 as u128 + t5 as u128 * P0;
t = (t & M52) as u64;
t0 // ...
// 2nd pass
= t0 as u128 + c * P0;
t = (t & M52) as u64; t0
Could be nice to combine those just so we save some instructions, better to do
We can use the fact that we don’t need to propagate the carry over
pub fn mul(&mut self, b: &Self) -> Self {
const M52: u128 = 0x000fffffffffffffu128; // 2^52 - 1
const M48: u64 = 0x0000ffffffffffffu64; // 2^48 - 1
const P0: u128 = 0x1000003d1u128; // 2^32 + 977
const P1: u128 = 0x1000003d10u128; // 2^32 + 977 << 4
let (a0, a1, a2, a3, a4) = (
self.d[0] as u128, self.d[1] as u128, self.d[2] as u128,
self.d[3] as u128, self.d[4] as u128
;
)let (b0, b1, b2, b3, b4) = (
.d[0] as u128, b.d[1] as u128, b.d[2] as u128,
b.d[3] as u128, b.d[4] as u128
b;
)let mut tx: u128;
let mut cx: u128;
let (t0, t1, t2, mut t3, mut t4, mut t5): (u64, u64, u64, u64, u64, u128);
let c4: u64;
// t3
= a0 * b3 + a1 * b2 + a2 * b1 + a3 * b0;
tx // t8
= a4 * b4;
cx // t3 + t8 * P1
+= (cx & M52) * P1;
tx >>= 52;
cx = (tx & M52) as u64;
t3 >>= 52;
tx
// t4
+= a0 * b4 + a1 * b3 + a2 * b2 + a3 * b1 + a4 * b0;
tx // (c3 + t4) + (c8 + t9) * P1
+= cx * P1;
tx = (tx & M52) as u64;
t4 >>= 52;
tx = t4 >> 48;
c4 &= M48;
t4
// t5
= tx + a1 * b4 + a2 * b3 + a3 * b2 + a4 * b1;
cx // t0
= a0 * b0;
tx = cx & M52;
t5 >>= 52;
cx = (t5 << 4) | c4 as u128;
t5 // c9 + t0 + (c4 + t5) * P0
+= t5 * P0;
tx = (tx & M52) as u64;
t0 >>= 52;
tx
// t1
+= a0 * b1 + a1 * b0;
tx // t6
+= a2 * b4 + a3 * b3 + a4 * b2;
cx // c0 + t1 + (c5 + t6) * P1
+= (cx & M52) * P1;
tx >>= 52;
cx = (tx & M52) as u64;
t1 >>= 52;
tx
// t2
+= a0 * b2 + a1 * b1 + a2 * b0;
tx // t7
+= a3 * b4 + a4 * b3;
cx // t2 + t7 * P1
+= (cx & M52) * P1;
tx >>= 52;
cx = (tx & M52) as u64;
t2 >>= 52;
tx
// t23
+= cx * P1 + t3 as u128;
tx = (tx & M52) as u64;
t3 >>= 52;
tx // t24
+= t4 as u128;
tx = tx as u64;
t4
Self { d: [t0, t1, t2, t3, t4] }
}