diff options
Diffstat (limited to 'src/research.py')
-rw-r--r-- | src/research.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/src/research.py b/src/research.py new file mode 100644 index 0000000..f0374cf --- /dev/null +++ b/src/research.py @@ -0,0 +1,52 @@ +import os +from plotter import Plotter + + +class Research: + def __init__(self, model, name): + self.model = model + self.name = name + self.plotter = Plotter() + self.path = os.path.join('data', name) + os.makedirs(self.path, exist_ok=True) + + self.model_path = os.path.join(self.path, 'model.npy') + self.solution_path = os.path.join(self.path, 'solution.npy') + + def load(self): + if os.path.exists(self.solution_path): + self.model.load(self.solution_path) + return 1 + + if os.path.exists(self.model_path): + self.model.load(self.model_path) + + return 0 + + def solve(self, precision=10 ** -4, preview=True, save_plot=False, save_model=False): + error = 1 + iteration = 0 + while error > precision: + iteration += 1 + self.model.iterate() + error = self.model.avg_error() + print(f'Iteration {iteration}, avg error: {error}') + + if iteration % 10 == 0 or iteration == 1: + if preview or save_plot: + self.plotter.plot(self.model) + if preview: + self.plotter.show() + + if iteration % 50 == 0: + if save_model: + self.model.save(self.model_path) + + self.model.save(self.solution_path) + self.inspect() + + def inspect(self): + self.plotter.plot(self.model, density=1) + while True: + self.plotter.show() + |