zad3 wip2
This commit is contained in:
File diff suppressed because one or more lines are too long
BIN
Binary file not shown.
File diff suppressed because one or more lines are too long
BIN
Binary file not shown.
BIN
Binary file not shown.
+50
-18
@@ -1,8 +1,12 @@
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.animation import FuncAnimation
|
||||
from random import sample
|
||||
from random import sample, shuffle
|
||||
from generate_points import get_random_point
|
||||
import numpy as np
|
||||
import json
|
||||
|
||||
|
||||
METHODS = ['forgy', 'random_partition']
|
||||
|
||||
|
||||
def get_color(i):
|
||||
@@ -49,10 +53,11 @@ def plot_kmeans(all_data, k):
|
||||
cluster_scatters = []
|
||||
centroids, clusters = all_data[0]
|
||||
for key in clusters:
|
||||
color = get_color(key/k)
|
||||
if clusters[key]:
|
||||
lst_x, lst_y = zip(*clusters[key])
|
||||
lst_x = list(lst_x)
|
||||
lst_y = list(lst_y)
|
||||
color = get_color(key/k)
|
||||
cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color))
|
||||
centroid_scatters.append(ax.scatter([centroids[key][0]], [
|
||||
centroids[key][1]], color=color, marker='X'))
|
||||
@@ -81,6 +86,12 @@ def init_centroids(data, k, method='forgy'): #TODO: Add k-means++ and Random Par
|
||||
match method:
|
||||
case 'forgy':
|
||||
return sample(data, k)
|
||||
case 'random_partition':
|
||||
shuffled = list(data)
|
||||
shuffle(shuffled)
|
||||
div = len(shuffled)/k
|
||||
partition = [shuffled[int(round(div*i)):int(round(div*(i+1)))] for i in range(k)]
|
||||
return [np.mean(prt, axis=0) for prt in partition]
|
||||
case _:
|
||||
raise NotImplementedError(
|
||||
f'method {method} is not implemented yet')
|
||||
@@ -91,9 +102,9 @@ def calc_error(centroids, clusters, k):
|
||||
for i in range(k):
|
||||
cluster = np.array(clusters[i])
|
||||
centroid = np.array([centroids[i] for _ in range(len(cluster))])
|
||||
errors = centroid - cluster
|
||||
errors = cluster - centroid
|
||||
squared_errors.append([e**2 for e in errors])
|
||||
return sum([np.mean(err) for err in squared_errors])
|
||||
return sum([np.mean(err) if err else 0 for err in squared_errors])
|
||||
|
||||
|
||||
def plot_error_data(error_data):
|
||||
@@ -112,16 +123,29 @@ def plot_error_data(error_data):
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
for get_data in [get_data1, get_data2]:
|
||||
data = get_data()
|
||||
def print_stats(k, data):
|
||||
print('='*20)
|
||||
print(f'k={k}')
|
||||
errs = [x[1] for x in data]
|
||||
m = np.mean(errs)
|
||||
std = np.std(errs)
|
||||
min_err = np.min(errs)
|
||||
lst_empty = [sum([1 for cluster in centroids_with_clusters[1] if not cluster]) for centroids_with_clusters,_ in data]
|
||||
print(lst_empty)
|
||||
|
||||
|
||||
def main(datas):
|
||||
# for get_data in [get_data1, get_data2]:
|
||||
# data = get_data()
|
||||
for data in datas:
|
||||
plot_data(data)
|
||||
for method in METHODS:
|
||||
kmeans_data = {}
|
||||
for k in range(2, 21):
|
||||
for k in [20]: # range(2, 21):
|
||||
kmeans_with_err = []
|
||||
for _ in range(100):
|
||||
all_data = []
|
||||
centroids = init_centroids(data, k)
|
||||
centroids_with_clusters = []
|
||||
centroids = init_centroids(data, k, method=method)
|
||||
clusters = {}
|
||||
for i in range(k):
|
||||
clusters[i] = []
|
||||
@@ -129,7 +153,7 @@ def main():
|
||||
lengths = [calc_length(c, point) for c in centroids]
|
||||
index_min = np.argmin(lengths)
|
||||
clusters[index_min].append(point)
|
||||
all_data.append((list(centroids), clusters))
|
||||
centroids_with_clusters.append((list(centroids), clusters))
|
||||
for _ in range(100):
|
||||
for key in clusters:
|
||||
if clusters[key]:
|
||||
@@ -138,14 +162,17 @@ def main():
|
||||
for i in range(k):
|
||||
clusters[i] = []
|
||||
for point in data:
|
||||
lengths = [calc_length(c, point) for c in centroids]
|
||||
lengths = [calc_length(c, point)
|
||||
for c in centroids]
|
||||
index_min = np.argmin(lengths)
|
||||
clusters[index_min].append(point)
|
||||
all_data.append((list(centroids), clusters))
|
||||
if all([all(np.isclose(all_data[-1][0][i], all_data[-2][0][i])) for i in range(k)]):
|
||||
centroids_with_clusters.append(
|
||||
(list(centroids), clusters))
|
||||
if all([all(np.isclose(centroids_with_clusters[-1][0][i], centroids_with_clusters[-2][0][i])) for i in range(k)]):
|
||||
break
|
||||
err = calc_error(centroids, clusters, k)
|
||||
kmeans_with_err.append((all_data, err))
|
||||
kmeans_with_err.append((centroids_with_clusters, err))
|
||||
print_stats(k, [(iterations[-1],err) for iterations, err in kmeans_with_err])
|
||||
min_err = kmeans_with_err[0][1]
|
||||
kmeans = kmeans_with_err[0][0]
|
||||
for temp_kmeans, err in kmeans_with_err:
|
||||
@@ -154,9 +181,14 @@ def main():
|
||||
kmeans = temp_kmeans
|
||||
kmeans_data[k]=(kmeans, min_err)
|
||||
plot_kmeans(kmeans, k)
|
||||
error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)]
|
||||
plot_error_data(error_data)
|
||||
#error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)]
|
||||
#plot_error_data(error_data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
datas = []
|
||||
with open('data1.json', 'r') as d:
|
||||
datas.append(json.loads(d.read()))
|
||||
with open('data2.json', 'r') as d:
|
||||
datas.append(json.loads(d.read()))
|
||||
main(datas)
|
||||
|
||||
Reference in New Issue
Block a user