zad3 wip2
This commit is contained in:
parent
baab800220
commit
d30f72f4c1
1
zad3/data1.json
Normal file
1
zad3/data1.json
Normal file
File diff suppressed because one or more lines are too long
BIN
zad3/data1_errors.png
(Stored with Git LFS)
Normal file
BIN
zad3/data1_errors.png
(Stored with Git LFS)
Normal file
Binary file not shown.
1
zad3/data2.json
Normal file
1
zad3/data2.json
Normal file
File diff suppressed because one or more lines are too long
BIN
zad3/data2_errors.png
(Stored with Git LFS)
Normal file
BIN
zad3/data2_errors.png
(Stored with Git LFS)
Normal file
Binary file not shown.
BIN
zad3/ml_195642_zad3.odt
(Stored with Git LFS)
Normal file
BIN
zad3/ml_195642_zad3.odt
(Stored with Git LFS)
Normal file
Binary file not shown.
120
zad3/zad3.py
120
zad3/zad3.py
@ -1,8 +1,12 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.animation import FuncAnimation
|
from matplotlib.animation import FuncAnimation
|
||||||
from random import sample
|
from random import sample, shuffle
|
||||||
from generate_points import get_random_point
|
from generate_points import get_random_point
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
METHODS = ['forgy', 'random_partition']
|
||||||
|
|
||||||
|
|
||||||
def get_color(i):
|
def get_color(i):
|
||||||
@ -49,11 +53,12 @@ def plot_kmeans(all_data, k):
|
|||||||
cluster_scatters = []
|
cluster_scatters = []
|
||||||
centroids, clusters = all_data[0]
|
centroids, clusters = all_data[0]
|
||||||
for key in clusters:
|
for key in clusters:
|
||||||
lst_x, lst_y = zip(*clusters[key])
|
|
||||||
lst_x = list(lst_x)
|
|
||||||
lst_y = list(lst_y)
|
|
||||||
color = get_color(key/k)
|
color = get_color(key/k)
|
||||||
cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color))
|
if clusters[key]:
|
||||||
|
lst_x, lst_y = zip(*clusters[key])
|
||||||
|
lst_x = list(lst_x)
|
||||||
|
lst_y = list(lst_y)
|
||||||
|
cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color))
|
||||||
centroid_scatters.append(ax.scatter([centroids[key][0]], [
|
centroid_scatters.append(ax.scatter([centroids[key][0]], [
|
||||||
centroids[key][1]], color=color, marker='X'))
|
centroids[key][1]], color=color, marker='X'))
|
||||||
|
|
||||||
@ -77,10 +82,16 @@ def calc_length(a, b):
|
|||||||
return (b[0]-a[0])**2+(b[1]-a[1])**2
|
return (b[0]-a[0])**2+(b[1]-a[1])**2
|
||||||
|
|
||||||
|
|
||||||
def init_centroids(data, k, method='forgy'): #TODO: Add k-means++ and Random Partition
|
def init_centroids(data, k, method='forgy'): # TODO: Add k-means++ and Random Partition
|
||||||
match method:
|
match method:
|
||||||
case 'forgy':
|
case 'forgy':
|
||||||
return sample(data, k)
|
return sample(data, k)
|
||||||
|
case 'random_partition':
|
||||||
|
shuffled = list(data)
|
||||||
|
shuffle(shuffled)
|
||||||
|
div = len(shuffled)/k
|
||||||
|
partition = [shuffled[int(round(div*i)):int(round(div*(i+1)))] for i in range(k)]
|
||||||
|
return [np.mean(prt, axis=0) for prt in partition]
|
||||||
case _:
|
case _:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f'method {method} is not implemented yet')
|
f'method {method} is not implemented yet')
|
||||||
@ -91,9 +102,9 @@ def calc_error(centroids, clusters, k):
|
|||||||
for i in range(k):
|
for i in range(k):
|
||||||
cluster = np.array(clusters[i])
|
cluster = np.array(clusters[i])
|
||||||
centroid = np.array([centroids[i] for _ in range(len(cluster))])
|
centroid = np.array([centroids[i] for _ in range(len(cluster))])
|
||||||
errors = centroid - cluster
|
errors = cluster - centroid
|
||||||
squared_errors.append([e**2 for e in errors])
|
squared_errors.append([e**2 for e in errors])
|
||||||
return sum([np.mean(err) for err in squared_errors])
|
return sum([np.mean(err) if err else 0 for err in squared_errors])
|
||||||
|
|
||||||
|
|
||||||
def plot_error_data(error_data):
|
def plot_error_data(error_data):
|
||||||
@ -112,28 +123,29 @@ def plot_error_data(error_data):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def print_stats(k, data):
|
||||||
for get_data in [get_data1, get_data2]:
|
print('='*20)
|
||||||
data = get_data()
|
print(f'k={k}')
|
||||||
|
errs = [x[1] for x in data]
|
||||||
|
m = np.mean(errs)
|
||||||
|
std = np.std(errs)
|
||||||
|
min_err = np.min(errs)
|
||||||
|
lst_empty = [sum([1 for cluster in centroids_with_clusters[1] if not cluster]) for centroids_with_clusters,_ in data]
|
||||||
|
print(lst_empty)
|
||||||
|
|
||||||
|
|
||||||
|
def main(datas):
|
||||||
|
# for get_data in [get_data1, get_data2]:
|
||||||
|
# data = get_data()
|
||||||
|
for data in datas:
|
||||||
plot_data(data)
|
plot_data(data)
|
||||||
kmeans_data = {}
|
for method in METHODS:
|
||||||
for k in range(2, 21):
|
kmeans_data = {}
|
||||||
kmeans_with_err = []
|
for k in [20]: # range(2, 21):
|
||||||
for _ in range(100):
|
kmeans_with_err = []
|
||||||
all_data = []
|
|
||||||
centroids = init_centroids(data, k)
|
|
||||||
clusters = {}
|
|
||||||
for i in range(k):
|
|
||||||
clusters[i] = []
|
|
||||||
for point in data:
|
|
||||||
lengths = [calc_length(c, point) for c in centroids]
|
|
||||||
index_min = np.argmin(lengths)
|
|
||||||
clusters[index_min].append(point)
|
|
||||||
all_data.append((list(centroids), clusters))
|
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
for key in clusters:
|
centroids_with_clusters = []
|
||||||
if clusters[key]:
|
centroids = init_centroids(data, k, method=method)
|
||||||
centroids[key] = np.mean(clusters[key], axis=0)
|
|
||||||
clusters = {}
|
clusters = {}
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
clusters[i] = []
|
clusters[i] = []
|
||||||
@ -141,22 +153,42 @@ def main():
|
|||||||
lengths = [calc_length(c, point) for c in centroids]
|
lengths = [calc_length(c, point) for c in centroids]
|
||||||
index_min = np.argmin(lengths)
|
index_min = np.argmin(lengths)
|
||||||
clusters[index_min].append(point)
|
clusters[index_min].append(point)
|
||||||
all_data.append((list(centroids), clusters))
|
centroids_with_clusters.append((list(centroids), clusters))
|
||||||
if all([all(np.isclose(all_data[-1][0][i], all_data[-2][0][i])) for i in range(k)]):
|
for _ in range(100):
|
||||||
break
|
for key in clusters:
|
||||||
err = calc_error(centroids, clusters, k)
|
if clusters[key]:
|
||||||
kmeans_with_err.append((all_data, err))
|
centroids[key] = np.mean(clusters[key], axis=0)
|
||||||
min_err = kmeans_with_err[0][1]
|
clusters = {}
|
||||||
kmeans = kmeans_with_err[0][0]
|
for i in range(k):
|
||||||
for temp_kmeans, err in kmeans_with_err:
|
clusters[i] = []
|
||||||
if err < min_err:
|
for point in data:
|
||||||
min_err = err
|
lengths = [calc_length(c, point)
|
||||||
kmeans = temp_kmeans
|
for c in centroids]
|
||||||
kmeans_data[k] = (kmeans, min_err)
|
index_min = np.argmin(lengths)
|
||||||
plot_kmeans(kmeans, k)
|
clusters[index_min].append(point)
|
||||||
error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)]
|
centroids_with_clusters.append(
|
||||||
plot_error_data(error_data)
|
(list(centroids), clusters))
|
||||||
|
if all([all(np.isclose(centroids_with_clusters[-1][0][i], centroids_with_clusters[-2][0][i])) for i in range(k)]):
|
||||||
|
break
|
||||||
|
err = calc_error(centroids, clusters, k)
|
||||||
|
kmeans_with_err.append((centroids_with_clusters, err))
|
||||||
|
print_stats(k, [(iterations[-1],err) for iterations, err in kmeans_with_err])
|
||||||
|
min_err = kmeans_with_err[0][1]
|
||||||
|
kmeans = kmeans_with_err[0][0]
|
||||||
|
for temp_kmeans, err in kmeans_with_err:
|
||||||
|
if err < min_err:
|
||||||
|
min_err = err
|
||||||
|
kmeans = temp_kmeans
|
||||||
|
kmeans_data[k]=(kmeans, min_err)
|
||||||
|
plot_kmeans(kmeans, k)
|
||||||
|
#error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)]
|
||||||
|
#plot_error_data(error_data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
datas = []
|
||||||
|
with open('data1.json', 'r') as d:
|
||||||
|
datas.append(json.loads(d.read()))
|
||||||
|
with open('data2.json', 'r') as d:
|
||||||
|
datas.append(json.loads(d.read()))
|
||||||
|
main(datas)
|
||||||
|
Loading…
Reference in New Issue
Block a user