Skip to content

Instantly share code, notes, and snippets.

@jsam
Created January 29, 2025 14:15
Show Gist options
  • Save jsam/e27856f8fbc9e788c8a0ff60506caeef to your computer and use it in GitHub Desktop.
Save jsam/e27856f8fbc9e788c8a0ff60506caeef to your computer and use it in GitHub Desktop.
server
#!/usr/bin/env python3
import socket
import numpy as np
import soundfile as sf
import time
import signal
import sys
import threading
import queue
from dataclasses import dataclass
from typing import List, Optional, Tuple
from collections import deque
import logging
import os
from datetime import datetime
###############################################################################
# Configuration
###############################################################################
CONFIG = {
# Server and Network
"HOST": "0.0.0.0",
"PORT": 8080,
"BUFFER_SIZE": 32768, # For reading from the socket in handle_client()
"SOCKET_TIMEOUT": 0.5,
"SOCKET_RCVBUF": 2 * 1024 * 1024, # 2MB
# Audio
"TARGET_CHANNELS": 2,
"BYTES_PER_SAMPLE": 4, # 32-bit float
"MAX_SAMPLE_VALUE": 1.0, # Not used in float normalization, but kept for clarity
"DIAGNOSTIC_INTERVAL": 5.0, # Seconds between diagnostic prints
# Logging
"LOG_LEVEL": "INFO", # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL
# Paths
"OUTPUT_DIR": "recordings"
}
###############################################################################
# Logging Setup
###############################################################################
def setup_logger():
"""
Set up the global logger based on the log level in CONFIG.
"""
numeric_level = getattr(logging, CONFIG["LOG_LEVEL"].upper(), logging.INFO)
logging.basicConfig(
level=numeric_level,
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
stream=sys.stdout
)
logger = logging.getLogger(__name__)
###############################################################################
# Data Classes & Helper Structures
###############################################################################
@dataclass
class AudioConfig:
"""
Holds audio configuration parameters.
"""
channels: int
sample_rate: int
bytes_per_frame: int
@classmethod
def from_client_data(cls, channel_data: bytes, sample_rate_data: bytes) -> 'AudioConfig':
"""
Create AudioConfig from initial client connection data.
The client is expected to send:
- 4 bytes for channels
- 4 bytes for sample_rate
"""
channels = int.from_bytes(channel_data, byteorder='little', signed=True)
sample_rate = int.from_bytes(sample_rate_data, byteorder='little', signed=True)
return cls(
channels=channels,
sample_rate=sample_rate,
bytes_per_frame=channels * CONFIG["BYTES_PER_SAMPLE"]
)
class AudioBuffer:
"""
Handles audio data buffering and frame alignment.
"""
def __init__(self, config: AudioConfig):
self.config = config
self.buffer = bytearray()
self.total_bytes_received = 0
self.bytes_per_frame = config.channels * CONFIG["BYTES_PER_SAMPLE"]
self.expected_bytes_per_sec = config.channels * config.sample_rate * CONFIG["BYTES_PER_SAMPLE"]
logger.info(f"Initialized audio buffer:")
logger.info(f" Channels: {config.channels}")
logger.info(f" Sample rate: {config.sample_rate} Hz")
logger.info(f" Bytes per frame: {self.bytes_per_frame}")
logger.info(f" Expected data rate: {self.expected_bytes_per_sec/1024/1024:.2f} MB/s")
def add_chunk(self, chunk: bytes) -> None:
"""
Add a new chunk of audio data to the buffer.
"""
chunk_size = len(chunk)
self.total_bytes_received += chunk_size
self.buffer.extend(chunk)
frames_in_buffer = len(self.buffer) // self.bytes_per_frame
duration = frames_in_buffer / self.config.sample_rate
logger.debug(
f"Added {chunk_size} bytes. "
f"Buffer length: {len(self.buffer)}. "
f"Frames in buffer: {frames_in_buffer} (~{duration:.2f}s)"
)
def get_audio_array(self) -> np.ndarray:
"""
Convert buffered data to a NumPy array with the shape (frames, channels).
Processes only complete frames; leftover bytes remain in buffer.
"""
buffer_size = len(self.buffer)
if buffer_size == 0:
logger.warning("Empty buffer, no audio to process.")
return np.array([], dtype=np.float32)
complete_frames = buffer_size // self.bytes_per_frame
usable_bytes = complete_frames * self.bytes_per_frame
if usable_bytes == 0:
logger.warning("No complete frames in buffer.")
return np.array([], dtype=np.float32)
logger.info(f"Processing buffered audio:")
logger.info(f" Total bytes received: {self.total_bytes_received/1024/1024:.2f} MB")
logger.info(f" Complete frames: {complete_frames}")
logger.info(f" Usable bytes: {usable_bytes/1024/1024:.2f} MB")
logger.info(f" Remainder bytes: {buffer_size - usable_bytes}")
try:
data = np.frombuffer(self.buffer[:usable_bytes], dtype=np.float32)
samples = len(data)
# Verify sample count is divisible by channel count
if samples % self.config.channels != 0:
raise ValueError(
f"Sample count {samples} not divisible by channel count {self.config.channels}."
)
audio = data.reshape(-1, self.config.channels)
duration = len(audio) / self.config.sample_rate
logger.info(f" Created array of shape {audio.shape}")
logger.info(f" Audio duration: {duration:.2f} seconds")
# Drop processed bytes from the buffer
del self.buffer[:usable_bytes]
return audio
except Exception as e:
logger.error(f"Error processing audio data: {e}", exc_info=True)
return np.array([], dtype=np.float32)
class AudioProcessor:
"""
Handles audio processing operations.
"""
@staticmethod
def mix_channels(audio_array: np.ndarray, target_channels: int) -> np.ndarray:
"""
Mix input channels to target channel count with simple averaging or duplication.
"""
if audio_array.size == 0:
return audio_array
if audio_array.ndim == 1:
audio_array = audio_array.reshape(-1, 1)
current_channels = audio_array.shape[1]
if current_channels == target_channels:
return audio_array
if current_channels > target_channels:
# Downmix: group channels and average them
groups = np.array_split(range(current_channels), target_channels)
mixed = np.zeros((len(audio_array), target_channels), dtype=audio_array.dtype)
for i, group in enumerate(groups):
mixed[:, i] = np.mean(audio_array[:, group], axis=1)
return mixed
else:
# Upmix: duplicate existing channels
mixed = np.zeros((len(audio_array), target_channels), dtype=audio_array.dtype)
for i in range(target_channels):
mixed[:, i] = audio_array[:, i % current_channels]
return mixed
@staticmethod
def normalize_audio(audio_array: np.ndarray, headroom_db: float = 3.0) -> np.ndarray:
"""
Normalize audio with a fixed headroom and protection against silence.
"""
if audio_array.size == 0:
return audio_array
max_val = np.max(np.abs(audio_array))
if max_val > 0:
headroom_gain = 10 ** (-headroom_db / 20)
return audio_array * (headroom_gain / max_val)
return audio_array
class NetworkDiagnostics:
"""
Tracks and reports network performance metrics.
"""
def __init__(self):
self.chunk_timestamps = deque(maxlen=1000)
self.chunk_sizes = deque(maxlen=1000)
self.last_report = time.time()
self.total_bytes = 0
def add_chunk(self, size: int) -> None:
"""
Record information about a received chunk.
"""
self.chunk_timestamps.append(time.time())
self.chunk_sizes.append(size)
self.total_bytes += size
def should_report(self) -> bool:
"""
Check if it's time for a diagnostic report.
"""
return (time.time() - self.last_report) >= CONFIG["DIAGNOSTIC_INTERVAL"]
def generate_report(self) -> str:
"""
Generate and return a diagnostic report string.
"""
if len(self.chunk_timestamps) < 2:
return "Insufficient data for diagnostics"
time_diffs = np.diff([ts for ts in self.chunk_timestamps])
current_time = time.time()
report = [
"\nNetwork Diagnostics:",
f" Total data received: {self.total_bytes / (1024*1024):.2f} MB",
f" Average chunk size: {np.mean(self.chunk_sizes):.0f} bytes",
f" Average chunk interval: {np.mean(time_diffs)*1000:.2f} ms",
f" Network jitter: {np.std(time_diffs)*1000:.2f} ms",
f" Data rate: {self.total_bytes / (current_time - self.chunk_timestamps[0]) / 1024:.1f} KB/s"
]
self.last_report = current_time
return "\n".join(report)
###############################################################################
# Main Server Class
###############################################################################
class AudioServer:
"""
Main server class handling connections and audio recording.
"""
def __init__(self):
self.is_running = True
self.setup_signal_handlers()
self.setup_output_directory()
def setup_signal_handlers(self) -> None:
"""
Set up graceful shutdown handlers for SIGINT and SIGTERM.
"""
signal.signal(signal.SIGINT, self.signal_handler)
signal.signal(signal.SIGTERM, self.signal_handler)
def setup_output_directory(self) -> None:
"""
Create output directory for recordings.
"""
self.output_dir = CONFIG["OUTPUT_DIR"]
os.makedirs(self.output_dir, exist_ok=True)
def signal_handler(self, signum, frame) -> None:
"""
Handle shutdown signals.
"""
logger.info("Shutdown signal received, stopping server...")
self.is_running = False
def configure_socket(self) -> socket.socket:
"""
Create and configure the main server socket.
"""
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# Increase receive buffer if desired
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, CONFIG["SOCKET_RCVBUF"])
return server_socket
def save_recording(self, audio_array: np.ndarray, config: AudioConfig) -> str:
"""
Save processed audio to file in the output directory.
"""
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")
filename = os.path.join(self.output_dir, f"recording_{timestamp}.wav")
frames, channels = audio_array.shape
duration = frames / config.sample_rate
size_mb = audio_array.nbytes / (1024 * 1024)
logger.info(f"Saving audio to: {filename}")
logger.info(f" Frames: {frames}")
logger.info(f" Channels: {channels}")
logger.info(f" Sample rate: {config.sample_rate}")
logger.info(f" Duration: {duration:.2f} seconds")
logger.info(f" Size: {size_mb:.2f} MB")
sf.write(filename, audio_array, config.sample_rate, subtype='FLOAT')
# Double-check the saved file
info = sf.info(filename)
logger.info(f"Verified saved file {filename}:")
logger.info(f" Duration: {info.duration:.2f}s")
logger.info(f" Channels: {info.channels}")
logger.info(f" Sample rate: {info.samplerate}")
return filename
def handle_client(self, client_socket: socket.socket, client_address: Tuple[str, int]) -> None:
"""
Handle individual client connection in a blocking manner (single-client server).
"""
logger.info(f"Client connected from {client_address}")
start_time = time.time()
try:
# 1. Receive audio configuration (channels & sample_rate)
channel_data = client_socket.recv(4)
sample_rate_data = client_socket.recv(4)
if len(channel_data) != 4 or len(sample_rate_data) != 4:
raise ValueError("Incomplete audio configuration received from client.")
config = AudioConfig.from_client_data(channel_data, sample_rate_data)
expected_bps = config.channels * config.sample_rate * CONFIG["BYTES_PER_SAMPLE"]
logger.info(f"Audio config from client:")
logger.info(f" Channels: {config.channels}")
logger.info(f" Sample rate: {config.sample_rate}")
logger.info(f" Bytes/frame: {config.bytes_per_frame}")
logger.info(f" Expected data rate: {expected_bps/1024/1024:.2f} MB/s")
# 2. Initialize components
audio_buffer = AudioBuffer(config)
diagnostics = NetworkDiagnostics()
# 3. Optimize client socket
optimal_buffer = max(131072, expected_bps // 4) # ~250ms worth of data
socket_buffer = optimal_buffer * 4 # ~1 second worth of data
client_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, socket_buffer)
client_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
client_socket.settimeout(0.1) # 100ms read timeout for responsiveness
logger.info(f"Socket configured with a {optimal_buffer} bytes read chunk "
f"(socket buffer: {socket_buffer/1024/1024:.1f} MB).")
logger.info(f"Minimal needed rate: {expected_bps/1024/1024:.1f} MB/s to keep up.")
# 4. Main receive loop
last_data_time = time.time()
bytes_received = 0
while self.is_running:
try:
data = client_socket.recv(optimal_buffer)
current_time = time.time()
if not data:
# No data means either client is done or there's a stall
if current_time - last_data_time > 0.5: # 500ms no data
logger.info("Stream ended (no data in 500ms).")
break
continue
# Check for special termination signal
if data == b"DONE":
logger.info("Received termination signal from client.")
break
data_size = len(data)
bytes_received += data_size
last_data_time = current_time
# Calculate approximate duration of audio
current_duration = bytes_received / expected_bps
logger.debug(
f"Received {data_size} bytes, total duration: {current_duration:.2f}s so far."
)
# Add data to buffer & update diagnostics
audio_buffer.add_chunk(data)
diagnostics.add_chunk(data_size)
# Periodic diagnostics
if diagnostics.should_report():
report = diagnostics.generate_report()
logger.info(report)
expected_duration = time.time() - start_time
logger.info(
f"Expected vs Actual Duration: "
f"{expected_duration:.1f}s vs {current_duration:.1f}s"
)
except socket.timeout:
# No data in this period; loop again
continue
except Exception as e:
logger.error(f"Error in receive loop: {e}", exc_info=True)
break
# 5. Final processing after the loop
logger.info(f"Client recording ended after {time.time() - start_time:.2f} seconds.")
audio_array = audio_buffer.get_audio_array()
if audio_array.size > 0:
logger.info(f"Raw audio array shape: {audio_array.shape}")
logger.info(f"Raw duration: {len(audio_array)/config.sample_rate:.2f}s")
# Mix channels to TARGET_CHANNELS
mixed_array = AudioProcessor.mix_channels(audio_array, CONFIG["TARGET_CHANNELS"])
logger.info(f"Mixed audio array shape: {mixed_array.shape}")
# Normalize
normalized_array = AudioProcessor.normalize_audio(mixed_array)
logger.info(f"Final audio array shape: {normalized_array.shape}")
# Save to file
self.save_recording(normalized_array, config)
else:
logger.warning("No audio data was received or buffer is empty.")
except Exception as e:
logger.error(f"Error handling client: {e}", exc_info=True)
finally:
client_socket.close()
logger.info(f"Client {client_address} disconnected.")
def run(self) -> None:
"""
Main server loop: listens for a single client, handles connection,
then waits for the next.
"""
server_socket = self.configure_socket()
try:
server_socket.bind((CONFIG["HOST"], CONFIG["PORT"]))
server_socket.listen(1)
server_socket.settimeout(CONFIG["SOCKET_TIMEOUT"])
logger.info(f"Server listening on {CONFIG['HOST']}:{CONFIG['PORT']}")
while self.is_running:
try:
client_socket, client_address = server_socket.accept()
# Additional optimization for the accepted socket
client_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
client_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, CONFIG["SOCKET_RCVBUF"])
# For simplicity, we handle one client at a time, blocking
self.handle_client(client_socket, client_address)
except socket.timeout:
# Check if we should stop
continue
except Exception as e:
logger.error(f"Error accepting connection: {e}", exc_info=True)
except Exception as e:
logger.error(f"Server error: {e}", exc_info=True)
finally:
server_socket.close()
logger.info("Server shutdown complete.")
###############################################################################
# Main Entry Point
###############################################################################
def main():
"""
Entry point for running the audio server.
"""
setup_logger()
logger.info("Starting Professional Audio Recording Server...")
server = AudioServer()
server.run()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment