From 3f4df26b0fc59d79f824da9074f3414f19e69d5d Mon Sep 17 00:00:00 2001 From: flaport Date: Thu, 7 Dec 2023 11:16:13 -0800 Subject: [PATCH] more bugfixes in export from gpu (#66) --- fdtd/backend.py | 2 -- fdtd/grid.py | 13 +++++++++++-- fdtd/visualization.py | 16 ++++++++-------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/fdtd/backend.py b/fdtd/backend.py index 9abc0d0..dd29bd6 100644 --- a/fdtd/backend.py +++ b/fdtd/backend.py @@ -338,8 +338,6 @@ def array(self, arr, dtype=None, **kwargs): return arr.clone().to(device="cuda", dtype=dtype, **kwargs) return torch.tensor(arr, device="cuda", dtype=dtype, **kwargs) - # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # The same warning applies here. def numpy(self, arr): """convert the array to numpy array""" if torch.is_tensor(arr): diff --git a/fdtd/grid.py b/fdtd/grid.py index 8b115af..c49ed1e 100644 --- a/fdtd/grid.py +++ b/fdtd/grid.py @@ -494,12 +494,21 @@ def save_data(self): Parameters: None """ + def _numpyfy(item): + if isinstance(item, list): + return [_numpyfy(el) for el in item] + elif bd.is_array(item): + return bd.numpy(item) + else: + return item + if self.folder is None: raise Exception( "Save location not initialized. Please read about 'fdtd.Grid.saveSimulation()' or try running 'grid.saveSimulation()'." ) dic = {} for detector in self.detectors: - dic[detector.name + " (E)"] = [x for x in detector.detector_values()["E"]] - dic[detector.name + " (H)"] = [x for x in detector.detector_values()["H"]] + values = detector.detector_values() + dic[detector.name + " (E)"] = _numpyfy(values['E']) + dic[detector.name + " (H)"] = _numpyfy(values['H']) savez(path.join(self.folder, "detector_readings"), **dic) diff --git a/fdtd/visualization.py b/fdtd/visualization.py index beb6690..adffc42 100644 --- a/fdtd/visualization.py +++ b/fdtd/visualization.py @@ -360,14 +360,14 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"): a[i].append(max(temp) - min(temp)) peakVal, minVal = max(map(max, a)), min(map(min, a)) - print( - "Peak at:", - [ - [[i, j] for j, y in enumerate(x) if y == peakVal] - for i, x in enumerate(a) - if peakVal in x - ], - ) + #print( + # "Peak at:", + # [ + # [[i, j] for j, y in enumerate(x) if y == peakVal] + # for i, x in enumerate(a) + # if peakVal in x + # ], + #) a = 10 * log10([[y / minVal for y in x] for x in a]) plt.title("dB map of Electrical waves in detector region")