Skip to content

Commit

Permalink
added an amination functionality to the workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
ronshnapp committed Oct 23, 2024
1 parent 73ded3a commit 7847078
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 1 deletion.
8 changes: 8 additions & 0 deletions example/params_file.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@
t0: 0
te: -1

- animate_trajectories:
file_name: smoothed_trajectories
min_length: 2
f_start: None
f_end: None
fps: 25
tail_length: 3

- run_extension:
path_to_extention: the_absolute_path_to_the_script_containing_the_code
action_name: the_name_of_the_class_that_needs_to_run
Expand Down
29 changes: 29 additions & 0 deletions example/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, param_file, action):
'manual_matching',
'fiber_orientations',
'plot_trajectories',
'animate_trajectories',
'run_extention']


Expand Down Expand Up @@ -108,6 +109,9 @@ def __init__(self, param_file, action):

elif action == 'plot_trajectories':
self.do_plot_trajectories()

elif action == 'animate_trajectories':
self.do_animate_trajectories()

elif action == 'run_extention':
self.do_run_extention()
Expand Down Expand Up @@ -1425,6 +1429,31 @@ def do_plot_trajectories(self):



def do_animate_trajectories(self):
'''
This function is used to generate a 3D animation of the trajectories
in a given file.
'''
from myptv.makePlots.plot_trajectories import animate_trajectories

# fetching the parameters
fname = self.get_param('animate_trajectories', 'file_name')
min_length = self.get_param('animate_trajectories', 'min_length')
f0 = self.get_param('animate_trajectories', 'f_start')
fe = self.get_param('animate_trajectories', 'f_end')
fps = self.get_param('animate_trajectories', 'fps')
tail_length = self.get_param('animate_trajectories', 'tail_length')

at = animate_trajectories(fname, min_length, fps=fps,
tail_length=tail_length,
f0=f0, fe=fe)
at.animate()

print('')
print('animation saved. Done!')





def do_run_extention(self):
Expand Down
124 changes: 123 additions & 1 deletion myptv/makePlots/plot_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
from numpy import ptp, array, arange, amin, amax
import matplotlib.pyplot as plt

from moviepy.video.io.bindings import mplfig_to_npimage
import moviepy.editor as mpy






def plot_trajectories(fname, min_length, write_trajID=False, t0=0, te=-1):
Expand Down Expand Up @@ -111,6 +117,123 @@ def plot_trajectories(fname, min_length, write_trajID=False, t0=0, te=-1):





class animate_trajectories(object):

def __init__(self, fname, min_length, f0=None, fe=None, fps=25,
tail_length=4):



data = read_csv(fname, header=None, sep='\t')

self.trajectories = dict([(g, array(k.values))
for g,k in data.groupby(0) if g!=-1])

self.longs = [k for k in self.trajectories.keys()
if len(self.trajectories[k])>=min_length]

x_lst, y_lst, z_lst = [], [], []
for i in self.longs:
x_lst += list(self.trajectories[i][:,1])
y_lst += list(self.trajectories[i][:,2])
z_lst += list(self.trajectories[i][:,3])

self.xmax = amax(x_lst) ; self.xmin = amin(x_lst)
self.ymax = amax(y_lst) ; self.ymin = amin(y_lst)
self.zmax = amax(z_lst) ; self.zmin = amin(z_lst)

if f0 is None:
f0 = int(min(data[data.columns[-1]]))

if fe is None:
fe = int(max(data[data.columns[-1]]))

self.fps = fps
self.counter = 0
self.frames = list(range(f0, fe+1))
self.duration = (len(self.frames)-1)/self.fps
self.tl = tail_length
self.min_length = min_length




def update(self, frame):
frame = self.frames[self.counter]
cmap = plt.get_cmap('viridis')
self.ax.clear()

for k in self.longs:
tr = self.trajectories[k]
whr = arange(len(tr))[tr[:,-1]==frame]
if any(whr):
ind = whr[0]
x = tr[ind-self.tl:ind+1,1]
y = tr[ind-self.tl:ind+1,2]
z = tr[ind-self.tl:ind+1,3]
if len(x)==0: continue
dx = ((x[-1]-x[0])**2 + (y[-1]-y[0])**2 + (z[-1]-z[0])**2)**0.5
c = min([(dx/self.tl) / self.vscale,1])
self.ax.plot(x, z, y, '-', color = cmap(c*0.9)) #color=(0.1+c*0.9,0,0.8*(1-c)))

self.ax.set_xlim(self.xmin, self.xmax)
self.ax.set_zlim(self.ymin, self.ymax)
self.ax.set_ylim(self.zmin, self.zmax)

self.ax.set_xlabel('x')
self.ax.set_ylabel('z')
self.ax.set_zlabel('y')


self.ax.w_xaxis.set_pane_color((0.6,0.6,1,0.1))
self.ax.w_yaxis.set_pane_color((0.6,0.6,1,0.15))
self.ax.w_zaxis.set_pane_color((0.7,0.6,1,0.2))

self.ax.grid(False)

self.ax.set_box_aspect((self.xmax-self.xmin,
self.zmax-self.zmin,
self.ymax-self.ymin))

plt.tight_layout(0.5)

self.counter += 1

return mplfig_to_npimage(self.fig) # RGB image of the figure



def animate(self):
'''
will animate the particle's location, and save the animation
'''
#self.prepare_for_animation()

self.vscale = 0
for i in self.longs:
tr = self.trajectories[i]
dt = int(self.min_length/2)
dx = sum([(tr[dt,j] - tr[0,j])**2 for j in [1,2,3]])**0.5
if self.vscale<dx/dt:
self.vscale = dx/dt

self.fig = plt.figure(figsize=(9,9))
self.ax = self.fig.add_subplot(projection='3d')

animation = mpy.VideoClip(self.update, duration=self.duration)
animation.write_videofile('animation.mp4',fps = self.fps)
return animation









def getSamplesFromLongTrajectories(fname, min_len):
'''
Reads a trajectory file and returns an array with its samples that
Expand All @@ -134,7 +257,6 @@ def getSamplesFromLongTrajectories(fname, min_len):





def PlotParticlePositionHistogram(fname):
'''
Expand Down

0 comments on commit 7847078

Please sign in to comment.