Add VAD segmentation and Docker ENV defaults
Some checks failed
Build and Push EVS Bridge Image / docker (push) Has been cancelled

This commit is contained in:
Kai
2026-02-13 16:47:54 +01:00
parent 9dc1ac3099
commit d4d4c7224b
4 changed files with 201 additions and 6 deletions

View File

@@ -51,6 +51,15 @@ WAV_SEGMENT_MAX_BYTES = int(os.getenv("WAV_SEGMENT_MAX_BYTES", str(20 * 1024 * 1
WAV_KEEP_FILES = int(os.getenv("WAV_KEEP_FILES", "10"))
WAV_HEADER_BYTES = 44
VAD_ENABLED = getenv_bool("VAD_ENABLED", True)
VAD_DIR = Path(os.getenv("VAD_DIR", "/data/vad"))
VAD_KEEP_FILES = int(os.getenv("VAD_KEEP_FILES", "100"))
VAD_PREROLL_MS = int(os.getenv("VAD_PREROLL_MS", "1000"))
VAD_POSTROLL_MS = int(os.getenv("VAD_POSTROLL_MS", "1000"))
VAD_START_THRESHOLD = int(os.getenv("VAD_START_THRESHOLD", "900"))
VAD_STOP_THRESHOLD = int(os.getenv("VAD_STOP_THRESHOLD", "600"))
VAD_MIN_SPEECH_MS = int(os.getenv("VAD_MIN_SPEECH_MS", "300"))
@dataclass
class DeviceSession:
@@ -64,6 +73,12 @@ class DeviceSession:
segment_index: int = 0
segment_pcm_buffer: bytearray = field(default_factory=bytearray)
saved_wavs: List[str] = field(default_factory=list)
vad_active: bool = False
vad_silence_ms: int = 0
vad_speech_ms: int = 0
vad_preroll_buffer: bytearray = field(default_factory=bytearray)
vad_segment_buffer: bytearray = field(default_factory=bytearray)
vad_segment_index: int = 0
class BridgeState:
@@ -121,16 +136,16 @@ async def call_ha_webhook(event: str, payload: dict) -> None:
log.exception("ha webhook call failed")
def enforce_wav_retention() -> None:
if not SAVE_SESSIONS:
def enforce_wav_retention(directory: Path, keep_files: int) -> None:
if keep_files <= 0:
return
try:
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
directory.mkdir(parents=True, exist_ok=True)
wavs = sorted(
[p for p in SESSIONS_DIR.glob("*.wav") if p.is_file()],
[p for p in directory.glob("*.wav") if p.is_file()],
key=lambda p: p.stat().st_mtime,
)
while len(wavs) > WAV_KEEP_FILES:
while len(wavs) > keep_files:
oldest = wavs.pop(0)
try:
oldest.unlink()
@@ -158,7 +173,7 @@ def write_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]:
wf.writeframes(pcm)
wf.close()
session.saved_wavs.append(str(path))
enforce_wav_retention()
enforce_wav_retention(SESSIONS_DIR, WAV_KEEP_FILES)
session.segment_index += 1
return str(path)
except Exception:
@@ -166,6 +181,112 @@ def write_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]:
return None
def write_vad_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]:
if not VAD_ENABLED or not pcm:
return None
try:
VAD_DIR.mkdir(parents=True, exist_ok=True)
ts_ms = int(time.time() * 1000)
name = f"{session.device_id}_{ts_ms}_vad{session.vad_segment_index:05d}.wav"
path = VAD_DIR / name
import wave
wf = wave.open(str(path), "wb")
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(PCM_SAMPLE_RATE)
wf.writeframes(pcm)
wf.close()
session.vad_segment_index += 1
enforce_wav_retention(VAD_DIR, VAD_KEEP_FILES)
return str(path)
except Exception:
log.exception("failed to write vad wav segment for %s", session.device_id)
return None
def pcm_avg_abs(data: bytes) -> int:
if len(data) < 2:
return 0
# PCM16LE mono
samples = memoryview(data).cast("h")
total = 0
for s in samples:
total += -s if s < 0 else s
return total // len(samples)
def pcm_duration_ms(pcm_bytes: int) -> int:
bytes_per_second = PCM_SAMPLE_RATE * 2
return int((pcm_bytes * 1000) / bytes_per_second)
def process_vad_frame(device_id: str, session: DeviceSession, data: bytes) -> None:
if not VAD_ENABLED or not data:
return
preroll_max_bytes = max(0, (PCM_SAMPLE_RATE * 2 * VAD_PREROLL_MS) // 1000)
if preroll_max_bytes > 0:
session.vad_preroll_buffer.extend(data)
if len(session.vad_preroll_buffer) > preroll_max_bytes:
drop = len(session.vad_preroll_buffer) - preroll_max_bytes
del session.vad_preroll_buffer[:drop]
level = pcm_avg_abs(data)
frame_ms = pcm_duration_ms(len(data))
if frame_ms <= 0:
return
if not session.vad_active:
if level >= VAD_START_THRESHOLD:
session.vad_active = True
session.vad_silence_ms = 0
session.vad_speech_ms = frame_ms
session.vad_segment_buffer.clear()
if session.vad_preroll_buffer:
session.vad_segment_buffer.extend(session.vad_preroll_buffer)
log.info("vad_start: device=%s level=%s", device_id, level)
return
session.vad_segment_buffer.extend(data)
session.vad_speech_ms += frame_ms
if level < VAD_STOP_THRESHOLD:
session.vad_silence_ms += frame_ms
else:
session.vad_silence_ms = 0
if session.vad_silence_ms < VAD_POSTROLL_MS:
return
segment_bytes = bytes(session.vad_segment_buffer)
segment_ms = pcm_duration_ms(len(segment_bytes))
path = None
if segment_ms >= VAD_MIN_SPEECH_MS:
path = write_vad_wav_segment(session, segment_bytes)
payload = {
"type": "vad_segment",
"ts": time.time(),
"device_id": device_id,
"duration_s": round(segment_ms / 1000.0, 3),
"level": level,
}
if path:
payload["wav_path"] = path
state.publish_status(device_id, payload)
log.info(
"vad_stop: device=%s duration_s=%s wav=%s",
device_id,
payload["duration_s"],
path or "-",
)
session.vad_active = False
session.vad_silence_ms = 0
session.vad_speech_ms = 0
session.vad_segment_buffer.clear()
def append_pcm_with_rotation(session: DeviceSession, data: bytes) -> None:
if not SAVE_SESSIONS or not data:
return
@@ -206,6 +327,11 @@ async def handle_text_message(device_id: str, session: DeviceSession, raw: str)
session.saved_wavs.clear()
session.segment_index = 0
session.segment_pcm_buffer.clear()
session.vad_active = False
session.vad_silence_ms = 0
session.vad_speech_ms = 0
session.vad_preroll_buffer.clear()
session.vad_segment_buffer.clear()
payload = {"type": "start", "ts": time.time(), "device_id": device_id}
state.publish_status(device_id, payload)
await call_ha_webhook("start", payload)
@@ -215,6 +341,11 @@ async def handle_text_message(device_id: str, session: DeviceSession, raw: str)
if msg_type == "stop":
session.ptt_active = False
flush_pending_segment(session)
session.vad_active = False
session.vad_silence_ms = 0
session.vad_speech_ms = 0
session.vad_preroll_buffer.clear()
session.vad_segment_buffer.clear()
metrics = build_metrics(device_id, session)
payload = {"type": "stop", "ts": time.time(), "device_id": device_id, **metrics}
if session.saved_wavs:
@@ -270,6 +401,7 @@ async def handle_binary_message(device_id: str, session: DeviceSession, data: by
# Keep newest data within cap to avoid unbounded memory growth.
drop = len(session.pcm_bytes) - MAX_SESSION_BYTES
del session.pcm_bytes[:drop]
process_vad_frame(device_id, session, data)
if ECHO_ENABLED:
await session.ws.send(data)
@@ -313,6 +445,11 @@ async def ws_handler(ws: WebSocketServerProtocol, path: str) -> None:
if session.ptt_active:
flush_pending_segment(session)
session.ptt_active = False
session.vad_active = False
session.vad_silence_ms = 0
session.vad_speech_ms = 0
session.vad_preroll_buffer.clear()
session.vad_segment_buffer.clear()
if state.devices.get(device_id) is session:
del state.devices[device_id]
state.publish_status(device_id, {"type": "disconnected", "ts": time.time(), "device_id": device_id})