ML - Clustering#


Develop Countries

World map showing country classifications as per the IMF (International Monetary Fund) and the UN (United Nations) (last updated 2022).

  • Blue: Developed countries

  • Orange: Developing countries

  • Red: Least developed countries

  • Gray: Data unavailable

Most commonly, the criteria for evaluating the degree of economic development are gross domestic product (GDP), gross national product (GNP), the per capita income, level of industrialization, amount of widespread infrastructure and general standard of living.

Question: Can we categorize countries based on these features without having labels from the beginning? Why are there only three categories?

Answer: Clustering! This is how we categorize elements without previous labels.


  • Input: Feature matrix \(X\) and a hyper-parameter \(K\) that determine the number of clusters.

  • Output: Cluster centroids (\(u_k\), \(k=1, 2, \ldots, K\)) and labels for each row of \(X\) that indicate which cluster it belongs to.

\[ \begin{align*} \min_{C_k, \mu_k} \sum_{k=1}^K \sum_{x_i \in C_k} \left\lVert x_i - \mu_k \right\rVert^2_2 \end{align*} \]

This is a NP-hard problem (impossible to solve in polynomial time, the most difficult type of NP problem).

LLoyd Algorithm#

  1. Compute the centroid of the cluster by averaging the positions of the elements currently in the cluster.

  2. Update cluster label of the elements using the closest distance to each centroid.

In this case, one video is worth more than a thousand pictures.

from IPython.display import YouTubeVideo


import numpy as np
import pandas as pd

from pathlib import Path
filepath = ""
# filepath = Path().resolve().parent / "data" / "gapminder.csv"  # If you are running locally
data = pd.read_csv(filepath, usecols=[1, 5, 6])
country life_exp gdp_cap
0 Afghanistan 43.828 974.580338
1 Albania 76.423 5937.029526
2 Algeria 72.301 6223.367465
3 Angola 42.731 4797.231267
4 Argentina 75.320 12779.379640
from sklearn.cluster import KMeans

K = 3
kmeans = KMeans(n_clusters=K)"country"))
/home/alonsolml/mambaforge/envs/casbbi-nrt-ds/lib/python3.11/site-packages/sklearn/cluster/ FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with
data["label"] = kmeans.labels_
country life_exp gdp_cap label
0 Afghanistan 43.828 974.580338 0
1 Albania 76.423 5937.029526 0
2 Algeria 72.301 6223.367465 0
3 Angola 42.731 4797.231267 0
4 Argentina 75.320 12779.379640 2
data.query("label == 2")
country life_exp gdp_cap label
4 Argentina 75.320 12779.379640 2
13 Botswana 50.728 12569.851770 2
14 Brazil 72.390 9065.800825 2
15 Bulgaria 73.005 10680.792820 2
23 Chile 78.553 13171.638850 2
29 Costa Rica 78.782 9645.061420 2
31 Croatia 75.748 14619.222720 2
32 Cuba 78.273 8948.102923 2
33 Czech Republic 76.486 22833.308510 2
40 Equatorial Guinea 51.579 12154.089750 2
45 Gabon 56.735 13206.484520 2
56 Hungary 73.338 18008.944440 2
60 Iran 70.964 11605.714490 2
70 Korea, Rep. 78.623 23348.139730 2
72 Lebanon 71.993 10461.058680 2
75 Libya 73.952 12057.499280 2
78 Malaysia 74.241 12451.655800 2
81 Mauritius 72.801 10956.991120 2
82 Mexico 76.195 11977.574960 2
84 Montenegro 74.543 9253.896111 2
96 Oman 75.640 22316.192870 2
98 Panama 75.537 9809.185636 2
102 Poland 75.563 15389.924680 2
103 Portugal 78.098 20509.647770 2
104 Puerto Rico 78.746 19328.709010 2
106 Romania 72.476 10808.475610 2
109 Saudi Arabia 72.777 21654.831940 2
111 Serbia 74.002 9786.534714 2
114 Slovak Republic 74.663 18678.314350 2
117 South Africa 49.339 9269.657808 2
129 Trinidad and Tobago 69.819 18008.509240 2
131 Turkey 71.777 8458.276384 2
135 Uruguay 76.384 10611.462990 2
136 Venezuela 73.747 11415.805690 2
data.query("label == 1")
country life_exp gdp_cap label
5 Australia 81.235 34435.36744 1
6 Austria 79.829 36126.49270 1
7 Bahrain 75.635 29796.04834 1
9 Belgium 79.441 33692.60508 1
20 Canada 80.653 36319.23501 1
34 Denmark 78.332 35278.41874 1
43 Finland 79.313 33207.08440 1
44 France 80.657 30470.01670 1
47 Germany 79.406 32170.37442 1
49 Greece 79.483 27538.41188 1
55 Hong Kong, China 82.208 39724.97867 1
57 Iceland 81.757 36180.78919 1
62 Ireland 78.885 40675.99635 1
63 Israel 80.745 25523.27710 1
64 Italy 80.546 28569.71970 1
66 Japan 82.603 31656.06806 1
71 Kuwait 77.588 47306.98978 1
90 Netherlands 79.762 36797.93332 1
91 New Zealand 80.204 25185.00911 1
95 Norway 80.196 49357.19017 1
113 Singapore 79.972 47143.17964 1
115 Slovenia 77.926 25768.25759 1
118 Spain 80.941 28821.06370 1
122 Sweden 80.884 33859.74835 1
123 Switzerland 81.701 37506.41907 1
125 Taiwan 78.400 28718.27684 1
133 United Kingdom 79.425 33203.26128 1
134 United States 78.242 42951.65309 1
data.query("label == 0")
country life_exp gdp_cap label
0 Afghanistan 43.828 974.580338 0
1 Albania 76.423 5937.029526 0
2 Algeria 72.301 6223.367465 0
3 Angola 42.731 4797.231267 0
8 Bangladesh 64.062 1391.253792 0
... ... ... ... ...
137 Vietnam 74.249 2441.576404 0
138 West Bank and Gaza 73.422 3025.349798 0
139 Yemen, Rep. 62.698 2280.769906 0
140 Zambia 42.384 1271.211593 0
141 Zimbabwe 43.487 469.709298 0

80 rows × 4 columns

Another Example#

We can compress images using clustering by reducing the number of bytes.

from PIL import Image
import requests

url = ""
im_filapath = requests.get(url,stream=True).raw
im =
K = 8  # Number of clusters
X = np.array(im.getdata())  # Array with image values
kmeans = KMeans(n_clusters=K)
compressed_array = kmeans.cluster_centers_[kmeans.predict(X)]  # Prediction 
im_compressed = compressed_array.astype(np.uint8).reshape(im.size[1], im.size[0], 3)  # New image
Image.fromarray(im_compressed, mode="RGB")
/home/alonsolml/mambaforge/envs/casbbi-nrt-ds/lib/python3.11/site-packages/sklearn/cluster/ FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning