summaryrefslogtreecommitdiff
path: root/src/plotter.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/plotter.py')
-rw-r--r--src/plotter.py39
1 files changed, 21 insertions, 18 deletions
diff --git a/src/plotter.py b/src/plotter.py
index 45d95da..c322de9 100644
--- a/src/plotter.py
+++ b/src/plotter.py
@@ -20,7 +20,7 @@ class Plotter:
if self.plt:
self.plt.remove()
- def plot(self, model, normalize=True, density=False, save_path=''):
+ def plot(self, model, normalize=True, density=False, save_path='', streamplot=False):
self.clear()
axes.set_title('Velocity field (normalized)')
@@ -28,38 +28,41 @@ class Plotter:
plt.xlabel('X')
plt.ylabel('Y')
- u, v = model.u, model.v
- if normalize:
- factor = np.sqrt(u ** 2 + v ** 2)
- u = u / factor
- v = v / factor
+ shape = model.p.shape
+
+ u = np.zeros(shape, dtype=float)
+ v = np.zeros(shape, dtype=float)
+
+ for i in range(shape[0]):
+ for j in range(shape[1]):
+ u[i][j] = 0.5 * (model.u[i][j] + model.u[i][j + 1])
+ v[i][j] = 0.5 * (model.v[i][j] + model.v[i + 1][j])
+ assert not v[0, :].any()
+ assert not v[-1, :].any()
- shape = (model.p.shape[0] + 1, model.p.shape[1] + 1)
x, y = np.meshgrid(
np.linspace(0, shape[1] * model.step, shape[1]),
np.linspace(0, shape[0] * model.step, shape[0]),
)
- u = copy(model.u)
- u.resize(shape)
- v = copy(model.v)
- v.resize(shape)
- p = copy(model.p)
- p.resize(shape)
-
- print(shape, u.shape, v.shape)
+ if normalize:
+ factor = np.sqrt(u ** 2 + v ** 2)
+ u = u / factor
+ v = v / factor
# density = density or int((max(model.domain_size) / model.step) / 40)
- plt.contourf(x, y, p)
+ plt.contourf(x, y, model.p)
# self.patch = axes.add_patch(Rectangle((0, 0), *reversed(model.bfs_size), color='gray'))
- # TODO: allow using streamplot
- self.plt = plt.quiver(
+
+ plotter = plt.streamplot if streamplot else plt.quiver
+ self.plt = plotter(
x,
y,
u,
v,
+ color='black'
)
self.colorbar = plt.colorbar(label='Pressure')