Source code for visbrain.objects.connect_obj

"""Base class for objects of type connectivity."""
import logging

import numpy as np
from collections import Counter

from vispy import scene
from vispy.scene import visuals

from .visbrain_obj import VisbrainObject, CombineObjects
from .source_obj import SourceObj
from ..utils import (array2colormap, color2vb, wrap_properties,
                     vector_to_opacity)


logger = logging.getLogger('visbrain')


[docs]class ConnectObj(VisbrainObject): """Create a connectivity object. Parameters ---------- name : string The name of the connectivity object. nodes : array_like Array of nodes coordinates of shape (n_nodes, 3). edges : array_like | None Array of ponderations for edges of shape (n_nodes, n_nodes). select : array_like | None Array to select edges to display. This should be an array of boolean values of shape (n_nodes, n_nodes). line_width : float | 3. Connectivity line width. color_by : {'strength', 'count', 'causal'} Coloring method: * 'strength' : color edges according to their connection strength define by the `edges` input. Only the upper triangle of the connectivity array is considered. * 'count' : color edges according to the number of connections per node. Only the upper triangle of the connectivity array is considered. * 'causal' : color edges according to the connectivity strength but this time, the upper and lower triangles of the connectivity array in `edges` are considered. custom_colors : dict | None Use a dictionary to colorize edges. For example, {1.2: 'red', 2.8: 'green', None: 'black'} turn connections that have a 1.2 and 2.8 strength into red and green. All others connections are set to black. alpha : float | 1. Transparency level (if dynamic is None). antialias : bool | False Use smoothed lines. dynamic : tuple | None Control the dynamic opacity. For example, if dynamic=(0, 1), strong connections will be more opaque than weaker connections. dynamic_order : int | 1 If 1, the dynamic transparency is linearly modulated by the connectivity. If 2, the transparency follow a x**2 curve etc. dynamic_orientation : str | 'ascending' Define the transparency behavior : * 'ascending' : from translucent to opaque * 'center' : from opaque to translucent and finish by opaque * 'descending' ; from opaque to translucent cmap : string | 'viridis' Colormap to use if custom_colors is None. vmin : float | None Lower threshold of the colormap if custom_colors is None. under : string | None Color to use for values under vmin if custom_colors is None. vmin : float | None Higher threshold of the colormap if custom_colors is None. over : string | None Color to use for values over vmax if custom_colors is None. transform : VisPy.visuals.transforms | None VisPy transformation to set to the parent node. parent : VisPy.parent | None Line object parent. verbose : string Verbosity level. _z : float | 10. In case of (n_sources, 2) use _z to specify the elevation. kw : dict | {} Optional arguments are used to control the colorbar (See :class:`ColorbarObj`). Notes ----- List of supported shortcuts : * **s** : save the figure * **<delete>** : reset camera Examples -------- >>> import numpy as np >>> from visbrain.objects import ConnectObj >>> n_nodes = 100 >>> nodes = np.random.rand(n_nodes, 3) >>> edges = np.random.uniform(low=-10., high=10., size=(n_nodes, n_nodes)) >>> select = np.logical_and(edges >= 0, edges <= 1.) >>> c = ConnectObj('Connect', nodes, edges, select=select, cmap='inferno', >>> antialias=True) >>> c.preview(axis=True) """ ########################################################################### ########################################################################### # BUILT IN ########################################################################### ###########################################################################
[docs] def __init__(self, name, nodes, edges, select=None, line_width=3., color_by='strength', custom_colors=None, alpha=1., antialias=False, dynamic=None, dynamic_order=1, dynamic_orientation='ascending', cmap='viridis', clim=None, vmin=None, vmax=None, under='gray', over='red', transform=None, parent=None, verbose=None, _z=-10., **kw): """Init.""" VisbrainObject.__init__(self, name, parent, transform, verbose, **kw) self._update_cbar_args(cmap, clim, vmin, vmax, under, over) # _______________________ CHECKING _______________________ # Nodes : assert isinstance(nodes, np.ndarray) and nodes.ndim == 2 sh = nodes.shape self._n_nodes = sh[0] assert sh[1] >= 2 pos = nodes if sh[1] == 3 else np.c_[nodes, np.full((len(self),), _z)] self._pos = pos.astype(np.float32) logger.info(" %i nodes detected" % self._pos.shape[0]) # Edges : assert edges.shape == (len(self), len(self)) if not np.ma.isMA(edges): mask = np.zeros(edges.shape, dtype=bool) edges = np.ma.masked_array(edges, mask=mask) # Select : if isinstance(select, np.ndarray): assert select.shape == edges.shape and select.dtype == bool edges.mask = np.invert(select) if color_by is not 'causal': edges.mask[np.tril_indices(len(self), 0)] = True edges.mask[np.diag_indices(len(self))] = True self._edges = edges # Colorby : assert color_by in ['strength', 'count', 'causal'] self._color_by = color_by # Dynamic : if dynamic is not None: assert len(dynamic) == 2 self._dynamic = dynamic assert isinstance(dynamic_order, int) and dynamic_order > 0 self._dyn_order = dynamic_order self._dyn_orient = dynamic_orientation # Custom color : if custom_colors is not None: assert isinstance(custom_colors, dict) self._custom_colors = custom_colors # Alpha : assert 0. <= alpha <= 1. self._alpha = alpha # _______________________ LINE _______________________ self._connect = visuals.Line(name='ConnectObjLine', width=line_width, antialias=antialias, parent=self._node, connect='segments') self._connect.set_gl_state('translucent', depth_test=False, cull_face=False) self._build_line()
def __len__(self): """Get the number of nodes.""" return self._n_nodes
[docs] def update(self): """Update the line.""" self._connect.update()
def _build_line(self): """Build the connectivity line.""" pos, edges = self._pos, self._edges # Color either edges or nodes : logger.info(" %s coloring method for connectivity" % self._color_by) # Switch between coloring method : if self._color_by in ['strength', 'count']: # Build line position nnz_x, nnz_y = np.where(~edges.mask) indices = np.c_[nnz_x, nnz_y].flatten() line_pos = pos[indices, :] if self._color_by == 'strength': nnz_values = edges.compressed() values = np.c_[nnz_values, nnz_values].flatten() elif self._color_by == 'count': node_count = Counter(np.ravel([nnz_x, nnz_y])) values = np.array([node_count[k] for k in indices]) elif self._color_by == 'causal': idx = np.array(np.where(~edges.mask)).T # If the array is not symetric, the line needs to be drawn between # points. If it's symetric, line should stop a the middle point. # Here, we get the maske value of the symetric and use it to # ponderate middle point calculation : pond = (~np.array(edges.mask))[idx[:, 1], idx[:, 0]] pond = pond.astype(float).reshape(-1, 1) div = pond + 1. # Build line pos : line_pos = np.zeros((2 * idx.shape[0], 3), dtype=float) line_pos[0::2, :] = pos[idx[:, 0], :] line_pos[1::2, :] = (pos[idx[:, 1]] + pond * pos[idx[:, 0]]) / div # Build values : values = np.full((line_pos.shape[0],), edges.min(), dtype=float) values[1::2] = edges.compressed() logger.info(" %i connectivity links displayed" % line_pos.shape[0]) self._minmax = (values.min(), values.max()) if self._clim is None: self._clim = self._minmax # Get the color according to values : if isinstance(self._custom_colors, dict): # custom color if None in list(self._custom_colors.keys()): # {None : 'color'} color = color2vb(self._custom_colors[None], length=len(values)) else: # black by default color = np.zeros((len(values), 4), dtype=np.float32) for val, col in self._custom_colors.items(): color[values == val, :] = color2vb(col) else: color = array2colormap(values, **self.to_kwargs()) color[:, -1] = self._alpha # Dynamic color : if self._dynamic is not None: color[:, 3] = vector_to_opacity(values, clim=self._clim, dyn=self._dynamic, order=self._dyn_order, orientation=self._dyn_orient) # Send data to the connectivity object : self._connect.set_data(pos=line_pos, color=color)
[docs] def get_nb_connections_per_node(self, sort='index', order='ascending'): """Get the number of connections per node. Parameters ---------- sort : {'index', 'count'} Sort either by node index ('index') or according to the number of connections per node ('count'). order : {'ascending', 'descending'} Get the number of connections per node """ return self._get_nb_connect(self._edges.mask, sort, order)
[docs] def analyse_connections(self, roi_obj='talairach', group_by=None, get_centroids=False, replace_bad=True, bad_patterns=[-1, 'undefined', 'None'], distance=None, replace_with='Not found', keep_only=None): """Analyse connections. Parameters ---------- roi_obj : string/list | 'talairach' The ROI object to use. Use either 'talairach', 'brodmann' or 'aal' to use a predefined ROI template. Otherwise, use a RoiObj object or a list of RoiObj. group_by : str | None Name of the column inside the dataframe for gouping connectivity results. replace_bad : bool | True Replace bad values (True) or not (False). bad_patterns : list | [-1, 'undefined', 'None'] Bad patterns to replace if replace_bad is True. replace_with : string | 'Not found' Replace bad patterns with this string. keep_only : list | None List of string patterns to keep only sources that match. Returns ------- df : pandas.DataFrames A Pandas DataFrame or a list of DataFrames if roi_obj is a list. """ # Get anatomical info of sources : s_obj = SourceObj('analyse', self._pos) df = s_obj.analyse_sources(roi_obj=roi_obj, replace_bad=replace_bad, bad_patterns=bad_patterns, distance=distance, replace_with=replace_with, keep_only=keep_only) # If no column, return the full dataframe : if group_by is None: return df # Group DataFrame column : grp = df.groupby(group_by).groups labels, index = list(grp.keys()), list(grp.values()) # Prepare the new connectivity array : n_labels = len(labels) x_r = np.zeros((n_labels, n_labels), dtype=float) mask_r = np.ones((n_labels, n_labels), dtype=bool) # Loop over the upper triangle : row, col = np.triu_indices(n_labels) data, mask = self._edges.data, self._edges.mask for r, c in zip(row, col): m = tuple(np.meshgrid(index[r], index[c])) x_r[r, c], mask_r[r, c] = data[m].mean(), mask[m].all() # Define a ROI dataframe : import pandas as pd columns = [group_by, "Mean connectivity strength inside ROI", "Number of connections per node"] df_roi = pd.DataFrame({}, columns=columns) df_roi[group_by] = labels df_roi[columns[1]] = np.diag(x_r) df_roi[columns[2]] = [len(k) for k in index] # Get (x, y, z) ROI centroids : if get_centroids: # Define the RoiObj : from .roi_obj import RoiObj if isinstance(roi_obj, str): r_obj = RoiObj(roi_obj) assert isinstance(r_obj, RoiObj) # Search where is the label : idx, roi_labels, rm_rows = [], [], [] for k, l in enumerate(labels): _idx = r_obj.where_is(l, exact=True) if not len(_idx): rm_rows += [k] else: idx += [_idx[0]] roi_labels += [l] xyz = r_obj.get_centroids(idx) x_r = np.delete(x_r, rm_rows, axis=0) x_r = np.delete(x_r, rm_rows, axis=1) mask_r = np.delete(mask_r, rm_rows, axis=0) mask_r = np.delete(mask_r, rm_rows, axis=1) df_roi.drop(rm_rows, inplace=True) df_roi.index = pd.RangeIndex(len(df_roi.index)) df_roi['X'] = xyz[:, 0] df_roi['Y'] = xyz[:, 1] df_roi['Z'] = xyz[:, 2] x_r = np.ma.masked_array(x_r, mask=mask_r) return x_r, labels, df_roi
@staticmethod def _get_nb_connect(mask, sort, order): """Sub-function to get the number of connections per node.""" assert sort in ['index', 'count'], \ ("`sort` should either be 'index' or 'count'") assert order in ['ascending', 'descending'], \ ("`order` should either be 'ascending' or 'descending'") logger.info(" Get the number of connections per node") n_nodes = mask.shape[0] # Get the number of connections per nodes : nnz_x, nnz_y = np.where(~mask) dict_ord = dict(Counter(np.ravel([nnz_x, nnz_y]))) # Full number of connections : nb_connect = np.zeros((n_nodes, 2), dtype=int) nb_connect[:, 0] = np.arange(n_nodes) nb_connect[list(dict_ord.keys()), 1] = list(dict_ord.values()) # Sort according to node index or number of connections per node : idx = 0 if sort is 'index' else 1 args = np.argsort(nb_connect[:, idx]) # Ascending or descending sorting : if order == 'descending': args = np.flip(args) return nb_connect[args, :] def _get_camera(self): """Get the most adapted camera.""" d_mean = self._pos.mean(0) dist = np.sqrt(np.sum(d_mean ** 2)) cam = scene.cameras.TurntableCamera(center=d_mean, scale_factor=dist) self.camera = cam return cam ########################################################################### ########################################################################### # PROPERTIES ########################################################################### ########################################################################### # ----------- LINE_WIDTH ----------- @property def line_width(self): """Get the line_width value.""" return self._connect.width @line_width.setter @wrap_properties def line_width(self, value): """Set line_width value.""" assert isinstance(value, (int, float)) self._connect._width = value self.update() # ----------- COLOR_BY ----------- @property def color_by(self): """Get the color_by value.""" return self._color_by @color_by.setter @wrap_properties def color_by(self, value): """Set color_by value.""" assert value in ['strength', 'count', 'causal'] self._color_by = value self._build_line() # ----------- DYNAMIC ----------- @property def dynamic(self): """Get the dynamic value.""" return self._dynamic @dynamic.setter @wrap_properties def dynamic(self, value): """Set dynamic value.""" assert value is None or len(value) == 2 self._dynamic = value self._build_line() # ----------- ALPHA ----------- @property def alpha(self): """Get the alpha value.""" return self._alpha @alpha.setter @wrap_properties def alpha(self, value): """Set alpha value.""" assert 0. <= value <= 1. self._connect.color[:, -1] = value self._alpha = value self.update()
class CombineConnect(CombineObjects): """Combine connectivity objects. Parameters ---------- cobjs : ConnectObj/list | None List of source objects. select : string | None The name of the connectivity object to select. parent : VisPy.parent | None Markers object parent. """ def __init__(self, cobjs=None, select=None, parent=None): """Init.""" CombineObjects.__init__(self, ConnectObj, cobjs, select, parent)