Skip to content

Commit 666d596

Browse files
authored
Merge pull request #132 from kadirnar/delete-model-file
Move model loading logic into SpeechToTextPipeline class
2 parents e2a6c5e + 90e04ca commit 666d596

3 files changed

Lines changed: 73 additions & 82 deletions

File tree

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ ffmpeg
1010
ffmpeg-python
1111
pre-commit
1212
fire
13+
transformers

whisperplus/model/load_model.py

Lines changed: 0 additions & 78 deletions
This file was deleted.

whisperplus/pipelines/whisper.py

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import logging
22
from typing import Optional
33

4-
from transformers import pipeline
5-
6-
from whisperplus.model.load_model import load_model_whisper
4+
import torch
5+
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
76

87
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
98

@@ -28,6 +27,75 @@ def __init__(
2827
else:
2928
logging.info("Model already loaded.")
3029

30+
def compile_model(self, model):
31+
model.model.encoder.forward = torch.compile(
32+
model.model.encoder.forward, mode="reduce-overhead", fullgraph=True)
33+
model.model.decoder.forward = torch.compile(
34+
model.model.decoder.forward, mode="reduce-overhead", fullgraph=True)
35+
return model
36+
37+
def hqq_compile_model(self, model_id, quant_config, device):
38+
import hqq.models.base as hqq_base
39+
import torch._dynamo
40+
from hqq.core.quantize import HQQBackend, HQQLinear
41+
from hqq.models.hf.base import AutoHQQHFModel
42+
from hqq.utils.patching import prepare_for_inference
43+
44+
torch._dynamo.config.suppress_errors = True
45+
46+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
47+
model_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
48+
49+
processor = AutoProcessor.from_pretrained(model_id)
50+
HQQLinear.set_backend(HQQBackend.PYTORCH)
51+
52+
AutoHQQHFModel.quantize_model(
53+
model.model.encoder, quant_config=quant_config, compute_dtype=torch.bfloat16, device=device)
54+
55+
AutoHQQHFModel.quantize_model(
56+
model.model.decoder, quant_config=quant_config, compute_dtype=torch.bfloat16, device=device)
57+
58+
hqq_base._QUANT_LAYERS = [torch.nn.Linear, HQQLinear]
59+
AutoHQQHFModel.set_auto_linear_tags(model.model.encoder)
60+
prepare_for_inference(model.model.encoder)
61+
62+
AutoHQQHFModel.set_auto_linear_tags(model.model.decoder)
63+
prepare_for_inference(model.model.decoder, backend="torchao_int4")
64+
65+
model = self.compile_model(model)
66+
67+
return model, processor
68+
69+
def load_model_whisper(
70+
self,
71+
model_id: str = "distil-whisper/distil-large-v3",
72+
quant_config=None,
73+
hqq_compile: bool = False,
74+
flash_attention_2: bool = False,
75+
device=None):
76+
77+
if hqq_compile:
78+
return self.hqq_compile_model(model_id, quant_config, device)
79+
else:
80+
if flash_attention_2:
81+
attn_implementation = "flash_attention_2"
82+
else:
83+
attn_implementation = "sdpa"
84+
85+
model = AutoModelForSpeechSeq2Seq.from_pretrained(
86+
model_id,
87+
quantization_config=quant_config,
88+
low_cpu_mem_usage=True,
89+
use_safetensors=True,
90+
attn_implementation=attn_implementation,
91+
torch_dtype=torch.bfloat16,
92+
device_map=device,
93+
)
94+
95+
processor = AutoProcessor.from_pretrained(model_id)
96+
97+
return model, processor
98+
3199
def load_plus_model(
32100
self,
33101
model_id: str = "distil-whisper/distil-large-v3",
@@ -36,7 +104,7 @@ def load_plus_model(
36104
flash_attention_2: bool = True,
37105
):
38106

39-
model, processor = load_model_whisper(
107+
model, processor = self.load_model_whisper(
40108
model_id=model_id,
41109
quant_config=quant_config,
42110
hqq_compile=hqq_compile,

0 commit comments

Comments
 (0)