# -*- coding: utf-8 -*-
import inspect
import warnings
import numpy as np
import matplotlib.pyplot as plt
import scipy.io as si
import scipy.signal as ss
from .mvarmodel import Mvar
from .conn import *
from .load.loaders import signalml_loader
from six.moves import range
[docs]class Data(object):
'''
Class governing the communication between data array and
connectivity estimators.
Args:
*data* : numpy.array or str
* array with data (kXNxR, k - channels nr, N - data points,
R - nr of trials)
* str - path to file with appropieate format
*fs* = 1: int
sampling frequency
*chan_names* = []: list
names of channels
*data_info* = '': string
other information about the data
'''
def __init__(self, data, fs=1., chan_names=[], data_info=''):
self.__fs = fs
self.__data = self._load_file(data, data_info)
self.__channames = ["x"+str(i) for i in range(self.__chans_number)]
if self.__data.shape[0] == len(chan_names):
self.__channames = chan_names
self.__channames_original = self.__channames
self.data_info = data_info
self._parameters = {}
self._parameters["mvar"] = False
def _load_file(self, data_what, data_info):
'''
Data loader.
Args:
*data_what* : str/numpy.array
path to file with appropieate format or numpy data array
*data_info* = '' : str
additional file with data settings if needed
Returns:
*data* : np.array
'''
dt_type = type(data_what)
if dt_type == np.ndarray:
data = data_what
elif dt_type == str:
dt_type = data_what.split('.')[-1] # catch file extension
if dt_type == 'mat':
mat_dict = si.loadmat(data_what)
if data_info:
key = data_info
else:
key = data_what[:-4].split('/')[-1]
data = mat_dict[key]
if dt_type == 'raw' and data_info == 'sml':
data, sml = signalml_loader(data_what[:-4])
self.__fs = sml['samplingFrequency']
self.__channames = sml['channelNames']
self.smldict = sml # here SignalML data is stored
else:
return False
if data.ndim > 2:
self.__multitrial = data.shape[2]
else:
self.__multitrial = False
# in original number of channels after loading is stored,
# self.__chans_number can be modified
self.__chans_number_original = self.__chans_number = data.shape[0]
self._channels = np.arange(self.__chans_number)
self.__length = data.shape[1]
return data
[docs] def select_channels(self, channels=None):
'''
Selecting channels to plot or further analysis.
Args:
*channels* : list(int)
List of channel indices. If None all channels are taken
into account.
'''
if np.max(channels) >= self.__chans_number_original:
raise ValueError("Indices are not correct")
if channels is None:
self._channels = np.arange(self.__chans_number_original)
self.__chans_number = self.__chans_number_original
self.__channames = self.__channames_original
else:
self._channels = channels
self.__chans_number = len(channels)
self.__channames = [i for e, i in enumerate(self.__channames_original) if e in channels]
[docs] def filter(self, b, a):
'''
Filter each channel of data using forward-backward filter
*filtfilt* from *scipy.signal*.
Args:
*b, a* : np.array
Numerator *b* / denominator *a* polynomials of
the IIR filter.
'''
if self.__multitrial:
for r in range(self.__multitrial):
self.__data[:, :, r] = ss.filtfilt(b, a, self.__data)
else:
self.__data = ss.filtfilt(b, a, self.__data)
[docs] def resample(self, fs_new):
'''
Signal resampling to new sampling frequency *new_fs* using
*resample* function from *scipy.signal* (basing on Fourier method).
Args:
*fs_new* : int
new sampling frequency
'''
new_nr_samples = int((self.__length*1./self.__fs)*fs_new)
self.__data = ss.resample(self.__data, new_nr_samples, axis=1)
self.__fs = fs_new
[docs] def fit_mvar(self, p=None, method='yw'):
'''
Fitting MVAR coefficients.
Args:
*p* = None : int
estimation order, default None
*method* = 'yw' : str {'yw', 'ns', 'vm'}
method of MVAR parameters estimation
all avaiable methods you can find in *fitting_algorithms*
'''
self.__Ar, self.__Vr = Mvar().fit(self.__data, p, method)
self._parameters["mvar"] = True
self._parameters["p"] = p
self._parameters["mvarmethod"] = method
[docs] def conn(self, method, **params):
'''
Estimate connectivity pattern.
Args:
*p* = None : int
estimation order, default None
*method* : str
method of connectivity estimation
all avaiable methods you can find in *conn_estim_dc*
'''
connobj = conn_estim_dc[method]()
if isinstance(connobj, ConnectAR):
self.__estim = connobj.calculate(self.__Ar, self.__Vr,
self.__fs, **params)
else:
if not self.__multitrial:
self.__estim = connobj.calculate(self.__data[self._channels, :], **params)
else:
for r in range(self.__multitrial):
if r == 0:
self.__estim = connobj.calculate(self.__data[self._channels, :, r], **params)
continue
self.__estim += connobj.calculate(self.__data[self._channels, :, r], **params)
self.__estim = self.__estim/self.__multitrial
self._parameters["method"] = method
self._parameters["y_lim"] = connobj.values_range
self._parameters.update(params)
return self.__estim
[docs] def short_time_conn(self, method, nfft=None, no=None, **params):
'''
Short-time connectivity.
Args:
*method* = 'yw' : str {'yw', 'ns', 'vm'}
method of estimation
all avaiable methods you can find in *fitting_algorithms*
*nfft* = None : int
number of data points in window; if None, it is signal length
N/5.
*no* = None : int
number of data points in overlap; if None, it is signal length
N/10.
*params*
other parameters for specific estimator
'''
connobj = conn_estim_dc[method]()
self._parameters.update(params)
arg = inspect.getargspec(connobj.calculate)
newparams = self.__make_params_dict(arg[0])
if "p" not in self._parameters:
if "order" in params:
self._parameters["p"] = params["order"]
else:
self._parameters["p"] = None
if not nfft:
nfft = int(self.__length/5)
if not no:
no = int(self.__length/10)
if "resolution" not in self._parameters:
self._parameters["resolution"] = 100
if isinstance(connobj, ConnectAR):
self.__shtimest = connobj.short_time(self.__data[self._channels, :], nfft=nfft, no=no,
fs=self.__fs, order=self._parameters["p"],
resol=self._parameters["resolution"])
else:
if self.__multitrial:
for r in range(self.__multitrial):
if r == 0:
self.__shtimest = connobj.short_time(self.__data[self._channels, :, r], nfft=nfft, no=no, **newparams)
continue
self.__shtimest += connobj.short_time(self.__data[self._channels, :, r], nfft=nfft, no=no, **newparams)
self.__shtimest /= self.__multitrial
else:
self.__shtimest = connobj.short_time(self.__data[self._channels, :], nfft=nfft, no=no, **newparams)
self._parameters["shorttime"] = method
self._parameters["nfft"] = nfft
self._parameters["no"] = no
return self.__shtimest
[docs] def significance(self, Nrep=100, alpha=0.05, verbose=True, **params):
'''
Statistical significance values of connectivity estimation method.
Args:
*Nrep* = 100 : int
number of resamples
*alpha* = 0.05 : float
type I error rate (significance level)
*verbose* = True : bool
if True it prints dot on every realization
Returns:
*signi*: numpy.array
matrix in shape of (k, k) with values for each pair of
channels
'''
connobj = conn_estim_dc[self._parameters["method"]]()
self._parameters.update(params)
arg = inspect.getargspec(connobj.calculate)
newparams = self.__make_params_dict(arg[0])
if not self.__multitrial:
if isinstance(connobj, ConnectAR):
self.__signific = connobj.surrogate(self.__data[self._channels, :], Nrep=Nrep, alpha=alpha,
method=self._parameters["mvarmethod"],
fs=self.__fs, order=self._parameters["p"],
verbose=verbose, **newparams)
else:
self.__signific = connobj.surrogate(self.__data[self._channels, :], Nrep=Nrep,
alpha=alpha, verbose=verbose,
**newparams)
else:
if isinstance(connobj, ConnectAR):
self.__signific = connobj.bootstrap(self.__data[self._channels, :, :], Nrep=Nrep, alpha=alpha,
method=self._parameters["mvarmethod"],
fs=self.__fs, order=self._parameters["p"],
verbose=verbose, **newparams)
else:
self.__signific = connobj.bootstrap(self.__data[self._channels, :, :], Nrep=Nrep,
alpha=alpha, verbose=verbose,
**newparams)
return self.__signific
[docs] def short_time_significance(self, Nrep=100, alpha=0.05, nfft=None,
no=None, verbose=True, **params):
'''
Statistical significance values of short-time version of
connectivity estimation method.
Args:
*Nrep* = 100 : int
number of resamples
*alpha* = 0.05 : float
type I error rate (significance level)
*nfft* = None : int
number of data points in window; if None, it is taken from
:func:`Data.short_time_conn` method.
*no* = None : int
number of data points in overlap; if None, it is taken from
*short_time_conn* method.
*verbose* = True : bool
if True it prints dot on every realization
Returns:
*signi*: numpy.array
matrix in shape of (k, k) with values for each pair of
channels
'''
if not nfft:
nfft = self._parameters["nfft"]
if not no:
no = self._parameters["no"]
connobj = conn_estim_dc[self._parameters["shorttime"]]()
self._parameters.update(params)
arg = inspect.getargspec(connobj.calculate)
newparams = self.__make_params_dict(arg[0])
if self.__multitrial:
temp_dat = self.__data[self._channels, :, :]
else:
temp_dat = self.__data[self._channels, :]
if isinstance(connobj, ConnectAR):
self.__st_signific = connobj.short_time_significance(temp_dat, Nrep=Nrep, alpha=alpha,
method=self._parameters["mvarmethod"],
fs=self.__fs, order=self._parameters["p"],
nfft=nfft, no=no, verbose=verbose, **newparams)
else:
self.__st_signific = connobj.short_time_significance(temp_dat, Nrep=Nrep,
nfft=nfft, no=no,
alpha=alpha, verbose=verbose, **newparams)
return self.__st_signific
[docs] def plot_data(self, trial=0, show=True):
'''
Plot data in a subplot for each channel.
Args:
*trial* = 0 : int
if there is multichannel data it should be a number of trial
you want to plot.
*show* = True : boolean
show the plot or not
'''
time = np.arange(0, self.__length)*1./self.__fs
if self.__multitrial:
plotdata = self.__data[self._channels, :, trial]
else:
plotdata = self.__data[self._channels, :]
if self.__chans_number>10:
warnings.warn("""Number of channels > 10.
Consider picking only some channels.""", Warning)
fig, axes = plt.subplots(self.__chans_number, 1)
for i in np.arange(self.__chans_number):
axes[i].plot(time, plotdata[i, :], 'g')
if self.__channames:
axes[i].set_title(self.__channames[i])
if show:
plt.show()
[docs] def plot_conn(self, name='', ylim=None, xlim=None, signi=True,
show=True):
'''
Plot connectivity estimation results.
Args:
*name* = '' : str
title of the plot
*ylim* = None : list
range of y-axis values shown, e.g. [0,1]
*None* means that default values of given estimator are taken
into account
*xlim* = None : list [from (int), to (int)]
range of y-axis values shown, if None it is from 0 to Nyquist frequency
*signi* = True : boolean
if significance levels are calculated they are shown in the plot
*show* = True : boolean
show the plot or not
'''
assert hasattr(self, '_Data__estim') is True, "No valid data!, Use calculation method first."
fig, axes = plt.subplots(self.__chans_number, self.__chans_number)
freqs = np.linspace(0, self.__fs//2, self.__estim.shape[0])
if not xlim:
xlim = [0, np.max(freqs)]
two_sides = False
if signi and hasattr(self, '_Data__signific'):
flag_sig = True
if self.__signific.ndim > 2:
two_sides = True
else:
flag_sig = False
if not ylim:
ylim = self._parameters["y_lim"]
if ylim[0] is None:
ylim[0] = np.min(self.__estim)
if flag_sig:
ylim[0] = np.min((ylim[0], np.min(self.__signific)))
if ylim[1] is None:
ylim[1] = np.max(self.__estim)
if flag_sig:
ylim[1] = np.max((ylim[1], np.max(self.__signific)))
# selecting right channels
if self.__estim.shape[-1] != self.__chans_number:
estim = self.__estim[:, [[c] for c in self._channels], self._channels]
if two_sides:
signific = self.__signific[:, [[c] for c in self._channels], self._channels]
else:
signific = self.__signific[[[c] for c in self._channels], self._channels]
else:
estim = self.__estim
signific = self.__signific
# plotting loop
for i in np.arange(self.__chans_number):
for j in np.arange(self.__chans_number):
if self.__channames and i == 0:
axes[i, j].set_title(self.__channames[j]+" >", fontsize=12)
if self.__channames and j == 0:
axes[i, j].set_ylabel(self.__channames[i])
axes[i, j].fill_between(freqs, estim[:, i, j], 0)
if flag_sig:
if two_sides:
l_u = axes[i, j].axhline(y=signific[0, i, j], color='r')
l_d = axes[i, j].axhline(y=signific[1, i, j], color='r')
else:
l = axes[i, j].axhline(y=signific[i, j], color='r')
axes[i, j].set_xlim(xlim)
axes[i, j].set_ylim(ylim)
if i != self.__chans_number-1:
axes[i, j].get_xaxis().set_visible(False)
if j != 0:
axes[i, j].get_yaxis().set_visible(False)
plt.suptitle(name, y=0.98)
plt.tight_layout()
plt.subplots_adjust(top=0.92)
if show:
plt.show()
[docs] def plot_short_time_conn(self, name='', signi=True, percmax=1.,
show=True):
'''
Plot short-time version of estimation results.
Args:
*name* = '' : str
title of the plot
*signi* = True : boolean
reset irrelevant values; it works only after short time
significance calculation using *short_time_significance*
*percmax* = 1. : float (0,1)
percent of maximal value which is maximum on the color map
*show* = True : boolean
show the plot or not
'''
assert hasattr(self, '_Data__shtimest') == True, "No valid data! Use calculation method first."
shtvalues = self.__shtimest
# selecting right channels
flag_channels_changed = False
if shtvalues.shape[-1] != self.__chans_number:
shtvalues = shtvalues[:, :, [[c] for c in self._channels], self._channels]
flag_channels_changed = True
# masking values if unsignificant
if signi and hasattr(self, '_Data__st_signific'):
if self.__st_signific.ndim > 3:
if flag_channels_changed:
shtvalues = self.fill_nans(shtvalues, self.__st_signific[:, 0, self._channels, self._channels])
shtvalues = self.fill_nans(shtvalues, self.__st_signific[:, 1, self._channels, self._channels])
else:
shtvalues = self.fill_nans(shtvalues, self.__st_signific[:, 0, :, :])
shtvalues = self.fill_nans(shtvalues, self.__st_signific[:, 1, :, :])
else:
if flag_channels_changed:
shtvalues = self.fill_nans(shtvalues, self.__st_signific[:, self._channels, self._channels])
else:
shtvalues = self.fill_nans(shtvalues, self.__st_signific)
fig, axes = plt.subplots(self.__chans_number, self.__chans_number)
# currently not used:
# freqs = np.linspace(0, self.__fs//2, 4)
# time = np.linspace(0, self.__length/self.__fs, 5)
# ticks_time = [0, self.__fs//2]
# ticks_freqs = [0, self.__length//self.__fs]
# mask diagonal values to not contaminate the plot
mask = np.zeros(shtvalues.shape)
for i in range(self.__chans_number):
mask[:, :, i, i] = 1
masked_shtimest = np.ma.array(shtvalues, mask=mask)
dtmax = np.nanmax(masked_shtimest)*percmax
dtmin = np.nanmin(masked_shtimest)
cmap = plt.get_cmap('rainbow')
cmap.set_bad(color='w', alpha=1)
for i in np.arange(self.__chans_number):
for j in np.arange(self.__chans_number):
if self.__channames and i == 0:
axes[i, j].set_title(self.__channames[j]+" >", fontsize=12)
if self.__channames and j == 0:
if i == self.__chans_number//2:
axes[i, j].set_ylabel("f [Hz]\n" + self.__channames[i])
else:
axes[i, j].set_ylabel(self.__channames[i])
elif j == 0 and i == self.__chans_number//2:
axes[i, j].set_ylabel("f [Hz]")
img = axes[i, j].imshow(shtvalues[:, :, i, j].T, cmap=cmap, aspect='auto',
extent=[0, self.__length/self.__fs, 0, self.__fs//2],
interpolation='none', origin='lower',
vmin=dtmin, vmax=dtmax)
if i != self.__chans_number-1:
axes[i, j].get_xaxis().set_visible(False)
else:
labels = axes[i, j].get_xticklabels()
for label in labels[::2]:
label.set_visible(False)
if j == self.__chans_number//2:
axes[i, j].set_xlabel("time [s]")
if j != 0:
axes[i, j].get_yaxis().set_visible(False)
else:
labels = axes[i, j].get_yticklabels()
for label in labels[::2]:
label.set_visible(False)
# xt = np.array(axes[i, j].get_xticks())/self.__fs
plt.suptitle(name, y=0.98)
plt.tight_layout()
fig.subplots_adjust(top=0.92, right=0.91, wspace=0.05, hspace=0.05)
axes
cbar_ax = fig.add_axes([0.93, 0.1, 0.02, 0.7])
cbar_ax.tick_params(labelsize=10)
fig.colorbar(img, cax=cbar_ax)
if show:
plt.show()
[docs] def export_trans3d(self, mod=0, filename='conntrans3d.dat', freq_band=[]):
'''
Export connectivity data to trans3D data file in order to make
3D arrow plots.
Args:
*mod* = 0 : int
0 - :func:`Data.conn` results
1 - :func:`Data.short_time_conn` results
*filename* = 'conn_trnas3d.dat' : str
title of the plot
*freq_band* = [] : list
frequency range [from_value, to_value] in Hz.
'''
content = ";electrodes = " + " ".join(self.__channames)
content += "\r\n;start = -0.500000\r\n"
content += ";samplerate = 12\r\n"
# two following lines define initial position of head model
content += ";transform_default = 1 0 0 0 0 1 0 0 0 0 1 -8 0 0 0 1\r\n"
content += ";transform = 0.005148315144289768 0.007407087180879943 -3.4522594293647604 0.0 -2.727691946877633 2.116098522619611 0.000472475639 0.0 2.1160923126171305 2.7276819307525173 0.009008154977586572 -8.0 0.000000 0.000000 0.000000 1.000000\r\n"
content += "\r\n"
# integrate value of estimator in given frequency band
freqs = np.linspace(0, int(self.__fs/2), self.__estim.shape[0])
if len(freq_band) == 0:
ind1 = 0
ind2 = len(freqs)
else:
ind1 = np.where(freqs >= freq_band[0])[0][0]
ind2 = np.where(freqs >= freq_band[1])[0][0]
if mod == 0:
assert hasattr(self, '_Data__estim') is True, "No valid data! Use calculation method first."
cnest = np.mean(self.__estim[ind1:ind2, :, :], axis=0)
for i in range(self.__chans_number):
content += " " + " ".join(['{:.4f}'.format(x) for x in cnest[i]]) + "\r\n"
elif mod == 1:
assert hasattr(self, '_Data__shtimest') is True, "No valid data! Use calculation method first."
for k in range(self.__shtimest.shape[0]):
cnest = np.mean(self.__shtimest[k, ind1:ind2, :, :], axis=0)
for i in range(self.__chans_number):
content += " " + " ".join(['{:.4f}'.format(x) for x in cnest[i]]) + "\r\n"
content += "\r\n"
with open(filename, 'wb') as fl:
fl.write(content)
# auxiliary methods:
def __make_params_dict(self, args):
"""
Making list of parameters from *self._parameters*
Args:
*args* : list
list with parameters of *calculate* method of specific
estimator
Returns:
*newparams* : dict
dictionary with new parameters
"""
newparams = {}
for ag in args[1:]:
if ag in ['data']:
continue
if ag in self._parameters:
newparams[ag] = self._parameters[ag]
return newparams
[docs] def fill_nans(self, values, borders):
'''
Fill nans where *values* < *borders* (independent of frequency).
Args:
*values* : numpy.array
array of shape (time, freqs, channels, channels) to fill nans
*borders* : numpy.array
array of shape (time, channels, channels) with limes
values
Returns:
*values_nans* : numpy.array
array of shape (time, freq, channels, channels) with nans
where values were less than appropieate value from *borders*
'''
tm, fr, k, k = values.shape
for i in range(fr):
values[:, i, :, :][values[:, i, :, :] < borders] = np.nan
return values
# accessors:
@property
def mvar_coefficients(self):
"Returns mvar coefficients if calculated"
if hasattr(self, '_Data__Ar') and hasattr(self, '_Data__Vr'):
return (self.__Ar, self.__Vr)
else:
return (None, None)
@property
def mvarcoef(self):
"Returns mvar coefficients if calculated"
return self.mvar_coefficients
@property
def data(self):
return self.__data[self._channels]
@property
def fs(self):
return self.__fs
@property
def srate(self):
return self.__fs
@property
def channelnames(self):
return self.__channames
@property
def channels(self):
return self._channels
@property
def channelsnr(self):
return self.__chans_number