Skip to content
203 changes: 174 additions & 29 deletions neuroanalysis/event_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,45 +95,175 @@ def zero_crossing_events(data, min_length=3, min_peak=0.0, min_sum=0.0, noise_th

return events

def deal_unbalanced_initial_off(omit_ends, on_inds, off_inds):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Start function names with _ if you don't intend then to be used externally.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

"""Deals with situation where there is an "off" crossing from above to below threshold
at the beginning of a trace without there first being an "on" crossing from below to above
threshold. Note that the usage of this function is looking for extreme regions
where a trace is below a negative threshold or above a positive threshold, thus, the
sign of the trace value at *on_inds* and *off_inds* can be positive or negative
"""
if not omit_ends:
on_inds = [0] + on_inds #prepend the edge as on on ind
else:
off_inds = off_inds[1:] #remove the off ind
return on_inds, off_inds

def deal_unbalanced_termination_on(omit_ends, on_inds, off_inds, off_to_add):
if not omit_ends:
off_inds = off_inds + [off_to_add] #append the index of the last data point
else:
on_inds = on_inds[:-1] #remove the last on indicie
return on_inds, off_inds


def threshold_events(trace, threshold, adjust_times=True, baseline=0.0, omit_ends=True):
def threshold_events(trace, threshold, adjust_times=True, baseline=0.0, omit_ends=True, debug=False):
"""
Finds regions in a trace that cross a threshold value (as measured by distance from baseline). Returns the index, length, peak, and sum of each event.
Optionally adjusts index to an extrapolated baseline-crossing.
Finds regions in a trace that cross a threshold value (as measured by distance from baseline) and then
recross threshold (bumps). If a threshold is crossed at the end of the trace, an event may be excluded
or the beginning/end may be used as the the start/end of the event (depending on the value of *omit_ends*).
Optionally adjusts start and end index of an event to an extrapolated baseline-crossing and calculates
values.

Parameters
==========
trace: *Trace* instance
threshold: float or np.array with dimensions of *trace.data*
algorithm checks if waveform crosses both positive and negative thresholds.
i.e. if -5. is provided, the algorithm looks for places where the waveform crosses +/-5.
If an array is provided the *threshold* is dynamic

adjust_times: boolean
if True, move the start and end times of the event outward, estimating the zero-crossing point for the event

Returns
=======
events: numpy structured array.
An event is a region of the *Trace.data* waveform that crosses above *threshold* and then falls below threshold again
Sometimes referred to as a 'bump'. There are additional criteria (not listed here) for a bump to be considered an event.
Each index contains information about an event. Fields as follows:
index: int
index of the initial crossing of the *threshold*
len: int
index length of the event
sum: float
sum of the values in the array between the start and end of the event
peak: float
peak value of event
peak_index: int
index value of the peak
time: float, or np.nan if timing not available
time of the onset of an event
duration: float, or np.nan if timing not available
duration of time of the event
area: float, or np.nan if timing not available
area under the curve of the event
peak_time: float, or np.nan if timing not available
time of peak
"""
threshold = abs(threshold)


data = trace.data
data1 = data - baseline
# convert threshold array
if isinstance(threshold, float):
threshold = np.ones(len(data)) * abs(threshold)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why convert to array? this just slows down the comparisons that come later.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it so that the input threshold could be an array so that it doesn't need to be constant (see line 130). For current clamp I alter the threshold across the duration of the stimulation pulse. (See spike_detection.py line 154)



#if (hasattr(data, 'implements') and data.implements('MetaArray')):

## find all threshold crossings
masks = [(data1 > threshold).astype(np.byte), (data1 < -threshold).astype(np.byte)]
## find all positive and negative threshold crossings of baseline adjusted data

if debug:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not so keen on merging debugging code unless there's a good reason to keep it here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I figured, you would want it removed, but I didn't want to do it before you looked at the rest of the code incase I needed to look at stuff again. The debug plots are particularly insightful for your questions below about removing the +1 on line 191. Once you are happy it can be removed before merge along with the 'debug' argument into the functions.

import pdb; pdb.set_trace()
# FYI: can't use matplot lib before debugger is on
# type *continue to see plot
import matplotlib.pyplot as mpl


masks = [(data1 > threshold).astype(np.byte), (data1 < -threshold).astype(np.byte)] # 1 where data is [above threshold, below negative threshold]
hits = []
for mask in masks:
diff = mask[1:] - mask[:-1]
on_inds = list(np.argwhere(diff==1)[:,0] + 1)
off_inds = list(np.argwhere(diff==-1)[:,0] + 1)
if len(on_inds) == 0 or len(off_inds) == 0:
continue
if off_inds[0] < on_inds[0]:
if omit_ends:
off_inds = off_inds[1:]
if len(off_inds) == 0:
continue
else:
on_inds.insert(0, 0)
if off_inds[-1] < on_inds[-1]:
if omit_ends:
on_inds = on_inds[:-1]
else:
off_inds.append(len(diff))
diff = mask[1:] - mask[:-1] # -1 (or +1) when crosses from above to below threshold (or visa versa if threshold is negative). Note above threshold refers to value furthest from zero, i.e. it can be positive or negative

# TODO: It might be a good idea to make offindexes one less so that region above threshold looks symmetrical
# find start and end inicies (above threshold) where waveform is above threshold
on_inds = list(np.argwhere(diff==1)[:,0] + 1) #where crosses from below to above threshold Note taking index after this happens
off_inds = list(np.argwhere(diff==-1)[:,0]) #where crosses from above to below threshold. Note taking index before this happens

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove the +1 here? The idea is you could take data[on_ind:off_ind] and get exactly the region above threshold.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last index was actually below the threshold and I only wanted to use indexes above threshold. The +1 includes an area and a value below threshold in the sum and the area.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will also remove the #TODO on 188 as I did implement that here.


if debug:
mpl.figure()
mpl.plot(data1, '.-')
for on in on_inds:
mpl.axvline(x=on, color='g', linestyle='--')
for off in off_inds:
mpl.axvline(x=off, color='r', linestyle='--')
mpl.plot(threshold, color='y', linestyle='--')
mpl.plot(-threshold, color='y', linestyle='--')
mpl.show(block=False)



# sometimes an event happens at the beginning of the pulse window and the trace hasn't

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a generic function for detecting threshold crossings; not a good place for comments that require an understanding of your specific use case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make this statement more generic. But this if statement is here for the case when an event happens at the beginning.

# been able to drop below threshold because it hasn't recovered from the artifact.
if len(off_inds) > 0: #if there are some off indicies
if len(on_inds) > 0: #and there are also on indicies
if on_inds[0] > off_inds[0]: #check if off happens before on
on_inds, off_inds = deal_unbalanced_initial_off(omit_ends, on_inds, off_inds)
else: #there are no on indicies
on_inds, off_inds = deal_unbalanced_initial_off(omit_ends, on_inds, off_inds)

if len(on_inds) > 0: #if there are some on indicies
if len(off_inds) > 0: #and there are also off indicies
if on_inds[-1] > off_inds[-1]: #check if off happens before on
on_inds, off_inds = deal_unbalanced_termination_on(omit_ends, on_inds, off_inds, len(data1))
else: #there are no off indicies
on_inds, off_inds = deal_unbalanced_termination_on(omit_ends, on_inds, off_inds, len(data1))

# this is insufficient because it ignores when there are just off or just on
# not sure why this is here if haven't decided whether or not to omit ends
# if len(on_inds) == 0 or len(off_inds) == 0:
# continue
# ## if there are unequal number of crossing from one direction, either remove the ends or add the appropriate initial or end index (which will be the beginning or end of the waveform)
# if off_inds[0] < on_inds[0]:
# if omit_ends:
# off_inds = off_inds[1:]
# if len(off_inds) == 0:
# continue
# else:
# on_inds.insert(0, 0)
# if off_inds[-1] < on_inds[-1]:
# if omit_ends:
# on_inds = on_inds[:-1]
# else:
# off_inds.append(len(diff))

# Add both events above +threshold and those below -threshold to a list
for i in range(len(on_inds)):
if on_inds[i] == off_inds[i]:

# remove any point where an on off is seperated by less than 2 indicies
if (on_inds[i] + 1) == off_inds[i]:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? Again, this is a generic threshold-crossing function; we shouldn't make assumptions about the use case here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was useful in the downstream stream code, I can pull it out and put in the other function.

continue
if (on_inds[i] - 1) == off_inds[i]:
continue
if on_inds[i] == off_inds[i]:
continue

hits.append((on_inds[i], off_inds[i]))

## sort hits ## NOTE: this can be sped up since we already know how to interleave the events..
hits.sort(key=lambda a: a[0])
if debug is True:
# FYI: can't use matplot lib before debugger is on
# type *continue to see plot
mpl.figure()
mpl.title('first round before adjustment')
mpl.plot(data1, '.-')
for hit in hits:
mpl.axvline(x=hit[0], color='g', linestyle='--')
mpl.axvline(x=hit[1], color='r', linestyle='--')
mpl.plot(threshold, color='y', linestyle='--')
mpl.plot(-threshold, color='y', linestyle='--')
mpl.show(block=False)

n_events = len(hits)
events = np.empty(n_events, dtype=[
Expand Down Expand Up @@ -162,18 +292,18 @@ def threshold_events(trace, threshold, adjust_times=True, baseline=0.0, omit_end
else:
peak_ind = np.argmin(ev_data)
peak = ev_data[peak_ind]
peak_ind += ind1
peak_ind += ind1 # adjust peak_ind from local event data to entire waveform

#print "event %f: %d" % (xvals[ind1], ind1)
if adjust_times: ## Move start and end times outward, estimating the zero-crossing point for the event

## adjust ind1 first
mind = np.argmax(ev_data)
pdiff = abs(peak - ev_data[0])
if pdiff == 0:
mind = np.argmax(ev_data) # max of whole trace
pdiff = abs(peak - ev_data[0]) # find the how high the peak is from the front of event
if pdiff == 0:
adj1 = 0
else:
adj1 = int(threshold * mind / pdiff)
adj1 = int(threshold * mind / pdiff) # (max value of whole trace)* 1/(hight of peak from first data point)
adj1 = min(ln, adj1)
ind1 -= adj1

Expand Down Expand Up @@ -233,6 +363,18 @@ def threshold_events(trace, threshold, adjust_times=True, baseline=0.0, omit_end

## remove masked events
events = events[mask]

if debug:
mpl.figure()
mpl.title('adjusted')
mpl.plot(data1, '.-')
for on in events['index']:
mpl.axvline(x=on, color='g', linestyle='--')
# for off in off_inds:
# mpl.axvline(x=off, color='r', linestyle='--')
mpl.axhline(y=threshold, color='y', linestyle='--')
mpl.axhline(y=-threshold, color='y', linestyle='--')
mpl.show(block=False)

# add in timing information if available:
if trace.has_timing:
Expand All @@ -249,6 +391,9 @@ def threshold_events(trace, threshold, adjust_times=True, baseline=0.0, omit_end
ev['area'] = np.nan
ev['peak_time'] = np.nan


if debug:
pdb.set_trace()
return events


Expand Down
Loading