Skip to content

Commit

Permalink
more bugfixes in export from gpu (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Dec 7, 2023
1 parent 91b5d93 commit 3f4df26
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
2 changes: 0 additions & 2 deletions fdtd/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,6 @@ def array(self, arr, dtype=None, **kwargs):
return arr.clone().to(device="cuda", dtype=dtype, **kwargs)
return torch.tensor(arr, device="cuda", dtype=dtype, **kwargs)

# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# The same warning applies here.
def numpy(self, arr):
"""convert the array to numpy array"""
if torch.is_tensor(arr):
Expand Down
13 changes: 11 additions & 2 deletions fdtd/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,21 @@ def save_data(self):
Parameters: None
"""
def _numpyfy(item):
if isinstance(item, list):
return [_numpyfy(el) for el in item]
elif bd.is_array(item):
return bd.numpy(item)
else:
return item

if self.folder is None:
raise Exception(
"Save location not initialized. Please read about 'fdtd.Grid.saveSimulation()' or try running 'grid.saveSimulation()'."
)
dic = {}
for detector in self.detectors:
dic[detector.name + " (E)"] = [x for x in detector.detector_values()["E"]]
dic[detector.name + " (H)"] = [x for x in detector.detector_values()["H"]]
values = detector.detector_values()
dic[detector.name + " (E)"] = _numpyfy(values['E'])
dic[detector.name + " (H)"] = _numpyfy(values['H'])
savez(path.join(self.folder, "detector_readings"), **dic)
16 changes: 8 additions & 8 deletions fdtd/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,14 +360,14 @@ def dB_map_2D(block_det=None, choose_axis=2, interpolation="spline16"):
a[i].append(max(temp) - min(temp))

peakVal, minVal = max(map(max, a)), min(map(min, a))
print(
"Peak at:",
[
[[i, j] for j, y in enumerate(x) if y == peakVal]
for i, x in enumerate(a)
if peakVal in x
],
)
#print(
# "Peak at:",
# [
# [[i, j] for j, y in enumerate(x) if y == peakVal]
# for i, x in enumerate(a)
# if peakVal in x
# ],
#)
a = 10 * log10([[y / minVal for y in x] for x in a])

plt.title("dB map of Electrical waves in detector region")
Expand Down

0 comments on commit 3f4df26

Please sign in to comment.