Files
EVS-Embedded-Voice-System/stt-worker/app.py
Kai 5294c24b08
Some checks failed
Build and Push EVS Bridge Image / docker (push) Has been cancelled
Add MQTT-based STT worker for VAD segments
2026-02-13 17:49:26 +01:00

176 lines
5.5 KiB
Python

import json
import logging
import os
import time
from pathlib import Path
from typing import Optional
import paho.mqtt.client as mqtt
from faster_whisper import WhisperModel
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO"),
format="%(asctime)s %(levelname)s %(message)s",
)
log = logging.getLogger("evs-stt-worker")
def getenv_bool(name: str, default: bool) -> bool:
val = os.getenv(name)
if val is None:
return default
return val.strip().lower() in {"1", "true", "yes", "on"}
MQTT_HOST = os.getenv("MQTT_HOST", "localhost")
MQTT_PORT = int(os.getenv("MQTT_PORT", "1883"))
MQTT_USER = os.getenv("MQTT_USER", "")
MQTT_PASSWORD = os.getenv("MQTT_PASSWORD", "")
MQTT_BASE_TOPIC = os.getenv("MQTT_BASE_TOPIC", "evs")
MQTT_VAD_TOPIC = os.getenv("MQTT_VAD_TOPIC", f"{MQTT_BASE_TOPIC}/+/vad_segment")
MQTT_TRANSCRIPT_TOPIC_TEMPLATE = os.getenv(
"MQTT_TRANSCRIPT_TOPIC_TEMPLATE", f"{MQTT_BASE_TOPIC}" + "/{device_id}/transcript"
)
MQTT_STT_ERROR_TOPIC_TEMPLATE = os.getenv(
"MQTT_STT_ERROR_TOPIC_TEMPLATE", f"{MQTT_BASE_TOPIC}" + "/{device_id}/stt_error"
)
STT_MODEL = os.getenv("STT_MODEL", "small")
STT_DEVICE = os.getenv("STT_DEVICE", "cpu")
STT_COMPUTE_TYPE = os.getenv("STT_COMPUTE_TYPE", "int8")
STT_BEAM_SIZE = int(os.getenv("STT_BEAM_SIZE", "1"))
STT_LANGUAGE = os.getenv("STT_LANGUAGE", "").strip()
STT_CONDITION_ON_PREV_TEXT = getenv_bool("STT_CONDITION_ON_PREV_TEXT", False)
class WorkerState:
def __init__(self) -> None:
self.client: Optional[mqtt.Client] = None
self.model: Optional[WhisperModel] = None
self.last_wav_path: str = ""
def model_instance(self) -> WhisperModel:
if self.model is None:
log.info(
"loading model: model=%s device=%s compute_type=%s",
STT_MODEL,
STT_DEVICE,
STT_COMPUTE_TYPE,
)
self.model = WhisperModel(STT_MODEL, device=STT_DEVICE, compute_type=STT_COMPUTE_TYPE)
log.info("model loaded")
return self.model
state = WorkerState()
def publish_json(topic: str, payload: dict) -> None:
if not state.client:
return
try:
state.client.publish(topic, json.dumps(payload), qos=0, retain=False)
except Exception:
log.exception("mqtt publish failed: topic=%s", topic)
def topic_for_transcript(device_id: str) -> str:
return MQTT_TRANSCRIPT_TOPIC_TEMPLATE.format(device_id=device_id)
def topic_for_error(device_id: str) -> str:
return MQTT_STT_ERROR_TOPIC_TEMPLATE.format(device_id=device_id)
def transcribe_wav(device_id: str, wav_path: str) -> None:
path = Path(wav_path)
if not path.is_file():
payload = {
"type": "stt_error",
"ts": time.time(),
"device_id": device_id,
"wav_path": wav_path,
"error": "wav_not_found",
}
publish_json(topic_for_error(device_id), payload)
log.warning("wav not found: device=%s path=%s", device_id, wav_path)
return
if wav_path == state.last_wav_path:
return
state.last_wav_path = wav_path
try:
model = state.model_instance()
kwargs = {
"beam_size": STT_BEAM_SIZE,
"condition_on_previous_text": STT_CONDITION_ON_PREV_TEXT,
}
if STT_LANGUAGE:
kwargs["language"] = STT_LANGUAGE
segments, info = model.transcribe(str(path), **kwargs)
text = " ".join(seg.text.strip() for seg in segments if seg.text and seg.text.strip()).strip()
payload = {
"type": "transcript",
"ts": time.time(),
"device_id": device_id,
"wav_path": wav_path,
"text": text,
"language": getattr(info, "language", ""),
"language_probability": getattr(info, "language_probability", 0.0),
"model": STT_MODEL,
}
publish_json(topic_for_transcript(device_id), payload)
log.info("transcript: device=%s chars=%s wav=%s", device_id, len(text), wav_path)
except Exception as exc:
payload = {
"type": "stt_error",
"ts": time.time(),
"device_id": device_id,
"wav_path": wav_path,
"error": str(exc),
}
publish_json(topic_for_error(device_id), payload)
log.exception("transcription failed: device=%s wav=%s", device_id, wav_path)
def on_mqtt_connect(client: mqtt.Client, _userdata, _flags, reason_code, _properties=None):
if reason_code == 0:
log.info("mqtt connected")
client.subscribe(MQTT_VAD_TOPIC, qos=0)
log.info("subscribed: %s", MQTT_VAD_TOPIC)
else:
log.error("mqtt connect failed reason=%s", reason_code)
def on_mqtt_message(_client: mqtt.Client, _userdata, msg: mqtt.MQTTMessage):
try:
payload = json.loads(msg.payload.decode("utf-8"))
except Exception:
return
if payload.get("type") != "vad_segment":
return
device_id = payload.get("device_id", "")
wav_path = payload.get("wav_path", "")
if not device_id or not wav_path:
return
transcribe_wav(device_id, wav_path)
def main() -> None:
client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="evs-stt-worker")
if MQTT_USER:
client.username_pw_set(MQTT_USER, MQTT_PASSWORD)
client.on_connect = on_mqtt_connect
client.on_message = on_mqtt_message
state.client = client
client.connect(MQTT_HOST, MQTT_PORT, keepalive=30)
client.loop_forever()
if __name__ == "__main__":
main()