From ecdafb45dd9c416cb0810d6687a20ac97e480ac9 Mon Sep 17 00:00:00 2001 From: eug-vs Date: Wed, 11 Dec 2024 02:12:22 +0100 Subject: refactor: allow for other solver implementations --- src/main.rs | 5 +- src/midpoint.rs | 174 ------------------------------------------------- src/particle_system.rs | 4 -- src/solver/midpoint.rs | 18 +++++ src/solver/mod.rs | 164 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 185 insertions(+), 180 deletions(-) delete mode 100644 src/midpoint.rs create mode 100644 src/solver/midpoint.rs create mode 100644 src/solver/mod.rs (limited to 'src') diff --git a/src/main.rs b/src/main.rs index 8984bf6..0a51789 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,8 @@ -use particle_system::{Particle, ParticleSystem, Point, Solver, Vector}; +use particle_system::{Particle, ParticleSystem, Point, Vector}; +use solver::Solver; -mod midpoint; mod particle_system; +mod solver; fn main() { let dt = 0.01; diff --git a/src/midpoint.rs b/src/midpoint.rs deleted file mode 100644 index 1aa1de8..0000000 --- a/src/midpoint.rs +++ /dev/null @@ -1,174 +0,0 @@ -use crate::particle_system::{ParticleSystem, Point, Scalar, Solver, Vector, N}; -use nalgebra::{Const, DVector, Dyn, Matrix, ViewStorage}; - -/// A vector of concatenated position and velocity components of each particle -#[derive(Debug, Clone)] -pub struct PhaseSpace(DVector); -type ParticleView<'a> = Matrix< - f32, - Const<{ PhaseSpace::PARTICLE_DIM }>, - Const<1>, - ViewStorage<'a, f32, Const<{ PhaseSpace::PARTICLE_DIM }>, Const<1>, Const<1>, Dyn>, ->; - -impl PhaseSpace { - /// Each particle spans 2N elements in a vector - /// first N for position, then N more for velocity - const PARTICLE_DIM: usize = N * 2; - - pub fn new(particle_count: usize) -> Self { - let dimension = particle_count * PhaseSpace::PARTICLE_DIM; - Self(DVector::::zeros(dimension)) - } - - pub fn particle_view(&self, i: usize) -> ParticleView { - self.0 - .fixed_rows::<{ PhaseSpace::PARTICLE_DIM }>(i * PhaseSpace::PARTICLE_DIM) - } - - pub fn set_particle(&mut self, i: usize, position: Point, velocity: Vector) { - let mut view = self - .0 - .fixed_rows_mut::<{ PhaseSpace::PARTICLE_DIM }>(i * PhaseSpace::PARTICLE_DIM); - for i in 0..N { - view[i] = position[i]; - view[i + N] = velocity[i]; - } - } -} - -impl ParticleSystem { - fn collect_phase_space(&self) -> PhaseSpace { - let mut phase_space = PhaseSpace::new(self.particles.len()); - for (particle_index, particle) in self.particles.iter().enumerate() { - phase_space.set_particle(particle_index, particle.position, particle.velocity); - } - phase_space - } - - fn compute_derivative(&self) -> PhaseSpace { - let mut phase_space = PhaseSpace::new(self.particles.len()); - for (particle_index, particle) in self.particles.iter().enumerate() { - phase_space.set_particle( - particle_index, - particle.velocity.into(), - particle.force / particle.mass, - ); - } - phase_space - } - - fn scatter_phase_space(&mut self, phase_space: &PhaseSpace) { - for (particle_index, particle) in &mut self.particles.iter_mut().enumerate() { - let view = phase_space.particle_view(particle_index); - - for i in 0..N { - particle.position[i] = view[i]; - particle.velocity[i] = view[i + N]; - } - } - } -} - -impl Solver for ParticleSystem { - fn step(&mut self, dt: Scalar) { - let mut state = self.collect_phase_space(); - - // Shift to the midpoint - self.scatter_phase_space(&PhaseSpace { - 0: state.0.clone() + self.compute_derivative().0 * dt / 2.0, - }); - - state.0 += self.compute_derivative().0 * dt; - self.scatter_phase_space(&state); - - self.t += dt; - } -} - -#[cfg(test)] -mod tests { - use super::{ParticleSystem, PhaseSpace, Point, Scalar, Vector}; - use crate::particle_system::{Particle, Solver}; - - #[test] - fn test_collect_phase_space() { - let system = ParticleSystem { - particles: vec![Particle::new(Point::new(2.0, 3.0), 1.0)], - t: 0.0, - }; - let phase_space = system.collect_phase_space(); - - assert!( - !phase_space.0.is_empty(), - "Phase space has to contain non-zero values" - ); - } - - #[test] - fn test_scatter_phase_space() { - let mut phase_space = PhaseSpace::new(2); - phase_space.set_particle(1, Point::new(5.0, 7.0), Vector::x()); - - let mut system = ParticleSystem { - particles: vec![ - Particle::new(Point::new(0.0, 0.0), 1.0), - Particle::new(Point::new(0.0, 0.0), 1.0), - ], - t: 0.0, - }; - - system.scatter_phase_space(&phase_space); - - assert!( - !system.particles[1].velocity.is_empty(), - "Velocity has to be set" - ); - assert!( - !system.particles[1].position.is_empty(), - "Position has to be set" - ); - } - - fn simulate_falling_ball(fall_time: Scalar, dt: Scalar) -> (Vector, Vector) { - let gravity = -9.8 * Vector::y(); - - let mut system = ParticleSystem { - particles: vec![Particle::new(Point::origin(), 1.0)], - t: 0.0, - }; - - let iterations = (fall_time / dt) as usize; - - for _ in 0..iterations { - for particle in &mut system.particles { - particle.reset_force(); - particle.apply_force(gravity); - } - system.step(dt); - } - - let expected_velocity = gravity * fall_time; // vt - let expected_position = gravity * fall_time * fall_time / 2.0; // at^2 / 2 - - ( - system.particles[0].position.coords - expected_position, - system.particles[0].velocity - expected_velocity, - ) - } - - #[test] - fn ball_should_fall() { - let (position_error, velocity_error) = simulate_falling_ball(10.0, 0.01); - assert!( - position_error.norm() < 0.01, - "Position error is too high: {}", - position_error, - ); - assert!( - velocity_error.norm() < 0.01, - "Velocity error is too high: {}", - velocity_error, - ); - } -} diff --git a/src/particle_system.rs b/src/particle_system.rs index 52e6284..6f43c35 100644 --- a/src/particle_system.rs +++ b/src/particle_system.rs @@ -41,7 +41,3 @@ pub struct ParticleSystem { /// Simulation clock pub t: Scalar, } - -pub trait Solver { - fn step(&mut self, dt: Scalar); -} diff --git a/src/solver/midpoint.rs b/src/solver/midpoint.rs new file mode 100644 index 0000000..c5a01c4 --- /dev/null +++ b/src/solver/midpoint.rs @@ -0,0 +1,18 @@ +use crate::particle_system::{ParticleSystem, Scalar}; +use super::{PhaseSpace, Solver}; + +impl Solver for ParticleSystem { + fn step(&mut self, dt: Scalar) { + let mut state = self.collect_phase_space(); + + // Shift to the midpoint + self.scatter_phase_space(&&PhaseSpace { + 0: state.0.clone() + self.compute_derivative().0 * dt / 2.0, + }); + + state.0 += self.compute_derivative().0 * dt; + self.scatter_phase_space(&state); + + self.t += dt; + } +} diff --git a/src/solver/mod.rs b/src/solver/mod.rs new file mode 100644 index 0000000..4a5fec5 --- /dev/null +++ b/src/solver/mod.rs @@ -0,0 +1,164 @@ +use crate::particle_system::{ParticleSystem, Point, Scalar, Vector, N}; +use nalgebra::{Const, DVector, Dyn, Matrix, ViewStorage}; + +mod midpoint; + +/// A vector of concatenated position and velocity components of each particle +#[derive(Debug, Clone)] +pub struct PhaseSpace(DVector); +type ParticleView<'a> = Matrix< + f32, + Const<{ PhaseSpace::PARTICLE_DIM }>, + Const<1>, + ViewStorage<'a, f32, Const<{ PhaseSpace::PARTICLE_DIM }>, Const<1>, Const<1>, Dyn>, +>; + +impl PhaseSpace { + /// Each particle spans 2N elements in a vector + /// first N for position, then N more for velocity + const PARTICLE_DIM: usize = N * 2; + + pub fn new(particle_count: usize) -> Self { + let dimension = particle_count * PhaseSpace::PARTICLE_DIM; + Self(DVector::::zeros(dimension)) + } + + pub fn particle_view(&self, i: usize) -> ParticleView { + self.0 + .fixed_rows::<{ PhaseSpace::PARTICLE_DIM }>(i * PhaseSpace::PARTICLE_DIM) + } + + pub fn set_particle(&mut self, i: usize, position: Point, velocity: Vector) { + let mut view = self + .0 + .fixed_rows_mut::<{ PhaseSpace::PARTICLE_DIM }>(i * PhaseSpace::PARTICLE_DIM); + for i in 0..N { + view[i] = position[i]; + view[i + N] = velocity[i]; + } + } +} + +impl ParticleSystem { + fn collect_phase_space(&self) -> PhaseSpace { + let mut phase_space = PhaseSpace::new(self.particles.len()); + for (particle_index, particle) in self.particles.iter().enumerate() { + phase_space.set_particle(particle_index, particle.position, particle.velocity); + } + phase_space + } + + fn compute_derivative(&self) -> PhaseSpace { + let mut phase_space = PhaseSpace::new(self.particles.len()); + for (particle_index, particle) in self.particles.iter().enumerate() { + phase_space.set_particle( + particle_index, + particle.velocity.into(), + particle.force / particle.mass, + ); + } + phase_space + } + + fn scatter_phase_space(&mut self, phase_space: &PhaseSpace) { + for (particle_index, particle) in &mut self.particles.iter_mut().enumerate() { + let view = phase_space.particle_view(particle_index); + + for i in 0..N { + particle.position[i] = view[i]; + particle.velocity[i] = view[i + N]; + } + } + } +} + +pub trait Solver { + fn step(&mut self, dt: Scalar); +} + +#[cfg(test)] +mod tests { + use super::{ParticleSystem, PhaseSpace, Point, Scalar, Solver, Vector}; + use crate::particle_system::Particle; + + #[test] + fn test_collect_phase_space() { + let system = ParticleSystem { + particles: vec![Particle::new(Point::new(2.0, 3.0), 1.0)], + t: 0.0, + }; + let phase_space = system.collect_phase_space(); + + assert!( + !phase_space.0.is_empty(), + "Phase space has to contain non-zero values" + ); + } + + #[test] + fn test_scatter_phase_space() { + let mut phase_space = PhaseSpace::new(2); + phase_space.set_particle(1, Point::new(5.0, 7.0), Vector::x()); + + let mut system = ParticleSystem { + particles: vec![ + Particle::new(Point::new(0.0, 0.0), 1.0), + Particle::new(Point::new(0.0, 0.0), 1.0), + ], + t: 0.0, + }; + + system.scatter_phase_space(&phase_space); + + assert!( + !system.particles[1].velocity.is_empty(), + "Velocity has to be set" + ); + assert!( + !system.particles[1].position.is_empty(), + "Position has to be set" + ); + } + + fn simulate_falling_ball(fall_time: Scalar, dt: Scalar) -> (Vector, Vector) { + let gravity = -9.8 * Vector::y(); + + let mut system = ParticleSystem { + particles: vec![Particle::new(Point::origin(), 1.0)], + t: 0.0, + }; + + let iterations = (fall_time / dt) as usize; + + for _ in 0..iterations { + for particle in &mut system.particles { + particle.reset_force(); + particle.apply_force(gravity); + } + system.step(dt); + } + + let expected_velocity = gravity * fall_time; // vt + let expected_position = gravity * fall_time * fall_time / 2.0; // at^2 / 2 + + ( + system.particles[0].position.coords - expected_position, + system.particles[0].velocity - expected_velocity, + ) + } + + #[test] + fn ball_should_fall() { + let (position_error, velocity_error) = simulate_falling_ball(10.0, 0.01); + assert!( + position_error.norm() < 0.01, + "Position error is too high: {}", + position_error, + ); + assert!( + velocity_error.norm() < 0.01, + "Velocity error is too high: {}", + velocity_error, + ); + } +} -- cgit v1.2.3