Data Visualization#

The objectives for this lesson are to interprete basics of visualization theory and apply it on different datasets using python libraries.

Why plots?#

example4

example1

example2

example3

example5

  • 70% of the sensory receptors in the human body are dedicated to vision.

  • Brain has been evolutionarily trained to interpret visual information in a massive way.

_“The eye and the visual cortex of the brain form a massively parallel processor that provides the highest bandwidth channel into human cognitive centers”

Colin Ware, Information Visualization, 2004

Theory#

Human Visual System

  • Good news

    • Graphs provide information that statistics might not.

    • Visual display is essential for understanding.

  • Bad news

    • Attention is selective and can easily be misled.

IO1a

IO1b

IO2a

IO2b

Advice

Noah Illinsky, in his talk Four pillars of visualization, presents good advice on how to perform a correct visualization:

  • Purpose

  • Information/Content

  • Encoding/Structure

  • Format

This indicates that a visualization is not the result of some data. A visualization is designed, thought out, and then appropriate sources of information are sought.

Elements for creating a good visualization

  1. Honesty: Visual representations must not mislead the viewer.

  2. Prioritization: the most important data must use an element of better perception.

  3. Expressiveness: data must use elements with appropriate attributions.

  4. Consistency: visual coding must allow data to be reproduced.

The basic principle to respect is that from the graph one must be able to easily interpret the original data.

Matplotlib#

Matplotlib is a comprehensive library for creating static, animated, and interactive visualizations in Python. Matplotlib makes easy things easy and hard things possible.

  • Create publication quality plots.

  • Make interactive figures that can zoom, pan, update.

  • Customize visual style and layout.

  • Export to many file formats.

  • Embed in JupyterLab and Graphical User Interfaces.

  • Use a rich array of third-party packages built on Matplotlib.

Source: https://matplotlib.org/

matplotlib_anatomy

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

We will use the same Heart Disease dataset that we used in the last activity.

https://archive-beta.ics.uci.edu/dataset/45/heart+disease

heart_df = pd.read_csv(
    "https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data",
    names=[
        "age",
        "sex",
        "cp",
        "trestbps",
        "chol",
        "fbs",
        "restecg",
        "thalach",
        "exang",
        "oldpeak",
        "slope",
        "ca",
        "thal",
        "num",
    ]
)
heart_df.head()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num
0 63.0 1.0 1.0 145.0 233.0 1.0 2.0 150.0 0.0 2.3 3.0 0.0 6.0 0
1 67.0 1.0 4.0 160.0 286.0 0.0 2.0 108.0 1.0 1.5 2.0 3.0 3.0 2
2 67.0 1.0 4.0 120.0 229.0 0.0 2.0 129.0 1.0 2.6 2.0 2.0 7.0 1
3 37.0 1.0 3.0 130.0 250.0 0.0 0.0 187.0 0.0 3.5 3.0 0.0 3.0 0
4 41.0 0.0 2.0 130.0 204.0 0.0 2.0 172.0 0.0 1.4 1.0 0.0 3.0 0

Scatter Plot#

x = heart_df["age"].values
y = heart_df["chol"].values
c = heart_df["sex"].values
fig, ax = plt.subplots()
ax.scatter(x, y)
<matplotlib.collections.PathCollection at 0x7ff231ab97c0>
../../_images/58662fdaed281d64d3de5266035da19a854e3d522a7e1a77fb41238f3b6a7d89.png
fig, ax = plt.subplots()
ax.scatter(x, y, c=c)
<matplotlib.collections.PathCollection at 0x7ff22f732d00>
../../_images/ba3fd65d2a652a0a898bec84c93ae09e4f2862c22272399a6e7e3016a46d4879.png

Line Plot#

Since we don’t have any ordered continuous data in this dataset let’s use synthetic data trying to plot the function

\[ f(t) = e^{-t} \cos(2 \pi t) \]
def f(t):
    return np.exp(-t) * np.cos(2 * np.pi * t)

t1 = np.arange(0.0, 5.0, 0.5)
t2 = np.arange(0.0, 5.0, 0.02)
fig, ax = plt.subplots()

ax.plot(t1, f(t1), marker='o')
ax.grid()
ax.set_title("arange step = 0.5")
Text(0.5, 1.0, 'arange step = 0.5')
../../_images/553ca3a82e2d48f9fbbe5344472e5414c553ae3063fc9dc5f9c55a676c05e297.png
fig, (ax1, ax2) = plt.subplots(nrows=2, figsize=(10, 5), sharex=True, sharey=True)

ax1.plot(t1, f(t1), marker='o')
ax1.grid()
ax1.set_ylabel("f(t)")
ax1.set_title("arange step = 0.5")

ax2.plot(t2, f(t2))
ax2.grid()
ax2.set_ylabel("f(t)")
ax2.set_title("arange step = 0.02")

ax2.set_xlabel("time")
fig.show()
../../_images/6139dde16a180c9beda22c6123378e9cd783c54e35406e7cd40b6d1cd5a914b4.png

Bar plot#

