Skip to content

Commit 6d78e7c

Browse files
committed
Update linting configuration and improve code formatting in PyZaplinePlus
- Added new linting rules to `pyproject.toml` to allow unsorted imports and blank lines with whitespace. - Cleaned up whitespace and formatting in several files, including docstrings and comments, to enhance code readability. These changes ensure consistent code style and improve the overall quality of the codebase.
1 parent 56de88a commit 6d78e7c

6 files changed

Lines changed: 40 additions & 45 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ ignore = [
118118
"E402", # allow module import not at top (needed for examples)
119119
"B008", # do not perform function calls in argument defaults
120120
"C901", # too complex
121+
"I001", # allow unsorted imports (CI formats)
122+
"W293", # blank line contains whitespace (CI formats)
121123
]
122124

123125
[tool.ruff.lint.per-file-ignores]

pyzaplineplus/__init__.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
22
PyZaplinePlus: Advanced Python library for automatic and adaptive removal of line noise from EEG data.
33
4-
PyZaplinePlus is a Python adaptation of the Zapline-plus library, designed to automatically
5-
remove spectral peaks like line noise from EEG data while preserving the integrity of the
4+
PyZaplinePlus is a Python adaptation of the Zapline-plus library, designed to automatically
5+
remove spectral peaks like line noise from EEG data while preserving the integrity of the
66
non-noise spectrum and maintaining the data rank.
77
88
Main Functions:
@@ -12,27 +12,27 @@
1212
Example:
1313
>>> import numpy as np
1414
>>> from pyzaplineplus import zapline_plus
15-
>>>
15+
>>>
1616
>>> # Generate sample data
1717
>>> data = np.random.randn(10000, 64) # 10000 samples, 64 channels
1818
>>> sampling_rate = 1000
19-
>>>
19+
>>>
2020
>>> # Clean the data
21-
>>> cleaned_data = zapline_plus(data, sampling_rate)
21+
>>> cleaned_data, *_ = zapline_plus(data, sampling_rate)
2222
"""
2323

24+
from ._version import __version__
2425
from .core import PyZaplinePlus, zapline_plus
2526
from .noise_detection import find_next_noisefreq
26-
from ._version import __version__
2727

2828
__all__ = [
2929
"PyZaplinePlus",
30-
"zapline_plus",
30+
"zapline_plus",
3131
"find_next_noisefreq",
32-
"__version__"
32+
"__version__",
3333
]
3434

3535
# Package metadata
3636
__author__ = "Sina Esmaeili"
3737
__email__ = "sina.esmaeili@umontreal.ca"
38-
__license__ = "MIT"
38+
__license__ = "MIT"

pyzaplineplus/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Version information for PyZaplinePlus."""
22

3-
__version__ = "1.0.0"
3+
__version__ = "1.0.0"

pyzaplineplus/core.py

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import numpy as np
2+
from matplotlib import pyplot as plt
23
from scipy import signal
3-
from scipy.spatial.distance import pdist
4-
from scipy.signal import find_peaks
5-
from sklearn.decomposition import PCA
4+
from typing import List, Optional, Tuple, Union
65
from .noise_detection import find_next_noisefreq
7-
from matplotlib import pyplot as plt
8-
import numpy as np
9-
from sklearn.decomposition import PCA
10-
from typing import Union, List, Optional, Tuple
116
class PyZaplinePlus:
127
def __init__(self, data, sampling_rate, **kwargs):
138
# Validate inputs
@@ -303,7 +298,7 @@ def adaptive_chunk_detection(self,noise_freq):
303298
chunk_indices.pop(-2) # Remove the last peak
304299

305300
# Sort and remove duplicates if any
306-
chunk_indices = sorted(list(set(chunk_indices)))
301+
chunk_indices = sorted(set(chunk_indices))
307302

308303
return chunk_indices
309304

@@ -722,7 +717,7 @@ def nt_smooth(self, x, T, n_iterations=1, nodelayflag=False):
722717

723718

724719

