Python

Seaborn

Seaborn is a Python library for creating statistical visualizations using matplotlib.

Seaborn Cheatsheet

Introduction

Seaborn is a Python visualization library based on matplotlib that provides a high-level interface for drawing attractive statistical graphics.

Basic Steps to Creating Plots

  1. Prepare some data
  2. Control figure aesthetics
  3. Plot with Seaborn
  4. Further customize your plot
  5. Show your plot

1. Data

Import Libraries

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

Create Sample Data

# Create random uniform data
uniform_data = np.random.rand(10, 12)
 
# Create sample DataFrame
data = pd.DataFrame({
    'x': np.arange(1, 101),
    'y': np.random.normal(0, 4, 100)
})

Load Built-in Datasets

titanic = sns.load_dataset("titanic")
iris = sns.load_dataset("iris")

2. Figure Aesthetics

Seaborn Styles

# Set default style
sns.set()
 
# Set specific style
sns.set_style("whitegrid")
 
# Set style with parameters
sns.set_style("ticks", {
    "xtick.major.size": 8,
    "ytick.major.size": 8
})
 
# Get style parameters
sns.axes_style("whitegrid")

Context Functions

# Set context
sns.set_context("talk")
 
# Set context with custom parameters
sns.set_context("notebook",
    font_scale=1.5,
    rc={"lines.linewidth": 2.5}
)

Color Palette

# Set color palette
sns.set_palette("husl", 3)
 
# Get current palette
sns.color_palette("husl")
 
# Custom color palette
flatui = ["#9b59b6", "#3498db", "#95a5a6", "#e74c3c", "#34495e", "#2ecc71"]
sns.set_palette(flatui)

3. Plotting With Seaborn

Axis Grids

FacetGrid

# Create FacetGrid
g = sns.FacetGrid(titanic,
    col="survived",
    row="sex"
)
g = g.map(plt.hist, "age")

Plot Types with FacetGrid

# Categorical plot
sns.factorplot(x="pclass",
    y="survived",
    hue="sex",
    data=titanic
)
 
# Regression plot
sns.lmplot(x="sepal_width",
    y="sepal_length",
    hue="species",
    data=iris
)

Distribution Plots

# Simple distribution plot
plot = sns.distplot(data.y,
    kde=False,
    color="b"
)

Matrix Plots

# Heatmap
sns.heatmap(uniform_data, vmin=0, vmax=1)
 
# Pair plots
sns.pairplot(iris)
 
# Pair grid
h = sns.PairGrid(iris)
h = h.map(plt.scatter)

Categorical Plots

Scatter Plots

# Strip plot
sns.stripplot(x="species",
    y="petal_length",
    data=iris
)
 
# Swarm plot
sns.swarmplot(x="species",
    y="petal_length",
    data=iris
)

Bar Plots

# Bar plot
sns.barplot(x="sex",
    y="survived",
    hue="class",
    data=titanic
)
 
# Count plot
sns.countplot(x="deck",
    data=titanic,
    palette="Greens_d"
)

Statistical Plots

# Point plot
sns.pointplot(x="class",
    y="survived",
    hue="sex",
    data=titanic,
    palette={"male": "g", "female": "m"},
    markers=["^", "o"],
    linestyles=[" ", "--"]
)
 
# Box plot
sns.boxplot(x="alive",
    y="age",
    hue="adult_male",
    data=titanic
)
 
# Violin plot
sns.violinplot(x="age",
    y="sex",
    hue="survived",
    data=titanic
)

4. Further Customizations

Axisgrid Objects

# Remove left spine
g.despine(left=True)
 
# Set axis labels
g.set_ylabels("Survived")
g.set_xticklabels(rotation=45)
g.set_axis_labels("Survived", "Sex")
 
# Set axis limits and ticks
h.set(xlim=(0,5),
    ylim=(0,5),
    xticks=[0,2.5,5],
    yticks=[0,2.5,5]
)

Plot Customization

# Title and labels
plt.title("A Title")
plt.ylabel("Survived")
plt.xlabel("Sex")
 
# Axis limits
plt.ylim(0, 100)
plt.xlim(0, 10)
 
# Plot properties
plt.setp(ax, yticks=[0,5])
plt.tight_layout()

5. Show or Save Plot

Display Plot

plt.show()

Save Plot

# Save as PNG
plt.savefig("foo.png")
 
# Save with transparency
plt.savefig("foo.png", transparent=True)

Close & Clear

# Clear axis
plt.cla()
 
# Clear figure
plt.clf()
 
# Close window
plt.close()

Last updated on