# from napari.layers.base.base import Layer
# from napari.utils.events import Event
# from napari.utils.colormaps import AVAILABLE_COLORMAPS
from typing import Dict, List, Union
from warnings import warn
import numpy as np
from ...utils.colormaps import AVAILABLE_COLORMAPS, Colormap
from ...utils.events import Event
from ..base import Layer
from ._track_utils import TrackManager
[docs]class Tracks(Layer):
"""Tracks layer.
Parameters
----------
data : array (N, D+1)
Coordinates for N points in D+1 dimensions. ID,T,(Z),Y,X. The first
axis is the integer ID of the track. D is either 3 or 4 for planar
or volumetric timeseries respectively.
properties : dict {str: array (N,)}, DataFrame
Properties for each point. Each property should be an array of length N,
where N is the number of points.
graph : dict {int: list}
Graph representing associations between tracks. Dictionary defines the
mapping between a track ID and the parents of the track. This can be
one (the track has one parent, and the parent has >=1 child) in the
case of track splitting, or more than one (the track has multiple
parents, but only one child) in the case of track merging.
See examples/tracks_3d_with_graph.py
color_by: str
Track property (from property keys) by which to color vertices.
tail_width : float
Width of the track tails in pixels.
tail_length : float
Length of the track tails in units of time.
colormap : str
Default colormap to use to set vertex colors. Specialized colormaps,
relating to specified properties can be passed to the layer via
colormaps_dict.
colormaps_dict : dict {str: napari.utils.Colormap}
Optional dictionary mapping each property to a colormap for that
property. This allows each property to be assigned a specific colormap,
rather than having a global colormap for everything.
name : str
Name of the layer.
metadata : dict
Layer metadata.
scale : tuple of float
Scale factors for the layer.
translate : tuple of float
Translation values for the layer.
rotate : float, 3-tuple of float, or n-D array.
If a float convert into a 2D rotation matrix using that value as an
angle. If 3-tuple convert into a 3D rotation matrix, using a yaw,
pitch, roll convention. Otherwise assume an nD rotation. Angles are
assumed to be in degrees. They can be converted from radians with
np.degrees if needed.
shear : 1-D array or n-D array
Either a vector of upper triangular values, or an nD shear matrix with
ones along the main diagonal.
affine : n-D array or napari.utils.transforms.Affine
(N+1, N+1) affine transformation matrix in homogeneous coordinates.
The first (N, N) entries correspond to a linear transform and
the final column is a lenght N translation vector and a 1 or a napari
AffineTransform object. If provided then translate, scale, rotate, and
shear values are ignored.
opacity : float
Opacity of the layer visual, between 0.0 and 1.0.
blending : str
One of a list of preset blending modes that determines how RGB and
alpha values of the layer visual get mixed. Allowed values are
{'opaque', 'translucent', and 'additive'}.
visible : bool
Whether the layer visual is currently being displayed.
"""
# The max number of tracks that will ever be used to render the thumbnail
# If more tracks are present then they are randomly subsampled
_max_tracks_thumbnail = 1024
def __init__(
self,
data,
*,
properties=None,
graph=None,
tail_width=2,
tail_length=30,
name=None,
metadata=None,
scale=None,
translate=None,
rotate=None,
shear=None,
affine=None,
opacity=1,
blending='additive',
visible=True,
colormap='turbo',
color_by='track_id',
colormaps_dict=None,
):
# if not provided with any data, set up an empty layer in 2D+t
if data is None:
data = np.empty((0, 4))
else:
# convert data to a numpy array if it is not already one
data = np.asarray(data)
# in absence of properties make the default an empty dict
if properties is None:
properties = {}
# set the track data dimensions (remove ID from data)
ndim = data.shape[1] - 1
super().__init__(
data,
ndim,
name=name,
metadata=metadata,
scale=scale,
translate=translate,
rotate=rotate,
shear=shear,
affine=affine,
opacity=opacity,
blending=blending,
visible=visible,
)
self.events.add(
tail_width=Event,
tail_length=Event,
display_id=Event,
display_tail=Event,
display_graph=Event,
color_by=Event,
colormap=Event,
properties=Event,
rebuild_tracks=Event,
rebuild_graph=Event,
)
# track manager deals with data slicing, graph building and properties
self._manager = TrackManager()
self._track_colors = None
self._colormaps_dict = colormaps_dict or {} # additional colormaps
self._color_by = color_by # default color by ID
self._colormap = colormap
# use this to update shaders when the displayed dims change
self._current_displayed_dims = None
# track display properties
self.tail_width = tail_width
self.tail_length = tail_length
self.display_id = False
self.display_tail = True
self.display_graph = True
# set the data, properties and graph
self.data = data
self.properties = properties
self.graph = graph or {}
self.color_by = color_by
self.colormap = colormap
self._update_dims()
# reset the display before returning
self._current_displayed_dims = None
@property
def _extent_data(self) -> np.ndarray:
"""Extent of layer in data coordinates.
Returns
-------
extent_data : array, shape (2, D)
"""
if len(self.data) == 0:
extrema = np.full((2, self.ndim), np.nan)
else:
maxs = np.max(self.data, axis=0)
mins = np.min(self.data, axis=0)
extrema = np.vstack([mins, maxs])
return extrema[:, 1:]
def _get_ndim(self) -> int:
"""Determine number of dimensions of the layer."""
return self._manager.ndim
def _get_state(self):
"""Get dictionary of layer state.
Returns
-------
state : dict
Dictionary of layer state.
"""
state = self._get_base_state()
state.update(
{
'data': self.data,
'properties': self.properties,
'graph': self.graph,
'color_by': self.color_by,
'colormap': self.colormap,
'colormaps_dict': self.colormaps_dict,
'tail_width': self.tail_width,
'tail_length': self.tail_length,
}
)
return state
def _set_view_slice(self):
"""Sets the view given the indices to slice with."""
# if the displayed dims have changed, update the shader data
if self._dims_displayed != self._current_displayed_dims:
# store the new dims
self._current_displayed_dims = self._dims_displayed
# fire the events to update the shaders
self.events.rebuild_tracks()
self.events.rebuild_graph()
return
def _get_value(self, position) -> int:
"""Value of the data at a position in data coordinates.
Use a kd-tree to lookup the ID of the nearest tree.
Parameters
----------
position : tuple
Position in data coordinates.
Returns
-------
value : int or None
Index of track that is at the current coordinate if any.
"""
return self._manager.get_value(np.array(position))
def _update_thumbnail(self):
"""Update thumbnail with current points and colors."""
colormapped = np.zeros(self._thumbnail_shape)
colormapped[..., 3] = 1
if self._view_data is not None and self.track_colors is not None:
de = self._extent_data
min_vals = [de[0, i] for i in self._dims_displayed]
shape = np.ceil(
[de[1, i] - de[0, i] + 1 for i in self._dims_displayed]
).astype(int)
zoom_factor = np.divide(
self._thumbnail_shape[:2], shape[-2:]
).min()
if len(self._view_data) > self._max_tracks_thumbnail:
thumbnail_indices = np.random.randint(
0, len(self._view_data), self._max_tracks_thumbnail
)
points = self._view_data[thumbnail_indices]
else:
points = self._view_data
thumbnail_indices = range(len(self._view_data))
# get the track coords here
coords = np.floor(
(points[:, :2] - min_vals[1:] + 0.5) * zoom_factor
).astype(int)
coords = np.clip(
coords, 0, np.subtract(self._thumbnail_shape[:2], 1)
)
# modulate track colors as per colormap/current_time
colors = self.track_colors[thumbnail_indices]
times = self.track_times[thumbnail_indices]
alpha = (self.current_time - times) / self.tail_length
alpha[times > self.current_time] = 1.0
colors[:, -1] = np.clip(1.0 - alpha, 0.0, 1.0)
colormapped[coords[:, 1], coords[:, 0]] = colors
colormapped[..., 3] *= self.opacity
self.thumbnail = colormapped
@property
def _view_data(self):
""" return a view of the data """
return self._pad_display_data(self._manager.track_vertices)
@property
def _view_graph(self):
""" return a view of the graph """
return self._pad_display_data(self._manager.graph_vertices)
def _pad_display_data(self, vertices):
""" pad display data when moving between 2d and 3d """
if vertices is None:
return
data = vertices[:, self._dims_displayed]
# if we're only displaying two dimensions, then pad the display dim
# with zeros
if self._ndisplay == 2:
data = np.pad(data, ((0, 0), (0, 1)), 'constant')
return data[:, (1, 0, 2)] # y, x, z -> x, y, z
else:
return data[:, (2, 1, 0)] # z, y, x -> x, y, z
@property
def current_time(self):
""" current time according to the first dimension """
# TODO(arl): get the correct index here
time_step = self._slice_indices[0]
if isinstance(time_step, slice):
# if we are visualizing all time, then just set to the maximum
# timestamp of the dataset
return self._manager.max_time
return time_step
@property
def use_fade(self) -> bool:
"""toggle whether we fade the tail of the track, depending on whether
the time dimension is displayed"""
return 0 in self._dims_not_displayed
@property
def data(self) -> np.ndarray:
""" array (N, D+1): Coordinates for N points in D+1 dimensions. """
return self._manager.data
@data.setter
def data(self, data: np.ndarray):
""" set the data and build the vispy arrays for display """
# set the data and build the tracks
self._manager.data = data
self._manager.build_tracks()
# reset the properties and recolor the tracks
self.properties = {}
self._recolor_tracks()
# reset the graph
self._manager.graph = {}
self._manager.build_graph()
# fire events to update shaders
self.events.rebuild_tracks()
self.events.rebuild_graph()
self.events.data(value=self.data)
self._set_editable()
self._update_dims()
@property
def properties(self) -> Dict[str, np.ndarray]:
"""dict {str: np.ndarray (N,)}, DataFrame: Properties for each track."""
return self._manager.properties
@property
def properties_to_color_by(self) -> List[str]:
""" track properties that can be used for coloring etc... """
return list(self.properties.keys())
@properties.setter
def properties(self, properties: Dict[str, np.ndarray]):
""" set track properties """
if self._color_by not in [*properties.keys(), 'track_id']:
warn(
(
f"Previous color_by key {self._color_by!r} not present in"
" new properties. Falling back to track_id"
),
UserWarning,
)
self._color_by = 'track_id'
self._manager.properties = properties
self.events.properties()
self.events.color_by()
@property
def graph(self) -> Dict[int, Union[int, List[int]]]:
"""dict {int: list}: Graph representing associations between tracks."""
return self._manager.graph
@graph.setter
def graph(self, graph: Dict[int, Union[int, List[int]]]):
""" Set the track graph. """
self._manager.graph = graph
self._manager.build_graph()
self.events.rebuild_graph()
@property
def tail_width(self) -> Union[int, float]:
"""float: Width for all vectors in pixels."""
return self._tail_width
@tail_width.setter
def tail_width(self, tail_width: Union[int, float]):
self._tail_width = tail_width
self.events.tail_width()
@property
def tail_length(self) -> Union[int, float]:
"""float: Width for all vectors in pixels."""
return self._tail_length
@tail_length.setter
def tail_length(self, tail_length: Union[int, float]):
self._tail_length = tail_length
self.events.tail_length()
@property
def display_id(self) -> bool:
""" display the track id """
return self._display_id
@display_id.setter
def display_id(self, value: bool):
self._display_id = value
self.events.display_id()
self.refresh()
@property
def display_tail(self) -> bool:
""" display the track tail """
return self._display_tail
@display_tail.setter
def display_tail(self, value: bool):
self._display_tail = value
self.events.display_tail()
@property
def display_graph(self) -> bool:
""" display the graph edges """
return self._display_graph
@display_graph.setter
def display_graph(self, value: bool):
self._display_graph = value
self.events.display_graph()
@property
def color_by(self) -> str:
return self._color_by
@color_by.setter
def color_by(self, color_by: str):
""" set the property to color vertices by """
if color_by not in self.properties_to_color_by:
raise ValueError(f'{color_by} is not a valid property key')
self._color_by = color_by
self._recolor_tracks()
self.events.color_by()
@property
def colormap(self) -> str:
return self._colormap
@colormap.setter
def colormap(self, colormap: str):
""" set the default colormap """
if colormap not in AVAILABLE_COLORMAPS:
raise ValueError(f'Colormap {colormap} not available')
self._colormap = colormap
self._recolor_tracks()
self.events.colormap()
@property
def colormaps_dict(self) -> Dict[str, Colormap]:
return self._colormaps_dict
@colormaps_dict.setter
def colomaps_dict(self, colormaps_dict: Dict[str, Colormap]):
# validate the dictionary entries?
self._colormaps_dict = colormaps_dict
def _recolor_tracks(self):
""" recolor the tracks """
# this catch prevents a problem coloring the tracks if the data is
# updated before the properties are. properties should always contain
# a track_id key
if self.color_by not in self.properties_to_color_by:
self.color_by = 'track_id'
# if we change the coloring, rebuild the vertex colors array
vertex_properties = self._manager.vertex_properties(self.color_by)
def _norm(p):
return (p - np.min(p)) / np.max([1e-10, np.ptp(p)])
if self.color_by in self.colormaps_dict:
colormap = self.colormaps_dict[self.color_by]
else:
# if we don't have a colormap, get one and scale the properties
colormap = AVAILABLE_COLORMAPS[self.colormap]
vertex_properties = _norm(vertex_properties)
# actually set the vertex colors
self._track_colors = colormap.map(vertex_properties)
@property
def track_connex(self) -> np.ndarray:
""" vertex connections for drawing track lines """
return self._manager.track_connex
@property
def track_colors(self) -> np.ndarray:
"""return the vertex colors according to the currently selected
property"""
return self._track_colors
@property
def graph_connex(self) -> np.ndarray:
""" vertex connections for drawing the graph """
return self._manager.graph_connex
@property
def track_times(self) -> np.ndarray:
""" time points associated with each track vertex """
return self._manager.track_times
@property
def graph_times(self) -> np.ndarray:
""" time points assocaite with each graph vertex """
return self._manager.graph_times
@property
def track_labels(self) -> tuple:
""" return track labels at the current time """
labels, positions = self._manager.track_labels(self.current_time)
# if there are no labels, return empty for vispy
if not labels:
return None, (None, None)
padded_positions = self._pad_display_data(positions)
return labels, padded_positions