import asyncio import base64 import json import logging import os import time from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional import aiohttp import paho.mqtt.client as mqtt import websockets from websockets.server import WebSocketServerProtocol logging.basicConfig( level=os.getenv("LOG_LEVEL", "INFO"), format="%(asctime)s %(levelname)s %(message)s", ) log = logging.getLogger("evs-bridge") def getenv_bool(name: str, default: bool) -> bool: val = os.getenv(name) if val is None: return default return val.strip().lower() in {"1", "true", "yes", "on"} WS_HOST = os.getenv("WS_HOST", "0.0.0.0") WS_PORT = int(os.getenv("WS_PORT", "8765")) WS_PATH = os.getenv("WS_PATH", "/audio") ECHO_ENABLED = getenv_bool("ECHO_ENABLED", True) MQTT_ENABLED = getenv_bool("MQTT_ENABLED", True) MQTT_HOST = os.getenv("MQTT_HOST", "localhost") MQTT_PORT = int(os.getenv("MQTT_PORT", "1883")) MQTT_USER = os.getenv("MQTT_USER", "") MQTT_PASSWORD = os.getenv("MQTT_PASSWORD", "") MQTT_BASE_TOPIC = os.getenv("MQTT_BASE_TOPIC", "evs") MQTT_TTS_TOPIC = os.getenv("MQTT_TTS_TOPIC", f"{MQTT_BASE_TOPIC}/+/play_pcm16le") MQTT_STATUS_RETAIN = getenv_bool("MQTT_STATUS_RETAIN", True) DEVICE_PAIR_MAP_JSON = os.getenv("DEVICE_PAIR_MAP", "").strip() HA_WEBHOOK_URL = os.getenv("HA_WEBHOOK_URL", "").strip() SAVE_SESSIONS = getenv_bool("SAVE_SESSIONS", True) SESSIONS_DIR = Path(os.getenv("SESSIONS_DIR", "/data/sessions")) PCM_SAMPLE_RATE = int(os.getenv("PCM_SAMPLE_RATE", "16000")) MAX_SESSION_BYTES = int(os.getenv("MAX_SESSION_BYTES", "16000000")) WAV_SEGMENT_MAX_BYTES = int(os.getenv("WAV_SEGMENT_MAX_BYTES", str(20 * 1024 * 1024))) WAV_KEEP_FILES = int(os.getenv("WAV_KEEP_FILES", "10")) WAV_HEADER_BYTES = 44 VAD_ENABLED = getenv_bool("VAD_ENABLED", True) VAD_DIR = Path(os.getenv("VAD_DIR", "/data/vad")) VAD_KEEP_FILES = int(os.getenv("VAD_KEEP_FILES", "200")) VAD_MAX_AGE_DAYS = int(os.getenv("VAD_MAX_AGE_DAYS", "7")) VAD_PREROLL_MS = int(os.getenv("VAD_PREROLL_MS", "1000")) VAD_POSTROLL_MS = int(os.getenv("VAD_POSTROLL_MS", "1000")) VAD_START_THRESHOLD = int(os.getenv("VAD_START_THRESHOLD", "900")) VAD_STOP_THRESHOLD = int(os.getenv("VAD_STOP_THRESHOLD", "600")) VAD_MIN_SPEECH_MS = int(os.getenv("VAD_MIN_SPEECH_MS", "300")) @dataclass class DeviceSession: device_id: str ws: WebSocketServerProtocol connected_at: float = field(default_factory=time.time) ptt_active: bool = False pcm_bytes: bytearray = field(default_factory=bytearray) last_rx_ts: float = field(default_factory=time.time) rx_bytes_total: int = 0 segment_index: int = 0 segment_pcm_buffer: bytearray = field(default_factory=bytearray) saved_wavs: List[str] = field(default_factory=list) vad_active: bool = False vad_silence_ms: int = 0 vad_speech_ms: int = 0 vad_preroll_buffer: bytearray = field(default_factory=bytearray) vad_segment_buffer: bytearray = field(default_factory=bytearray) vad_segment_index: int = 0 class BridgeState: def __init__(self) -> None: self.loop: Optional[asyncio.AbstractEventLoop] = None self.devices: Dict[str, DeviceSession] = {} self.mqtt_client: Optional[mqtt.Client] = None def publish_status(self, device_id: str, payload: dict) -> None: if not self.mqtt_client: return msg_type = str(payload.get("type", "status")) retain = MQTT_STATUS_RETAIN if msg_type == "mic_level": topic = f"{MQTT_BASE_TOPIC}/{device_id}/mic_level" retain = False elif msg_type == "vad_segment": topic = f"{MQTT_BASE_TOPIC}/{device_id}/vad_segment" retain = False elif msg_type == "transcript": topic = f"{MQTT_BASE_TOPIC}/{device_id}/transcript" retain = False elif msg_type == "stt_error": topic = f"{MQTT_BASE_TOPIC}/{device_id}/stt_error" retain = False else: topic = f"{MQTT_BASE_TOPIC}/{device_id}/status" try: self.mqtt_client.publish(topic, json.dumps(payload), qos=0, retain=retain) except Exception: log.exception("mqtt publish failed") async def send_binary_to_device(self, device_id: str, pcm_data: bytes) -> bool: session = self.devices.get(device_id) if not session: return False try: await session.ws.send(pcm_data) return True except Exception: log.exception("ws send to device failed") return False state = BridgeState() DEVICE_PAIR_MAP: Dict[str, str] = {} if DEVICE_PAIR_MAP_JSON: try: raw = json.loads(DEVICE_PAIR_MAP_JSON) if isinstance(raw, dict): DEVICE_PAIR_MAP = {str(k): str(v) for k, v in raw.items() if str(k) and str(v)} log.info("device pair map loaded: %s", DEVICE_PAIR_MAP) else: log.warning("DEVICE_PAIR_MAP must be a JSON object") except Exception: log.exception("failed to parse DEVICE_PAIR_MAP") def paired_output_device(device_id: str) -> str: return DEVICE_PAIR_MAP.get(device_id, device_id) def build_metrics(device_id: str, session: DeviceSession) -> dict: samples = session.rx_bytes_total // 2 seconds = samples / float(PCM_SAMPLE_RATE) return { "device_id": device_id, "ptt_active": session.ptt_active, "rx_bytes": session.rx_bytes_total, "duration_s": round(seconds, 3), "last_rx_ts": session.last_rx_ts, } async def call_ha_webhook(event: str, payload: dict) -> None: if not HA_WEBHOOK_URL: return data = {"event": event, **payload} try: async with aiohttp.ClientSession() as client: async with client.post(HA_WEBHOOK_URL, json=data, timeout=10) as resp: if resp.status >= 400: log.warning("ha webhook error status=%s", resp.status) except Exception: log.exception("ha webhook call failed") def enforce_wav_retention(directory: Path, keep_files: int, max_age_days: int = 0) -> None: if keep_files <= 0 and max_age_days <= 0: return try: directory.mkdir(parents=True, exist_ok=True) wavs = [] now = time.time() max_age_seconds = max_age_days * 86400 for p in directory.glob("*.wav"): if not p.is_file(): continue try: st = p.stat() except Exception: continue if max_age_seconds > 0 and (now - st.st_mtime) > max_age_seconds: try: p.unlink() log.info("deleted old wav by age: %s", p) except Exception: log.exception("failed to delete old wav by age: %s", p) continue wavs.append((p, st.st_mtime)) wavs.sort(key=lambda x: x[1]) files = [p for (p, _) in wavs] while keep_files > 0 and len(files) > keep_files: oldest = files.pop(0) try: oldest.unlink() log.info("deleted old wav: %s", oldest) except Exception: log.exception("failed to delete old wav: %s", oldest) except Exception: log.exception("failed to enforce wav retention") def write_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]: if not SAVE_SESSIONS or not pcm: return None try: SESSIONS_DIR.mkdir(parents=True, exist_ok=True) ts_ms = int(time.time() * 1000) name = f"{session.device_id}_{ts_ms}_part{session.segment_index:03d}.wav" path = SESSIONS_DIR / name import wave wf = wave.open(str(path), "wb") wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(PCM_SAMPLE_RATE) wf.writeframes(pcm) wf.close() session.saved_wavs.append(str(path)) enforce_wav_retention(SESSIONS_DIR, WAV_KEEP_FILES) session.segment_index += 1 return str(path) except Exception: log.exception("failed to write wav segment for %s", session.device_id) return None def write_vad_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]: if not VAD_ENABLED or not pcm: return None try: VAD_DIR.mkdir(parents=True, exist_ok=True) ts_ms = int(time.time() * 1000) name = f"{session.device_id}_{ts_ms}_vad{session.vad_segment_index:05d}.wav" path = VAD_DIR / name import wave wf = wave.open(str(path), "wb") wf.setnchannels(1) wf.setsampwidth(2) wf.setframerate(PCM_SAMPLE_RATE) wf.writeframes(pcm) wf.close() session.vad_segment_index += 1 enforce_wav_retention(VAD_DIR, VAD_KEEP_FILES, VAD_MAX_AGE_DAYS) return str(path) except Exception: log.exception("failed to write vad wav segment for %s", session.device_id) return None def pcm_avg_abs(data: bytes) -> int: if len(data) < 2: return 0 # PCM16LE mono samples = memoryview(data).cast("h") total = 0 for s in samples: total += -s if s < 0 else s return total // len(samples) def pcm_duration_ms(pcm_bytes: int) -> int: bytes_per_second = PCM_SAMPLE_RATE * 2 return int((pcm_bytes * 1000) / bytes_per_second) def process_vad_frame(device_id: str, session: DeviceSession, data: bytes) -> None: if not VAD_ENABLED or not data: return preroll_max_bytes = max(0, (PCM_SAMPLE_RATE * 2 * VAD_PREROLL_MS) // 1000) if preroll_max_bytes > 0: session.vad_preroll_buffer.extend(data) if len(session.vad_preroll_buffer) > preroll_max_bytes: drop = len(session.vad_preroll_buffer) - preroll_max_bytes del session.vad_preroll_buffer[:drop] level = pcm_avg_abs(data) frame_ms = pcm_duration_ms(len(data)) if frame_ms <= 0: return if not session.vad_active: if level >= VAD_START_THRESHOLD: session.vad_active = True session.vad_silence_ms = 0 session.vad_speech_ms = frame_ms session.vad_segment_buffer.clear() if session.vad_preroll_buffer: session.vad_segment_buffer.extend(session.vad_preroll_buffer) log.info("vad_start: device=%s level=%s", device_id, level) return session.vad_segment_buffer.extend(data) session.vad_speech_ms += frame_ms if level < VAD_STOP_THRESHOLD: session.vad_silence_ms += frame_ms else: session.vad_silence_ms = 0 if session.vad_silence_ms < VAD_POSTROLL_MS: return segment_bytes = bytes(session.vad_segment_buffer) segment_ms = pcm_duration_ms(len(segment_bytes)) path = None if segment_ms >= VAD_MIN_SPEECH_MS: path = write_vad_wav_segment(session, segment_bytes) payload = { "type": "vad_segment", "ts": time.time(), "device_id": device_id, "duration_s": round(segment_ms / 1000.0, 3), "level": level, } if path: payload["wav_path"] = path state.publish_status(device_id, payload) log.info( "vad_stop: device=%s duration_s=%s wav=%s", device_id, payload["duration_s"], path or "-", ) session.vad_active = False session.vad_silence_ms = 0 session.vad_speech_ms = 0 session.vad_segment_buffer.clear() def append_pcm_with_rotation(session: DeviceSession, data: bytes) -> None: if not SAVE_SESSIONS or not data: return if WAV_SEGMENT_MAX_BYTES <= WAV_HEADER_BYTES: log.warning("WAV_SEGMENT_MAX_BYTES too small, minimum is %s", WAV_HEADER_BYTES + 2) return max_pcm_per_file = WAV_SEGMENT_MAX_BYTES - WAV_HEADER_BYTES session.segment_pcm_buffer.extend(data) while len(session.segment_pcm_buffer) >= max_pcm_per_file: chunk = bytes(session.segment_pcm_buffer[:max_pcm_per_file]) del session.segment_pcm_buffer[:max_pcm_per_file] write_wav_segment(session, chunk) def flush_pending_segment(session: DeviceSession) -> None: if not SAVE_SESSIONS: return if not session.segment_pcm_buffer: return chunk = bytes(session.segment_pcm_buffer) session.segment_pcm_buffer.clear() write_wav_segment(session, chunk) async def handle_text_message(device_id: str, session: DeviceSession, raw: str) -> None: try: msg = json.loads(raw) except Exception: log.warning("invalid text frame from %s: %s", device_id, raw[:80]) return msg_type = msg.get("type") if msg_type == "start": session.ptt_active = True session.pcm_bytes.clear() session.rx_bytes_total = 0 session.saved_wavs.clear() session.segment_index = 0 session.segment_pcm_buffer.clear() session.vad_active = False session.vad_silence_ms = 0 session.vad_speech_ms = 0 session.vad_preroll_buffer.clear() session.vad_segment_buffer.clear() payload = {"type": "start", "ts": time.time(), "device_id": device_id} state.publish_status(device_id, payload) await call_ha_webhook("start", payload) log.info("start: device=%s", device_id) return if msg_type == "stop": session.ptt_active = False flush_pending_segment(session) session.vad_active = False session.vad_silence_ms = 0 session.vad_speech_ms = 0 session.vad_preroll_buffer.clear() session.vad_segment_buffer.clear() metrics = build_metrics(device_id, session) payload = {"type": "stop", "ts": time.time(), "device_id": device_id, **metrics} if session.saved_wavs: payload["wav_path"] = session.saved_wavs[-1] payload["wav_paths"] = session.saved_wavs state.publish_status(device_id, payload) await call_ha_webhook("stop", payload) log.info( "stop: device=%s bytes=%s duration_s=%s wav_count=%s last_wav=%s", device_id, metrics.get("rx_bytes", 0), metrics.get("duration_s", 0), len(session.saved_wavs), session.saved_wavs[-1] if session.saved_wavs else "-", ) return if msg_type == "ping": await session.ws.send(json.dumps({"type": "pong", "ts": time.time()})) return if msg_type == "mic_level": payload = { "type": "mic_level", "ts": time.time(), "device_id": device_id, "peak": msg.get("peak", 0), "avg_abs": msg.get("avg_abs", 0), "samples": msg.get("samples", 0), "mic_gain": msg.get("mic_gain", 0), } state.publish_status(device_id, payload) log.info( "mic_level: device=%s peak=%s avg_abs=%s samples=%s mic_gain=%s", device_id, payload["peak"], payload["avg_abs"], payload["samples"], payload["mic_gain"], ) return log.info("text msg from %s: %s", device_id, msg) async def handle_binary_message(device_id: str, session: DeviceSession, data: bytes) -> None: session.last_rx_ts = time.time() # VAD should work on continuous stream even if explicit start/stop is missing. process_vad_frame(device_id, session, data) if session.ptt_active: session.rx_bytes_total += len(data) if SAVE_SESSIONS: append_pcm_with_rotation(session, data) else: session.pcm_bytes.extend(data) if len(session.pcm_bytes) > MAX_SESSION_BYTES: # Keep newest data within cap to avoid unbounded memory growth. drop = len(session.pcm_bytes) - MAX_SESSION_BYTES del session.pcm_bytes[:drop] if ECHO_ENABLED: target_device = paired_output_device(device_id) ok = await state.send_binary_to_device(target_device, data) if not ok and target_device != device_id: log.debug("paired output device not connected: src=%s target=%s", device_id, target_device) def parse_device_id(path: str) -> str: # expected: # /audio # /audio?device_id=esp32-kitchen if "?" not in path: return "esp32-unknown" try: from urllib.parse import parse_qs, urlsplit q = parse_qs(urlsplit(path).query) return q.get("device_id", ["esp32-unknown"])[0] except Exception: return "esp32-unknown" async def ws_handler(ws: WebSocketServerProtocol, path: str) -> None: if not path.startswith(WS_PATH): await ws.close(code=1008, reason="Invalid path") return device_id = parse_device_id(path) session = DeviceSession(device_id=device_id, ws=ws) state.devices[device_id] = session state.publish_status(device_id, {"type": "connected", "ts": time.time(), "device_id": device_id}) await call_ha_webhook("connected", {"device_id": device_id, "ts": time.time()}) log.info("device connected: %s", device_id) try: async for message in ws: if isinstance(message, bytes): await handle_binary_message(device_id, session, message) else: await handle_text_message(device_id, session, message) except websockets.ConnectionClosed: pass finally: if session.ptt_active: flush_pending_segment(session) session.ptt_active = False session.vad_active = False session.vad_silence_ms = 0 session.vad_speech_ms = 0 session.vad_preroll_buffer.clear() session.vad_segment_buffer.clear() if state.devices.get(device_id) is session: del state.devices[device_id] state.publish_status(device_id, {"type": "disconnected", "ts": time.time(), "device_id": device_id}) await call_ha_webhook("disconnected", {"device_id": device_id, "ts": time.time()}) log.info("device disconnected: %s", device_id) def on_mqtt_connect(client: mqtt.Client, _userdata, _flags, reason_code, _properties=None): if reason_code == 0: log.info("mqtt connected") client.subscribe(MQTT_TTS_TOPIC, qos=0) else: log.error("mqtt connect failed reason=%s", reason_code) def on_mqtt_message(_client: mqtt.Client, _userdata, msg: mqtt.MQTTMessage): # topic: evs//play_pcm16le try: parts = msg.topic.split("/") if len(parts) < 3: return device_id = parts[1] # payload options: # 1) raw binary PCM16LE # 2) json {"pcm16le_b64":"..."} payload = msg.payload if payload.startswith(b"{"): doc = json.loads(payload.decode("utf-8")) b64 = doc.get("pcm16le_b64", "") if not b64: return pcm = base64.b64decode(b64) else: pcm = bytes(payload) if not state.loop: return fut = asyncio.run_coroutine_threadsafe(state.send_binary_to_device(device_id, pcm), state.loop) _ = fut.result(timeout=2) except Exception: log.exception("mqtt message handling failed") def setup_mqtt(loop: asyncio.AbstractEventLoop) -> Optional[mqtt.Client]: if not MQTT_ENABLED: log.info("mqtt disabled") return None client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="evs-bridge") if MQTT_USER: client.username_pw_set(MQTT_USER, MQTT_PASSWORD) client.on_connect = on_mqtt_connect client.on_message = on_mqtt_message client.connect(MQTT_HOST, MQTT_PORT, keepalive=30) client.loop_start() log.info("mqtt connecting to %s:%s", MQTT_HOST, MQTT_PORT) return client async def main(): state.loop = asyncio.get_running_loop() state.mqtt_client = setup_mqtt(state.loop) ws_server = await websockets.serve(ws_handler, WS_HOST, WS_PORT, max_size=2**22) log.info("ws listening on ws://%s:%s%s", WS_HOST, WS_PORT, WS_PATH) try: await ws_server.wait_closed() finally: if state.mqtt_client: state.mqtt_client.loop_stop() state.mqtt_client.disconnect() if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: pass