Created
January 29, 2025 14:15
-
-
Save jsam/e27856f8fbc9e788c8a0ff60506caeef to your computer and use it in GitHub Desktop.
server
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import socket | |
import numpy as np | |
import soundfile as sf | |
import time | |
import signal | |
import sys | |
import threading | |
import queue | |
from dataclasses import dataclass | |
from typing import List, Optional, Tuple | |
from collections import deque | |
import logging | |
import os | |
from datetime import datetime | |
############################################################################### | |
# Configuration | |
############################################################################### | |
CONFIG = { | |
# Server and Network | |
"HOST": "0.0.0.0", | |
"PORT": 8080, | |
"BUFFER_SIZE": 32768, # For reading from the socket in handle_client() | |
"SOCKET_TIMEOUT": 0.5, | |
"SOCKET_RCVBUF": 2 * 1024 * 1024, # 2MB | |
# Audio | |
"TARGET_CHANNELS": 2, | |
"BYTES_PER_SAMPLE": 4, # 32-bit float | |
"MAX_SAMPLE_VALUE": 1.0, # Not used in float normalization, but kept for clarity | |
"DIAGNOSTIC_INTERVAL": 5.0, # Seconds between diagnostic prints | |
# Logging | |
"LOG_LEVEL": "INFO", # Options: DEBUG, INFO, WARNING, ERROR, CRITICAL | |
# Paths | |
"OUTPUT_DIR": "recordings" | |
} | |
############################################################################### | |
# Logging Setup | |
############################################################################### | |
def setup_logger(): | |
""" | |
Set up the global logger based on the log level in CONFIG. | |
""" | |
numeric_level = getattr(logging, CONFIG["LOG_LEVEL"].upper(), logging.INFO) | |
logging.basicConfig( | |
level=numeric_level, | |
format="%(asctime)s - %(levelname)s - %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
stream=sys.stdout | |
) | |
logger = logging.getLogger(__name__) | |
############################################################################### | |
# Data Classes & Helper Structures | |
############################################################################### | |
@dataclass | |
class AudioConfig: | |
""" | |
Holds audio configuration parameters. | |
""" | |
channels: int | |
sample_rate: int | |
bytes_per_frame: int | |
@classmethod | |
def from_client_data(cls, channel_data: bytes, sample_rate_data: bytes) -> 'AudioConfig': | |
""" | |
Create AudioConfig from initial client connection data. | |
The client is expected to send: | |
- 4 bytes for channels | |
- 4 bytes for sample_rate | |
""" | |
channels = int.from_bytes(channel_data, byteorder='little', signed=True) | |
sample_rate = int.from_bytes(sample_rate_data, byteorder='little', signed=True) | |
return cls( | |
channels=channels, | |
sample_rate=sample_rate, | |
bytes_per_frame=channels * CONFIG["BYTES_PER_SAMPLE"] | |
) | |
class AudioBuffer: | |
""" | |
Handles audio data buffering and frame alignment. | |
""" | |
def __init__(self, config: AudioConfig): | |
self.config = config | |
self.buffer = bytearray() | |
self.total_bytes_received = 0 | |
self.bytes_per_frame = config.channels * CONFIG["BYTES_PER_SAMPLE"] | |
self.expected_bytes_per_sec = config.channels * config.sample_rate * CONFIG["BYTES_PER_SAMPLE"] | |
logger.info(f"Initialized audio buffer:") | |
logger.info(f" Channels: {config.channels}") | |
logger.info(f" Sample rate: {config.sample_rate} Hz") | |
logger.info(f" Bytes per frame: {self.bytes_per_frame}") | |
logger.info(f" Expected data rate: {self.expected_bytes_per_sec/1024/1024:.2f} MB/s") | |
def add_chunk(self, chunk: bytes) -> None: | |
""" | |
Add a new chunk of audio data to the buffer. | |
""" | |
chunk_size = len(chunk) | |
self.total_bytes_received += chunk_size | |
self.buffer.extend(chunk) | |
frames_in_buffer = len(self.buffer) // self.bytes_per_frame | |
duration = frames_in_buffer / self.config.sample_rate | |
logger.debug( | |
f"Added {chunk_size} bytes. " | |
f"Buffer length: {len(self.buffer)}. " | |
f"Frames in buffer: {frames_in_buffer} (~{duration:.2f}s)" | |
) | |
def get_audio_array(self) -> np.ndarray: | |
""" | |
Convert buffered data to a NumPy array with the shape (frames, channels). | |
Processes only complete frames; leftover bytes remain in buffer. | |
""" | |
buffer_size = len(self.buffer) | |
if buffer_size == 0: | |
logger.warning("Empty buffer, no audio to process.") | |
return np.array([], dtype=np.float32) | |
complete_frames = buffer_size // self.bytes_per_frame | |
usable_bytes = complete_frames * self.bytes_per_frame | |
if usable_bytes == 0: | |
logger.warning("No complete frames in buffer.") | |
return np.array([], dtype=np.float32) | |
logger.info(f"Processing buffered audio:") | |
logger.info(f" Total bytes received: {self.total_bytes_received/1024/1024:.2f} MB") | |
logger.info(f" Complete frames: {complete_frames}") | |
logger.info(f" Usable bytes: {usable_bytes/1024/1024:.2f} MB") | |
logger.info(f" Remainder bytes: {buffer_size - usable_bytes}") | |
try: | |
data = np.frombuffer(self.buffer[:usable_bytes], dtype=np.float32) | |
samples = len(data) | |
# Verify sample count is divisible by channel count | |
if samples % self.config.channels != 0: | |
raise ValueError( | |
f"Sample count {samples} not divisible by channel count {self.config.channels}." | |
) | |
audio = data.reshape(-1, self.config.channels) | |
duration = len(audio) / self.config.sample_rate | |
logger.info(f" Created array of shape {audio.shape}") | |
logger.info(f" Audio duration: {duration:.2f} seconds") | |
# Drop processed bytes from the buffer | |
del self.buffer[:usable_bytes] | |
return audio | |
except Exception as e: | |
logger.error(f"Error processing audio data: {e}", exc_info=True) | |
return np.array([], dtype=np.float32) | |
class AudioProcessor: | |
""" | |
Handles audio processing operations. | |
""" | |
@staticmethod | |
def mix_channels(audio_array: np.ndarray, target_channels: int) -> np.ndarray: | |
""" | |
Mix input channels to target channel count with simple averaging or duplication. | |
""" | |
if audio_array.size == 0: | |
return audio_array | |
if audio_array.ndim == 1: | |
audio_array = audio_array.reshape(-1, 1) | |
current_channels = audio_array.shape[1] | |
if current_channels == target_channels: | |
return audio_array | |
if current_channels > target_channels: | |
# Downmix: group channels and average them | |
groups = np.array_split(range(current_channels), target_channels) | |
mixed = np.zeros((len(audio_array), target_channels), dtype=audio_array.dtype) | |
for i, group in enumerate(groups): | |
mixed[:, i] = np.mean(audio_array[:, group], axis=1) | |
return mixed | |
else: | |
# Upmix: duplicate existing channels | |
mixed = np.zeros((len(audio_array), target_channels), dtype=audio_array.dtype) | |
for i in range(target_channels): | |
mixed[:, i] = audio_array[:, i % current_channels] | |
return mixed | |
@staticmethod | |
def normalize_audio(audio_array: np.ndarray, headroom_db: float = 3.0) -> np.ndarray: | |
""" | |
Normalize audio with a fixed headroom and protection against silence. | |
""" | |
if audio_array.size == 0: | |
return audio_array | |
max_val = np.max(np.abs(audio_array)) | |
if max_val > 0: | |
headroom_gain = 10 ** (-headroom_db / 20) | |
return audio_array * (headroom_gain / max_val) | |
return audio_array | |
class NetworkDiagnostics: | |
""" | |
Tracks and reports network performance metrics. | |
""" | |
def __init__(self): | |
self.chunk_timestamps = deque(maxlen=1000) | |
self.chunk_sizes = deque(maxlen=1000) | |
self.last_report = time.time() | |
self.total_bytes = 0 | |
def add_chunk(self, size: int) -> None: | |
""" | |
Record information about a received chunk. | |
""" | |
self.chunk_timestamps.append(time.time()) | |
self.chunk_sizes.append(size) | |
self.total_bytes += size | |
def should_report(self) -> bool: | |
""" | |
Check if it's time for a diagnostic report. | |
""" | |
return (time.time() - self.last_report) >= CONFIG["DIAGNOSTIC_INTERVAL"] | |
def generate_report(self) -> str: | |
""" | |
Generate and return a diagnostic report string. | |
""" | |
if len(self.chunk_timestamps) < 2: | |
return "Insufficient data for diagnostics" | |
time_diffs = np.diff([ts for ts in self.chunk_timestamps]) | |
current_time = time.time() | |
report = [ | |
"\nNetwork Diagnostics:", | |
f" Total data received: {self.total_bytes / (1024*1024):.2f} MB", | |
f" Average chunk size: {np.mean(self.chunk_sizes):.0f} bytes", | |
f" Average chunk interval: {np.mean(time_diffs)*1000:.2f} ms", | |
f" Network jitter: {np.std(time_diffs)*1000:.2f} ms", | |
f" Data rate: {self.total_bytes / (current_time - self.chunk_timestamps[0]) / 1024:.1f} KB/s" | |
] | |
self.last_report = current_time | |
return "\n".join(report) | |
############################################################################### | |
# Main Server Class | |
############################################################################### | |
class AudioServer: | |
""" | |
Main server class handling connections and audio recording. | |
""" | |
def __init__(self): | |
self.is_running = True | |
self.setup_signal_handlers() | |
self.setup_output_directory() | |
def setup_signal_handlers(self) -> None: | |
""" | |
Set up graceful shutdown handlers for SIGINT and SIGTERM. | |
""" | |
signal.signal(signal.SIGINT, self.signal_handler) | |
signal.signal(signal.SIGTERM, self.signal_handler) | |
def setup_output_directory(self) -> None: | |
""" | |
Create output directory for recordings. | |
""" | |
self.output_dir = CONFIG["OUTPUT_DIR"] | |
os.makedirs(self.output_dir, exist_ok=True) | |
def signal_handler(self, signum, frame) -> None: | |
""" | |
Handle shutdown signals. | |
""" | |
logger.info("Shutdown signal received, stopping server...") | |
self.is_running = False | |
def configure_socket(self) -> socket.socket: | |
""" | |
Create and configure the main server socket. | |
""" | |
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
# Increase receive buffer if desired | |
server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, CONFIG["SOCKET_RCVBUF"]) | |
return server_socket | |
def save_recording(self, audio_array: np.ndarray, config: AudioConfig) -> str: | |
""" | |
Save processed audio to file in the output directory. | |
""" | |
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
filename = os.path.join(self.output_dir, f"recording_{timestamp}.wav") | |
frames, channels = audio_array.shape | |
duration = frames / config.sample_rate | |
size_mb = audio_array.nbytes / (1024 * 1024) | |
logger.info(f"Saving audio to: {filename}") | |
logger.info(f" Frames: {frames}") | |
logger.info(f" Channels: {channels}") | |
logger.info(f" Sample rate: {config.sample_rate}") | |
logger.info(f" Duration: {duration:.2f} seconds") | |
logger.info(f" Size: {size_mb:.2f} MB") | |
sf.write(filename, audio_array, config.sample_rate, subtype='FLOAT') | |
# Double-check the saved file | |
info = sf.info(filename) | |
logger.info(f"Verified saved file {filename}:") | |
logger.info(f" Duration: {info.duration:.2f}s") | |
logger.info(f" Channels: {info.channels}") | |
logger.info(f" Sample rate: {info.samplerate}") | |
return filename | |
def handle_client(self, client_socket: socket.socket, client_address: Tuple[str, int]) -> None: | |
""" | |
Handle individual client connection in a blocking manner (single-client server). | |
""" | |
logger.info(f"Client connected from {client_address}") | |
start_time = time.time() | |
try: | |
# 1. Receive audio configuration (channels & sample_rate) | |
channel_data = client_socket.recv(4) | |
sample_rate_data = client_socket.recv(4) | |
if len(channel_data) != 4 or len(sample_rate_data) != 4: | |
raise ValueError("Incomplete audio configuration received from client.") | |
config = AudioConfig.from_client_data(channel_data, sample_rate_data) | |
expected_bps = config.channels * config.sample_rate * CONFIG["BYTES_PER_SAMPLE"] | |
logger.info(f"Audio config from client:") | |
logger.info(f" Channels: {config.channels}") | |
logger.info(f" Sample rate: {config.sample_rate}") | |
logger.info(f" Bytes/frame: {config.bytes_per_frame}") | |
logger.info(f" Expected data rate: {expected_bps/1024/1024:.2f} MB/s") | |
# 2. Initialize components | |
audio_buffer = AudioBuffer(config) | |
diagnostics = NetworkDiagnostics() | |
# 3. Optimize client socket | |
optimal_buffer = max(131072, expected_bps // 4) # ~250ms worth of data | |
socket_buffer = optimal_buffer * 4 # ~1 second worth of data | |
client_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, socket_buffer) | |
client_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
client_socket.settimeout(0.1) # 100ms read timeout for responsiveness | |
logger.info(f"Socket configured with a {optimal_buffer} bytes read chunk " | |
f"(socket buffer: {socket_buffer/1024/1024:.1f} MB).") | |
logger.info(f"Minimal needed rate: {expected_bps/1024/1024:.1f} MB/s to keep up.") | |
# 4. Main receive loop | |
last_data_time = time.time() | |
bytes_received = 0 | |
while self.is_running: | |
try: | |
data = client_socket.recv(optimal_buffer) | |
current_time = time.time() | |
if not data: | |
# No data means either client is done or there's a stall | |
if current_time - last_data_time > 0.5: # 500ms no data | |
logger.info("Stream ended (no data in 500ms).") | |
break | |
continue | |
# Check for special termination signal | |
if data == b"DONE": | |
logger.info("Received termination signal from client.") | |
break | |
data_size = len(data) | |
bytes_received += data_size | |
last_data_time = current_time | |
# Calculate approximate duration of audio | |
current_duration = bytes_received / expected_bps | |
logger.debug( | |
f"Received {data_size} bytes, total duration: {current_duration:.2f}s so far." | |
) | |
# Add data to buffer & update diagnostics | |
audio_buffer.add_chunk(data) | |
diagnostics.add_chunk(data_size) | |
# Periodic diagnostics | |
if diagnostics.should_report(): | |
report = diagnostics.generate_report() | |
logger.info(report) | |
expected_duration = time.time() - start_time | |
logger.info( | |
f"Expected vs Actual Duration: " | |
f"{expected_duration:.1f}s vs {current_duration:.1f}s" | |
) | |
except socket.timeout: | |
# No data in this period; loop again | |
continue | |
except Exception as e: | |
logger.error(f"Error in receive loop: {e}", exc_info=True) | |
break | |
# 5. Final processing after the loop | |
logger.info(f"Client recording ended after {time.time() - start_time:.2f} seconds.") | |
audio_array = audio_buffer.get_audio_array() | |
if audio_array.size > 0: | |
logger.info(f"Raw audio array shape: {audio_array.shape}") | |
logger.info(f"Raw duration: {len(audio_array)/config.sample_rate:.2f}s") | |
# Mix channels to TARGET_CHANNELS | |
mixed_array = AudioProcessor.mix_channels(audio_array, CONFIG["TARGET_CHANNELS"]) | |
logger.info(f"Mixed audio array shape: {mixed_array.shape}") | |
# Normalize | |
normalized_array = AudioProcessor.normalize_audio(mixed_array) | |
logger.info(f"Final audio array shape: {normalized_array.shape}") | |
# Save to file | |
self.save_recording(normalized_array, config) | |
else: | |
logger.warning("No audio data was received or buffer is empty.") | |
except Exception as e: | |
logger.error(f"Error handling client: {e}", exc_info=True) | |
finally: | |
client_socket.close() | |
logger.info(f"Client {client_address} disconnected.") | |
def run(self) -> None: | |
""" | |
Main server loop: listens for a single client, handles connection, | |
then waits for the next. | |
""" | |
server_socket = self.configure_socket() | |
try: | |
server_socket.bind((CONFIG["HOST"], CONFIG["PORT"])) | |
server_socket.listen(1) | |
server_socket.settimeout(CONFIG["SOCKET_TIMEOUT"]) | |
logger.info(f"Server listening on {CONFIG['HOST']}:{CONFIG['PORT']}") | |
while self.is_running: | |
try: | |
client_socket, client_address = server_socket.accept() | |
# Additional optimization for the accepted socket | |
client_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
client_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, CONFIG["SOCKET_RCVBUF"]) | |
# For simplicity, we handle one client at a time, blocking | |
self.handle_client(client_socket, client_address) | |
except socket.timeout: | |
# Check if we should stop | |
continue | |
except Exception as e: | |
logger.error(f"Error accepting connection: {e}", exc_info=True) | |
except Exception as e: | |
logger.error(f"Server error: {e}", exc_info=True) | |
finally: | |
server_socket.close() | |
logger.info("Server shutdown complete.") | |
############################################################################### | |
# Main Entry Point | |
############################################################################### | |
def main(): | |
""" | |
Entry point for running the audio server. | |
""" | |
setup_logger() | |
logger.info("Starting Professional Audio Recording Server...") | |
server = AudioServer() | |
server.run() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment