summaryrefslogtreecommitdiff
path: root/src/plotter.py
blob: b493c3e528eeccd8e51ea19ee85233387000911a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import matplotlib.pyplot as plt
from copy import copy
from matplotlib.patches import Rectangle

figure, axes = plt.subplots()


class Plotter:
    def __init__(self):
        self.plt = None
        self.colorbar = None
        self.patch = None

    def clear(self):
        if self.patch:
            self.patch.remove()
        if self.colorbar:
            self.colorbar.remove()
        if self.plt:
            try:
                self.plt.remove()
            except:
                pass

    def plot(self, model, normalize=True, density=False, save_path='', streamplot=False):
        self.clear()

        axes.set_title('Velocity field (normalized)')
        plt.suptitle(f'Avg mass source per grid point = {model.avg_error()}')
        plt.xlabel('X')
        plt.ylabel('Y')

        shape = model.p.shape

        u = np.zeros(shape, dtype=float)
        v = np.zeros(shape, dtype=float)

        for i in range(shape[0]):
            for j in range(shape[1]):
                u[i][j] = 0.5 * (model.u[i][j] + model.u[i][j + 1])
                v[i][j] = 0.5 * (model.v[i][j] + model.v[i + 1][j])

        assert not v[0, :].any()
        assert not v[-1, :].any()

        x, y = np.meshgrid(
            np.linspace(0, shape[1] * model.step, shape[1]),
            np.linspace(0, shape[0] * model.step, shape[0]),
        )

        if normalize:
            factor = np.sqrt(u ** 2 + v ** 2)
            u = u / factor
            v = v / factor

        density = density or int(min(model.p.shape) / 40)

        plt.contourf(x, y, model.p, cmap='inferno')
        self.colorbar = plt.colorbar(label='Pressure')

        self.patch = axes.add_patch(Rectangle((0, 0), *reversed(list(x * model.step for x in model.bfs_shape)), color='gray'))

        plotter = plt.streamplot if streamplot else plt.quiver
        self.plt = plotter(
            x[::density, ::density],
            y[::density, ::density],
            u[::density, ::density],
            v[::density, ::density],
            color='black'
        )


    def save(self, path):
        return plt.savefig(path, dpi=300)

    def show(self):
        return plt.pause(0.0001)