import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation from random import sample, shuffle from generate_points import get_random_point import numpy as np import json METHODS = ['forgy', 'random_partition'] def get_color(i): return plt.get_cmap('tab20')(i) def get_data1(): data = [] for _ in range(200): data.append(get_random_point((0, 0), 1)) return data def get_data2(): data = [] for i in range(2): for _ in range(100): data.append(get_random_point((3*((-1)**i), 0), 0.5)) return data def plot_data(data): lst_x, lst_y = zip(*data) lst_x = list(lst_x) lst_y = list(lst_y) plt.figure(1) ax = plt.axes() ax.scatter(lst_x, lst_y) ax.set_xlabel('X') ax.set_ylabel('Y') plt.grid(True) plt.show() def plot_kmeans(all_data, k): fig, ax = plt.subplots() ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_title(f'k={k}') time_text = ax.text(0.05, 0.95, 'iter=0', horizontalalignment='left', verticalalignment='top', transform=ax.transAxes) plt.grid(True) centroid_scatters = [] cluster_scatters = [] centroids, clusters = all_data[0] for key in clusters: color = get_color(key/k) if clusters[key]: lst_x, lst_y = zip(*clusters[key]) lst_x = list(lst_x) lst_y = list(lst_y) cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color)) centroid_scatters.append(ax.scatter([centroids[key][0]], [ centroids[key][1]], color=color, marker='X')) def update_plot_kmeans(i): centroids, clusters = all_data[i] time_text.set_text(f'iter={i}') for key in clusters: centroid_scatters[key].set_offsets(centroids[key]) cluster_scatters[key].set_offsets(clusters[key]) return centroid_scatters+cluster_scatters+[time_text, ] anim = FuncAnimation(fig, update_plot_kmeans, frames=len(all_data), blit=True) # anim.save('animation.mp4') plt.show() def calc_length(a, b): # return ((b[0]-a[0])**2+(b[1]-a[1])**2)**0.5 # no need to calculate square root for comparison return (b[0]-a[0])**2+(b[1]-a[1])**2 def init_centroids(data, k, method='forgy'): # TODO: Add k-means++ and Random Partition match method: case 'forgy': return sample(data, k) case 'random_partition': shuffled = list(data) shuffle(shuffled) div = len(shuffled)/k partition = [shuffled[int(round(div*i)):int(round(div*(i+1)))] for i in range(k)] return [np.mean(prt, axis=0) for prt in partition] case _: raise NotImplementedError( f'method {method} is not implemented yet') def calc_error(centroids, clusters, k): squared_errors = [] for i in range(k): cluster = np.array(clusters[i]) centroid = np.array([centroids[i] for _ in range(len(cluster))]) errors = cluster - centroid squared_errors.append([e**2 for e in errors]) return sum([np.mean(err) if err else 0 for err in squared_errors]) def plot_error_data(error_data): fig, ax = plt.subplots() ax.set_xlabel('k') ax.set_ylabel('err') ax.set_xlim(2, 20) plt.title('Errors') plt.grid(True) lst_x, lst_y = zip(*error_data) lst_x = list(lst_x) lst_y = list(lst_y) ax.plot(lst_x, lst_y, 'ro-') plt.show() def print_stats(k, data): print('='*20) print(f'k={k}') errs = [x[1] for x in data] m = np.mean(errs) std = np.std(errs) min_err = np.min(errs) lst_empty = [sum([1 for cluster in centroids_with_clusters[1] if not cluster]) for centroids_with_clusters,_ in data] print(lst_empty) def main(datas): # for get_data in [get_data1, get_data2]: # data = get_data() for data in datas: plot_data(data) for method in METHODS: kmeans_data = {} for k in [20]: # range(2, 21): kmeans_with_err = [] for _ in range(100): centroids_with_clusters = [] centroids = init_centroids(data, k, method=method) clusters = {} for i in range(k): clusters[i] = [] for point in data: lengths = [calc_length(c, point) for c in centroids] index_min = np.argmin(lengths) clusters[index_min].append(point) centroids_with_clusters.append((list(centroids), clusters)) for _ in range(100): for key in clusters: if clusters[key]: centroids[key] = np.mean(clusters[key], axis=0) clusters = {} for i in range(k): clusters[i] = [] for point in data: lengths = [calc_length(c, point) for c in centroids] index_min = np.argmin(lengths) clusters[index_min].append(point) centroids_with_clusters.append( (list(centroids), clusters)) if all([all(np.isclose(centroids_with_clusters[-1][0][i], centroids_with_clusters[-2][0][i])) for i in range(k)]): break err = calc_error(centroids, clusters, k) kmeans_with_err.append((centroids_with_clusters, err)) print_stats(k, [(iterations[-1],err) for iterations, err in kmeans_with_err]) min_err = kmeans_with_err[0][1] kmeans = kmeans_with_err[0][0] for temp_kmeans, err in kmeans_with_err: if err < min_err: min_err = err kmeans = temp_kmeans kmeans_data[k]=(kmeans, min_err) plot_kmeans(kmeans, k) #error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)] #plot_error_data(error_data) if __name__ == '__main__': datas = [] with open('data1.json', 'r') as d: datas.append(json.loads(d.read())) with open('data2.json', 'r') as d: datas.append(json.loads(d.read())) main(datas)