KAD/zad3/kmeans.py
2022-02-07 15:24:23 +01:00

146 lines
5.1 KiB
Python

import matplotlib.pyplot as plt
import utils
from matplotlib.animation import FuncAnimation
from random import sample, shuffle
import numpy as np
def plot_kmeans(all_data, k, name_suffix):
fig, ax = plt.subplots()
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_title(f'k={k}')
time_text = ax.text(0.05, 0.95, 'iter=0', horizontalalignment='left',
verticalalignment='top', transform=ax.transAxes)
plt.grid(True)
centroid_scatters = []
cluster_scatters = {}
centroids, clusters = all_data[0]
for key in clusters:
color = utils.get_color(key / k)
if clusters[key]:
lst_x, lst_y = zip(*clusters[key])
lst_x = list(lst_x)
lst_y = list(lst_y)
cluster_scatters[key] = ax.scatter(lst_x, lst_y, color=color)
centroid_scatters.append(ax.scatter([centroids[key][0]], [
centroids[key][1]], color=color, marker='X'))
def update_plot_kmeans(i):
centroids, clusters = all_data[i]
time_text.set_text(f'iter={i}')
for key in clusters:
centroid_scatters[key].set_offsets(centroids[key])
if clusters[key]:
if key in cluster_scatters:
cluster_scatters[key].set_offsets(clusters[key])
else:
color = utils.get_color(key/k)
lst_x, lst_y = zip(*clusters[key])
lst_x = list(lst_x)
lst_y = list(lst_y)
cluster_scatters[key] = ax.scatter(
lst_x, lst_y, color=color)
return centroid_scatters + list(cluster_scatters.values()) + [time_text, ]
anim = FuncAnimation(fig, update_plot_kmeans,
frames=len(all_data), blit=True)
anim.save(f'animationKMEANS{name_suffix}.gif')
plt.show()
def calc_error(centroids, clusters, k):
squared_errors = []
for i in range(k):
cluster = np.array(clusters[i])
centroid = np.array([centroids[i] for _ in range(len(cluster))])
errors = cluster - centroid
squared_errors.append([e ** 2 for e in errors])
return sum([np.mean(err) if err else 0 for err in squared_errors])
def plot_error_data(error_data):
fig, ax = plt.subplots()
ax.set_xlabel('k')
ax.set_ylabel('err')
ax.set_xlim(2, 20)
plt.title('Errors')
plt.grid(True)
lst_x, lst_y = zip(*error_data)
lst_x = list(lst_x)
lst_y = list(lst_y)
ax.plot(lst_x, lst_y, 'ro-')
plt.show()
def print_stats(k, data):
print(f'k={k}')
centroids_with_clusters, errs = zip(*data)
centroids, clusters = zip(*centroids_with_clusters)
m = np.mean(errs)
std = np.std(errs)
min_err = np.min(errs)
empty_clusters = [sum([1 for cluster in sample.values() if not cluster]) for sample in clusters]
empty_clusters_mean = sum(empty_clusters)/len(empty_clusters)
empty_clusters_std = np.std(empty_clusters)
print(f'MSE={m}')
print(f'std={std}')
print(f'min(err)={min_err}')
print(f'Mean of empty clusters count={empty_clusters_mean}')
print(f'Standard deviation of empty clusters count={empty_clusters_std}')
print('='*20)
def kmeans(data, method, k):
kmeans_with_err = []
for _ in range(100):
centroids_with_clusters = []
centroids = init_units(data, k, method=method)
clusters = {}
for i in range(k):
clusters[i] = []
for point in data:
lengths = [utils.calc_length(c, point) for c in centroids]
index_min = int(np.argmin(lengths))
clusters[index_min].append(point)
centroids_with_clusters.append((list(centroids), clusters))
for _ in range(100):
for key in clusters:
if clusters[key]:
centroids[key] = np.mean(clusters[key], axis=0)
clusters = {}
for i in range(k):
clusters[i] = []
for point in data:
lengths = [utils.calc_length(c, point)
for c in centroids]
index_min = int(np.argmin(lengths))
clusters[index_min].append(point)
centroids_with_clusters.append(
(list(centroids), clusters))
if all([all(np.isclose(centroids_with_clusters[-1][0][i], centroids_with_clusters[-2][0][i]))
for i in range(k)]):
break
err = calc_error(centroids, clusters, k)
kmeans_with_err.append((centroids_with_clusters, err))
return kmeans_with_err
def init_units(data, k, method='forgy'): # TODO: Add k-units++ and Random Partition
match method:
case 'forgy':
return sample(data, k)
case 'random_partition':
shuffled = list(data)
shuffle(shuffled)
div = len(shuffled) / k
partition = [
shuffled[int(round(div * i)):int(round(div * (i + 1)))] for i in range(k)]
return [np.mean(prt, axis=0) for prt in partition]
case _:
raise NotImplementedError(
f'method {method} is not implemented yet')