-
Notifications
You must be signed in to change notification settings - Fork 54
/
mnist_ae.py
85 lines (67 loc) · 1.87 KB
/
mnist_ae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_openml
from sklearn.metrics import mean_squared_error
from lassonet import LassoNetAutoEncoder
X, y = fetch_openml(name="mnist_784", return_X_y=True)
filter = y == "3"
X = X[filter].values / 255
model = LassoNetAutoEncoder(
M=30, n_iters=(3000, 500), path_multiplier=1.05, verbose=True
)
path = model.path(X)
img = model.feature_importances_.reshape(28, 28)
plt.title("Feature importance to reconstruct 3")
plt.imshow(img)
plt.colorbar()
plt.savefig("mnist-ae-importance.png")
n_selected = []
score = []
lambda_ = []
for save in path:
model.load(save.state_dict)
X_pred = model.predict(X)
n_selected.append(save.selected.sum())
score.append(mean_squared_error(X_pred, X))
lambda_.append(save.lambda_)
to_plot = [160, 220, 300]
for i, save in zip(n_selected, path):
if not to_plot:
break
if i > to_plot[-1]:
continue
to_plot.pop()
plt.clf()
plt.title(f"Linear model with {i} features")
weight = save.state_dict["skip.weight"]
img = (weight[1] - weight[0]).reshape(28, 28)
plt.imshow(img)
plt.colorbar()
plt.savefig(f"mnist-ae-{i}.png")
plt.clf()
fig = plt.figure(figsize=(12, 12))
plt.subplot(311)
plt.grid(True)
plt.plot(n_selected, score, ".-")
plt.xlabel("number of selected features")
plt.ylabel("MSE")
plt.subplot(312)
plt.grid(True)
plt.plot(lambda_, score, ".-")
plt.xlabel("lambda")
plt.xscale("log")
plt.ylabel("MSE")
plt.subplot(313)
plt.grid(True)
plt.plot(lambda_, n_selected, ".-")
plt.xlabel("lambda")
plt.xscale("log")
plt.ylabel("number of selected features")
plt.savefig("mnist-ae-training.png")
plt.subplot(221)
plt.imshow(X[150].reshape(28, 28))
plt.subplot(222)
plt.imshow(model.predict(X[150]).reshape(28, 28))
plt.subplot(223)
plt.imshow(X[250].reshape(28, 28))
plt.subplot(224)
plt.imshow(model.predict(X[250]).reshape(28, 28))