11import numpy as np
2+ from matplotlib import pyplot as plt
23from scipy import signal
3- from scipy .spatial .distance import pdist
4- from scipy .signal import find_peaks
5- from sklearn .decomposition import PCA
4+ from typing import List , Optional , Tuple , Union
65from .noise_detection import find_next_noisefreq
7- from matplotlib import pyplot as plt
8- import numpy as np
9- from sklearn .decomposition import PCA
10- from typing import Union , List , Optional , Tuple
116class PyZaplinePlus :
127 def __init__ (self , data , sampling_rate , ** kwargs ):
138 # Validate inputs
@@ -303,7 +298,7 @@ def adaptive_chunk_detection(self,noise_freq):
303298 chunk_indices .pop (- 2 ) # Remove the last peak
304299
305300 # Sort and remove duplicates if any
306- chunk_indices = sorted (list ( set (chunk_indices ) ))
301+ chunk_indices = sorted (set (chunk_indices ))
307302
308303 return chunk_indices
309304
@@ -722,7 +717,7 @@ def nt_smooth(self, x, T, n_iterations=1, nodelayflag=False):
722717
723718
724719
725- def nt_pca (self , x , shifts = [ 0 ] , nkeep = None , threshold = 0 , w = None ):
720+ def nt_pca (self , x , shifts = None , nkeep = None , threshold = 0 , w = None ):
726721 """
727722 Apply PCA with time shifts and retain a specified number of components.
728723
@@ -743,9 +738,12 @@ def nt_pca(self, x, shifts=[0], nkeep=None, threshold=0, w=None):
743738 """
744739
745740 # Ensure shifts is a numpy array
746- shifts = np .array (shifts ).flatten ()
747- if len (shifts ) == 0 :
741+ if shifts is None :
748742 shifts = np .array ([0 ])
743+ else :
744+ shifts = np .array (shifts ).flatten ()
745+ if len (shifts ) == 0 :
746+ shifts = np .array ([0 ])
749747 if np .any (shifts < 0 ):
750748 raise ValueError ("All shifts must be non-negative." )
751749
@@ -913,7 +911,7 @@ def nt_cov(
913911 if w is not None and not isinstance (w , list ):
914912 raise ValueError ("Weights `w` must be a list if `x` is a list (cell array)." )
915913
916- o = len ( x ) # Number of cells/trials
914+ # Number of cells/trials not required explicitly here
917915 # Determine number of channels
918916 if len (x ) == 0 :
919917 raise ValueError ("Input list `x` is empty." )
@@ -987,7 +985,7 @@ def nt_cov(
987985 elif isinstance (x , np .ndarray ):
988986 # Handle NumPy array input
989987 data = x .copy ()
990- original_shape = data . shape
988+ # original_shape not used; keep minimal state
991989
992990 # Determine data dimensionality
993991 if data .ndim == 1 :
@@ -1175,7 +1173,6 @@ def nt_xcov(
11751173 - tw: total weight.
11761174 """
11771175 if x .ndim == 2 and y .ndim == 2 :
1178- n_old_dim = 2
11791176 x = x [:,:,np .newaxis ]
11801177 y = y [:,:,np .newaxis ]
11811178 w = w [:,:,np .newaxis ]
@@ -1222,8 +1219,7 @@ def nt_xcov(
12221219 raise ValueError ("Input list `x` is empty." )
12231220
12241221 # Determine number of channels from the first cell
1225- first_x_shape = x [0 ].shape
1226- first_y_shape = y [0 ].shape
1222+ # shapes of first elements inferred below when needed
12271223 if x [0 ].ndim == 1 :
12281224 n_channels_x = 1
12291225 elif x [0 ].ndim == 2 :
@@ -1486,7 +1482,7 @@ def nt_bias_fft(self, x: np.ndarray, freq: np.ndarray, nfft: int) -> tuple:
14861482 def nt_dss0 (self , c0 , c1 , keep1 = None , keep2 = 10 ** - 9 ):
14871483 """
14881484 Compute DSS from covariance matrices.
1489-
1485+
14901486 Parameters:
14911487 - c0: baseline covariance
14921488 - c1: biased covariance
@@ -1541,7 +1537,7 @@ def nt_dss0(self, c0, c1, keep1=None, keep2=10**-9):
15411537 def nt_tsr (self , x , ref , shifts = None , wx = None , wref = None , keep = None , thresh = 1e-20 ):
15421538 """
15431539 Perform time-shift regression (TSPCA) to denoise data.
1544-
1540+
15451541 Parameters:
15461542 x (np.ndarray): Data to denoise (time x channels x trials).
15471543 ref (np.ndarray): Reference data (time x channels x trials).
@@ -1550,7 +1546,7 @@ def nt_tsr(self, x, ref, shifts=None, wx=None, wref=None, keep=None, thresh=1e-2
15501546 wref (np.ndarray): Weights to apply to ref (time x 1 x trials).
15511547 keep (int): Number of shifted-ref PCs to retain (default: all).
15521548 thresh (float): Threshold to ignore small shifted-ref PCs (default: 1e-20).
1553-
1549+
15541550 Returns:
15551551 y (np.ndarray): Denoised data.
15561552 idx (np.ndarray): Indices where x(idx) is aligned with y.
@@ -1616,9 +1612,7 @@ def nt_tsr(self, x, ref, shifts=None, wx=None, wref=None, keep=None, thresh=1e-2
16161612 mref , nref , oref = ref .shape
16171613 elif x .ndim == 2 :
16181614 mx , nx = x .shape
1619- ox = 1
16201615 mref , nref = ref .shape
1621- oref = 1
16221616 else :
16231617 raise ValueError ('x should be 2D or 3D' )
16241618
@@ -1675,7 +1669,6 @@ def nt_tsr(self, x, ref, shifts=None, wx=None, wref=None, keep=None, thresh=1e-2
16751669
16761670 # idx for alignment
16771671 idx_output = np .arange (offset1 , offset1 + y .shape [0 ])
1678- mn = mn1 + mn2
16791672 w = wref
16801673
16811674 # Return outputs
@@ -2521,7 +2514,7 @@ def run(self):
25212514 def add_back_flat_channels (self , clean_data ):
25222515 """
25232516 Add back flat channels that were removed during preprocessing.
2524-
2517+
25252518 Parameters:
25262519 clean_data (np.ndarray): Cleaned data with flat channels removed
25272520
0 commit comments