use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::{Duration, Instant};
use rand::{Rng, SeedableRng};
use crate::congestion::bbr::bw_estimation::BandwidthEstimation;
use crate::congestion::bbr::min_max::MinMax;
use crate::connection::RttEstimator;
use super::{Controller, ControllerFactory, BASE_DATAGRAM_SIZE};
mod bw_estimation;
mod min_max;
#[derive(Debug, Clone)]
pub struct Bbr {
config: Arc<BbrConfig>,
current_mtu: u64,
max_bandwidth: BandwidthEstimation,
acked_bytes: u64,
mode: Mode,
loss_state: LossState,
recovery_state: RecoveryState,
recovery_window: u64,
is_at_full_bandwidth: bool,
pacing_gain: f32,
high_gain: f32,
drain_gain: f32,
cwnd_gain: f32,
high_cwnd_gain: f32,
last_cycle_start: Option<Instant>,
current_cycle_offset: u8,
init_cwnd: u64,
min_cwnd: u64,
prev_in_flight_count: u64,
exit_probe_rtt_at: Option<Instant>,
probe_rtt_last_started_at: Option<Instant>,
min_rtt: Duration,
exiting_quiescence: bool,
pacing_rate: u64,
max_acked_packet_number: u64,
max_sent_packet_number: u64,
end_recovery_at_packet_number: u64,
cwnd: u64,
current_round_trip_end_packet_number: u64,
round_count: u64,
bw_at_last_round: u64,
round_wo_bw_gain: u64,
ack_aggregation: AckAggregationState,
random_number_generator: rand::rngs::StdRng,
}
impl Bbr {
pub fn new(config: Arc<BbrConfig>, current_mtu: u16) -> Self {
let initial_window = config.initial_window;
Self {
config,
current_mtu: current_mtu as u64,
max_bandwidth: BandwidthEstimation::default(),
acked_bytes: 0,
mode: Mode::Startup,
loss_state: Default::default(),
recovery_state: RecoveryState::NotInRecovery,
recovery_window: 0,
is_at_full_bandwidth: false,
pacing_gain: K_DEFAULT_HIGH_GAIN,
high_gain: K_DEFAULT_HIGH_GAIN,
drain_gain: 1.0 / K_DEFAULT_HIGH_GAIN,
cwnd_gain: K_DEFAULT_HIGH_GAIN,
high_cwnd_gain: K_DEFAULT_HIGH_GAIN,
last_cycle_start: None,
current_cycle_offset: 0,
init_cwnd: initial_window,
min_cwnd: calculate_min_window(current_mtu as u64),
prev_in_flight_count: 0,
exit_probe_rtt_at: None,
probe_rtt_last_started_at: None,
min_rtt: Default::default(),
exiting_quiescence: false,
pacing_rate: 0,
max_acked_packet_number: 0,
max_sent_packet_number: 0,
end_recovery_at_packet_number: 0,
cwnd: initial_window,
current_round_trip_end_packet_number: 0,
round_count: 0,
bw_at_last_round: 0,
round_wo_bw_gain: 0,
ack_aggregation: AckAggregationState::default(),
random_number_generator: rand::rngs::StdRng::from_entropy(),
}
}
fn enter_startup_mode(&mut self) {
self.mode = Mode::Startup;
self.pacing_gain = self.high_gain;
self.cwnd_gain = self.high_cwnd_gain;
}
fn enter_probe_bandwidth_mode(&mut self, now: Instant) {
self.mode = Mode::ProbeBw;
self.cwnd_gain = K_DERIVED_HIGH_CWNDGAIN;
self.last_cycle_start = Some(now);
let mut rand_index = self
.random_number_generator
.gen_range(0..K_PACING_GAIN.len() as u8 - 1);
if rand_index >= 1 {
rand_index += 1;
}
self.current_cycle_offset = rand_index;
self.pacing_gain = K_PACING_GAIN[rand_index as usize];
}
fn update_recovery_state(&mut self, is_round_start: bool) {
if self.loss_state.has_losses() {
self.end_recovery_at_packet_number = self.max_sent_packet_number;
}
match self.recovery_state {
RecoveryState::NotInRecovery if self.loss_state.has_losses() => {
self.recovery_state = RecoveryState::Conservation;
self.recovery_window = 0;
self.current_round_trip_end_packet_number = self.max_sent_packet_number;
}
RecoveryState::Growth | RecoveryState::Conservation => {
if self.recovery_state == RecoveryState::Conservation && is_round_start {
self.recovery_state = RecoveryState::Growth;
}
if !self.loss_state.has_losses()
&& self.max_acked_packet_number > self.end_recovery_at_packet_number
{
self.recovery_state = RecoveryState::NotInRecovery;
}
}
_ => {}
}
}
fn update_gain_cycle_phase(&mut self, now: Instant, in_flight: u64) {
let mut should_advance_gain_cycling = self
.last_cycle_start
.map(|last_cycle_start| now.duration_since(last_cycle_start) > self.min_rtt)
.unwrap_or(false);
if self.pacing_gain > 1.0
&& !self.loss_state.has_losses()
&& self.prev_in_flight_count < self.get_target_cwnd(self.pacing_gain)
{
should_advance_gain_cycling = false;
}
if self.pacing_gain < 1.0 && in_flight <= self.get_target_cwnd(1.0) {
should_advance_gain_cycling = true;
}
if should_advance_gain_cycling {
self.current_cycle_offset = (self.current_cycle_offset + 1) % K_PACING_GAIN.len() as u8;
self.last_cycle_start = Some(now);
if DRAIN_TO_TARGET
&& self.pacing_gain < 1.0
&& (K_PACING_GAIN[self.current_cycle_offset as usize] - 1.0).abs() < f32::EPSILON
&& in_flight > self.get_target_cwnd(1.0)
{
return;
}
self.pacing_gain = K_PACING_GAIN[self.current_cycle_offset as usize];
}
}
fn maybe_exit_startup_or_drain(&mut self, now: Instant, in_flight: u64) {
if self.mode == Mode::Startup && self.is_at_full_bandwidth {
self.mode = Mode::Drain;
self.pacing_gain = self.drain_gain;
self.cwnd_gain = self.high_cwnd_gain;
}
if self.mode == Mode::Drain && in_flight <= self.get_target_cwnd(1.0) {
self.enter_probe_bandwidth_mode(now);
}
}
fn is_min_rtt_expired(&self, now: Instant, app_limited: bool) -> bool {
!app_limited
&& self
.probe_rtt_last_started_at
.map(|last| now.saturating_duration_since(last) > Duration::from_secs(10))
.unwrap_or(true)
}
fn maybe_enter_or_exit_probe_rtt(
&mut self,
now: Instant,
is_round_start: bool,
bytes_in_flight: u64,
app_limited: bool,
) {
let min_rtt_expired = self.is_min_rtt_expired(now, app_limited);
if min_rtt_expired && !self.exiting_quiescence && self.mode != Mode::ProbeRtt {
self.mode = Mode::ProbeRtt;
self.pacing_gain = 1.0;
self.exit_probe_rtt_at = None;
self.probe_rtt_last_started_at = Some(now);
}
if self.mode == Mode::ProbeRtt {
if self.exit_probe_rtt_at.is_none() {
if bytes_in_flight < self.get_probe_rtt_cwnd() + self.current_mtu {
const K_PROBE_RTT_TIME: Duration = Duration::from_millis(200);
self.exit_probe_rtt_at = Some(now + K_PROBE_RTT_TIME);
}
} else if is_round_start && now >= self.exit_probe_rtt_at.unwrap() {
if !self.is_at_full_bandwidth {
self.enter_startup_mode();
} else {
self.enter_probe_bandwidth_mode(now);
}
}
}
self.exiting_quiescence = false;
}
fn get_target_cwnd(&self, gain: f32) -> u64 {
let bw = self.max_bandwidth.get_estimate();
let bdp = self.min_rtt.as_micros() as u64 * bw;
let bdpf = bdp as f64;
let cwnd = ((gain as f64 * bdpf) / 1_000_000f64) as u64;
if cwnd == 0 {
return self.init_cwnd;
}
cwnd.max(self.min_cwnd)
}
fn get_probe_rtt_cwnd(&self) -> u64 {
const K_MODERATE_PROBE_RTT_MULTIPLIER: f32 = 0.75;
if PROBE_RTT_BASED_ON_BDP {
return self.get_target_cwnd(K_MODERATE_PROBE_RTT_MULTIPLIER);
}
self.min_cwnd
}
fn calculate_pacing_rate(&mut self) {
let bw = self.max_bandwidth.get_estimate();
if bw == 0 {
return;
}
let target_rate = (bw as f64 * self.pacing_gain as f64) as u64;
if self.is_at_full_bandwidth {
self.pacing_rate = target_rate;
return;
}
if self.pacing_rate == 0 && self.min_rtt.as_nanos() != 0 {
self.pacing_rate =
BandwidthEstimation::bw_from_delta(self.init_cwnd, self.min_rtt).unwrap();
return;
}
if self.pacing_rate < target_rate {
self.pacing_rate = target_rate;
}
}
fn calculate_cwnd(&mut self, bytes_acked: u64, excess_acked: u64) {
if self.mode == Mode::ProbeRtt {
return;
}
let mut target_window = self.get_target_cwnd(self.cwnd_gain);
if self.is_at_full_bandwidth {
target_window += self.ack_aggregation.max_ack_height.get();
} else {
target_window += excess_acked;
}
if self.is_at_full_bandwidth {
self.cwnd = target_window.min(self.cwnd + bytes_acked);
} else if (self.cwnd_gain < target_window as f32) || (self.acked_bytes < self.init_cwnd) {
self.cwnd += bytes_acked;
}
if self.cwnd < self.min_cwnd {
self.cwnd = self.min_cwnd;
}
}
fn calculate_recovery_window(&mut self, bytes_acked: u64, bytes_lost: u64, in_flight: u64) {
if !self.recovery_state.in_recovery() {
return;
}
if self.recovery_window == 0 {
self.recovery_window = self.min_cwnd.max(in_flight + bytes_acked);
return;
}
if self.recovery_window >= bytes_lost {
self.recovery_window -= bytes_lost;
} else {
self.recovery_window = self.current_mtu;
}
if self.recovery_state == RecoveryState::Growth {
self.recovery_window += bytes_acked;
}
self.recovery_window = self
.recovery_window
.max(in_flight + bytes_acked)
.max(self.min_cwnd);
}
fn check_if_full_bw_reached(&mut self, app_limited: bool) {
if app_limited {
return;
}
let target = (self.bw_at_last_round as f64 * K_STARTUP_GROWTH_TARGET as f64) as u64;
let bw = self.max_bandwidth.get_estimate();
if bw >= target {
self.bw_at_last_round = bw;
self.round_wo_bw_gain = 0;
self.ack_aggregation.max_ack_height.reset();
return;
}
self.round_wo_bw_gain += 1;
if self.round_wo_bw_gain >= K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP as u64
|| (self.recovery_state.in_recovery())
{
self.is_at_full_bandwidth = true;
}
}
}
impl Controller for Bbr {
fn on_sent(&mut self, now: Instant, bytes: u64, last_packet_number: u64) {
self.max_sent_packet_number = last_packet_number;
self.max_bandwidth.on_sent(now, bytes);
}
fn on_ack(
&mut self,
now: Instant,
sent: Instant,
bytes: u64,
app_limited: bool,
rtt: &RttEstimator,
) {
self.max_bandwidth
.on_ack(now, sent, bytes, self.round_count, app_limited);
self.acked_bytes += bytes;
if self.is_min_rtt_expired(now, app_limited) || self.min_rtt > rtt.min() {
self.min_rtt = rtt.min();
}
}
fn on_end_acks(
&mut self,
now: Instant,
in_flight: u64,
app_limited: bool,
largest_packet_num_acked: Option<u64>,
) {
let bytes_acked = self.max_bandwidth.bytes_acked_this_window();
let excess_acked = self.ack_aggregation.update_ack_aggregation_bytes(
bytes_acked,
now,
self.round_count,
self.max_bandwidth.get_estimate(),
);
self.max_bandwidth.end_acks(self.round_count, app_limited);
if let Some(largest_acked_packet) = largest_packet_num_acked {
self.max_acked_packet_number = largest_acked_packet;
}
let mut is_round_start = false;
if bytes_acked > 0 {
is_round_start =
self.max_acked_packet_number > self.current_round_trip_end_packet_number;
if is_round_start {
self.current_round_trip_end_packet_number = self.max_sent_packet_number;
self.round_count += 1;
}
}
self.update_recovery_state(is_round_start);
if self.mode == Mode::ProbeBw {
self.update_gain_cycle_phase(now, in_flight);
}
if is_round_start && !self.is_at_full_bandwidth {
self.check_if_full_bw_reached(app_limited);
}
self.maybe_exit_startup_or_drain(now, in_flight);
self.maybe_enter_or_exit_probe_rtt(now, is_round_start, in_flight, app_limited);
self.calculate_pacing_rate();
self.calculate_cwnd(bytes_acked, excess_acked);
self.calculate_recovery_window(bytes_acked, self.loss_state.lost_bytes, in_flight);
self.prev_in_flight_count = in_flight;
self.loss_state.reset();
}
fn on_congestion_event(
&mut self,
_now: Instant,
_sent: Instant,
_is_persistent_congestion: bool,
lost_bytes: u64,
) {
self.loss_state.lost_bytes += lost_bytes;
}
fn on_mtu_update(&mut self, new_mtu: u16) {
self.current_mtu = new_mtu as u64;
self.min_cwnd = calculate_min_window(self.current_mtu);
self.init_cwnd = self.config.initial_window.max(self.min_cwnd);
self.cwnd = self.cwnd.max(self.min_cwnd);
}
fn window(&self) -> u64 {
if self.mode == Mode::ProbeRtt {
return self.get_probe_rtt_cwnd();
} else if self.recovery_state.in_recovery() && self.mode != Mode::Startup {
return self.cwnd.min(self.recovery_window);
}
self.cwnd
}
fn clone_box(&self) -> Box<dyn Controller> {
Box::new(self.clone())
}
fn initial_window(&self) -> u64 {
self.config.initial_window
}
fn into_any(self: Box<Self>) -> Box<dyn Any> {
self
}
}
#[derive(Debug, Clone)]
pub struct BbrConfig {
initial_window: u64,
}
impl BbrConfig {
pub fn initial_window(&mut self, value: u64) -> &mut Self {
self.initial_window = value;
self
}
}
impl Default for BbrConfig {
fn default() -> Self {
Self {
initial_window: K_MAX_INITIAL_CONGESTION_WINDOW * BASE_DATAGRAM_SIZE,
}
}
}
impl ControllerFactory for Arc<BbrConfig> {
fn build(&self, _now: Instant, current_mtu: u16) -> Box<dyn Controller> {
Box::new(Bbr::new(self.clone(), current_mtu))
}
}
#[derive(Debug, Default, Copy, Clone)]
struct AckAggregationState {
max_ack_height: MinMax,
aggregation_epoch_start_time: Option<Instant>,
aggregation_epoch_bytes: u64,
}
impl AckAggregationState {
fn update_ack_aggregation_bytes(
&mut self,
newly_acked_bytes: u64,
now: Instant,
round: u64,
max_bandwidth: u64,
) -> u64 {
let expected_bytes_acked = max_bandwidth
* now
.saturating_duration_since(self.aggregation_epoch_start_time.unwrap_or(now))
.as_micros() as u64
/ 1_000_000;
if self.aggregation_epoch_bytes <= expected_bytes_acked {
self.aggregation_epoch_bytes = newly_acked_bytes;
self.aggregation_epoch_start_time = Some(now);
return 0;
}
self.aggregation_epoch_bytes += newly_acked_bytes;
let diff = self.aggregation_epoch_bytes - expected_bytes_acked;
self.max_ack_height.update_max(round, diff);
diff
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum Mode {
Startup,
Drain,
ProbeBw,
ProbeRtt,
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
enum RecoveryState {
NotInRecovery,
Conservation,
Growth,
}
impl RecoveryState {
pub(super) fn in_recovery(&self) -> bool {
!matches!(self, Self::NotInRecovery)
}
}
#[derive(Debug, Clone, Default)]
struct LossState {
lost_bytes: u64,
}
impl LossState {
pub(super) fn reset(&mut self) {
self.lost_bytes = 0;
}
pub(super) fn has_losses(&self) -> bool {
self.lost_bytes != 0
}
}
fn calculate_min_window(current_mtu: u64) -> u64 {
4 * current_mtu
}
const K_DEFAULT_HIGH_GAIN: f32 = 2.885;
const K_DERIVED_HIGH_CWNDGAIN: f32 = 2.0;
const K_PACING_GAIN: [f32; 8] = [1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
const K_STARTUP_GROWTH_TARGET: f32 = 1.25;
const K_ROUND_TRIPS_WITHOUT_GROWTH_BEFORE_EXITING_STARTUP: u8 = 3;
const K_MAX_INITIAL_CONGESTION_WINDOW: u64 = 200;
const PROBE_RTT_BASED_ON_BDP: bool = true;
const DRAIN_TO_TARGET: bool = true;