summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoreug-vs <eugene@eug-vs.xyz>2024-12-11 01:41:27 +0100
committereug-vs <eugene@eug-vs.xyz>2024-12-11 01:41:27 +0100
commitfe464a9fe49e4319be9baa8321ac6cf0a7a7945c (patch)
tree6bf43b2568b8f97fdbde847d9f4edaea514bba8c
parent376edaead7be57470e74dd2d616f18a2e6bbc0b0 (diff)
downloadparticle-physics-fe464a9fe49e4319be9baa8321ac6cf0a7a7945c.tar.gz
feat: use midpoint method for more accuracy
-rw-r--r--src/main.rs5
-rw-r--r--src/midpoint.rs174
-rw-r--r--src/particle_system.rs165
3 files changed, 181 insertions, 163 deletions
diff --git a/src/main.rs b/src/main.rs
index fcf19c7..8984bf6 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,5 +1,6 @@
-use particle_system::{Particle, ParticleSystem, Point, Vector};
+use particle_system::{Particle, ParticleSystem, Point, Solver, Vector};
+mod midpoint;
mod particle_system;
fn main() {
@@ -21,7 +22,7 @@ fn main() {
particle.apply_force(gravity);
}
- system.euler_step(dt);
+ system.step(dt);
println!("{:?}", system);
}
diff --git a/src/midpoint.rs b/src/midpoint.rs
new file mode 100644
index 0000000..1aa1de8
--- /dev/null
+++ b/src/midpoint.rs
@@ -0,0 +1,174 @@
+use crate::particle_system::{ParticleSystem, Point, Scalar, Solver, Vector, N};
+use nalgebra::{Const, DVector, Dyn, Matrix, ViewStorage};
+
+/// A vector of concatenated position and velocity components of each particle
+#[derive(Debug, Clone)]
+pub struct PhaseSpace(DVector<Scalar>);
+type ParticleView<'a> = Matrix<
+ f32,
+ Const<{ PhaseSpace::PARTICLE_DIM }>,
+ Const<1>,
+ ViewStorage<'a, f32, Const<{ PhaseSpace::PARTICLE_DIM }>, Const<1>, Const<1>, Dyn>,
+>;
+
+impl PhaseSpace {
+ /// Each particle spans 2N elements in a vector
+ /// first N for position, then N more for velocity
+ const PARTICLE_DIM: usize = N * 2;
+
+ pub fn new(particle_count: usize) -> Self {
+ let dimension = particle_count * PhaseSpace::PARTICLE_DIM;
+ Self(DVector::<Scalar>::zeros(dimension))
+ }
+
+ pub fn particle_view(&self, i: usize) -> ParticleView {
+ self.0
+ .fixed_rows::<{ PhaseSpace::PARTICLE_DIM }>(i * PhaseSpace::PARTICLE_DIM)
+ }
+
+ pub fn set_particle(&mut self, i: usize, position: Point, velocity: Vector) {
+ let mut view = self
+ .0
+ .fixed_rows_mut::<{ PhaseSpace::PARTICLE_DIM }>(i * PhaseSpace::PARTICLE_DIM);
+ for i in 0..N {
+ view[i] = position[i];
+ view[i + N] = velocity[i];
+ }
+ }
+}
+
+impl ParticleSystem {
+ fn collect_phase_space(&self) -> PhaseSpace {
+ let mut phase_space = PhaseSpace::new(self.particles.len());
+ for (particle_index, particle) in self.particles.iter().enumerate() {
+ phase_space.set_particle(particle_index, particle.position, particle.velocity);
+ }
+ phase_space
+ }
+
+ fn compute_derivative(&self) -> PhaseSpace {
+ let mut phase_space = PhaseSpace::new(self.particles.len());
+ for (particle_index, particle) in self.particles.iter().enumerate() {
+ phase_space.set_particle(
+ particle_index,
+ particle.velocity.into(),
+ particle.force / particle.mass,
+ );
+ }
+ phase_space
+ }
+
+ fn scatter_phase_space(&mut self, phase_space: &PhaseSpace) {
+ for (particle_index, particle) in &mut self.particles.iter_mut().enumerate() {
+ let view = phase_space.particle_view(particle_index);
+
+ for i in 0..N {
+ particle.position[i] = view[i];
+ particle.velocity[i] = view[i + N];
+ }
+ }
+ }
+}
+
+impl Solver for ParticleSystem {
+ fn step(&mut self, dt: Scalar) {
+ let mut state = self.collect_phase_space();
+
+ // Shift to the midpoint
+ self.scatter_phase_space(&PhaseSpace {
+ 0: state.0.clone() + self.compute_derivative().0 * dt / 2.0,
+ });
+
+ state.0 += self.compute_derivative().0 * dt;
+ self.scatter_phase_space(&state);
+
+ self.t += dt;
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{ParticleSystem, PhaseSpace, Point, Scalar, Vector};
+ use crate::particle_system::{Particle, Solver};
+
+ #[test]
+ fn test_collect_phase_space() {
+ let system = ParticleSystem {
+ particles: vec![Particle::new(Point::new(2.0, 3.0), 1.0)],
+ t: 0.0,
+ };
+ let phase_space = system.collect_phase_space();
+
+ assert!(
+ !phase_space.0.is_empty(),
+ "Phase space has to contain non-zero values"
+ );
+ }
+
+ #[test]
+ fn test_scatter_phase_space() {
+ let mut phase_space = PhaseSpace::new(2);
+ phase_space.set_particle(1, Point::new(5.0, 7.0), Vector::x());
+
+ let mut system = ParticleSystem {
+ particles: vec![
+ Particle::new(Point::new(0.0, 0.0), 1.0),
+ Particle::new(Point::new(0.0, 0.0), 1.0),
+ ],
+ t: 0.0,
+ };
+
+ system.scatter_phase_space(&phase_space);
+
+ assert!(
+ !system.particles[1].velocity.is_empty(),
+ "Velocity has to be set"
+ );
+ assert!(
+ !system.particles[1].position.is_empty(),
+ "Position has to be set"
+ );
+ }
+
+ fn simulate_falling_ball(fall_time: Scalar, dt: Scalar) -> (Vector, Vector) {
+ let gravity = -9.8 * Vector::y();
+
+ let mut system = ParticleSystem {
+ particles: vec![Particle::new(Point::origin(), 1.0)],
+ t: 0.0,
+ };
+
+ let iterations = (fall_time / dt) as usize;
+
+ for _ in 0..iterations {
+ for particle in &mut system.particles {
+ particle.reset_force();
+ particle.apply_force(gravity);
+ }
+ system.step(dt);
+ }
+
+ let expected_velocity = gravity * fall_time; // vt
+ let expected_position = gravity * fall_time * fall_time / 2.0; // at^2 / 2
+
+ (
+ system.particles[0].position.coords - expected_position,
+ system.particles[0].velocity - expected_velocity,
+ )
+ }
+
+ #[test]
+ fn ball_should_fall() {
+ let (position_error, velocity_error) = simulate_falling_ball(10.0, 0.01);
+ assert!(
+ position_error.norm() < 0.01,
+ "Position error is too high: {}",
+ position_error,
+ );
+ assert!(
+ velocity_error.norm() < 0.01,
+ "Velocity error is too high: {}",
+ velocity_error,
+ );
+ }
+}
diff --git a/src/particle_system.rs b/src/particle_system.rs
index 0556764..52e6284 100644
--- a/src/particle_system.rs
+++ b/src/particle_system.rs
@@ -1,6 +1,6 @@
-use nalgebra::{Const, DVector, Dyn, Matrix, Point as PointBase, SVector, ViewStorage};
+use nalgebra::{Point as PointBase, SVector};
-const N: usize = 2;
+pub const N: usize = 2;
pub type Scalar = f32;
pub type Vector = SVector<Scalar, N>;
@@ -34,42 +34,6 @@ impl Particle {
}
}
-/// A vector of concatenated position and velocity components of each particle
-#[derive(Debug)]
-pub struct PhaseSpace(DVector<Scalar>);
-type ParticleView<'a> = Matrix<
- f32,
- Const<{ PhaseSpace::PARTICLE_DIM }>,
- Const<1>,
- ViewStorage<'a, f32, Const<{ PhaseSpace::PARTICLE_DIM }>, Const<1>, Const<1>, Dyn>,
->;
-
-impl PhaseSpace {
- /// Each particle spans 2N elements in a vector
- /// first N for position, then N more for velocity
- const PARTICLE_DIM: usize = N * 2;
-
- pub fn new(particle_count: usize) -> Self {
- let dimension = particle_count * PhaseSpace::PARTICLE_DIM;
- Self(DVector::<Scalar>::zeros(dimension))
- }
-
- pub fn particle_view(&self, i: usize) -> ParticleView {
- self.0
- .fixed_rows::<{ PhaseSpace::PARTICLE_DIM }>(i * PhaseSpace::PARTICLE_DIM)
- }
-
- pub fn set_particle(&mut self, i: usize, position: Point, velocity: Vector) {
- let mut view = self
- .0
- .fixed_rows_mut::<{ PhaseSpace::PARTICLE_DIM }>(i * PhaseSpace::PARTICLE_DIM);
- for i in 0..N {
- view[i] = position[i];
- view[i + N] = velocity[i];
- }
- }
-}
-
#[derive(Debug)]
pub struct ParticleSystem {
pub particles: Vec<Particle>,
@@ -78,127 +42,6 @@ pub struct ParticleSystem {
pub t: Scalar,
}
-impl ParticleSystem {
- fn collect_phase_space(&self) -> PhaseSpace {
- let mut phase_space = PhaseSpace::new(self.particles.len());
- for (particle_index, particle) in self.particles.iter().enumerate() {
- phase_space.set_particle(particle_index, particle.position, particle.velocity);
- }
- phase_space
- }
-
- fn compute_derivative(&self) -> PhaseSpace {
- let mut phase_space = PhaseSpace::new(self.particles.len());
- for (particle_index, particle) in self.particles.iter().enumerate() {
- phase_space.set_particle(
- particle_index,
- particle.velocity.into(),
- particle.force / particle.mass,
- );
- }
- phase_space
- }
-
- fn scatter_phase_space(&mut self, phase_space: &PhaseSpace) {
- for (particle_index, particle) in &mut self.particles.iter_mut().enumerate() {
- let view = phase_space.particle_view(particle_index);
-
- for i in 0..N {
- particle.position[i] = view[i];
- particle.velocity[i] = view[i + N];
- }
- }
- }
-
- pub fn euler_step(&mut self, dt: Scalar) {
- let derivative = self.compute_derivative();
- let mut state = self.collect_phase_space();
-
- state.0 += derivative.0 * dt;
- self.scatter_phase_space(&state);
-
- self.t += dt;
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::{Particle, ParticleSystem, PhaseSpace, Point, Scalar, Vector};
-
- #[test]
- fn test_collect_phase_space() {
- let system = ParticleSystem {
- particles: vec![Particle::new(Point::new(2.0, 3.0), 1.0)],
- t: 0.0,
- };
- let phase_space = system.collect_phase_space();
-
- assert!(
- !phase_space.0.is_empty(),
- "Phase space has to contain non-zero values"
- );
- }
-
- #[test]
- fn test_scatter_phase_space() {
- let mut phase_space = PhaseSpace::new(2);
- phase_space.set_particle(1, Point::new(5.0, 7.0), Vector::x());
-
- let mut system = ParticleSystem {
- particles: vec![
- Particle::new(Point::new(0.0, 0.0), 1.0),
- Particle::new(Point::new(0.0, 0.0), 1.0),
- ],
- t: 0.0,
- };
-
- system.scatter_phase_space(&phase_space);
-
- assert!(
- !system.particles[1].velocity.is_empty(),
- "Velocity has to be set"
- );
- assert!(
- !system.particles[1].position.is_empty(),
- "Position has to be set"
- );
- }
-
- fn simulate_falling_ball(fall_time: Scalar, dt: Scalar) -> (Vector, Vector) {
- let gravity = -9.8 * Vector::y();
-
- let mut system = ParticleSystem {
- particles: vec![Particle::new(Point::origin(), 1.0)],
- t: 0.0,
- };
-
- let iterations = (fall_time / dt) as usize;
-
- for _ in 0..iterations {
- for particle in &mut system.particles {
- particle.reset_force();
- }
-
- for particle in &mut system.particles {
- particle.apply_force(gravity);
- }
-
- system.euler_step(dt);
- }
-
- let expected_velocity = gravity * fall_time; // vt
- let expected_position = gravity * fall_time * fall_time / 2.0; // at^2 / 2
-
- (
- system.particles[0].position.coords - expected_position,
- system.particles[0].velocity - expected_velocity,
- )
- }
-
- #[test]
- fn ball_should_fall() {
- let (position_error, velocity_error) = simulate_falling_ball(10.0, 0.01);
- assert!(position_error.norm() < 0.5, "Position error has is too high");
- assert!(velocity_error.norm() < 0.5, "Velocity error has is too high");
- }
+pub trait Solver {
+ fn step(&mut self, dt: Scalar);
}