11import logging
22from typing import Optional
33
4- from transformers import pipeline
5-
6- from whisperplus .model .load_model import load_model_whisper
4+ import torch
5+ from transformers import AutoModelForSpeechSeq2Seq , AutoProcessor , pipeline
76
87logging .basicConfig (level = logging .INFO , format = '%(asctime)s - %(levelname)s - %(message)s' )
98
@@ -28,6 +27,75 @@ def __init__(
2827 else :
2928 logging .info ("Model already loaded." )
3029
30+ def compile_model (self , model ):
31+ model .model .encoder .forward = torch .compile (
32+ model .model .encoder .forward , mode = "reduce-overhead" , fullgraph = True )
33+ model .model .decoder .forward = torch .compile (
34+ model .model .decoder .forward , mode = "reduce-overhead" , fullgraph = True )
35+ return model
36+
37+ def hqq_compile_model (self , model_id , quant_config , device ):
38+ import hqq .models .base as hqq_base
39+ import torch ._dynamo
40+ from hqq .core .quantize import HQQBackend , HQQLinear
41+ from hqq .models .hf .base import AutoHQQHFModel
42+ from hqq .utils .patching import prepare_for_inference
43+
44+ torch ._dynamo .config .suppress_errors = True
45+
46+ model = AutoModelForSpeechSeq2Seq .from_pretrained (
47+ model_id , torch_dtype = torch .bfloat16 , attn_implementation = "flash_attention_2" )
48+
49+ processor = AutoProcessor .from_pretrained (model_id )
50+ HQQLinear .set_backend (HQQBackend .PYTORCH )
51+
52+ AutoHQQHFModel .quantize_model (
53+ model .model .encoder , quant_config = quant_config , compute_dtype = torch .bfloat16 , device = device )
54+
55+ AutoHQQHFModel .quantize_model (
56+ model .model .decoder , quant_config = quant_config , compute_dtype = torch .bfloat16 , device = device )
57+
58+ hqq_base ._QUANT_LAYERS = [torch .nn .Linear , HQQLinear ]
59+ AutoHQQHFModel .set_auto_linear_tags (model .model .encoder )
60+ prepare_for_inference (model .model .encoder )
61+
62+ AutoHQQHFModel .set_auto_linear_tags (model .model .decoder )
63+ prepare_for_inference (model .model .decoder , backend = "torchao_int4" )
64+
65+ model = self .compile_model (model )
66+
67+ return model , processor
68+
69+ def load_model_whisper (
70+ self ,
71+ model_id : str = "distil-whisper/distil-large-v3" ,
72+ quant_config = None ,
73+ hqq_compile : bool = False ,
74+ flash_attention_2 : bool = False ,
75+ device = None ):
76+
77+ if hqq_compile :
78+ return self .hqq_compile_model (model_id , quant_config , device )
79+ else :
80+ if flash_attention_2 :
81+ attn_implementation = "flash_attention_2"
82+ else :
83+ attn_implementation = "sdpa"
84+
85+ model = AutoModelForSpeechSeq2Seq .from_pretrained (
86+ model_id ,
87+ quantization_config = quant_config ,
88+ low_cpu_mem_usage = True ,
89+ use_safetensors = True ,
90+ attn_implementation = attn_implementation ,
91+ torch_dtype = torch .bfloat16 ,
92+ device_map = device ,
93+ )
94+
95+ processor = AutoProcessor .from_pretrained (model_id )
96+
97+ return model , processor
98+
3199 def load_plus_model (
32100 self ,
33101 model_id : str = "distil-whisper/distil-large-v3" ,
@@ -36,7 +104,7 @@ def load_plus_model(
36104 flash_attention_2 : bool = True ,
37105 ):
38106
39- model , processor = load_model_whisper (
107+ model , processor = self . load_model_whisper (
40108 model_id = model_id ,
41109 quant_config = quant_config ,
42110 hqq_compile = hqq_compile ,
0 commit comments