use std::collections::HashSet;
use bevy::math::DVec2;
use bevy::transform::TransformSystems;
use bevy_replicon::client::confirm_history::ConfirmHistory;
use bevy_replicon::prelude::Replicated;
use bevy_replicon::shared::replicon_tick::RepliconTick;
use crate::prelude::*;
use crate::shared::config::planet::{Planet, PlanetSpring};
use crate::shared::world_config::WorldConfigResource;
const SKIP_THRESHOLD: f64 = 10.0;
#[derive(Component, Default)]
struct LocalState {
position: DVec2,
linvel: DVec2,
angvel: f64,
rotation_cos: f64,
rotation_sin: f64,
}
#[derive(Component, Default)]
struct CorrectionTarget {
needs_correction: bool,
position: DVec2,
linvel: DVec2,
angvel: f64,
rotation_cos: f64,
rotation_sin: f64,
last_tick: Option<RepliconTick>,
}
#[derive(Resource, Default)]
struct LocalTick(u64);
#[derive(Resource)]
struct LagEstimate {
horizon: f64,
last_local_tick: u64,
}
impl Default for LagEstimate {
fn default() -> Self {
Self { horizon: 4.0, last_local_tick: 0 }
}
}
#[derive(Resource, Default)]
struct PlanetPositions(Vec<(DVec2, f64)>);
/// Public so future thrust-prediction systems can replay buffered inputs at the correct tick.
#[derive(Resource, Default)]
pub struct Resimulating {
pub active: bool,
/// The absolute `LocalTick` value corresponding to the current resimulation step.
pub resim_tick: u64,
}
struct PhysicsSnapshot {
position: DVec2,
rotation_cos: f64,
rotation_sin: f64,
linvel: DVec2,
angvel: f64,
}
impl PhysicsSnapshot {
fn from_target(t: &CorrectionTarget) -> Self {
Self {
position: t.position,
rotation_cos: t.rotation_cos,
rotation_sin: t.rotation_sin,
linvel: t.linvel,
angvel: t.angvel,
}
}
fn from_components(
p: &Position,
r: &Rotation,
lv: &LinearVelocity,
av: &AngularVelocity,
) -> Self {
Self {
position: p.0,
rotation_cos: r.cos,
rotation_sin: r.sin,
linvel: lv.0,
angvel: av.0,
}
}
fn apply(&self, world: &mut World, entity: Entity) {
let mut e = world.entity_mut(entity);
if let Some(mut p) = e.get_mut::<Position>() { p.0 = self.position; }
if let Some(mut r) = e.get_mut::<Rotation>() {
r.cos = self.rotation_cos;
r.sin = self.rotation_sin;
}
if let Some(mut lv) = e.get_mut::<LinearVelocity>() { lv.0 = self.linvel; }
if let Some(mut av) = e.get_mut::<AngularVelocity>() { av.0 = self.angvel; }
}
}
pub fn prediction_plugin(app: &mut App) {
app
.init_resource::<LocalTick>()
.init_resource::<LagEstimate>()
.init_resource::<Resimulating>()
.init_resource::<PlanetPositions>()
.add_observer(init_prediction)
.add_systems(FixedFirst, tick_local_clock)
.add_systems(PreUpdate, (
(
save_local_state,
collect_planet_positions,
).before(bevy_replicon::client::ClientSystems::Receive),
record_server_correction.after(bevy_replicon::client::ClientSystems::Receive),
))
.add_systems(PostUpdate,
perform_resimulation.before(TransformSystems::Propagate));
}
fn tick_local_clock(mut tick: ResMut<LocalTick>) {
tick.0 += 1;
}
fn init_prediction(
trigger: On<Add, Replicated>,
query: Query<
(&Position, &Rotation, &LinearVelocity, Option<&AngularVelocity>),
(Without<PlanetSpring>, Without<Planet>),
>,
mut commands: Commands,
) {
let entity = trigger.event_target();
let Ok((pos, rot, linvel, angvel)) = query.get(entity) else { return };
commands.entity(entity).insert((
LocalState {
position: pos.0,
linvel: linvel.0,
angvel: angvel.map(|a| a.0).unwrap_or(0.0),
rotation_cos: rot.cos,
rotation_sin: rot.sin,
},
CorrectionTarget::default(),
));
}
fn collect_planet_positions(
mut resource: ResMut<PlanetPositions>,
planets: Query<(&Position, &Mass), (With<Planet>, Without<PlanetSpring>)>,
) {
resource.0 = planets.iter().map(|(p, m)| (p.0, m.0 as f64)).collect();
}
fn save_local_state(
mut query: Query<(&Position, &Rotation, &LinearVelocity, &AngularVelocity, &mut LocalState)>,
) {
for (pos, rot, linvel, angvel, mut local) in &mut query {
if pos.0.is_nan() { continue; }
local.position = pos.0;
local.linvel = linvel.0;
local.angvel = angvel.0;
local.rotation_cos = rot.cos;
local.rotation_sin = rot.sin;
}
}
fn record_server_correction(
mut query: Query<(
&mut Position,
&mut Rotation,
&mut LinearVelocity,
&mut AngularVelocity,
&ConfirmHistory,
&mut CorrectionTarget,
&LocalState,
)>,
planet_positions: Res<PlanetPositions>,
world_config: Res<WorldConfigResource>,
local_tick: Res<LocalTick>,
mut lag: ResMut<LagEstimate>,
) {
let Some(cfg) = &world_config.config else {
for (mut pos, mut rot, mut linvel, mut angvel, _, _, local) in &mut query {
pos.0 = local.position;
rot.cos = local.rotation_cos;
rot.sin = local.rotation_sin;
linvel.0 = local.linvel;
angvel.0 = local.angvel;
}
return;
};
let planet_snapshot = &planet_positions.0;
for (mut pos, mut rot, mut linvel, mut angvel, history, mut target, local) in &mut query {
let tick = history.last_tick();
if target.last_tick == Some(tick) {
pos.0 = local.position;
rot.cos = local.rotation_cos;
rot.sin = local.rotation_sin;
linvel.0 = local.linvel;
angvel.0 = local.angvel;
continue;
}
let elapsed = (local_tick.0.saturating_sub(lag.last_local_tick)).clamp(1, 20) as f64;
lag.horizon = lag.horizon * 0.9 + elapsed * 0.1;
lag.last_local_tick = local_tick.0;
target.last_tick = Some(tick);
let server_pos = pos.0;
let server_vel = linvel.0;
let n = lag.horizon.round() as u64;
let dt = 1.0 / crate::shared::plugins::TICK_RATE;
let (ext_pos, _) = extrapolate(
server_pos, server_vel, &planet_snapshot, n, dt, cfg.world.gravity,
);
let error = (ext_pos - local.position).length();
if error < SKIP_THRESHOLD {
target.needs_correction = false;
} else {
target.needs_correction = true;
target.position = server_pos;
target.linvel = server_vel;
target.angvel = angvel.0;
target.rotation_cos = rot.cos;
target.rotation_sin = rot.sin;
}
pos.0 = local.position;
rot.cos = local.rotation_cos;
rot.sin = local.rotation_sin;
linvel.0 = local.linvel;
angvel.0 = local.angvel;
}
}
fn extrapolate(
mut pos: DVec2,
mut vel: DVec2,
planets: &[(DVec2, f64)],
n: u64,
dt: f64,
gravity: f64,
) -> (DVec2, DVec2) {
for _ in 0..n {
let mut accel = DVec2::ZERO;
for &(planet_pos, planet_mass) in planets {
let diff = planet_pos - pos;
let dist_sq = diff.length_squared().max(1.0);
accel += diff.normalize() * gravity * planet_mass / dist_sq;
}
vel += accel * dt;
pos += vel * dt;
}
(pos, vel)
}
fn perform_resimulation(world: &mut World) {
let corrections: Vec<(Entity, PhysicsSnapshot)> = {
let mut q = world.query::<(Entity, &CorrectionTarget)>();
q.iter(world)
.filter(|(_, t)| t.needs_correction)
.map(|(e, t)| (e, PhysicsSnapshot::from_target(t)))
.collect()
};
if corrections.is_empty() { return; }
let saved: Vec<(Entity, PhysicsSnapshot)> = {
let mut q = world.query::<(
Entity, &Position, &Rotation, &LinearVelocity, &AngularVelocity,
)>();
q.iter(world)
.map(|(e, p, r, lv, av)| (e, PhysicsSnapshot::from_components(p, r, lv, av)))
.collect()
};
// Save Transforms separately; avian2d's Writeback overwrites Transform in each FixedUpdate
// step, so non-corrected entities would visually jump if we don't restore it.
let saved_transforms: Vec<(Entity, Vec3, Quat)> = {
let mut q = world.query::<(Entity, &Transform)>();
q.iter(world)
.map(|(e, t)| (e, t.translation, t.rotation))
.collect()
};
let corrected_set: HashSet<Entity> = corrections.iter().map(|(e, _)| *e).collect();
for (entity, snapshot) in &corrections {
snapshot.apply(world, *entity);
}
let n = world.resource::<LagEstimate>().horizon.round() as u64;
let base_tick = world.resource::<LocalTick>().0;
world.resource_mut::<Resimulating>().active = true;
for step in 0..n {
world.resource_mut::<Resimulating>().resim_tick = base_tick + step;
world.run_schedule(FixedUpdate);
}
world.resource_mut::<Resimulating>().active = false;
for (entity, snapshot) in &saved {
if !corrected_set.contains(entity) {
snapshot.apply(world, *entity);
}
}
// Restore transforms for non-corrected entities so the visual doesn't jump.
for (entity, translation, rotation) in &saved_transforms {
if !corrected_set.contains(entity) {
let mut e = world.entity_mut(*entity);
if let Some(mut t) = e.get_mut::<Transform>() {
t.translation = *translation;
t.rotation = *rotation;
}
}
}
for (entity, _) in &corrections {
if let Some(mut t) = world.entity_mut(*entity).get_mut::<CorrectionTarget>() {
t.needs_correction = false;
}
}
}