Files
EVS-Embedded-Voice-System/bridge/app.py
Kai 5f20b38088
Some checks failed
Build and Push EVS Bridge Image / docker (push) Has been cancelled
Simplify client modes and add VAD retention policy
2026-02-13 17:09:54 +01:00

549 lines
18 KiB
Python

import asyncio
import base64
import json
import logging
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional
import aiohttp
import paho.mqtt.client as mqtt
import websockets
from websockets.server import WebSocketServerProtocol
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s %(levelname)s %(message)s",
)
log = logging.getLogger("evs-bridge")
def getenv_bool(name: str, default: bool) -> bool:
val = os.getenv(name)
if val is None:
return default
return val.strip().lower() in {"1", "true", "yes", "on"}
WS_HOST = os.getenv("WS_HOST", "0.0.0.0")
WS_PORT = int(os.getenv("WS_PORT", "8765"))
WS_PATH = os.getenv("WS_PATH", "/audio")
ECHO_ENABLED = getenv_bool("ECHO_ENABLED", True)
MQTT_ENABLED = getenv_bool("MQTT_ENABLED", True)
MQTT_HOST = os.getenv("MQTT_HOST", "localhost")
MQTT_PORT = int(os.getenv("MQTT_PORT", "1883"))
MQTT_USER = os.getenv("MQTT_USER", "")
MQTT_PASSWORD = os.getenv("MQTT_PASSWORD", "")
MQTT_BASE_TOPIC = os.getenv("MQTT_BASE_TOPIC", "evs")
MQTT_TTS_TOPIC = os.getenv("MQTT_TTS_TOPIC", f"{MQTT_BASE_TOPIC}/+/play_pcm16le")
MQTT_STATUS_RETAIN = getenv_bool("MQTT_STATUS_RETAIN", True)
HA_WEBHOOK_URL = os.getenv("HA_WEBHOOK_URL", "").strip()
SAVE_SESSIONS = getenv_bool("SAVE_SESSIONS", True)
SESSIONS_DIR = Path(os.getenv("SESSIONS_DIR", "/data/sessions"))
PCM_SAMPLE_RATE = int(os.getenv("PCM_SAMPLE_RATE", "16000"))
MAX_SESSION_BYTES = int(os.getenv("MAX_SESSION_BYTES", "16000000"))
WAV_SEGMENT_MAX_BYTES = int(os.getenv("WAV_SEGMENT_MAX_BYTES", str(20 * 1024 * 1024)))
WAV_KEEP_FILES = int(os.getenv("WAV_KEEP_FILES", "10"))
WAV_HEADER_BYTES = 44
VAD_ENABLED = getenv_bool("VAD_ENABLED", True)
VAD_DIR = Path(os.getenv("VAD_DIR", "/data/vad"))
VAD_KEEP_FILES = int(os.getenv("VAD_KEEP_FILES", "200"))
VAD_MAX_AGE_DAYS = int(os.getenv("VAD_MAX_AGE_DAYS", "7"))
VAD_PREROLL_MS = int(os.getenv("VAD_PREROLL_MS", "1000"))
VAD_POSTROLL_MS = int(os.getenv("VAD_POSTROLL_MS", "1000"))
VAD_START_THRESHOLD = int(os.getenv("VAD_START_THRESHOLD", "900"))
VAD_STOP_THRESHOLD = int(os.getenv("VAD_STOP_THRESHOLD", "600"))
VAD_MIN_SPEECH_MS = int(os.getenv("VAD_MIN_SPEECH_MS", "300"))
@dataclass
class DeviceSession:
device_id: str
ws: WebSocketServerProtocol
connected_at: float = field(default_factory=time.time)
ptt_active: bool = False
pcm_bytes: bytearray = field(default_factory=bytearray)
last_rx_ts: float = field(default_factory=time.time)
rx_bytes_total: int = 0
segment_index: int = 0
segment_pcm_buffer: bytearray = field(default_factory=bytearray)
saved_wavs: List[str] = field(default_factory=list)
vad_active: bool = False
vad_silence_ms: int = 0
vad_speech_ms: int = 0
vad_preroll_buffer: bytearray = field(default_factory=bytearray)
vad_segment_buffer: bytearray = field(default_factory=bytearray)
vad_segment_index: int = 0
class BridgeState:
def __init__(self) -> None:
self.loop: Optional[asyncio.AbstractEventLoop] = None
self.devices: Dict[str, DeviceSession] = {}
self.mqtt_client: Optional[mqtt.Client] = None
def publish_status(self, device_id: str, payload: dict) -> None:
if not self.mqtt_client:
return
topic = f"{MQTT_BASE_TOPIC}/{device_id}/status"
try:
self.mqtt_client.publish(topic, json.dumps(payload), qos=0, retain=MQTT_STATUS_RETAIN)
except Exception:
log.exception("mqtt publish failed")
async def send_binary_to_device(self, device_id: str, pcm_data: bytes) -> bool:
session = self.devices.get(device_id)
if not session:
return False
try:
await session.ws.send(pcm_data)
return True
except Exception:
log.exception("ws send to device failed")
return False
state = BridgeState()
def build_metrics(device_id: str, session: DeviceSession) -> dict:
samples = session.rx_bytes_total // 2
seconds = samples / float(PCM_SAMPLE_RATE)
return {
"device_id": device_id,
"ptt_active": session.ptt_active,
"rx_bytes": session.rx_bytes_total,
"duration_s": round(seconds, 3),
"last_rx_ts": session.last_rx_ts,
}
async def call_ha_webhook(event: str, payload: dict) -> None:
if not HA_WEBHOOK_URL:
return
data = {"event": event, **payload}
try:
async with aiohttp.ClientSession() as client:
async with client.post(HA_WEBHOOK_URL, json=data, timeout=10) as resp:
if resp.status >= 400:
log.warning("ha webhook error status=%s", resp.status)
except Exception:
log.exception("ha webhook call failed")
def enforce_wav_retention(directory: Path, keep_files: int, max_age_days: int = 0) -> None:
if keep_files <= 0 and max_age_days <= 0:
return
try:
directory.mkdir(parents=True, exist_ok=True)
wavs = []
now = time.time()
max_age_seconds = max_age_days * 86400
for p in directory.glob("*.wav"):
if not p.is_file():
continue
try:
st = p.stat()
except Exception:
continue
if max_age_seconds > 0 and (now - st.st_mtime) > max_age_seconds:
try:
p.unlink()
log.info("deleted old wav by age: %s", p)
except Exception:
log.exception("failed to delete old wav by age: %s", p)
continue
wavs.append((p, st.st_mtime))
wavs.sort(key=lambda x: x[1])
files = [p for (p, _) in wavs]
while keep_files > 0 and len(files) > keep_files:
oldest = files.pop(0)
try:
oldest.unlink()
log.info("deleted old wav: %s", oldest)
except Exception:
log.exception("failed to delete old wav: %s", oldest)
except Exception:
log.exception("failed to enforce wav retention")
def write_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]:
if not SAVE_SESSIONS or not pcm:
return None
try:
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
ts_ms = int(time.time() * 1000)
name = f"{session.device_id}_{ts_ms}_part{session.segment_index:03d}.wav"
path = SESSIONS_DIR / name
import wave
wf = wave.open(str(path), "wb")
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(PCM_SAMPLE_RATE)
wf.writeframes(pcm)
wf.close()
session.saved_wavs.append(str(path))
enforce_wav_retention(SESSIONS_DIR, WAV_KEEP_FILES)
session.segment_index += 1
return str(path)
except Exception:
log.exception("failed to write wav segment for %s", session.device_id)
return None
def write_vad_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]:
if not VAD_ENABLED or not pcm:
return None
try:
VAD_DIR.mkdir(parents=True, exist_ok=True)
ts_ms = int(time.time() * 1000)
name = f"{session.device_id}_{ts_ms}_vad{session.vad_segment_index:05d}.wav"
path = VAD_DIR / name
import wave
wf = wave.open(str(path), "wb")
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(PCM_SAMPLE_RATE)
wf.writeframes(pcm)
wf.close()
session.vad_segment_index += 1
enforce_wav_retention(VAD_DIR, VAD_KEEP_FILES, VAD_MAX_AGE_DAYS)
return str(path)
except Exception:
log.exception("failed to write vad wav segment for %s", session.device_id)
return None
def pcm_avg_abs(data: bytes) -> int:
if len(data) < 2:
return 0
# PCM16LE mono
samples = memoryview(data).cast("h")
total = 0
for s in samples:
total += -s if s < 0 else s
return total // len(samples)
def pcm_duration_ms(pcm_bytes: int) -> int:
bytes_per_second = PCM_SAMPLE_RATE * 2
return int((pcm_bytes * 1000) / bytes_per_second)
def process_vad_frame(device_id: str, session: DeviceSession, data: bytes) -> None:
if not VAD_ENABLED or not data:
return
preroll_max_bytes = max(0, (PCM_SAMPLE_RATE * 2 * VAD_PREROLL_MS) // 1000)
if preroll_max_bytes > 0:
session.vad_preroll_buffer.extend(data)
if len(session.vad_preroll_buffer) > preroll_max_bytes:
drop = len(session.vad_preroll_buffer) - preroll_max_bytes
del session.vad_preroll_buffer[:drop]
level = pcm_avg_abs(data)
frame_ms = pcm_duration_ms(len(data))
if frame_ms <= 0:
return
if not session.vad_active:
if level >= VAD_START_THRESHOLD:
session.vad_active = True
session.vad_silence_ms = 0
session.vad_speech_ms = frame_ms
session.vad_segment_buffer.clear()
if session.vad_preroll_buffer:
session.vad_segment_buffer.extend(session.vad_preroll_buffer)
log.info("vad_start: device=%s level=%s", device_id, level)
return
session.vad_segment_buffer.extend(data)
session.vad_speech_ms += frame_ms
if level < VAD_STOP_THRESHOLD:
session.vad_silence_ms += frame_ms
else:
session.vad_silence_ms = 0
if session.vad_silence_ms < VAD_POSTROLL_MS:
return
segment_bytes = bytes(session.vad_segment_buffer)
segment_ms = pcm_duration_ms(len(segment_bytes))
path = None
if segment_ms >= VAD_MIN_SPEECH_MS:
path = write_vad_wav_segment(session, segment_bytes)
payload = {
"type": "vad_segment",
"ts": time.time(),
"device_id": device_id,
"duration_s": round(segment_ms / 1000.0, 3),
"level": level,
}
if path:
payload["wav_path"] = path
state.publish_status(device_id, payload)
log.info(
"vad_stop: device=%s duration_s=%s wav=%s",
device_id,
payload["duration_s"],
path or "-",
)
session.vad_active = False
session.vad_silence_ms = 0
session.vad_speech_ms = 0
session.vad_segment_buffer.clear()
def append_pcm_with_rotation(session: DeviceSession, data: bytes) -> None:
if not SAVE_SESSIONS or not data:
return
if WAV_SEGMENT_MAX_BYTES <= WAV_HEADER_BYTES:
log.warning("WAV_SEGMENT_MAX_BYTES too small, minimum is %s", WAV_HEADER_BYTES + 2)
return
max_pcm_per_file = WAV_SEGMENT_MAX_BYTES - WAV_HEADER_BYTES
session.segment_pcm_buffer.extend(data)
while len(session.segment_pcm_buffer) >= max_pcm_per_file:
chunk = bytes(session.segment_pcm_buffer[:max_pcm_per_file])
del session.segment_pcm_buffer[:max_pcm_per_file]
write_wav_segment(session, chunk)
def flush_pending_segment(session: DeviceSession) -> None:
if not SAVE_SESSIONS:
return
if not session.segment_pcm_buffer:
return
chunk = bytes(session.segment_pcm_buffer)
session.segment_pcm_buffer.clear()
write_wav_segment(session, chunk)
async def handle_text_message(device_id: str, session: DeviceSession, raw: str) -> None:
try:
msg = json.loads(raw)
except Exception:
log.warning("invalid text frame from %s: %s", device_id, raw[:80])
return
msg_type = msg.get("type")
if msg_type == "start":
session.ptt_active = True
session.pcm_bytes.clear()
session.rx_bytes_total = 0
session.saved_wavs.clear()
session.segment_index = 0
session.segment_pcm_buffer.clear()
session.vad_active = False
session.vad_silence_ms = 0
session.vad_speech_ms = 0
session.vad_preroll_buffer.clear()
session.vad_segment_buffer.clear()
payload = {"type": "start", "ts": time.time(), "device_id": device_id}
state.publish_status(device_id, payload)
await call_ha_webhook("start", payload)
log.info("start: device=%s", device_id)
return
if msg_type == "stop":
session.ptt_active = False
flush_pending_segment(session)
session.vad_active = False
session.vad_silence_ms = 0
session.vad_speech_ms = 0
session.vad_preroll_buffer.clear()
session.vad_segment_buffer.clear()
metrics = build_metrics(device_id, session)
payload = {"type": "stop", "ts": time.time(), "device_id": device_id, **metrics}
if session.saved_wavs:
payload["wav_path"] = session.saved_wavs[-1]
payload["wav_paths"] = session.saved_wavs
state.publish_status(device_id, payload)
await call_ha_webhook("stop", payload)
log.info(
"stop: device=%s bytes=%s duration_s=%s wav_count=%s last_wav=%s",
device_id,
metrics.get("rx_bytes", 0),
metrics.get("duration_s", 0),
len(session.saved_wavs),
session.saved_wavs[-1] if session.saved_wavs else "-",
)
return
if msg_type == "ping":
await session.ws.send(json.dumps({"type": "pong", "ts": time.time()}))
return
if msg_type == "mic_level":
payload = {
"type": "mic_level",
"ts": time.time(),
"device_id": device_id,
"peak": msg.get("peak", 0),
"avg_abs": msg.get("avg_abs", 0),
"samples": msg.get("samples", 0),
}
state.publish_status(device_id, payload)
log.info(
"mic_level: device=%s peak=%s avg_abs=%s samples=%s",
device_id,
payload["peak"],
payload["avg_abs"],
payload["samples"],
)
return
log.info("text msg from %s: %s", device_id, msg)
async def handle_binary_message(device_id: str, session: DeviceSession, data: bytes) -> None:
session.last_rx_ts = time.time()
if session.ptt_active:
session.rx_bytes_total += len(data)
if SAVE_SESSIONS:
append_pcm_with_rotation(session, data)
else:
session.pcm_bytes.extend(data)
if len(session.pcm_bytes) > MAX_SESSION_BYTES:
# Keep newest data within cap to avoid unbounded memory growth.
drop = len(session.pcm_bytes) - MAX_SESSION_BYTES
del session.pcm_bytes[:drop]
process_vad_frame(device_id, session, data)
if ECHO_ENABLED:
await session.ws.send(data)
def parse_device_id(path: str) -> str:
# expected:
# /audio
# /audio?device_id=esp32-kitchen
if "?" not in path:
return "esp32-unknown"
try:
from urllib.parse import parse_qs, urlsplit
q = parse_qs(urlsplit(path).query)
return q.get("device_id", ["esp32-unknown"])[0]
except Exception:
return "esp32-unknown"
async def ws_handler(ws: WebSocketServerProtocol, path: str) -> None:
if not path.startswith(WS_PATH):
await ws.close(code=1008, reason="Invalid path")
return
device_id = parse_device_id(path)
session = DeviceSession(device_id=device_id, ws=ws)
state.devices[device_id] = session
state.publish_status(device_id, {"type": "connected", "ts": time.time(), "device_id": device_id})
await call_ha_webhook("connected", {"device_id": device_id, "ts": time.time()})
log.info("device connected: %s", device_id)
try:
async for message in ws:
if isinstance(message, bytes):
await handle_binary_message(device_id, session, message)
else:
await handle_text_message(device_id, session, message)
except websockets.ConnectionClosed:
pass
finally:
if session.ptt_active:
flush_pending_segment(session)
session.ptt_active = False
session.vad_active = False
session.vad_silence_ms = 0
session.vad_speech_ms = 0
session.vad_preroll_buffer.clear()
session.vad_segment_buffer.clear()
if state.devices.get(device_id) is session:
del state.devices[device_id]
state.publish_status(device_id, {"type": "disconnected", "ts": time.time(), "device_id": device_id})
await call_ha_webhook("disconnected", {"device_id": device_id, "ts": time.time()})
log.info("device disconnected: %s", device_id)
def on_mqtt_connect(client: mqtt.Client, _userdata, _flags, reason_code, _properties=None):
if reason_code == 0:
log.info("mqtt connected")
client.subscribe(MQTT_TTS_TOPIC, qos=0)
else:
log.error("mqtt connect failed reason=%s", reason_code)
def on_mqtt_message(_client: mqtt.Client, _userdata, msg: mqtt.MQTTMessage):
# topic: evs/<device_id>/play_pcm16le
try:
parts = msg.topic.split("/")
if len(parts) < 3:
return
device_id = parts[1]
# payload options:
# 1) raw binary PCM16LE
# 2) json {"pcm16le_b64":"..."}
payload = msg.payload
if payload.startswith(b"{"):
doc = json.loads(payload.decode("utf-8"))
b64 = doc.get("pcm16le_b64", "")
if not b64:
return
pcm = base64.b64decode(b64)
else:
pcm = bytes(payload)
if not state.loop:
return
fut = asyncio.run_coroutine_threadsafe(state.send_binary_to_device(device_id, pcm), state.loop)
_ = fut.result(timeout=2)
except Exception:
log.exception("mqtt message handling failed")
def setup_mqtt(loop: asyncio.AbstractEventLoop) -> Optional[mqtt.Client]:
if not MQTT_ENABLED:
log.info("mqtt disabled")
return None
client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="evs-bridge")
if MQTT_USER:
client.username_pw_set(MQTT_USER, MQTT_PASSWORD)
client.on_connect = on_mqtt_connect
client.on_message = on_mqtt_message
client.connect(MQTT_HOST, MQTT_PORT, keepalive=30)
client.loop_start()
log.info("mqtt connecting to %s:%s", MQTT_HOST, MQTT_PORT)
return client
async def main():
state.loop = asyncio.get_running_loop()
state.mqtt_client = setup_mqtt(state.loop)
ws_server = await websockets.serve(ws_handler, WS_HOST, WS_PORT, max_size=2**22)
log.info("ws listening on ws://%s:%s%s", WS_HOST, WS_PORT, WS_PATH)
try:
await ws_server.wait_closed()
finally:
if state.mqtt_client:
state.mqtt_client.loop_stop()
state.mqtt_client.disconnect()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass