-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Circle CI
committed
Jul 18, 2024
1 parent
11b8822
commit b0eca50
Showing
47 changed files
with
1,093 additions
and
97 deletions.
There are no files selected for viewing
Binary file modified
BIN
-84 Bytes
(99%)
dev/_downloads/07fcc19ba03226cd3d83d4e40ec44385/auto_examples_python.zip
Binary file not shown.
115 changes: 115 additions & 0 deletions
115
dev/_downloads/286dcc8a82b9a5553a5d809cc0f6fa61/demo_ne_methods_affinity_matcher.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"\n# Comparison of different DR methods and the use of affinity matcher\n\nWe illustrate the basic usage of TorchDR with different Neighbor Embedding methods\non the swiss roll dataset.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\nimport matplotlib.pyplot as plt\nfrom sklearn.datasets import make_swiss_roll\n\nfrom torchdr import (\n AffinityMatcher,\n SNE,\n UMAP,\n TSNE,\n EntropicAffinity,\n NormalizedGaussianAffinity,\n)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Load toy images\n\nFirst, let's load 5 classes of the digits dataset from sklearn.\n\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"torch.manual_seed(0)\nn_samples = 500\nX, t = make_swiss_roll(n_samples=n_samples, noise=0.1, random_state=0)\n\ninit_embedding = torch.normal(0, 1, size=(n_samples, 2), dtype=torch.double)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Compute the different embedding\n\nTune the different hyperparameters for better results.\n\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"perplexity = 30\nlr = 1e-1\noptim_params = {\n \"init\": init_embedding,\n \"early_exaggeration_iter\": 0,\n \"optimizer\": \"Adam\",\n \"optimizer_kwargs\": None,\n \"early_exaggeration\": 1.0,\n \"max_iter\": 100,\n}\n\nsne = SNE(n_components=2, perplexity=perplexity, lr=lr, **optim_params)\n\numap = UMAP(n_neighbors=perplexity, n_components=2, lr=lr, **optim_params)\n\ntsne = TSNE(n_components=2, perplexity=perplexity, lr=lr, **optim_params)\n\nall_methods = {\n \"TSNE\": tsne,\n \"SNE\": sne,\n \"UMAP\": umap,\n}\n\nfor method_name, method in all_methods.items():\n print(\"--- Computing {} ---\".format(method_name))\n method.fit(X)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Plot the different embeddings\n\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"fig = plt.figure(figsize=(15, 4))\nfs = 24\nax = fig.add_subplot(1, 4, 1, projection=\"3d\")\nax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, s=20)\nax.set_title(\"Swiss Roll in ambient space\", font=\"Times New Roman\", fontsize=fs)\nax.view_init(azim=-66, elev=12)\n\nfor i, (method_name, method) in enumerate(all_methods.items()):\n ax = fig.add_subplot(1, 4, i + 2)\n emb = method.embedding_.detach().numpy() # get the embedding\n ax.scatter(emb[:, 0], emb[:, 1], c=t, s=20)\n ax.set_title(\"{0}\".format(method_name), font=\"Times New Roman\", fontsize=fs)\n ax.set_xticks([])\n ax.set_yticks([])\nplt.tight_layout()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Using AffinityMatcher\n\nWe can reproduce the same kind of results using the\nflexible class AffinityMatcher\n:class:`torchdr.AffinityMatcher`. It take as input\ntwo affinities and minimize a certain matching loss\nbetween them. To reproduce the SNE algorithm\nwe can match with the cross entropy loss\nan EntropicAffinity\n:class:`torchdr.EntropicAffinity` with given\nperplexity and a NormalizedGaussianAffinity\n:class:`torchdr.NormalizedGaussianAffinity`.\n\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": false | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"sne_affinity_matcher = AffinityMatcher(\n n_components=2,\n # SNE matches an EntropicAffinity\n affinity_in=EntropicAffinity(perplexity=perplexity),\n # with a Gaussian kernel normalized by row\n affinity_out=NormalizedGaussianAffinity(normalization_dim=1),\n loss_fn=\"cross_entropy_loss\", # and the cross_entropy loss\n init=init_embedding,\n max_iter=200,\n lr=lr,\n)\nsne_affinity_matcher.fit(X)\n\nfig = plt.figure(figsize=(10, 4))\nfs = 24\ntwo_sne_dict = {\"SNE\": sne, \"SNE (with affinity matcher)\": sne_affinity_matcher}\nfor i, (method_name, method) in enumerate(two_sne_dict.items()):\n ax = fig.add_subplot(1, 2, i + 1)\n emb = method.embedding_.detach().numpy() # get the embedding\n ax.scatter(emb[:, 0], emb[:, 1], c=t, s=20)\n ax.set_title(\"{0}\".format(method_name), font=\"Times New Roman\", fontsize=fs)\n ax.set_xticks([])\n ax.set_yticks([])\nplt.tight_layout()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 0 | ||
} |
122 changes: 122 additions & 0 deletions
122
dev/_downloads/3e483e2c5442faf121372695ae6355bc/demo_ne_methods_affinity_matcher.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
r""" | ||
Comparison of different DR methods and the use of affinity matcher | ||
================================================================== | ||
We illustrate the basic usage of TorchDR with different Neighbor Embedding methods | ||
on the swiss roll dataset. | ||
""" | ||
|
||
# %% | ||
import torch | ||
import matplotlib.pyplot as plt | ||
from sklearn.datasets import make_swiss_roll | ||
|
||
from torchdr import ( | ||
AffinityMatcher, | ||
SNE, | ||
UMAP, | ||
TSNE, | ||
EntropicAffinity, | ||
NormalizedGaussianAffinity, | ||
) | ||
|
||
# %% | ||
# Load toy images | ||
# --------------- | ||
# | ||
# First, let's load 5 classes of the digits dataset from sklearn. | ||
torch.manual_seed(0) | ||
n_samples = 500 | ||
X, t = make_swiss_roll(n_samples=n_samples, noise=0.1, random_state=0) | ||
|
||
init_embedding = torch.normal(0, 1, size=(n_samples, 2), dtype=torch.double) | ||
# %% | ||
# Compute the different embedding | ||
# ------------------------------- | ||
# | ||
# Tune the different hyperparameters for better results. | ||
perplexity = 30 | ||
lr = 1e-1 | ||
optim_params = { | ||
"init": init_embedding, | ||
"early_exaggeration_iter": 0, | ||
"optimizer": "Adam", | ||
"optimizer_kwargs": None, | ||
"early_exaggeration": 1.0, | ||
"max_iter": 100, | ||
} | ||
|
||
sne = SNE(n_components=2, perplexity=perplexity, lr=lr, **optim_params) | ||
|
||
umap = UMAP(n_neighbors=perplexity, n_components=2, lr=lr, **optim_params) | ||
|
||
tsne = TSNE(n_components=2, perplexity=perplexity, lr=lr, **optim_params) | ||
|
||
all_methods = { | ||
"TSNE": tsne, | ||
"SNE": sne, | ||
"UMAP": umap, | ||
} | ||
|
||
for method_name, method in all_methods.items(): | ||
print("--- Computing {} ---".format(method_name)) | ||
method.fit(X) | ||
|
||
# %% | ||
# Plot the different embeddings | ||
# ----------------------------- | ||
fig = plt.figure(figsize=(15, 4)) | ||
fs = 24 | ||
ax = fig.add_subplot(1, 4, 1, projection="3d") | ||
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=t, s=20) | ||
ax.set_title("Swiss Roll in ambient space", font="Times New Roman", fontsize=fs) | ||
ax.view_init(azim=-66, elev=12) | ||
|
||
for i, (method_name, method) in enumerate(all_methods.items()): | ||
ax = fig.add_subplot(1, 4, i + 2) | ||
emb = method.embedding_.detach().numpy() # get the embedding | ||
ax.scatter(emb[:, 0], emb[:, 1], c=t, s=20) | ||
ax.set_title("{0}".format(method_name), font="Times New Roman", fontsize=fs) | ||
ax.set_xticks([]) | ||
ax.set_yticks([]) | ||
plt.tight_layout() | ||
# %% | ||
# Using AffinityMatcher | ||
# ----------------------------- | ||
# | ||
# We can reproduce the same kind of results using the | ||
# flexible class AffinityMatcher | ||
# :class:`torchdr.AffinityMatcher`. It take as input | ||
# two affinities and minimize a certain matching loss | ||
# between them. To reproduce the SNE algorithm | ||
# we can match with the cross entropy loss | ||
# an EntropicAffinity | ||
# :class:`torchdr.EntropicAffinity` with given | ||
# perplexity and a NormalizedGaussianAffinity | ||
# :class:`torchdr.NormalizedGaussianAffinity`. | ||
|
||
sne_affinity_matcher = AffinityMatcher( | ||
n_components=2, | ||
# SNE matches an EntropicAffinity | ||
affinity_in=EntropicAffinity(perplexity=perplexity), | ||
# with a Gaussian kernel normalized by row | ||
affinity_out=NormalizedGaussianAffinity(normalization_dim=1), | ||
loss_fn="cross_entropy_loss", # and the cross_entropy loss | ||
init=init_embedding, | ||
max_iter=200, | ||
lr=lr, | ||
) | ||
sne_affinity_matcher.fit(X) | ||
|
||
fig = plt.figure(figsize=(10, 4)) | ||
fs = 24 | ||
two_sne_dict = {"SNE": sne, "SNE (with affinity matcher)": sne_affinity_matcher} | ||
for i, (method_name, method) in enumerate(two_sne_dict.items()): | ||
ax = fig.add_subplot(1, 2, i + 1) | ||
emb = method.embedding_.detach().numpy() # get the embedding | ||
ax.scatter(emb[:, 0], emb[:, 1], c=t, s=20) | ||
ax.set_title("{0}".format(method_name), font="Times New Roman", fontsize=fs) | ||
ax.set_xticks([]) | ||
ax.set_yticks([]) | ||
plt.tight_layout() |
Binary file modified
BIN
-45 Bytes
(100%)
dev/_downloads/6f1e7a639e0699d6164445b55e6c116d/auto_examples_jupyter.zip
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.