"""MQTT Hermes client base class"""
import asyncio
import io
import json
import logging
import queue
import subprocess
import threading
import typing
import wave
from concurrent.futures import CancelledError
from pathlib import Path
from .asr import AsrTrain
from .audioserver import AudioFrame, AudioSessionFrame, AudioSummary
from .base import Message
from .nlu import NluTrain
# -----------------------------------------------------------------------------
TopicArgs = typing.Mapping[str, typing.Any]
GeneratorType = typing.AsyncIterable[
typing.Optional[typing.Union[Message, typing.Tuple[Message, TopicArgs]]]
]
# -----------------------------------------------------------------------------
[docs]class HermesClient:
"""Base class for Hermes MQTT clients"""
def __init__(
self,
client_name: str,
mqtt_client,
site_ids: typing.Optional[typing.List[str]] = None,
sample_rate: int = 16000,
sample_width: int = 2,
channels: int = 1,
loop: typing.Optional[asyncio.AbstractEventLoop] = None,
):
# Internal logger
self.client_name = client_name
self.logger = logging.getLogger(client_name)
# Paho MQTT client
self.mqtt_client = mqtt_client
self.mqtt_client.on_connect = self.mqtt_on_connect
self.mqtt_client.on_disconnect = self.mqtt_on_disconnect
self.mqtt_client.on_message = self.mqtt_on_message
# Set when on_connect succeeds
self.mqtt_connected_event: asyncio.Event = asyncio.Event()
self.is_connected: bool = False
self.subscribe_lock = threading.Lock()
self.pending_mqtt_topics: typing.Set[str] = set()
# Incoming message queue (async)
self.in_queue: typing.Optional[asyncio.Queue] = None
self.pre_queue: queue.Queue = queue.Queue()
# Message types that are subscribed to
self.subscribed_types: typing.Set[typing.Type[Message]] = set()
self.subscribed_topics: typing.Set[str] = set()
# Cache of all MQTT topics in case we get disconnected
self.all_mqtt_topics: typing.Set[str] = set()
# Set of valid site ids (empty for all)
self.site_ids: typing.Set[str] = set(site_ids) if site_ids else set()
self.site_id = "default" if not site_ids else site_ids[0]
# Required audio format
self.sample_rate = sample_rate
self.sample_width = sample_width
self.channels = channels
self.loop: typing.Optional[asyncio.AbstractEventLoop] = loop
# -------------------------------------------------------------------------
# User Methods
# -------------------------------------------------------------------------
[docs] def subscribe(self, *message_types: typing.Type[Message], **topic_args):
"""Subscribe to one or more Hermes messages."""
topics: typing.List[str] = []
if self.site_ids:
# Specific site ids
for site_id in self.site_ids:
for message_type in message_types:
topics.append(message_type.topic(site_id=site_id))
self.subscribed_types.add(message_type)
else:
# All site ids
for message_type in message_types:
topics.append(message_type.topic())
self.subscribed_types.add(message_type)
# Subscribe to all MQTT topics
self.subscribe_topics(*topics)
[docs] def subscribe_topics(self, *topics):
"""Subscribe to one or more MQTT topics."""
with self.subscribe_lock:
self.pending_mqtt_topics.update(topics)
if self.is_connected:
# Subscribe to all pending topics
for topic in self.pending_mqtt_topics:
self.all_mqtt_topics.add(topic)
# Don't re-subscribe
if topic not in self.subscribed_topics:
self.mqtt_client.subscribe(topic)
self.subscribed_topics.add(topic)
self.logger.debug("Subscribed to %s", topic)
self.pending_mqtt_topics.clear()
[docs] async def on_message(
self,
message: Message,
site_id: typing.Optional[str] = None,
session_id: typing.Optional[str] = None,
topic: typing.Optional[str] = None,
) -> GeneratorType:
"""Override to handle Hermes messages."""
yield None
[docs] async def on_message_blocking(
self,
message: Message,
site_id: typing.Optional[str] = None,
session_id: typing.Optional[str] = None,
topic: typing.Optional[str] = None,
) -> GeneratorType:
"""Override to handle Hermes messages and block."""
yield None
[docs] async def on_raw_message(self, topic: str, payload: bytes):
"""Override to handle MQTT messages."""
pass
# -------------------------------------------------------------------------
# MQTT Event Handlers
# -------------------------------------------------------------------------
[docs] def mqtt_on_connect(self, client, userdata, flags, rc):
"""Connected to MQTT broker."""
try:
self.is_connected = True
self.logger.debug("Connected to MQTT broker")
# Clear topic cache
self.subscribed_topics.clear()
# Re-subscribe to everything if previous disconnected
self.pending_mqtt_topics.update(self.all_mqtt_topics)
# Handle subscriptions
self.subscribe()
if self.loop:
self.loop.call_soon_threadsafe(self.mqtt_connected_event.set)
except Exception:
self.logger.exception("on_connect")
[docs] def mqtt_on_disconnect(self, client, userdata, flags, rc):
"""Automatically reconnect when disconnected."""
try:
self.logger.warning("Disconnected. Trying to reconnect...")
# Automatically reconnect
if self.loop:
self.loop.call_soon_threadsafe(self.mqtt_connected_event.clear)
self.is_connected = False
self.mqtt_client.reconnect()
except Exception:
self.logger.exception("on_disconnect")
[docs] def mqtt_on_message(self, client, userdata, msg):
"""Received message from MQTT broker."""
try:
# Handle message in event loop
if self.loop and self.in_queue:
self.loop.call_soon_threadsafe(self.in_queue.put_nowait, msg)
else:
# Save in pre-queue to be picked up later
self.pre_queue.put(msg)
except Exception:
self.logger.exception("on_message")
[docs] async def handle_messages_async(
self, loop: typing.Optional[asyncio.AbstractEventLoop] = None
):
"""Handles MQTT messages in event loop."""
self.loop = loop or self.loop or asyncio.get_running_loop()
self.in_queue = asyncio.Queue()
# Pull in messages from pre-queue
while self.pre_queue.qsize() > 0:
self.in_queue.put_nowait(self.pre_queue.get_nowait())
# Main loop
while True:
try:
mqtt_message = await self.in_queue.get()
if mqtt_message is None:
break
# Fire and forget
asyncio.create_task(
self.on_raw_message(mqtt_message.topic, mqtt_message.payload)
)
# Check against all known message types
for message, site_id, session_id in HermesClient.parse_mqtt_message(
mqtt_message.topic,
mqtt_message.payload,
self.subscribed_types,
logger=self.logger,
):
if not self.valid_site_id(site_id):
continue
# Log messages
if message.is_binary_payload():
# Class name + size
if not isinstance(message, (AudioFrame, AudioSessionFrame)):
self.logger.debug(
"<- %s(%s byte(s))",
message.__class__.__name__,
len(mqtt_message.payload),
)
elif isinstance(message, (AsrTrain, NluTrain)):
# Just class name
self.logger.debug("<- %s", message.__class__.__name__)
elif not isinstance(message, AudioSummary):
# Entire message
self.logger.debug("<- %s", message)
# Publish all responses (blocking)
await self.publish_all(
self.on_message_blocking(
message,
site_id=site_id,
session_id=session_id,
topic=mqtt_message.topic,
)
)
# Publish all responses (non-blocking)
asyncio.create_task(
self.publish_all(
self.on_message(
message,
site_id=site_id,
session_id=session_id,
topic=mqtt_message.topic,
)
)
)
except KeyboardInterrupt:
break
except CancelledError:
break
except Exception:
self.logger.exception("handle_messages_async")
break
[docs] @classmethod
def parse_mqtt_message(
cls,
topic: str,
payload: typing.Union[str, bytes],
subscribed_types: typing.Iterable[typing.Type[Message]],
logger=None,
) -> typing.Iterable[
typing.Tuple[Message, typing.Optional[str], typing.Optional[str]]
]:
"""Deserialize MQTT message into Hermes object."""
try:
# Check against all known message types
for message_type in subscribed_types:
if message_type.is_topic(topic):
site_id: typing.Optional[str] = None
# Verify site id and parse
if message_type.is_binary_payload():
# Binary
if message_type.is_site_in_topic():
site_id = message_type.get_site_id(topic)
# Assume payload is only argument to constructor
message = message_type(payload) # type: ignore
else:
# JSON
json_payload = json.loads(payload)
if message_type.is_site_in_topic():
site_id = message_type.get_site_id(topic)
else:
site_id = json_payload.get("siteId")
# Load from JSON
message = message_type.from_dict(json_payload)
session_id: typing.Optional[str] = None
if message_type.is_session_in_topic():
session_id = message_type.get_session_id(topic)
yield (message, site_id, session_id)
# Assume only one message type will match
break
except Exception:
if not logger:
logger = logging
logger.exception("parse_mqtt_message (topic=%s)", topic)
# -------------------------------------------------------------------------
# Publishing Messages
# -------------------------------------------------------------------------
[docs] def publish(self, message: Message, **topic_args):
"""Publish a Hermes message to MQTT."""
try:
topic = message.topic(**topic_args)
payload = message.payload()
if message.is_binary_payload():
# Don't log audio frames
if not isinstance(message, (AudioFrame, AudioSessionFrame)):
self.logger.debug(
"-> %s(%s byte(s)) to %s",
message.__class__.__name__,
len(payload),
topic,
)
else:
# Log most JSON messages
if isinstance(message, (AsrTrain, NluTrain)):
# Just class name
self.logger.debug("-> %s", message.__class__.__name__)
self.logger.debug(
"Publishing %s bytes(s) to %s", len(payload), topic
)
elif not isinstance(message, AudioSummary):
# Entire message
self.logger.debug("-> %s", message)
self.logger.debug(
"Publishing %s bytes(s) to %s", len(payload), topic
)
self.mqtt_client.publish(topic, payload)
except Exception:
self.logger.exception(
"publish (message=%s, topic_args=%s)",
message.__class__.__name__,
topic_args,
)
[docs] async def publish_all(self, async_generator: GeneratorType):
"""Enumerate all messages in an async generator publish them"""
async for maybe_message in async_generator:
if maybe_message is None:
continue
if isinstance(maybe_message, Message):
self.publish(maybe_message)
else:
message, kwargs = maybe_message
self.publish(message, **kwargs)
# -------------------------------------------------------------------------
# Utility Methods
# -------------------------------------------------------------------------
[docs] def valid_site_id(self, site_id: typing.Optional[str]):
"""True if site id is valid for this client."""
if site_id and self.site_ids:
return site_id in self.site_ids
return True
[docs] def convert_wav(
self,
wav_bytes: bytes,
sample_rate: typing.Optional[int] = None,
sample_width: typing.Optional[int] = None,
channels: typing.Optional[int] = None,
) -> bytes:
"""Converts WAV data to required format with sox. Return raw audio."""
if sample_rate is None:
sample_rate = self.sample_rate
if sample_width is None:
sample_width = self.sample_width
if channels is None:
channels = self.channels
return subprocess.run(
[
"sox",
"-t",
"wav",
"-",
"-r",
str(sample_rate),
"-e",
"signed-integer",
"-b",
str(sample_width * 8),
"-c",
str(channels),
"-t",
"raw",
"-",
],
check=True,
stdout=subprocess.PIPE,
input=wav_bytes,
).stdout
[docs] def maybe_convert_wav(
self,
wav_bytes: bytes,
sample_rate: typing.Optional[int] = None,
sample_width: typing.Optional[int] = None,
channels: typing.Optional[int] = None,
) -> bytes:
"""Converts WAV data to required format if necessary. Returns raw audio."""
if sample_rate is None:
sample_rate = self.sample_rate
if sample_width is None:
sample_width = self.sample_width
if channels is None:
channels = self.channels
with io.BytesIO(wav_bytes) as wav_io:
with wave.open(wav_io, "rb") as wav_file:
if (
(wav_file.getframerate() != sample_rate)
or (wav_file.getsampwidth() != sample_width)
or (wav_file.getnchannels() != channels)
):
# Return converted wav
return self.convert_wav(
wav_bytes,
sample_rate=sample_rate,
sample_width=sample_width,
channels=channels,
)
# Return original audio
return wav_file.readframes(wav_file.getnframes())
[docs] def to_wav_bytes(
self,
audio_data: bytes,
sample_rate: typing.Optional[int] = None,
sample_width: typing.Optional[int] = None,
channels: typing.Optional[int] = None,
) -> bytes:
"""Wrap raw audio data in WAV."""
if sample_rate is None:
sample_rate = self.sample_rate
if sample_width is None:
sample_width = self.sample_width
if channels is None:
channels = self.channels
with io.BytesIO() as wav_buffer:
wav_file: wave.Wave_write = wave.open(wav_buffer, mode="wb")
with wav_file:
wav_file.setframerate(sample_rate)
wav_file.setsampwidth(sample_width)
wav_file.setnchannels(channels)
wav_file.writeframes(audio_data)
return wav_buffer.getvalue()
[docs] def reduce_noise(
self, audio_data: bytes, noise_profile: Path, amount: float = 0.5
) -> bytes:
"""Reduce noise in raw audio using sox noise profile."""
return subprocess.run(
[
"sox",
"-r",
str(self.sample_rate),
"-e",
"signed-integer",
"-b",
str(self.sample_width * 8),
"-c",
str(self.channels),
"-t",
"raw",
"-",
"-t",
"raw",
"-",
"noisered",
str(noise_profile),
str(amount),
],
check=True,
stdout=subprocess.PIPE,
input=audio_data,
).stdout