Coverage for / home / runner / work / viur-core / viur-core / viur / src / viur / core / request.py: 6%
446 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 20:18 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-11 20:18 +0000
1"""
2 This module implements the WSGI (Web Server Gateway Interface) layer for ViUR. This is the main entry
3 point for incomming http requests. The main class is the :class:BrowserHandler. Each request will get it's
4 own instance of that class which then holds the reference to the request and response object.
5 Additionally, this module defines the RequestValidator interface which provides a very early hook into the
6 request processing (useful for global ratelimiting, DDoS prevention or access control).
7"""
8import datetime
9import fnmatch
10import json
11import logging
12import os
13import re
14import time
15import traceback
16import typing as t
17import unicodedata
18from abc import ABC, abstractmethod
19from urllib import parse
20from urllib.parse import quote, unquote, urljoin, urlparse
22import webob
24from viur.core import current, db, errors, session, utils
25from viur.core.config import conf
26from viur.core.logging import client as loggingClient, requestLogger, requestLoggingRessource
27from viur.core.module import Method
28from viur.core.securityheaders import extendCsp
29from viur.core.tasks import _appengineServiceIPs
31TEMPLATE_STYLE_KEY = "style"
34class RequestValidator(ABC):
35 """
36 RequestValidators can be used to validate a request very early on. If the validate method returns a tuple,
37 the request is aborted. Can be used to block requests from bots.
39 To register or remove a validator, access it in main.py through
40 :attr: viur.core.request.Router.requestValidators
41 """
42 # Internal name to trace which validator aborted the request
43 name = "RequestValidator"
45 @staticmethod
46 @abstractmethod
47 def validate(request: 'BrowseHandler') -> t.Optional[tuple[int, str, str]]:
48 """
49 The function that checks the current request. If the request is valid, simply return None.
50 If the request should be blocked, it must return a tuple of
51 - The HTTP status code (as int)
52 - The Description of that status code (eg "Forbidden")
53 - The Response Body (can be a simple string or an HTML-Page)
54 :param request: The Request instance to check
55 :return: None on success, an Error-Tuple otherwise
56 """
57 raise NotImplementedError()
60class FetchMetaDataValidator(RequestValidator):
61 """
62 This validator examines the headers "Sec-Fetch-Site", "sec-fetch-mode" and "sec-fetch-dest" as
63 recommended by https://web.dev/fetch-metadata/
64 """
65 name = "FetchMetaDataValidator"
67 @staticmethod
68 def validate(request: 'BrowseHandler') -> t.Optional[tuple[int, str, str]]:
69 """
70 This validator examines the headers "sec-fetch-site",
71 "sec-fetch-mode" and "sec-fetch-dest" as recommended
72 by https://web.dev/fetch-metadata/
73 """
74 headers = request.request.headers
76 match headers.get("sec-fetch-site"):
77 case None | "same-origin" | "none":
78 # A Request from our site, or browser didn't send "sec-fetch-site"
79 return None
80 case "same-site":
81 # We are accepting a request with same-site only in local dev mode
82 if conf.instance.is_dev_server:
83 return None
84 case _:
85 # Incoming navigation GET request
86 if (
87 not request.isPostRequest
88 and headers.get("sec-fetch-mode") == "navigate"
89 and headers.get('sec-fetch-dest') not in ("object", "embed")
90 ):
91 return None
93 return 403, "Forbidden", "Request rejected due to fetch metadata"
96class Router:
97 """
98 This class accepts the requests, collect its parameters and routes the request
99 to its destination function.
100 The basic control flow is
101 - Setting up internal variables
102 - Running the Request validators
103 - Emitting the headers (especially the security related ones)
104 - Run the TLS check (ensure it's a secure connection or check if the URL is whitelisted)
105 - Load or initialize a new session
106 - Set up i18n (choosing the language etc)
107 - Run the request preprocessor (if any)
108 - Normalize & sanity check the parameters
109 - Resolve the exposed function and call it
110 - Save the session / tear down the request
111 - Return the response generated
114 :warning: Don't instantiate! Don't subclass! DON'T TOUCH! ;)
115 """
117 # List of requestValidators used to preflight-check an request before it's being dispatched within ViUR
118 requestValidators = [FetchMetaDataValidator]
120 def __init__(self, environ: dict):
121 super().__init__()
122 self.startTime = time.time()
124 self.request = webob.Request(environ)
125 self.response = webob.Response()
127 self.maxLogLevel = logging.DEBUG
128 self._traceID = \
129 self.request.headers.get("X-Cloud-Trace-Context", "").split("/")[0] or utils.string.random()
130 self.is_deferred = False
131 self.path = ""
132 self.path_list = ()
134 self.skey_checked = False # indicates whether @skey-decorator-check has already performed within a request
135 self.internalRequest = False
136 self.disableCache = False # Shall this request bypass the caches?
137 self.pendingTasks = []
138 self.args = ()
139 self.kwargs = {}
140 self.context = {}
141 self.template_style: str | None = None
142 self.cors_headers = ()
144 # Check if it's a HTTP-Method we support
145 self.method = self.request.method.lower()
146 self.isPostRequest = self.method == "post"
147 self.isSSLConnection = self.request.host_url.lower().startswith("https://") # We have an encrypted channel
149 db.current_db_access_log.set(set())
151 # Set context variables
152 current.language.set(conf.i18n.default_language)
153 current.request.set(self)
154 current.session.set(session.Session())
155 current.request_data.set({})
157 # Process actual request
158 self._process()
160 self._cors()
162 # Unset context variables
163 current.language.set(None)
164 current.request_data.set(None)
165 current.session.set(None)
166 current.request.set(None)
167 current.user.set(None)
169 @property
170 def isDevServer(self) -> bool:
171 import warnings
172 msg = "Use of `isDevServer` is deprecated; Use `conf.instance.is_dev_server` instead!"
173 warnings.warn(msg, DeprecationWarning, stacklevel=2)
174 logging.warning(msg)
175 return conf.instance.is_dev_server
177 def _select_language(self, path: str) -> str:
178 """
179 Tries to select the best language for the current request. Depending on the value of
180 conf.i18n.language_method, we'll either try to load it from the session, determine it by the domain
181 or extract it from the URL.
182 """
184 def get_language_from_header() -> str | None:
185 if not (accept_language := self.request.headers.get("accept-language")):
186 return None
187 languages = accept_language.split(",")
188 locale_q_pairs = []
190 for language in languages:
191 if language.split(";")[0] == language:
192 # no q => q = 1
193 locale_q_pairs.append((language.strip(), "1"))
194 else:
195 try:
196 locale = language.split(";")[0].strip()
197 q = language.split(";")[1].split("=")[1]
198 locale_q_pairs.append((locale, q))
199 except IndexError:
200 continue # skip language
201 locale_q_pairs.sort(key=lambda pair: pair[1], reverse=True) # sort by Quality values
202 for locale_q_pair in locale_q_pairs:
203 if "-" in locale_q_pair[0]: # Check for de-DE
204 lang = locale_q_pair[0].split("-")[0]
205 else:
206 lang = locale_q_pair[0]
207 if lang in conf.i18n.available_languages + list(conf.i18n.language_alias_map.keys()):
208 return lang
209 if lang == "*": # fallback
210 return conf.i18n.available_languages[0]
211 return None
213 if not conf.i18n.available_languages:
214 # This project doesn't use the multi-language feature, nothing to do here
215 return path
216 if conf.i18n.language_method == "session":
217 current_session = current.session.get()
218 lang = conf.i18n.default_language
219 # We save the language in the session, if it exists, and try to load it from there
220 if "lang" in current_session:
221 current.language.set(current_session["lang"])
222 return path
224 if header_lang := get_language_from_header():
225 lang = header_lang
226 current.language.set(lang)
228 elif header_lang := self.request.headers.get("X-Appengine-Country"):
229 header_lang = str(header_lang).lower()
230 if header_lang in conf.i18n.available_languages + list(conf.i18n.language_alias_map.keys()):
231 lang = header_lang
233 if current_session.loaded:
234 current_session["lang"] = lang
235 current.language.set(lang)
237 elif conf.i18n.language_method == "domain":
238 host = self.request.host_url.lower()
239 host = host[host.find("://") + 3:].strip(" /") # strip http(s)://
240 if host.startswith("www."):
241 host = host[4:]
242 if lang := conf.i18n.domain_language_mapping.get(host):
243 current.language.set(lang)
244 # We have no language configured for this domain, try to read it from the HTTP Header
245 elif lang := get_language_from_header():
246 current.language.set(lang)
248 elif conf.i18n.language_method == "url":
249 tmppath = urlparse(path).path
250 tmppath = [unquote(x) for x in tmppath.lower().strip("/").split("/")]
251 if (
252 len(tmppath) > 0
253 and tmppath[0] in conf.i18n.available_languages + list(conf.i18n.language_alias_map.keys())
254 ):
255 current.language.set(tmppath[0])
256 return path[len(tmppath[0]) + 1:] # Return the path stripped by its language segment
257 else: # This URL doesnt contain an language prefix, try to read it from session
258 if header_lang := get_language_from_header():
259 current.language.set(header_lang)
260 elif header_lang := self.request.headers.get("X-Appengine-Country"):
261 lang = str(header_lang).lower()
262 if lang in conf.i18n.available_languages or lang in conf.i18n.language_alias_map:
263 current.language.set(lang)
264 elif conf.i18n.language_method == "header":
265 if lang := get_language_from_header():
266 current.language.set(lang)
268 return path
270 def _process(self):
271 if self.method not in ("get", "post", "head", "options"):
272 logging.error(f"{self.method=} not supported")
273 return
275 if self.request.headers.get("X-AppEngine-TaskName", None) is not None: # Check if we run in the appengine
276 if self.request.environ.get("HTTP_X_APPENGINE_USER_IP") in _appengineServiceIPs:
277 self.is_deferred = True
278 elif os.getenv("TASKS_EMULATOR") is not None:
279 self.is_deferred = True
281 # Check if we should process or abort the request
282 for validator, reqValidatorResult in [(x, x.validate(self)) for x in self.requestValidators]:
283 if reqValidatorResult is not None:
284 logging.warning(f"Request rejected by validator {validator.name}")
285 statusCode, statusStr, statusDescr = reqValidatorResult
286 self.response.status = f"{statusCode} {statusStr}"
287 self.response.write(statusDescr)
288 return
290 try:
291 path = self.request.path
292 except UnicodeDecodeError: # webob can fail with UnicodeDecodeError on broken/invalid URLs
293 self.response.status = "400 Bad Request" # let's send the client onto a health cure in Bad Request ...
294 return
296 # Add CSP headers early (if any)
297 if conf.security.content_security_policy and conf.security.content_security_policy["_headerCache"]:
298 for k, v in conf.security.content_security_policy["_headerCache"].items():
299 self.response.headers[k] = v
300 if self.isSSLConnection: # Check for HTST and PKP headers only if we have a secure channel.
301 if conf.security.strict_transport_security:
302 self.response.headers["Strict-Transport-Security"] = conf.security.strict_transport_security
303 # Check for X-Security-Headers we shall emit
304 if conf.security.x_content_type_options:
305 self.response.headers["X-Content-Type-Options"] = "nosniff"
306 if conf.security.x_xss_protection is not None:
307 if conf.security.x_xss_protection:
308 self.response.headers["X-XSS-Protection"] = "1; mode=block"
309 elif conf.security.x_xss_protection is False:
310 self.response.headers["X-XSS-Protection"] = "0"
311 if conf.security.x_frame_options is not None and isinstance(conf.security.x_frame_options, tuple):
312 mode, uri = conf.security.x_frame_options
313 if mode in ["deny", "sameorigin"]:
314 self.response.headers["X-Frame-Options"] = mode
315 elif mode == "allow-from":
316 self.response.headers["X-Frame-Options"] = f"allow-from {uri}"
317 if conf.security.x_permitted_cross_domain_policies is not None:
318 self.response.headers["X-Permitted-Cross-Domain-Policies"] = conf.security.x_permitted_cross_domain_policies
319 if conf.security.referrer_policy:
320 self.response.headers["Referrer-Policy"] = conf.security.referrer_policy
321 if conf.security.permissions_policy.get("_headerCache"):
322 self.response.headers["Permissions-Policy"] = conf.security.permissions_policy["_headerCache"]
323 if conf.security.enable_coep:
324 self.response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
325 if conf.security.enable_coop:
326 self.response.headers["Cross-Origin-Opener-Policy"] = conf.security.enable_coop
327 if conf.security.enable_corp:
328 self.response.headers["Cross-Origin-Resource-Policy"] = conf.security.enable_corp
330 # Ensure that TLS is used if required
331 if conf.security.force_ssl and not self.isSSLConnection and not conf.instance.is_dev_server:
332 isWhitelisted = False
333 reqPath = self.request.path
334 for testUrl in conf.security.no_ssl_check_urls:
335 if testUrl.endswith("*"):
336 if reqPath.startswith(testUrl[:-1]):
337 isWhitelisted = True
338 break
339 else:
340 if testUrl == reqPath:
341 isWhitelisted = True
342 break
343 if not isWhitelisted: # Some URLs need to be whitelisted (as f.e. the Tasks-Queue doesn't call using https)
344 # Redirect the user to the startpage (using ssl this time)
345 host = self.request.host_url.lower()
346 host = host[host.find("://") + 3:].strip(" /") # strip http(s)://
347 self.response.status = "302 Found"
348 self.response.headers['Location'] = f"https://{host}/"
349 return
350 if path.startswith("/_ah/warmup"):
351 self.response.write("okay")
352 return
354 try:
355 current.session.get().load()
357 # Load current user into context variable if user module is there.
358 if user_mod := getattr(conf.main_app.vi, "user", None):
359 current.user.set(user_mod.getCurrentUser())
361 path = self._select_language(path)[1:]
363 # Check for closed system
364 if conf.security.closed_system and self.method != "options":
365 if not current.user.get():
366 if not any(fnmatch.fnmatch(path, pat) for pat in conf.security.closed_system_allowed_paths):
367 raise errors.Unauthorized()
369 if conf.request_preprocessor:
370 path = conf.request_preprocessor(path)
372 self._route(path)
374 except errors.Redirect as e:
375 if conf.debug.trace_exceptions:
376 logging.warning("""conf.debug.trace_exceptions is set, won't handle this exception""")
377 raise
378 self.response.status = f"{e.status} {e.name}"
379 url = e.url
380 url = unquote(url) # decode first
381 # safe = https://url.spec.whatwg.org/#url-path-segment-string
382 url = quote(url, encoding="utf-8", safe="!$&'()*+,-./:;=?@_~#") # re-encode all in utf-8
383 if url.startswith(('.', '/')):
384 url = str(urljoin(self.request.url, url))
385 self.response.headers['Location'] = url
387 except Exception as e:
388 if conf.debug.trace_exceptions:
389 logging.warning("""conf.debug.trace_exceptions is set, won't handle this exception""")
390 raise
391 self.response.body = b""
392 if isinstance(e, errors.HTTPException):
393 logging.info(f"[{e.status}] {e.name}: {e.descr}", exc_info=conf.debug.trace)
394 self.response.status = f"{e.status} {e.name}"
395 # Set machine-readable x-viur-error response header in case there is an exception description.
396 if e.descr:
397 self.response.headers["x-viur-error"] = e.descr.replace("\n", "")
398 else:
399 self.response.status = 500
400 logging.error("ViUR has caught an unhandled exception!")
401 logging.exception(e)
403 res = None
404 if conf.error_handler:
405 try:
406 res = conf.error_handler(e)
407 except Exception as newE:
408 logging.error("viur.error_handler failed!")
409 logging.exception(newE)
410 res = None
411 if not res:
412 descr = "The server encountered an unexpected error and is unable to process your request."
414 if isinstance(e, errors.HTTPException):
415 error_info = {
416 "status": e.status,
417 "reason": e.name,
418 "title": str(translate(e.name)),
419 "descr": e.descr,
420 }
421 else:
422 error_info = {
423 "status": 500,
424 "reason": "Internal Server Error",
425 "title": str(translate("Internal Server Error")),
426 "descr": descr
427 }
429 if conf.instance.is_dev_server:
430 error_info["traceback"] = traceback.format_exc()
432 error_info["logo"] = conf.error_logo
434 if (len(self.path_list) > 0 and self.path_list[0] in ("vi", "json")) or \
435 current.request.get().response.headers["Content-Type"] == "application/json":
436 current.request.get().response.headers["Content-Type"] = "application/json"
437 res = json.dumps(error_info)
438 else: # We render the error in html
439 # Try to get the template from html/error/
440 if filename := conf.main_app.render.getTemplateFileName((f"{error_info['status']}", "error"),
441 raise_exception=False):
442 template = conf.main_app.render.getEnv().get_template(filename)
443 try:
444 uses_unsafe_inline = \
445 "unsafe-inline" in conf.security.content_security_policy["enforce"]["style-src"]
446 except (KeyError, TypeError): # Not set
447 uses_unsafe_inline = False
448 if uses_unsafe_inline:
449 logging.info("Using style-src:unsafe-inline, don't create a nonce")
450 nonce = None
451 else:
452 nonce = utils.string.random(16)
453 extendCsp({"style-src": [f"nonce-{nonce}"]})
454 res = template.render(error_info, nonce=nonce)
455 else:
456 res = (f'<!DOCTYPE html><html lang="en"><head><meta charset="UTF-8">'
457 f'<title>{error_info["status"]} - {error_info["reason"]}</title>'
458 f'</head><body><h1>{error_info["status"]} - {error_info["reason"]}</h1>')
460 self.response.write(res.encode("UTF-8"))
462 finally:
463 current.session.get().save()
464 if conf.instance.is_dev_server and conf.debug.dev_server_cloud_logging:
465 # Emit the outer log only on dev_appserver (we'll use the existing request log when live)
466 SEVERITY = "DEBUG"
467 if self.maxLogLevel >= 50:
468 SEVERITY = "CRITICAL"
469 elif self.maxLogLevel >= 40:
470 SEVERITY = "ERROR"
471 elif self.maxLogLevel >= 30:
472 SEVERITY = "WARNING"
473 elif self.maxLogLevel >= 20:
474 SEVERITY = "INFO"
476 TRACE = "projects/{}/traces/{}".format(loggingClient.project, self._traceID)
478 REQUEST = {
479 'requestMethod': self.request.method,
480 'requestUrl': self.request.url,
481 'status': self.response.status_code,
482 'userAgent': self.request.headers.get('USER-AGENT'),
483 'responseSize': self.response.content_length,
484 'latency': "%0.3fs" % (time.time() - self.startTime),
485 'remoteIp': self.request.environ.get("HTTP_X_APPENGINE_USER_IP")
486 }
487 requestLogger.log_text(
488 "",
489 client=loggingClient,
490 severity=SEVERITY,
491 http_request=REQUEST,
492 trace=TRACE,
493 resource=requestLoggingRessource,
494 operation={
495 "first": True,
496 "last": True,
497 "id": self._traceID
498 }
499 )
501 if conf.instance.is_dev_server:
502 self.is_deferred = True
504 while self.pendingTasks:
505 task = self.pendingTasks.pop()
506 logging.debug(f"Deferred task emulation, executing {task=}")
507 try:
508 task()
509 except Exception: # noqa
510 logging.exception(f"Deferred Task emulation {task} failed")
512 def _route(self, path: str) -> None:
513 """
514 Does the actual work of sanitizing the parameter, determine which exposed-function to call
515 (and with which parameters)
516 """
518 # Parse the URL
519 if path := parse.urlparse(path).path:
520 self.path = path
521 self.path_list = tuple(unicodedata.normalize("NFC", parse.unquote(part))
522 for part in path.strip("/").split("/"))
524 # Prevent Hash-collision attacks
525 if len(self.request.params) > conf.max_post_params_count:
526 raise errors.BadRequest(
527 f"Too many arguments supplied, exceeding maximum"
528 f" of {conf.max_post_params_count} allowed arguments per request"
529 )
531 param_filter = conf.param_filter_function
532 if param_filter and not callable(param_filter):
533 raise ValueError(f"""{param_filter=} is not callable""")
535 for key, value in self.request.params.items():
536 try:
537 key = unicodedata.normalize("NFC", key)
538 value = unicodedata.normalize("NFC", value)
539 except UnicodeError:
540 # We received invalid unicode data (usually happens when
541 # someone tries to exploit unicode normalisation bugs)
542 raise errors.BadRequest()
544 if param_filter and param_filter(key, value):
545 continue
547 if key == TEMPLATE_STYLE_KEY:
548 self.template_style = value
549 continue
551 if key in self.kwargs:
552 if isinstance(self.kwargs[key], list):
553 self.kwargs[key].append(value)
554 else: # Convert that key to a list
555 self.kwargs[key] = [self.kwargs[key], value]
556 else:
557 self.kwargs[key] = value
559 if "self" in self.kwargs or "return" in self.kwargs: # self or return is reserved for bound methods
560 raise errors.BadRequest()
562 caller = conf.main_resolver
563 idx = 0 # Count how may items from *args we'd have consumed (so the rest can go into *args of the called func
564 path_found = True
566 for part in self.path_list:
567 # TODO: Remove canAccess guards... solve differently.
568 if "canAccess" in caller and not caller["canAccess"]():
569 # We have a canAccess function guarding that object,
570 # and it returns False...
571 raise errors.Unauthorized()
573 idx += 1
575 if part not in caller:
576 part = "index"
578 if caller := caller.get(part):
579 if isinstance(caller, Method):
580 if part == "index":
581 idx -= 1
583 self.args = tuple(self.path_list[idx:])
584 break
586 elif part == "index":
587 path_found = False
588 break
590 else:
591 path_found = False
592 break
594 if not path_found:
595 raise errors.NotFound(
596 f"""The path {utils.string.escape("/".join(self.path_list[:idx]))} could not be found""")
598 if not isinstance(caller, Method):
599 # try to find "index" function
600 if (index := caller.get("index")) and isinstance(index, Method):
601 caller = index
602 else:
603 raise errors.MethodNotAllowed()
605 # Check for internal exposed
606 if caller.exposed is False and not self.internalRequest:
607 raise errors.NotFound()
609 # Fill the Allow header of the response with the allowed HTTP methods
610 if self.method == "options":
611 self.response.headers["Allow"] = ", ".join(sorted(caller.methods)).upper()
613 # Register caller specific CORS headers
614 self.cors_headers = [str(header).lower() for header in caller.cors_allow_headers or ()]
616 # Check for @force_ssl flag
617 if not self.internalRequest \
618 and caller.ssl \
619 and not self.request.host_url.lower().startswith("https://") \
620 and not conf.instance.is_dev_server:
621 raise errors.PreconditionFailed("You must use SSL to access this resource!")
623 # Check for @force_post flag
624 if not self.isPostRequest and caller.methods == ("POST",):
625 raise errors.MethodNotAllowed("You must use POST to access this resource!")
627 # Check if this request should bypass the caches
628 if self.request.headers.get("X-Viur-Disable-Cache"):
629 # No cache requested, check if the current user is allowed to do so
630 if (user := current.user.get()) and "root" in user["access"]:
631 logging.debug("Caching disabled by X-Viur-Disable-Cache header")
632 self.disableCache = True
634 # Destill context as self.context, if available
635 if context := {k: v for k, v in self.kwargs.items() if k.startswith("@")}:
636 # Remove context parameters from kwargs
637 kwargs = {k: v for k, v in self.kwargs.items() if k not in context}
638 # Remove leading "@" from context parameters
639 self.context |= {k[1:]: v for k, v in context.items() if len(k) > 1}
640 else:
641 kwargs = self.kwargs
643 if ((self.internalRequest and conf.debug.trace_internal_call_routing)
644 or conf.debug.trace_external_call_routing):
645 logging.debug(
646 f"Calling {caller._func!r} with args={self.args!r}, {kwargs=} within context={self.context!r}"
647 )
649 if self.method == "options":
650 # OPTIONS request doesn't have a body
651 del self.response.app_iter
652 del self.response.content_type
653 self.response.status = "204 No Content"
654 return
656 # Now call the routed method!
657 res = caller(*self.args, **kwargs)
659 if self.method == "options":
660 # OPTIONS request doesn't have a body
661 del self.response.app_iter
662 del self.response.content_type
663 self.response.status = "204 No Content"
664 return
666 if not isinstance(res, bytes): # Convert the result to bytes if it is not already!
667 res = str(res).encode("UTF-8")
668 self.response.write(res)
670 def _cors(self) -> None:
671 """
672 Set CORS headers to the HTTP response.
674 .. seealso::
676 Option :attr:`core.config.Security.cors_origins`, etc.
677 for cors settings.
679 https://fetch.spec.whatwg.org/#http-cors-protocol
681 https://enable-cors.org/server.html
683 https://www.html5rocks.com/static/images/cors_server_flowchart.png
684 """
686 def test_candidates(value: str, *candidates: str | re.Pattern) -> bool:
687 """Test if the value matches the pattern of any candidate"""
688 for candidate in candidates:
689 if isinstance(candidate, re.Pattern):
690 if candidate.match(value):
691 return True
692 elif isinstance(candidate, str):
693 if candidate.lower() == str(value).lower():
694 return True
695 else:
696 raise TypeError(
697 f"Invalid setting {candidate}. "
698 f"Expected a string or a compiled regex."
699 )
700 return False
702 origin = current.request.get().request.headers.get("Origin")
703 if not origin:
704 return
706 # Origin is set --> It's a CORS request
708 any_origin_allowed = (
709 conf.security.cors_origins == "*"
710 or any(_origin == "*" for _origin in conf.security.cors_origins)
711 or any(_origin.pattern == r".*"
712 for _origin in conf.security.cors_origins
713 if isinstance(_origin, re.Pattern))
714 )
716 if any_origin_allowed and conf.security.cors_origins_use_wildcard:
717 if conf.security.cors_allow_credentials:
718 raise RuntimeError(
719 "Invalid CORS config: "
720 "If credentials mode is \"include\", then `Access-Control-Allow-Origin` cannot be `*`. "
721 "See https://fetch.spec.whatwg.org/#cors-protocol-and-credentials"
722 )
723 self.response.headers["Access-Control-Allow-Origin"] = "*"
725 elif test_candidates(origin, *conf.security.cors_origins):
726 self.response.headers["Access-Control-Allow-Origin"] = origin
728 else:
729 logging.warning(f"{origin=} not valid (must be one of {conf.security.cors_origins=})")
730 return
732 if conf.security.cors_allow_credentials:
733 self.response.headers["Access-Control-Allow-Credentials"] = "true"
735 if self.method == "options":
736 method = (self.request.headers.get("Access-Control-Request-Method") or "").lower()
738 if method in conf.security.cors_methods:
739 # It's a CORS-preflight request
740 # - MUST include Access-Control-Request-Method
741 # - CAN include Access-Control-Request-Headers
743 # The response can be cached
744 if conf.security.cors_max_age is not None:
745 assert isinstance(conf.security.cors_max_age, datetime.timedelta)
746 self.response.headers["Access-Control-Max-Age"] = \
747 str(int(conf.security.cors_max_age.total_seconds()))
749 # Allowed methods
750 self.response.headers["Access-Control-Allow-Methods"] = ", ".join(
751 sorted(conf.security.cors_methods)).upper()
753 # Allowed headers
754 request_headers = self.request.headers.get("Access-Control-Request-Headers")
755 request_headers = [h.strip().lower() for h in request_headers.split(",")]
756 if conf.security.cors_allow_headers == "*":
757 # Every header is allowed
758 allow_headers = request_headers[:]
759 else:
760 # There are generally headers allowed and/or from the caller
761 allow_headers = [
762 header
763 for header in request_headers
764 if test_candidates(
765 header,
766 *(self.cors_headers or ()), # caller specific
767 *(conf.security.cors_allow_headers or ()) # generally global
768 )
769 ]
770 if allow_headers:
771 self.response.headers["Access-Control-Allow-Headers"] = ", ".join(sorted(allow_headers))
773 else:
774 logging.warning(
775 f"Access-Control-Request-Method: {method} is NOT a valid method of {conf.security.cors_methods=}. "
776 f"Don't append CORS-preflight request headers"
777 )
779 def saveSession(self) -> None:
780 current.session.get().save()
783from .i18n import translate # noqa: E402