586 lines
20 KiB
Python
586 lines
20 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional
|
|
|
|
import aiohttp
|
|
import paho.mqtt.client as mqtt
|
|
import websockets
|
|
from websockets.server import WebSocketServerProtocol
|
|
|
|
|
|
logging.basicConfig(
|
|
level=os.getenv("LOG_LEVEL", "INFO"),
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
|
)
|
|
log = logging.getLogger("evs-bridge")
|
|
|
|
|
|
def getenv_bool(name: str, default: bool) -> bool:
|
|
val = os.getenv(name)
|
|
if val is None:
|
|
return default
|
|
return val.strip().lower() in {"1", "true", "yes", "on"}
|
|
|
|
|
|
WS_HOST = os.getenv("WS_HOST", "0.0.0.0")
|
|
WS_PORT = int(os.getenv("WS_PORT", "8765"))
|
|
WS_PATH = os.getenv("WS_PATH", "/audio")
|
|
ECHO_ENABLED = getenv_bool("ECHO_ENABLED", True)
|
|
|
|
MQTT_ENABLED = getenv_bool("MQTT_ENABLED", True)
|
|
MQTT_HOST = os.getenv("MQTT_HOST", "localhost")
|
|
MQTT_PORT = int(os.getenv("MQTT_PORT", "1883"))
|
|
MQTT_USER = os.getenv("MQTT_USER", "")
|
|
MQTT_PASSWORD = os.getenv("MQTT_PASSWORD", "")
|
|
MQTT_BASE_TOPIC = os.getenv("MQTT_BASE_TOPIC", "evs")
|
|
MQTT_TTS_TOPIC = os.getenv("MQTT_TTS_TOPIC", f"{MQTT_BASE_TOPIC}/+/play_pcm16le")
|
|
MQTT_STATUS_RETAIN = getenv_bool("MQTT_STATUS_RETAIN", True)
|
|
DEVICE_PAIR_MAP_JSON = os.getenv("DEVICE_PAIR_MAP", "").strip()
|
|
|
|
HA_WEBHOOK_URL = os.getenv("HA_WEBHOOK_URL", "").strip()
|
|
SAVE_SESSIONS = getenv_bool("SAVE_SESSIONS", True)
|
|
SESSIONS_DIR = Path(os.getenv("SESSIONS_DIR", "/data/sessions"))
|
|
PCM_SAMPLE_RATE = int(os.getenv("PCM_SAMPLE_RATE", "16000"))
|
|
MAX_SESSION_BYTES = int(os.getenv("MAX_SESSION_BYTES", "16000000"))
|
|
WAV_SEGMENT_MAX_BYTES = int(os.getenv("WAV_SEGMENT_MAX_BYTES", str(20 * 1024 * 1024)))
|
|
WAV_KEEP_FILES = int(os.getenv("WAV_KEEP_FILES", "10"))
|
|
WAV_HEADER_BYTES = 44
|
|
|
|
VAD_ENABLED = getenv_bool("VAD_ENABLED", True)
|
|
VAD_DIR = Path(os.getenv("VAD_DIR", "/data/vad"))
|
|
VAD_KEEP_FILES = int(os.getenv("VAD_KEEP_FILES", "200"))
|
|
VAD_MAX_AGE_DAYS = int(os.getenv("VAD_MAX_AGE_DAYS", "7"))
|
|
VAD_PREROLL_MS = int(os.getenv("VAD_PREROLL_MS", "1000"))
|
|
VAD_POSTROLL_MS = int(os.getenv("VAD_POSTROLL_MS", "1000"))
|
|
VAD_START_THRESHOLD = int(os.getenv("VAD_START_THRESHOLD", "900"))
|
|
VAD_STOP_THRESHOLD = int(os.getenv("VAD_STOP_THRESHOLD", "600"))
|
|
VAD_MIN_SPEECH_MS = int(os.getenv("VAD_MIN_SPEECH_MS", "300"))
|
|
|
|
|
|
@dataclass
|
|
class DeviceSession:
|
|
device_id: str
|
|
ws: WebSocketServerProtocol
|
|
connected_at: float = field(default_factory=time.time)
|
|
ptt_active: bool = False
|
|
pcm_bytes: bytearray = field(default_factory=bytearray)
|
|
last_rx_ts: float = field(default_factory=time.time)
|
|
rx_bytes_total: int = 0
|
|
segment_index: int = 0
|
|
segment_pcm_buffer: bytearray = field(default_factory=bytearray)
|
|
saved_wavs: List[str] = field(default_factory=list)
|
|
vad_active: bool = False
|
|
vad_silence_ms: int = 0
|
|
vad_speech_ms: int = 0
|
|
vad_preroll_buffer: bytearray = field(default_factory=bytearray)
|
|
vad_segment_buffer: bytearray = field(default_factory=bytearray)
|
|
vad_segment_index: int = 0
|
|
|
|
|
|
class BridgeState:
|
|
def __init__(self) -> None:
|
|
self.loop: Optional[asyncio.AbstractEventLoop] = None
|
|
self.devices: Dict[str, DeviceSession] = {}
|
|
self.mqtt_client: Optional[mqtt.Client] = None
|
|
|
|
def publish_status(self, device_id: str, payload: dict) -> None:
|
|
if not self.mqtt_client:
|
|
return
|
|
msg_type = str(payload.get("type", "status"))
|
|
retain = MQTT_STATUS_RETAIN
|
|
if msg_type == "mic_level":
|
|
topic = f"{MQTT_BASE_TOPIC}/{device_id}/mic_level"
|
|
retain = False
|
|
elif msg_type == "vad_segment":
|
|
topic = f"{MQTT_BASE_TOPIC}/{device_id}/vad_segment"
|
|
retain = False
|
|
elif msg_type == "transcript":
|
|
topic = f"{MQTT_BASE_TOPIC}/{device_id}/transcript"
|
|
retain = False
|
|
elif msg_type == "stt_error":
|
|
topic = f"{MQTT_BASE_TOPIC}/{device_id}/stt_error"
|
|
retain = False
|
|
else:
|
|
topic = f"{MQTT_BASE_TOPIC}/{device_id}/status"
|
|
try:
|
|
self.mqtt_client.publish(topic, json.dumps(payload), qos=0, retain=retain)
|
|
except Exception:
|
|
log.exception("mqtt publish failed")
|
|
|
|
async def send_binary_to_device(self, device_id: str, pcm_data: bytes) -> bool:
|
|
session = self.devices.get(device_id)
|
|
if not session:
|
|
return False
|
|
try:
|
|
await session.ws.send(pcm_data)
|
|
return True
|
|
except Exception:
|
|
log.exception("ws send to device failed")
|
|
return False
|
|
|
|
|
|
state = BridgeState()
|
|
DEVICE_PAIR_MAP: Dict[str, str] = {}
|
|
if DEVICE_PAIR_MAP_JSON:
|
|
try:
|
|
raw = json.loads(DEVICE_PAIR_MAP_JSON)
|
|
if isinstance(raw, dict):
|
|
DEVICE_PAIR_MAP = {str(k): str(v) for k, v in raw.items() if str(k) and str(v)}
|
|
log.info("device pair map loaded: %s", DEVICE_PAIR_MAP)
|
|
else:
|
|
log.warning("DEVICE_PAIR_MAP must be a JSON object")
|
|
except Exception:
|
|
log.exception("failed to parse DEVICE_PAIR_MAP")
|
|
|
|
|
|
def paired_output_device(device_id: str) -> str:
|
|
return DEVICE_PAIR_MAP.get(device_id, device_id)
|
|
|
|
|
|
def build_metrics(device_id: str, session: DeviceSession) -> dict:
|
|
samples = session.rx_bytes_total // 2
|
|
seconds = samples / float(PCM_SAMPLE_RATE)
|
|
return {
|
|
"device_id": device_id,
|
|
"ptt_active": session.ptt_active,
|
|
"rx_bytes": session.rx_bytes_total,
|
|
"duration_s": round(seconds, 3),
|
|
"last_rx_ts": session.last_rx_ts,
|
|
}
|
|
|
|
|
|
async def call_ha_webhook(event: str, payload: dict) -> None:
|
|
if not HA_WEBHOOK_URL:
|
|
return
|
|
data = {"event": event, **payload}
|
|
try:
|
|
async with aiohttp.ClientSession() as client:
|
|
async with client.post(HA_WEBHOOK_URL, json=data, timeout=10) as resp:
|
|
if resp.status >= 400:
|
|
log.warning("ha webhook error status=%s", resp.status)
|
|
except Exception:
|
|
log.exception("ha webhook call failed")
|
|
|
|
|
|
def enforce_wav_retention(directory: Path, keep_files: int, max_age_days: int = 0) -> None:
|
|
if keep_files <= 0 and max_age_days <= 0:
|
|
return
|
|
try:
|
|
directory.mkdir(parents=True, exist_ok=True)
|
|
wavs = []
|
|
now = time.time()
|
|
max_age_seconds = max_age_days * 86400
|
|
for p in directory.glob("*.wav"):
|
|
if not p.is_file():
|
|
continue
|
|
try:
|
|
st = p.stat()
|
|
except Exception:
|
|
continue
|
|
if max_age_seconds > 0 and (now - st.st_mtime) > max_age_seconds:
|
|
try:
|
|
p.unlink()
|
|
log.info("deleted old wav by age: %s", p)
|
|
except Exception:
|
|
log.exception("failed to delete old wav by age: %s", p)
|
|
continue
|
|
wavs.append((p, st.st_mtime))
|
|
|
|
wavs.sort(key=lambda x: x[1])
|
|
files = [p for (p, _) in wavs]
|
|
while keep_files > 0 and len(files) > keep_files:
|
|
oldest = files.pop(0)
|
|
try:
|
|
oldest.unlink()
|
|
log.info("deleted old wav: %s", oldest)
|
|
except Exception:
|
|
log.exception("failed to delete old wav: %s", oldest)
|
|
except Exception:
|
|
log.exception("failed to enforce wav retention")
|
|
|
|
|
|
def write_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]:
|
|
if not SAVE_SESSIONS or not pcm:
|
|
return None
|
|
try:
|
|
SESSIONS_DIR.mkdir(parents=True, exist_ok=True)
|
|
ts_ms = int(time.time() * 1000)
|
|
name = f"{session.device_id}_{ts_ms}_part{session.segment_index:03d}.wav"
|
|
path = SESSIONS_DIR / name
|
|
import wave
|
|
|
|
wf = wave.open(str(path), "wb")
|
|
wf.setnchannels(1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(PCM_SAMPLE_RATE)
|
|
wf.writeframes(pcm)
|
|
wf.close()
|
|
session.saved_wavs.append(str(path))
|
|
enforce_wav_retention(SESSIONS_DIR, WAV_KEEP_FILES)
|
|
session.segment_index += 1
|
|
return str(path)
|
|
except Exception:
|
|
log.exception("failed to write wav segment for %s", session.device_id)
|
|
return None
|
|
|
|
|
|
def write_vad_wav_segment(session: DeviceSession, pcm: bytes) -> Optional[str]:
|
|
if not VAD_ENABLED or not pcm:
|
|
return None
|
|
try:
|
|
VAD_DIR.mkdir(parents=True, exist_ok=True)
|
|
ts_ms = int(time.time() * 1000)
|
|
name = f"{session.device_id}_{ts_ms}_vad{session.vad_segment_index:05d}.wav"
|
|
path = VAD_DIR / name
|
|
import wave
|
|
|
|
wf = wave.open(str(path), "wb")
|
|
wf.setnchannels(1)
|
|
wf.setsampwidth(2)
|
|
wf.setframerate(PCM_SAMPLE_RATE)
|
|
wf.writeframes(pcm)
|
|
wf.close()
|
|
session.vad_segment_index += 1
|
|
enforce_wav_retention(VAD_DIR, VAD_KEEP_FILES, VAD_MAX_AGE_DAYS)
|
|
return str(path)
|
|
except Exception:
|
|
log.exception("failed to write vad wav segment for %s", session.device_id)
|
|
return None
|
|
|
|
|
|
def pcm_avg_abs(data: bytes) -> int:
|
|
if len(data) < 2:
|
|
return 0
|
|
# PCM16LE mono
|
|
samples = memoryview(data).cast("h")
|
|
total = 0
|
|
for s in samples:
|
|
total += -s if s < 0 else s
|
|
return total // len(samples)
|
|
|
|
|
|
def pcm_duration_ms(pcm_bytes: int) -> int:
|
|
bytes_per_second = PCM_SAMPLE_RATE * 2
|
|
return int((pcm_bytes * 1000) / bytes_per_second)
|
|
|
|
|
|
def process_vad_frame(device_id: str, session: DeviceSession, data: bytes) -> None:
|
|
if not VAD_ENABLED or not data:
|
|
return
|
|
|
|
preroll_max_bytes = max(0, (PCM_SAMPLE_RATE * 2 * VAD_PREROLL_MS) // 1000)
|
|
if preroll_max_bytes > 0:
|
|
session.vad_preroll_buffer.extend(data)
|
|
if len(session.vad_preroll_buffer) > preroll_max_bytes:
|
|
drop = len(session.vad_preroll_buffer) - preroll_max_bytes
|
|
del session.vad_preroll_buffer[:drop]
|
|
|
|
level = pcm_avg_abs(data)
|
|
frame_ms = pcm_duration_ms(len(data))
|
|
if frame_ms <= 0:
|
|
return
|
|
|
|
if not session.vad_active:
|
|
if level >= VAD_START_THRESHOLD:
|
|
session.vad_active = True
|
|
session.vad_silence_ms = 0
|
|
session.vad_speech_ms = frame_ms
|
|
session.vad_segment_buffer.clear()
|
|
if session.vad_preroll_buffer:
|
|
session.vad_segment_buffer.extend(session.vad_preroll_buffer)
|
|
log.info("vad_start: device=%s level=%s", device_id, level)
|
|
return
|
|
|
|
session.vad_segment_buffer.extend(data)
|
|
session.vad_speech_ms += frame_ms
|
|
if level < VAD_STOP_THRESHOLD:
|
|
session.vad_silence_ms += frame_ms
|
|
else:
|
|
session.vad_silence_ms = 0
|
|
|
|
if session.vad_silence_ms < VAD_POSTROLL_MS:
|
|
return
|
|
|
|
segment_bytes = bytes(session.vad_segment_buffer)
|
|
segment_ms = pcm_duration_ms(len(segment_bytes))
|
|
path = None
|
|
if segment_ms >= VAD_MIN_SPEECH_MS:
|
|
path = write_vad_wav_segment(session, segment_bytes)
|
|
|
|
payload = {
|
|
"type": "vad_segment",
|
|
"ts": time.time(),
|
|
"device_id": device_id,
|
|
"duration_s": round(segment_ms / 1000.0, 3),
|
|
"level": level,
|
|
}
|
|
if path:
|
|
payload["wav_path"] = path
|
|
state.publish_status(device_id, payload)
|
|
log.info(
|
|
"vad_stop: device=%s duration_s=%s wav=%s",
|
|
device_id,
|
|
payload["duration_s"],
|
|
path or "-",
|
|
)
|
|
|
|
session.vad_active = False
|
|
session.vad_silence_ms = 0
|
|
session.vad_speech_ms = 0
|
|
session.vad_segment_buffer.clear()
|
|
|
|
|
|
def append_pcm_with_rotation(session: DeviceSession, data: bytes) -> None:
|
|
if not SAVE_SESSIONS or not data:
|
|
return
|
|
if WAV_SEGMENT_MAX_BYTES <= WAV_HEADER_BYTES:
|
|
log.warning("WAV_SEGMENT_MAX_BYTES too small, minimum is %s", WAV_HEADER_BYTES + 2)
|
|
return
|
|
|
|
max_pcm_per_file = WAV_SEGMENT_MAX_BYTES - WAV_HEADER_BYTES
|
|
session.segment_pcm_buffer.extend(data)
|
|
while len(session.segment_pcm_buffer) >= max_pcm_per_file:
|
|
chunk = bytes(session.segment_pcm_buffer[:max_pcm_per_file])
|
|
del session.segment_pcm_buffer[:max_pcm_per_file]
|
|
write_wav_segment(session, chunk)
|
|
|
|
|
|
def flush_pending_segment(session: DeviceSession) -> None:
|
|
if not SAVE_SESSIONS:
|
|
return
|
|
if not session.segment_pcm_buffer:
|
|
return
|
|
chunk = bytes(session.segment_pcm_buffer)
|
|
session.segment_pcm_buffer.clear()
|
|
write_wav_segment(session, chunk)
|
|
|
|
|
|
async def handle_text_message(device_id: str, session: DeviceSession, raw: str) -> None:
|
|
try:
|
|
msg = json.loads(raw)
|
|
except Exception:
|
|
log.warning("invalid text frame from %s: %s", device_id, raw[:80])
|
|
return
|
|
|
|
msg_type = msg.get("type")
|
|
if msg_type == "start":
|
|
session.ptt_active = True
|
|
session.pcm_bytes.clear()
|
|
session.rx_bytes_total = 0
|
|
session.saved_wavs.clear()
|
|
session.segment_index = 0
|
|
session.segment_pcm_buffer.clear()
|
|
session.vad_active = False
|
|
session.vad_silence_ms = 0
|
|
session.vad_speech_ms = 0
|
|
session.vad_preroll_buffer.clear()
|
|
session.vad_segment_buffer.clear()
|
|
payload = {"type": "start", "ts": time.time(), "device_id": device_id}
|
|
state.publish_status(device_id, payload)
|
|
await call_ha_webhook("start", payload)
|
|
log.info("start: device=%s", device_id)
|
|
return
|
|
|
|
if msg_type == "stop":
|
|
session.ptt_active = False
|
|
flush_pending_segment(session)
|
|
session.vad_active = False
|
|
session.vad_silence_ms = 0
|
|
session.vad_speech_ms = 0
|
|
session.vad_preroll_buffer.clear()
|
|
session.vad_segment_buffer.clear()
|
|
metrics = build_metrics(device_id, session)
|
|
payload = {"type": "stop", "ts": time.time(), "device_id": device_id, **metrics}
|
|
if session.saved_wavs:
|
|
payload["wav_path"] = session.saved_wavs[-1]
|
|
payload["wav_paths"] = session.saved_wavs
|
|
state.publish_status(device_id, payload)
|
|
await call_ha_webhook("stop", payload)
|
|
log.info(
|
|
"stop: device=%s bytes=%s duration_s=%s wav_count=%s last_wav=%s",
|
|
device_id,
|
|
metrics.get("rx_bytes", 0),
|
|
metrics.get("duration_s", 0),
|
|
len(session.saved_wavs),
|
|
session.saved_wavs[-1] if session.saved_wavs else "-",
|
|
)
|
|
return
|
|
|
|
if msg_type == "ping":
|
|
await session.ws.send(json.dumps({"type": "pong", "ts": time.time()}))
|
|
return
|
|
|
|
if msg_type == "mic_level":
|
|
payload = {
|
|
"type": "mic_level",
|
|
"ts": time.time(),
|
|
"device_id": device_id,
|
|
"peak": msg.get("peak", 0),
|
|
"avg_abs": msg.get("avg_abs", 0),
|
|
"samples": msg.get("samples", 0),
|
|
"mic_gain": msg.get("mic_gain", 0),
|
|
}
|
|
state.publish_status(device_id, payload)
|
|
log.info(
|
|
"mic_level: device=%s peak=%s avg_abs=%s samples=%s mic_gain=%s",
|
|
device_id,
|
|
payload["peak"],
|
|
payload["avg_abs"],
|
|
payload["samples"],
|
|
payload["mic_gain"],
|
|
)
|
|
return
|
|
|
|
log.info("text msg from %s: %s", device_id, msg)
|
|
|
|
|
|
async def handle_binary_message(device_id: str, session: DeviceSession, data: bytes) -> None:
|
|
session.last_rx_ts = time.time()
|
|
# VAD should work on continuous stream even if explicit start/stop is missing.
|
|
process_vad_frame(device_id, session, data)
|
|
if session.ptt_active:
|
|
session.rx_bytes_total += len(data)
|
|
if SAVE_SESSIONS:
|
|
append_pcm_with_rotation(session, data)
|
|
else:
|
|
session.pcm_bytes.extend(data)
|
|
if len(session.pcm_bytes) > MAX_SESSION_BYTES:
|
|
# Keep newest data within cap to avoid unbounded memory growth.
|
|
drop = len(session.pcm_bytes) - MAX_SESSION_BYTES
|
|
del session.pcm_bytes[:drop]
|
|
if ECHO_ENABLED:
|
|
target_device = paired_output_device(device_id)
|
|
ok = await state.send_binary_to_device(target_device, data)
|
|
if not ok and target_device != device_id:
|
|
log.debug("paired output device not connected: src=%s target=%s", device_id, target_device)
|
|
|
|
|
|
def parse_device_id(path: str) -> str:
|
|
# expected:
|
|
# /audio
|
|
# /audio?device_id=esp32-kitchen
|
|
if "?" not in path:
|
|
return "esp32-unknown"
|
|
try:
|
|
from urllib.parse import parse_qs, urlsplit
|
|
|
|
q = parse_qs(urlsplit(path).query)
|
|
return q.get("device_id", ["esp32-unknown"])[0]
|
|
except Exception:
|
|
return "esp32-unknown"
|
|
|
|
|
|
async def ws_handler(ws: WebSocketServerProtocol, path: str) -> None:
|
|
if not path.startswith(WS_PATH):
|
|
await ws.close(code=1008, reason="Invalid path")
|
|
return
|
|
|
|
device_id = parse_device_id(path)
|
|
session = DeviceSession(device_id=device_id, ws=ws)
|
|
state.devices[device_id] = session
|
|
state.publish_status(device_id, {"type": "connected", "ts": time.time(), "device_id": device_id})
|
|
await call_ha_webhook("connected", {"device_id": device_id, "ts": time.time()})
|
|
log.info("device connected: %s", device_id)
|
|
|
|
try:
|
|
async for message in ws:
|
|
if isinstance(message, bytes):
|
|
await handle_binary_message(device_id, session, message)
|
|
else:
|
|
await handle_text_message(device_id, session, message)
|
|
except websockets.ConnectionClosed:
|
|
pass
|
|
finally:
|
|
if session.ptt_active:
|
|
flush_pending_segment(session)
|
|
session.ptt_active = False
|
|
session.vad_active = False
|
|
session.vad_silence_ms = 0
|
|
session.vad_speech_ms = 0
|
|
session.vad_preroll_buffer.clear()
|
|
session.vad_segment_buffer.clear()
|
|
if state.devices.get(device_id) is session:
|
|
del state.devices[device_id]
|
|
state.publish_status(device_id, {"type": "disconnected", "ts": time.time(), "device_id": device_id})
|
|
await call_ha_webhook("disconnected", {"device_id": device_id, "ts": time.time()})
|
|
log.info("device disconnected: %s", device_id)
|
|
|
|
|
|
def on_mqtt_connect(client: mqtt.Client, _userdata, _flags, reason_code, _properties=None):
|
|
if reason_code == 0:
|
|
log.info("mqtt connected")
|
|
client.subscribe(MQTT_TTS_TOPIC, qos=0)
|
|
else:
|
|
log.error("mqtt connect failed reason=%s", reason_code)
|
|
|
|
|
|
def on_mqtt_message(_client: mqtt.Client, _userdata, msg: mqtt.MQTTMessage):
|
|
# topic: evs/<device_id>/play_pcm16le
|
|
try:
|
|
parts = msg.topic.split("/")
|
|
if len(parts) < 3:
|
|
return
|
|
device_id = parts[1]
|
|
|
|
# payload options:
|
|
# 1) raw binary PCM16LE
|
|
# 2) json {"pcm16le_b64":"..."}
|
|
payload = msg.payload
|
|
if payload.startswith(b"{"):
|
|
doc = json.loads(payload.decode("utf-8"))
|
|
b64 = doc.get("pcm16le_b64", "")
|
|
if not b64:
|
|
return
|
|
pcm = base64.b64decode(b64)
|
|
else:
|
|
pcm = bytes(payload)
|
|
|
|
if not state.loop:
|
|
return
|
|
fut = asyncio.run_coroutine_threadsafe(state.send_binary_to_device(device_id, pcm), state.loop)
|
|
_ = fut.result(timeout=2)
|
|
except Exception:
|
|
log.exception("mqtt message handling failed")
|
|
|
|
|
|
def setup_mqtt(loop: asyncio.AbstractEventLoop) -> Optional[mqtt.Client]:
|
|
if not MQTT_ENABLED:
|
|
log.info("mqtt disabled")
|
|
return None
|
|
|
|
client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="evs-bridge")
|
|
if MQTT_USER:
|
|
client.username_pw_set(MQTT_USER, MQTT_PASSWORD)
|
|
client.on_connect = on_mqtt_connect
|
|
client.on_message = on_mqtt_message
|
|
client.connect(MQTT_HOST, MQTT_PORT, keepalive=30)
|
|
client.loop_start()
|
|
log.info("mqtt connecting to %s:%s", MQTT_HOST, MQTT_PORT)
|
|
return client
|
|
|
|
|
|
async def main():
|
|
state.loop = asyncio.get_running_loop()
|
|
state.mqtt_client = setup_mqtt(state.loop)
|
|
ws_server = await websockets.serve(ws_handler, WS_HOST, WS_PORT, max_size=2**22)
|
|
log.info("ws listening on ws://%s:%s%s", WS_HOST, WS_PORT, WS_PATH)
|
|
try:
|
|
await ws_server.wait_closed()
|
|
finally:
|
|
if state.mqtt_client:
|
|
state.mqtt_client.loop_stop()
|
|
state.mqtt_client.disconnect()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
pass
|