You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

565 line
16 KiB

  1. import re
  2. from ipaddress import (
  3. IPv4Address,
  4. IPv4Interface,
  5. IPv4Network,
  6. IPv6Address,
  7. IPv6Interface,
  8. IPv6Network,
  9. _BaseAddress,
  10. _BaseNetwork,
  11. )
  12. from typing import (
  13. TYPE_CHECKING,
  14. Any,
  15. Collection,
  16. Dict,
  17. Generator,
  18. Optional,
  19. Pattern,
  20. Set,
  21. Tuple,
  22. Type,
  23. Union,
  24. cast,
  25. no_type_check,
  26. )
  27. from . import errors
  28. from .utils import Representation, update_not_none
  29. from .validators import constr_length_validator, str_validator
  30. if TYPE_CHECKING:
  31. import email_validator
  32. from typing_extensions import TypedDict
  33. from .config import BaseConfig
  34. from .fields import ModelField
  35. from .typing import AnyCallable
  36. CallableGenerator = Generator[AnyCallable, None, None]
  37. class Parts(TypedDict, total=False):
  38. scheme: str
  39. user: Optional[str]
  40. password: Optional[str]
  41. ipv4: Optional[str]
  42. ipv6: Optional[str]
  43. domain: Optional[str]
  44. port: Optional[str]
  45. path: Optional[str]
  46. query: Optional[str]
  47. fragment: Optional[str]
  48. else:
  49. email_validator = None
  50. NetworkType = Union[str, bytes, int, Tuple[Union[str, bytes, int], Union[str, int]]]
  51. __all__ = [
  52. 'AnyUrl',
  53. 'AnyHttpUrl',
  54. 'FileUrl',
  55. 'HttpUrl',
  56. 'stricturl',
  57. 'EmailStr',
  58. 'NameEmail',
  59. 'IPvAnyAddress',
  60. 'IPvAnyInterface',
  61. 'IPvAnyNetwork',
  62. 'PostgresDsn',
  63. 'AmqpDsn',
  64. 'RedisDsn',
  65. 'KafkaDsn',
  66. 'validate_email',
  67. ]
  68. _url_regex_cache = None
  69. _ascii_domain_regex_cache = None
  70. _int_domain_regex_cache = None
  71. def url_regex() -> Pattern[str]:
  72. global _url_regex_cache
  73. if _url_regex_cache is None:
  74. _url_regex_cache = re.compile(
  75. r'(?:(?P<scheme>[a-z][a-z0-9+\-.]+)://)?' # scheme https://tools.ietf.org/html/rfc3986#appendix-A
  76. r'(?:(?P<user>[^\s:/]*)(?::(?P<password>[^\s/]*))?@)?' # user info
  77. r'(?:'
  78. r'(?P<ipv4>(?:\d{1,3}\.){3}\d{1,3})(?=$|[/:#?])|' # ipv4
  79. r'(?P<ipv6>\[[A-F0-9]*:[A-F0-9:]+\])(?=$|[/:#?])|' # ipv6
  80. r'(?P<domain>[^\s/:?#]+)' # domain, validation occurs later
  81. r')?'
  82. r'(?::(?P<port>\d+))?' # port
  83. r'(?P<path>/[^\s?#]*)?' # path
  84. r'(?:\?(?P<query>[^\s#]*))?' # query
  85. r'(?:#(?P<fragment>[^\s#]*))?', # fragment
  86. re.IGNORECASE,
  87. )
  88. return _url_regex_cache
  89. def ascii_domain_regex() -> Pattern[str]:
  90. global _ascii_domain_regex_cache
  91. if _ascii_domain_regex_cache is None:
  92. ascii_chunk = r'[_0-9a-z](?:[-_0-9a-z]{0,61}[_0-9a-z])?'
  93. ascii_domain_ending = r'(?P<tld>\.[a-z]{2,63})?\.?'
  94. _ascii_domain_regex_cache = re.compile(
  95. fr'(?:{ascii_chunk}\.)*?{ascii_chunk}{ascii_domain_ending}', re.IGNORECASE
  96. )
  97. return _ascii_domain_regex_cache
  98. def int_domain_regex() -> Pattern[str]:
  99. global _int_domain_regex_cache
  100. if _int_domain_regex_cache is None:
  101. int_chunk = r'[_0-9a-\U00040000](?:[-_0-9a-\U00040000]{0,61}[_0-9a-\U00040000])?'
  102. int_domain_ending = r'(?P<tld>(\.[^\W\d_]{2,63})|(\.(?:xn--)[_0-9a-z-]{2,63}))?\.?'
  103. _int_domain_regex_cache = re.compile(fr'(?:{int_chunk}\.)*?{int_chunk}{int_domain_ending}', re.IGNORECASE)
  104. return _int_domain_regex_cache
  105. class AnyUrl(str):
  106. strip_whitespace = True
  107. min_length = 1
  108. max_length = 2 ** 16
  109. allowed_schemes: Optional[Collection[str]] = None
  110. tld_required: bool = False
  111. user_required: bool = False
  112. host_required: bool = True
  113. hidden_parts: Set[str] = set()
  114. __slots__ = ('scheme', 'user', 'password', 'host', 'tld', 'host_type', 'port', 'path', 'query', 'fragment')
  115. @no_type_check
  116. def __new__(cls, url: Optional[str], **kwargs) -> object:
  117. return str.__new__(cls, cls.build(**kwargs) if url is None else url)
  118. def __init__(
  119. self,
  120. url: str,
  121. *,
  122. scheme: str,
  123. user: Optional[str] = None,
  124. password: Optional[str] = None,
  125. host: Optional[str] = None,
  126. tld: Optional[str] = None,
  127. host_type: str = 'domain',
  128. port: Optional[str] = None,
  129. path: Optional[str] = None,
  130. query: Optional[str] = None,
  131. fragment: Optional[str] = None,
  132. ) -> None:
  133. str.__init__(url)
  134. self.scheme = scheme
  135. self.user = user
  136. self.password = password
  137. self.host = host
  138. self.tld = tld
  139. self.host_type = host_type
  140. self.port = port
  141. self.path = path
  142. self.query = query
  143. self.fragment = fragment
  144. @classmethod
  145. def build(
  146. cls,
  147. *,
  148. scheme: str,
  149. user: Optional[str] = None,
  150. password: Optional[str] = None,
  151. host: str,
  152. port: Optional[str] = None,
  153. path: Optional[str] = None,
  154. query: Optional[str] = None,
  155. fragment: Optional[str] = None,
  156. **_kwargs: str,
  157. ) -> str:
  158. url = scheme + '://'
  159. if user:
  160. url += user
  161. if password:
  162. url += ':' + password
  163. if user or password:
  164. url += '@'
  165. url += host
  166. if port and 'port' not in cls.hidden_parts:
  167. url += ':' + port
  168. if path:
  169. url += path
  170. if query:
  171. url += '?' + query
  172. if fragment:
  173. url += '#' + fragment
  174. return url
  175. @classmethod
  176. def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
  177. update_not_none(field_schema, minLength=cls.min_length, maxLength=cls.max_length, format='uri')
  178. @classmethod
  179. def __get_validators__(cls) -> 'CallableGenerator':
  180. yield cls.validate
  181. @classmethod
  182. def validate(cls, value: Any, field: 'ModelField', config: 'BaseConfig') -> 'AnyUrl':
  183. if value.__class__ == cls:
  184. return value
  185. value = str_validator(value)
  186. if cls.strip_whitespace:
  187. value = value.strip()
  188. url: str = cast(str, constr_length_validator(value, field, config))
  189. m = url_regex().match(url)
  190. # the regex should always match, if it doesn't please report with details of the URL tried
  191. assert m, 'URL regex failed unexpectedly'
  192. original_parts = cast('Parts', m.groupdict())
  193. parts = cls.apply_default_parts(original_parts)
  194. parts = cls.validate_parts(parts)
  195. host, tld, host_type, rebuild = cls.validate_host(parts)
  196. if m.end() != len(url):
  197. raise errors.UrlExtraError(extra=url[m.end() :])
  198. return cls(
  199. None if rebuild else url,
  200. scheme=parts['scheme'],
  201. user=parts['user'],
  202. password=parts['password'],
  203. host=host,
  204. tld=tld,
  205. host_type=host_type,
  206. port=parts['port'],
  207. path=parts['path'],
  208. query=parts['query'],
  209. fragment=parts['fragment'],
  210. )
  211. @classmethod
  212. def validate_parts(cls, parts: 'Parts') -> 'Parts':
  213. """
  214. A method used to validate parts of an URL.
  215. Could be overridden to set default values for parts if missing
  216. """
  217. scheme = parts['scheme']
  218. if scheme is None:
  219. raise errors.UrlSchemeError()
  220. if cls.allowed_schemes and scheme.lower() not in cls.allowed_schemes:
  221. raise errors.UrlSchemePermittedError(set(cls.allowed_schemes))
  222. port = parts['port']
  223. if port is not None and int(port) > 65_535:
  224. raise errors.UrlPortError()
  225. user = parts['user']
  226. if cls.user_required and user is None:
  227. raise errors.UrlUserInfoError()
  228. return parts
  229. @classmethod
  230. def validate_host(cls, parts: 'Parts') -> Tuple[str, Optional[str], str, bool]:
  231. host, tld, host_type, rebuild = None, None, None, False
  232. for f in ('domain', 'ipv4', 'ipv6'):
  233. host = parts[f] # type: ignore[misc]
  234. if host:
  235. host_type = f
  236. break
  237. if host is None:
  238. if cls.host_required:
  239. raise errors.UrlHostError()
  240. elif host_type == 'domain':
  241. is_international = False
  242. d = ascii_domain_regex().fullmatch(host)
  243. if d is None:
  244. d = int_domain_regex().fullmatch(host)
  245. if d is None:
  246. raise errors.UrlHostError()
  247. is_international = True
  248. tld = d.group('tld')
  249. if tld is None and not is_international:
  250. d = int_domain_regex().fullmatch(host)
  251. assert d is not None
  252. tld = d.group('tld')
  253. is_international = True
  254. if tld is not None:
  255. tld = tld[1:]
  256. elif cls.tld_required:
  257. raise errors.UrlHostTldError()
  258. if is_international:
  259. host_type = 'int_domain'
  260. rebuild = True
  261. host = host.encode('idna').decode('ascii')
  262. if tld is not None:
  263. tld = tld.encode('idna').decode('ascii')
  264. return host, tld, host_type, rebuild # type: ignore
  265. @staticmethod
  266. def get_default_parts(parts: 'Parts') -> 'Parts':
  267. return {}
  268. @classmethod
  269. def apply_default_parts(cls, parts: 'Parts') -> 'Parts':
  270. for key, value in cls.get_default_parts(parts).items():
  271. if not parts[key]: # type: ignore[misc]
  272. parts[key] = value # type: ignore[misc]
  273. return parts
  274. def __repr__(self) -> str:
  275. extra = ', '.join(f'{n}={getattr(self, n)!r}' for n in self.__slots__ if getattr(self, n) is not None)
  276. return f'{self.__class__.__name__}({super().__repr__()}, {extra})'
  277. class AnyHttpUrl(AnyUrl):
  278. allowed_schemes = {'http', 'https'}
  279. class HttpUrl(AnyHttpUrl):
  280. tld_required = True
  281. # https://stackoverflow.com/questions/417142/what-is-the-maximum-length-of-a-url-in-different-browsers
  282. max_length = 2083
  283. hidden_parts = {'port'}
  284. @staticmethod
  285. def get_default_parts(parts: 'Parts') -> 'Parts':
  286. return {'port': '80' if parts['scheme'] == 'http' else '443'}
  287. class FileUrl(AnyUrl):
  288. allowed_schemes = {'file'}
  289. host_required = False
  290. class PostgresDsn(AnyUrl):
  291. allowed_schemes = {
  292. 'postgres',
  293. 'postgresql',
  294. 'postgresql+asyncpg',
  295. 'postgresql+pg8000',
  296. 'postgresql+psycopg2',
  297. 'postgresql+psycopg2cffi',
  298. 'postgresql+py-postgresql',
  299. 'postgresql+pygresql',
  300. }
  301. user_required = True
  302. class AmqpDsn(AnyUrl):
  303. allowed_schemes = {'amqp', 'amqps'}
  304. host_required = False
  305. class RedisDsn(AnyUrl):
  306. allowed_schemes = {'redis', 'rediss'}
  307. host_required = False
  308. @staticmethod
  309. def get_default_parts(parts: 'Parts') -> 'Parts':
  310. return {
  311. 'domain': 'localhost' if not (parts['ipv4'] or parts['ipv6']) else '',
  312. 'port': '6379',
  313. 'path': '/0',
  314. }
  315. class KafkaDsn(AnyUrl):
  316. allowed_schemes = {'kafka'}
  317. @staticmethod
  318. def get_default_parts(parts: 'Parts') -> 'Parts':
  319. return {
  320. 'domain': 'localhost',
  321. 'port': '9092',
  322. }
  323. def stricturl(
  324. *,
  325. strip_whitespace: bool = True,
  326. min_length: int = 1,
  327. max_length: int = 2 ** 16,
  328. tld_required: bool = True,
  329. host_required: bool = True,
  330. allowed_schemes: Optional[Collection[str]] = None,
  331. ) -> Type[AnyUrl]:
  332. # use kwargs then define conf in a dict to aid with IDE type hinting
  333. namespace = dict(
  334. strip_whitespace=strip_whitespace,
  335. min_length=min_length,
  336. max_length=max_length,
  337. tld_required=tld_required,
  338. host_required=host_required,
  339. allowed_schemes=allowed_schemes,
  340. )
  341. return type('UrlValue', (AnyUrl,), namespace)
  342. def import_email_validator() -> None:
  343. global email_validator
  344. try:
  345. import email_validator
  346. except ImportError as e:
  347. raise ImportError('email-validator is not installed, run `pip install pydantic[email]`') from e
  348. class EmailStr(str):
  349. @classmethod
  350. def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
  351. field_schema.update(type='string', format='email')
  352. @classmethod
  353. def __get_validators__(cls) -> 'CallableGenerator':
  354. # included here and below so the error happens straight away
  355. import_email_validator()
  356. yield str_validator
  357. yield cls.validate
  358. @classmethod
  359. def validate(cls, value: Union[str]) -> str:
  360. return validate_email(value)[1]
  361. class NameEmail(Representation):
  362. __slots__ = 'name', 'email'
  363. def __init__(self, name: str, email: str):
  364. self.name = name
  365. self.email = email
  366. def __eq__(self, other: Any) -> bool:
  367. return isinstance(other, NameEmail) and (self.name, self.email) == (other.name, other.email)
  368. @classmethod
  369. def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
  370. field_schema.update(type='string', format='name-email')
  371. @classmethod
  372. def __get_validators__(cls) -> 'CallableGenerator':
  373. import_email_validator()
  374. yield cls.validate
  375. @classmethod
  376. def validate(cls, value: Any) -> 'NameEmail':
  377. if value.__class__ == cls:
  378. return value
  379. value = str_validator(value)
  380. return cls(*validate_email(value))
  381. def __str__(self) -> str:
  382. return f'{self.name} <{self.email}>'
  383. class IPvAnyAddress(_BaseAddress):
  384. @classmethod
  385. def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
  386. field_schema.update(type='string', format='ipvanyaddress')
  387. @classmethod
  388. def __get_validators__(cls) -> 'CallableGenerator':
  389. yield cls.validate
  390. @classmethod
  391. def validate(cls, value: Union[str, bytes, int]) -> Union[IPv4Address, IPv6Address]:
  392. try:
  393. return IPv4Address(value)
  394. except ValueError:
  395. pass
  396. try:
  397. return IPv6Address(value)
  398. except ValueError:
  399. raise errors.IPvAnyAddressError()
  400. class IPvAnyInterface(_BaseAddress):
  401. @classmethod
  402. def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
  403. field_schema.update(type='string', format='ipvanyinterface')
  404. @classmethod
  405. def __get_validators__(cls) -> 'CallableGenerator':
  406. yield cls.validate
  407. @classmethod
  408. def validate(cls, value: NetworkType) -> Union[IPv4Interface, IPv6Interface]:
  409. try:
  410. return IPv4Interface(value)
  411. except ValueError:
  412. pass
  413. try:
  414. return IPv6Interface(value)
  415. except ValueError:
  416. raise errors.IPvAnyInterfaceError()
  417. class IPvAnyNetwork(_BaseNetwork): # type: ignore
  418. @classmethod
  419. def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
  420. field_schema.update(type='string', format='ipvanynetwork')
  421. @classmethod
  422. def __get_validators__(cls) -> 'CallableGenerator':
  423. yield cls.validate
  424. @classmethod
  425. def validate(cls, value: NetworkType) -> Union[IPv4Network, IPv6Network]:
  426. # Assume IP Network is defined with a default value for ``strict`` argument.
  427. # Define your own class if you want to specify network address check strictness.
  428. try:
  429. return IPv4Network(value)
  430. except ValueError:
  431. pass
  432. try:
  433. return IPv6Network(value)
  434. except ValueError:
  435. raise errors.IPvAnyNetworkError()
  436. pretty_email_regex = re.compile(r'([\w ]*?) *<(.*)> *')
  437. def validate_email(value: Union[str]) -> Tuple[str, str]:
  438. """
  439. Brutally simple email address validation. Note unlike most email address validation
  440. * raw ip address (literal) domain parts are not allowed.
  441. * "John Doe <local_part@domain.com>" style "pretty" email addresses are processed
  442. * the local part check is extremely basic. This raises the possibility of unicode spoofing, but no better
  443. solution is really possible.
  444. * spaces are striped from the beginning and end of addresses but no error is raised
  445. See RFC 5322 but treat it with suspicion, there seems to exist no universally acknowledged test for a valid email!
  446. """
  447. if email_validator is None:
  448. import_email_validator()
  449. m = pretty_email_regex.fullmatch(value)
  450. name: Optional[str] = None
  451. if m:
  452. name, value = m.groups()
  453. email = value.strip()
  454. try:
  455. email_validator.validate_email(email, check_deliverability=False)
  456. except email_validator.EmailNotValidError as e:
  457. raise errors.EmailError() from e
  458. at_index = email.index('@')
  459. local_part = email[:at_index] # RFC 5321, local part must be case-sensitive.
  460. global_part = email[at_index:].lower()
  461. return name or local_part, local_part + global_part