A scalable scatter plot extension for Jupyter Lab and Notebook
IMPORTANT: THIS IS VERY EARLY WORK! THE API WILL LIKELY CHANGE. Anyway, you're more than welcome to give the extension a try and let me know what you think :) All feedback is welcome!
Why? Imagine trying to explore an embedding space of millions of data points. Besides plotting the space as a 2D scatter, the exploration typically involves three things: First, we want to interactively adjust the view (e.g., via panning & zooming) and the visual point encoding (e.g., the point color, opacity, or size). Second, we want to be able to select/highlight points. And third, we want to compare multiple embeddings (e.g., via animation, color, or point connections). The goal of jupyter-scatter is to support all three requirements and scale to millions of points.
How? Internally, jupyter-scatter uses regl-scatterplot for rendering and ipywidgets for linking the scatter plot to the iPython kernel.
# Install extension
pip install jupyter-scatter
# Activate extension in Jupyter Lab
jupyter labextension install jupyter-scatter
# Activate extension in Jupyter Notebook
jupyter nbextension install --py --sys-prefix jscatter
jupyter nbextension enable --py --sys-prefix jscatter
For a minimal working example, take a look at test-environment.
In the simplest case, you can pass the x/y coordinates to the plot function as follows:
import jscatter
import numpy as np
x = np.random.rand(500)
y = np.random.rand(500)
jscatter.plot(x, y)
Say your data is stored in a Pandas dataframe like the following:
import pandas as pd
// Just some random float and int values
data = np.random.rand(500, 4)
data[:,3] = np.round(data[:,3] * 7).astype(int)
df = pd.DataFrame(data, columns=['mass', 'speed', 'pval', 'group'])
// We'll convert the group column to categorical data for later use
df['group'] = df['group'].astype('int').astype('category').map(lambda c: chr(65 + c), na_action=None)
x | y | value | group | |
---|---|---|---|---|
0 | 0.13 | 0.27 | 0.51 | G |
1 | 0.87 | 0.93 | 0.80 | B |
2 | 0.10 | 0.25 | 0.25 | F |
3 | 0.03 | 0.90 | 0.01 | G |
4 | 0.19 | 0.78 | 0.65 | D |
You can then visualize this data by referencing column names:
jscatter.plot(data=df, x='mass', y='speed')
Often you want to customize the visual encoding, such as the point color, size, and opacity.
jscatter.plot(
data=df,
x='mass',
y='speed',
size=8, # static encoding
color_by='group', # data-driven encoding
opacity_by='density', # view-driven encoding
)
In the above example, we chose a static point size of 8
. In contrast, the point color is data-driven and assigned based on the categorical group
value. The point opacity is view-driven and defined dynamically by the number of points currently visible in the view.
Also notice how jscatter uses an appropriate color map by default based on the data type used for color encoding. In this examples, jscatter uses the color blindness safe color map from Okabe and Ito as the data type is categorical
and the number of categories is less than 9
.
Important: in order for jscatter to recognize categorical data, the dtype
of the corresponding column needs to be category
!
You can of course customize the color map and many other parameters of the visual encoding as shown next.
The flat API, can get overwhelming when you want to customize a lot of properties. Therefore, jscatter provides a functional API that groups properties by type.
scatter = jscatter.Scatter(data=df, x='mass', y='speed')
scatter.selection(df.query('mass < 0.5').index)
scatter.color(by='mass', map='plasma', order='reverse')
scatter.opacity(by='density')
scatter.size(by='pval', map=[2, 4, 6, 8, 10])
scatter.height(480)
scatter.background('black')
scatter.show()
When you update properties dynamically, i.e., after having called scatter.show()
, the plot will update automatically. For instance, try calling scatter.xy('speed', 'mass')
and you will see how the points are mirrored along the diagonal.
Moreover, all arguments are optional. If you specify arguments, the methods will act as setters and change the properties. If you call a method without any arguments it will act as a getter and return the property (or properties). For example, scatter.selection()
will return the currently selected points.
Finally, the scatter plot is interactive and supports two-way communication. Hence, if you select some point with the lasso tool and then call scatter.selection()
you will get the current selection.
For a complete example, take a look at notebooks/example.ipynb
Setting up a development environment
Requirements:
- Conda >= 4.8
Installation:
git clone https://github.com/flekschas/jupyter-scatter/ jscatter && cd jscatter
conda env create -f environment.yml && conda activate jscatter
pip install -e .
Enable the Notebook Extension:
jupyter nbextension install --py --symlink --sys-prefix jscatter
jupyter nbextension enable --py --sys-prefix jscatter
Enable the Lab Extension:
jupyter labextension develop --overwrite jscatter
After Changing Python code: simply restart the kernel.
After Changing JavaScript code: do cd js && npm run build
and reload the browser tab.