use std::iter;
use bitfield::bitfield;
use bitvec::prelude::*;
use probe_rs_target::ScanChainElement;
use crate::probe::{
BatchExecutionError, ChainParams, CommandResult, DebugProbe, DebugProbeError,
DeferredResultSet, JTAGAccess, JtagChainItem, JtagCommand, JtagCommandQueue,
};
pub(crate) fn bits_to_byte(bits: impl IntoIterator<Item = bool>) -> u32 {
let mut bit_val = 0u32;
for (index, bit) in bits.into_iter().take(32).enumerate() {
if bit {
bit_val |= 1 << index;
}
}
bit_val
}
bitfield! {
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct IdCode(u32);
impl Debug;
u8;
pub version, set_version: 31, 28;
u16;
pub part_number, set_part_number: 27, 12;
pub manufacturer, set_manufacturer: 11, 1;
u8;
pub manufacturer_continuation, set_manufacturer_continuation: 11, 8;
pub manufacturer_identity, set_manufacturer_identity: 7, 1;
bool;
pub lsbit, set_lsbit: 0;
}
impl std::fmt::Display for IdCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(mfn) = self.manufacturer_name() {
write!(f, "0x{:08X} ({})", self.0, mfn)
} else {
write!(f, "0x{:08X}", self.0)
}
}
}
impl IdCode {
pub fn valid(&self) -> bool {
self.lsbit() && (self.manufacturer() != 0) && (self.manufacturer() != 127)
}
pub fn manufacturer_name(&self) -> Option<&'static str> {
let cc = self.manufacturer_continuation();
let id = self.manufacturer_identity();
jep106::JEP106Code::new(cc, id).get()
}
}
#[derive(Debug, thiserror::Error)]
pub enum ScanChainError {
#[error("Invalid IDCODE")]
InvalidIdCode,
#[error("Invalid IR scan chain")]
InvalidIR,
}
fn starts_to_lengths(starts: &[usize], total: usize) -> Vec<usize> {
let mut lens: Vec<usize> = starts.windows(2).map(|w| w[1] - w[0]).collect();
lens.push(total - lens.iter().sum::<usize>());
lens
}
pub(crate) fn extract_idcodes(
mut dr: &BitSlice<u8>,
) -> Result<Vec<Option<IdCode>>, ScanChainError> {
let mut idcodes = Vec::new();
while !dr.is_empty() {
if dr[0] {
if dr.len() < 32 {
tracing::error!("Truncated IDCODE: {dr:02X?}");
return Err(ScanChainError::InvalidIdCode);
}
let idcode = dr[0..32].load_le::<u32>();
if idcode == u32::MAX {
break;
}
let idcode = IdCode(idcode);
if !idcode.valid() {
tracing::error!("Invalid IDCODE: {:08X}", idcode.0);
return Err(ScanChainError::InvalidIdCode);
}
tracing::info!("Found IDCODE: {idcode}");
idcodes.push(Some(idcode));
dr = &dr[32..];
} else {
idcodes.push(None);
tracing::info!("Found bypass TAP");
dr = &dr[1..];
}
}
Ok(idcodes)
}
pub(crate) fn common_sequence<'a, S: BitStore>(
a: &'a BitSlice<S>,
b: &BitSlice<S>,
) -> &'a BitSlice<S> {
let common_length = a.iter().zip(b.iter()).take_while(|(a, b)| *a == *b).count();
&a[..common_length]
}
pub(crate) fn extract_ir_lengths(
ir: &BitSlice<u8>,
n_taps: usize,
expected: Option<&[usize]>,
) -> Result<Vec<usize>, ScanChainError> {
let starts = ir
.windows(2)
.enumerate()
.filter(|(_, w)| w[0] && !w[1])
.map(|(i, _)| i)
.collect::<Vec<usize>>();
tracing::trace!("Possible IR start positions: {starts:?}");
if n_taps == 0 {
tracing::error!("Cannot scan IR without at least one TAP");
Err(ScanChainError::InvalidIR)
} else if n_taps > starts.len() {
tracing::error!("Fewer IRs detected than TAPs");
Err(ScanChainError::InvalidIR)
} else if starts[0] != 0 {
tracing::error!("IR chain does not begin with a valid start pattern");
Err(ScanChainError::InvalidIR)
} else if let Some(expected) = expected {
if expected.len() != n_taps {
tracing::error!(
"Number of provided IR lengths ({}) does not match \
number of detected TAPs ({n_taps})",
expected.len()
);
Err(ScanChainError::InvalidIR)
} else if expected.iter().sum::<usize>() != ir.len() {
tracing::error!(
"Sum of provided IR lengths ({}) does not match \
length of IR scan ({} bits)",
expected.iter().sum::<usize>(),
ir.len()
);
Err(ScanChainError::InvalidIR)
} else {
let exp_starts = expected
.iter()
.scan(0, |a, &x| {
let b = *a;
*a += x;
Some(b)
})
.collect::<Vec<usize>>();
tracing::trace!("Provided IR start positions: {exp_starts:?}");
let unsupported = exp_starts.iter().filter(|s| !starts.contains(s)).count();
if unsupported > 0 {
tracing::error!(
"Provided IR lengths imply an IR start position \
which is not supported by the IR scan"
);
Err(ScanChainError::InvalidIR)
} else {
tracing::debug!("Verified provided IR lengths against IR scan");
Ok(starts_to_lengths(&exp_starts, ir.len()))
}
}
} else if n_taps == 1 {
tracing::info!("Only one TAP detected, IR length {}", ir.len());
Ok(vec![ir.len()])
} else if n_taps == starts.len() {
let irlens = starts_to_lengths(&starts, ir.len());
tracing::info!("IR lengths are unambiguous: {irlens:?}");
Ok(irlens)
} else {
if n_taps < starts.len() {
let mut irlens = starts_to_lengths(&starts, ir.len()).into_iter();
let mut merged = Vec::new();
while let Some(len) = irlens.next() {
if len == 2 {
if let Some(next) = irlens.next() {
merged.push(len + next);
continue;
}
}
merged.push(len);
}
if merged.len() == n_taps {
tracing::info!("IR lengths after merging 101xx prefixes: {merged:?}");
return Ok(merged);
}
}
tracing::error!("IR lengths are ambiguous and must be explicitly configured.");
Err(ScanChainError::InvalidIR)
}
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub(crate) enum RegisterState {
Select,
Capture,
Shift,
Exit1,
Pause,
Exit2,
Update,
}
impl RegisterState {
fn step_toward(self, target: Self) -> bool {
match self {
Self::Select => false,
Self::Capture if matches!(target, Self::Shift) => false,
Self::Exit1 if matches!(target, Self::Pause | Self::Exit2) => false,
Self::Exit2 if matches!(target, Self::Shift | Self::Exit1 | Self::Pause) => false,
Self::Update => {
unreachable!("This is a bug, this case should have been handled by JtagState.")
}
_ => true,
}
}
fn update(self, tms: bool) -> Self {
if tms {
match self {
Self::Capture | Self::Shift => Self::Exit1,
Self::Exit1 | Self::Exit2 => Self::Update,
Self::Pause => Self::Exit2,
Self::Select | Self::Update => {
unreachable!("This is a bug, this case should have been handled by JtagState.")
}
}
} else {
match self {
Self::Select => Self::Capture,
Self::Capture | Self::Shift => Self::Shift,
Self::Exit1 | Self::Pause => Self::Pause,
Self::Exit2 => Self::Shift,
Self::Update => {
unreachable!("This is a bug, this case should have been handled by JtagState.")
}
}
}
}
}
#[derive(Clone, Copy, PartialEq, Debug)]
pub(crate) enum JtagState {
Reset,
Idle,
Dr(RegisterState),
Ir(RegisterState),
}
impl JtagState {
pub fn step_toward(self, target: Self) -> Option<bool> {
let tms = match self {
state if target == state => return None,
Self::Reset => false,
Self::Idle => true,
Self::Dr(RegisterState::Select) => !matches!(target, Self::Dr(_)),
Self::Ir(RegisterState::Select) => !matches!(target, Self::Ir(_)),
Self::Dr(RegisterState::Update) | Self::Ir(RegisterState::Update) => {
matches!(target, Self::Ir(_) | Self::Dr(_))
}
Self::Dr(state) => {
let next = if let Self::Dr(target) = target {
target
} else {
RegisterState::Update
};
state.step_toward(next)
}
Self::Ir(state) => {
let next = if let Self::Ir(target) = target {
target
} else {
RegisterState::Update
};
state.step_toward(next)
}
};
Some(tms)
}
pub fn update(&mut self, tms: bool) {
*self = match *self {
Self::Reset if tms => Self::Reset,
Self::Reset => Self::Idle,
Self::Idle if tms => Self::Dr(RegisterState::Select),
Self::Idle => Self::Idle,
Self::Dr(RegisterState::Select) if tms => Self::Ir(RegisterState::Select),
Self::Ir(RegisterState::Select) if tms => Self::Reset,
Self::Dr(RegisterState::Update) | Self::Ir(RegisterState::Update) => {
if tms {
Self::Dr(RegisterState::Select)
} else {
Self::Idle
}
}
Self::Dr(state) => Self::Dr(state.update(tms)),
Self::Ir(state) => Self::Ir(state.update(tms)),
};
}
}
#[derive(Debug)]
pub(crate) struct JtagDriverState {
pub state: JtagState,
pub max_ir_address: u32,
pub expected_scan_chain: Option<Vec<ScanChainElement>>,
pub scan_chain: Vec<JtagChainItem>,
pub chain_params: ChainParams,
pub jtag_idle_cycles: usize,
}
impl Default for JtagDriverState {
fn default() -> Self {
Self {
state: JtagState::Reset,
max_ir_address: 0x0F,
expected_scan_chain: None,
scan_chain: Vec::new(),
chain_params: ChainParams::default(),
jtag_idle_cycles: 0,
}
}
}
pub(crate) trait RawJtagIo {
fn state_mut(&mut self) -> &mut JtagDriverState;
fn state(&self) -> &JtagDriverState;
fn shift_bits(
&mut self,
tms: impl IntoIterator<Item = bool>,
tdi: impl IntoIterator<Item = bool>,
cap: impl IntoIterator<Item = bool>,
) -> Result<(), DebugProbeError> {
for ((tms, tdi), cap) in tms.into_iter().zip(tdi.into_iter()).zip(cap.into_iter()) {
self.shift_bit(tms, tdi, cap)?;
}
Ok(())
}
fn shift_bit(&mut self, tms: bool, tdi: bool, capture: bool) -> Result<(), DebugProbeError>;
fn read_captured_bits(&mut self) -> Result<BitVec<u8, Lsb0>, DebugProbeError>;
fn reset_jtag_state_machine(&mut self) -> Result<(), DebugProbeError> {
tracing::debug!("Resetting JTAG chain by setting tms high for 5 bits");
let tms = [true, true, true, true, true, false];
let tdi = iter::repeat(true);
self.shift_bits(tms, tdi, iter::repeat(false))?;
let response = self.read_captured_bits()?;
tracing::debug!("Response to reset: {response}");
Ok(())
}
fn select_target(&mut self, target: usize) -> Result<(), DebugProbeError> {
let state = self.state_mut();
let Some(params) = ChainParams::from_jtag_chain(&state.scan_chain, target) else {
return Err(DebugProbeError::TargetNotFound);
};
let max_ir_address = (1 << params.irlen) - 1;
tracing::debug!("Selecting JTAG TAP: {target}");
tracing::debug!("Setting chain params: {params:?}");
tracing::debug!("Setting max_ir_address to {max_ir_address}");
let state = self.state_mut();
state.max_ir_address = max_ir_address;
state.chain_params = params;
Ok(())
}
}
fn jtag_move_to_state(
protocol: &mut impl RawJtagIo,
target: JtagState,
) -> Result<(), DebugProbeError> {
tracing::trace!(
"Changing state: {:?} -> {:?}",
protocol.state_mut().state,
target
);
while let Some(tms) = protocol.state().state.step_toward(target) {
protocol.shift_bit(tms, false, false)?;
}
tracing::trace!("In state: {:?}", protocol.state_mut().state);
Ok(())
}
fn shift_ir(
protocol: &mut impl RawJtagIo,
data: &[u8],
len: usize,
capture_data: bool,
) -> Result<(), DebugProbeError> {
tracing::debug!("Write IR: {:?}, len={}", data, len);
if data.len() * 8 < len || len == 0 {
return Err(DebugProbeError::Other(format!(
"Invalid data length. IR bits: {}, expected: {}",
data.len(),
len
)));
}
let pre_bits = protocol.state().chain_params.irpre;
let post_bits = protocol.state().chain_params.irpost;
let tms_data = iter::repeat(false).take(len - 1);
jtag_move_to_state(protocol, JtagState::Ir(RegisterState::Shift))?;
let tms = iter::repeat(false)
.take(pre_bits)
.chain(tms_data)
.chain(iter::repeat(false).take(post_bits))
.chain(iter::once(true));
let tdi = iter::repeat(true)
.take(pre_bits)
.chain(data.as_bits::<Lsb0>()[..len].iter().map(|b| *b))
.chain(iter::repeat(true).take(post_bits));
let capture = iter::repeat(false)
.take(pre_bits)
.chain(iter::repeat(capture_data).take(len))
.chain(iter::repeat(false));
tracing::trace!("tms: {:?}", tms.clone());
tracing::trace!("tdi: {:?}", tdi.clone());
protocol.shift_bits(tms, tdi, capture)?;
jtag_move_to_state(protocol, JtagState::Ir(RegisterState::Update))?;
Ok(())
}
fn shift_dr(
protocol: &mut impl RawJtagIo,
data: &[u8],
register_bits: usize,
capture_data: bool,
) -> Result<usize, DebugProbeError> {
tracing::debug!("Write DR: {:?}, len={}", data, register_bits);
if data.len() * 8 < register_bits || register_bits == 0 {
return Err(DebugProbeError::Other(format!(
"Invalid data length. DR bits: {}, expected: {}",
data.len(),
register_bits
)));
}
let tms_shift_out_value = iter::repeat(false).take(register_bits - 1);
jtag_move_to_state(protocol, JtagState::Dr(RegisterState::Shift))?;
let pre_bits = protocol.state().chain_params.drpre;
let post_bits = protocol.state().chain_params.drpost;
let tms = iter::repeat(false)
.take(pre_bits)
.chain(tms_shift_out_value)
.chain(iter::repeat(false).take(post_bits))
.chain(iter::once(true));
let tdi = iter::repeat(false)
.take(pre_bits)
.chain(data.as_bits::<Lsb0>()[..register_bits].iter().map(|b| *b))
.chain(iter::repeat(false).take(post_bits));
let capture = iter::repeat(false)
.take(pre_bits)
.chain(iter::repeat(capture_data).take(register_bits))
.chain(iter::repeat(false));
protocol.shift_bits(tms, tdi, capture)?;
jtag_move_to_state(protocol, JtagState::Dr(RegisterState::Update))?;
let idle_cycles = protocol.state().jtag_idle_cycles;
if idle_cycles > 0 {
jtag_move_to_state(protocol, JtagState::Idle)?;
let tms = iter::repeat(false).take(idle_cycles);
let tdi = iter::repeat(false).take(idle_cycles);
protocol.shift_bits(tms, tdi, iter::repeat(false))?;
}
if capture_data {
Ok(register_bits)
} else {
Ok(0)
}
}
fn prepare_write_register(
protocol: &mut impl RawJtagIo,
address: u32,
data: &[u8],
len: u32,
capture: bool,
) -> Result<usize, DebugProbeError> {
if address > protocol.state().max_ir_address {
return Err(DebugProbeError::Other(format!(
"Invalid instruction register access: {}",
address
)));
}
let ir_len = protocol.state().chain_params.irlen;
shift_ir(protocol, &address.to_le_bytes(), ir_len, false)?;
shift_dr(protocol, data, len as usize, capture)
}
impl<Probe: DebugProbe + RawJtagIo + 'static> JTAGAccess for Probe {
fn scan_chain(&mut self) -> Result<(), DebugProbeError> {
const MAX_CHAIN: usize = 8;
self.reset_jtag_state_machine()?;
self.state_mut().chain_params = ChainParams::default();
let input = vec![0xFF; 4 * MAX_CHAIN];
shift_dr(self, &input, input.len() * 8, true)?;
let response = self.read_captured_bits()?;
tracing::debug!("DR: {:?}", response);
let idcodes = extract_idcodes(&response)?;
tracing::info!(
"JTAG DR scan complete, found {} TAPs. {:?}",
idcodes.len(),
idcodes
);
tracing::debug!("Scanning JTAG chain for IR lengths");
let input = vec![0xff; idcodes.len()];
shift_ir(self, &input, input.len() * 8, true)?;
let response = self.read_captured_bits()?;
tracing::debug!("IR scan: {}", response);
self.reset_jtag_state_machine()?;
let input = iter::repeat(0)
.take(idcodes.len())
.chain(input.iter().copied())
.collect::<Vec<_>>();
shift_ir(self, &input, input.len() * 8, true)?;
let response_zeros = self.read_captured_bits()?;
tracing::debug!("IR scan: {}", response_zeros);
let response = response.as_bitslice();
let response = common_sequence(response, response_zeros.as_bitslice());
tracing::debug!("IR scan: {}", response);
let ir_lens = extract_ir_lengths(
response,
idcodes.len(),
self.state()
.expected_scan_chain
.as_ref()
.map(|chain| {
chain
.iter()
.filter_map(|s| s.ir_len)
.map(|s| s as usize)
.collect::<Vec<usize>>()
})
.as_deref(),
)?;
tracing::info!("Found {} TAPs on reset scan", idcodes.len());
tracing::debug!("Detected IR lens: {:?}", ir_lens);
let chain = idcodes
.into_iter()
.zip(ir_lens)
.map(|(idcode, irlen)| JtagChainItem { irlen, idcode })
.collect::<Vec<_>>();
self.state_mut().scan_chain = chain;
Ok(())
}
fn tap_reset(&mut self) -> Result<(), DebugProbeError> {
self.reset_jtag_state_machine()
}
fn set_idle_cycles(&mut self, idle_cycles: u8) {
self.state_mut().jtag_idle_cycles = idle_cycles as usize;
}
fn idle_cycles(&self) -> u8 {
self.state().jtag_idle_cycles as u8
}
fn read_register(&mut self, address: u32, len: u32) -> Result<Vec<u8>, DebugProbeError> {
let data = vec![0u8; len.div_ceil(8) as usize];
self.write_register(address, &data, len)
}
fn write_register(
&mut self,
address: u32,
data: &[u8],
len: u32,
) -> Result<Vec<u8>, DebugProbeError> {
prepare_write_register(self, address, data, len, true)?;
let mut response = self.read_captured_bits()?;
response.force_align();
let result = response.into_vec();
tracing::trace!("recieve_write_dr result: {:?}", result);
Ok(result)
}
fn write_dr(&mut self, data: &[u8], len: u32) -> Result<Vec<u8>, DebugProbeError> {
shift_dr(self, data, len as usize, true)?;
let mut response = self.read_captured_bits()?;
response.force_align();
let result = response.into_vec();
tracing::trace!("recieve_write_dr result: {:?}", result);
Ok(result)
}
#[tracing::instrument(skip(self, writes))]
fn write_register_batch(
&mut self,
writes: &JtagCommandQueue,
) -> Result<DeferredResultSet, BatchExecutionError> {
let mut bits = Vec::with_capacity(writes.len());
let t1 = std::time::Instant::now();
tracing::debug!("Preparing {} writes...", writes.len());
for (idx, command) in writes.iter() {
let result = match command {
JtagCommand::WriteRegister(write) => prepare_write_register(
self,
write.address,
&write.data,
write.len,
idx.should_capture(),
),
JtagCommand::ShiftDr(write) => {
shift_dr(self, &write.data, write.len as usize, idx.should_capture())
}
};
let op =
result.map_err(|e| BatchExecutionError::new(e.into(), DeferredResultSet::new()))?;
bits.push((idx, command, op));
}
tracing::debug!("Sending to chip...");
let bitstream = self
.read_captured_bits()
.map_err(|e| BatchExecutionError::new(e.into(), DeferredResultSet::new()))?;
tracing::debug!("Got responses! Took {:?}! Processing...", t1.elapsed());
let mut responses = DeferredResultSet::with_capacity(bits.len());
let mut bitstream = bitstream.as_bitslice();
for (idx, command, bits) in bits.into_iter() {
if idx.should_capture() {
let mut reg_bits = bitstream[..bits].to_bitvec();
reg_bits.force_align();
let response = reg_bits.into_vec();
let result = match command {
JtagCommand::WriteRegister(command) => (command.transform)(command, response),
JtagCommand::ShiftDr(command) => (command.transform)(command, response),
};
match result {
Ok(response) => responses.push(idx, response),
Err(e) => return Err(BatchExecutionError::new(e, responses)),
}
} else {
responses.push(idx, CommandResult::None);
}
bitstream = &bitstream[bits..];
}
Ok(responses)
}
}
#[cfg(test)]
mod tests {
use super::*;
const ARM_TAP: IdCode = IdCode(0x4BA00477);
const STM_BS_TAP: IdCode = IdCode(0x06433041);
#[test]
fn id_code_display() {
let debug_fmt = format!("{idcode}", idcode = ARM_TAP);
assert_eq!(debug_fmt, "0x4BA00477 (ARM Ltd)");
let debug_fmt = format!("{idcode}", idcode = STM_BS_TAP);
assert_eq!(debug_fmt, "0x06433041 (STMicroelectronics)");
}
#[test]
fn extract_ir_lengths_with_one_tap() {
let ir = &bitvec![u8, Lsb0; 1,0,0,0];
let n_taps = 1;
let expected = None;
let ir_lengths = extract_ir_lengths(ir, n_taps, expected).unwrap();
assert_eq!(ir_lengths, vec![4]);
}
#[test]
fn extract_ir_lengths_with_two_taps() {
let ir = &bitvec![u8, Lsb0; 1,0,0,0,1,0,0,0,0];
let n_taps = 2;
let expected = None;
let ir_lengths = extract_ir_lengths(ir, n_taps, expected).unwrap();
assert_eq!(ir_lengths, vec![4, 5]);
}
#[test]
fn extract_ir_lengths_with_two_taps_101() {
let ir = &bitvec![u8, Lsb0; 1,0,1,0,1,0,0,0,0];
let n_taps = 2;
let expected = None;
let ir_lengths = extract_ir_lengths(ir, n_taps, expected).unwrap();
assert_eq!(ir_lengths, vec![4, 5]);
}
#[test]
fn extract_id_codes_one_tap() {
let mut dr = bitvec![u8, Lsb0; 0; 32];
dr[0..32].store_le(ARM_TAP.0);
let idcodes = extract_idcodes(&dr).unwrap();
assert_eq!(idcodes, vec![Some(ARM_TAP)]);
}
#[test]
fn extract_id_codes_two_taps() {
let mut dr = bitvec![u8, Lsb0; 0; 64];
dr[0..32].store_le(ARM_TAP.0);
dr[32..64].store_le(STM_BS_TAP.0);
let idcodes = extract_idcodes(&dr).unwrap();
assert_eq!(idcodes, vec![Some(ARM_TAP), Some(STM_BS_TAP)]);
}
#[test]
fn extract_id_codes_tap_bypass_tap() {
let mut dr = bitvec![u8, Lsb0; 0; 65];
dr[0..32].store_le(ARM_TAP.0);
dr.set(32, false);
dr[33..65].store_le(STM_BS_TAP.0);
let idcodes = extract_idcodes(&dr).unwrap();
assert_eq!(idcodes, vec![Some(ARM_TAP), None, Some(STM_BS_TAP)]);
}
#[test]
fn reset_from_ir_shift() {
let mut state = JtagState::Ir(RegisterState::Shift);
state.update(true);
state.update(true);
state.update(true);
state.update(true);
state.update(true);
assert_eq!(state, JtagState::Reset);
}
#[test]
fn idle_from_reset() {
let mut state = JtagState::Reset;
state.update(false);
assert_eq!(state, JtagState::Idle);
}
#[test]
fn generated_bits_lead_to_correct_state() {
for (start, goal) in [(JtagState::Reset, JtagState::Idle)] {
let mut state = start;
let mut transitions = 0;
while state != goal && transitions < 10 {
let tms = state.step_toward(goal).unwrap();
state.update(tms);
transitions += 1;
}
assert!(transitions < 10);
}
}
}