#include "risc0/zkp/prove/prove.h"
#include "risc0/core/log.h"
#include "risc0/core/rng.h"
#include "risc0/zkp/core/constants.h"
#include "risc0/zkp/core/poly.h"
#include "risc0/zkp/core/rou.h"
#include "risc0/zkp/prove/fri.h"
#include "risc0/zkp/prove/poly_group.h"
#include "risc0/zkp/prove/write_iop.h"
namespace risc0 {
static AccelSlice<Fp> makeCoeffs(const std::vector<Fp>& vec, size_t count) {
auto ret = AccelSlice<Fp>::copy(vec);
batchInterpolateNTT(ret, count);
#ifndef CIRCUIT_DEBUG
zkShiftAccel(ret, count);
#endif
return ret;
}
std::vector<uint32_t> prove(ProveCircuit& circuit) {
TapSetRef tapSet = circuit.getTaps();
WriteIOP iop;
circuit.execute(iop);
uint32_t po2 = circuit.getPo2();
REQUIRE(po2 <= kMaxCyclesPo2);
size_t size = size_t(1) << po2;
size_t codeSize = tapSet.groupSize(RegisterGroup::CODE);
size_t dataSize = tapSet.groupSize(RegisterGroup::DATA);
size_t accumSize = tapSet.groupSize(RegisterGroup::ACCUM);
size_t comboCount = tapSet.combosSize();
PolyGroup codeGroup(makeCoeffs(circuit.getCode(), codeSize), codeSize, size);
PolyGroup dataGroup(makeCoeffs(circuit.getData(), dataSize), dataSize, size);
codeGroup.getMerkle().commit(iop);
dataGroup.getMerkle().commit(iop);
LOG(1, "codeGroup: " << codeGroup.getMerkle().getRoot());
LOG(1, "dataGroup: " << dataGroup.getMerkle().getRoot());
circuit.accumulate(iop);
LOG(1, "size = " << size << ", accumSize = " << accumSize);
LOG(1, "getAccum.size() = " << circuit.getAccum().size());
PolyGroup accumGroup(makeCoeffs(circuit.getAccum(), accumSize), accumSize, size);
accumGroup.getMerkle().commit(iop);
LOG(1, "accumGroup: " << accumGroup.getMerkle().getRoot());
Fp4 polyMix = Fp4::random(iop);
size_t domain = size * kInvRate;
auto checkPoly = AccelSlice<Fp>::allocate(4 * domain);
circuit.evalCheck(checkPoly,
codeGroup.getEvaluated(),
dataGroup.getEvaluated(),
accumGroup.getEvaluated(),
polyMix);
#ifdef CIRCUIT_DEBUG
Fp4 badZ; {
AccelReadLock<Fp> lock(checkPoly);
for (size_t i = 0; i < 4 * domain; i += 4) {
if (lock[i] != 0) {
LOG(1, "ERROR AT i = " << i << " value = " << lock[i]);
badZ = Fp4(pow(kRouFwd[po2], i / 4));
break;
}
}
}
#endif
batchInterpolateNTT(checkPoly, 4);
PolyGroup checkGroup(checkPoly, kCheckSize, size);
checkGroup.getMerkle().commit(iop);
LOG(1, "checkGroup: " << checkGroup.getMerkle().getRoot());
Fp4 Z = Fp4::random(iop);
#ifdef CIRCUIT_DEBUG
if (badZ != Fp4(0)) {
Z = badZ;
}
iop.write(&Z, 1);
#endif
LOG(1, "Z = " << Z);
Fp backOne = kRouRev[po2];
std::vector<Fp4> allXs;
std::vector<Fp4> evalU;
auto evalGroup = [&](RegisterGroup id, PolyGroup& pg) {
AccelConstSlice<Fp> coeffs = pg.getCoeffs();
std::vector<uint32_t> which;
std::vector<Fp4> xs;
for (auto tap : tapSet.groupTaps(id)) {
which.push_back(tap.offset());
xs.push_back(pow(backOne, tap.back()) * Z);
allXs.push_back(xs.back());
}
auto whichAccel = AccelSlice<uint32_t>::copy(which);
auto xsAccel = AccelSlice<Fp4>::copy(xs);
auto outAccel = AccelSlice<Fp4>::allocate(which.size());
batchEvaluateAny(coeffs, pg.getCount(), whichAccel, xsAccel, outAccel);
{
AccelReadLock out(outAccel);
std::copy(out.data(), out.data() + out.size(), std::back_inserter(evalU));
}
};
evalGroup(RegisterGroup::ACCUM, accumGroup);
evalGroup(RegisterGroup::CODE, codeGroup);
evalGroup(RegisterGroup::DATA, dataGroup);
size_t curPos = 0;
std::vector<Fp4> coeffU(evalU.size());
for (auto reg : tapSet.regs()) {
polyInterpolate(
coeffU.data() + curPos, allXs.data() + curPos, evalU.data() + curPos, reg.size());
curPos += reg.size();
}
Fp4 Z4 = pow(Z, 4);
{
std::vector<uint32_t> which;
for (size_t i = 0; i < kCheckSize; i++) {
which.push_back(i);
}
std::vector<Fp4> xs(kCheckSize, Z4);
auto out = AccelSlice<Fp4>::allocate(kCheckSize);
batchEvaluateAny(checkGroup.getCoeffs(),
kCheckSize,
AccelSlice<uint32_t>::copy(which),
AccelSlice<Fp4>::copy(xs),
out);
AccelReadLock outLock(out);
std::copy(outLock.data(), outLock.data() + out.size(), std::back_inserter(coeffU));
}
LOG(1, "Size of U = " << coeffU.size());
iop.write(coeffU.data(), coeffU.size());
auto hashU = shaHash(reinterpret_cast<const Fp*>(coeffU.data()), coeffU.size() * 4, 1, false);
iop.commit(hashU);
Fp4 mix = Fp4::random(iop);
LOG(1, "Mix = " << mix);
auto combos = AccelSlice<Fp4>::copy(std::vector<Fp4>(size * (comboCount + 1)));
Fp4 curMix(1);
auto mixGroup = [&](RegisterGroup id, PolyGroup& pg) {
std::vector<uint32_t> which;
for (auto reg : tapSet.groupRegs(id)) {
which.push_back(reg.comboID());
}
auto whichAccel = AccelSlice<uint32_t>::copy(which);
auto curMixAccel = AccelSlice<Fp4>::copy(&curMix, 1);
auto mixAccel = AccelSlice<Fp4>::copy(&mix, 1);
size_t gsize = tapSet.groupSize(id);
mixPolyCoeffsAccel(combos, curMixAccel, mixAccel, pg.getCoeffs(), whichAccel, gsize, size);
curMix *= pow(mix, gsize);
};
mixGroup(RegisterGroup::ACCUM, accumGroup);
mixGroup(RegisterGroup::CODE, codeGroup);
mixGroup(RegisterGroup::DATA, dataGroup);
{
std::vector<uint32_t> which(kCheckSize, comboCount);
auto whichAccel = AccelSlice<uint32_t>::copy(which);
auto curMixAccel = AccelSlice<Fp4>::copy(&curMix, 1);
auto mixAccel = AccelSlice<Fp4>::copy(&mix, 1);
mixPolyCoeffsAccel(
combos, curMixAccel, mixAccel, checkGroup.getCoeffs(), whichAccel, kCheckSize, size);
}
auto comboCpu = std::make_unique<AccelReadWriteLock<Fp4>>(combos);
curPos = 0;
Fp4 cur = Fp4(1);
for (auto reg : tapSet.regs()) {
for (size_t i = 0; i < reg.size(); i++) {
(*comboCpu)[size * reg.comboID() + i] -= cur * coeffU[curPos + i];
}
cur *= mix;
curPos += reg.size();
}
for (size_t i = 0; i < kCheckSize; i++) {
(*comboCpu)[size * comboCount] -= cur * coeffU[curPos++];
cur *= mix;
}
for (size_t combo = 0; combo < tapSet.combosSize(); combo++) {
for (size_t back : tapSet.getCombo(combo)) {
REQUIRE(polyDivide(comboCpu->data() + combo * size, size, Z * pow(backOne, back)) == Fp4(0));
}
}
REQUIRE(polyDivide(comboCpu->data() + comboCount * size, size, Z4) == Fp4(0));
comboCpu.reset();
auto finalPolyCoeffs = AccelSlice<Fp>::allocate(size * 4);
eltwiseSumFp4Accel(finalPolyCoeffs, combos);
batchBitReverse(finalPolyCoeffs, 4);
LOG(1, "FRI-proof, size = " << finalPolyCoeffs.size() / 4);
friProve(iop, finalPolyCoeffs, [&](WriteIOP& iop, size_t idx) {
accumGroup.getMerkle().prove(iop, idx);
codeGroup.getMerkle().prove(iop, idx);
dataGroup.getMerkle().prove(iop, idx);
checkGroup.getMerkle().prove(iop, idx);
});
std::vector<uint32_t> ret = iop.getProof();
LOG(1, "Proof size = " << ret.size());
return ret;
}
}