# from napari.layers.base.base import Layer
# from napari.utils.events import Event
# from napari.utils.colormaps import AVAILABLE_COLORMAPS
from typing import Dict, List, Union
from warnings import warn
import numpy as np
from ...utils.colormaps import AVAILABLE_COLORMAPS, Colormap
from ...utils.events import Event
from ..base import Layer
from ._track_utils import TrackManager
[docs]class Tracks(Layer):
    """Tracks layer.
    Parameters
    ----------
    data : array (N, D+1)
        Coordinates for N points in D+1 dimensions. ID,T,(Z),Y,X. The first
        axis is the integer ID of the track. D is either 3 or 4 for planar
        or volumetric timeseries respectively.
    properties : dict {str: array (N,)}, DataFrame
        Properties for each point. Each property should be an array of length N,
        where N is the number of points.
    graph : dict {int: list}
        Graph representing associations between tracks. Dictionary defines the
        mapping between a track ID and the parents of the track. This can be
        one (the track has one parent, and the parent has >=1 child) in the
        case of track splitting, or more than one (the track has multiple
        parents, but only one child) in the case of track merging.
        See examples/tracks_3d_with_graph.py
    color_by: str
        Track property (from property keys) by which to color vertices.
    tail_width : float
        Width of the track tails in pixels.
    tail_length : float
        Length of the track tails in units of time.
    colormap : str
        Default colormap to use to set vertex colors. Specialized colormaps,
        relating to specified properties can be passed to the layer via
        colormaps_dict.
    colormaps_dict : dict {str: napari.utils.Colormap}
        Optional dictionary mapping each property to a colormap for that
        property. This allows each property to be assigned a specific colormap,
        rather than having a global colormap for everything.
    name : str
        Name of the layer.
    metadata : dict
        Layer metadata.
    scale : tuple of float
        Scale factors for the layer.
    translate : tuple of float
        Translation values for the layer.
    rotate : float, 3-tuple of float, or n-D array.
        If a float convert into a 2D rotation matrix using that value as an
        angle. If 3-tuple convert into a 3D rotation matrix, using a yaw,
        pitch, roll convention. Otherwise assume an nD rotation. Angles are
        assumed to be in degrees. They can be converted from radians with
        np.degrees if needed.
    shear : 1-D array or n-D array
        Either a vector of upper triangular values, or an nD shear matrix with
        ones along the main diagonal.
    affine : n-D array or napari.utils.transforms.Affine
        (N+1, N+1) affine transformation matrix in homogeneous coordinates.
        The first (N, N) entries correspond to a linear transform and
        the final column is a lenght N translation vector and a 1 or a napari
        AffineTransform object. If provided then translate, scale, rotate, and
        shear values are ignored.
    opacity : float
        Opacity of the layer visual, between 0.0 and 1.0.
    blending : str
        One of a list of preset blending modes that determines how RGB and
        alpha values of the layer visual get mixed. Allowed values are
        {'opaque', 'translucent', and 'additive'}.
    visible : bool
        Whether the layer visual is currently being displayed.
    """
    # The max number of tracks that will ever be used to render the thumbnail
    # If more tracks are present then they are randomly subsampled
    _max_tracks_thumbnail = 1024
    def __init__(
        self,
        data,
        *,
        properties=None,
        graph=None,
        tail_width=2,
        tail_length=30,
        name=None,
        metadata=None,
        scale=None,
        translate=None,
        rotate=None,
        shear=None,
        affine=None,
        opacity=1,
        blending='additive',
        visible=True,
        colormap='turbo',
        color_by='track_id',
        colormaps_dict=None,
    ):
        # if not provided with any data, set up an empty layer in 2D+t
        if data is None:
            data = np.empty((0, 4))
        else:
            # convert data to a numpy array if it is not already one
            data = np.asarray(data)
        # in absence of properties make the default an empty dict
        if properties is None:
            properties = {}
        # set the track data dimensions (remove ID from data)
        ndim = data.shape[1] - 1
        super().__init__(
            data,
            ndim,
            name=name,
            metadata=metadata,
            scale=scale,
            translate=translate,
            rotate=rotate,
            shear=shear,
            affine=affine,
            opacity=opacity,
            blending=blending,
            visible=visible,
        )
        self.events.add(
            tail_width=Event,
            tail_length=Event,
            display_id=Event,
            display_tail=Event,
            display_graph=Event,
            color_by=Event,
            colormap=Event,
            properties=Event,
            rebuild_tracks=Event,
            rebuild_graph=Event,
        )
        # track manager deals with data slicing, graph building and properties
        self._manager = TrackManager()
        self._track_colors = None
        self._colormaps_dict = colormaps_dict or {}  # additional colormaps
        self._color_by = color_by  # default color by ID
        self._colormap = colormap
        # use this to update shaders when the displayed dims change
        self._current_displayed_dims = None
        # track display properties
        self.tail_width = tail_width
        self.tail_length = tail_length
        self.display_id = False
        self.display_tail = True
        self.display_graph = True
        # set the data, properties and graph
        self.data = data
        self.properties = properties
        self.graph = graph or {}
        self.color_by = color_by
        self.colormap = colormap
        self._update_dims()
        # reset the display before returning
        self._current_displayed_dims = None
    @property
    def _extent_data(self) -> np.ndarray:
        """Extent of layer in data coordinates.
        Returns
        -------
        extent_data : array, shape (2, D)
        """
        if len(self.data) == 0:
            extrema = np.full((2, self.ndim), np.nan)
        else:
            maxs = np.max(self.data, axis=0)
            mins = np.min(self.data, axis=0)
            extrema = np.vstack([mins, maxs])
        return extrema[:, 1:]
    def _get_ndim(self) -> int:
        """Determine number of dimensions of the layer."""
        return self._manager.ndim
    def _get_state(self):
        """Get dictionary of layer state.
        Returns
        -------
        state : dict
            Dictionary of layer state.
        """
        state = self._get_base_state()
        state.update(
            {
                'data': self.data,
                'properties': self.properties,
                'graph': self.graph,
                'color_by': self.color_by,
                'colormap': self.colormap,
                'colormaps_dict': self.colormaps_dict,
                'tail_width': self.tail_width,
                'tail_length': self.tail_length,
            }
        )
        return state
    def _set_view_slice(self):
        """Sets the view given the indices to slice with."""
        # if the displayed dims have changed, update the shader data
        if self._dims_displayed != self._current_displayed_dims:
            # store the new dims
            self._current_displayed_dims = self._dims_displayed
            # fire the events to update the shaders
            self.events.rebuild_tracks()
            self.events.rebuild_graph()
        return
    def _get_value(self, position) -> int:
        """Value of the data at a position in data coordinates.
        Use a kd-tree to lookup the ID of the nearest tree.
        Parameters
        ----------
        position : tuple
            Position in data coordinates.
        Returns
        -------
        value : int or None
            Index of track that is at the current coordinate if any.
        """
        return self._manager.get_value(np.array(position))
    def _update_thumbnail(self):
        """Update thumbnail with current points and colors."""
        colormapped = np.zeros(self._thumbnail_shape)
        colormapped[..., 3] = 1
        if self._view_data is not None and self.track_colors is not None:
            de = self._extent_data
            min_vals = [de[0, i] for i in self._dims_displayed]
            shape = np.ceil(
                [de[1, i] - de[0, i] + 1 for i in self._dims_displayed]
            ).astype(int)
            zoom_factor = np.divide(
                self._thumbnail_shape[:2], shape[-2:]
            ).min()
            if len(self._view_data) > self._max_tracks_thumbnail:
                thumbnail_indices = np.random.randint(
                    0, len(self._view_data), self._max_tracks_thumbnail
                )
                points = self._view_data[thumbnail_indices]
            else:
                points = self._view_data
                thumbnail_indices = range(len(self._view_data))
            # get the track coords here
            coords = np.floor(
                (points[:, :2] - min_vals[1:] + 0.5) * zoom_factor
            ).astype(int)
            coords = np.clip(
                coords, 0, np.subtract(self._thumbnail_shape[:2], 1)
            )
            # modulate track colors as per colormap/current_time
            colors = self.track_colors[thumbnail_indices]
            times = self.track_times[thumbnail_indices]
            alpha = (self.current_time - times) / self.tail_length
            alpha[times > self.current_time] = 1.0
            colors[:, -1] = np.clip(1.0 - alpha, 0.0, 1.0)
            colormapped[coords[:, 1], coords[:, 0]] = colors
        colormapped[..., 3] *= self.opacity
        self.thumbnail = colormapped
    @property
    def _view_data(self):
        """ return a view of the data """
        return self._pad_display_data(self._manager.track_vertices)
    @property
    def _view_graph(self):
        """ return a view of the graph """
        return self._pad_display_data(self._manager.graph_vertices)
    def _pad_display_data(self, vertices):
        """ pad display data when moving between 2d and 3d """
        if vertices is None:
            return
        data = vertices[:, self._dims_displayed]
        # if we're only displaying two dimensions, then pad the display dim
        # with zeros
        if self._ndisplay == 2:
            data = np.pad(data, ((0, 0), (0, 1)), 'constant')
            return data[:, (1, 0, 2)]  # y, x, z -> x, y, z
        else:
            return data[:, (2, 1, 0)]  # z, y, x -> x, y, z
    @property
    def current_time(self):
        """ current time according to the first dimension """
        # TODO(arl): get the correct index here
        time_step = self._slice_indices[0]
        if isinstance(time_step, slice):
            # if we are visualizing all time, then just set to the maximum
            # timestamp of the dataset
            return self._manager.max_time
        return time_step
    @property
    def use_fade(self) -> bool:
        """toggle whether we fade the tail of the track, depending on whether
        the time dimension is displayed"""
        return 0 in self._dims_not_displayed
    @property
    def data(self) -> np.ndarray:
        """ array (N, D+1): Coordinates for N points in D+1 dimensions. """
        return self._manager.data
    @data.setter
    def data(self, data: np.ndarray):
        """ set the data and build the vispy arrays for display """
        # set the data and build the tracks
        self._manager.data = data
        self._manager.build_tracks()
        # reset the properties and recolor the tracks
        self.properties = {}
        self._recolor_tracks()
        # reset the graph
        self._manager.graph = {}
        self._manager.build_graph()
        # fire events to update shaders
        self.events.rebuild_tracks()
        self.events.rebuild_graph()
        self.events.data(value=self.data)
        self._set_editable()
        self._update_dims()
    @property
    def properties(self) -> Dict[str, np.ndarray]:
        """dict {str: np.ndarray (N,)}, DataFrame: Properties for each track."""
        return self._manager.properties
    @property
    def properties_to_color_by(self) -> List[str]:
        """ track properties that can be used for coloring etc... """
        return list(self.properties.keys())
    @properties.setter
    def properties(self, properties: Dict[str, np.ndarray]):
        """ set track properties """
        if self._color_by not in [*properties.keys(), 'track_id']:
            warn(
                (
                    f"Previous color_by key {self._color_by!r} not present in"
                    " new properties. Falling back to track_id"
                ),
                UserWarning,
            )
            self._color_by = 'track_id'
        self._manager.properties = properties
        self.events.properties()
        self.events.color_by()
    @property
    def graph(self) -> Dict[int, Union[int, List[int]]]:
        """dict {int: list}: Graph representing associations between tracks."""
        return self._manager.graph
    @graph.setter
    def graph(self, graph: Dict[int, Union[int, List[int]]]):
        """ Set the track graph. """
        self._manager.graph = graph
        self._manager.build_graph()
        self.events.rebuild_graph()
    @property
    def tail_width(self) -> Union[int, float]:
        """float: Width for all vectors in pixels."""
        return self._tail_width
    @tail_width.setter
    def tail_width(self, tail_width: Union[int, float]):
        self._tail_width = tail_width
        self.events.tail_width()
    @property
    def tail_length(self) -> Union[int, float]:
        """float: Width for all vectors in pixels."""
        return self._tail_length
    @tail_length.setter
    def tail_length(self, tail_length: Union[int, float]):
        self._tail_length = tail_length
        self.events.tail_length()
    @property
    def display_id(self) -> bool:
        """ display the track id """
        return self._display_id
    @display_id.setter
    def display_id(self, value: bool):
        self._display_id = value
        self.events.display_id()
        self.refresh()
    @property
    def display_tail(self) -> bool:
        """ display the track tail """
        return self._display_tail
    @display_tail.setter
    def display_tail(self, value: bool):
        self._display_tail = value
        self.events.display_tail()
    @property
    def display_graph(self) -> bool:
        """ display the graph edges """
        return self._display_graph
    @display_graph.setter
    def display_graph(self, value: bool):
        self._display_graph = value
        self.events.display_graph()
    @property
    def color_by(self) -> str:
        return self._color_by
    @color_by.setter
    def color_by(self, color_by: str):
        """ set the property to color vertices by """
        if color_by not in self.properties_to_color_by:
            raise ValueError(f'{color_by} is not a valid property key')
        self._color_by = color_by
        self._recolor_tracks()
        self.events.color_by()
    @property
    def colormap(self) -> str:
        return self._colormap
    @colormap.setter
    def colormap(self, colormap: str):
        """ set the default colormap """
        if colormap not in AVAILABLE_COLORMAPS:
            raise ValueError(f'Colormap {colormap} not available')
        self._colormap = colormap
        self._recolor_tracks()
        self.events.colormap()
    @property
    def colormaps_dict(self) -> Dict[str, Colormap]:
        return self._colormaps_dict
    @colormaps_dict.setter
    def colomaps_dict(self, colormaps_dict: Dict[str, Colormap]):
        # validate the dictionary entries?
        self._colormaps_dict = colormaps_dict
    def _recolor_tracks(self):
        """ recolor the tracks """
        # this catch prevents a problem coloring the tracks if the data is
        # updated before the properties are. properties should always contain
        # a track_id key
        if self.color_by not in self.properties_to_color_by:
            self.color_by = 'track_id'
        # if we change the coloring, rebuild the vertex colors array
        vertex_properties = self._manager.vertex_properties(self.color_by)
        def _norm(p):
            return (p - np.min(p)) / np.max([1e-10, np.ptp(p)])
        if self.color_by in self.colormaps_dict:
            colormap = self.colormaps_dict[self.color_by]
        else:
            # if we don't have a colormap, get one and scale the properties
            colormap = AVAILABLE_COLORMAPS[self.colormap]
            vertex_properties = _norm(vertex_properties)
        # actually set the vertex colors
        self._track_colors = colormap.map(vertex_properties)
    @property
    def track_connex(self) -> np.ndarray:
        """ vertex connections for drawing track lines """
        return self._manager.track_connex
    @property
    def track_colors(self) -> np.ndarray:
        """return the vertex colors according to the currently selected
        property"""
        return self._track_colors
    @property
    def graph_connex(self) -> np.ndarray:
        """ vertex connections for drawing the graph """
        return self._manager.graph_connex
    @property
    def track_times(self) -> np.ndarray:
        """ time points associated with each track vertex """
        return self._manager.track_times
    @property
    def graph_times(self) -> np.ndarray:
        """ time points assocaite with each graph vertex """
        return self._manager.graph_times
    @property
    def track_labels(self) -> tuple:
        """ return track labels at the current time """
        labels, positions = self._manager.track_labels(self.current_time)
        # if there are no labels, return empty for vispy
        if not labels:
            return None, (None, None)
        padded_positions = self._pad_display_data(positions)
        return labels, padded_positions