import json import logging import os import time from pathlib import Path from typing import Optional import paho.mqtt.client as mqtt from faster_whisper import WhisperModel logging.basicConfig( level=os.getenv("LOG_LEVEL", "INFO"), format="%(asctime)s %(levelname)s %(message)s", ) log = logging.getLogger("evs-stt-worker") def getenv_bool(name: str, default: bool) -> bool: val = os.getenv(name) if val is None: return default return val.strip().lower() in {"1", "true", "yes", "on"} MQTT_HOST = os.getenv("MQTT_HOST", "localhost") MQTT_PORT = int(os.getenv("MQTT_PORT", "1883")) MQTT_USER = os.getenv("MQTT_USER", "") MQTT_PASSWORD = os.getenv("MQTT_PASSWORD", "") MQTT_BASE_TOPIC = os.getenv("MQTT_BASE_TOPIC", "evs") MQTT_VAD_TOPIC = os.getenv("MQTT_VAD_TOPIC", f"{MQTT_BASE_TOPIC}/+/vad_segment") MQTT_TRANSCRIPT_TOPIC_TEMPLATE = os.getenv( "MQTT_TRANSCRIPT_TOPIC_TEMPLATE", f"{MQTT_BASE_TOPIC}" + "/{device_id}/transcript" ) MQTT_STT_ERROR_TOPIC_TEMPLATE = os.getenv( "MQTT_STT_ERROR_TOPIC_TEMPLATE", f"{MQTT_BASE_TOPIC}" + "/{device_id}/stt_error" ) STT_MODEL = os.getenv("STT_MODEL", "small") STT_DEVICE = os.getenv("STT_DEVICE", "cpu") STT_COMPUTE_TYPE = os.getenv("STT_COMPUTE_TYPE", "int8") STT_BEAM_SIZE = int(os.getenv("STT_BEAM_SIZE", "1")) STT_LANGUAGE = os.getenv("STT_LANGUAGE", "").strip() STT_CONDITION_ON_PREV_TEXT = getenv_bool("STT_CONDITION_ON_PREV_TEXT", False) class WorkerState: def __init__(self) -> None: self.client: Optional[mqtt.Client] = None self.model: Optional[WhisperModel] = None self.last_wav_path: str = "" def model_instance(self) -> WhisperModel: if self.model is None: log.info( "loading model: model=%s device=%s compute_type=%s", STT_MODEL, STT_DEVICE, STT_COMPUTE_TYPE, ) self.model = WhisperModel(STT_MODEL, device=STT_DEVICE, compute_type=STT_COMPUTE_TYPE) log.info("model loaded") return self.model state = WorkerState() def publish_json(topic: str, payload: dict) -> None: if not state.client: return try: state.client.publish(topic, json.dumps(payload), qos=0, retain=False) except Exception: log.exception("mqtt publish failed: topic=%s", topic) def topic_for_transcript(device_id: str) -> str: return MQTT_TRANSCRIPT_TOPIC_TEMPLATE.format(device_id=device_id) def topic_for_error(device_id: str) -> str: return MQTT_STT_ERROR_TOPIC_TEMPLATE.format(device_id=device_id) def transcribe_wav(device_id: str, wav_path: str) -> None: path = Path(wav_path) if not path.is_file(): payload = { "type": "stt_error", "ts": time.time(), "device_id": device_id, "wav_path": wav_path, "error": "wav_not_found", } publish_json(topic_for_error(device_id), payload) log.warning("wav not found: device=%s path=%s", device_id, wav_path) return if wav_path == state.last_wav_path: return state.last_wav_path = wav_path try: model = state.model_instance() kwargs = { "beam_size": STT_BEAM_SIZE, "condition_on_previous_text": STT_CONDITION_ON_PREV_TEXT, } if STT_LANGUAGE: kwargs["language"] = STT_LANGUAGE segments, info = model.transcribe(str(path), **kwargs) text = " ".join(seg.text.strip() for seg in segments if seg.text and seg.text.strip()).strip() payload = { "type": "transcript", "ts": time.time(), "device_id": device_id, "wav_path": wav_path, "text": text, "language": getattr(info, "language", ""), "language_probability": getattr(info, "language_probability", 0.0), "model": STT_MODEL, } publish_json(topic_for_transcript(device_id), payload) log.info("transcript: device=%s chars=%s wav=%s", device_id, len(text), wav_path) except Exception as exc: payload = { "type": "stt_error", "ts": time.time(), "device_id": device_id, "wav_path": wav_path, "error": str(exc), } publish_json(topic_for_error(device_id), payload) log.exception("transcription failed: device=%s wav=%s", device_id, wav_path) def on_mqtt_connect(client: mqtt.Client, _userdata, _flags, reason_code, _properties=None): if reason_code == 0: log.info("mqtt connected") client.subscribe(MQTT_VAD_TOPIC, qos=0) log.info("subscribed: %s", MQTT_VAD_TOPIC) else: log.error("mqtt connect failed reason=%s", reason_code) def on_mqtt_message(_client: mqtt.Client, _userdata, msg: mqtt.MQTTMessage): try: payload = json.loads(msg.payload.decode("utf-8")) except Exception: return if payload.get("type") != "vad_segment": return device_id = payload.get("device_id", "") wav_path = payload.get("wav_path", "") if not device_id or not wav_path: return transcribe_wav(device_id, wav_path) def main() -> None: client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="evs-stt-worker") if MQTT_USER: client.username_pw_set(MQTT_USER, MQTT_PASSWORD) client.on_connect = on_mqtt_connect client.on_message = on_mqtt_message state.client = client client.connect(MQTT_HOST, MQTT_PORT, keepalive=30) client.loop_forever() if __name__ == "__main__": main()