zad3 wip2
This commit is contained in:
		
							
								
								
									
										1
									
								
								zad3/data1.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								zad3/data1.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								zad3/data1_errors.png
									
									LFS
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								zad3/data1_errors.png
									
									LFS
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										1
									
								
								zad3/data2.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								zad3/data2.json
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								zad3/data2_errors.png
									
									LFS
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								zad3/data2_errors.png
									
									LFS
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										
											BIN
										
									
								
								zad3/ml_195642_zad3.odt
									
									LFS
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								zad3/ml_195642_zad3.odt
									
									LFS
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							
							
								
								
									
										120
									
								
								zad3/zad3.py
									
									
									
									
									
								
							
							
						
						
									
										120
									
								
								zad3/zad3.py
									
									
									
									
									
								
							| @@ -1,8 +1,12 @@ | |||||||
| import matplotlib.pyplot as plt | import matplotlib.pyplot as plt | ||||||
| from matplotlib.animation import FuncAnimation | from matplotlib.animation import FuncAnimation | ||||||
| from random import sample | from random import sample, shuffle | ||||||
| from generate_points import get_random_point | from generate_points import get_random_point | ||||||
| import numpy as np | import numpy as np | ||||||
|  | import json | ||||||
|  |  | ||||||
|  |  | ||||||
|  | METHODS = ['forgy', 'random_partition'] | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_color(i): | def get_color(i): | ||||||
| @@ -49,11 +53,12 @@ def plot_kmeans(all_data, k): | |||||||
|     cluster_scatters = [] |     cluster_scatters = [] | ||||||
|     centroids, clusters = all_data[0] |     centroids, clusters = all_data[0] | ||||||
|     for key in clusters: |     for key in clusters: | ||||||
|         lst_x, lst_y = zip(*clusters[key]) |  | ||||||
|         lst_x = list(lst_x) |  | ||||||
|         lst_y = list(lst_y) |  | ||||||
|         color = get_color(key/k) |         color = get_color(key/k) | ||||||
|         cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color)) |         if clusters[key]: | ||||||
|  |             lst_x, lst_y = zip(*clusters[key]) | ||||||
|  |             lst_x = list(lst_x) | ||||||
|  |             lst_y = list(lst_y) | ||||||
|  |             cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color)) | ||||||
|         centroid_scatters.append(ax.scatter([centroids[key][0]], [ |         centroid_scatters.append(ax.scatter([centroids[key][0]], [ | ||||||
|                                  centroids[key][1]], color=color, marker='X')) |                                  centroids[key][1]], color=color, marker='X')) | ||||||
|  |  | ||||||
| @@ -77,10 +82,16 @@ def calc_length(a, b): | |||||||
|     return (b[0]-a[0])**2+(b[1]-a[1])**2 |     return (b[0]-a[0])**2+(b[1]-a[1])**2 | ||||||
|  |  | ||||||
|  |  | ||||||
| def init_centroids(data, k, method='forgy'): #TODO: Add k-means++ and Random Partition | def init_centroids(data, k, method='forgy'):  # TODO: Add k-means++ and Random Partition | ||||||
|     match method: |     match method: | ||||||
|         case 'forgy': |         case 'forgy': | ||||||
|             return sample(data, k) |             return sample(data, k) | ||||||
|  |         case 'random_partition': | ||||||
|  |             shuffled = list(data) | ||||||
|  |             shuffle(shuffled) | ||||||
|  |             div = len(shuffled)/k | ||||||
|  |             partition = [shuffled[int(round(div*i)):int(round(div*(i+1)))] for i in range(k)] | ||||||
|  |             return [np.mean(prt, axis=0) for prt in partition] | ||||||
|         case _: |         case _: | ||||||
|             raise NotImplementedError( |             raise NotImplementedError( | ||||||
|                 f'method {method} is not implemented yet') |                 f'method {method} is not implemented yet') | ||||||
| @@ -91,9 +102,9 @@ def calc_error(centroids, clusters, k): | |||||||
|     for i in range(k): |     for i in range(k): | ||||||
|         cluster = np.array(clusters[i]) |         cluster = np.array(clusters[i]) | ||||||
|         centroid = np.array([centroids[i] for _ in range(len(cluster))]) |         centroid = np.array([centroids[i] for _ in range(len(cluster))]) | ||||||
|         errors = centroid - cluster |         errors = cluster - centroid | ||||||
|         squared_errors.append([e**2 for e in errors]) |         squared_errors.append([e**2 for e in errors]) | ||||||
|     return sum([np.mean(err) for err in squared_errors]) |     return sum([np.mean(err) if err else 0 for err in squared_errors]) | ||||||
|  |  | ||||||
|  |  | ||||||
| def plot_error_data(error_data): | def plot_error_data(error_data): | ||||||
| @@ -112,28 +123,29 @@ def plot_error_data(error_data): | |||||||
|     plt.show() |     plt.show() | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(): | def print_stats(k, data): | ||||||
|     for get_data in [get_data1, get_data2]: |     print('='*20) | ||||||
|         data = get_data() |     print(f'k={k}') | ||||||
|  |     errs = [x[1] for x in data] | ||||||
|  |     m = np.mean(errs) | ||||||
|  |     std = np.std(errs) | ||||||
|  |     min_err = np.min(errs) | ||||||
|  |     lst_empty = [sum([1 for cluster in centroids_with_clusters[1] if not cluster]) for centroids_with_clusters,_ in data] | ||||||
|  |     print(lst_empty) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(datas): | ||||||
|  |     # for get_data in [get_data1, get_data2]: | ||||||
|  |     #    data = get_data() | ||||||
|  |     for data in datas: | ||||||
|         plot_data(data) |         plot_data(data) | ||||||
|         kmeans_data = {} |         for method in METHODS: | ||||||
|         for k in range(2, 21): |             kmeans_data = {} | ||||||
|             kmeans_with_err = [] |             for k in [20]:  # range(2, 21): | ||||||
|             for _ in range(100): |                 kmeans_with_err = [] | ||||||
|                 all_data = [] |  | ||||||
|                 centroids = init_centroids(data, k) |  | ||||||
|                 clusters = {} |  | ||||||
|                 for i in range(k): |  | ||||||
|                     clusters[i] = [] |  | ||||||
|                 for point in data: |  | ||||||
|                     lengths = [calc_length(c, point) for c in centroids] |  | ||||||
|                     index_min = np.argmin(lengths) |  | ||||||
|                     clusters[index_min].append(point) |  | ||||||
|                 all_data.append((list(centroids), clusters)) |  | ||||||
|                 for _ in range(100): |                 for _ in range(100): | ||||||
|                     for key in clusters: |                     centroids_with_clusters = [] | ||||||
|                         if clusters[key]: |                     centroids = init_centroids(data, k, method=method) | ||||||
|                             centroids[key] = np.mean(clusters[key], axis=0) |  | ||||||
|                     clusters = {} |                     clusters = {} | ||||||
|                     for i in range(k): |                     for i in range(k): | ||||||
|                         clusters[i] = [] |                         clusters[i] = [] | ||||||
| @@ -141,22 +153,42 @@ def main(): | |||||||
|                         lengths = [calc_length(c, point) for c in centroids] |                         lengths = [calc_length(c, point) for c in centroids] | ||||||
|                         index_min = np.argmin(lengths) |                         index_min = np.argmin(lengths) | ||||||
|                         clusters[index_min].append(point) |                         clusters[index_min].append(point) | ||||||
|                     all_data.append((list(centroids), clusters)) |                     centroids_with_clusters.append((list(centroids), clusters)) | ||||||
|                     if all([all(np.isclose(all_data[-1][0][i], all_data[-2][0][i])) for i in range(k)]): |                     for _ in range(100): | ||||||
|                         break |                         for key in clusters: | ||||||
|                 err = calc_error(centroids, clusters, k) |                             if clusters[key]: | ||||||
|                 kmeans_with_err.append((all_data, err)) |                                 centroids[key] = np.mean(clusters[key], axis=0) | ||||||
|             min_err = kmeans_with_err[0][1] |                         clusters = {} | ||||||
|             kmeans = kmeans_with_err[0][0] |                         for i in range(k): | ||||||
|             for temp_kmeans, err in kmeans_with_err: |                             clusters[i] = [] | ||||||
|                 if err < min_err: |                         for point in data: | ||||||
|                     min_err = err |                             lengths = [calc_length(c, point) | ||||||
|                     kmeans = temp_kmeans |                                     for c in centroids] | ||||||
|             kmeans_data[k] = (kmeans, min_err) |                             index_min = np.argmin(lengths) | ||||||
|             plot_kmeans(kmeans, k) |                             clusters[index_min].append(point) | ||||||
|         error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)] |                         centroids_with_clusters.append( | ||||||
|         plot_error_data(error_data) |                             (list(centroids), clusters)) | ||||||
|  |                         if all([all(np.isclose(centroids_with_clusters[-1][0][i], centroids_with_clusters[-2][0][i])) for i in range(k)]): | ||||||
|  |                             break | ||||||
|  |                     err = calc_error(centroids, clusters, k) | ||||||
|  |                     kmeans_with_err.append((centroids_with_clusters, err)) | ||||||
|  |                 print_stats(k, [(iterations[-1],err) for iterations, err in kmeans_with_err]) | ||||||
|  |                 min_err = kmeans_with_err[0][1] | ||||||
|  |                 kmeans = kmeans_with_err[0][0] | ||||||
|  |                 for temp_kmeans, err in kmeans_with_err: | ||||||
|  |                     if err < min_err: | ||||||
|  |                         min_err = err | ||||||
|  |                         kmeans = temp_kmeans | ||||||
|  |                 kmeans_data[k]=(kmeans, min_err) | ||||||
|  |                 plot_kmeans(kmeans, k) | ||||||
|  |             #error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)] | ||||||
|  |             #plot_error_data(error_data) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||||
|     main() |     datas = [] | ||||||
|  |     with open('data1.json', 'r') as d: | ||||||
|  |         datas.append(json.loads(d.read())) | ||||||
|  |     with open('data2.json', 'r') as d: | ||||||
|  |         datas.append(json.loads(d.read())) | ||||||
|  |     main(datas) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user