#include "risc0/zkvm/prove/riscv.h"
#include "risc0/core/log.h"
#include "risc0/zkp/core/rou.h"
#include "risc0/zkvm/prove/exec.h"
#include "risc0/zkvm/verify/riscv.h"
#include "oneapi/tbb/parallel_for.h"
using oneapi::tbb::parallel_for;
namespace risc0 {
namespace {
class RiscVProveCircuit : public ProveCircuit {
public:
RiscVProveCircuit(const std::string& elfFile, MemoryHandler& io);
TapSetRef getTaps() const override { return getRiscVTaps(); }
void execute(WriteIOP& iop) override;
void accumulate(WriteIOP& iop) override;
void evalCheck( AccelSlice<Fp> check, AccelConstSlice<Fp> codeEval, AccelConstSlice<Fp> dataEval, AccelConstSlice<Fp> accumEval, Fp4 polyMix) const override;
uint32_t getPo2() const override { return po2_; }
const std::vector<Fp>& getCode() const override { return exec_.code; }
const std::vector<Fp>& getData() const override { return exec_.data; }
const std::vector<Fp>& getAccum() const override { return accum_; }
private:
ExecState exec_;
std::vector<Fp> accum_;
MemoryHandler& io_;
uint32_t po2_;
};
}
RiscVProveCircuit::RiscVProveCircuit(const std::string& elfFile, MemoryHandler& io)
: exec_(elfFile), io_(io) {}
void RiscVProveCircuit::execute(WriteIOP& iop) {
exec_.run(kMaxCycles, io_);
po2_ = log2Ceil(exec_.context.numSteps);
size_t size = size_t(1) << po2_;
LOG(1, "size = " << size);
for (size_t i = 0; i < kOutputRegs; i++) {
const Fp* globals = exec_.context.globals;
uint32_t regVal = globals[2 * i].asUInt32() | (globals[2 * i + 1].asUInt32() << 16);
LOG(2, "x" << i + 1 << " = " << hex(regVal));
iop.write(®Val, 1);
}
iop.write(&po2_, 1);
for (size_t i = 0; i < size - kZkCycles; i++) {
exec_.context.curStep = i;
dataStepCheck(exec_.context, exec_.code.data(), exec_.data.data());
}
#ifdef CIRCUIT_DEBUG
for (Fp& x : exec_.data) {
if (x == Fp::invalid()) {
x = 0xdead;
}
}
#endif
#ifndef CIRCUIT_DEBUG
for (size_t i = 0; i < kDataSize; i++) {
for (size_t j = size - kZkCycles; j < size; j++) {
exec_.data[i * size + j] = Fp::random(CryptoRng::shared());
}
}
#endif
}
void RiscVProveCircuit::accumulate(WriteIOP& iop) {
size_t size = size_t(1) << po2_;
LOG(1, "size = " << size);
for (size_t i = 0; i < kAccumMixGlobalSize; i++) {
exec_.context.globals[kAccumMixGlobalOffset + i] = Fp::random(iop);
}
#ifdef CIRCUIT_DEBUG
accum_.resize(kAccumSize * size, Fp::invalid());
#else
accum_.resize(kAccumSize * size);
#endif
LOG(1, "accum_.size() == " << accum_.size());
for (size_t i = 0; i < size - kZkCycles; i++) {
exec_.context.curStep = i;
accumStep(exec_.context, exec_.code.data(), exec_.data.data(), accum_.data());
}
#ifdef CIRCUIT_DEBUG
for (Fp& x : accum_) {
if (x == Fp::invalid()) {
x = 0xdead;
}
}
#endif
#ifndef CIRCUIT_DEBUG
for (size_t i = 0; i < kAccumSize; i++) {
for (size_t j = size - kZkCycles; j < size; j++) {
accum_[i * size + j] = Fp::random(CryptoRng::shared());
}
}
#endif
}
namespace {
struct MixState {
Fp4 tot;
Fp4 mul;
};
}
void RiscVProveCircuit::evalCheck( AccelSlice<Fp> check, AccelConstSlice<Fp> codeEval, AccelConstSlice<Fp> dataEval, AccelConstSlice<Fp> accumEval, Fp4 polyMix) const {
size_t size = size_t(1) << po2_;
size_t domain = size * kInvRate;
uint32_t mask = domain - 1;
Fp* out = check.devicePointer();
constexpr size_t expPo2 = log2Ceil(kInvRate);
const Fp* code = codeEval.devicePointer();
const Fp* data = dataEval.devicePointer();
const Fp* accum = accumEval.devicePointer();
const Fp* global = exec_.context.globals;
parallel_for<size_t>(0, domain, [&](size_t idx) {
#define CHECK_EVAL
#define do_const(out, cval) Fp v##out = Fp(cval);
#define do_get(out, buf, reg, back, id) \
Fp v##out = buf[reg * domain + ((idx - kInvRate * back) & mask)];
#define do_get_global(out, reg) Fp v##out = global[reg];
#define do_begin(out) MixState m##out = {Fp4(0), Fp4(1)};
#define do_assert_zero(out, in, zval, loc) \
MixState m##out = {m##in.tot + m##in.mul * v##zval, m##in.mul * polyMix};
#define do_combine(out, prev, cond, inner, loc) \
MixState m##out = \
MixState{m##prev.tot + v##cond * m##prev.mul * m##inner.tot, m##prev.mul * m##inner.mul};
#define do_add(out, a, b) Fp v##out = v##a + v##b;
#define do_sub(out, a, b) Fp v##out = v##a - v##b;
#define do_mul(out, a, b) Fp v##out = v##a * v##b;
#define do_result(out) Fp4 ret = m##out.tot;
#include "risc0/zkvm/circuit/step.cpp.inc"
#undef CHECK_EVAL
#undef do_const
#undef do_get
#undef do_get_global
#undef do_begin
#undef do_assert_zero
#undef do_combine
#undef do_add
#undef do_sub
#undef do_mul
#undef do_result
Fp x = pow(kRouFwd[po2_ + expPo2], idx);
ret = ret * inv(pow(Fp(3) * x, (1 << (po2_))) - 1);
out[0 * domain + idx] = ret.elems[0];
out[1 * domain + idx] = ret.elems[1];
out[2 * domain + idx] = ret.elems[2];
out[3 * domain + idx] = ret.elems[3];
});
}
using oneapi::tbb::parallel_for;
std::unique_ptr<ProveCircuit> getRiscVProveCircuit(const std::string& elfFile, MemoryHandler& io) {
return std::make_unique<RiscVProveCircuit>(elfFile, io);
}
}