Zad3
This commit is contained in:
parent
fc5a5d8599
commit
0e4795ed43
12
zad3/generate_points.py
Normal file
12
zad3/generate_points.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from math import pi, cos, sin, sqrt
|
||||||
|
from random import random
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def get_random_point(center: Tuple[float, float], radius: float) -> Tuple[float, float]:
|
||||||
|
shift_x, shift_y = center
|
||||||
|
|
||||||
|
a = random() * 2 * pi
|
||||||
|
r = radius * sqrt(random())
|
||||||
|
|
||||||
|
return r * cos(a) + shift_x, r * sin(a) + shift_y
|
162
zad3/zad3.py
Normal file
162
zad3/zad3.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.animation import FuncAnimation
|
||||||
|
from random import sample
|
||||||
|
from generate_points import get_random_point
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def get_color(i):
|
||||||
|
return plt.get_cmap('tab20')(i)
|
||||||
|
|
||||||
|
|
||||||
|
def get_data1():
|
||||||
|
data = []
|
||||||
|
for _ in range(200):
|
||||||
|
data.append(get_random_point((0, 0), 1))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def get_data2():
|
||||||
|
data = []
|
||||||
|
for i in range(2):
|
||||||
|
for _ in range(100):
|
||||||
|
data.append(get_random_point((3*((-1)**i), 0), 0.5))
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def plot_data(data):
|
||||||
|
lst_x, lst_y = zip(*data)
|
||||||
|
lst_x = list(lst_x)
|
||||||
|
lst_y = list(lst_y)
|
||||||
|
plt.figure(1)
|
||||||
|
ax = plt.axes()
|
||||||
|
ax.scatter(lst_x, lst_y)
|
||||||
|
ax.set_xlabel('X')
|
||||||
|
ax.set_ylabel('Y')
|
||||||
|
plt.grid(True)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def plot_kmeans(all_data, k):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.set_xlabel('X')
|
||||||
|
ax.set_ylabel('Y')
|
||||||
|
ax.set_title(f'k={k}')
|
||||||
|
time_text = ax.text(0.05, 0.95, 'iter=0', horizontalalignment='left',
|
||||||
|
verticalalignment='top', transform=ax.transAxes)
|
||||||
|
plt.grid(True)
|
||||||
|
centroid_scatters = []
|
||||||
|
cluster_scatters = []
|
||||||
|
centroids, clusters = all_data[0]
|
||||||
|
for key in clusters:
|
||||||
|
lst_x, lst_y = zip(*clusters[key])
|
||||||
|
lst_x = list(lst_x)
|
||||||
|
lst_y = list(lst_y)
|
||||||
|
color = get_color(key/k)
|
||||||
|
cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color))
|
||||||
|
centroid_scatters.append(ax.scatter([centroids[key][0]], [
|
||||||
|
centroids[key][1]], color=color, marker='X'))
|
||||||
|
|
||||||
|
def update_plot_kmeans(i):
|
||||||
|
centroids, clusters = all_data[i]
|
||||||
|
time_text.set_text(f'iter={i}')
|
||||||
|
for key in clusters:
|
||||||
|
centroid_scatters[key].set_offsets(centroids[key])
|
||||||
|
cluster_scatters[key].set_offsets(clusters[key])
|
||||||
|
return centroid_scatters+cluster_scatters+[time_text, ]
|
||||||
|
anim = FuncAnimation(fig, update_plot_kmeans,
|
||||||
|
frames=len(all_data), blit=True)
|
||||||
|
# anim.save('animation.mp4')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def calc_length(a, b):
|
||||||
|
# return ((b[0]-a[0])**2+(b[1]-a[1])**2)**0.5
|
||||||
|
# no need to calculate square root for comparison
|
||||||
|
return (b[0]-a[0])**2+(b[1]-a[1])**2
|
||||||
|
|
||||||
|
|
||||||
|
def init_centroids(data, k, method='forgy'):
|
||||||
|
match method:
|
||||||
|
case 'forgy':
|
||||||
|
return sample(data, k)
|
||||||
|
case _:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f'method {method} is not implemented yet')
|
||||||
|
|
||||||
|
|
||||||
|
def calc_error(centroids, clusters, k):
|
||||||
|
squared_errors = []
|
||||||
|
for i in range(k):
|
||||||
|
cluster = np.array(clusters[i])
|
||||||
|
centroid = np.array([centroids[i] for _ in range(len(cluster))])
|
||||||
|
errors = centroid - cluster
|
||||||
|
squared_errors.append([e**2 for e in errors])
|
||||||
|
return sum([np.mean(err) for err in squared_errors])
|
||||||
|
|
||||||
|
|
||||||
|
def plot_error_data(error_data):
|
||||||
|
fig, ax = plt.subplots()
|
||||||
|
ax.set_xlabel('k')
|
||||||
|
ax.set_ylabel('err')
|
||||||
|
ax.set_xlim(2, 20)
|
||||||
|
plt.title('Errors')
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
lst_x, lst_y = zip(*error_data)
|
||||||
|
lst_x = list(lst_x)
|
||||||
|
lst_y = list(lst_y)
|
||||||
|
ax.plot(lst_x, lst_y, 'ro-')
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
for get_data in [get_data1, get_data2]:
|
||||||
|
data = get_data()
|
||||||
|
plot_data(data)
|
||||||
|
kmeans_data = {}
|
||||||
|
for k in range(2, 21):
|
||||||
|
kmeans_with_err = []
|
||||||
|
for _ in range(10):
|
||||||
|
all_data = []
|
||||||
|
centroids = init_centroids(data, k)
|
||||||
|
clusters = {}
|
||||||
|
for i in range(k):
|
||||||
|
clusters[i] = []
|
||||||
|
for point in data:
|
||||||
|
lengths = [calc_length(c, point) for c in centroids]
|
||||||
|
index_min = np.argmin(lengths)
|
||||||
|
clusters[index_min].append(point)
|
||||||
|
all_data.append((list(centroids), clusters))
|
||||||
|
for _ in range(100):
|
||||||
|
for key in clusters:
|
||||||
|
if clusters[key]:
|
||||||
|
centroids[key] = np.mean(clusters[key], axis=0)
|
||||||
|
clusters = {}
|
||||||
|
for i in range(k):
|
||||||
|
clusters[i] = []
|
||||||
|
for point in data:
|
||||||
|
lengths = [calc_length(c, point) for c in centroids]
|
||||||
|
index_min = np.argmin(lengths)
|
||||||
|
clusters[index_min].append(point)
|
||||||
|
all_data.append((list(centroids), clusters))
|
||||||
|
if all([all(np.isclose(all_data[-1][0][i], all_data[-2][0][i])) for i in range(k)]):
|
||||||
|
break
|
||||||
|
err = calc_error(centroids, clusters, k)
|
||||||
|
kmeans_with_err.append((all_data, err))
|
||||||
|
min_err = kmeans_with_err[0][1]
|
||||||
|
kmeans = kmeans_with_err[0][0]
|
||||||
|
for temp_kmeans, err in kmeans_with_err:
|
||||||
|
if err < min_err:
|
||||||
|
min_err = err
|
||||||
|
kmeans = temp_kmeans
|
||||||
|
kmeans_data[k] = (kmeans, min_err)
|
||||||
|
plot_kmeans(kmeans, k)
|
||||||
|
error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)]
|
||||||
|
plot_error_data(error_data)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user