Plotting Polars data frames with Seaborn

Recently, I found myself using Polars more and more as my data frame library of choice. Compared to Pandas, it has much better support for complex column types, like list of structs, but also better performance. Pandas however is still clearly the most popular data frame library in the ecosystem. Case in point, Seaborn does not support Polars data frames out of the box. In this post, I would like to walk you through a quick way to combine both libraries.

To interface Polars with Seaborn, we follow these steps:

  1. evaluate any columns to be plotted and collect their names
  2. convert the Polars data frame to Pandas
  3. call Seaborn with the converted data frame and the relevant column names

Thanks to the uniform interface of Seaborn, it is enough to write a single wrapper that performs theses steps:

def call_seaborn(seaborn_func, /, df, **kwargs):
    exprs = []

    # handle arguments that can reference data frame columns
    for key in "x", "y", "hue", "col", "row":
        # skip arguments not used
        val = kwargs.get(key)
        if val is None:
            continue

        # convert strings to Polars expressions
        expr = pl.col(val) if isinstance(val, str) else val

        # collect the expression
        exprs.append(expr)

        # pass them to Seaborn via their names
        kwargs[key] = expr.meta.output_name()

    return (
        df
        # evaluate the expressions
        .select(exprs)
        # convert the data frame to Pandas
        .to_pandas()
        # call Seaborn
        .pipe(seaborn_func, **kwargs)
    )

The full implementation, see below, has some additional features:

  • it de-duplicates expressions to allow columns to be used multiple times
  • it handles both lazy and eager Polars data frames
  • it is packaged as Polars data frame namespace with wrappers for the common Seaborn plotting functions

To demonstrate the functionality, load the penguins data set and convert it to a Polars data frame

df = sns.load_dataset("penguins").pipe(pl.from_pandas)

Plots can use columns names

df.sns.lmplot(
    x="bill_length_mm", y="bill_depth_mm", hue="species",
    height=5
)

or computed expressions

df.sns.boxenplot(
    x=pl.col("species").alias("Species"),
    y=(
        pl.col("body_mass_g") / pl.col("flipper_length_mm") ** 2
    ).alias("BMI $g/mm^2$"),
)

It is also possible to use lazy data frames to combine computations

(
    df.lazy()
    .filter(pl.col("species") != "Gentoo")
    .sns
    .boxenplot(
        x=pl.col("species").alias("Species"),
        y=(
            pl.col("body_mass_g") / pl.col("flipper_length_mm") ** 2
        ).alias("BMI $g/mm^2$"),
    )
)

To use this functionality feel free to copy the code below. You can find the example as Python script here.


# Copyright (c) 2023 Christopher Prohm
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
import functools as ft
from dataclasses import dataclass
from typing import Union

import polars as pl
import seaborn as sns

@pl.api.register_dataframe_namespace("sns")
@pl.api.register_lazyframe_namespace("sns")
@dataclass
class SeabornPlotting:
    df: Union[pl.DataFrame, pl.LazyFrame]

    def pipe(self, func, /, **kwargs):
        def maybe_collect(df):
            return df.collect() if isinstance(df, pl.LazyFrame) else df

        exprs = {}
        for key in "x", "y", "hue", "col", "row":
            val = kwargs.get(key)
            if val is None:
                continue

            expr = pl.col(val) if isinstance(val, str) else val

            exprs[expr.meta.output_name()] = expr
            kwargs[key] = expr.meta.output_name()

        return (
            self.df
            .select(list(exprs.values()))
            .pipe(maybe_collect)
            .to_pandas()
            .pipe(func, **kwargs)
        )

    relplot = ft.partialmethod(pipe, sns.relplot)
    scatterplot = ft.partialmethod(pipe, sns.scatterplot)
    lineplot = ft.partialmethod(pipe, sns.lineplot)
    displot = ft.partialmethod(pipe, sns.displot)
    histplot = ft.partialmethod(pipe, sns.histplot)
    kdeplot = ft.partialmethod(pipe, sns.kdeplot)
    ecdfplot = ft.partialmethod(pipe, sns.ecdfplot)
    rugplot = ft.partialmethod(pipe, sns.rugplot)
    distplot = ft.partialmethod(pipe, sns.distplot)
    catplot = ft.partialmethod(pipe, sns.catplot)
    stripplot = ft.partialmethod(pipe, sns.stripplot)
    swarmplot = ft.partialmethod(pipe, sns.swarmplot)
    boxplot = ft.partialmethod(pipe, sns.boxplot)
    violinplot = ft.partialmethod(pipe, sns.violinplot)
    boxenplot = ft.partialmethod(pipe, sns.boxenplot)
    pointplot = ft.partialmethod(pipe, sns.pointplot)
    barplot = ft.partialmethod(pipe, sns.barplot)
    countplot = ft.partialmethod(pipe, sns.countplot)
    lmplot = ft.partialmethod(pipe, sns.lmplot)
    regplot = ft.partialmethod(pipe, sns.regplot)
    residplot = ft.partialmethod(pipe, sns.residplot)