#include "risc0/zkp/prove/fri.h"
#include <memory>
#include "risc0/core/log.h"
#include "risc0/core/util.h"
#include "risc0/zkp/core/constants.h"
#include "risc0/zkp/prove/merkle.h"
namespace risc0 {
namespace {
struct ProveRoundInfo {
size_t size;
size_t domain;
AccelSlice<Fp> evaluated;
AccelSlice<Fp> outCoeffs;
std::unique_ptr<MerkleTreeProver> merkle;
ProveRoundInfo(WriteIOP& iop, AccelConstSlice<Fp> coeffs)
: size(coeffs.size() / 4) , domain(size * kInvRate)
, evaluated(AccelSlice<Fp>::allocate(domain * 4))
, outCoeffs(AccelSlice<Fp>::allocate(size / kFriFold * 4)) {
LOG(1, "Doing FRI folding");
batchExpand(evaluated, coeffs, 4);
batchEvaluateNTT(evaluated, 4, log2Ceil(kInvRate));
merkle =
std::make_unique<MerkleTreeProver>(evaluated, domain / kFriFold, kFriFold * 4, kQueries);
merkle->commit(iop);
Fp4 foldMix = Fp4::random(iop);
friFoldAccel(outCoeffs, coeffs, AccelSlice<Fp4>::copy(&foldMix, 1));
}
void proveQuery(WriteIOP& iop, size_t* pos) const {
size_t group = *pos % (domain / kFriFold);
merkle->prove(iop, group);
*pos = group;
}
};
}
void friProve(WriteIOP& iop, AccelConstSlice<Fp> coeffs, InnerProve inner) {
size_t origDomain = coeffs.size() / 4 * kInvRate;
std::vector<ProveRoundInfo> rounds;
while (coeffs.size() / 4 > kFriMinDegree) {
rounds.emplace_back(iop, coeffs);
coeffs = rounds.back().outCoeffs;
}
auto final = AccelSlice<Fp>::allocate(coeffs.size());
eltwiseCopyFpAccel(final, coeffs);
batchBitReverse(final, 4);
{
AccelReadLock<Fp> finalCpu(final);
iop.write(finalCpu.data(), finalCpu.size());
auto digest = shaHash(finalCpu.data(), finalCpu.size(), 1, false);
iop.commit(digest);
}
LOG(1, "Doing Queries");
for (size_t q = 0; q < kQueries; q++) {
uint32_t rng = iop.generate();
size_t pos = rng % origDomain;
inner(iop, pos);
for (auto& round : rounds) {
round.proveQuery(iop, &pos);
}
}
}
}