-
Notifications
You must be signed in to change notification settings - Fork 78
/
fit_3D_landmarks.py
134 lines (106 loc) · 6.05 KB
/
fit_3D_landmarks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
'''
Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG) is holder of all proprietary rights on this
computer program.
You can only use this computer program if you have closed a license agreement with MPG or you get the right to use
the computer program from someone who is authorized to grant you that right.
Any use of the computer program without a valid license is prohibited and liable to prosecution.
Copyright 2019 Max-Planck-Gesellschaft zur Foerderung der Wissenschaften e.V. (MPG). acting on behalf of its
Max Planck Institute for Intelligent Systems and the Max Planck Institute for Biological Cybernetics.
All rights reserved.
More information about FLAME is available at http://flame.is.tue.mpg.de.
For comments or questions, please email us at [email protected]
'''
import os
import six
import numpy as np
import tensorflow as tf
from psbody.mesh import Mesh
from psbody.mesh.meshviewer import MeshViewer
from utils.landmarks import load_embedding, tf_get_model_lmks, create_lmk_spheres
from tf_smpl.batch_smpl import SMPL
from tensorflow.contrib.opt import ScipyOptimizerInterface as scipy_pt
def fit_lmk3d(target_3d_lmks, model_fname, lmk_face_idx, lmk_b_coords, weights, show_fitting=True):
'''
Fit FLAME to 3D landmarks
:param target_3d_lmks: target 3D landmarks provided as (num_lmks x 3) matrix
:param model_fname: saved Tensorflow FLAME model
:param lmk_face_idx: face indices of the landmark embedding in the FLAME topology
:param lmk_b_coords: barycentric coordinates of the landmark embedding in the FLAME topology
(i.e. weighting of the three vertices for the trinagle, the landmark is embedded in
:param weights: weights of the individual objective functions
:return: a mesh with the fitting results
'''
tf_trans = tf.Variable(np.zeros((1,3)), name="trans", dtype=tf.float64, trainable=True)
tf_rot = tf.Variable(np.zeros((1,3)), name="pose", dtype=tf.float64, trainable=True)
tf_pose = tf.Variable(np.zeros((1,12)), name="pose", dtype=tf.float64, trainable=True)
tf_shape = tf.Variable(np.zeros((1,300)), name="shape", dtype=tf.float64, trainable=True)
tf_exp = tf.Variable(np.zeros((1,100)), name="expression", dtype=tf.float64, trainable=True)
smpl = SMPL(model_fname)
tf_model = tf.squeeze(smpl(tf_trans,
tf.concat((tf_shape, tf_exp), axis=-1),
tf.concat((tf_rot, tf_pose), axis=-1)))
with tf.Session() as session:
session.run(tf.global_variables_initializer())
lmks = tf_get_model_lmks(tf_model, smpl.f, lmk_face_idx, lmk_b_coords)
lmk_dist = tf.reduce_sum(tf.square(1000 * tf.subtract(lmks, target_3d_lmks)))
neck_pose_reg = tf.reduce_sum(tf.square(tf_pose[:,:3]))
jaw_pose_reg = tf.reduce_sum(tf.square(tf_pose[:,3:6]))
eyeballs_pose_reg = tf.reduce_sum(tf.square(tf_pose[:,6:]))
shape_reg = tf.reduce_sum(tf.square(tf_shape))
exp_reg = tf.reduce_sum(tf.square(tf_exp))
# Optimize global transformation first
vars = [tf_trans, tf_rot]
loss = weights['lmk'] * lmk_dist
optimizer = scipy_pt(loss=loss, var_list=vars, method='L-BFGS-B', options={'disp': 1, 'ftol': 5e-6})
print('Optimize rigid transformation')
optimizer.minimize(session)
# Optimize for the model parameters
vars = [tf_trans, tf_rot, tf_pose, tf_shape, tf_exp]
loss = weights['lmk'] * lmk_dist + weights['shape'] * shape_reg + weights['expr'] * exp_reg + \
weights['neck_pose'] * neck_pose_reg + weights['jaw_pose'] * jaw_pose_reg + weights['eyeballs_pose'] * eyeballs_pose_reg
optimizer = scipy_pt(loss=loss, var_list=vars, method='L-BFGS-B', options={'disp': 1, 'ftol': 5e-6})
print('Optimize model parameters')
optimizer.minimize(session)
print('Fitting done')
if show_fitting:
# Visualize landmark fitting
mv = MeshViewer()
mv.set_static_meshes(create_lmk_spheres(target_3d_lmks, 0.001, [255.0, 0.0, 0.0]))
mv.set_dynamic_meshes([Mesh(session.run(tf_model), smpl.f)] + create_lmk_spheres(session.run(lmks), 0.001, [0.0, 0.0, 255.0]), blocking=True)
six.moves.input('Press key to continue')
return Mesh(session.run(tf_model), smpl.f)
def run_3d_lmk_fitting():
# Path of the Tensorflow FLAME model
model_fname = './models/generic_model.pkl'
# model_fname = './models/female_model.pkl'
# model_fname = './models/male_model.pkl'
# Path of the landamrk embedding file into the FLAME surface
flame_lmk_path = './data/flame_static_embedding.pkl'
# 3D landmark file that should be fitted (landmarks must be corresponding with the defined FLAME landmarks)
# see "img1_lmks_visualized.jpeg" or "see the img2_lmks_visualized.jpeg" for the order of the landmarks
target_lmk_path = './data/landmark_3d.npy'
# Output filename
out_mesh_fname = './results/landmark_3d.ply'
lmk_face_idx, lmk_b_coords = load_embedding(flame_lmk_path)
lmk_3d = np.load(target_lmk_path)
weights = {}
# Weight of the landmark distance term
weights['lmk'] = 1.0
# Weight of the shape regularizer
weights['shape'] = 1.0
# Weight of the expression regularizer
weights['expr'] = 1.0
# Weight of the neck pose (i.e. neck rotationh around the neck) regularizer
weights['neck_pose'] = 1000.0
# Weight of the jaw pose (i.e. jaw rotation for opening the mouth) regularizer
weights['jaw_pose'] = 1.0
# Weight of the eyeball pose (i.e. eyeball rotations) regularizer
weights['eyeballs_pose'] = 10.0
# Show landmark fitting (default: red = target landmarks, blue = fitting landmarks)
show_fitting = True
result_mesh = fit_lmk3d(lmk_3d, model_fname, lmk_face_idx, lmk_b_coords, weights, show_fitting=show_fitting)
if not os.path.exists(os.path.dirname(out_mesh_fname)):
os.makedirs(os.path.dirname(out_mesh_fname))
result_mesh.write_ply(out_mesh_fname)
if __name__ == '__main__':
run_3d_lmk_fitting()