Zad3
This commit is contained in:
		
							
								
								
									
										12
									
								
								zad3/generate_points.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								zad3/generate_points.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| from math import pi, cos, sin, sqrt | ||||
| from random import random | ||||
| from typing import Tuple | ||||
|  | ||||
|  | ||||
| def get_random_point(center: Tuple[float, float], radius: float) -> Tuple[float, float]: | ||||
|     shift_x, shift_y = center | ||||
|  | ||||
|     a = random() * 2 * pi | ||||
|     r = radius * sqrt(random()) | ||||
|  | ||||
|     return r * cos(a) + shift_x, r * sin(a) + shift_y | ||||
							
								
								
									
										162
									
								
								zad3/zad3.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								zad3/zad3.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,162 @@ | ||||
| import matplotlib.pyplot as plt | ||||
| from matplotlib.animation import FuncAnimation | ||||
| from random import sample | ||||
| from generate_points import get_random_point | ||||
| import numpy as np | ||||
|  | ||||
|  | ||||
| def get_color(i): | ||||
|     return plt.get_cmap('tab20')(i) | ||||
|  | ||||
|  | ||||
| def get_data1(): | ||||
|     data = [] | ||||
|     for _ in range(200): | ||||
|         data.append(get_random_point((0, 0), 1)) | ||||
|     return data | ||||
|  | ||||
|  | ||||
| def get_data2(): | ||||
|     data = [] | ||||
|     for i in range(2): | ||||
|         for _ in range(100): | ||||
|             data.append(get_random_point((3*((-1)**i), 0), 0.5)) | ||||
|     return data | ||||
|  | ||||
|  | ||||
| def plot_data(data): | ||||
|     lst_x, lst_y = zip(*data) | ||||
|     lst_x = list(lst_x) | ||||
|     lst_y = list(lst_y) | ||||
|     plt.figure(1) | ||||
|     ax = plt.axes() | ||||
|     ax.scatter(lst_x, lst_y) | ||||
|     ax.set_xlabel('X') | ||||
|     ax.set_ylabel('Y') | ||||
|     plt.grid(True) | ||||
|     plt.show() | ||||
|  | ||||
|  | ||||
| def plot_kmeans(all_data, k): | ||||
|     fig, ax = plt.subplots() | ||||
|     ax.set_xlabel('X') | ||||
|     ax.set_ylabel('Y') | ||||
|     ax.set_title(f'k={k}') | ||||
|     time_text = ax.text(0.05, 0.95, 'iter=0', horizontalalignment='left', | ||||
|                         verticalalignment='top', transform=ax.transAxes) | ||||
|     plt.grid(True) | ||||
|     centroid_scatters = [] | ||||
|     cluster_scatters = [] | ||||
|     centroids, clusters = all_data[0] | ||||
|     for key in clusters: | ||||
|         lst_x, lst_y = zip(*clusters[key]) | ||||
|         lst_x = list(lst_x) | ||||
|         lst_y = list(lst_y) | ||||
|         color = get_color(key/k) | ||||
|         cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color)) | ||||
|         centroid_scatters.append(ax.scatter([centroids[key][0]], [ | ||||
|                                  centroids[key][1]], color=color, marker='X')) | ||||
|  | ||||
|     def update_plot_kmeans(i): | ||||
|         centroids, clusters = all_data[i] | ||||
|         time_text.set_text(f'iter={i}') | ||||
|         for key in clusters: | ||||
|             centroid_scatters[key].set_offsets(centroids[key]) | ||||
|             cluster_scatters[key].set_offsets(clusters[key]) | ||||
|         return centroid_scatters+cluster_scatters+[time_text, ] | ||||
|     anim = FuncAnimation(fig, update_plot_kmeans, | ||||
|                          frames=len(all_data), blit=True) | ||||
|     # anim.save('animation.mp4') | ||||
|  | ||||
|     plt.show() | ||||
|  | ||||
|  | ||||
| def calc_length(a, b): | ||||
|     # return ((b[0]-a[0])**2+(b[1]-a[1])**2)**0.5 | ||||
|     # no need to calculate square root for comparison | ||||
|     return (b[0]-a[0])**2+(b[1]-a[1])**2 | ||||
|  | ||||
|  | ||||
| def init_centroids(data, k, method='forgy'): | ||||
|     match method: | ||||
|         case 'forgy': | ||||
|             return sample(data, k) | ||||
|         case _: | ||||
|             raise NotImplementedError( | ||||
|                 f'method {method} is not implemented yet') | ||||
|  | ||||
|  | ||||
| def calc_error(centroids, clusters, k): | ||||
|     squared_errors = [] | ||||
|     for i in range(k): | ||||
|         cluster = np.array(clusters[i]) | ||||
|         centroid = np.array([centroids[i] for _ in range(len(cluster))]) | ||||
|         errors = centroid - cluster | ||||
|         squared_errors.append([e**2 for e in errors]) | ||||
|     return sum([np.mean(err) for err in squared_errors]) | ||||
|  | ||||
|  | ||||
| def plot_error_data(error_data): | ||||
|     fig, ax = plt.subplots() | ||||
|     ax.set_xlabel('k') | ||||
|     ax.set_ylabel('err') | ||||
|     ax.set_xlim(2, 20) | ||||
|     plt.title('Errors') | ||||
|     plt.grid(True) | ||||
|  | ||||
|     lst_x, lst_y = zip(*error_data) | ||||
|     lst_x = list(lst_x) | ||||
|     lst_y = list(lst_y) | ||||
|     ax.plot(lst_x, lst_y, 'ro-') | ||||
|  | ||||
|     plt.show() | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     for get_data in [get_data1, get_data2]: | ||||
|         data = get_data() | ||||
|         plot_data(data) | ||||
|         kmeans_data = {} | ||||
|         for k in range(2, 21): | ||||
|             kmeans_with_err = [] | ||||
|             for _ in range(10): | ||||
|                 all_data = [] | ||||
|                 centroids = init_centroids(data, k) | ||||
|                 clusters = {} | ||||
|                 for i in range(k): | ||||
|                     clusters[i] = [] | ||||
|                 for point in data: | ||||
|                     lengths = [calc_length(c, point) for c in centroids] | ||||
|                     index_min = np.argmin(lengths) | ||||
|                     clusters[index_min].append(point) | ||||
|                 all_data.append((list(centroids), clusters)) | ||||
|                 for _ in range(100): | ||||
|                     for key in clusters: | ||||
|                         if clusters[key]: | ||||
|                             centroids[key] = np.mean(clusters[key], axis=0) | ||||
|                     clusters = {} | ||||
|                     for i in range(k): | ||||
|                         clusters[i] = [] | ||||
|                     for point in data: | ||||
|                         lengths = [calc_length(c, point) for c in centroids] | ||||
|                         index_min = np.argmin(lengths) | ||||
|                         clusters[index_min].append(point) | ||||
|                     all_data.append((list(centroids), clusters)) | ||||
|                     if all([all(np.isclose(all_data[-1][0][i], all_data[-2][0][i])) for i in range(k)]): | ||||
|                         break | ||||
|                 err = calc_error(centroids, clusters, k) | ||||
|                 kmeans_with_err.append((all_data, err)) | ||||
|             min_err = kmeans_with_err[0][1] | ||||
|             kmeans = kmeans_with_err[0][0] | ||||
|             for temp_kmeans, err in kmeans_with_err: | ||||
|                 if err < min_err: | ||||
|                     min_err = err | ||||
|                     kmeans = temp_kmeans | ||||
|             kmeans_data[k] = (kmeans, min_err) | ||||
|             plot_kmeans(kmeans, k) | ||||
|         error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)] | ||||
|         plot_error_data(error_data) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|     main() | ||||
		Reference in New Issue
	
	Block a user