#pragma once
#include "risc0/core/key.h"
#include "risc0/zkp/core/fp.h"
#include "risc0/zkvm/circuit/constants.h"
#include <map>
#include <set>
#include <vector>
namespace risc0 {
using BufferU8 = std::vector<uint8_t>;
using BufferU32 = std::vector<uint32_t>;
struct MemoryEvent {
uint32_t addr;
uint32_t cycle;
bool isWrite;
uint32_t data;
bool operator<(const MemoryEvent& rhs) const {
if (addr != rhs.addr) {
return addr < rhs.addr;
}
return cycle < rhs.cycle;
}
};
struct MemoryState {
std::map<uint32_t, uint32_t> data;
std::set<MemoryEvent> history;
void dump(size_t logLevel);
uint8_t loadByte(uint32_t addr);
uint32_t load(uint32_t addr);
uint32_t loadBE(uint32_t addr);
void loadRegion(uint32_t addr, void* ptr, uint32_t len);
void storeByte(uint32_t addr, uint8_t byte);
void store(uint32_t addr, uint32_t value);
void store(uint32_t addr, const void* ptr, uint32_t len);
size_t strlen(uint32_t addr);
};
struct IoHandler {
virtual void onInit(MemoryState& mem) {}
virtual void onWrite(const BufferU8& data) {}
virtual void onCommit(const BufferU8& data) {}
virtual void onFault(const std::string& msg);
virtual KeyStore& getKeyStore() = 0;
};
class MemoryHandler {
public:
MemoryHandler();
MemoryHandler(IoHandler* io);
virtual void onInit(MemoryState& mem);
virtual void onLoaded(MemoryState& mem) {}
virtual uint32_t onRead(MemoryState& mem, uint32_t addr) { return 0; }
virtual void onWrite(MemoryState& mem, uint32_t cycle, uint32_t addr, uint32_t value);
virtual void onHalt(const MemoryState& mem, const std::array<uint32_t, 8>& output) {}
private:
IoHandler* io;
};
struct StepContext {
MemoryHandler* io;
MemoryState mem;
uint32_t curStep;
uint32_t numSteps;
Fp globals[kGlobalSize];
Fp get(const Fp* buf, size_t offset, size_t back);
void set(Fp* buf, size_t offset, Fp val);
Fp getDigits(const Fp* buf, size_t bits, size_t offset, size_t back, size_t size);
Fp setDigits(Fp* buf, size_t bits, size_t offset, size_t size, Fp val);
Fp getMux(const Fp* buf, size_t offset, size_t back, size_t size);
void setMux(Fp* buf, size_t offset, size_t size, Fp val);
void memWrite(Fp cycle, Fp addr, Fp low, Fp high);
std::array<Fp, 2> memRead(Fp cycle, Fp addr);
std::array<Fp, 5> memCheck(); std::array<Fp, 4> divide(Fp numerLow, Fp numerHigh, Fp denomLow, Fp denomHigh);
void requireDigits(Fp* buf, size_t bits, size_t offset, size_t size);
void requireMux(Fp* buf, size_t offset, size_t size, const char* msg);
void requireZero(Fp val, const char* msg);
};
void setupCode(Fp* code,
size_t numSteps,
uint32_t startAddr,
const std::map<uint32_t, uint32_t>& image);
void dataStepExec(StepContext& ctx, const Fp* code, Fp* data);
void dataStepCheck(StepContext& ctx, const Fp* code, Fp* data);
void accumStep(StepContext& ctx, const Fp* code, const Fp* data, Fp* accum);
}