Skip to content
134 changes: 111 additions & 23 deletions neuroanalysis/event_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,40 +95,128 @@ def zero_crossing_events(data, min_length=3, min_peak=0.0, min_sum=0.0, noise_th

return events

def _deal_unbalanced_initial_off(omit_ends, on_inds, off_inds):
"""Deals with situation where there is an "off" crossing from above to below threshold
at the beginning of a trace without there first being an "on" crossing from below to above
threshold. Note that the usage of this function is looking for extreme regions
where a trace is below a negative threshold or above a positive threshold, thus, the
sign of the trace value at *on_inds* and *off_inds* can be positive or negative
"""
if not omit_ends:
on_inds = [0] + on_inds #prepend the edge as on on ind
else:
off_inds = off_inds[1:] #remove the off ind
return on_inds, off_inds

def _deal_unbalanced_termination_on(omit_ends, on_inds, off_inds, off_to_add):
"""Deals with situation where there is an "on" crossing from below to above threshold
toward the end of a trace without an "off" crossing happening thereafter. Note that
the usage of this function is looking for extreme regions
where a trace is below a negative threshold or above a positive threshold, thus, the
sign of the trace value at *on_inds* and *off_inds* can be positive or negative
"""
if not omit_ends:
off_inds = off_inds + [off_to_add] #append the index of the last data point
else:
on_inds = on_inds[:-1] #remove the last on indicie
return on_inds, off_inds


def threshold_events(trace, threshold, adjust_times=True, baseline=0.0, omit_ends=True):
"""
Finds regions in a trace that cross a threshold value (as measured by distance from baseline). Returns the index, length, peak, and sum of each event.
Optionally adjusts index to an extrapolated baseline-crossing.
Finds regions in a trace that cross a threshold value (as measured by distance from baseline) and then
recross threshold ('bumps'). If a threshold is crossed at the end of the trace, an event may be excluded
or the beginning/end may be used as the the start/end of the event (depending on the value of *omit_ends*).


Parameters
==========
trace: *Tseries* instance
threshold: float or np.array with dimensions of *trace.data*
Algorithm checks if waveform crosses both positive and negative *threshold* symetrically
around from the y-axis. i.e. if -5. is provided, the algorithm looks for places where
the waveform crosses +/-5. If an array is provided, each index of the *threshold* will
be compared with the data pointwise.
adjust_times: boolean
If True, move the start and end times of the event outward, estimating the zero-crossing point for the event
baseline: float
Value subtracted from the data.
omit_ends: boolean
If true, add the trace endpoint indices to incomplete events, i.e., events that started above threhold at the
beginning of trace, or crossed threshold but did not return below threshold at the end of a trace. If false,
remove the imcomplete events.


Returns
=======
events: numpy structured array.
An event ('bump') is a region of the *trace.data* waveform that crosses above *threshold* and then falls below
threshold again. Each index contains information about an event. Fields as follows:
index: int
index of the initial crossing of the *threshold*
len: int
index length of the event
sum: float
sum of the values in the array between the start and end of the event
peak: float
peak value of event
peak_index: int
index value of the peak
time: float, or np.nan if timing not available
time of the onset of an event
duration: float, or np.nan if timing not available
duration of time of the event
area: float, or np.nan if timing not available
area under the curve of the event
peak_time: float, or np.nan if timing not available
time of peak
"""
threshold = abs(threshold)


data = trace.data
data1 = data - baseline
#if (hasattr(data, 'implements') and data.implements('MetaArray')):

## find all threshold crossings
masks = [(data1 > threshold).astype(np.byte), (data1 < -threshold).astype(np.byte)]
# convert threshold array
if isinstance(threshold, float):
threshold = np.ones(len(data)) * abs(threshold)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why convert to array? this just slows down the comparisons that come later.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it so that the input threshold could be an array so that it doesn't need to be constant (see line 130). For current clamp I alter the threshold across the duration of the stimulation pulse. (See spike_detection.py line 154)


## find all threshold crossings in both positive and negative directions
## deal with imcomplete events, and store events

# -1 (or +1) when crosses from above to below threshold (or visa versa if threshold is negative). Note above threshold refers to value furthest from zero, i.e. it can be positive or negative
masks = [(data1 > threshold).astype(np.byte), (data1 < -threshold).astype(np.byte)]

hits = []
for mask in masks:
diff = mask[1:] - mask[:-1]
on_inds = list(np.argwhere(diff==1)[:,0] + 1)
off_inds = list(np.argwhere(diff==-1)[:,0] + 1)
if len(on_inds) == 0 or len(off_inds) == 0:
continue
if off_inds[0] < on_inds[0]:
if omit_ends:
off_inds = off_inds[1:]
if len(off_inds) == 0:
continue
else:
on_inds.insert(0, 0)
if off_inds[-1] < on_inds[-1]:
if omit_ends:
on_inds = on_inds[:-1]
else:
off_inds.append(len(diff))
# indices where crosses from below to above threshold ('on')
on_inds = list(np.argwhere(diff==1)[:,0] + 1)
# indices where crosses from above to below threshold ('off')
off_inds = list(np.argwhere(diff==-1)[:,0] + 1)

# deal with cases when there are unmatched on and off indicies
if len(off_inds) > 0: #if there are some off indicies
if len(on_inds) > 0: #and there are also on indicies
if on_inds[0] > off_inds[0]: #check if off happens before on
on_inds, off_inds = _deal_unbalanced_initial_off(omit_ends, on_inds, off_inds)
else: #there are no on indicies
on_inds, off_inds = _deal_unbalanced_initial_off(omit_ends, on_inds, off_inds)

if len(on_inds) > 0: #if there are some on indicies
if len(off_inds) > 0: #and there are also off indicies
if on_inds[-1] > off_inds[-1]: #check if off happens before on
on_inds, off_inds = _deal_unbalanced_termination_on(omit_ends, on_inds, off_inds, len(data1))
else: #there are no off indicies
on_inds, off_inds = _deal_unbalanced_termination_on(omit_ends, on_inds, off_inds, len(data1))


# at this point every 'on' should have and 'off'
assert len(on_inds) == len(off_inds)

# put corresponding on and off indeces in a list
for i in range(len(on_inds)):
if on_inds[i] == off_inds[i]:
#something wierd happened
continue
hits.append((on_inds[i], off_inds[i]))

Expand All @@ -154,7 +242,7 @@ def threshold_events(trace, threshold, adjust_times=True, baseline=0.0, omit_end
## 2) adjust event times if requested, then recompute parameters
for i in range(n_events):
ind1, ind2 = hits[i]
ln = ind2-ind1
ln = ind2 - ind1
ev_data = data1[ind1:ind2]
sum = ev_data.sum()
if sum > 0:
Expand Down
Loading