summaryrefslogtreecommitdiff
path: root/src/research.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/research.py')
-rw-r--r--src/research.py52
1 files changed, 52 insertions, 0 deletions
diff --git a/src/research.py b/src/research.py
new file mode 100644
index 0000000..f0374cf
--- /dev/null
+++ b/src/research.py
@@ -0,0 +1,52 @@
+import os
+from plotter import Plotter
+
+
+class Research:
+ def __init__(self, model, name):
+ self.model = model
+ self.name = name
+ self.plotter = Plotter()
+ self.path = os.path.join('data', name)
+ os.makedirs(self.path, exist_ok=True)
+
+ self.model_path = os.path.join(self.path, 'model.npy')
+ self.solution_path = os.path.join(self.path, 'solution.npy')
+
+ def load(self):
+ if os.path.exists(self.solution_path):
+ self.model.load(self.solution_path)
+ return 1
+
+ if os.path.exists(self.model_path):
+ self.model.load(self.model_path)
+
+ return 0
+
+ def solve(self, precision=10 ** -4, preview=True, save_plot=False, save_model=False):
+ error = 1
+ iteration = 0
+ while error > precision:
+ iteration += 1
+ self.model.iterate()
+ error = self.model.avg_error()
+ print(f'Iteration {iteration}, avg error: {error}')
+
+ if iteration % 10 == 0 or iteration == 1:
+ if preview or save_plot:
+ self.plotter.plot(self.model)
+ if preview:
+ self.plotter.show()
+
+ if iteration % 50 == 0:
+ if save_model:
+ self.model.save(self.model_path)
+
+ self.model.save(self.solution_path)
+ self.inspect()
+
+ def inspect(self):
+ self.plotter.plot(self.model, density=1)
+ while True:
+ self.plotter.show()
+