Skip to content

Commit 793466e

Browse files
feat: LID shadow mode, telemetry, reviewer fixes + null multilingual … (#680)
* feat: LID shadow mode, telemetry, reviewer fixes + null multilingual guard * bump: version 0.10.21 → 0.10.22 * feat: export LIDProvider/SarvamLID, add multilingual+active to Transcriber model * bump: version 0.10.22 → 0.10.23 --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 9f21754 commit 793466e

7 files changed

Lines changed: 441 additions & 8 deletions

File tree

bolna/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.10.21"
1+
__version__ = "0.10.23"
22

33
import os
44
from bolna.helpers.logger_config import configure_logger

bolna/agent_manager/task_manager.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,9 @@ def __init__(
348348

349349
# Stores structured API call records for dashboard/backend persistence.
350350
self.function_tool_api_call_details = []
351+
# Records every manual switch_language tool call — used post-call to
352+
# compare against LID shadow detections (precision / latency analysis).
353+
self.language_switch_events: list[dict] = []
351354
self.hangup_task = None
352355

353356
self.conversation_config = None
@@ -1041,15 +1044,27 @@ def __setup_transcriber(self):
10411044
if label == active_label:
10421045
self.transcriber_provider = cfg.get("provider", cfg.get("model"))
10431046

1047+
# Audio LID tap.
1048+
# LID_PROVIDER — which backend to use (default: "sarvam").
1049+
# LID_MODE — "shadow" (default) logs detections without switching;
1050+
# "active" performs live transcriber/synthesizer/prompt swap.
1051+
# Keep "shadow" until detection quality is validated.
1052+
LID_PROVIDER = os.getenv("LID_PROVIDER", "sarvam")
1053+
_lid_config = {"telephony_provider": provider}
1054+
10441055
self.tools["transcriber"] = TranscriberPool(
10451056
transcribers=transcribers,
10461057
shared_input_queue=self.audio_queue,
10471058
output_queue=self.transcriber_output_queue,
10481059
active_label=active_label,
10491060
multilingual_config=multilingual,
1061+
lid_provider=LID_PROVIDER,
1062+
lid_config=_lid_config,
1063+
on_lid_switch=self.switch_language,
10501064
)
10511065
logger.info(
1052-
f"TranscriberPool created with labels={list(transcribers.keys())}, active='{active_label}'"
1066+
f"TranscriberPool created with labels={list(transcribers.keys())}, "
1067+
f"active='{active_label}', lid_provider={LID_PROVIDER!r}"
10531068
)
10541069
return
10551070

@@ -3329,14 +3344,27 @@ def _get_voice_name_for_label(self, label):
33293344
"""Get agent name for a language label from configured agent_names."""
33303345
return self.agent_names.get(label, "")
33313346

3332-
async def switch_language(self, label, components=None):
3347+
async def switch_language(self, label, components=None, triggered_by: str = "manual"):
33333348
"""Switch the active language for multilingual pools.
33343349
33353350
Args:
33363351
label: language label to switch to (e.g. "hi", "en").
33373352
components: list of component names to switch. Defaults to both.
3353+
triggered_by: "manual" (LLM tool call) or "lid" (automatic detection).
3354+
Used in post-call telemetry to compare LID shadow detections
3355+
against actual LLM-decided switches.
33383356
"""
33393357
components = components or ["transcriber", "synthesizer"]
3358+
3359+
# Record every switch so shadow-eval can compare LID detections vs.
3360+
# actual LLM-decided switches on the same call.
3361+
self.language_switch_events.append({
3362+
"to_label": label,
3363+
"from_label": self.language,
3364+
"triggered_by": triggered_by,
3365+
"switched_at": time.time(),
3366+
})
3367+
33403368
if "transcriber" in components and isinstance(self.tools.get("transcriber"), TranscriberPool):
33413369
await self.tools["transcriber"].switch(label)
33423370
if "synthesizer" in components and isinstance(self.tools.get("synthesizer"), SynthesizerPool):
@@ -4233,6 +4261,8 @@ async def run(self):
42334261
"conversation_time": time.time() - self.start_time,
42344262
"label_flow": self.label_flow,
42354263
"function_tool_api_call_details": copy.deepcopy(self.function_tool_api_call_details),
4264+
"lid_detection_events": list(getattr(self.tools.get("transcriber"), "lid_detection_events", [])),
4265+
"language_switch_events": list(self.language_switch_events),
42364266
"call_sid": self.call_sid,
42374267
"stream_sid": self.stream_sid,
42384268
"transcriber_duration": self.transcriber_duration,

bolna/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ class Transcriber(BaseModel):
122122
keywords: Optional[str] = None
123123
task: Optional[str] = "transcribe"
124124
provider: Optional[str] = "deepgram"
125+
multilingual: Optional[Dict[str, Any]] = None
126+
active: Optional[str] = None
125127

126128
@field_validator("provider")
127129
def validate_model(cls, value):

bolna/transcriber/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@
99
from .elevenlabs_transcriber import ElevenLabsTranscriber
1010
from .smallest_transcriber import SmallestTranscriber
1111
from .transcriber_pool import TranscriberPool
12+
from .lid_provider import LIDProvider, SarvamLID

bolna/transcriber/lid_provider.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""
2+
lid_provider.py — Language Identification (LID) via Sarvam saaras:v3.
3+
4+
Opens a dedicated WebSocket to Sarvam with language-code=unknown so the
5+
server auto-detects the spoken language and returns language_code in each
6+
data payload alongside the transcript. Audio is forwarded in real-time
7+
from the TranscriberPool audio router — zero added latency to the ASR path.
8+
9+
Usage (in TranscriberPool):
10+
lid = SarvamLID(on_language=callback, config={...})
11+
await lid.start()
12+
lid.feed(audio_chunk_bytes) # called for every incoming audio packet
13+
await lid.stop()
14+
"""
15+
16+
from __future__ import annotations
17+
18+
import asyncio
19+
import base64
20+
import io
21+
import json
22+
import os
23+
import wave
24+
from typing import Awaitable, Callable, Optional
25+
26+
from bolna.helpers.logger_config import configure_logger
27+
28+
logger = configure_logger(__name__)
29+
30+
# Signature: async def on_language(lang: str, confidence: float) -> None
31+
OnLanguageCallback = Callable[[str, float], Awaitable[None]]
32+
33+
34+
class SarvamLID:
35+
"""
36+
LID via Sarvam saaras:v3 with language_code=unknown.
37+
38+
Config keys (all optional, fall back to env vars):
39+
sarvam_api_key — SARVAM_API_KEY env var
40+
sarvam_host — api.sarvam.ai
41+
telephony_provider — "twilio" | "plivo" | other
42+
sampling_rate — 16000
43+
"""
44+
45+
_WS_BASE = "wss://{host}/speech-to-text/ws"
46+
47+
def __init__(self, on_language: OnLanguageCallback, config: dict):
48+
self.on_language = on_language
49+
self.config = config
50+
self._api_key = config.get("sarvam_api_key") or os.getenv("SARVAM_API_KEY", "")
51+
self._host = config.get("sarvam_host") or os.getenv("SARVAM_HOST", "api.sarvam.ai")
52+
self._telephony = config.get("telephony_provider", "")
53+
self._sr = int(config.get("sampling_rate", 16000))
54+
self._input_sr = 8000 if self._telephony in ("twilio", "plivo") else self._sr
55+
self._encoding = "mulaw" if self._telephony == "twilio" else "linear16"
56+
57+
# Bounded queue: LID is best-effort. If the Sarvam WS stalls, we drop
58+
# chunks rather than buffering unboundedly for the entire call duration.
59+
self._queue: asyncio.Queue = asyncio.Queue(maxsize=200)
60+
self._ws = None
61+
self._sender_task: Optional[asyncio.Task] = None
62+
self._receiver_task: Optional[asyncio.Task] = None
63+
# Set to True if the receiver loop exits abnormally (WS drop / error).
64+
# feed() will log a warning when dead so silent stat bias is visible.
65+
self._dead: bool = False
66+
67+
def _build_url(self) -> str:
68+
params = {
69+
"model": "saaras:v3",
70+
"mode": "transcribe",
71+
"language-code": "unknown",
72+
"high_vad_sensitivity": "true",
73+
}
74+
qs = "&".join(f"{k}={v}" for k, v in params.items())
75+
return f"{self._WS_BASE.format(host=self._host)}?{qs}"
76+
77+
def _convert_to_wav_b64(self, raw: bytes) -> Optional[str]:
78+
"""Convert telephony audio to 16kHz WAV base64 for Sarvam."""
79+
import audioop
80+
try:
81+
if self._encoding == "mulaw":
82+
raw = audioop.ulaw2lin(raw, 2)
83+
if self._input_sr != self._sr:
84+
raw, _ = audioop.ratecv(raw, 2, 1, self._input_sr, self._sr, None)
85+
buf = io.BytesIO()
86+
with wave.open(buf, "wb") as wf:
87+
wf.setnchannels(1)
88+
wf.setsampwidth(2)
89+
wf.setframerate(self._sr)
90+
wf.writeframes(raw)
91+
return base64.b64encode(buf.getvalue()).decode()
92+
except Exception as e:
93+
logger.warning(f"SarvamLID audio convert error: {e}")
94+
return None
95+
96+
async def start(self) -> None:
97+
import websockets as ws_lib
98+
url = self._build_url()
99+
headers = {"api-subscription-key": self._api_key}
100+
logger.info(f"SarvamLID: connecting to {url}")
101+
self._ws = await ws_lib.connect(url, additional_headers=headers)
102+
self._sender_task = asyncio.create_task(self._sender_loop())
103+
self._receiver_task = asyncio.create_task(self._receiver_loop())
104+
logger.info("SarvamLID: connected")
105+
106+
def feed(self, audio_bytes: bytes) -> None:
107+
if self._dead:
108+
logger.warning("SarvamLID: feed() called but WS is dead — chunk dropped (LID inactive)")
109+
return
110+
try:
111+
self._queue.put_nowait(audio_bytes)
112+
except asyncio.QueueFull:
113+
logger.debug("SarvamLID: audio queue full — chunk dropped (backpressure)")
114+
115+
async def _sender_loop(self) -> None:
116+
try:
117+
while True:
118+
chunk = await self._queue.get()
119+
if chunk is None:
120+
break
121+
b64 = self._convert_to_wav_b64(chunk)
122+
if b64:
123+
msg = {"audio": {"data": b64, "encoding": "audio/wav", "sample_rate": self._sr}}
124+
await self._ws.send(json.dumps(msg))
125+
except asyncio.CancelledError:
126+
pass
127+
except Exception as e:
128+
logger.error(f"SarvamLID sender error: {e}")
129+
self._dead = True
130+
logger.warning("SarvamLID: sender loop exited abnormally — LID inactive for remainder of call")
131+
132+
async def _receiver_loop(self) -> None:
133+
try:
134+
async for raw in self._ws:
135+
try:
136+
data = json.loads(raw) if isinstance(raw, str) else {}
137+
if data.get("type") == "data":
138+
payload = data.get("data", {})
139+
lang = payload.get("language_code", "")
140+
# Sarvam returns language_probability=None when operating in
141+
# unknown-language mode — the language_code is the signal.
142+
# conf is passed through for API compatibility but the pool's
143+
# confidence gate is skipped for Sarvam (see _handle_lid_signal).
144+
conf = float(payload.get("language_probability") or 0.0)
145+
if lang and lang != "unknown":
146+
short = lang.split("-")[0].lower()
147+
logger.info(f"SarvamLID: detected {lang!r} (short={short!r}, conf={conf:.2f})")
148+
await self.on_language(short, conf)
149+
except Exception as e:
150+
logger.error(f"SarvamLID receiver parse error: {e}")
151+
except asyncio.CancelledError:
152+
pass
153+
except Exception as e:
154+
logger.error(f"SarvamLID receiver error: {e}")
155+
self._dead = True
156+
logger.warning("SarvamLID: receiver loop exited abnormally — LID inactive for remainder of call")
157+
158+
async def stop(self) -> None:
159+
self._queue.put_nowait(None)
160+
for task in (self._sender_task, self._receiver_task):
161+
if task and not task.done():
162+
task.cancel()
163+
try:
164+
await task
165+
except asyncio.CancelledError:
166+
pass
167+
if self._ws:
168+
try:
169+
await self._ws.close()
170+
except Exception:
171+
pass
172+
logger.info("SarvamLID: stopped")
173+
174+
175+
# Thin factory shim for backward compatibility
176+
class LIDProvider:
177+
@classmethod
178+
def create(cls, provider: str, on_language: OnLanguageCallback, config: dict) -> SarvamLID:
179+
if provider.lower() != "sarvam":
180+
logger.warning(f"LIDProvider: unknown provider '{provider}', falling back to sarvam")
181+
return SarvamLID(on_language=on_language, config=config)

0 commit comments

Comments
 (0)