|
- import socket
- from abc import abstractmethod
- from io import IOBase
- from ipaddress import IPv4Address, IPv6Address
- from socket import AddressFamily
- from types import TracebackType
- from typing import (
- Any, AsyncContextManager, Callable, Collection, Dict, List, Mapping, Optional, Tuple, Type,
- TypeVar, Union)
-
- from .._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute
- from ._streams import ByteStream, Listener, T_Stream, UnreliableObjectStream
- from ._tasks import TaskGroup
-
- IPAddressType = Union[str, IPv4Address, IPv6Address]
- IPSockAddrType = Tuple[str, int]
- SockAddrType = Union[IPSockAddrType, str]
- UDPPacketType = Tuple[bytes, IPSockAddrType]
- T_Retval = TypeVar('T_Retval')
-
-
- class _NullAsyncContextManager:
- async def __aenter__(self) -> None:
- pass
-
- async def __aexit__(self, exc_type: Optional[Type[BaseException]],
- exc_val: Optional[BaseException],
- exc_tb: Optional[TracebackType]) -> Optional[bool]:
- return None
-
-
- class SocketAttribute(TypedAttributeSet):
- #: the address family of the underlying socket
- family: AddressFamily = typed_attribute()
- #: the local socket address of the underlying socket
- local_address: SockAddrType = typed_attribute()
- #: for IP addresses, the local port the underlying socket is bound to
- local_port: int = typed_attribute()
- #: the underlying stdlib socket object
- raw_socket: socket.socket = typed_attribute()
- #: the remote address the underlying socket is connected to
- remote_address: SockAddrType = typed_attribute()
- #: for IP addresses, the remote port the underlying socket is connected to
- remote_port: int = typed_attribute()
-
-
- class _SocketProvider(TypedAttributeProvider):
- @property
- def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
- from .._core._sockets import convert_ipv6_sockaddr as convert
-
- attributes: Dict[Any, Callable[[], Any]] = {
- SocketAttribute.family: lambda: self._raw_socket.family,
- SocketAttribute.local_address: lambda: convert(self._raw_socket.getsockname()),
- SocketAttribute.raw_socket: lambda: self._raw_socket
- }
- try:
- peername: Optional[Tuple[str, int]] = convert(self._raw_socket.getpeername())
- except OSError:
- peername = None
-
- # Provide the remote address for connected sockets
- if peername is not None:
- attributes[SocketAttribute.remote_address] = lambda: peername
-
- # Provide local and remote ports for IP based sockets
- if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
- attributes[SocketAttribute.local_port] = lambda: self._raw_socket.getsockname()[1]
- if peername is not None:
- remote_port = peername[1]
- attributes[SocketAttribute.remote_port] = lambda: remote_port
-
- return attributes
-
- @property
- @abstractmethod
- def _raw_socket(self) -> socket.socket:
- pass
-
-
- class SocketStream(ByteStream, _SocketProvider):
- """
- Transports bytes over a socket.
-
- Supports all relevant extra attributes from :class:`~SocketAttribute`.
- """
-
-
- class UNIXSocketStream(SocketStream):
- @abstractmethod
- async def send_fds(self, message: bytes, fds: Collection[Union[int, IOBase]]) -> None:
- """
- Send file descriptors along with a message to the peer.
-
- :param message: a non-empty bytestring
- :param fds: a collection of files (either numeric file descriptors or open file or socket
- objects)
- """
-
- @abstractmethod
- async def receive_fds(self, msglen: int, maxfds: int) -> Tuple[bytes, List[int]]:
- """
- Receive file descriptors along with a message from the peer.
-
- :param msglen: length of the message to expect from the peer
- :param maxfds: maximum number of file descriptors to expect from the peer
- :return: a tuple of (message, file descriptors)
- """
-
-
- class SocketListener(Listener[SocketStream], _SocketProvider):
- """
- Listens to incoming socket connections.
-
- Supports all relevant extra attributes from :class:`~SocketAttribute`.
- """
-
- @abstractmethod
- async def accept(self) -> SocketStream:
- """Accept an incoming connection."""
-
- async def serve(self, handler: Callable[[T_Stream], Any],
- task_group: Optional[TaskGroup] = None) -> None:
- from .. import create_task_group
-
- context_manager: AsyncContextManager
- if task_group is None:
- task_group = context_manager = create_task_group()
- else:
- # Can be replaced with AsyncExitStack once on py3.7+
- context_manager = _NullAsyncContextManager()
-
- async with context_manager:
- while True:
- stream = await self.accept()
- task_group.start_soon(handler, stream)
-
-
- class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
- """
- Represents an unconnected UDP socket.
-
- Supports all relevant extra attributes from :class:`~SocketAttribute`.
- """
-
- async def sendto(self, data: bytes, host: str, port: int) -> None:
- """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port)))."""
- return await self.send((data, (host, port)))
-
-
- class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
- """
- Represents an connected UDP socket.
-
- Supports all relevant extra attributes from :class:`~SocketAttribute`.
- """
|