From d4d4c7224be5edab45d7f26a92f41f99bb20d5df Mon Sep 17 00:00:00 2001 From: Kai Date: Fri, 13 Feb 2026 16:47:54 +0100 Subject: [PATCH] Add VAD segmentation and Docker ENV defaults --- bridge/.env.example | 9 +++ bridge/Dockerfile | 30 +++++++++ bridge/README.md | 19 ++++++ bridge/app.py | 149 ++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 201 insertions(+), 6 deletions(-) diff --git a/bridge/.env.example b/bridge/.env.example index 5bd68ca..7bc4534 100644 --- a/bridge/.env.example +++ b/bridge/.env.example @@ -23,3 +23,12 @@ PCM_SAMPLE_RATE=16000 MAX_SESSION_BYTES=16000000 WAV_SEGMENT_MAX_BYTES=20971520 WAV_KEEP_FILES=10 + +VAD_ENABLED=true +VAD_DIR=/data/vad +VAD_KEEP_FILES=100 +VAD_PREROLL_MS=1000 +VAD_POSTROLL_MS=1000 +VAD_START_THRESHOLD=900 +VAD_STOP_THRESHOLD=600 +VAD_MIN_SPEECH_MS=300 diff --git a/bridge/Dockerfile b/bridge/Dockerfile index dfa0329..ca1dac6 100644 --- a/bridge/Dockerfile +++ b/bridge/Dockerfile @@ -7,4 +7,34 @@ RUN pip install --no-cache-dir -r requirements.txt COPY app.py . +# Default runtime configuration. Can be overridden with docker run -e or Portainer stack env. +ENV WS_HOST=0.0.0.0 \ + WS_PORT=8765 \ + WS_PATH=/audio \ + ECHO_ENABLED=true \ + LOG_LEVEL=INFO \ + MQTT_ENABLED=true \ + MQTT_HOST=localhost \ + MQTT_PORT=1883 \ + MQTT_USER= \ + MQTT_PASSWORD= \ + MQTT_BASE_TOPIC=evs \ + MQTT_TTS_TOPIC=evs/+/play_pcm16le \ + MQTT_STATUS_RETAIN=true \ + HA_WEBHOOK_URL= \ + SAVE_SESSIONS=true \ + SESSIONS_DIR=/data/sessions \ + PCM_SAMPLE_RATE=16000 \ + MAX_SESSION_BYTES=16000000 \ + WAV_SEGMENT_MAX_BYTES=20971520 \ + WAV_KEEP_FILES=10 \ + VAD_ENABLED=true \ + VAD_DIR=/data/vad \ + VAD_KEEP_FILES=100 \ + VAD_PREROLL_MS=1000 \ + VAD_POSTROLL_MS=1000 \ + VAD_START_THRESHOLD=900 \ + VAD_STOP_THRESHOLD=600 \ + VAD_MIN_SPEECH_MS=300 + CMD ["python", "app.py"] diff --git a/bridge/README.md b/bridge/README.md index c094c14..b469d7f 100644 --- a/bridge/README.md +++ b/bridge/README.md @@ -7,9 +7,12 @@ It provides: - MQTT status/events (`evs//status`) - MQTT playback input (`evs//play_pcm16le`) - Optional Home Assistant webhook callbacks (`connected`, `start`, `stop`, `disconnected`) +- VAD auto-segmentation (`vad_segment`) with pre-roll/post-roll ## 1) Start the bridge +The image already contains sane default `ENV` values. A custom `.env` is optional. + 1. Copy env template: ```bash cp .env.example .env @@ -58,6 +61,7 @@ Then upload firmware. - Status/events published by bridge: - `evs//status` (JSON) + - includes `type: "vad_segment"` when a speech segment is finalized - Playback input to device: - `evs//play_pcm16le` - payload options: @@ -86,6 +90,13 @@ You can build automations on these events (for STT/TTS pipelines or Node-RED han - `WAV_SEGMENT_MAX_BYTES` max size per `.wav` file (default: `20971520` = 20 MB) - `WAV_KEEP_FILES` max number of `.wav` files to keep (default: `10`) - `MAX_SESSION_BYTES` is only used if session file saving is disabled +- Voice activity detection (VAD): + - `VAD_ENABLED=true` enables automatic speech segment detection + - `VAD_PREROLL_MS=1000` keeps 1s before speech start + - `VAD_POSTROLL_MS=1000` keeps 1s after speech end + - `VAD_START_THRESHOLD` / `VAD_STOP_THRESHOLD` tune sensitivity + - `VAD_DIR` stores per-utterance WAV files + - `VAD_KEEP_FILES` limits stored VAD WAV files - MQTT is recommended for control/events, WebSocket for streaming audio ## 7) Build and push to Gitea registry @@ -135,6 +146,14 @@ services: PCM_SAMPLE_RATE: "16000" WAV_SEGMENT_MAX_BYTES: "20971520" WAV_KEEP_FILES: "10" + VAD_ENABLED: "true" + VAD_DIR: "/data/vad" + VAD_KEEP_FILES: "100" + VAD_PREROLL_MS: "1000" + VAD_POSTROLL_MS: "1000" + VAD_START_THRESHOLD: "900" + VAD_STOP_THRESHOLD: "600" + VAD_MIN_SPEECH_MS: "300" volumes: - evs_bridge_data:/data diff --git a/bridge/app.py b/bridge/app.py index 1461c60..bc5bc01 100644 --- a/bridge/app.py +++ b/bridge/app.py @@ -51,6 +51,15 @@ WAV_SEGMENT_MAX_BYTES = int(os.getenv("WAV_SEGMENT_MAX_BYTES", str(20 * 1024 * 1 WAV_KEEP_FILES = int(os.getenv("WAV_KEEP_FILES", "10")) WAV_HEADER_BYTES = 44 +VAD_ENABLED = getenv_bool("VAD_ENABLED", True) +VAD_DIR = Path(os.getenv("VAD_DIR", "/data/vad")) +VAD_KEEP_FILES = int(os.getenv("VAD_KEEP_FILES", "100")) +VAD_PREROLL_MS = int(os.getenv("VAD_PREROLL_MS", "1000")) +VAD_POSTROLL_MS = int(os.getenv("VAD_POSTROLL_MS", "1000")) +VAD_START_THRESHOLD = int(os.getenv("VAD_START_THRESHOLD", "900")) +VAD_STOP_THRESHOLD = int(os.getenv("VAD_STOP_THRESHOLD", "600")) +VAD_MIN_SPEECH_MS = int(os.getenv("VAD_MIN_SPEECH_MS", "300")) + @dataclass class DeviceSession: @@ -64,6 +73,12 @@ class DeviceSession: segment_index: int = 0 segment_pcm_buffer: bytearray = field(default_factory=bytearray) saved_wavs: List[str] = field(default_factory=list) + vad_active: bool = False + vad_silence_ms: int = 0 + vad_speech_ms: int = 0 + vad_preroll_buffer: bytearray = field(default_factory=bytearray) + vad_segment_buffer: bytearray = field(default_factory=bytearray) + vad_segment_index: int = 0 class BridgeState: @@ -121,16 +136,16 @@ async def call_ha_webhook(event: str, payload: dict) -> None: log.exception("ha webhook call failed") -def enforce_wav_retention() -> None: - if not SAVE_SESSIONS: +def enforce_wav_retention(directory: Path, keep_files: int) -> None: + if keep_files <= 0: return try: - SESSIONS_DIR.mkdir(parents=True, exist_ok=True) + directory.mkdir(parents=True, exist_ok=True) wavs = sorted( - [p for p in SESSIONS_DIR.glob("*.wav") if p.is_file()], + [p for p in directory.glob("*.wav") if p.is_file()], key=lambda p: p.stat().st_mtime, ) - while len(wavs) > WAV_KEEP_FILES: + while len(wavs) > keep_files: oldest = wavs.pop(0) try: oldest.unlink() @@ -158,7 +173,7 @@ def write_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]: wf.writeframes(pcm) wf.close() session.saved_wavs.append(str(path)) - enforce_wav_retention() + enforce_wav_retention(SESSIONS_DIR, WAV_KEEP_FILES) session.segment_index += 1 return str(path) except Exception: @@ -166,6 +181,112 @@ def write_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]: return None +def write_vad_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]: + if not VAD_ENABLED or not pcm: + return None + try: + VAD_DIR.mkdir(parents=True, exist_ok=True) + ts_ms = int(time.time() * 1000) + name = f"{session.device_id}_{ts_ms}_vad{session.vad_segment_index:05d}.wav" + path = VAD_DIR / name + import wave + + wf = wave.open(str(path), "wb") + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(PCM_SAMPLE_RATE) + wf.writeframes(pcm) + wf.close() + session.vad_segment_index += 1 + enforce_wav_retention(VAD_DIR, VAD_KEEP_FILES) + return str(path) + except Exception: + log.exception("failed to write vad wav segment for %s", session.device_id) + return None + + +def pcm_avg_abs(data: bytes) -> int: + if len(data) < 2: + return 0 + # PCM16LE mono + samples = memoryview(data).cast("h") + total = 0 + for s in samples: + total += -s if s < 0 else s + return total // len(samples) + + +def pcm_duration_ms(pcm_bytes: int) -> int: + bytes_per_second = PCM_SAMPLE_RATE * 2 + return int((pcm_bytes * 1000) / bytes_per_second) + + +def process_vad_frame(device_id: str, session: DeviceSession, data: bytes) -> None: + if not VAD_ENABLED or not data: + return + + preroll_max_bytes = max(0, (PCM_SAMPLE_RATE * 2 * VAD_PREROLL_MS) // 1000) + if preroll_max_bytes > 0: + session.vad_preroll_buffer.extend(data) + if len(session.vad_preroll_buffer) > preroll_max_bytes: + drop = len(session.vad_preroll_buffer) - preroll_max_bytes + del session.vad_preroll_buffer[:drop] + + level = pcm_avg_abs(data) + frame_ms = pcm_duration_ms(len(data)) + if frame_ms <= 0: + return + + if not session.vad_active: + if level >= VAD_START_THRESHOLD: + session.vad_active = True + session.vad_silence_ms = 0 + session.vad_speech_ms = frame_ms + session.vad_segment_buffer.clear() + if session.vad_preroll_buffer: + session.vad_segment_buffer.extend(session.vad_preroll_buffer) + log.info("vad_start: device=%s level=%s", device_id, level) + return + + session.vad_segment_buffer.extend(data) + session.vad_speech_ms += frame_ms + if level < VAD_STOP_THRESHOLD: + session.vad_silence_ms += frame_ms + else: + session.vad_silence_ms = 0 + + if session.vad_silence_ms < VAD_POSTROLL_MS: + return + + segment_bytes = bytes(session.vad_segment_buffer) + segment_ms = pcm_duration_ms(len(segment_bytes)) + path = None + if segment_ms >= VAD_MIN_SPEECH_MS: + path = write_vad_wav_segment(session, segment_bytes) + + payload = { + "type": "vad_segment", + "ts": time.time(), + "device_id": device_id, + "duration_s": round(segment_ms / 1000.0, 3), + "level": level, + } + if path: + payload["wav_path"] = path + state.publish_status(device_id, payload) + log.info( + "vad_stop: device=%s duration_s=%s wav=%s", + device_id, + payload["duration_s"], + path or "-", + ) + + session.vad_active = False + session.vad_silence_ms = 0 + session.vad_speech_ms = 0 + session.vad_segment_buffer.clear() + + def append_pcm_with_rotation(session: DeviceSession, data: bytes) -> None: if not SAVE_SESSIONS or not data: return @@ -206,6 +327,11 @@ async def handle_text_message(device_id: str, session: DeviceSession, raw: str) session.saved_wavs.clear() session.segment_index = 0 session.segment_pcm_buffer.clear() + session.vad_active = False + session.vad_silence_ms = 0 + session.vad_speech_ms = 0 + session.vad_preroll_buffer.clear() + session.vad_segment_buffer.clear() payload = {"type": "start", "ts": time.time(), "device_id": device_id} state.publish_status(device_id, payload) await call_ha_webhook("start", payload) @@ -215,6 +341,11 @@ async def handle_text_message(device_id: str, session: DeviceSession, raw: str) if msg_type == "stop": session.ptt_active = False flush_pending_segment(session) + session.vad_active = False + session.vad_silence_ms = 0 + session.vad_speech_ms = 0 + session.vad_preroll_buffer.clear() + session.vad_segment_buffer.clear() metrics = build_metrics(device_id, session) payload = {"type": "stop", "ts": time.time(), "device_id": device_id, **metrics} if session.saved_wavs: @@ -270,6 +401,7 @@ async def handle_binary_message(device_id: str, session: DeviceSession, data: by # Keep newest data within cap to avoid unbounded memory growth. drop = len(session.pcm_bytes) - MAX_SESSION_BYTES del session.pcm_bytes[:drop] + process_vad_frame(device_id, session, data) if ECHO_ENABLED: await session.ws.send(data) @@ -313,6 +445,11 @@ async def ws_handler(ws: WebSocketServerProtocol, path: str) -> None: if session.ptt_active: flush_pending_segment(session) session.ptt_active = False + session.vad_active = False + session.vad_silence_ms = 0 + session.vad_speech_ms = 0 + session.vad_preroll_buffer.clear() + session.vad_segment_buffer.clear() if state.devices.get(device_id) is session: del state.devices[device_id] state.publish_status(device_id, {"type": "disconnected", "ts": time.time(), "device_id": device_id})