176 lines
5.5 KiB
Python
176 lines
5.5 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import paho.mqtt.client as mqtt
|
|
from faster_whisper import WhisperModel
|
|
|
|
|
|
logging.basicConfig(
|
|
level=os.getenv("LOG_LEVEL", "INFO"),
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
|
)
|
|
log = logging.getLogger("evs-stt-worker")
|
|
|
|
|
|
def getenv_bool(name: str, default: bool) -> bool:
|
|
val = os.getenv(name)
|
|
if val is None:
|
|
return default
|
|
return val.strip().lower() in {"1", "true", "yes", "on"}
|
|
|
|
|
|
MQTT_HOST = os.getenv("MQTT_HOST", "localhost")
|
|
MQTT_PORT = int(os.getenv("MQTT_PORT", "1883"))
|
|
MQTT_USER = os.getenv("MQTT_USER", "")
|
|
MQTT_PASSWORD = os.getenv("MQTT_PASSWORD", "")
|
|
MQTT_BASE_TOPIC = os.getenv("MQTT_BASE_TOPIC", "evs")
|
|
MQTT_VAD_TOPIC = os.getenv("MQTT_VAD_TOPIC", f"{MQTT_BASE_TOPIC}/+/vad_segment")
|
|
MQTT_TRANSCRIPT_TOPIC_TEMPLATE = os.getenv(
|
|
"MQTT_TRANSCRIPT_TOPIC_TEMPLATE", f"{MQTT_BASE_TOPIC}" + "/{device_id}/transcript"
|
|
)
|
|
MQTT_STT_ERROR_TOPIC_TEMPLATE = os.getenv(
|
|
"MQTT_STT_ERROR_TOPIC_TEMPLATE", f"{MQTT_BASE_TOPIC}" + "/{device_id}/stt_error"
|
|
)
|
|
|
|
STT_MODEL = os.getenv("STT_MODEL", "small")
|
|
STT_DEVICE = os.getenv("STT_DEVICE", "cpu")
|
|
STT_COMPUTE_TYPE = os.getenv("STT_COMPUTE_TYPE", "int8")
|
|
STT_BEAM_SIZE = int(os.getenv("STT_BEAM_SIZE", "1"))
|
|
STT_LANGUAGE = os.getenv("STT_LANGUAGE", "").strip()
|
|
STT_CONDITION_ON_PREV_TEXT = getenv_bool("STT_CONDITION_ON_PREV_TEXT", False)
|
|
|
|
|
|
class WorkerState:
|
|
def __init__(self) -> None:
|
|
self.client: Optional[mqtt.Client] = None
|
|
self.model: Optional[WhisperModel] = None
|
|
self.last_wav_path: str = ""
|
|
|
|
def model_instance(self) -> WhisperModel:
|
|
if self.model is None:
|
|
log.info(
|
|
"loading model: model=%s device=%s compute_type=%s",
|
|
STT_MODEL,
|
|
STT_DEVICE,
|
|
STT_COMPUTE_TYPE,
|
|
)
|
|
self.model = WhisperModel(STT_MODEL, device=STT_DEVICE, compute_type=STT_COMPUTE_TYPE)
|
|
log.info("model loaded")
|
|
return self.model
|
|
|
|
|
|
state = WorkerState()
|
|
|
|
|
|
def publish_json(topic: str, payload: dict) -> None:
|
|
if not state.client:
|
|
return
|
|
try:
|
|
state.client.publish(topic, json.dumps(payload), qos=0, retain=False)
|
|
except Exception:
|
|
log.exception("mqtt publish failed: topic=%s", topic)
|
|
|
|
|
|
def topic_for_transcript(device_id: str) -> str:
|
|
return MQTT_TRANSCRIPT_TOPIC_TEMPLATE.format(device_id=device_id)
|
|
|
|
|
|
def topic_for_error(device_id: str) -> str:
|
|
return MQTT_STT_ERROR_TOPIC_TEMPLATE.format(device_id=device_id)
|
|
|
|
|
|
def transcribe_wav(device_id: str, wav_path: str) -> None:
|
|
path = Path(wav_path)
|
|
if not path.is_file():
|
|
payload = {
|
|
"type": "stt_error",
|
|
"ts": time.time(),
|
|
"device_id": device_id,
|
|
"wav_path": wav_path,
|
|
"error": "wav_not_found",
|
|
}
|
|
publish_json(topic_for_error(device_id), payload)
|
|
log.warning("wav not found: device=%s path=%s", device_id, wav_path)
|
|
return
|
|
|
|
if wav_path == state.last_wav_path:
|
|
return
|
|
state.last_wav_path = wav_path
|
|
|
|
try:
|
|
model = state.model_instance()
|
|
kwargs = {
|
|
"beam_size": STT_BEAM_SIZE,
|
|
"condition_on_previous_text": STT_CONDITION_ON_PREV_TEXT,
|
|
}
|
|
if STT_LANGUAGE:
|
|
kwargs["language"] = STT_LANGUAGE
|
|
|
|
segments, info = model.transcribe(str(path), **kwargs)
|
|
text = " ".join(seg.text.strip() for seg in segments if seg.text and seg.text.strip()).strip()
|
|
payload = {
|
|
"type": "transcript",
|
|
"ts": time.time(),
|
|
"device_id": device_id,
|
|
"wav_path": wav_path,
|
|
"text": text,
|
|
"language": getattr(info, "language", ""),
|
|
"language_probability": getattr(info, "language_probability", 0.0),
|
|
"model": STT_MODEL,
|
|
}
|
|
publish_json(topic_for_transcript(device_id), payload)
|
|
log.info("transcript: device=%s chars=%s wav=%s", device_id, len(text), wav_path)
|
|
except Exception as exc:
|
|
payload = {
|
|
"type": "stt_error",
|
|
"ts": time.time(),
|
|
"device_id": device_id,
|
|
"wav_path": wav_path,
|
|
"error": str(exc),
|
|
}
|
|
publish_json(topic_for_error(device_id), payload)
|
|
log.exception("transcription failed: device=%s wav=%s", device_id, wav_path)
|
|
|
|
|
|
def on_mqtt_connect(client: mqtt.Client, _userdata, _flags, reason_code, _properties=None):
|
|
if reason_code == 0:
|
|
log.info("mqtt connected")
|
|
client.subscribe(MQTT_VAD_TOPIC, qos=0)
|
|
log.info("subscribed: %s", MQTT_VAD_TOPIC)
|
|
else:
|
|
log.error("mqtt connect failed reason=%s", reason_code)
|
|
|
|
|
|
def on_mqtt_message(_client: mqtt.Client, _userdata, msg: mqtt.MQTTMessage):
|
|
try:
|
|
payload = json.loads(msg.payload.decode("utf-8"))
|
|
except Exception:
|
|
return
|
|
if payload.get("type") != "vad_segment":
|
|
return
|
|
|
|
device_id = payload.get("device_id", "")
|
|
wav_path = payload.get("wav_path", "")
|
|
if not device_id or not wav_path:
|
|
return
|
|
transcribe_wav(device_id, wav_path)
|
|
|
|
|
|
def main() -> None:
|
|
client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="evs-stt-worker")
|
|
if MQTT_USER:
|
|
client.username_pw_set(MQTT_USER, MQTT_PASSWORD)
|
|
client.on_connect = on_mqtt_connect
|
|
client.on_message = on_mqtt_message
|
|
state.client = client
|
|
client.connect(MQTT_HOST, MQTT_PORT, keepalive=30)
|
|
client.loop_forever()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|