You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

763 lines
27 KiB

  1. import asyncio
  2. import contextlib
  3. import functools
  4. import inspect
  5. import re
  6. import sys
  7. import traceback
  8. import types
  9. import typing
  10. import warnings
  11. from enum import Enum
  12. from starlette.concurrency import run_in_threadpool
  13. from starlette.convertors import CONVERTOR_TYPES, Convertor
  14. from starlette.datastructures import URL, Headers, URLPath
  15. from starlette.exceptions import HTTPException
  16. from starlette.requests import Request
  17. from starlette.responses import PlainTextResponse, RedirectResponse
  18. from starlette.types import ASGIApp, Receive, Scope, Send
  19. from starlette.websockets import WebSocket, WebSocketClose
  20. if sys.version_info >= (3, 7):
  21. from contextlib import asynccontextmanager # pragma: no cover
  22. else:
  23. from contextlib2 import asynccontextmanager # pragma: no cover
  24. class NoMatchFound(Exception):
  25. """
  26. Raised by `.url_for(name, **path_params)` and `.url_path_for(name, **path_params)`
  27. if no matching route exists.
  28. """
  29. class Match(Enum):
  30. NONE = 0
  31. PARTIAL = 1
  32. FULL = 2
  33. def iscoroutinefunction_or_partial(obj: typing.Any) -> bool:
  34. """
  35. Correctly determines if an object is a coroutine function,
  36. including those wrapped in functools.partial objects.
  37. """
  38. while isinstance(obj, functools.partial):
  39. obj = obj.func
  40. return inspect.iscoroutinefunction(obj)
  41. def request_response(func: typing.Callable) -> ASGIApp:
  42. """
  43. Takes a function or coroutine `func(request) -> response`,
  44. and returns an ASGI application.
  45. """
  46. is_coroutine = iscoroutinefunction_or_partial(func)
  47. async def app(scope: Scope, receive: Receive, send: Send) -> None:
  48. request = Request(scope, receive=receive, send=send)
  49. if is_coroutine:
  50. response = await func(request)
  51. else:
  52. response = await run_in_threadpool(func, request)
  53. await response(scope, receive, send)
  54. return app
  55. def websocket_session(func: typing.Callable) -> ASGIApp:
  56. """
  57. Takes a coroutine `func(session)`, and returns an ASGI application.
  58. """
  59. # assert asyncio.iscoroutinefunction(func), "WebSocket endpoints must be async"
  60. async def app(scope: Scope, receive: Receive, send: Send) -> None:
  61. session = WebSocket(scope, receive=receive, send=send)
  62. await func(session)
  63. return app
  64. def get_name(endpoint: typing.Callable) -> str:
  65. if inspect.isfunction(endpoint) or inspect.isclass(endpoint):
  66. return endpoint.__name__
  67. return endpoint.__class__.__name__
  68. def replace_params(
  69. path: str,
  70. param_convertors: typing.Dict[str, Convertor],
  71. path_params: typing.Dict[str, str],
  72. ) -> typing.Tuple[str, dict]:
  73. for key, value in list(path_params.items()):
  74. if "{" + key + "}" in path:
  75. convertor = param_convertors[key]
  76. value = convertor.to_string(value)
  77. path = path.replace("{" + key + "}", value)
  78. path_params.pop(key)
  79. return path, path_params
  80. # Match parameters in URL paths, eg. '{param}', and '{param:int}'
  81. PARAM_REGEX = re.compile("{([a-zA-Z_][a-zA-Z0-9_]*)(:[a-zA-Z_][a-zA-Z0-9_]*)?}")
  82. def compile_path(
  83. path: str,
  84. ) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
  85. """
  86. Given a path string, like: "/{username:str}", return a three-tuple
  87. of (regex, format, {param_name:convertor}).
  88. regex: "/(?P<username>[^/]+)"
  89. format: "/{username}"
  90. convertors: {"username": StringConvertor()}
  91. """
  92. path_regex = "^"
  93. path_format = ""
  94. duplicated_params = set()
  95. idx = 0
  96. param_convertors = {}
  97. for match in PARAM_REGEX.finditer(path):
  98. param_name, convertor_type = match.groups("str")
  99. convertor_type = convertor_type.lstrip(":")
  100. assert (
  101. convertor_type in CONVERTOR_TYPES
  102. ), f"Unknown path convertor '{convertor_type}'"
  103. convertor = CONVERTOR_TYPES[convertor_type]
  104. path_regex += re.escape(path[idx : match.start()])
  105. path_regex += f"(?P<{param_name}>{convertor.regex})"
  106. path_format += path[idx : match.start()]
  107. path_format += "{%s}" % param_name
  108. if param_name in param_convertors:
  109. duplicated_params.add(param_name)
  110. param_convertors[param_name] = convertor
  111. idx = match.end()
  112. if duplicated_params:
  113. names = ", ".join(sorted(duplicated_params))
  114. ending = "s" if len(duplicated_params) > 1 else ""
  115. raise ValueError(f"Duplicated param name{ending} {names} at path {path}")
  116. path_regex += re.escape(path[idx:]) + "$"
  117. path_format += path[idx:]
  118. return re.compile(path_regex), path_format, param_convertors
  119. class BaseRoute:
  120. def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
  121. raise NotImplementedError() # pragma: no cover
  122. def url_path_for(self, name: str, **path_params: str) -> URLPath:
  123. raise NotImplementedError() # pragma: no cover
  124. async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
  125. raise NotImplementedError() # pragma: no cover
  126. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  127. """
  128. A route may be used in isolation as a stand-alone ASGI app.
  129. This is a somewhat contrived case, as they'll almost always be used
  130. within a Router, but could be useful for some tooling and minimal apps.
  131. """
  132. match, child_scope = self.matches(scope)
  133. if match == Match.NONE:
  134. if scope["type"] == "http":
  135. response = PlainTextResponse("Not Found", status_code=404)
  136. await response(scope, receive, send)
  137. elif scope["type"] == "websocket":
  138. websocket_close = WebSocketClose()
  139. await websocket_close(scope, receive, send)
  140. return
  141. scope.update(child_scope)
  142. await self.handle(scope, receive, send)
  143. class Route(BaseRoute):
  144. def __init__(
  145. self,
  146. path: str,
  147. endpoint: typing.Callable,
  148. *,
  149. methods: typing.List[str] = None,
  150. name: str = None,
  151. include_in_schema: bool = True,
  152. ) -> None:
  153. assert path.startswith("/"), "Routed paths must start with '/'"
  154. self.path = path
  155. self.endpoint = endpoint
  156. self.name = get_name(endpoint) if name is None else name
  157. self.include_in_schema = include_in_schema
  158. endpoint_handler = endpoint
  159. while isinstance(endpoint_handler, functools.partial):
  160. endpoint_handler = endpoint_handler.func
  161. if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
  162. # Endpoint is function or method. Treat it as `func(request) -> response`.
  163. self.app = request_response(endpoint)
  164. if methods is None:
  165. methods = ["GET"]
  166. else:
  167. # Endpoint is a class. Treat it as ASGI.
  168. self.app = endpoint
  169. if methods is None:
  170. self.methods = None
  171. else:
  172. self.methods = {method.upper() for method in methods}
  173. if "GET" in self.methods:
  174. self.methods.add("HEAD")
  175. self.path_regex, self.path_format, self.param_convertors = compile_path(path)
  176. def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
  177. if scope["type"] == "http":
  178. match = self.path_regex.match(scope["path"])
  179. if match:
  180. matched_params = match.groupdict()
  181. for key, value in matched_params.items():
  182. matched_params[key] = self.param_convertors[key].convert(value)
  183. path_params = dict(scope.get("path_params", {}))
  184. path_params.update(matched_params)
  185. child_scope = {"endpoint": self.endpoint, "path_params": path_params}
  186. if self.methods and scope["method"] not in self.methods:
  187. return Match.PARTIAL, child_scope
  188. else:
  189. return Match.FULL, child_scope
  190. return Match.NONE, {}
  191. def url_path_for(self, name: str, **path_params: str) -> URLPath:
  192. seen_params = set(path_params.keys())
  193. expected_params = set(self.param_convertors.keys())
  194. if name != self.name or seen_params != expected_params:
  195. raise NoMatchFound()
  196. path, remaining_params = replace_params(
  197. self.path_format, self.param_convertors, path_params
  198. )
  199. assert not remaining_params
  200. return URLPath(path=path, protocol="http")
  201. async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
  202. if self.methods and scope["method"] not in self.methods:
  203. if "app" in scope:
  204. raise HTTPException(status_code=405)
  205. else:
  206. response = PlainTextResponse("Method Not Allowed", status_code=405)
  207. await response(scope, receive, send)
  208. else:
  209. await self.app(scope, receive, send)
  210. def __eq__(self, other: typing.Any) -> bool:
  211. return (
  212. isinstance(other, Route)
  213. and self.path == other.path
  214. and self.endpoint == other.endpoint
  215. and self.methods == other.methods
  216. )
  217. class WebSocketRoute(BaseRoute):
  218. def __init__(
  219. self, path: str, endpoint: typing.Callable, *, name: str = None
  220. ) -> None:
  221. assert path.startswith("/"), "Routed paths must start with '/'"
  222. self.path = path
  223. self.endpoint = endpoint
  224. self.name = get_name(endpoint) if name is None else name
  225. if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
  226. # Endpoint is function or method. Treat it as `func(websocket)`.
  227. self.app = websocket_session(endpoint)
  228. else:
  229. # Endpoint is a class. Treat it as ASGI.
  230. self.app = endpoint
  231. self.path_regex, self.path_format, self.param_convertors = compile_path(path)
  232. def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
  233. if scope["type"] == "websocket":
  234. match = self.path_regex.match(scope["path"])
  235. if match:
  236. matched_params = match.groupdict()
  237. for key, value in matched_params.items():
  238. matched_params[key] = self.param_convertors[key].convert(value)
  239. path_params = dict(scope.get("path_params", {}))
  240. path_params.update(matched_params)
  241. child_scope = {"endpoint": self.endpoint, "path_params": path_params}
  242. return Match.FULL, child_scope
  243. return Match.NONE, {}
  244. def url_path_for(self, name: str, **path_params: str) -> URLPath:
  245. seen_params = set(path_params.keys())
  246. expected_params = set(self.param_convertors.keys())
  247. if name != self.name or seen_params != expected_params:
  248. raise NoMatchFound()
  249. path, remaining_params = replace_params(
  250. self.path_format, self.param_convertors, path_params
  251. )
  252. assert not remaining_params
  253. return URLPath(path=path, protocol="websocket")
  254. async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
  255. await self.app(scope, receive, send)
  256. def __eq__(self, other: typing.Any) -> bool:
  257. return (
  258. isinstance(other, WebSocketRoute)
  259. and self.path == other.path
  260. and self.endpoint == other.endpoint
  261. )
  262. class Mount(BaseRoute):
  263. def __init__(
  264. self,
  265. path: str,
  266. app: ASGIApp = None,
  267. routes: typing.Sequence[BaseRoute] = None,
  268. name: str = None,
  269. ) -> None:
  270. assert path == "" or path.startswith("/"), "Routed paths must start with '/'"
  271. assert (
  272. app is not None or routes is not None
  273. ), "Either 'app=...', or 'routes=' must be specified"
  274. self.path = path.rstrip("/")
  275. if app is not None:
  276. self.app: ASGIApp = app
  277. else:
  278. self.app = Router(routes=routes)
  279. self.name = name
  280. self.path_regex, self.path_format, self.param_convertors = compile_path(
  281. self.path + "/{path:path}"
  282. )
  283. @property
  284. def routes(self) -> typing.List[BaseRoute]:
  285. return getattr(self.app, "routes", None)
  286. def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
  287. if scope["type"] in ("http", "websocket"):
  288. path = scope["path"]
  289. match = self.path_regex.match(path)
  290. if match:
  291. matched_params = match.groupdict()
  292. for key, value in matched_params.items():
  293. matched_params[key] = self.param_convertors[key].convert(value)
  294. remaining_path = "/" + matched_params.pop("path")
  295. matched_path = path[: -len(remaining_path)]
  296. path_params = dict(scope.get("path_params", {}))
  297. path_params.update(matched_params)
  298. root_path = scope.get("root_path", "")
  299. child_scope = {
  300. "path_params": path_params,
  301. "app_root_path": scope.get("app_root_path", root_path),
  302. "root_path": root_path + matched_path,
  303. "path": remaining_path,
  304. "endpoint": self.app,
  305. }
  306. return Match.FULL, child_scope
  307. return Match.NONE, {}
  308. def url_path_for(self, name: str, **path_params: str) -> URLPath:
  309. if self.name is not None and name == self.name and "path" in path_params:
  310. # 'name' matches "<mount_name>".
  311. path_params["path"] = path_params["path"].lstrip("/")
  312. path, remaining_params = replace_params(
  313. self.path_format, self.param_convertors, path_params
  314. )
  315. if not remaining_params:
  316. return URLPath(path=path)
  317. elif self.name is None or name.startswith(self.name + ":"):
  318. if self.name is None:
  319. # No mount name.
  320. remaining_name = name
  321. else:
  322. # 'name' matches "<mount_name>:<child_name>".
  323. remaining_name = name[len(self.name) + 1 :]
  324. path_kwarg = path_params.get("path")
  325. path_params["path"] = ""
  326. path_prefix, remaining_params = replace_params(
  327. self.path_format, self.param_convertors, path_params
  328. )
  329. if path_kwarg is not None:
  330. remaining_params["path"] = path_kwarg
  331. for route in self.routes or []:
  332. try:
  333. url = route.url_path_for(remaining_name, **remaining_params)
  334. return URLPath(
  335. path=path_prefix.rstrip("/") + str(url), protocol=url.protocol
  336. )
  337. except NoMatchFound:
  338. pass
  339. raise NoMatchFound()
  340. async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
  341. await self.app(scope, receive, send)
  342. def __eq__(self, other: typing.Any) -> bool:
  343. return (
  344. isinstance(other, Mount)
  345. and self.path == other.path
  346. and self.app == other.app
  347. )
  348. class Host(BaseRoute):
  349. def __init__(self, host: str, app: ASGIApp, name: str = None) -> None:
  350. self.host = host
  351. self.app = app
  352. self.name = name
  353. self.host_regex, self.host_format, self.param_convertors = compile_path(host)
  354. @property
  355. def routes(self) -> typing.List[BaseRoute]:
  356. return getattr(self.app, "routes", None)
  357. def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]:
  358. if scope["type"] in ("http", "websocket"):
  359. headers = Headers(scope=scope)
  360. host = headers.get("host", "").split(":")[0]
  361. match = self.host_regex.match(host)
  362. if match:
  363. matched_params = match.groupdict()
  364. for key, value in matched_params.items():
  365. matched_params[key] = self.param_convertors[key].convert(value)
  366. path_params = dict(scope.get("path_params", {}))
  367. path_params.update(matched_params)
  368. child_scope = {"path_params": path_params, "endpoint": self.app}
  369. return Match.FULL, child_scope
  370. return Match.NONE, {}
  371. def url_path_for(self, name: str, **path_params: str) -> URLPath:
  372. if self.name is not None and name == self.name and "path" in path_params:
  373. # 'name' matches "<mount_name>".
  374. path = path_params.pop("path")
  375. host, remaining_params = replace_params(
  376. self.host_format, self.param_convertors, path_params
  377. )
  378. if not remaining_params:
  379. return URLPath(path=path, host=host)
  380. elif self.name is None or name.startswith(self.name + ":"):
  381. if self.name is None:
  382. # No mount name.
  383. remaining_name = name
  384. else:
  385. # 'name' matches "<mount_name>:<child_name>".
  386. remaining_name = name[len(self.name) + 1 :]
  387. host, remaining_params = replace_params(
  388. self.host_format, self.param_convertors, path_params
  389. )
  390. for route in self.routes or []:
  391. try:
  392. url = route.url_path_for(remaining_name, **remaining_params)
  393. return URLPath(path=str(url), protocol=url.protocol, host=host)
  394. except NoMatchFound:
  395. pass
  396. raise NoMatchFound()
  397. async def handle(self, scope: Scope, receive: Receive, send: Send) -> None:
  398. await self.app(scope, receive, send)
  399. def __eq__(self, other: typing.Any) -> bool:
  400. return (
  401. isinstance(other, Host)
  402. and self.host == other.host
  403. and self.app == other.app
  404. )
  405. _T = typing.TypeVar("_T")
  406. class _AsyncLiftContextManager(typing.AsyncContextManager[_T]):
  407. def __init__(self, cm: typing.ContextManager[_T]):
  408. self._cm = cm
  409. async def __aenter__(self) -> _T:
  410. return self._cm.__enter__()
  411. async def __aexit__(
  412. self,
  413. exc_type: typing.Optional[typing.Type[BaseException]],
  414. exc_value: typing.Optional[BaseException],
  415. traceback: typing.Optional[types.TracebackType],
  416. ) -> typing.Optional[bool]:
  417. return self._cm.__exit__(exc_type, exc_value, traceback)
  418. def _wrap_gen_lifespan_context(
  419. lifespan_context: typing.Callable[[typing.Any], typing.Generator]
  420. ) -> typing.Callable[[typing.Any], typing.AsyncContextManager]:
  421. cmgr = contextlib.contextmanager(lifespan_context)
  422. @functools.wraps(cmgr)
  423. def wrapper(app: typing.Any) -> _AsyncLiftContextManager:
  424. return _AsyncLiftContextManager(cmgr(app))
  425. return wrapper
  426. class _DefaultLifespan:
  427. def __init__(self, router: "Router"):
  428. self._router = router
  429. async def __aenter__(self) -> None:
  430. await self._router.startup()
  431. async def __aexit__(self, *exc_info: object) -> None:
  432. await self._router.shutdown()
  433. def __call__(self: _T, app: object) -> _T:
  434. return self
  435. class Router:
  436. def __init__(
  437. self,
  438. routes: typing.Sequence[BaseRoute] = None,
  439. redirect_slashes: bool = True,
  440. default: ASGIApp = None,
  441. on_startup: typing.Sequence[typing.Callable] = None,
  442. on_shutdown: typing.Sequence[typing.Callable] = None,
  443. lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None,
  444. ) -> None:
  445. self.routes = [] if routes is None else list(routes)
  446. self.redirect_slashes = redirect_slashes
  447. self.default = self.not_found if default is None else default
  448. self.on_startup = [] if on_startup is None else list(on_startup)
  449. self.on_shutdown = [] if on_shutdown is None else list(on_shutdown)
  450. if lifespan is None:
  451. self.lifespan_context: typing.Callable[
  452. [typing.Any], typing.AsyncContextManager
  453. ] = _DefaultLifespan(self)
  454. elif inspect.isasyncgenfunction(lifespan):
  455. warnings.warn(
  456. "async generator function lifespans are deprecated, "
  457. "use an @contextlib.asynccontextmanager function instead",
  458. DeprecationWarning,
  459. )
  460. self.lifespan_context = asynccontextmanager(
  461. lifespan, # type: ignore[arg-type]
  462. )
  463. elif inspect.isgeneratorfunction(lifespan):
  464. warnings.warn(
  465. "generator function lifespans are deprecated, "
  466. "use an @contextlib.asynccontextmanager function instead",
  467. DeprecationWarning,
  468. )
  469. self.lifespan_context = _wrap_gen_lifespan_context(
  470. lifespan, # type: ignore[arg-type]
  471. )
  472. else:
  473. self.lifespan_context = lifespan
  474. async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None:
  475. if scope["type"] == "websocket":
  476. websocket_close = WebSocketClose()
  477. await websocket_close(scope, receive, send)
  478. return
  479. # If we're running inside a starlette application then raise an
  480. # exception, so that the configurable exception handler can deal with
  481. # returning the response. For plain ASGI apps, just return the response.
  482. if "app" in scope:
  483. raise HTTPException(status_code=404)
  484. else:
  485. response = PlainTextResponse("Not Found", status_code=404)
  486. await response(scope, receive, send)
  487. def url_path_for(self, name: str, **path_params: str) -> URLPath:
  488. for route in self.routes:
  489. try:
  490. return route.url_path_for(name, **path_params)
  491. except NoMatchFound:
  492. pass
  493. raise NoMatchFound()
  494. async def startup(self) -> None:
  495. """
  496. Run any `.on_startup` event handlers.
  497. """
  498. for handler in self.on_startup:
  499. if asyncio.iscoroutinefunction(handler):
  500. await handler()
  501. else:
  502. handler()
  503. async def shutdown(self) -> None:
  504. """
  505. Run any `.on_shutdown` event handlers.
  506. """
  507. for handler in self.on_shutdown:
  508. if asyncio.iscoroutinefunction(handler):
  509. await handler()
  510. else:
  511. handler()
  512. async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None:
  513. """
  514. Handle ASGI lifespan messages, which allows us to manage application
  515. startup and shutdown events.
  516. """
  517. started = False
  518. app = scope.get("app")
  519. await receive()
  520. try:
  521. async with self.lifespan_context(app):
  522. await send({"type": "lifespan.startup.complete"})
  523. started = True
  524. await receive()
  525. except BaseException:
  526. exc_text = traceback.format_exc()
  527. if started:
  528. await send({"type": "lifespan.shutdown.failed", "message": exc_text})
  529. else:
  530. await send({"type": "lifespan.startup.failed", "message": exc_text})
  531. raise
  532. else:
  533. await send({"type": "lifespan.shutdown.complete"})
  534. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  535. """
  536. The main entry point to the Router class.
  537. """
  538. assert scope["type"] in ("http", "websocket", "lifespan")
  539. if "router" not in scope:
  540. scope["router"] = self
  541. if scope["type"] == "lifespan":
  542. await self.lifespan(scope, receive, send)
  543. return
  544. partial = None
  545. for route in self.routes:
  546. # Determine if any route matches the incoming scope,
  547. # and hand over to the matching route if found.
  548. match, child_scope = route.matches(scope)
  549. if match == Match.FULL:
  550. scope.update(child_scope)
  551. await route.handle(scope, receive, send)
  552. return
  553. elif match == Match.PARTIAL and partial is None:
  554. partial = route
  555. partial_scope = child_scope
  556. if partial is not None:
  557. #  Handle partial matches. These are cases where an endpoint is
  558. # able to handle the request, but is not a preferred option.
  559. # We use this in particular to deal with "405 Method Not Allowed".
  560. scope.update(partial_scope)
  561. await partial.handle(scope, receive, send)
  562. return
  563. if scope["type"] == "http" and self.redirect_slashes and scope["path"] != "/":
  564. redirect_scope = dict(scope)
  565. if scope["path"].endswith("/"):
  566. redirect_scope["path"] = redirect_scope["path"].rstrip("/")
  567. else:
  568. redirect_scope["path"] = redirect_scope["path"] + "/"
  569. for route in self.routes:
  570. match, child_scope = route.matches(redirect_scope)
  571. if match != Match.NONE:
  572. redirect_url = URL(scope=redirect_scope)
  573. response = RedirectResponse(url=str(redirect_url))
  574. await response(scope, receive, send)
  575. return
  576. await self.default(scope, receive, send)
  577. def __eq__(self, other: typing.Any) -> bool:
  578. return isinstance(other, Router) and self.routes == other.routes
  579. # The following usages are now discouraged in favour of configuration
  580. #  during Router.__init__(...)
  581. def mount(self, path: str, app: ASGIApp, name: str = None) -> None:
  582. route = Mount(path, app=app, name=name)
  583. self.routes.append(route)
  584. def host(self, host: str, app: ASGIApp, name: str = None) -> None:
  585. route = Host(host, app=app, name=name)
  586. self.routes.append(route)
  587. def add_route(
  588. self,
  589. path: str,
  590. endpoint: typing.Callable,
  591. methods: typing.List[str] = None,
  592. name: str = None,
  593. include_in_schema: bool = True,
  594. ) -> None:
  595. route = Route(
  596. path,
  597. endpoint=endpoint,
  598. methods=methods,
  599. name=name,
  600. include_in_schema=include_in_schema,
  601. )
  602. self.routes.append(route)
  603. def add_websocket_route(
  604. self, path: str, endpoint: typing.Callable, name: str = None
  605. ) -> None:
  606. route = WebSocketRoute(path, endpoint=endpoint, name=name)
  607. self.routes.append(route)
  608. def route(
  609. self,
  610. path: str,
  611. methods: typing.List[str] = None,
  612. name: str = None,
  613. include_in_schema: bool = True,
  614. ) -> typing.Callable:
  615. def decorator(func: typing.Callable) -> typing.Callable:
  616. self.add_route(
  617. path,
  618. func,
  619. methods=methods,
  620. name=name,
  621. include_in_schema=include_in_schema,
  622. )
  623. return func
  624. return decorator
  625. def websocket_route(self, path: str, name: str = None) -> typing.Callable:
  626. def decorator(func: typing.Callable) -> typing.Callable:
  627. self.add_websocket_route(path, func, name=name)
  628. return func
  629. return decorator
  630. def add_event_handler(self, event_type: str, func: typing.Callable) -> None:
  631. assert event_type in ("startup", "shutdown")
  632. if event_type == "startup":
  633. self.on_startup.append(func)
  634. else:
  635. self.on_shutdown.append(func)
  636. def on_event(self, event_type: str) -> typing.Callable:
  637. def decorator(func: typing.Callable) -> typing.Callable:
  638. self.add_event_handler(event_type, func)
  639. return func
  640. return decorator