import numpy as np


class SIMPLE:
    def __init__(self, shape, bfs_shape, step, Re, alpha=0.8):
        np.set_printoptions(precision=2, floatmode="maxprec", suppress=True)

        self.Re = Re
        self.nu = 1 / Re
        self.alpha = alpha

        self.step = step
        self.bfs_shape = bfs_shape

        # Allocations
        self.u = np.zeros(shape=(shape[0], shape[1] + 1), dtype=float)
        self.u_star = np.zeros(shape=(shape[0], shape[1] + 1), dtype=float)

        self.v = np.zeros(shape=(shape[0] + 1, shape[1]), dtype=float)
        self.v_star = np.zeros(shape=(shape[0] + 1, shape[1]), dtype=float)

        self.p = np.zeros(shape=shape, dtype=float)
        self.p_star = np.random.rand(*shape)
        self.p_prime = np.zeros(shape=shape, dtype=float)

        self.d_e = np.zeros(shape=self.u.shape, dtype=float)
        self.d_n = np.zeros(shape=self.v.shape, dtype=float)
        self.b = np.zeros(shape=shape, dtype=float)

    def assert_positive(self, value):
        '''Assert that the value is nearly positive'''
        assert value > -0.01, f'WARNING: Value must be positive: {value}'
        return value

    def solve_momentum_equations(self):
        # Momentum along X direction
        for i in range(1, self.u.shape[0] - 1):
            for j in range(1, self.u.shape[1] - 1):
                if i >= self.bfs_shape[0] or j >= self.bfs_shape[1]:
                    u_W = 0.5 * (self.u[i][j] + self.u[i][j - 1])
                    u_E = 0.5 * (self.u[i][j] + self.u[i][j + 1])

                    v_S = 0.5 * (self.v[i][j - 1] + self.v[i][j])
                    v_N = 0.5 * (self.v[i + 1][j - 1] + self.v[i + 1][j])

                    a_E = self.assert_positive(-0.5 * u_E * self.step + self.nu)
                    a_W = self.assert_positive(+0.5 * u_W * self.step + self.nu)
                    a_N = self.assert_positive(-0.5 * v_N * self.step + self.nu)
                    a_S = self.assert_positive(+0.5 * v_S * self.step + self.nu)

                    a_e = 0.5 * self.step * (u_E - u_W + v_N - v_S) + 4 * self.nu
                    A_e = self.step

                    self.d_e[i][j] = A_e / a_e

                    self.u_star[i][j] = (
                        a_E * self.u[i][j + 1] +
                        a_W * self.u[i][j - 1] +
                        a_N * self.u[i + 1][j] +
                        a_S * self.u[i - 1][j] +
                        0#self.b[i][j - 1]
                    ) / a_e + self.d_e[i][j] * (self.p_star[i][j - 1] - self.p_star[i][j]) # p - p_e

        # Momentum along Y direction
        for i in range(1, self.v.shape[0] - 1):
            for j in range(1, self.v.shape[1] - 1):
                if i >= self.bfs_shape[0] or j >= self.bfs_shape[1]:
                    u_W = 0.5 * (self.u[i - 1][j] + self.u[i][j])
                    u_E = 0.5 * (self.u[i - 1][j + 1] + self.u[i][j + 1])

                    v_N = 0.5 * (self.v[i][j] + self.v[i + 1][j])
                    v_S = 0.5 * (self.v[i][j] + self.v[i - 1][j])

                    a_E = self.assert_positive(-0.5 * u_E * self.step + self.nu)
                    a_W = self.assert_positive(+0.5 * u_W * self.step + self.nu)
                    a_N = self.assert_positive(-0.5 * v_N * self.step + self.nu)
                    a_S = self.assert_positive(+0.5 * v_S * self.step + self.nu)

                    a_n = 0.5 * self.step * (u_E - u_W + v_N - v_S) + 4 * self.nu
                    A_n = self.step

                    self.d_n[i][j] = A_n / a_n

                    self.v_star[i][j] = (
                        a_E * self.v[i][j + 1] +
                        a_W * self.v[i][j - 1] +
                        a_N * self.v[i + 1][j] +
                        a_S * self.v[i - 1][j] +
                        0#self.b[i - 1][j]
                    ) / a_n + self.d_n[i][j] * (self.p_star[i - 1][j] - self.p_star[i][j]) # p - p_n

    def correct_pressure(self):
        self.p_prime = np.zeros(shape=self.p.shape, dtype=float)
        for i in range(1, self.p.shape[0] - 1):
            for j in range(1, self.p.shape[1] - 1):
                if i >= self.bfs_shape[0] or j >= self.bfs_shape[1]:
                    a_E = 0 if j == self.p.shape[1] - 1 else self.assert_positive(-self.d_e[i][j+1] * self.step)
                    a_W = 0 if j == 1 else self.assert_positive(-self.d_e[i][j] * self.step)
                    a_N = 0 if i == self.p.shape[0] - 1 else self.assert_positive(-self.d_n[i+1][j] * self.step)
                    a_S = 0 if i == 1 else self.assert_positive(-self.d_n[i][j] * self.step)
                    a_P = a_E + a_W + a_N + a_S

                    self.b[i][j] = self.step * (
                        (self.u_star[i][j+1] - self.u_star[i][j]) +
                        (self.v_star[i+1][j] - self.v_star[i][j])
                    )

                    self.p_prime[i][j] = (
                        (a_E * self.p_prime[i][j+1] if a_E > 0 else 0) +
                        (a_W * self.p_prime[i][j-1] if a_W > 0 else 0) +
                        (a_N * self.p_prime[i+1][j] if a_N > 0 else 0) +
                        (a_S * self.p_prime[i-1][j] if a_S > 0 else 0) +
                        self.b[i][j]
                    ) / a_P

        self.p = self.p_star + self.p_prime * self.alpha
        self.p_star = self.p

    def correct_velocities(self):
        for i in range(self.u.shape[0]):
            for j in range(1, self.u.shape[1] - 1):
                self.u[i][j] = self.u_star[i][j] + self.d_e[i][j] * (self.p_prime[i][j - 1] - self.p_prime[i][j])

        for i in range(1, self.v.shape[0] - 1):
            for j in range(self.v.shape[1]):
                self.v[i][j] = self.v_star[i][j] + self.d_n[i][j] * (self.p_prime[i - 1][j] - self.p_prime[i][j])

    def iterate(self):
        self.solve_momentum_equations()

        # Boundary
        self.u_star[:, 0] = 2 - self.u_star[:, 1]
        self.v_star[:, 0] = 0

        self.v_star[-2, :] = -self.v_star[-1, :]
        self.v_star[1, :] = -self.v_star[0, :]

        self.v_star[self.bfs_shape[0], :self.bfs_shape[1]] = self.v_star[self.bfs_shape[0] - 1, :self.bfs_shape[1]]
        self.u_star[:self.bfs_shape[0], self.bfs_shape[1]] = self.u_star[:self.bfs_shape[0], self.bfs_shape[1] - 1]

        self.p_star[:self.bfs_shape[0], :self.bfs_shape[1]] = 0

        self.correct_pressure()
        self.correct_velocities()

        # Boundary enforce
        self.u[:, 0] = 2 - self.u[:, 1]
        self.v[:, 0] = 0

        self.v[-2, :] = -self.v[-1, :]
        self.v[1, :] = -self.v[0, :]

        self.v[self.bfs_shape[0], :self.bfs_shape[1]] = self.v[self.bfs_shape[0] - 1, :self.bfs_shape[1]]
        self.u[:self.bfs_shape[0], self.bfs_shape[1]] = self.u[:self.bfs_shape[0], self.bfs_shape[1] - 1]

        self.p[:self.bfs_shape[0], :self.bfs_shape[1]] = 0

    def avg_error(self):
        return np.absolute(self.b).sum()

    def save(self, path):
        print('SAVE', path)
        with open(path, 'wb') as file:
            np.save(file, self.u)
            np.save(file, self.v)
            np.save(file, self.p)
            np.save(file, self.b)

    def load(self, path):
        print('LOAD', path)
        with open(path, 'rb') as file:
            self.u = np.load(file)
            self.v = np.load(file)
            self.p = np.load(file)
            self.b = np.load(file)
            self.p_star = self.p