import os from plotter import Plotter class Research: def __init__(self, model, name): self.model = model self.name = name self.plotter = Plotter() self.path = os.path.join('data', name) os.makedirs(self.path, exist_ok=True) self.model_path = os.path.join(self.path, 'model.npy') self.solution_path = os.path.join(self.path, 'solution.npy') def load(self): if os.path.exists(self.solution_path): self.model.load(self.solution_path) return 1 if os.path.exists(self.model_path): self.model.load(self.model_path) return 0 def solve(self, precision=10 ** -4, preview=True, save_plot=False, save_model=False): error = 1 iteration = 0 while error > precision: iteration += 1 self.model.iterate() error = self.model.avg_error() print(f'Iteration {iteration}, avg error: {error}') if iteration % 10 == 0 or iteration == 1: if preview or save_plot: self.plotter.plot(self.model, normalize=True) if preview: self.plotter.show() if save_plot: self.plotter.save(os.path.join(self.path, f'{iteration}.png')) if iteration % 50 == 0: if save_model: self.model.save(self.model_path) self.model.save(self.solution_path) self.plotter.plot(self.model, streamplot=True) self.plotter.save(os.path.join(self.path, f'streamplot.png')) self.inspect() def inspect(self): self.plotter.plot(self.model, streamplot=True) while True: self.plotter.show()