Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ repos:
name: Flake8 linting

- repo: https://github.com/PyCQA/docformatter
rev: v1.7.7
rev: v1.7.8-rc1
hooks:
- id: docformatter
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
</div>
<a href="https://trendshift.io/repositories/6007" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6007" alt="kadirnar%2Fwhisper-plus | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>

<div>
<div>
<a href="https://pypi.org/project/whisperplus" target="_blank">
<img src="https://img.shields.io/pypi/pyversions/whisperplus.svg?color=%2334D058" alt="Supported Python versions">
</a>
Expand Down
1 change: 0 additions & 1 deletion whisperplus/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def speaker_diarization(url, model_id, device, num_speakers, min_speaker, max_sp
transcript (str): The transcript of the speech-to-text conversion.
video_path (str): The path of the downloaded video.
"""

pipeline = ASRDiarizationPipeline.from_pretrained(
asr_model=model_id,
diarizer_model="pyannote/speaker-diarization",
Expand Down
5 changes: 2 additions & 3 deletions whisperplus/pipelines/lightning_whisper_mlx/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""

# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
Expand Down Expand Up @@ -79,8 +78,8 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None)
def mel_filters(n_mels: int) -> mx.array:
"""
Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa
dependency; saved using:
Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. dependency; saved using: Allows
decoupling librosa.

np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400,
n_mels=80), mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), )
Expand Down
4 changes: 3 additions & 1 deletion whisperplus/pipelines/lightning_whisper_mlx/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ def get_encoding(name: str = "gpt2", num_languages: int = 99):
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
pat_str=\
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
,
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
Expand Down
1 change: 0 additions & 1 deletion whisperplus/pipelines/lightning_whisper_mlx/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def transcribe_audio(
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""

dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
model = ModelHolder.get_model(path_or_hf_repo, dtype)

Expand Down
5 changes: 2 additions & 3 deletions whisperplus/pipelines/mlx_whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""

# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
Expand Down Expand Up @@ -79,8 +78,8 @@ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
@lru_cache(maxsize=None)
def mel_filters(n_mels: int) -> mx.array:
"""
Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. Allows decoupling librosa
dependency; saved using:
Load the mel filterbank matrix for projecting STFT into a Mel spectrogram. dependency; saved using: Allows
decoupling librosa.

np.savez_compressed( "mel_filters.npz", mel_80=librosa.filters.mel(sr=16000, n_fft=400,
n_mels=80), mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128), )
Expand Down
1 change: 0 additions & 1 deletion whisperplus/pipelines/mlx_whisper/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def load_torch_model(
model : Whisper
The Whisper ASR model instance
"""

if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")

Expand Down
4 changes: 3 additions & 1 deletion whisperplus/pipelines/mlx_whisper/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,9 @@ def get_encoding(name: str = "gpt2", num_languages: int = 99):
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
pat_str=\
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
,
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
Expand Down
1 change: 0 additions & 1 deletion whisperplus/pipelines/mlx_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def transcribe(
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""

dtype = mx.float16 if decode_options.get("fp16", True) else mx.float32
model = ModelHolder.get_model(path_or_hf_repo, dtype)

Expand Down
17 changes: 8 additions & 9 deletions whisperplus/pipelines/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ def hqq_compile_model(self, model_id, quant_config, device):
prepare_for_inference(model.model.decoder, backend="torchao_int4")

model = self.compile_model(model)

return model, processor

def load_model_whisper(
self,
model_id: str = "distil-whisper/distil-large-v3",
quant_config=None,
hqq_compile: bool = False,
flash_attention_2: bool = False,
device=None):
self,
model_id: str = "distil-whisper/distil-large-v3",
quant_config=None,
hqq_compile: bool = False,
flash_attention_2: bool = False,
device=None):

if hqq_compile:
return self.hqq_compile_model(model_id, quant_config, device)
else:
Expand Down Expand Up @@ -134,7 +134,6 @@ def __call__(
Returns:
str: Transcribed text from the audio.
"""

pipe = pipeline(
"automatic-speech-recognition",
model=self.model,
Expand Down
4 changes: 2 additions & 2 deletions whisperplus/pipelines/whisper_diarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def __call__(
**kwargs,
):
"""
Transcribe the audio sequence(s) given as inputs to text and label with speaker information. The input
audio is first passed to the speaker diarization pipeline, which returns timestamps for 'who spoke
Transcribe the audio sequence(s) given as inputs to text and label with speaker information. audio is
first passed to the speaker diarization pipeline, which returns timestamps for 'who spoke The input
when'. The audio is then passed to the ASR pipeline, which returns utterance-level transcriptions and
their corresponding timestamps. The speaker diarizer timestamps are aligned with the ASR transcription
timestamps to give speaker-labelled transcriptions. We cannot use the speaker diarization timestamps
Expand Down
Loading