Add MQTT-based STT worker for VAD segments
Some checks failed
Build and Push EVS Bridge Image / docker (push) Has been cancelled
Some checks failed
Build and Push EVS Bridge Image / docker (push) Has been cancelled
This commit is contained in:
26
stt-worker/Dockerfile
Normal file
26
stt-worker/Dockerfile
Normal file
@@ -0,0 +1,26 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
COPY app.py .
|
||||
|
||||
ENV LOG_LEVEL=INFO \
|
||||
MQTT_HOST=localhost \
|
||||
MQTT_PORT=1883 \
|
||||
MQTT_USER= \
|
||||
MQTT_PASSWORD= \
|
||||
MQTT_BASE_TOPIC=evs \
|
||||
MQTT_VAD_TOPIC=evs/+/vad_segment \
|
||||
MQTT_TRANSCRIPT_TOPIC_TEMPLATE=evs/{device_id}/transcript \
|
||||
MQTT_STT_ERROR_TOPIC_TEMPLATE=evs/{device_id}/stt_error \
|
||||
STT_MODEL=small \
|
||||
STT_DEVICE=cpu \
|
||||
STT_COMPUTE_TYPE=int8 \
|
||||
STT_BEAM_SIZE=1 \
|
||||
STT_LANGUAGE=de \
|
||||
STT_CONDITION_ON_PREV_TEXT=false
|
||||
|
||||
CMD ["python", "app.py"]
|
||||
14
stt-worker/README.md
Normal file
14
stt-worker/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# EVS STT Worker
|
||||
|
||||
This worker subscribes to VAD events from MQTT, transcribes the referenced WAV files, and publishes text back to MQTT.
|
||||
|
||||
Flow:
|
||||
- input topic: `evs/<device_id>/vad_segment`
|
||||
- reads: `wav_path` from JSON payload
|
||||
- output topic: `evs/<device_id>/transcript`
|
||||
- error topic: `evs/<device_id>/stt_error`
|
||||
|
||||
Default model:
|
||||
- `STT_MODEL=small`
|
||||
- `STT_DEVICE=cpu`
|
||||
- `STT_COMPUTE_TYPE=int8`
|
||||
175
stt-worker/app.py
Normal file
175
stt-worker/app.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import paho.mqtt.client as mqtt
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=os.getenv("LOG_LEVEL", "INFO"),
|
||||
format="%(asctime)s %(levelname)s %(message)s",
|
||||
)
|
||||
log = logging.getLogger("evs-stt-worker")
|
||||
|
||||
|
||||
def getenv_bool(name: str, default: bool) -> bool:
|
||||
val = os.getenv(name)
|
||||
if val is None:
|
||||
return default
|
||||
return val.strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
MQTT_HOST = os.getenv("MQTT_HOST", "localhost")
|
||||
MQTT_PORT = int(os.getenv("MQTT_PORT", "1883"))
|
||||
MQTT_USER = os.getenv("MQTT_USER", "")
|
||||
MQTT_PASSWORD = os.getenv("MQTT_PASSWORD", "")
|
||||
MQTT_BASE_TOPIC = os.getenv("MQTT_BASE_TOPIC", "evs")
|
||||
MQTT_VAD_TOPIC = os.getenv("MQTT_VAD_TOPIC", f"{MQTT_BASE_TOPIC}/+/vad_segment")
|
||||
MQTT_TRANSCRIPT_TOPIC_TEMPLATE = os.getenv(
|
||||
"MQTT_TRANSCRIPT_TOPIC_TEMPLATE", f"{MQTT_BASE_TOPIC}" + "/{device_id}/transcript"
|
||||
)
|
||||
MQTT_STT_ERROR_TOPIC_TEMPLATE = os.getenv(
|
||||
"MQTT_STT_ERROR_TOPIC_TEMPLATE", f"{MQTT_BASE_TOPIC}" + "/{device_id}/stt_error"
|
||||
)
|
||||
|
||||
STT_MODEL = os.getenv("STT_MODEL", "small")
|
||||
STT_DEVICE = os.getenv("STT_DEVICE", "cpu")
|
||||
STT_COMPUTE_TYPE = os.getenv("STT_COMPUTE_TYPE", "int8")
|
||||
STT_BEAM_SIZE = int(os.getenv("STT_BEAM_SIZE", "1"))
|
||||
STT_LANGUAGE = os.getenv("STT_LANGUAGE", "").strip()
|
||||
STT_CONDITION_ON_PREV_TEXT = getenv_bool("STT_CONDITION_ON_PREV_TEXT", False)
|
||||
|
||||
|
||||
class WorkerState:
|
||||
def __init__(self) -> None:
|
||||
self.client: Optional[mqtt.Client] = None
|
||||
self.model: Optional[WhisperModel] = None
|
||||
self.last_wav_path: str = ""
|
||||
|
||||
def model_instance(self) -> WhisperModel:
|
||||
if self.model is None:
|
||||
log.info(
|
||||
"loading model: model=%s device=%s compute_type=%s",
|
||||
STT_MODEL,
|
||||
STT_DEVICE,
|
||||
STT_COMPUTE_TYPE,
|
||||
)
|
||||
self.model = WhisperModel(STT_MODEL, device=STT_DEVICE, compute_type=STT_COMPUTE_TYPE)
|
||||
log.info("model loaded")
|
||||
return self.model
|
||||
|
||||
|
||||
state = WorkerState()
|
||||
|
||||
|
||||
def publish_json(topic: str, payload: dict) -> None:
|
||||
if not state.client:
|
||||
return
|
||||
try:
|
||||
state.client.publish(topic, json.dumps(payload), qos=0, retain=False)
|
||||
except Exception:
|
||||
log.exception("mqtt publish failed: topic=%s", topic)
|
||||
|
||||
|
||||
def topic_for_transcript(device_id: str) -> str:
|
||||
return MQTT_TRANSCRIPT_TOPIC_TEMPLATE.format(device_id=device_id)
|
||||
|
||||
|
||||
def topic_for_error(device_id: str) -> str:
|
||||
return MQTT_STT_ERROR_TOPIC_TEMPLATE.format(device_id=device_id)
|
||||
|
||||
|
||||
def transcribe_wav(device_id: str, wav_path: str) -> None:
|
||||
path = Path(wav_path)
|
||||
if not path.is_file():
|
||||
payload = {
|
||||
"type": "stt_error",
|
||||
"ts": time.time(),
|
||||
"device_id": device_id,
|
||||
"wav_path": wav_path,
|
||||
"error": "wav_not_found",
|
||||
}
|
||||
publish_json(topic_for_error(device_id), payload)
|
||||
log.warning("wav not found: device=%s path=%s", device_id, wav_path)
|
||||
return
|
||||
|
||||
if wav_path == state.last_wav_path:
|
||||
return
|
||||
state.last_wav_path = wav_path
|
||||
|
||||
try:
|
||||
model = state.model_instance()
|
||||
kwargs = {
|
||||
"beam_size": STT_BEAM_SIZE,
|
||||
"condition_on_previous_text": STT_CONDITION_ON_PREV_TEXT,
|
||||
}
|
||||
if STT_LANGUAGE:
|
||||
kwargs["language"] = STT_LANGUAGE
|
||||
|
||||
segments, info = model.transcribe(str(path), **kwargs)
|
||||
text = " ".join(seg.text.strip() for seg in segments if seg.text and seg.text.strip()).strip()
|
||||
payload = {
|
||||
"type": "transcript",
|
||||
"ts": time.time(),
|
||||
"device_id": device_id,
|
||||
"wav_path": wav_path,
|
||||
"text": text,
|
||||
"language": getattr(info, "language", ""),
|
||||
"language_probability": getattr(info, "language_probability", 0.0),
|
||||
"model": STT_MODEL,
|
||||
}
|
||||
publish_json(topic_for_transcript(device_id), payload)
|
||||
log.info("transcript: device=%s chars=%s wav=%s", device_id, len(text), wav_path)
|
||||
except Exception as exc:
|
||||
payload = {
|
||||
"type": "stt_error",
|
||||
"ts": time.time(),
|
||||
"device_id": device_id,
|
||||
"wav_path": wav_path,
|
||||
"error": str(exc),
|
||||
}
|
||||
publish_json(topic_for_error(device_id), payload)
|
||||
log.exception("transcription failed: device=%s wav=%s", device_id, wav_path)
|
||||
|
||||
|
||||
def on_mqtt_connect(client: mqtt.Client, _userdata, _flags, reason_code, _properties=None):
|
||||
if reason_code == 0:
|
||||
log.info("mqtt connected")
|
||||
client.subscribe(MQTT_VAD_TOPIC, qos=0)
|
||||
log.info("subscribed: %s", MQTT_VAD_TOPIC)
|
||||
else:
|
||||
log.error("mqtt connect failed reason=%s", reason_code)
|
||||
|
||||
|
||||
def on_mqtt_message(_client: mqtt.Client, _userdata, msg: mqtt.MQTTMessage):
|
||||
try:
|
||||
payload = json.loads(msg.payload.decode("utf-8"))
|
||||
except Exception:
|
||||
return
|
||||
if payload.get("type") != "vad_segment":
|
||||
return
|
||||
|
||||
device_id = payload.get("device_id", "")
|
||||
wav_path = payload.get("wav_path", "")
|
||||
if not device_id or not wav_path:
|
||||
return
|
||||
transcribe_wav(device_id, wav_path)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
client = mqtt.Client(mqtt.CallbackAPIVersion.VERSION2, client_id="evs-stt-worker")
|
||||
if MQTT_USER:
|
||||
client.username_pw_set(MQTT_USER, MQTT_PASSWORD)
|
||||
client.on_connect = on_mqtt_connect
|
||||
client.on_message = on_mqtt_message
|
||||
state.client = client
|
||||
client.connect(MQTT_HOST, MQTT_PORT, keepalive=30)
|
||||
client.loop_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
stt-worker/requirements.txt
Normal file
2
stt-worker/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
||||
paho-mqtt==2.1.0
|
||||
faster-whisper==1.1.1
|
||||
Reference in New Issue
Block a user