Last active
April 7, 2024 22:53
-
-
Save autumnjolitz/08d0774826737dcf267f80991252e3ea to your computer and use it in GitHub Desktop.
Subscribe sockets to multicast groups, allow send/recv
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
multicast - subscribe/send/recv for IPv4/IPv6 multicast groups with UDP. | |
:author: Autumn Jolitz | |
:date: 2024-04-06 | |
:license: BSD-2-Clause | |
:tags: networking, multicast, sockets, udp | |
Copyright (c) 2024, Autumn Jolitz | |
Redistribution and use in source and binary forms, with or without modification, are permitted | |
provided that the following conditions are met: | |
Redistributions of source code must retain the above copyright notice, this list of | |
conditions and the following disclaimer. | |
Redistributions in binary form must reproduce the above copyright notice, this list of | |
conditions and the following disclaimer in the documentation and/or other materials | |
provided with the distribution. | |
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS | |
OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY | |
AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER | |
OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | |
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY | |
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR | |
OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY | |
OF SUCH DAMAGE | |
""" | |
import ipaddress | |
import socket | |
import struct | |
import weakref | |
from abc import abstractmethod, abstractproperty | |
from collections.abc import Iterable, Buffer | |
from contextlib import ExitStack, suppress | |
from enum import IntEnum | |
from ipaddress import IPv4Address, IPv6Address | |
from typing import ( | |
assert_never, | |
overload, | |
ClassVar, | |
Literal, | |
NamedTuple, | |
Sized as ImplementsSized, | |
TYPE_CHECKING, | |
) | |
IPAddress = IPv4Address | IPv6Address | |
MaybeIPAddress = str | int | bytes | |
HostAddress = Literal["", "::", "0.0.0.0"] | IPv4Address | IPv6Address | int | str | |
class SizedBuffer(ImplementsSized, Buffer): | |
pass | |
class MulticastFlags(IntEnum): | |
NONE = 0 | |
MULTICAST_LOOP = 2 | |
REUSEADDR = 4 | |
REUSEPORT = 8 | |
class ReadOnlyMulticastConfig: | |
__slots__ = () | |
@property | |
def family(self) -> int: | |
version = self.address.version | |
if version == 4: | |
return socket.AF_INET | |
elif version == 6: | |
return socket.AF_INET6 | |
else: | |
assert_never() | |
@property | |
def version(self): | |
return self.address.version | |
if TYPE_CHECKING: | |
@property | |
def address(self) -> IPAddress: | |
pass | |
class IMulticastSocket: | |
__slots__ = ("__weakref__",) | |
def __init_subclass__(cls): | |
errors = [] | |
for key in dir(cls): | |
value = getattr(cls, key, None) | |
if isinstance(value, abstractproperty): | |
errors.append(NotImplementedError(f"You must implement abstractproperty {key!r}")) | |
elif getattr(value, "__isabstractmethod__", False) is True: | |
errors.append(NotImplementedError(f"You must implement abstractmethod {key!r}")) | |
if errors: | |
if len(errors) > 1: | |
raise ExceptionGroup(f"Multiple issues on class {cls}", tuple(errors)) | |
raise NotImplementedError(str(errors[0]) + f" on class {cls}") | |
return cls | |
@abstractproperty | |
def groups(self): | |
pass | |
@abstractproperty | |
def group_set(self): | |
pass | |
@abstractmethod | |
def add(self, group): | |
pass | |
@abstractmethod | |
def remove(self, group): | |
pass | |
@abstractproperty | |
def port(self): | |
pass | |
@abstractproperty | |
def flags(self): | |
pass | |
@abstractmethod | |
def __enter__(self): | |
pass | |
@abstractmethod | |
def __exit__(self, exc_type, exc_value, traceback): | |
pass | |
@abstractmethod | |
def shutdown(self, direction: int): | |
pass | |
@abstractmethod | |
def send(self, b: bytes, /, group) -> int: | |
pass | |
@abstractmethod | |
def recvfrom(self, n: int, /): | |
pass | |
@abstractmethod | |
def recvfrom_into(self, buffer: SizedBuffer, nbytes: int = -1, flags: int = 0, /): | |
pass | |
@abstractproperty | |
def closed(self): | |
pass | |
@abstractmethod | |
def close(self): | |
pass | |
@abstractmethod | |
def broadcast_send(self, b: bytes, /) -> tuple[int, ...]: | |
pass | |
class BaseMulticastSocket(IMulticastSocket): | |
__slots__ = ("readonly", "_socket", "_groups", "_group_set") | |
_socket: socket.socket | None | |
if TYPE_CHECKING: | |
AddressType: ClassVar[type[IPv4Address] | type[IPv6Address]] | |
@property | |
def groups(self): | |
return self._groups | |
@property | |
def group_set(self): | |
return frozenset(self._group_set) | |
def add(self, group): | |
if self.closed: | |
raise ValueError("I/O operation on closed socket") | |
ip_cls = self.AddressType | |
if isinstance(group, ip_cls): | |
group_addr = group | |
elif isinstance(group, (str, int)): | |
group_addr = ip_cls(group) | |
else: | |
raise TypeError(f"{group!r} is not a str, int or {ip_cls.__name__}") | |
if not group_addr.is_multicast: | |
raise ValueError(f"{group!s} is not a multicast address!") | |
if group_addr in self._group_set: | |
return self | |
self._add(group_addr) | |
return self | |
def _add(self, group): | |
ip_cls = self.AddressType | |
assert self._socket is not None | |
assert isinstance(group, ip_cls), f"{group!r} is not a {self.AddressType!r}" | |
assert group.is_multicast, f"{group!r} not a multicast address!" | |
assert ( | |
group.version == self.readonly.version | |
), f"ip version mismatch between socket {self.readonly.version} and group {group.version}!" | |
_add_group_to(self._socket, group, self.readonly) | |
self._groups = (*self._groups, group) | |
self._group_set.add(group) | |
def remove(self, group): | |
if self.closed: | |
raise ValueError("I/O operation on closed socket") | |
ip_cls = self.AddressType | |
if isinstance(group, ip_cls): | |
group_addr = group | |
elif isinstance(group, (str, int)): | |
group_addr = ip_cls(group) | |
else: | |
raise TypeError(f"{group!r} is not a str, int or {ip_cls.__name__}") | |
if not group.is_multicast: | |
raise ValueError(f"{group!s} is not a multicast address!") | |
if group not in self._group_set: | |
return self | |
self._remove(group_addr) | |
return self | |
def _remove(self, group): | |
assert self._socket is not None | |
assert isinstance(group, self.AddressType), f"{group!r} is not a {self.AddressType!r}" | |
assert group.is_multicast, f"{group!r} not a multicast address!" | |
assert ( | |
group.version == self.readonly.version | |
), f"ip version mismatch between socket {self.readonly.version} and group {group.version}!" | |
_remove_group_from(self._socket, group, self.readonly) | |
index = self._groups.index(group) | |
self._groups = self._groups[:index] + self._groups[index + 1 :] | |
self._group_set.remove(group) | |
def __init__(self, sock, /, multicast_config, groups): | |
if not isinstance(sock, socket.socket): | |
raise TypeError(f"{sock!r} is not a socket!") | |
if not ( | |
isinstance(groups, tuple) | |
and all(isinstance(group, self.AddressType) for group in groups) | |
): | |
raise TypeError(f"groups {groups!r} must be a tuple of {self.AddressType!r}") | |
if not isinstance(multicast_config, ReadOnlyMulticastConfig): | |
raise TypeError(f"multicast_config must be a ReadOnlyMulticastConfig!") | |
self._socket = sock | |
self.readonly = multicast_config | |
if sock.family != self.readonly.family: | |
raise ValueError( | |
f"mismatch between socket family {sock.family!r} and {self.readonly.family!r}" | |
) | |
self._groups = groups | |
self._group_set = set(groups) | |
if len(self._group_set) != len(groups): | |
raise ValueError("groups has duplicates!") | |
@property | |
def address(self): | |
return self.readonly.address | |
@property | |
def port(self): | |
return self.readonly.port | |
@property | |
def flags(self): | |
return self.readonly.flags | |
def __enter__(self): | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.close() | |
def shutdown(self, direction: int): | |
assert direction in (socket.SHUT_RD, socket.SHUT_RDWR, socket.SHUT_WR) | |
if self.closed: | |
return | |
assert self._socket is not None | |
assert direction in (socket.SHUT_RD, socket.SHUT_RDWR, socket.SHUT_WR) | |
return self._socket.shutdown(direction) | |
def send(self, b: bytes, /, group) -> int: | |
if self.closed: | |
raise ValueError("operation on closed socket!") | |
assert self._socket is not None | |
ip_cls = self.AddressType | |
if not isinstance(group, (ip_cls, str, int)): | |
raise TypeError(f"group must be a str, int or {ip_cls.__name__}") | |
dest: IPAddress | |
if not group: | |
try: | |
(dest,) = self._groups | |
except ValueError: | |
raise ValueError("group must be specified!") from None | |
else: | |
if isinstance(group, (str, int)): | |
dest = ip_cls(group) | |
else: | |
dest = group | |
if not isinstance(dest, ip_cls): | |
raise TypeError(f"{group} is not a str or {ip_cls.__name__}") | |
if not dest.is_multicast: | |
raise ValueError("Not a multicast address!") | |
if dest not in self._group_set: | |
raise LookupError(f"Not subscribed to {dest!s}") | |
return self._socket.sendto(b, (str(dest), self.port)) | |
def recvfrom(self, n: int, /): | |
if self.closed: | |
raise ValueError("operation on closed socket!") | |
assert self._socket is not None | |
if n < 0: | |
n = 4096 | |
return self._socket.recvfrom(n) | |
def recvfrom_into(self, buffer: SizedBuffer, nbytes: int = -1, flags: int = 0, /): | |
if self.closed: | |
raise ValueError("operation on closed socket!") | |
assert self._socket is not None | |
if nbytes == -1: | |
assert not flags | |
return self._socket.recvfrom_into(buffer) | |
return self._socket.recvfrom_into(buffer, nbytes, flags) | |
@property | |
def closed(self): | |
return self._socket is None | |
def close(self): | |
if self._socket: | |
self._socket.close() | |
self._socket = None | |
@property | |
def socket(self): | |
return self._socket | |
def broadcast_send(self, b: bytes, /) -> tuple[int, ...]: | |
return tuple(self.send(b, group) for group in self._groups) | |
class _MulticastIPv4(NamedTuple): | |
address: IPv4Address | |
port: int | |
flags: MulticastFlags | int | |
class _MulticastIPv6(NamedTuple): | |
address: IPv6Address | |
port: int | |
flags: MulticastFlags | int | |
scope_id: int | |
device: str | |
class ReadOnlyMulticastIPv4Config(_MulticastIPv4, ReadOnlyMulticastConfig): | |
__slots__ = () | |
class ReadOnlyMulticastIPv6Config(_MulticastIPv6, ReadOnlyMulticastConfig): | |
__slots__ = () | |
class MulticastIPv4Socket(BaseMulticastSocket): | |
__slots__ = () | |
AddressType: ClassVar[type[IPv4Address]] = IPv4Address | |
class MulticastIPv6Socket(BaseMulticastSocket): | |
__slots__ = () | |
AddressType: ClassVar[type[IPv6Address]] = IPv6Address | |
@property | |
def scope_id(self) -> int: | |
return self.readonly.scope_id | |
@property | |
def device(self) -> str: | |
return self.readonly.device | |
class SocketGroupRead(NamedTuple): | |
socket_index: int | |
value: bytes | |
socket_group_ref: weakref.ReferenceType["MulticastSocketGroup"] | |
@property | |
def socket_group(self) -> "None | MulticastSocketGroup": | |
return self.socket_group_ref() | |
@property | |
def socket(self) -> MulticastIPv6Socket | MulticastIPv4Socket | None: | |
g = self.socket_group | |
if g is not None: | |
return g.sockets[self.socket_index] | |
return None | |
class SocketGroupBufferRead(NamedTuple): | |
socket_index: int | |
span: tuple[int, int] | |
from_addr: tuple[str, int] | |
socket_group_ref: weakref.ReferenceType["MulticastSocketGroup"] | |
@property | |
def length(self) -> int: | |
index, endex = self.span | |
return endex - index | |
@property | |
def socket_group(self) -> "None | MulticastSocketGroup": | |
return self.socket_group_ref() | |
@property | |
def socket(self) -> MulticastIPv6Socket | MulticastIPv4Socket | None: | |
g = self.socket_group | |
if g is not None: | |
return g.sockets[self.socket_index] | |
return None | |
class MulticastSocketGroup(IMulticastSocket): | |
__slots__ = ("sockets", "readonly", "in_ctx") | |
@property | |
def groups(self): | |
return tuple(group for group in socket.groups for socket in self.sockets) | |
@property | |
def group_set(self): | |
return frozenset(self.groups) | |
@property | |
def group_types(self): | |
return frozenset(socket.AddressType for socket in self.sockets) | |
@property | |
def port(self): | |
return self.sockets[0].port | |
def __init__(self, sock: MulticastIPv4Socket | MulticastIPv6Socket, *sockets): | |
self.sockets = (sock, *sockets) | |
self.in_ctx = None | |
for sock in self.sockets: | |
sock.socket.setblocking(False) | |
@property | |
def flags(self) -> tuple[MulticastFlags]: | |
return tuple(sock.flags for sock in self.sockets) | |
def add(self, group): | |
if isinstance(group, (int, str)): | |
group_addr = ipaddress.ip_address(group) | |
elif isinstance(group, (IPv4Address, IPv6Address)): | |
group_addr = group | |
else: | |
raise TypeError | |
types = tuple(self.group_types) | |
if not isinstance(group_addr, types): | |
raise TypeError("{} is not a {}".join(group, " or ".join(types))) | |
if group in self.group_set: | |
return self | |
for sock in self.sockets: | |
if isinstance(group_addr, sock.AddressType): | |
sock.add(group) | |
break | |
raise LookupError(f"Unable to find a group for {group_addr!r}") | |
def remove(self, group): | |
if isinstance(group, (int, str)): | |
group_addr = ipaddress.ip_address(group) | |
elif isinstance(group, (IPv4Address, IPv6Address)): | |
group_addr = group | |
else: | |
raise TypeError | |
types = tuple(self.group_types) | |
if not isinstance(group_addr, types): | |
raise TypeError("{} is not a {}".join(group, " or ".join(types))) | |
if group not in self.group_set: | |
return self | |
for sock in self.sockets: | |
if isinstance(group_addr, sock.AddressType) and group in sock._groups: | |
sock.remove(group) | |
return self | |
def __enter__(self): | |
if self.in_ctx is not None: | |
raise ValueError("reentrancy NOT supported") | |
self.in_ctx = ExitStack() | |
for s in self.sockets: | |
self.in_ctx.enter_context(s) | |
self.in_ctx.__enter__() | |
return self | |
def __exit__(self, exc_type, exc_value, traceback): | |
try: | |
return self.in_ctx.__exit__(exc_type, exc_value, traceback) | |
finally: | |
self.in_ctx = None | |
def send(self, b: bytes, /, group="") -> int: | |
if not group: | |
for sock in self.sockets: | |
return sock.send(b, group) | |
if isinstance(group, (int, str)): | |
group_addr = ipaddress.ip_address(group) | |
elif isinstance(group, (IPv4Address, IPv6Address)): | |
group_addr = group | |
else: | |
raise TypeError | |
types = tuple(self.group_types) | |
if not isinstance(group_addr, types): | |
raise TypeError("{} is not a {}".format(group, " or ".join(repr(x) for x in types))) | |
for sock in self.sockets: | |
if group_addr in sock.group_set: | |
return sock.send(b, group_addr) | |
raise LookupError(group) | |
def broadcast_send(self, b: bytes, /) -> tuple[int, ...]: | |
return tuple(val for sock in self.sockets for val in sock.broadcast_send(b)) | |
def close(self): | |
for sock in self.sockets: | |
with suppress(OSError): | |
sock.close() | |
def shutdown(self, direction: int): | |
errors = [] | |
for sock in self.sockets: | |
try: | |
sock.shutdown(direction) | |
except Exception as e: | |
errors.append(e) | |
if errors: | |
raise ExceptionGroup("Unable to shutdown", tuple(errors)) | |
@property | |
def closed(self): | |
return all(sock.closed for sock in self.sockets) | |
def recvfrom(self, n: int, /) -> tuple[SocketGroupRead, ...]: | |
results = [] | |
for index, sock in enumerate(self.sockets): | |
try: | |
value = sock.recvfrom(n) | |
except BlockingIOError: | |
continue | |
results.append(SocketGroupRead(index, value, weakref.ref(self))) | |
return tuple(results) | |
def recvfrom_into( | |
self, buffer: SizedBuffer, nbytes: int = -1, flags: int = 0, / | |
) -> tuple[SocketGroupBufferRead, ...]: | |
left = nbytes | |
if left == -1: | |
left = len(buffer) | |
spans = [] | |
index = 0 | |
with memoryview(buffer) as buf: | |
for sock_index, sock in enumerate(self.sockets): | |
try: | |
length, from_addr = sock.recvfrom_into(buf, left, flags) | |
except BlockingIOError: | |
continue | |
left -= length | |
spans.append( | |
SocketGroupBufferRead( | |
sock_index, (index, index + length), from_addr, weakref.ref(self) | |
) | |
) | |
index += length | |
if left == 0: | |
break | |
if length: | |
buf = buf[length:] | |
return tuple(spans) | |
def as_multicast_address(address: IPAddress | MaybeIPAddress) -> IPAddress | None: | |
assert isinstance(address, (str, int, IPv6Address, IPv4Address)) | |
if isinstance(address, (str, int)): | |
addr = ipaddress.ip_address(address) | |
if not isinstance(addr, (IPv4Address, IPv6Address)): | |
return None | |
elif isinstance(address, (IPv4Address, IPv6Address)): | |
addr = address | |
else: | |
assert_never(address) | |
if addr.is_multicast: | |
return addr | |
return None | |
def _add_group_to(sock: socket.socket, group, config): | |
assert isinstance(group, (IPv4Address, IPv6Address)) | |
assert isinstance(config, ReadOnlyMulticastConfig), f"{config!r}" | |
assert group.is_multicast | |
assert sock.family in (socket.AddressFamily.AF_INET, socket.AddressFamily.AF_INET6) | |
if socket.AddressFamily.AF_INET == sock.family: | |
assert isinstance(config, ReadOnlyMulticastIPv4Config) | |
assert isinstance(group, IPv4Address) | |
ipv4_group: IPv4Address = group | |
request = b"".join((ipv4_group.packed, config.address.packed)) | |
sock.setsockopt( | |
socket.IPPROTO_IP, | |
socket.IP_ADD_MEMBERSHIP, | |
request, | |
) | |
elif socket.AddressFamily.AF_INET6 == sock.family: | |
assert isinstance(config, ReadOnlyMulticastIPv6Config) | |
assert isinstance(group, IPv6Address) | |
ipv6_group: IPv6Address = group | |
interface_index: bytes = struct.pack("i", config.scope_id) | |
request = b"".join((ipv6_group.packed, interface_index)) | |
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_JOIN_GROUP, request) | |
else: | |
raise NotImplementedError | |
def _remove_group_from(sock: socket.socket, group, config): | |
assert isinstance(group, (IPv4Address, IPv6Address)) | |
assert sock.family in (socket.AddressFamily.AF_INET, socket.AddressFamily.AF_INET6) | |
assert group.is_multicast | |
if socket.AddressFamily.AF_INET == sock.family: | |
assert isinstance(group, IPv4Address) | |
assert isinstance(config, ReadOnlyMulticastIPv4Config) | |
ipv4_group: IPv4Address = group | |
request = b"".join((ipv4_group.packed, config.address.packed)) | |
sock.setsockopt( | |
socket.IPPROTO_IP, | |
socket.IP_DROP_MEMBERSHIP, | |
request, | |
) | |
elif socket.AddressFamily.AF_INET6 == sock.family: | |
assert isinstance(group, IPv6Address) | |
assert isinstance(config, ReadOnlyMulticastIPv6Config) | |
ipv6_group: IPv6Address = group | |
interface_index: bytes = struct.pack("i", config.scope_id) | |
request = b"".join((ipv6_group.packed, interface_index)) | |
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_LEAVE_GROUP, request) | |
else: | |
raise NotImplementedError | |
if TYPE_CHECKING: | |
@overload | |
def create_multicast_on( | |
multicast_groups: Iterable[IPv4Address], | |
port: int, | |
/, | |
bind_on: Literal["0.0.0.0", ""] | IPv4Address = "", | |
flags: MulticastFlags | int = MulticastFlags.NONE, | |
*, | |
hop_limit: int = -1, | |
) -> MulticastIPv4Socket: | |
pass | |
@overload | |
def create_multicast_on( | |
multicast_groups: Iterable[IPv6Address], | |
port: int, | |
/, | |
bind_on: Literal["::", ""] | IPv6Address = "", | |
flags: MulticastFlags | int = MulticastFlags.NONE, | |
*, | |
hop_limit: int = -1, | |
) -> MulticastIPv6Socket: | |
pass | |
@overload | |
def create_multicast_on( | |
multicast_groups: Iterable[IPAddress], | |
port: int, | |
/, | |
bind_on: HostAddress = "", | |
flags: MulticastFlags | int = MulticastFlags.NONE, | |
*, | |
hop_limit: int = -1, | |
) -> MulticastIPv6Socket | MulticastIPv4Socket | MulticastSocketGroup: | |
pass | |
@overload | |
def create_multicast_on( | |
multicast_groups: str, | |
port: int, | |
/, | |
bind_on: HostAddress = "", | |
flags: MulticastFlags | int = MulticastFlags.NONE, | |
*, | |
hop_limit: int = -1, | |
) -> MulticastIPv6Socket | MulticastIPv4Socket | MulticastSocketGroup: | |
pass | |
@overload | |
def create_multicast_on( | |
multicast_addresses: Iterable[str], | |
port: int, | |
/, | |
bind_on: HostAddress = "", | |
flags: MulticastFlags | int = MulticastFlags.NONE, | |
*, | |
hop_limit: int = -1, | |
) -> MulticastIPv6Socket | MulticastIPv4Socket | MulticastSocketGroup: | |
pass | |
def create_multicast_on( | |
multicast_addresses, | |
port: int, | |
/, | |
bind_on: HostAddress = "", | |
flags=MulticastFlags.NONE, | |
*, | |
hop_limit: int = -1, | |
) -> MulticastSocketGroup | MulticastIPv4Socket | MulticastIPv6Socket: | |
assert isinstance(multicast_addresses, (str, IPv6Address, IPv4Address)) or isinstance( | |
multicast_addresses, Iterable | |
) | |
assert isinstance(port, int) | |
assert 0 < port < 65535 | |
assert isinstance(flags, (int, MulticastFlags)) | |
assert hop_limit != 0 and hop_limit < 256 | |
pending_multicast_addresses: Iterable[str | IPv4Address | IPv6Address] | |
if isinstance(multicast_addresses, (str, IPv6Address, IPv4Address)): | |
pending_multicast_addresses = (multicast_addresses,) | |
else: | |
pending_multicast_addresses = multicast_addresses | |
del multicast_addresses | |
if hop_limit < 0: | |
hop_limit = 1 | |
create_flags: MulticastFlags | int | |
if isinstance(flags, int): | |
try: | |
create_flags = MulticastFlags(flags) | |
except ValueError: | |
create_flags = flags | |
for flag in MulticastFlags: | |
if flag & flags: | |
flags ^= flag & flags | |
if flags: | |
raise | |
elif isinstance(flags, MulticastFlags): | |
create_flags = flags | |
else: | |
raise TypeError("flags are not an int or MulticastFlags") | |
errors = [] | |
uniq_ipv4_multicast = set() | |
uniq_ipv6_multicast = set() | |
pending_ipv4_multicast = [] | |
pending_ipv6_multicast = [] | |
for addr in pending_multicast_addresses: | |
group_address = as_multicast_address(addr) | |
if group_address is None: | |
errors.append(ValueError(f"{addr!r} is not a valid multicast address!")) | |
continue | |
del addr | |
if 4 == group_address.version: | |
if group_address in uniq_ipv6_multicast: | |
continue | |
uniq_ipv4_multicast.add(group_address) | |
pending_ipv4_multicast.append(group_address) | |
elif 6 == group_address.version: | |
if group_address in uniq_ipv6_multicast: | |
continue | |
uniq_ipv6_multicast.add(group_address) | |
pending_ipv6_multicast.append(group_address) | |
else: | |
assert_never(group_address.version) | |
if errors: | |
if len(errors) == 1: | |
raise errors[0] | |
raise ExceptionGroup("Multiple invalid multicast addresses given", tuple(errors)) | |
ipv4_multicast: tuple[IPv4Address, ...] | |
ipv6_multicast: tuple[IPv6Address, ...] | |
ipv4_multicast = tuple(pending_ipv4_multicast) | |
ipv6_multicast = tuple(pending_ipv6_multicast) | |
del uniq_ipv4_multicast, uniq_ipv6_multicast, pending_ipv4_multicast, pending_ipv6_multicast | |
if not any((ipv4_multicast, ipv6_multicast)): | |
raise ValueError("No multicast addresses given!") | |
del errors | |
del group_address | |
family: socket.AddressFamily | |
scope_id: int | |
bind_address: IPv4Address | IPv6Address | |
if bind_on in ("", "::", "0.0.0.0", 0): | |
if ipv4_multicast and ipv6_multicast: | |
s1 = create_multicast_on(ipv4_multicast, port, "0.0.0.0", flags=flags) | |
s2 = create_multicast_on(ipv6_multicast, port, "::", flags=flags) | |
assert isinstance(s1, MulticastIPv4Socket) | |
assert isinstance(s2, MulticastIPv4Socket) | |
return MulticastSocketGroup(s1, s2) | |
if ipv4_multicast: | |
bind_address = ipaddress.ip_address("0.0.0.0") | |
elif ipv6_multicast: | |
bind_address = ipaddress.ip_address("::") | |
else: | |
assert False, "unreachable" | |
else: | |
is_dns_name = ( | |
isinstance(bind_on, str) | |
and ":" not in bind_on | |
and not (0 < bind_on.count(".") < 4 and all(x.isdigit() for x in bind_on.split("."))) | |
) | |
if is_dns_name: | |
assert isinstance(bind_on, str) | |
sockets: list[MulticastIPv4Socket | MulticastIPv6Socket] = [] | |
gather_address_infos = socket.getaddrinfo( | |
bind_on, port, socket.AF_UNSPEC, socket.SOCK_DGRAM, socket.IPPROTO_UDP | |
) | |
distinct_address_families = frozenset( | |
family | |
for family, *_ in gather_address_infos | |
if family in (socket.AddressFamily.AF_INET, socket.AddressFamily.AF_INET6) | |
) | |
if len(distinct_address_families) > 1: | |
for family, sock_type, protocol, _, sock_addr in gather_address_infos: | |
address, socket_port, *addr_extra = sock_addr | |
if socket.AddressFamily.AF_INET == family: | |
bind_address4 = ipaddress.ip_address(address) | |
assert isinstance(bind_address4, IPv4Address) | |
new_socket4 = create_multicast_on( | |
ipv4_multicast, port, bind_address4, flags=create_flags | |
) | |
assert isinstance(new_socket4, MulticastIPv4Socket) | |
sockets.append(new_socket4) | |
elif socket.AddressFamily.AF_INET6 == family: | |
flow_info, scope_id = addr_extra | |
bind_address6 = ipaddress.ip_address(f"{address}%{scope_id}") | |
assert isinstance(bind_address6, IPv6Address) | |
new_socket6 = create_multicast_on( | |
ipv6_multicast, | |
port, | |
bind_address6, | |
flags=create_flags, | |
) | |
assert isinstance(new_socket6, MulticastIPv6Socket) | |
sockets.append(new_socket6) | |
del family, sock_type, protocol, sock_addr | |
del address, socket_port, addr_extra | |
assert sockets | |
return MulticastSocketGroup(*sockets) | |
elif distinct_address_families: | |
(gather_address_info,) = gather_address_infos | |
family = gather_address_info[0] | |
sock_addr = gather_address_info[-1] | |
host_address, _, *addr_extra = sock_addr | |
if family == socket.AddressFamily.AF_INET: | |
bind_address = ipaddress.ip_address(host_address) | |
assert isinstance(bind_on, IPv4Address) | |
elif family == socket.AddressFamily.AF_INET6: | |
flow_info, scope_id = addr_extra | |
bind_address = ipaddress.ip_address(f"{host_address}%{scope_id}") | |
assert isinstance(bind_on, IPv6Address) | |
else: | |
raise ValueError("unsupported") | |
is_dns_name = False | |
else: | |
raise ValueError("no interfaces to bind to") | |
if not is_dns_name: | |
if not isinstance(bind_on, (IPv4Address, IPv6Address)): | |
try: | |
bind_addr = ipaddress.ip_address(bind_on) | |
except ValueError: | |
... | |
else: | |
bind_address = bind_addr | |
del bind_addr | |
elif not isinstance(bind_on, (IPv4Address, IPv6Address)): | |
raise TypeError(f"{bind_on!r} is not an IPv4Address or IPv6Address!") | |
if isinstance(bind_address, IPv4Address): | |
family = socket.AF_INET | |
elif isinstance(bind_address, IPv6Address): | |
family = socket.AF_INET6 | |
else: | |
assert_never(bind_address) | |
if 6 == bind_address.version and None is bind_address.scope_id: | |
(info,) = socket.getaddrinfo( | |
str(bind_address), port, family, socket.SOCK_DGRAM, socket.IPPROTO_UDP | |
) | |
socket_address = info[-1] | |
assert len(socket_address) == 4 | |
address, _, flow_info6, scope_id = socket_address | |
bind_address = ipaddress.ip_address(f"{bind_address!s}%{scope_id}") | |
del flow_info6, scope_id, address, socket_address | |
sock = socket.socket(family, socket.SOCK_DGRAM) | |
if MulticastFlags.REUSEADDR & create_flags: | |
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
if MulticastFlags.REUSEPORT & create_flags: | |
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) | |
sock.bind(("", port)) | |
request: bytes | |
ttl: bytes = struct.pack("B", hop_limit) | |
if family == socket.AF_INET: | |
assert isinstance(bind_address, IPv4Address) | |
readonly4 = ReadOnlyMulticastIPv4Config(bind_address, port, create_flags) | |
for multicast_group in ipv4_multicast: | |
_add_group_to(sock, multicast_group, readonly4) | |
if MulticastFlags.MULTICAST_LOOP & create_flags: | |
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_LOOP, True) | |
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_TTL, ttl) | |
sock.setsockopt(socket.IPPROTO_IP, socket.IP_MULTICAST_IF, bind_address.packed) | |
return MulticastIPv4Socket( | |
sock, | |
readonly4, | |
ipv4_multicast, | |
) | |
elif family == socket.AF_INET6: | |
assert isinstance(bind_address, IPv6Address) | |
assert not bind_address.ipv4_mapped, "ipv4-in-6 goes kernel kaboom" | |
assert bind_address.scope_id is not None | |
device: str | |
try: | |
if bind_address.scope_id.isdigit(): | |
scope_id = int(bind_address.scope_id) | |
device = socket.if_indextoname(scope_id) | |
else: | |
scope_id = socket.if_nametoindex(bind_address.scope_id) | |
device = bind_address.scope_id | |
except OSError: | |
raise ValueError(f"IPv6 scope_id ({bind_address.scope_id!r}) is invalid!") | |
readonly6 = ReadOnlyMulticastIPv6Config( | |
bind_address, | |
port, | |
create_flags, | |
scope_id, | |
device, | |
) | |
for multicast_group in ipv4_multicast: | |
_add_group_to(sock, multicast_group, readonly6) | |
interface_index: bytes = struct.pack("i", readonly6.scope_id) | |
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, interface_index) | |
if MulticastFlags.MULTICAST_LOOP & create_flags: | |
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_LOOP, True) | |
sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_HOPS, hop_limit) | |
return MulticastIPv6Socket( | |
sock, | |
readonly6, | |
ipv6_multicast, | |
) | |
else: | |
raise NotImplementedError | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.set_defaults(use_ipv4=True, use_ipv6=False) | |
excl_ipv4 = parser.add_mutually_exclusive_group() | |
excl_ipv4.add_argument("-4", "--ipv4", dest="use_ipv4", action="store_true") | |
excl_ipv4.add_argument("--no-ipv4", dest="use_ipv4", action="store_false") | |
excl_ipv6 = parser.add_mutually_exclusive_group() | |
excl_ipv6.add_argument("-6", "--ipv6", dest="use_ipv6", action="store_true") | |
excl_ipv6.add_argument("--no-ipv6", dest="use_ipv6", action="store_false") | |
args = parser.parse_args() | |
if args.use_ipv4: | |
with create_multicast_on( | |
{"224.0.1.187"}, 8885, flags=MulticastFlags.MULTICAST_LOOP | |
) as conn1, create_multicast_on( | |
{"224.0.1.188"}, 8886, flags=MulticastFlags.MULTICAST_LOOP | |
) as conn2: | |
for conn in conn1, conn2: | |
for name in conn.readonly._fields: | |
print(repr(getattr(conn, name))) | |
conn = MulticastSocketGroup(conn1, conn2) | |
print(conn.send(b"1234", "224.0.1.187")) | |
print(conn.send(b"1234", "224.0.1.188")) | |
print(conn.broadcast_send(b"1234519")) | |
print(conn.recvfrom(10)) | |
b = bytearray(20) | |
print(conn.recvfrom_into(b)) | |
print(b) | |
if args.use_ipv6: | |
with create_multicast_on("FF00::FD", 8885, flags=MulticastFlags.MULTICAST_LOOP) as conn: | |
print(conn.send(b"1234")) | |
print(conn.recvfrom(4)) |
Author
autumnjolitz
commented
Apr 7, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment