KAD/zad3/som.py

161 lines
5.7 KiB
Python
Raw Permalink Normal View History

2022-02-09 22:58:16 +01:00
import matplotlib.pyplot as plt
import utils as u
from matplotlib.animation import FuncAnimation
from random import shuffle
import numpy as np
2022-02-11 16:42:38 +01:00
def find_bmu(som, exhausted, x):
2022-02-09 22:58:16 +01:00
'''Return the (g,h) index of the BMU in the grid'''
#wrong_dist_sq = np.asarray([u.calc_length(x, s) for s in som])
2022-02-11 16:42:38 +01:00
dist_sq = exhausted * (np.square(som - x)).sum(axis=2)
2022-02-09 22:58:16 +01:00
return np.unravel_index(np.argmin(dist_sq, axis=None), dist_sq.shape)
2022-02-11 16:42:38 +01:00
def dist_comp(som, exhausted, x):
2022-02-09 22:58:16 +01:00
distsq = []
for i in range(som.shape[0]):
for j in range(som.shape[1]):
2022-02-11 16:42:38 +01:00
distsq.append([(i, j), exhausted[i][j] *
u.calc_length(x, som[i][j])])
2022-02-09 22:58:16 +01:00
return sorted(distsq, key=lambda x: x[1])
2022-02-11 16:42:38 +01:00
def update_weights(som, exhausted, train_ex, learn_rate, radius_sq,
bmu_coord, algorithm):
'''Update the weights of the SOM cells when given a single training example
and the model parameters along with BMU coordinates as a tuple'''
2022-02-09 22:58:16 +01:00
g, h = bmu_coord
# if radius is close to zero then only BMU is changed
if radius_sq < 1e-3:
som[g, h, :] += learn_rate * (train_ex - som[g, h, :])
return som
match algorithm:
case 'kohonen':
# Change all cells in a neighborhood of BMU
2022-02-11 16:42:38 +01:00
for i in range(som.shape[0]):
for j in range(som.shape[1]):
2022-02-09 22:58:16 +01:00
dist_sq = np.square(i - g) + np.square(j - h)
dist_func = np.exp(-dist_sq / 2 / radius_sq)
som[i, j, :] += learn_rate * \
dist_func * (train_ex - som[i, j, :])
case 'neuron gas':
2022-02-11 16:42:38 +01:00
dist_rank = dist_comp(som, exhausted, train_ex)
2022-02-09 22:58:16 +01:00
for i in range(len(dist_rank)):
dist_func = np.exp(-i / 2 / np.sqrt(radius_sq))
2022-02-11 16:42:38 +01:00
som[dist_rank[i][0][0], dist_rank[i][0][1], :] += \
2022-02-09 22:58:16 +01:00
learn_rate * dist_func * \
2022-02-11 16:42:38 +01:00
(train_ex - som[dist_rank[i][0][0], dist_rank[i][0][1], :])
2022-02-09 22:58:16 +01:00
case _:
raise NotImplementedError(
f'algorithm {algorithm} is not implemented yet')
return som
def train_som(som, train_data, learn_rate=.1, radius_sq=1,
lr_decay=.1, radius_decay=.1, epochs=20, algorithm='kohonen'):
'''Main routine for training an SOM. It requires an initialized SOM grid
or a partially trained grid as parameter'''
2022-02-11 16:42:38 +01:00
exhausted = np.ones((som.shape[0], som.shape[1]))
2022-02-09 22:58:16 +01:00
learn_rate_0 = learn_rate
radius_0 = radius_sq
2022-02-11 16:42:38 +01:00
soms_with_error = [
(som.copy(), calc_som_error(som, exhausted, train_data))]
2022-02-09 22:58:16 +01:00
for epoch in np.arange(epochs):
shuffle(train_data)
for train_ex in train_data:
2022-02-11 16:42:38 +01:00
g, h = find_bmu(som, exhausted, train_ex)
som = update_weights(som, exhausted, train_ex,
2022-02-09 22:58:16 +01:00
learn_rate, radius_sq, (g, h), algorithm)
2022-02-11 16:42:38 +01:00
exhausted[g][h] += 1
2022-02-09 22:58:16 +01:00
# Update learning rate and radius
learn_rate = learn_rate_0 * np.exp(-epoch * lr_decay)
radius_sq = radius_0 * np.exp(-epoch * radius_decay)
2022-02-11 16:42:38 +01:00
exhausted = np.ones((som.shape[0], som.shape[1]))
error = calc_som_error(som, exhausted, train_data)
2022-02-09 22:58:16 +01:00
soms_with_error.append((som.copy(), error))
if error < 1e-3:
break
return soms_with_error
2022-02-11 16:42:38 +01:00
def calc_som_error(som, exhausted, train_data):
2022-02-09 22:58:16 +01:00
errors = []
for train_ex in train_data:
2022-02-11 16:42:38 +01:00
g, h = find_bmu(som, exhausted, train_ex)
2022-02-09 22:58:16 +01:00
errors.append(u.calc_length(train_ex, som[g][h]))
return np.mean(np.sqrt(np.asarray(errors)))
def plot_with_data(soms, data, name_suffix='_'):
fig, ax = plt.subplots()
ax.set_xlabel('X')
ax.set_ylabel('Y')
time_text = ax.text(0.05, 0.95, 'epoch=0', horizontalalignment='left',
verticalalignment='top', transform=ax.transAxes)
# data
lst_x, lst_y = zip(*data)
lst_x = list(lst_x)
lst_y = list(lst_y)
ax.scatter(lst_x, lst_y)
som_data = soms[0]
lst_x, lst_y = zip(*som_data[0])
lst_x = list(lst_x)
lst_y = list(lst_y)
som_plot, = ax.plot(lst_x, lst_y, color='black', marker='X')
plt.grid(True)
def update_plot_som(i):
som_data = soms[i]
time_text.set_text(f'epoch={i}')
lst_x, lst_y = zip(*som_data[0])
lst_x = list(lst_x)
lst_y = list(lst_y)
som_plot.set_data(lst_x, lst_y)
return [time_text, som_plot]
anim = FuncAnimation(fig, update_plot_som,
frames=len(soms), blit=True)
anim.save(f'animationSOMs{name_suffix}.gif')
plt.show()
def init_neurons(data, k, rand: np.random.RandomState = None, method='random'):
match method:
case 'zeros':
return np.zeros((1, k, 2))
case 'random':
lst_x, lst_y = zip(*data)
minimal = min(min(lst_x), min(lst_y))
maximal = max(max(lst_x), max(lst_y))
return (maximal - minimal) * rand.random_sample((1, k, 2)) + minimal
case _:
raise NotImplementedError(
f'method {method} is not implemented yet')
def print_som_stats(soms_with_errors, train_data):
print('=' * 20)
2022-02-11 16:42:38 +01:00
exhausted = np.ones(
(soms_with_errors[0][0].shape[0], soms_with_errors[0][0].shape[1]))
2022-02-09 22:58:16 +01:00
soms, errs = zip(*soms_with_errors)
m = np.mean(errs)
std = np.std(errs)
min_err = np.min(errs)
dead_neurons_count = []
for som in soms:
dead_neurons_count.append(
2022-02-11 16:42:38 +01:00
20-len(set([find_bmu(som, exhausted, x) for x in train_data])))
2022-02-09 22:58:16 +01:00
print("Średni błąd: ", m)
print("Odchylenie standardowe: ", std)
print("Błąd minimalny: ", min_err)
print(
f'Średnia liczba nieaktywnych neuronów: {np.mean(dead_neurons_count)}')
print(
f'Odchylenie standardowe liczby nieaktywnych neuronów: {np.std(dead_neurons_count)}')
2022-02-11 16:42:38 +01:00
print('=' * 20)