diff --git a/zad3/generate_points.py b/zad3/generate_points.py new file mode 100644 index 0000000..e2b8a23 --- /dev/null +++ b/zad3/generate_points.py @@ -0,0 +1,12 @@ +from math import pi, cos, sin, sqrt +from random import random +from typing import Tuple + + +def get_random_point(center: Tuple[float, float], radius: float) -> Tuple[float, float]: + shift_x, shift_y = center + + a = random() * 2 * pi + r = radius * sqrt(random()) + + return r * cos(a) + shift_x, r * sin(a) + shift_y diff --git a/zad3/zad3.py b/zad3/zad3.py new file mode 100644 index 0000000..154f2db --- /dev/null +++ b/zad3/zad3.py @@ -0,0 +1,162 @@ +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation +from random import sample +from generate_points import get_random_point +import numpy as np + + +def get_color(i): + return plt.get_cmap('tab20')(i) + + +def get_data1(): + data = [] + for _ in range(200): + data.append(get_random_point((0, 0), 1)) + return data + + +def get_data2(): + data = [] + for i in range(2): + for _ in range(100): + data.append(get_random_point((3*((-1)**i), 0), 0.5)) + return data + + +def plot_data(data): + lst_x, lst_y = zip(*data) + lst_x = list(lst_x) + lst_y = list(lst_y) + plt.figure(1) + ax = plt.axes() + ax.scatter(lst_x, lst_y) + ax.set_xlabel('X') + ax.set_ylabel('Y') + plt.grid(True) + plt.show() + + +def plot_kmeans(all_data, k): + fig, ax = plt.subplots() + ax.set_xlabel('X') + ax.set_ylabel('Y') + ax.set_title(f'k={k}') + time_text = ax.text(0.05, 0.95, 'iter=0', horizontalalignment='left', + verticalalignment='top', transform=ax.transAxes) + plt.grid(True) + centroid_scatters = [] + cluster_scatters = [] + centroids, clusters = all_data[0] + for key in clusters: + lst_x, lst_y = zip(*clusters[key]) + lst_x = list(lst_x) + lst_y = list(lst_y) + color = get_color(key/k) + cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color)) + centroid_scatters.append(ax.scatter([centroids[key][0]], [ + centroids[key][1]], color=color, marker='X')) + + def update_plot_kmeans(i): + centroids, clusters = all_data[i] + time_text.set_text(f'iter={i}') + for key in clusters: + centroid_scatters[key].set_offsets(centroids[key]) + cluster_scatters[key].set_offsets(clusters[key]) + return centroid_scatters+cluster_scatters+[time_text, ] + anim = FuncAnimation(fig, update_plot_kmeans, + frames=len(all_data), blit=True) + # anim.save('animation.mp4') + + plt.show() + + +def calc_length(a, b): + # return ((b[0]-a[0])**2+(b[1]-a[1])**2)**0.5 + # no need to calculate square root for comparison + return (b[0]-a[0])**2+(b[1]-a[1])**2 + + +def init_centroids(data, k, method='forgy'): + match method: + case 'forgy': + return sample(data, k) + case _: + raise NotImplementedError( + f'method {method} is not implemented yet') + + +def calc_error(centroids, clusters, k): + squared_errors = [] + for i in range(k): + cluster = np.array(clusters[i]) + centroid = np.array([centroids[i] for _ in range(len(cluster))]) + errors = centroid - cluster + squared_errors.append([e**2 for e in errors]) + return sum([np.mean(err) for err in squared_errors]) + + +def plot_error_data(error_data): + fig, ax = plt.subplots() + ax.set_xlabel('k') + ax.set_ylabel('err') + ax.set_xlim(2, 20) + plt.title('Errors') + plt.grid(True) + + lst_x, lst_y = zip(*error_data) + lst_x = list(lst_x) + lst_y = list(lst_y) + ax.plot(lst_x, lst_y, 'ro-') + + plt.show() + + +def main(): + for get_data in [get_data1, get_data2]: + data = get_data() + plot_data(data) + kmeans_data = {} + for k in range(2, 21): + kmeans_with_err = [] + for _ in range(10): + all_data = [] + centroids = init_centroids(data, k) + clusters = {} + for i in range(k): + clusters[i] = [] + for point in data: + lengths = [calc_length(c, point) for c in centroids] + index_min = np.argmin(lengths) + clusters[index_min].append(point) + all_data.append((list(centroids), clusters)) + for _ in range(100): + for key in clusters: + if clusters[key]: + centroids[key] = np.mean(clusters[key], axis=0) + clusters = {} + for i in range(k): + clusters[i] = [] + for point in data: + lengths = [calc_length(c, point) for c in centroids] + index_min = np.argmin(lengths) + clusters[index_min].append(point) + all_data.append((list(centroids), clusters)) + if all([all(np.isclose(all_data[-1][0][i], all_data[-2][0][i])) for i in range(k)]): + break + err = calc_error(centroids, clusters, k) + kmeans_with_err.append((all_data, err)) + min_err = kmeans_with_err[0][1] + kmeans = kmeans_with_err[0][0] + for temp_kmeans, err in kmeans_with_err: + if err < min_err: + min_err = err + kmeans = temp_kmeans + kmeans_data[k] = (kmeans, min_err) + plot_kmeans(kmeans, k) + error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)] + plot_error_data(error_data) + + +if __name__ == '__main__': + main()