Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding PCA #6

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
3 changes: 1 addition & 2 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ PYTHONPATH=.

# JAX_ENABLE_X64=1
# XLA_PYTHON_CLIENT_PREALLOCATE=0
# CUDA_VISIBLE_DEVICES=0
# CUDA_VISIBLE_DEVICES=-1
# CUDA_VISIBLE_DEVICES=1 #-1
# JAX_PLATFORM_NAME=cpu
# JAX_DISABLE_JIT=1
# JAX_DEBUG_NANS=1
Expand Down
7 changes: 7 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@
"program": "${workspaceFolder}/examples/draft_examples.py",
"justMyCode": false
},
{
"name": "Draft presentation",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/examples/draft_presentation.py",
"justMyCode": false
},
{
"name": "Train graph",
"type": "python",
Expand Down
4 changes: 2 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"python.formatting.blackArgs": [
"--line-length=100",
"--target-version=py37"
// "--line-length=100",
// "--target-version=py37"
],
// Pylint
"python.linting.pylintArgs": [
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

Contact mechanics models the behavior of physical bodies that come into contact with each other. It examines phenomena such as collisions, normal compliance, and friction. Most contact problems cannot be solved analytically and require a numerical procedure, such as the classical finite element method (FEM).

Conmech3d is an implementation of FEM for soft-body mechanical contact problems. The project is almost entirely self-contained and mainly aimed at research and didactic applications. Conmech3d is written in Python and uses [JAX](https://github.com/google/jax/tree/main), a library for high-performance numerical computing. Besides basic Python libraries, such as [Numpy](https://github.com/numpy/numpy) and [Scipy](https://scipy.org/), it also employs [pygmsh](https://github.com/meshpro/pygmsh) for mesh construction and [Numba](https://github.com/numba/numba) along with [Cython](https://github.com/cython/cython) to increase the speed of initial setup. Various options for visualization of simulation results are included, such as [Blender](https://github.com/blender/blender), [Three.js](https://github.com/mrdoob/three.js/) and [Matplotlib](https://github.com/matplotlib/matplotlib)
Conmech3d is an implementation of FEM for soft-body mechanical contact problems. The project is almost entirely self-contained and mainly aimed at research and didactic applications. Conmech3d is written in Python and uses [JAX](https://github.com/google/jax/tree/main), a library for high-performance numerical computing. Besides basic Python libraries, such as [Numpy](https://github.com/numpy/numpy) and [Scipy](https://scipy.org/), it also employs [pygmsh](https://github.com/meshpro/pygmsh) for mesh construction and [Numba](https://github.com/numba/numba) along with [Cython](https://github.com/cython/cython) to increase the speed of initial setup. Various options for visualization of simulation results are included, such as [Blender](https://github.com/blender/blender), [Three.js](https://github.com/mrdoob/three.js/) and [Matplotlib](https://github.com/matplotlib/matplotlib).

Experimental implementations of model reduction techniques that include tetrahedral skinning used in computer graphics and a new approach using Graph Neural Network are included in this repository.
<!-- PCA, Flax and Pytorch Geometric-->
Expand Down
6 changes: 4 additions & 2 deletions conmech/dynamics/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,11 @@ def get_rotation(self, displacement):
displacement, self.matrices.dx_big_jax
)
if not state.success:
raise Exception("Error calculating rotation")
raise ArithmeticError("Error calculating rotation")
# print(state.iteration, state.norm)
return complete_base(base_seed=np.array(final_rotation, dtype=np.float64))
return np.array(
complete_base(base_seed=np.array(final_rotation, dtype=np.float64))
)

# def iterate_self(self, acceleration, temperature=None):
# super().iterate_self(acceleration, temperature)
Expand Down
2 changes: 1 addition & 1 deletion conmech/helpers/cmh.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ def find_files_by_name(directory, name):


def get_base_for_comarison():
print("USING BASE FOR COMPARISON")
all_paths = glob(
"output/**/scenarios/*skinning_backwards*.scenes_comparer", recursive=True
)
assert len(all_paths) == 1
print("USING BASE FOR COMPARISON")
return all_paths[0]


Expand Down
6 changes: 3 additions & 3 deletions conmech/helpers/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ class SimulationConfig:
use_nonconvex_friction_law: bool
use_constant_contact_integral: bool
use_lhs_preconditioner: bool
use_pca: bool
with_self_collisions: bool
mesh_layer_proportion: int = None
mode: str = "normal" # "normal" "skinning" "net"
mode: str = "normal" # "normal" "skinning" "net" "pca"



@dataclass
class Config:
shell: bool = False
timestamp_skip: int = 10000
run_timestamp: float = int(time.time() * timestamp_skip)
current_time: str = datetime.now().strftime("%m.%d-%H.%M.%S")
current_time: str = datetime.now().strftime("%y.%m.%d-%H.%M.%S")
verbose: bool = True

animation_backend: str = "three" # "matplotlib blender three"
Expand Down
4 changes: 2 additions & 2 deletions conmech/helpers/interpolation_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,14 @@ def interpolate_3d_corner_vectors(
nodes: np.ndarray, base: np.ndarray, corner_vectors: np.ndarray
):
# orthonormal matrix; inverse equals transposition
upward_nodes = lnh.get_in_base(nodes, base.T)
upward_nodes = lnh.get_in_base2(nodes, base)
scaled_nodes = scale_nodes_to_cube(upward_nodes)
upward_vectors_interpolation = interpolate_scaled_nodes_numba(
scaled_nodes=scaled_nodes,
corner_vectors=corner_vectors,
)

vectors_interpolation = lnh.get_in_base(upward_vectors_interpolation, base)
vectors_interpolation = lnh.get_in_base2(upward_vectors_interpolation, base.T)
# assert np.abs(np.mean(vectors_interpolation)) < 0.1
return vectors_interpolation

Expand Down
31 changes: 16 additions & 15 deletions conmech/helpers/lnh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@
linear algebra helpers
"""

import jax.numpy as jnp
import numpy as np

from conmech.helpers import nph


def move_vector(vectors, index):
return np.roll(vectors, -index, axis=0)
from conmech.helpers import jxh, nph


def complete_base(base_seed):
Expand All @@ -19,32 +16,36 @@ def complete_base(base_seed):
return base


def __move_vector(vectors, index):
return jnp.roll(vectors, -index, axis=0)


def __orthonormalize_priority_gram_schmidt(base_seed, index):
prioritized_base_seed = move_vector(vectors=base_seed, index=index)
prioritized_base_seed = __move_vector(vectors=base_seed, index=index)
prioritized_base = __orthonormalize_gram_schmidt(prioritized_base_seed)
base = move_vector(vectors=prioritized_base, index=index)
base = __move_vector(vectors=prioritized_base, index=index)
return base


def __orthonormalize_gram_schmidt(base_seed):
normalized_base_seed = nph.normalize_euclidean_numba(base_seed)
normalized_base_seed = jxh.normalize_euclidean(base_seed)
unnormalized_base = __orthogonalize_gram_schmidt(normalized_base_seed)
base = nph.normalize_euclidean_numba(unnormalized_base)
base = jxh.normalize_euclidean(unnormalized_base)
return base


def __orthogonalize_gram_schmidt(vectors):
# Gramm-Schmidt orthogonalization
b0 = vectors[0]
if len(vectors) == 1:
return np.array((b0))
return jnp.array((b0))

b1 = vectors[1] - (vectors[1] @ b0) * b0
if len(vectors) == 2:
return np.array((b0, b1))
return jnp.array((b0, b1))

b2 = np.cross(b0, b1)
return np.array((b0, b1, b2))
b2 = jnp.cross(b0, b1)
return jnp.array((b0, b1, b2))


def generate_base(dimension):
Expand Down Expand Up @@ -81,5 +82,5 @@ def correct_base(base):
return True


def get_in_base(vectors, base):
return vectors @ base.T
def get_in_base2(vectors, base):
return vectors @ base
148 changes: 93 additions & 55 deletions conmech/helpers/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@

import jax
import jax.numpy as jnp
import numpy as np
from tqdm import tqdm

from conmech.helpers import cmh, nph
from conmech.helpers import cmh, lnh, nph
from conmech.properties.mesh_properties import MeshProperties
from conmech.scene.scene import Scene
from conmech.simulations.simulation_runner import create_scene


def get_all_indices(data_path):
Expand Down Expand Up @@ -37,11 +41,14 @@ def get_scenes():
input_path = "/home/michal/Desktop/conmech3d/output"
scene_files = cmh.find_files_by_extension(input_path, "scenes") # scenes_data
path_id = "/scenarios/"
scene_files = [f for f in scene_files if path_id in f]
scene_files = [f for f in scene_files if path_id in f and "SAVED" not in f]

assert len(scene_files) == 1
# all_arrays_path = max(scene_files, key=os.path.getctime)
scenes = []
for all_arrays_path in scene_files:
if "SAVED" in all_arrays_path:
continue
all_arrays_name = os.path.basename(all_arrays_path).split("DATA")[0]
print(f"FILE: {all_arrays_name}")

Expand All @@ -57,34 +64,6 @@ def get_scenes():
return scenes


def get_projection(data, latent_dim=200):
projection_mean = 0 * data.mean(axis=0) # columnwise mean = 0
svd = jax.numpy.linalg.svd(data - projection_mean, full_matrices=False)
# (svd[0] @ jnp.diag(svd[1]) @ svd[2])
projection_matrix = svd[2][:latent_dim].T
return {"matrix": projection_matrix, "mean": projection_mean.reshape(-1, 1)}


def project_to_latent(projection, data_stack):
data_stack_zeroed = data_stack - projection["mean"]
latent = projection["matrix"].T @ data_stack_zeroed
return latent


def project_from_latent(projection, latent):
data_stack_zeroed = projection["matrix"] @ latent
data_stack = data_stack_zeroed + projection["mean"]
return data_stack


def p_to_vector(projection, vector):
return project_to_latent(projection, vector.reshape(-1, 1)).reshape(-1)


def p_from_vector(projection, vector):
return project_from_latent(projection, vector.reshape(-1, 1)).reshape(-1)


def save_pca(projection, file_path="./output/PCA"):
with open(file_path, "wb") as file:
pickle.dump(projection, file)
Expand All @@ -96,45 +75,104 @@ def load_pca(file_path="./output/PCA"):
return projection


def get_displacement_new(scene):
velocity = scene.velocity_old + scene.time_step * scene.exact_acceleration
displacement = scene.displacement_old + scene.time_step * velocity
return displacement


def get_data_scenes(scenes):
data_list = []
count = len(scenes)
for scene in scenes:
u = jnp.array(scene.get_last_displacement_step()) # scene.displacement_old)
u_stack = nph.stack_column(u)
for scene in tqdm(scenes):
# print(scene.moved_base)
u = scene.get_lifted_displacement()
u_stack = nph.stack(u)
data_list.append(u_stack)

data = jnp.array(data_list).reshape(count, -1)
data = jnp.array(data_list)
return data, u_stack, u


def get_data_dataset(dataloader):
def get_data_dataset(dataloader, scene):
data_list = []
count = 1000
for _ in tqdm(range(count)):
sample = next(iter(dataloader))
count = 3000
print(f"LIMIT TO {count}")
for i, sample in enumerate(tqdm(dataloader)): # check randomness
target = sample[0][1]

u = jnp.array(target.reduced_acceleration)
u_stack = nph.stack_column(u)
data_list.append(u_stack)
original_displacement = jnp.array(target["new_displacement"])

data = jnp.array(data_list).reshape(count, -1)
return data, u_stack, u
original_rotation = scene.get_rotation(original_displacement)
random_rotation = jnp.linalg.qr(np.random.rand(3, 3))[0]
new_rotation = original_rotation.T @ random_rotation

moved_nodes = scene.initial_nodes + original_displacement
displacement_mean = np.mean(moved_nodes, axis=0)
rotated_moved_nodes = lnh.get_in_base2(
(moved_nodes - displacement_mean), new_rotation
)
displacement = rotated_moved_nodes - scene.initial_nodes
displacement += displacement_mean # 20 * np.random.rand(3) ###

displacement_stack = nph.stack(displacement)
data_list.append(displacement_stack)
if i > count:
break

data = jnp.array(
data_list
) # Sort by displacement and get max, plot hist # np.linalg.norm(data_list[190])
return data, displacement_stack, displacement


def get_projection(data, latent_dim):
projection_mean = 0 # data.mean(axis=0)

svd = jax.numpy.linalg.svd(data - projection_mean, full_matrices=False)
# (svd[0] @ jnp.diag(svd[1]) @ svd[2])

projection_matrix = svd[2][:latent_dim]
# projection_matrix = jax.experimental.sparse.eye(data.shape[1])

return {"matrix": projection_matrix, "mean": projection_mean}


def project_to_latent(projection, data):
data_zeroed = data - projection["mean"]
latent = projection["matrix"] @ data_zeroed
return latent


def project_from_latent(projection, latent):
data_stack_zeroed = projection["matrix"].T @ latent
data_stack = data_stack_zeroed + projection["mean"]
return data_stack


def p_to_vector(projection, data):
return project_to_latent(projection, nph.stack(data))


def p_from_vector(projection, latent):
return nph.unstack(project_from_latent(projection, latent), dim=3)


def run(dataloader):
_ = dataloader
scenes = get_scenes()
data, sample_u_stack, sample_u = get_data_scenes(scenes)
# data, sample_u_stack, sample_u = get_data_dataset(dataloader)
def run(dataloader, latent_dim, scenario):
if dataloader is None:
scenes = get_scenes()
data, sample_u_stack, sample_u = get_data_scenes(scenes)
else:
scene = create_scene(scenario)
data, sample_u_stack, sample_u = get_data_dataset(
dataloader=dataloader, scene=scene
)

original_projection = get_projection(data)
original_projection = get_projection(data, latent_dim)
save_pca(original_projection)

projection = load_pca()
latent = project_to_latent(projection, sample_u_stack)
u_reprojected_stack = project_from_latent(projection, latent)
u_reprojected = nph.unstack(u_reprojected_stack, dim=3)
print("Error max: ", jnp.abs(u_reprojected - sample_u).max())
# projection = load_pca()
# latent = project_to_latent(projection, sample_u_stack)
# u_reprojected_stack = project_from_latent(projection, latent)
# u_reprojected = nph.unstack(u_reprojected_stack, dim=3)
# print("Error max: ", jnp.abs(u_reprojected - sample_u).max())
return 0
2 changes: 1 addition & 1 deletion conmech/mesh/mesh_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def translate_nodes(nodes: np.ndarray, mesh_prop: MeshProperties):
if mesh_prop.mean_at_origin:
nodes -= np.mean(nodes, axis=0)
if mesh_prop.initial_base is not None:
nodes = lnh.get_in_base(nodes, mesh_prop.initial_base)
nodes = lnh.get_in_base2(nodes, mesh_prop.initial_base.T)
if mesh_prop.initial_position is not None:
nodes += mesh_prop.initial_position
return nodes
Expand Down
Loading
Loading