mean_age_per_chest_pain_type = heart_df.groupby("cp")["age"].mean()
mean_age_per_chest_pain_type
cp
1.0    55.869565
2.0    51.360000
3.0    53.697674
4.0    55.722222
Name: age, dtype: float64
chest_pain_types_names = ("typical angina", "atypical angina", "non-anginal pain", "asymptomatic")
chest_pain_types = mean_age_per_chest_pain_type.index.values
mean_age = mean_age_per_chest_pain_type.values
# Another way of plotting without axes
fig = plt.figure(figsize=(8, 4))
plt.barh(chest_pain_types, mean_age, align='center', color="g")
plt.yticks(chest_pain_types, chest_pain_types_names)
plt.xlabel('Mean Age')
plt.ylabel('Chest Pain Type')


plt.show()
../../_images/b2a406300dfa381cefb8750c0399961f48ef75682cdf4152b3168f9818808107.png

Matplotlib follows the philosophy of Imperative Visualization, you are in control of everything and you need to declare every single dot/bar/line/etc.

The Examples Gallery of matplotlib will be your best friend: https://matplotlib.org/stable/gallery/index.html

Seaborn#

Seaborn is a Python data visualization library based on matplotlib. It provides a high-level interface for drawing attractive and informative statistical graphics.

Source: https://seaborn.pydata.org/

import seaborn as sns

sns.set_theme(style="whitegrid")
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[12], line 1
----> 1 import seaborn as sns
      3 sns.set_theme(style="whitegrid")

ModuleNotFoundError: No module named 'seaborn'

Scatter Plot#

fig, ax = plt.subplots(figsize=(6.5, 6.5))
sns.scatterplot(
    x="age",
    y="chol",
    hue="sex",
    data=heart_df,
    ax=ax
)
<AxesSubplot: xlabel='age', ylabel='chol'>
../../_images/149842acfdd53a43ed8c7fddabddc4f6bb246b6856013df651454a2b3fa0a9d8.png

Line Plot#

# Load an example dataset with long-form data
fmri = sns.load_dataset("fmri")
fmri.head()
subject timepoint event region signal
0 s13 18 stim parietal -0.017552
1 s5 14 stim parietal -0.080883
2 s12 18 stim parietal -0.081033
3 s11 18 stim parietal -0.046134
4 s10 18 stim parietal -0.037970
# Plot the responses for different events and regions
sns.lineplot(
    x="timepoint",
    y="signal",
    hue="region",
    style="event",
    data=fmri
)
<AxesSubplot: xlabel='timepoint', ylabel='signal'>
../../_images/352f780bd582996522804f7069f79f50592933ad852bb6b2c5b9900ead2846f4.png

Bar Plot#

chest_pain_type_dict = {
    1: "typical angina",
    2: "atypical angina",
    3: "non-anginal pain",
    4: "asymptomatic"
}

heart_df = heart_df.assign(chest_pain=lambda x: x["cp"].map(chest_pain_type_dict))
heart_df.head()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num chest_pain
0 63.0 1.0 1.0 145.0 233.0 1.0 2.0 150.0 0.0 2.3 3.0 0.0 6.0 0 typical angina
1 67.0 1.0 4.0 160.0 286.0 0.0 2.0 108.0 1.0 1.5 2.0 3.0 3.0 2 asymptomatic
2 67.0 1.0 4.0 120.0 229.0 0.0 2.0 129.0 1.0 2.6 2.0 2.0 7.0 1 asymptomatic
3 37.0 1.0 3.0 130.0 250.0 0.0 0.0 187.0 0.0 3.5 3.0 0.0 3.0 0 non-anginal pain
4 41.0 0.0 2.0 130.0 204.0 0.0 2.0 172.0 0.0 1.4 1.0 0.0 3.0 0 atypical angina
g = sns.barplot(
    data=heart_df,
    x="age",
    y="chest_pain",
    estimator="mean"
    # hue="sex",
    # alpha=.6
)
../../_images/b2ad0da58b9d29cd38b0cb920bb4f8173f305867035d1c63a68408a28435a019.png
g = sns.barplot(
    data=heart_df,
    x="age",
    y="chest_pain",
    hue="sex",
    estimator="mean",
)
../../_images/10b84bdf89fe9b18d2db7d9e0b1e6970213676513a6e7f80e6b8f090b454b6c9.png

Histogram#

g = sns.histplot(
    data=heart_df,
    x="age"
)
../../_images/994d4405c5426ace9e02ccf368e70f45b8a24d005b210dbc978dc91b91c6210f.png
g = sns.histplot(
    data=heart_df,
    x="age",
    kde=True
)
../../_images/64150a2304e00b299c6b14b9f03dd0a78a9d3d6bad0ee64ed70a6e86b9f0c1fe.png
g = sns.histplot(
    data=heart_df,
    x="age",
    hue="sex",
)
../../_images/da88125ef36c4623346f6372e195087ac83330d6fb8542d4a7a7a4d17add80dd.png

Box Plot#

g = sns.boxplot(
    data=heart_df,
    x="chest_pain",
    y="age",
    hue="sex",
)
../../_images/5eed870c66958bd29b65e7937cffe724e105caea40363da3af7f6d9e0160b781.png

Heat Map#

tmp_data = heart_df.pivot_table(
    index="age",
    columns="chest_pain",
    values="trestbps",

)
g = sns.heatmap(tmp_data)
../../_images/3d5e26a79c4dfc05828d38f2e0c09192673118df04fc0d45bcd91b349985959e.png

Seaborn is a nice library for Declarative Visualization, as you noticed the inputs are data and then you only need to declare how that data will be used. For example Seaborn computed all the mean values for us.

More examples here: https://seaborn.pydata.org/examples/index.html