Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
**unitwaveformdensitymapwidget_kwargs,
)
col_counter += 1
ax_waveform_density.set_xlabel(None)
ax_waveform_density.set_ylabel(None)

if sorting_analyzer.has_extension("correlograms"):
Expand Down
58 changes: 23 additions & 35 deletions src/spikeinterface/widgets/unit_waveforms_density_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(
templates = ext_templates.get_templates(unit_ids=unit_ids)
bin_min = np.min(templates) * 1.3
bin_max = np.max(templates) * 1.3
bin_size = (bin_max - bin_min) / 100
bins = np.arange(bin_min, bin_max, bin_size)
num_bins = 100
bins = np.linspace(bin_min, bin_max, num_bins + 1)

# 2d histograms
if same_axis:
Expand Down Expand Up @@ -121,14 +121,9 @@ def __init__(
wfs = wfs_

# make histogram density
wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1)
hist2d = np.zeros((wfs_flat.shape[1], bins.size))
indexes0 = np.arange(wfs_flat.shape[1])

wf_bined = np.floor((wfs_flat - bin_min) / bin_size).astype("int32")
wf_bined = wf_bined.clip(0, bins.size - 1)
for d in wf_bined:
hist2d[indexes0, d] += 1
wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) # num_spikes x (num_channels * timepoints)
hists_per_timepoint = [np.histogram(one_timepoint, bins=bins)[0] for one_timepoint in wfs_flat.T]
hist2d = np.stack(hists_per_timepoint)

if same_axis:
if all_hist2d is None:
Expand Down Expand Up @@ -162,60 +157,53 @@ def __init__(
bin_min=bin_min,
bin_max=bin_max,
all_hist2d=all_hist2d,
sampling_frequency=sorting_analyzer.sampling_frequency,
templates_flat=templates_flat,
template_width=wfs.shape[1],
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)

if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None:
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)
else:
if dp.same_axis:
num_axes = 1
else:
num_axes = len(dp.unit_ids)
if backend_kwargs["axes"] is None and backend_kwargs["ax"] is None:
backend_kwargs["ncols"] = 1
backend_kwargs["num_axes"] = num_axes
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)
backend_kwargs["num_axes"] = 1 if dp.same_axis else len(dp.unit_ids)
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

freq_khz = dp.sampling_frequency / 1000 # samples / msec
if dp.same_axis:
ax = self.ax
hist2d = dp.all_hist2d
im = ax.imshow(
x_max = len(hist2d) / freq_khz # in milliseconds
self.ax.imshow(
hist2d.T,
interpolation="nearest",
origin="lower",
aspect="auto",
extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max),
extent=(0, x_max, dp.bin_min, dp.bin_max),
cmap="hot",
)
else:
for unit_index, unit_id in enumerate(dp.unit_ids):
for ax, unit_id in zip(self.axes.flatten(), dp.unit_ids):
hist2d = dp.all_hist2d[unit_id]
ax = self.axes.flatten()[unit_index]
im = ax.imshow(
x_max = len(hist2d) / freq_khz # in milliseconds
ax.imshow(
hist2d.T,
interpolation="nearest",
origin="lower",
aspect="auto",
extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max),
extent=(0, x_max, dp.bin_min, dp.bin_max),
cmap="hot",
)

for unit_index, unit_id in enumerate(dp.unit_ids):
if dp.same_axis:
ax = self.ax
else:
ax = self.axes.flatten()[unit_index]
ax = self.ax if dp.same_axis else self.axes.flatten()[unit_index]
color = dp.unit_colors[unit_id]
ax.plot(dp.templates_flat[unit_id], color=color, lw=1)
x = np.arange(len(dp.templates_flat[unit_id])) / freq_khz
ax.plot(x, dp.templates_flat[unit_id], color=color, lw=1)

# final cosmetics
for unit_index, unit_id in enumerate(dp.unit_ids):
Expand All @@ -228,11 +216,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
chan_inds = dp.channel_inds[unit_id]
for i, chan_ind in enumerate(chan_inds):
if i != 0:
ax.axvline(i * dp.template_width, color="w", lw=3)
ax.axvline(i * dp.template_width / freq_khz, color="w", lw=3)
channel_id = dp.channel_ids[chan_ind]
x = i * dp.template_width + dp.template_width // 2
x = (i + 0.5) * dp.template_width / freq_khz
y = (dp.bin_max + dp.bin_min) / 2.0
ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center")

ax.set_xticks([])
ax.set_xlabel("Time [ms]")
ax.set_ylabel(f"unit_id {unit_id}")