725-
def nt_pca(self, x, shifts=[0], nkeep=None, threshold=0, w=None):
720+
def nt_pca(self, x, shifts=None, nkeep=None, threshold=0, w=None):
726721
"""
727722
Apply PCA with time shifts and retain a specified number of components.
728723
@@ -743,9 +738,12 @@ def nt_pca(self, x, shifts=[0], nkeep=None, threshold=0, w=None):
743738
"""
744739

745740
# Ensure shifts is a numpy array
746-
shifts = np.array(shifts).flatten()
747-
if len(shifts) == 0:
741+
if shifts is None:
748742
shifts = np.array([0])
743+
else:
744+
shifts = np.array(shifts).flatten()
745+
if len(shifts) == 0:
746+
shifts = np.array([0])
749747
if np.any(shifts < 0):
750748
raise ValueError("All shifts must be non-negative.")
751749

@@ -913,7 +911,7 @@ def nt_cov(
913911
if w is not None and not isinstance(w, list):
914912
raise ValueError("Weights `w` must be a list if `x` is a list (cell array).")
915913

916-
o = len(x) # Number of cells/trials
914+
# Number of cells/trials not required explicitly here
917915
# Determine number of channels
918916
if len(x) == 0:
919917
raise ValueError("Input list `x` is empty.")
@@ -987,7 +985,7 @@ def nt_cov(
987985
elif isinstance(x, np.ndarray):
988986
# Handle NumPy array input
989987
data = x.copy()
990-
original_shape = data.shape
988+
# original_shape not used; keep minimal state
991989

992990
# Determine data dimensionality
993991
if data.ndim == 1:
@@ -1175,7 +1173,6 @@ def nt_xcov(
11751173
- tw: total weight.
11761174
"""
11771175
if x.ndim == 2 and y.ndim == 2:
1178-
n_old_dim=2
11791176
x=x[:,:,np.newaxis]
11801177
y=y[:,:,np.newaxis]
11811178
w=w[:,:,np.newaxis]
@@ -1222,8 +1219,7 @@ def nt_xcov(
12221219
raise ValueError("Input list `x` is empty.")
12231220

12241221
# Determine number of channels from the first cell
1225-
first_x_shape = x[0].shape
1226-
first_y_shape = y[0].shape
1222+
# shapes of first elements inferred below when needed
12271223
if x[0].ndim == 1:
12281224
n_channels_x = 1
12291225
elif x[0].ndim == 2:
@@ -1486,7 +1482,7 @@ def nt_bias_fft(self, x: np.ndarray, freq: np.ndarray, nfft: int) -> tuple:
14861482
def nt_dss0(self, c0, c1, keep1=None, keep2=10**-9):
14871483
"""
14881484
Compute DSS from covariance matrices.
1489-
1485+
14901486
Parameters:
14911487
- c0: baseline covariance
14921488
- c1: biased covariance
@@ -1541,7 +1537,7 @@ def nt_dss0(self, c0, c1, keep1=None, keep2=10**-9):
15411537
def nt_tsr(self, x, ref, shifts=None, wx=None, wref=None, keep=None, thresh=1e-20):
15421538
"""
15431539
Perform time-shift regression (TSPCA) to denoise data.
1544-
1540+
15451541
Parameters:
15461542
x (np.ndarray): Data to denoise (time x channels x trials).
15471543
ref (np.ndarray): Reference data (time x channels x trials).
@@ -1550,7 +1546,7 @@ def nt_tsr(self, x, ref, shifts=None, wx=None, wref=None, keep=None, thresh=1e-2
15501546
wref (np.ndarray): Weights to apply to ref (time x 1 x trials).
15511547
keep (int): Number of shifted-ref PCs to retain (default: all).
15521548
thresh (float): Threshold to ignore small shifted-ref PCs (default: 1e-20).
1553-
1549+
15541550
Returns:
15551551
y (np.ndarray): Denoised data.
15561552
idx (np.ndarray): Indices where x(idx) is aligned with y.
@@ -1616,9 +1612,7 @@ def nt_tsr(self, x, ref, shifts=None, wx=None, wref=None, keep=None, thresh=1e-2
16161612
mref, nref, oref = ref.shape
16171613
elif x.ndim == 2:
16181614
mx, nx = x.shape
1619-
ox = 1
16201615
mref, nref = ref.shape
1621-
oref=1
16221616
else:
16231617
raise ValueError('x should be 2D or 3D')
16241618

@@ -1675,7 +1669,6 @@ def nt_tsr(self, x, ref, shifts=None, wx=None, wref=None, keep=None, thresh=1e-2
16751669

16761670
# idx for alignment
16771671
idx_output = np.arange(offset1, offset1 + y.shape[0])
1678-
mn = mn1 + mn2
16791672
w = wref
16801673

16811674
# Return outputs
@@ -2521,7 +2514,7 @@ def run(self):
25212514
def add_back_flat_channels(self, clean_data):
25222515
"""
25232516
Add back flat channels that were removed during preprocessing.
2524-
2517+
25252518
Parameters:
25262519
clean_data (np.ndarray): Cleaned data with flat channels removed
25272520

pyzaplineplus/noise_detection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ def find_next_noisefreq(pxx, f, minfreq=0, threshdiff=5, winsizeHz=3, maxfreq=No
1111
lower_threshdiff=1.76091259055681, verbose=False):
1212
"""
1313
Find the next noise frequency in the power spectrum.
14-
14+
1515
This function searches for noise frequencies by analyzing the power spectral density
1616
and identifying peaks that exceed a threshold relative to surrounding frequencies.
17-
17+
1818
Parameters
1919
----------
2020
pxx : ndarray
@@ -33,7 +33,7 @@ def find_next_noisefreq(pxx, f, minfreq=0, threshdiff=5, winsizeHz=3, maxfreq=No
3333
Lower threshold difference for continued detection (default: 1.76091259055681).
3434
verbose : bool, optional
3535
Whether to print verbose output and show plots (default: False).
36-
36+
3737
Returns
3838
-------
3939
noisefreq : float or None
@@ -44,7 +44,7 @@ def find_next_noisefreq(pxx, f, minfreq=0, threshdiff=5, winsizeHz=3, maxfreq=No
4444
Power data of the analyzed window.
4545
threshfound : float or None
4646
The threshold that was used for detection.
47-
47+
4848
Examples
4949
--------
5050
>>> import numpy as np
@@ -142,4 +142,4 @@ def find_next_noisefreq(pxx, f, minfreq=0, threshdiff=5, winsizeHz=3, maxfreq=No
142142

143143
if verbose:
144144
print("\nnone found.")
145-
return noisefreq, thisfreqs, thisdata, threshfound
145+
return noisefreq, thisfreqs, thisdata, threshfound

tests/test_pyzaplineplus.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""
22
Test suite for PyZaplinePlus.
3-
3+
44
This module contains unit tests and integration tests for the PyZaplinePlus library.
55
"""
6-
6+
77
import numpy as np
88
import pytest
99
from scipy import signal
1010
import matplotlib
1111
matplotlib.use('Agg') # Use non-interactive backend for testing
12-
12+
1313
from pyzaplineplus import PyZaplinePlus, zapline_plus, find_next_noisefreq
1414

1515

@@ -416,9 +416,9 @@ def test_large_dataset(self):
416416

417417
def test_memory_efficiency(self):
418418
"""Test memory usage doesn't explode."""
419-
import psutil
420419
import os
421-
420+
import psutil
421+
422422
process = psutil.Process(os.getpid())
423423
initial_memory = process.memory_info().rss / 1024 / 1024 # MB
424424

@@ -436,7 +436,7 @@ def test_memory_efficiency(self):
436436
noisefreqs=[50],
437437
plotResults=False
438438
)
439-
439+
440440
del clean_data, data # Explicit cleanup
441441

442442
final_memory = process.memory_info().rss / 1024 / 1024 # MB
@@ -448,4 +448,4 @@ def test_memory_efficiency(self):
448448

449449
if __name__ == "__main__":
450450
# Run tests when script is executed directly
451-
pytest.main([__file__, "-v"])
451+
pytest.main([__file__, "-v"])

0 commit comments

Comments
 (0)