diff --git a/zad3/data1.json b/zad3/data1.json new file mode 100644 index 0000000..763ca3a --- /dev/null +++ b/zad3/data1.json @@ -0,0 +1 @@ +[[0.26707464387488744, 0.874854952853165], [0.705437371776663, -0.34949767414070093], [-0.5123772864623964, 0.6642064994779734], [-0.6177949905888044, -0.08653100644282426], [0.2306447201597256, 0.15997733585871163], [-0.17981113007117286, -0.5662529061011463], [0.5210997923912827, -0.4328726372954992], [0.098444558908785, 0.9463812329618818], [-0.08349153183022817, -0.9536270256945143], [0.4733394651575273, -0.7495341015398066], [-0.6065455160048521, 0.1681939852588296], [0.33182453244095295, -0.5072685611099996], [0.9123852245993803, 0.20912199256357955], [0.5760493149872555, 0.5349177709112336], [-0.34401542908441646, -0.6829244249035414], [-0.5444059397567688, -0.8072194131021861], [-0.348044369392786, -0.7953176738397602], [0.3879884431435828, 0.6301030241070193], [0.3014344191306867, 0.1955985852142356], [0.2008763711882078, -0.6581882714512488], [-0.4298467671032608, 0.07063965484877796], [0.5714395567933921, -0.3411581119355833], [0.5130674745215648, 0.17082236936224793], [-0.31330386010949196, -0.7232347065738357], [0.3910415194714711, -0.8619910578621013], [0.37231010226526, -0.09949715555220831], [0.011172274885089867, 0.6957560233959222], [0.2042916905450732, 0.8933295265723123], [0.8772299632735555, -0.4599154183794484], [-0.6745371421227153, -0.6246224283985878], [-0.8548459622621303, 0.026203763525584668], [0.0003960555111902596, -0.2851993384779069], [0.13132744802058663, -0.8372460708371889], [0.6070920361038301, 0.7481629718139915], [-0.33085832366473544, 0.1265368675267927], [0.27903484181199883, 0.7651311160372127], [0.15800545539286157, -0.2511562767613446], [0.3743138554175838, 0.2591288724936859], [0.5814623997248104, 0.3735091839192642], [-0.43803851777676744, 0.02554483674462911], [-0.6332547100290917, -0.6126079756852014], [0.580874356779032, -0.5044583569611079], [-0.17533686527341164, -0.0041740262222879085], [-0.7944419041111201, 0.21566814287601763], [-0.8513159614224243, -0.45868850291904967], [0.6214933580430171, 0.3343114816776216], [-0.8543237265124555, 0.1106665871050883], [-0.5549208862251612, -0.6489338122175173], [0.45568997517605603, 0.6444683755231256], [-0.6048703175530287, -0.012780126690070353], [-0.16077571375999355, 0.5283910865401149], [0.10544050990938172, -0.5293124040022554], [-0.4167369042996385, 0.19455253960946214], [-0.1674194590747637, 0.571379676102006], [-0.675987752288574, 0.548631195266286], [0.23808787417822755, -0.5401261345428957], [0.875481281158248, 0.36696888311713516], [0.0032516528309900175, -0.1959173726852134], [-0.03264751267940699, -0.4423419510181568], [-0.2079660605331547, 0.4543529329466901], [0.19299466230881132, 0.28830685077653373], [0.5377729862858684, -0.7242490744221382], [0.217577389761842, 0.4612400302420327], [0.8207616971635188, -0.5242199691124612], [0.4150564189238725, 0.12168136108987514], [-0.45147811088354134, -0.3201352197111406], [-0.6375066490596355, -0.3906385783929273], [-0.21694262947054296, -0.4865687734231412], [-0.7536021068296546, -0.05953741704137909], [-0.2077938444479637, -0.820394986674174], [-0.413187840312592, 0.5419512753220819], [-0.0626200267476199, -0.4554300852485216], [-0.579173237508243, 0.38108877820925385], [-0.38639485556039166, 0.8236972812557243], [0.30285172576280256, 0.6973254196563166], [-0.18643077301152702, 0.8104664169281803], [-0.5056702408881234, -0.5645139129285524], [-0.19965409241022655, 0.020567382601630758], [0.7906613651062897, 0.4526155702182911], [0.22926769838935906, -0.09593168962690546], [0.2655874410918289, 0.775614912458297], [0.22446631035820688, 0.633307350304088], [-0.5252416870256676, -0.06105251694880269], [0.5350692099326566, 0.7171627237384379], [-0.07607819853592684, 0.5383928073863181], [-0.23214352774173766, -0.6687475643171613], [-0.945739293667666, 0.0278862483993709], [-0.9516580049184689, -0.20550528518977657], [0.432592032135796, -0.04357131293906445], [-0.15624644689389688, -0.9324537874524754], [0.9245085417241353, -0.37397169960687243], [-0.6264100577306938, -0.2698820209933208], [0.28159509543403355, -0.4085865726487092], [0.5111554547290363, 0.5996614930399004], [0.49207426523039743, 0.7772600457052795], [0.3979500438013701, -0.7575119633086485], [-0.11051877213392854, -0.7379223877397723], [0.25011225416657534, -0.2445013997226545], [0.535125138882435, -0.5058053749999026], [-0.05424425728455065, -0.09152221099526703], [-0.2076310905771738, 0.3906242865798983], [-0.20781512209852343, -0.8321550845773823], [0.804606639403553, -0.19200269746384477], [-0.3737101095721845, -0.4757043370633175], [0.09933587075907778, -0.9911738502964657], [-0.23469955934654554, 0.1187513870199458], [0.13712031648874362, 0.2582211898922505], [-0.3300882243125398, 0.9331843402218332], [0.08406339925746308, 0.31166138120285697], [-0.6025364808434707, -0.546624972487031], [0.030589506521855696, 0.20344277046948397], [-0.4889611552697711, -0.6726500006036894], [0.5259962114486337, 0.3335387807645858], [-0.3137622108076408, -0.22670442965431523], [-0.44041663701966516, -0.5362009295087813], [0.4289380762737359, -0.048396931462774594], [0.1402832721612381, -0.3005390384423581], [0.5518180374815734, -0.5513020444816217], [-0.6070097126426262, -0.6550079398529972], [-0.8545304113079695, -0.007231640909044598], [-0.08338154238719535, -0.2166365910739847], [-0.1961635872834149, -0.11846590845033937], [0.38982849820571147, -0.08116921140415956], [-0.29230039187553, 0.3852987341942868], [-0.23316835881075765, -0.46843184014002553], [0.28512272628547797, -0.9483365028210736], [-0.1603101940447538, -0.22365186300119747], [-0.05223840429221204, 0.19874583359019984], [0.9107689609322829, -0.3266972875245784], [-0.6440074053282822, -0.7448966170017388], [0.491856140519848, -0.14725042837283064], [0.16726756107863908, 0.895634630224659], [0.07730244196043036, -0.2720983094439567], [0.2153401146789877, -0.32970289649346834], [0.8764960643859633, 0.2515791663506035], [0.8708466012093414, 0.018596395338106465], [-0.31383417975851685, -0.5462977493947762], [0.3664271681544787, 0.7887794737819468], [0.8256700081039111, 0.42333404045101136], [0.6272771405060873, -0.4581639888769162], [-0.3934007907867641, 0.3158847851771906], [0.26590206006873024, 0.2139776944534792], [0.2062872294068216, -0.6501984964808394], [0.7860492933093629, 0.41427593342742963], [-0.12931368187382483, 0.7246259301611381], [-0.39461760494605314, 0.8444501205939923], [0.28071962307289605, 0.756034366150398], [0.43394991199815264, 0.8774794786784456], [-0.7691508542109727, 0.2589398437546813], [0.09579132524007974, 0.2202897242078923], [0.0419756576546234, 0.2851294155669188], [-0.5206825514617193, 0.42956759094004604], [0.7258441254347748, 0.4726294763955478], [0.6965162807543556, -0.6893337153195686], [0.9430149787858686, -0.07652252627742206], [0.2973264540027509, 0.1573255247340431], [-0.17432236049189564, 0.8550399548199028], [0.7998607726714959, -0.42236640603834674], [0.7834998977523709, 0.5280985809199821], [0.1758999533469528, 0.5496425899816588], [-0.7848407264598516, -0.37433119334672527], [0.9731192557956453, -0.08341222019687589], [-0.5484007683416774, -0.5754807021960964], [-0.15573072332823795, 0.6818640926455134], [-0.480603524450811, -0.04219705610377576], [0.4820153845413311, 0.5390940165684492], [0.4327404654547406, 0.4624574226997848], [0.8576526155768542, -0.071000300197826], [-0.44079058955980566, -0.1975700482760791], [-0.14980077554127644, 0.7726978304420852], [-0.15451803172992853, 0.09221539239600149], [-0.6825668816326079, 0.2743592085009418], [-0.08621696454063776, -0.6719703783743122], [0.645363614377212, 0.06484922797936292], [-0.44869096596681224, -0.49329119573696434], [-0.05056987301920263, 0.15561814437907556], [0.4877274877308552, -0.5865878277205919], [0.6199083504842986, 0.20159098946156892], [0.21428827727021013, -0.30490520356949363], [-0.19294383016998534, 0.6572588971144847], [0.43206332177260437, -0.4335597085144989], [0.3278468969238784, 0.6825749582585258], [-0.11943115869267891, 0.3344624110574163], [0.47683035026220044, 0.2805787197991178], [-0.7355239533092265, 0.34146320656797546], [-0.5474200793019207, -0.2677571401453806], [-0.6624711459621938, -0.17147727083699515], [0.07210704572265987, -0.7696809013692562], [-0.7178512893709064, -0.5495526519705619], [-0.03245671564899693, -0.1687385357341344], [0.7100367505310209, 0.4349894290349571], [0.38405185254792357, 0.5098819944163963], [0.1901563778159929, -0.7614200952516973], [0.22742031156137307, -0.7152521320464627], [-0.413834733449932, 0.44357088542636713], [0.24638896213563324, 0.15732317617858493], [0.0024132713865741497, 0.4751052090842128], [-0.07908333245546112, 0.39262775955040946], [-0.9571926690537242, -0.019760469155857115], [-0.007295379318281298, -0.9106660061499262]] \ No newline at end of file diff --git a/zad3/data1_errors.png b/zad3/data1_errors.png new file mode 100644 index 0000000..290f4db --- /dev/null +++ b/zad3/data1_errors.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ce4ddc242499f89264fdb8ba0d43d995cc617c2180a17c49be9e3a6a41612a82 +size 20193 diff --git a/zad3/data2.json b/zad3/data2.json new file mode 100644 index 0000000..7818ee5 --- /dev/null +++ b/zad3/data2.json @@ -0,0 +1 @@ +[[3.094031536539631, 0.39478546662283415], [2.7296437622770164, -0.2613398001285549], [2.9699785275240655, 0.4463305680463877], [2.63945978783313, -0.2860601346977574], [2.918014288162325, -0.3568879841771503], [2.9757077054866814, 0.1976143367014519], [3.030438075231262, 0.4250929929744947], [3.389530100177308, -0.28379538762741097], [3.1280427502674355, -0.12422476057554413], [3.451993640418272, 0.11973477733048372], [2.825035233117628, 0.16844395811472174], [3.250405577067777, -0.3676711376419543], [2.9296598477439826, -0.26590506043713863], [2.6794456023856763, -0.12101255993467745], [3.237000886012415, 0.3819004399412394], [2.661815093143124, 0.22666139264932847], [2.6713788682959203, -0.34563361203295895], [3.240482686883239, 0.1677585999982612], [3.0973924359504146, 0.37249814406887993], [2.8380071737589554, 0.07353401757908563], [2.855735732068298, 0.4727569854083527], [2.8220073260916707, -0.0020349015143730337], [3.007589963913807, 0.357216588461496], [2.819100032462428, 0.0026571757221326713], [3.2508800545495435, -0.3498361050839355], [3.2549892755756202, 0.14644705092779428], [3.269429777042236, 0.3465606634114739], [3.0147066002832337, 0.13524061815175958], [2.8697470754307726, 0.1372217971061337], [3.153536760377613, -0.1735954819644789], [2.9196684366844967, 0.3928049158589175], [2.7741725972722118, 0.3174951846161572], [2.6246826297124923, -0.24390228004865894], [2.6465821469347914, 0.10099189983994017], [2.9894162122480594, -0.018957683361648596], [3.272372524749967, -0.3182278360178188], [2.7845186439092413, 0.2330590136221214], [3.3241642272982816, 0.19936259900087505], [2.637239228934715, -0.14945730545751948], [3.0569695904528102, -0.14168723128445732], [3.136663061100583, 0.06909264349911241], [2.8432459255435365, 0.23629856040733016], [2.9964783703067464, 0.36366748306186203], [3.0352668093513127, 0.45785855252148266], [2.7977989505847995, 0.40416634824352954], [2.8036737928213635, -0.22195768515214714], [2.8633099806245417, -0.04172977987118372], [2.8053201565064385, 0.4083142710359155], [3.2937899779903526, 0.02502122682780065], [3.3961471519685023, 0.11548942625563141], [2.8980582533006434, 0.16170524524275945], [3.052396561374095, -0.14751345927434364], [3.4808088724687383, 0.10776147953507777], [3.2320208834063897, 0.13902064532866526], [2.8723401627154423, -0.0618408361738482], [3.0110168800587758, -0.1975882364041726], [3.4435751747516044, 0.17336389672014899], [3.123770891508169, 0.218750700325495], [3.233468416647203, 0.39261047141457234], [3.3088259360819237, -0.02404211923635316], [3.2911137049215187, 0.19993310691994767], [2.9839675824264753, 0.34622554985083004], [2.7681799597580254, 0.37527818718611217], [2.659068813259608, -0.30485871211723553], [2.681049215993094, -0.05365454362877191], [3.3554716203575654, 0.04702582517075264], [2.968515135235644, 0.10167015360697616], [3.0281443502162673, 0.25745990983509415], [2.7522164367568043, -0.05010853276984887], [2.848107514587062, 0.28789197706160663], [2.689297538332607, 0.057032666174576904], [3.251593036340039, -0.006296478982475766], [2.8623357938003884, -0.33271560228443925], [2.9021321837186864, -0.034216129493711514], [2.8783397935406203, -0.20830490016032274], [3.37052245934321, -0.17867832853970655], [3.256738044521217, -0.14306882122389006], [3.051519849699072, -0.37063523405290044], [2.586013195778654, 0.13251703170846652], [2.8092383522372564, -0.3141947031835265], [3.070372515141633, 0.029622219888343418], [3.169710799172379, -0.19918120307397788], [2.6836256482768492, -0.2608653430820558], [2.985741891558439, -0.41565693178489244], [2.5737435060748637, -0.1474698365372057], [3.085649617057237, -0.47876543660040344], [3.4731902976084563, -0.069047382669738], [3.100161356611954, 0.23655154109851398], [3.0371434142328897, 0.27078625809072426], [3.449409419239475, 0.13355750906227024], [2.6647132454964297, 0.2932420879849981], [2.8122583414926825, 0.40098434255985227], [3.2520531486756967, 0.21151280977179407], [3.3970822388918718, -0.024129974958073622], [2.922083319441422, 0.1404386700531341], [3.2895817368399043, 0.3959636346306004], [3.073212126649572, -0.003508802617335682], [3.4527549459670466, -0.13271925061247078], [3.478178772683139, 0.09114978420919956], [2.893451017778235, -0.4373461986385599], [-2.984661213344486, 0.3875116986004827], [-2.611250198500886, -0.10896780970899189], [-2.83778554430435, -0.339613851311361], [-3.186936813845085, 0.1508338183285309], [-2.7775126048753775, 0.22246292314371108], [-2.8693125105494026, -0.34663681920880113], [-3.301292170394524, 0.2812956497152523], [-2.863555961893672, 0.005846716286817034], [-3.1400519769700734, 0.16912627936262317], [-3.3579728826709716, 0.18339066269473084], [-2.9794415309707456, 0.371048505060653], [-2.9619734917349487, -0.4162590320591322], [-3.1916434843484494, 0.3930596800257602], [-2.739676519020144, -0.2746237164651906], [-3.477287708280446, 0.08794583763286529], [-2.641179337577712, -0.24656055869765633], [-3.192958828685941, -0.40596441987355275], [-2.66728018534682, 0.015296292921890971], [-3.243394681782726, 0.10914935090626089], [-2.5767087902449686, 0.018994912120394444], [-3.046120594438075, 0.24505475044899694], [-2.8713012749005804, 0.05812974957644922], [-2.919742470528474, -0.11427546984819405], [-3.2280803942222254, 0.29531970707687344], [-2.7923064976652885, 0.2876803068611328], [-2.6661635242567425, 0.06392652067534244], [-2.8953909180322936, -0.46805641950266147], [-2.8405518401536907, 0.4539577714105206], [-2.9643072281192095, 0.49576918695805], [-3.1776044128951595, 0.15788776534829027], [-2.5119394469075167, 0.09146153035024664], [-2.808327177612772, -0.3193890078204129], [-3.145685547182948, -0.10268648823123505], [-3.1681732349983465, 0.21764236848494273], [-3.2273337385007634, 0.20143052611738516], [-3.293054609846102, -0.12775497188853152], [-3.4181473943166676, -0.11936026002227963], [-2.9854856336987905, -0.058646691635120145], [-3.12324790932266, 0.0724753631883691], [-2.734801869963408, 0.186108901818203], [-3.233654440356901, 0.1739772970072226], [-3.1732921252809234, 0.0785963160474773], [-2.592929360195901, -0.0008697400348726154], [-2.987739366344891, 0.3763207964230079], [-2.9868452789526687, -0.13323919067887582], [-2.8269434436899252, -0.1370143514621371], [-3.393552903260004, -0.25127504544422136], [-3.18573137056079, -0.24541635139031076], [-2.812378646039027, 0.25597213923369605], [-2.5546310976475106, -0.21179261713096403], [-2.7127053938017753, 0.25945780539627517], [-3.042619727023157, 0.3218361308505062], [-2.6191384550299905, 0.22030429326165332], [-2.6588562737455717, -0.22697491093214317], [-3.32454514542011, 0.05825464269714356], [-2.737878522815576, 0.15617247487410915], [-2.61286073710987, -0.08792486137119064], [-2.725256114187041, 0.35509737798340385], [-2.552421791685839, 0.19987662809885076], [-3.1705068727578745, -0.10879966386664662], [-2.9246883344324823, 0.29450397032336273], [-2.6849287864788867, -0.2501522740327516], [-3.134549732214868, 0.24866652991145977], [-3.349985596435894, 0.14676816933438713], [-3.4512250960186965, -0.11456847714892437], [-3.189767847888351, 0.04811038004418914], [-3.1800350188076187, 0.41503807650395735], [-2.889347790242332, 0.10117300853769502], [-3.407514580224879, 0.1148442248714858], [-2.638066220215395, -0.19221983313237395], [-2.768419727901171, 0.3068404691134517], [-2.5529014744225655, 0.0909389460801169], [-3.1205909226384008, 0.03278732824746637], [-3.20740923255135, -0.41138779834111655], [-2.9002919587629785, -0.4508037670919345], [-2.6788354159452887, -0.3378152019068404], [-3.1394502892455636, -0.16792947872411246], [-2.608904732781579, 0.09704789682820804], [-2.670526869776713, -0.14413113310196823], [-3.087606327613503, -0.4704164151817391], [-3.1021139014913386, 0.43828758779774557], [-2.727284797406029, -0.13833010046518265], [-3.0893679308834443, 0.30969673013938404], [-2.6286455012661873, 0.041785307705411535], [-2.6992017234281747, -0.39429175937351824], [-3.2034814132240705, 0.1692125654157948], [-2.9177739815158388, -0.3839218347527531], [-3.269781028177342, 0.06555798500274322], [-2.916449359832331, -0.15600293952455863], [-3.1304632341282743, -0.37002790566139476], [-3.0423607822154057, -0.19438034492019438], [-3.3219599062569656, -0.10562088889619237], [-2.831192371032275, 0.3568987800579648], [-2.6404856144561024, 0.20506082317330807], [-3.0010621180822685, 0.34669633155897933], [-2.8557056080926624, 0.21155027147580938], [-2.6323820641815723, -0.1962197433392866], [-3.32019282368624, -0.18822375581559767], [-2.552299053442381, 0.022675498993799725], [-3.2679792167958905, -0.33509612824806023]] \ No newline at end of file diff --git a/zad3/data2_errors.png b/zad3/data2_errors.png new file mode 100644 index 0000000..bce4f2c --- /dev/null +++ b/zad3/data2_errors.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c73cce33452f045cd249468cfa1de26a3e9048a0452680aceb217102ea9fe0e5 +size 23167 diff --git a/zad3/ml_195642_zad3.odt b/zad3/ml_195642_zad3.odt new file mode 100644 index 0000000..b1b0004 --- /dev/null +++ b/zad3/ml_195642_zad3.odt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e7a4a417d8c0c27e15b86dc40de94e044ee3152320f923b2021c155c355a633 +size 708612 diff --git a/zad3/zad3.py b/zad3/zad3.py index 0817f95..9589a2f 100644 --- a/zad3/zad3.py +++ b/zad3/zad3.py @@ -1,8 +1,12 @@ import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation -from random import sample +from random import sample, shuffle from generate_points import get_random_point import numpy as np +import json + + +METHODS = ['forgy', 'random_partition'] def get_color(i): @@ -49,11 +53,12 @@ def plot_kmeans(all_data, k): cluster_scatters = [] centroids, clusters = all_data[0] for key in clusters: - lst_x, lst_y = zip(*clusters[key]) - lst_x = list(lst_x) - lst_y = list(lst_y) color = get_color(key/k) - cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color)) + if clusters[key]: + lst_x, lst_y = zip(*clusters[key]) + lst_x = list(lst_x) + lst_y = list(lst_y) + cluster_scatters.append(ax.scatter(lst_x, lst_y, color=color)) centroid_scatters.append(ax.scatter([centroids[key][0]], [ centroids[key][1]], color=color, marker='X')) @@ -77,10 +82,16 @@ def calc_length(a, b): return (b[0]-a[0])**2+(b[1]-a[1])**2 -def init_centroids(data, k, method='forgy'): #TODO: Add k-means++ and Random Partition +def init_centroids(data, k, method='forgy'): # TODO: Add k-means++ and Random Partition match method: case 'forgy': return sample(data, k) + case 'random_partition': + shuffled = list(data) + shuffle(shuffled) + div = len(shuffled)/k + partition = [shuffled[int(round(div*i)):int(round(div*(i+1)))] for i in range(k)] + return [np.mean(prt, axis=0) for prt in partition] case _: raise NotImplementedError( f'method {method} is not implemented yet') @@ -91,9 +102,9 @@ def calc_error(centroids, clusters, k): for i in range(k): cluster = np.array(clusters[i]) centroid = np.array([centroids[i] for _ in range(len(cluster))]) - errors = centroid - cluster + errors = cluster - centroid squared_errors.append([e**2 for e in errors]) - return sum([np.mean(err) for err in squared_errors]) + return sum([np.mean(err) if err else 0 for err in squared_errors]) def plot_error_data(error_data): @@ -112,28 +123,29 @@ def plot_error_data(error_data): plt.show() -def main(): - for get_data in [get_data1, get_data2]: - data = get_data() +def print_stats(k, data): + print('='*20) + print(f'k={k}') + errs = [x[1] for x in data] + m = np.mean(errs) + std = np.std(errs) + min_err = np.min(errs) + lst_empty = [sum([1 for cluster in centroids_with_clusters[1] if not cluster]) for centroids_with_clusters,_ in data] + print(lst_empty) + + +def main(datas): + # for get_data in [get_data1, get_data2]: + # data = get_data() + for data in datas: plot_data(data) - kmeans_data = {} - for k in range(2, 21): - kmeans_with_err = [] - for _ in range(100): - all_data = [] - centroids = init_centroids(data, k) - clusters = {} - for i in range(k): - clusters[i] = [] - for point in data: - lengths = [calc_length(c, point) for c in centroids] - index_min = np.argmin(lengths) - clusters[index_min].append(point) - all_data.append((list(centroids), clusters)) + for method in METHODS: + kmeans_data = {} + for k in [20]: # range(2, 21): + kmeans_with_err = [] for _ in range(100): - for key in clusters: - if clusters[key]: - centroids[key] = np.mean(clusters[key], axis=0) + centroids_with_clusters = [] + centroids = init_centroids(data, k, method=method) clusters = {} for i in range(k): clusters[i] = [] @@ -141,22 +153,42 @@ def main(): lengths = [calc_length(c, point) for c in centroids] index_min = np.argmin(lengths) clusters[index_min].append(point) - all_data.append((list(centroids), clusters)) - if all([all(np.isclose(all_data[-1][0][i], all_data[-2][0][i])) for i in range(k)]): - break - err = calc_error(centroids, clusters, k) - kmeans_with_err.append((all_data, err)) - min_err = kmeans_with_err[0][1] - kmeans = kmeans_with_err[0][0] - for temp_kmeans, err in kmeans_with_err: - if err < min_err: - min_err = err - kmeans = temp_kmeans - kmeans_data[k] = (kmeans, min_err) - plot_kmeans(kmeans, k) - error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)] - plot_error_data(error_data) + centroids_with_clusters.append((list(centroids), clusters)) + for _ in range(100): + for key in clusters: + if clusters[key]: + centroids[key] = np.mean(clusters[key], axis=0) + clusters = {} + for i in range(k): + clusters[i] = [] + for point in data: + lengths = [calc_length(c, point) + for c in centroids] + index_min = np.argmin(lengths) + clusters[index_min].append(point) + centroids_with_clusters.append( + (list(centroids), clusters)) + if all([all(np.isclose(centroids_with_clusters[-1][0][i], centroids_with_clusters[-2][0][i])) for i in range(k)]): + break + err = calc_error(centroids, clusters, k) + kmeans_with_err.append((centroids_with_clusters, err)) + print_stats(k, [(iterations[-1],err) for iterations, err in kmeans_with_err]) + min_err = kmeans_with_err[0][1] + kmeans = kmeans_with_err[0][0] + for temp_kmeans, err in kmeans_with_err: + if err < min_err: + min_err = err + kmeans = temp_kmeans + kmeans_data[k]=(kmeans, min_err) + plot_kmeans(kmeans, k) + #error_data = [[i, kmeans_data[i][1]] for i in range(2, 21, 2)] + #plot_error_data(error_data) if __name__ == '__main__': - main() + datas = [] + with open('data1.json', 'r') as d: + datas.append(json.loads(d.read())) + with open('data2.json', 'r') as d: + datas.append(json.loads(d.read())) + main(datas)