|
- import re
- from ipaddress import (
- IPv4Address,
- IPv4Interface,
- IPv4Network,
- IPv6Address,
- IPv6Interface,
- IPv6Network,
- _BaseAddress,
- _BaseNetwork,
- )
- from typing import (
- TYPE_CHECKING,
- Any,
- Collection,
- Dict,
- Generator,
- Optional,
- Pattern,
- Set,
- Tuple,
- Type,
- Union,
- cast,
- no_type_check,
- )
-
- from . import errors
- from .utils import Representation, update_not_none
- from .validators import constr_length_validator, str_validator
-
- if TYPE_CHECKING:
- import email_validator
- from typing_extensions import TypedDict
-
- from .config import BaseConfig
- from .fields import ModelField
- from .typing import AnyCallable
-
- CallableGenerator = Generator[AnyCallable, None, None]
-
- class Parts(TypedDict, total=False):
- scheme: str
- user: Optional[str]
- password: Optional[str]
- ipv4: Optional[str]
- ipv6: Optional[str]
- domain: Optional[str]
- port: Optional[str]
- path: Optional[str]
- query: Optional[str]
- fragment: Optional[str]
-
- else:
- email_validator = None
-
- NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]]
-
- __all__ = [
- 'AnyUrl',
- 'AnyHttpUrl',
- 'FileUrl',
- 'HttpUrl',
- 'stricturl',
- 'EmailStr',
- 'NameEmail',
- 'IPvAnyAddress',
- 'IPvAnyInterface',
- 'IPvAnyNetwork',
- 'PostgresDsn',
- 'AmqpDsn',
- 'RedisDsn',
- 'KafkaDsn',
- 'validate_email',
- ]
-
- _url_regex_cache = None
- _ascii_domain_regex_cache = None
- _int_domain_regex_cache = None
-
-
- def url_regex() -> Pattern[str]:
- global _url_regex_cache
- if _url_regex_cache is None:
- _url_regex_cache = re.compile(
- r'(?:(?P<scheme>[a-z][a-z0-9+\-.]+)://)?' # scheme https://tools.ietf.org/html/rfc3986#appendix-A
- r'(?:(?P<user>[^\s:/]*)(?::(?P<password>[^\s/]*))?@)?' # user info
- r'(?:'
- r'(?P<ipv4>(?:\d{1,3}\.){3}\d{1,3})(?=$|[/:#?])|' # ipv4
- r'(?P<ipv6>\[[A-F0-9]*:[A-F0-9:]+\])(?=$|[/:#?])|' # ipv6
- r'(?P<domain>[^\s/:?#]+)' # domain, validation occurs later
- r')?'
- r'(?::(?P<port>\d+))?' # port
- r'(?P<path>/[^\s?#]*)?' # path
- r'(?:\?(?P<query>[^\s#]*))?' # query
- r'(?:#(?P<fragment>[^\s#]*))?', # fragment
- re.IGNORECASE,
- )
- return _url_regex_cache
-
-
- def ascii_domain_regex() -> Pattern[str]:
- global _ascii_domain_regex_cache
- if _ascii_domain_regex_cache is None:
- ascii_chunk = r'[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?'
- ascii_domain_ending = r'(?P<tld>\.[a-z]{2,63})?\.?'
- _ascii_domain_regex_cache = re.compile(
- fr'(?:{ascii_chunk}\.)*?{ascii_chunk}{ascii_domain_ending}', re.IGNORECASE
- )
- return _ascii_domain_regex_cache
-
-
- def int_domain_regex() -> Pattern[str]:
- global _int_domain_regex_cache
- if _int_domain_regex_cache is None:
- int_chunk = r'[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?'
- int_domain_ending = r'(?P<tld>(\.[^\W\d_]{2,63})|(\.(?:xn--)[_0-9a-z-]{2,63}))?\.?'
- _int_domain_regex_cache = re.compile(fr'(?:{int_chunk}\.)*?{int_chunk}{int_domain_ending}', re.IGNORECASE)
- return _int_domain_regex_cache
-
-
- class AnyUrl(str):
- strip_whitespace = True
- min_length = 1
- max_length = 2 ** 16
- allowed_schemes: Optional[Collection[str]] = None
- tld_required: bool = False
- user_required: bool = False
- host_required: bool = True
- hidden_parts: Set[str] = set()
-
- __slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment')
-
- @no_type_check
- def __new__(cls, url: Optional[str], **kwargs) -> object:
- return str.__new__(cls, cls.build(**kwargs) if url is None else url)
-
- def __init__(
- self,
- url: str,
- *,
- scheme: str,
- user: Optional[str] = None,
- password: Optional[str] = None,
- host: Optional[str] = None,
- tld: Optional[str] = None,
- host_type: str = 'domain',
- port: Optional[str] = None,
- path: Optional[str] = None,
- query: Optional[str] = None,
- fragment: Optional[str] = None,
- ) -> None:
- str.__init__(url)
- self.scheme = scheme
- self.user = user
- self.password = password
- self.host = host
- self.tld = tld
- self.host_type = host_type
- self.port = port
- self.path = path
- self.query = query
- self.fragment = fragment
-
- @classmethod
- def build(
- cls,
- *,
- scheme: str,
- user: Optional[str] = None,
- password: Optional[str] = None,
- host: str,
- port: Optional[str] = None,
- path: Optional[str] = None,
- query: Optional[str] = None,
- fragment: Optional[str] = None,
- **_kwargs: str,
- ) -> str:
- url = scheme + '://'
- if user:
- url += user
- if password:
- url += ':' + password
- if user or password:
- url += '@'
- url += host
- if port and 'port' not in cls.hidden_parts:
- url += ':' + port
- if path:
- url += path
- if query:
- url += '?' + query
- if fragment:
- url += '#' + fragment
- return url
-
- @classmethod
- def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
- update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length, format='uri')
-
- @classmethod
- def __get_validators__(cls) -> 'CallableGenerator':
- yield cls.validate
-
- @classmethod
- def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'AnyUrl':
- if value.__class__ == cls:
- return value
- value = str_validator(value)
- if cls.strip_whitespace:
- value = value.strip()
- url: str = cast(str, constr_length_validator(value, field, config))
-
- m = url_regex().match(url)
- # the regex should always match, if it doesn't please report with details of the URL tried
- assert m, 'URL regex failed unexpectedly'
-
- original_parts = cast('Parts', m.groupdict())
- parts = cls.apply_default_parts(original_parts)
- parts = cls.validate_parts(parts)
-
- host, tld, host_type, rebuild = cls.validate_host(parts)
-
- if m.end() != len(url):
- raise errors.UrlExtraError(extra=url[m.end() :])
-
- return cls(
- None if rebuild else url,
- scheme=parts['scheme'],
- user=parts['user'],
- password=parts['password'],
- host=host,
- tld=tld,
- host_type=host_type,
- port=parts['port'],
- path=parts['path'],
- query=parts['query'],
- fragment=parts['fragment'],
- )
-
- @classmethod
- def validate_parts(cls, parts: 'Parts') -> 'Parts':
- """
- A method used to validate parts of an URL.
- Could be overridden to set default values for parts if missing
- """
- scheme = parts['scheme']
- if scheme is None:
- raise errors.UrlSchemeError()
-
- if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes:
- raise errors.UrlSchemePermittedError(set(cls.allowed_schemes))
-
- port = parts['port']
- if port is not None and int(port) > 65_535:
- raise errors.UrlPortError()
-
- user = parts['user']
- if cls.user_required and user is None:
- raise errors.UrlUserInfoError()
-
- return parts
-
- @classmethod
- def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]:
- host, tld, host_type, rebuild = None, None, None, False
- for f in ('domain', 'ipv4', 'ipv6'):
- host = parts[f] # type: ignore[misc]
- if host:
- host_type = f
- break
-
- if host is None:
- if cls.host_required:
- raise errors.UrlHostError()
- elif host_type == 'domain':
- is_international = False
- d = ascii_domain_regex().fullmatch(host)
- if d is None:
- d = int_domain_regex().fullmatch(host)
- if d is None:
- raise errors.UrlHostError()
- is_international = True
-
- tld = d.group('tld')
- if tld is None and not is_international:
- d = int_domain_regex().fullmatch(host)
- assert d is not None
- tld = d.group('tld')
- is_international = True
-
- if tld is not None:
- tld = tld[1:]
- elif cls.tld_required:
- raise errors.UrlHostTldError()
-
- if is_international:
- host_type = 'int_domain'
- rebuild = True
- host = host.encode('idna').decode('ascii')
- if tld is not None:
- tld = tld.encode('idna').decode('ascii')
-
- return host, tld, host_type, rebuild # type: ignore
-
- @staticmethod
- def get_default_parts(parts: 'Parts') -> 'Parts':
- return {}
-
- @classmethod
- def apply_default_parts(cls, parts: 'Parts') -> 'Parts':
- for key, value in cls.get_default_parts(parts).items():
- if not parts[key]: # type: ignore[misc]
- parts[key] = value # type: ignore[misc]
- return parts
-
- def __repr__(self) -> str:
- extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None)
- return f'{self.__class__.__name__}({super().__repr__()}, {extra})'
-
-
- class AnyHttpUrl(AnyUrl):
- allowed_schemes = {'http', 'https'}
-
-
- class HttpUrl(AnyHttpUrl):
- tld_required = True
- # https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers
- max_length = 2083
- hidden_parts = {'port'}
-
- @staticmethod
- def get_default_parts(parts: 'Parts') -> 'Parts':
- return {'port': '80' if parts['scheme'] == 'http' else '443'}
-
-
- class FileUrl(AnyUrl):
- allowed_schemes = {'file'}
- host_required = False
-
-
- class PostgresDsn(AnyUrl):
- allowed_schemes = {
- 'postgres',
- 'postgresql',
- 'postgresql+asyncpg',
- 'postgresql+pg8000',
- 'postgresql+psycopg2',
- 'postgresql+psycopg2cffi',
- 'postgresql+py-postgresql',
- 'postgresql+pygresql',
- }
- user_required = True
-
-
- class AmqpDsn(AnyUrl):
- allowed_schemes = {'amqp', 'amqps'}
- host_required = False
-
-
- class RedisDsn(AnyUrl):
- allowed_schemes = {'redis', 'rediss'}
- host_required = False
-
- @staticmethod
- def get_default_parts(parts: 'Parts') -> 'Parts':
- return {
- 'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '',
- 'port': '6379',
- 'path': '/0',
- }
-
-
- class KafkaDsn(AnyUrl):
- allowed_schemes = {'kafka'}
-
- @staticmethod
- def get_default_parts(parts: 'Parts') -> 'Parts':
- return {
- 'domain': 'localhost',
- 'port': '9092',
- }
-
-
- def stricturl(
- *,
- strip_whitespace: bool = True,
- min_length: int = 1,
- max_length: int = 2 ** 16,
- tld_required: bool = True,
- host_required: bool = True,
- allowed_schemes: Optional[Collection[str]] = None,
- ) -> Type[AnyUrl]:
- # use kwargs then define conf in a dict to aid with IDE type hinting
- namespace = dict(
- strip_whitespace=strip_whitespace,
- min_length=min_length,
- max_length=max_length,
- tld_required=tld_required,
- host_required=host_required,
- allowed_schemes=allowed_schemes,
- )
- return type('UrlValue', (AnyUrl,), namespace)
-
-
- def import_email_validator() -> None:
- global email_validator
- try:
- import email_validator
- except ImportError as e:
- raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e
-
-
- class EmailStr(str):
- @classmethod
- def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
- field_schema.update(type='string', format='email')
-
- @classmethod
- def __get_validators__(cls) -> 'CallableGenerator':
- # included here and below so the error happens straight away
- import_email_validator()
-
- yield str_validator
- yield cls.validate
-
- @classmethod
- def validate(cls, value: Union[str]) -> str:
- return validate_email(value)[1]
-
-
- class NameEmail(Representation):
- __slots__ = 'name', 'email'
-
- def __init__(self, name: str, email: str):
- self.name = name
- self.email = email
-
- def __eq__(self, other: Any) -> bool:
- return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email)
-
- @classmethod
- def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
- field_schema.update(type='string', format='name-email')
-
- @classmethod
- def __get_validators__(cls) -> 'CallableGenerator':
- import_email_validator()
-
- yield cls.validate
-
- @classmethod
- def validate(cls, value: Any) -> 'NameEmail':
- if value.__class__ == cls:
- return value
- value = str_validator(value)
- return cls(*validate_email(value))
-
- def __str__(self) -> str:
- return f'{self.name} <{self.email}>'
-
-
- class IPvAnyAddress(_BaseAddress):
- @classmethod
- def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
- field_schema.update(type='string', format='ipvanyaddress')
-
- @classmethod
- def __get_validators__(cls) -> 'CallableGenerator':
- yield cls.validate
-
- @classmethod
- def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]:
- try:
- return IPv4Address(value)
- except ValueError:
- pass
-
- try:
- return IPv6Address(value)
- except ValueError:
- raise errors.IPvAnyAddressError()
-
-
- class IPvAnyInterface(_BaseAddress):
- @classmethod
- def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
- field_schema.update(type='string', format='ipvanyinterface')
-
- @classmethod
- def __get_validators__(cls) -> 'CallableGenerator':
- yield cls.validate
-
- @classmethod
- def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]:
- try:
- return IPv4Interface(value)
- except ValueError:
- pass
-
- try:
- return IPv6Interface(value)
- except ValueError:
- raise errors.IPvAnyInterfaceError()
-
-
- class IPvAnyNetwork(_BaseNetwork): # type: ignore
- @classmethod
- def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
- field_schema.update(type='string', format='ipvanynetwork')
-
- @classmethod
- def __get_validators__(cls) -> 'CallableGenerator':
- yield cls.validate
-
- @classmethod
- def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]:
- # Assume IP Network is defined with a default value for ``strict`` argument.
- # Define your own class if you want to specify network address check strictness.
- try:
- return IPv4Network(value)
- except ValueError:
- pass
-
- try:
- return IPv6Network(value)
- except ValueError:
- raise errors.IPvAnyNetworkError()
-
-
- pretty_email_regex = re.compile(r'([\w ]*?) *<(.*)> *')
-
-
- def validate_email(value: Union[str]) -> Tuple[str, str]:
- """
- Brutally simple email address validation. Note unlike most email address validation
- * raw ip address (literal) domain parts are not allowed.
- * "John Doe <local_part@domain.com>" style "pretty" email addresses are processed
- * the local part check is extremely basic. This raises the possibility of unicode spoofing, but no better
- solution is really possible.
- * spaces are striped from the beginning and end of addresses but no error is raised
-
- See RFC 5322 but treat it with suspicion, there seems to exist no universally acknowledged test for a valid email!
- """
- if email_validator is None:
- import_email_validator()
-
- m = pretty_email_regex.fullmatch(value)
- name: Optional[str] = None
- if m:
- name, value = m.groups()
-
- email = value.strip()
-
- try:
- email_validator.validate_email(email, check_deliverability=False)
- except email_validator.EmailNotValidError as e:
- raise errors.EmailError() from e
-
- at_index = email.index('@')
- local_part = email[:at_index] # RFC 5321, local part must be case-sensitive.
- global_part = email[at_index:].lower()
-
- return name or local_part, local_part + global_part
|