Coverage for /home/runner/work/viur-core/viur-core/viur/src/viur/core/tasks.py: 22%
420 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-27 07:59 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-27 07:59 +0000
1import abc
2import datetime
3import functools
4import json
5import logging
6import os
7import sys
8import time
9import traceback
10import typing as t
12import grpc
13import requests
14from google import protobuf
15from google.cloud import tasks_v2
17from viur.core import current, db, errors, utils
18from viur.core.config import conf
19from viur.core.decorators import exposed, skey
20from viur.core.module import Module
22CUSTOM_OBJ = t.TypeVar("CUSTOM_OBJ") # A JSON serializable object
25class CustomEnvironmentHandler(abc.ABC):
26 @abc.abstractmethod
27 def serialize(self) -> CUSTOM_OBJ:
28 """Serialize custom environment data
30 This function must not require any parameters and must
31 return a JSON serializable object with the desired information.
32 """
33 ...
35 @abc.abstractmethod
36 def restore(self, obj: CUSTOM_OBJ) -> None:
37 """Restore custom environment data
39 This function will receive the object from :meth:`serialize` and should write
40 the information it contains to the environment of the deferred request.
41 """
42 ...
45_gaeApp = os.environ.get("GAE_APPLICATION")
47queueRegion = None
48if _gaeApp: 48 ↛ 50line 48 didn't jump to line 50 because the condition on line 48 was never true
50 try:
51 headers = {"Metadata-Flavor": "Google"}
52 r = requests.get("http://metadata.google.internal/computeMetadata/v1/instance/region", headers=headers)
53 # r.text should be look like this "projects/(project-number)/region/(region)"
54 # like so "projects/1234567890/region/europe-west3"
55 queueRegion = r.text.split("/")[-1]
56 except Exception as e: # Something went wrong with the Google Metadata Sever we use the old way
57 logging.warning(f"Can't obtain queueRegion from Google MetaData Server due exception {e=}")
58 regionPrefix = _gaeApp.split("~")[0]
59 regionMap = {
60 "h": "europe-west3",
61 "e": "europe-west1"
62 }
63 queueRegion = regionMap.get(regionPrefix)
65if not queueRegion and conf.instance.is_dev_server and os.getenv("TASKS_EMULATOR") is None: 65 ↛ 67line 65 didn't jump to line 67 because the condition on line 65 was never true
66 # Probably local development server
67 logging.warning("Taskqueue disabled, tasks will run inline!")
69if not conf.instance.is_dev_server or os.getenv("TASKS_EMULATOR") is None: 69 ↛ 72line 69 didn't jump to line 72 because the condition on line 69 was always true
70 taskClient = tasks_v2.CloudTasksClient()
71else:
72 taskClient = tasks_v2.CloudTasksClient(
73 transport=tasks_v2.services.cloud_tasks.transports.CloudTasksGrpcTransport(
74 channel=grpc.insecure_channel(os.getenv("TASKS_EMULATOR"))
75 )
76 )
77 queueRegion = "local"
79_periodicTasks: dict[str, dict[t.Callable, datetime.timedelta]] = {}
80_callableTasks = {}
81_deferred_tasks = {}
82_startupTasks = []
83_appengineServiceIPs = {"10.0.0.1", "0.1.0.1", "0.1.0.2"}
86class PermanentTaskFailure(Exception):
87 """Indicates that a task failed, and will never succeed."""
88 pass
91def removePeriodicTask(task: t.Callable) -> None:
92 """
93 Removes a periodic task from the queue. Useful to unqueue an task
94 that has been inherited from an overridden module.
95 """
96 global _periodicTasks
97 assert "periodicTaskName" in dir(task), "This is not a periodic task? "
98 for queueDict in _periodicTasks.values():
99 if task in queueDict:
100 del queueDict[task]
103class CallableTaskBase:
104 """
105 Base class for user-callable tasks.
106 Must be subclassed.
107 """
108 key = None # Unique identifier for this task
109 name = None # Human-Readable name
110 descr = None # Human-Readable description
111 kindName = "server-task"
113 def canCall(self) -> bool:
114 """
115 Checks wherever the current user can execute this task
116 :returns: bool
117 """
118 return False
120 def dataSkel(self):
121 """
122 If additional data is needed, return a skeleton-instance here.
123 These values are then passed to *execute*.
124 """
125 return None
127 def execute(self):
128 """
129 The actual code that should be run goes here.
130 """
131 raise NotImplementedError()
134class TaskHandler(Module):
135 """
136 Task Handler.
137 Handles calling of Tasks (queued and periodic), and performs updatechecks
138 Do not Modify. Do not Subclass.
139 """
140 adminInfo = None
141 retryCountWarningThreshold = 25
143 def findBoundTask(self, task: t.Callable, obj: object, depth: int = 0) -> t.Optional[tuple[t.Callable, object]]:
145 """
146 Tries to locate the instance, this function belongs to.
147 If it succeeds in finding it, it returns the function and its instance (-> its "self").
148 Otherwise, None is returned.
149 :param task: A callable decorated with @PeriodicTask
150 :param obj: Object, which will be scanned in the current iteration.
151 :param depth: Current iteration depth.
152 """
153 if depth > 3 or "periodicTaskName" not in dir(task): # Limit the maximum amount of recursions
154 return None
155 for attr in dir(obj):
156 if attr.startswith("_"):
157 continue
158 try:
159 v = getattr(obj, attr)
160 except AttributeError:
161 continue
162 if callable(v) and "periodicTaskName" in dir(v) and str(v.periodicTaskName) == str(task.periodicTaskName):
163 return v, obj
164 if not isinstance(v, str) and not callable(v):
165 res = self.findBoundTask(task, v, depth + 1)
166 if res:
167 return res
168 return None
170 @exposed
171 def queryIter(self, *args, **kwargs):
172 """
173 This processes one chunk of a queryIter (see below).
174 """
175 req = current.request.get().request
176 self._validate_request()
177 data = utils.json.loads(req.body)
178 if data["classID"] not in MetaQueryIter._classCache:
179 logging.error(f"""Could not continue queryIter - {data["classID"]} not known on this instance""")
180 MetaQueryIter._classCache[data["classID"]]._qryStep(data)
182 @exposed
183 def deferred(self, *args, **kwargs):
184 """
185 This catches one deferred call and routes it to its destination
186 """
187 req = current.request.get().request
188 self._validate_request()
189 # Check if the retry count exceeds our warning threshold
190 retryCount = req.headers.get("X-Appengine-Taskretrycount", None)
191 if retryCount and int(retryCount) == self.retryCountWarningThreshold:
192 from viur.core import email
193 email.send_email_to_admins(
194 "Deferred task retry counter exceeded warning threshold",
195 f"""Task {req.headers.get("X-Appengine-Taskname", "")} is retried for the {retryCount}th time."""
196 )
198 cmd, data = utils.json.loads(req.body)
199 funcPath, args, kwargs, env = data
200 logging.debug(f"Call task {funcPath} with {cmd=} {args=} {kwargs=} {env=}")
202 if env:
203 if "user" in env and env["user"]:
204 current.session.get()["user"] = env["user"]
205 # FIXME: We do not have a fully loaded session from the cookie here,
206 # but only a partial session.
207 # But we still leave `loaded` on False, which leads to problems.
209 # Load current user into context variable if user module is there.
210 if user_mod := getattr(conf.main_app.vi, "user", None):
211 current.user.set(user_mod.getCurrentUser())
212 if "lang" in env and env["lang"]:
213 current.language.set(env["lang"])
214 if "transactionMarker" in env:
215 marker = db.Get(db.Key("viur-transactionmarker", env["transactionMarker"]))
216 if not marker:
217 logging.info(f"""Dropping task, transaction {env["transactionMarker"]} did not apply""")
218 return
219 else:
220 logging.info(f"""Executing task, transaction {env["transactionMarker"]} did succeed""")
221 if "custom" in env and conf.tasks_custom_environment_handler:
222 # Check if we need to restore additional environmental data
223 conf.tasks_custom_environment_handler.restore(env["custom"])
224 if cmd == "rel":
225 caller = conf.main_app
226 pathlist = [x for x in funcPath.split("/") if x]
227 for currpath in pathlist:
228 if currpath not in dir(caller):
229 logging.error(f"Could not resolve {funcPath=} (failed part was {currpath!r})")
230 return
231 caller = getattr(caller, currpath)
232 try:
233 caller(*args, **kwargs)
234 except PermanentTaskFailure:
235 logging.error("PermanentTaskFailure")
236 except Exception as e:
237 logging.exception(e)
238 raise errors.RequestTimeout() # Task-API should retry
239 elif cmd == "unb":
240 if funcPath not in _deferred_tasks:
241 logging.error(f"Missed deferred task {funcPath=} ({args=},{kwargs=})")
242 # We call the deferred function *directly* (without walking through the mkDeferred lambda), so ensure
243 # that any hit to another deferred function will defer again
245 current.request.get().DEFERRED_TASK_CALLED = True
246 try:
247 _deferred_tasks[funcPath](*args, **kwargs)
248 except PermanentTaskFailure:
249 logging.error("PermanentTaskFailure")
250 except Exception as e:
251 logging.exception(e)
252 raise errors.RequestTimeout() # Task-API should retry
254 @exposed
255 def cron(self, cronName="default", *args, **kwargs):
256 req = current.request.get()
257 if not conf.instance.is_dev_server:
258 self._validate_request(require_cron=True, require_taskname=False)
259 if cronName not in _periodicTasks:
260 logging.warning(f"Cron request {cronName} doesn't have any tasks")
261 # We must defer from cron, as tasks will interpret it as a call originating from task-queue - causing deferred
262 # functions to be called directly, wich causes calls with _countdown etc set to fail.
263 req.DEFERRED_TASK_CALLED = True
264 for task, interval in _periodicTasks[cronName].items(): # Call all periodic tasks bound to that queue
265 periodicTaskName = task.periodicTaskName.lower()
266 if interval: # Ensure this task doesn't get called to often
267 lastCall = db.Get(db.Key("viur-task-interval", periodicTaskName))
268 if lastCall and utils.utcNow() - lastCall["date"] < interval:
269 logging.debug(f"Task {periodicTaskName!r} has already run recently - skipping.")
270 continue
271 res = self.findBoundTask(task, conf.main_app)
272 try:
273 if res: # Its bound, call it this way :)
274 res[0]()
275 else:
276 task() # It seems it wasn't bound - call it as a static method
277 except Exception as e:
278 logging.error(f"Error calling periodic task {periodicTaskName}")
279 logging.exception(e)
280 else:
281 logging.debug(f"Successfully called task {periodicTaskName}")
282 if interval:
283 # Update its last-call timestamp
284 entry = db.Entity(db.Key("viur-task-interval", periodicTaskName))
285 entry["date"] = utils.utcNow()
286 db.Put(entry)
287 logging.debug("Periodic tasks complete")
289 def _validate_request(
290 self,
291 *,
292 require_cron: bool = False,
293 require_taskname: bool = True,
294 ) -> None:
295 """
296 Validate the header and metadata of a request
298 If the request is valid, None will be returned.
299 Otherwise, an exception will be raised.
301 :param require_taskname: Require "X-AppEngine-TaskName" header
302 :param require_cron: Require "X-Appengine-Cron" header
303 """
304 req = current.request.get().request
305 if (
306 req.environ.get("HTTP_X_APPENGINE_USER_IP") not in _appengineServiceIPs
307 and (not conf.instance.is_dev_server or os.getenv("TASKS_EMULATOR") is None)
308 ):
309 logging.critical("Detected an attempted XSRF attack. This request did not originate from Task Queue.")
310 raise errors.Forbidden()
311 if require_cron and "X-Appengine-Cron" not in req.headers:
312 logging.critical('Detected an attempted XSRF attack. The header "X-AppEngine-Cron" was not set.')
313 raise errors.Forbidden()
314 if require_taskname and "X-AppEngine-TaskName" not in req.headers:
315 logging.critical('Detected an attempted XSRF attack. The header "X-AppEngine-Taskname" was not set.')
316 raise errors.Forbidden()
318 @exposed
319 def list(self, *args, **kwargs):
320 """Lists all user-callable tasks which are callable by this user"""
321 global _callableTasks
323 tasks = db.SkelListRef()
324 tasks.extend([{
325 "key": x.key,
326 "name": str(x.name),
327 "descr": str(x.descr)
328 } for x in _callableTasks.values() if x().canCall()
329 ])
331 return self.render.list(tasks)
333 @exposed
334 @skey(allow_empty=True)
335 def execute(self, taskID, *args, **kwargs):
336 """Queues a specific task for the next maintenance run"""
337 global _callableTasks
338 if taskID in _callableTasks:
339 task = _callableTasks[taskID]()
340 else:
341 return
342 if not task.canCall():
343 raise errors.Unauthorized()
344 skel = task.dataSkel()
345 if not kwargs or not skel.fromClient(kwargs) or utils.parse.bool(kwargs.get("bounce")):
346 return self.render.add(skel)
347 task.execute(**skel.accessedValues)
348 return self.render.addSuccess(skel)
351TaskHandler.admin = True
352TaskHandler.vi = True
353TaskHandler.html = True
356# Decorators
358def retry_n_times(retries: int, email_recipients: None | str | list[str] = None,
359 tpl: None | str = None) -> t.Callable:
360 """
361 Wrapper for deferred tasks to limit the amount of retries
363 :param retries: Number of maximum allowed retries
364 :param email_recipients: Email addresses to which a info should be sent
365 when the retry limit is reached.
366 :param tpl: Instead of the standard text, a custom template can be used.
367 The name of an email template must be specified.
368 """
369 # language=Jinja2
370 string_template = \
371 """Task {{func_name}} failed {{retries}} times
372 This was the last attempt.<br>
373 <pre>{{func_module|escape}}.{{func_name|escape}}({{signature|escape}})</pre>
374 <pre>{{traceback|escape}}</pre>"""
376 def outer_wrapper(func):
377 @functools.wraps(func)
378 def inner_wrapper(*args, **kwargs):
379 try:
380 retry_count = int(current.request.get().request.headers.get("X-Appengine-Taskretrycount", -1))
381 except AttributeError:
382 # During warmup current.request is None (at least on local devserver)
383 retry_count = -1
384 try:
385 return func(*args, **kwargs)
386 except Exception as exc:
387 logging.exception(f"Task {func.__qualname__} failed: {exc}")
388 logging.info(
389 f"This was the {retry_count}. retry."
390 f"{retries - retry_count} retries remaining. (total = {retries})"
391 )
392 if retry_count < retries:
393 # Raise the exception to mark this task as failed, so the task queue can retry it.
394 raise exc
395 else:
396 if email_recipients:
397 args_repr = [repr(arg) for arg in args]
398 kwargs_repr = [f"{k!s}={v!r}" for k, v in kwargs.items()]
399 signature = ", ".join(args_repr + kwargs_repr)
400 try:
401 from viur.core import email
402 email.send_email(
403 dests=email_recipients,
404 tpl=tpl,
405 stringTemplate=string_template if tpl is None else string_template,
406 # The following params provide information for the emails templates
407 func_name=func.__name__,
408 func_qualname=func.__qualname__,
409 func_module=func.__module__,
410 retries=retries,
411 args=args,
412 kwargs=kwargs,
413 signature=signature,
414 traceback=traceback.format_exc(),
415 )
416 except Exception:
417 logging.exception("Failed to send email to %r", email_recipients)
418 # Mark as permanently failed (could return nothing too)
419 raise PermanentTaskFailure()
421 return inner_wrapper
423 return outer_wrapper
426def noRetry(f):
427 """Prevents a deferred Function from being called a second time"""
428 logging.warning("Use of `@noRetry` is deprecated; Use `@retry_n_times(0)` instead!", stacklevel=2)
429 return retry_n_times(0)(f)
432def CallDeferred(func: t.Callable) -> t.Callable:
433 """
434 This is a decorator, which always calls the wrapped method deferred.
436 The call will be packed and queued into a Cloud Tasks queue.
437 The Task Queue calls the TaskHandler which executed the wrapped function
438 with the originally arguments in a different request.
441 In addition to the arguments for the wrapped methods you can set these:
443 _queue: Specify the queue in which the task should be pushed.
444 If no value is given, the queue name set in `conf.tasks_default_queues`
445 will be used. If the config does not have a value for this task, "default"
446 is used as the default. The queue must exist (use the queue.yaml).
448 _countdown: Specify a time in seconds after which the task should be called.
449 This time is relative to the moment where the wrapped method has been called.
451 _eta: Instead of a relative _countdown value you can specify a `datetime`
452 when the task is scheduled to be attempted or retried.
454 _name: Specify a custom name for the cloud task. Must be unique and can
455 contain only letters ([A-Za-z]), numbers ([0-9]), hyphens (-), colons (:), or periods (.).
457 _target_version: Specify a version on which to run this task.
458 By default, a task will be run on the same version where the wrapped method has been called.
460 _call_deferred: Calls the @CallDeferred decorated method directly.
461 This is for example necessary, to call a super method which is decorated with @CallDeferred.
463 .. code-block:: python
465 # Example for use of the _call_deferred-parameter
466 class A(Module):
467 @CallDeferred
468 def task(self):
469 ...
471 class B(A):
472 @CallDeferred
473 def task(self):
474 super().task(_call_deferred=False) # avoid secondary deferred call
475 ...
477 See also:
478 https://cloud.google.com/python/docs/reference/cloudtasks/latest/google.cloud.tasks_v2.types.Task
479 https://cloud.google.com/python/docs/reference/cloudtasks/latest/google.cloud.tasks_v2.types.CreateTaskRequest
480 """
481 if "viur_doc_build" in dir(sys): 481 ↛ 484line 481 didn't jump to line 484 because the condition on line 481 was always true
482 return func
484 __undefinedFlag_ = object()
486 def make_deferred(
487 func: t.Callable,
488 self=__undefinedFlag_,
489 *args,
490 _queue: str = None,
491 _name: str | None = None,
492 _call_deferred: bool = True,
493 _target_version: str = conf.instance.app_version,
494 _eta: datetime.datetime | None = None,
495 _countdown: int = 0,
496 **kwargs
497 ):
498 if _eta is not None and _countdown:
499 raise ValueError("You cannot set the _countdown and _eta argument together!")
501 logging.debug(
502 f"make_deferred {func=}, {self=}, {args=}, {kwargs=}, "
503 f"{_queue=}, {_name=}, {_call_deferred=}, {_target_version=}, {_eta=}, {_countdown=}"
504 )
506 try:
507 req = current.request.get()
508 except Exception: # This will fail for warmup requests
509 req = None
511 if not queueRegion:
512 # Run tasks inline
513 logging.debug(f"{func=} will be executed inline")
515 @functools.wraps(func)
516 def task():
517 if self is __undefinedFlag_:
518 return func(*args, **kwargs)
519 else:
520 return func(self, *args, **kwargs)
522 if req:
523 req.pendingTasks.append(task) # This property only exists on development server!
524 else:
525 # Warmup request or something - we have to call it now as we can't defer it :/
526 task()
528 return # Ensure no result gets passed back
530 # It's the deferred method which is called from the task queue, this has to be called directly
531 _call_deferred &= not (req and req.request.headers.get("X-Appengine-Taskretrycount")
532 and "DEFERRED_TASK_CALLED" not in dir(req))
534 if not _call_deferred:
535 if self is __undefinedFlag_:
536 return func(*args, **kwargs)
538 req.DEFERRED_TASK_CALLED = True
539 return func(self, *args, **kwargs)
541 else:
542 try:
543 if self.__class__.__name__ == "index":
544 funcPath = func.__name__
545 else:
546 funcPath = f"{self.modulePath}/{func.__name__}"
547 command = "rel"
548 except Exception:
549 funcPath = f"{func.__name__}.{func.__module__}"
550 if self is not __undefinedFlag_:
551 args = (self,) + args # Re-append self to args, as this function is (hopefully) unbound
552 command = "unb"
554 if _queue is None:
555 _queue = conf.tasks_default_queues.get(
556 funcPath, conf.tasks_default_queues.get("__default__", "default")
557 )
559 # Try to preserve the important data from the current environment
560 try: # We might get called inside a warmup request without session
561 usr = current.session.get().get("user")
562 if "password" in usr:
563 del usr["password"]
564 except Exception:
565 usr = None
567 env = {"user": usr}
569 try:
570 env["lang"] = current.language.get()
571 except AttributeError: # This isn't originating from a normal request
572 pass
574 if db.IsInTransaction():
575 # We have to ensure transaction guarantees for that task also
576 env["transactionMarker"] = db.acquireTransactionSuccessMarker()
577 # We move that task at least 90 seconds into the future so the transaction has time to settle
578 _countdown = max(90, _countdown) # Countdown can be set to None
580 if conf.tasks_custom_environment_handler:
581 # Check if this project relies on additional environmental variables and serialize them too
582 env["custom"] = conf.tasks_custom_environment_handler.serialize()
584 # Create task description
585 task = tasks_v2.Task(
586 app_engine_http_request=tasks_v2.AppEngineHttpRequest(
587 body=utils.json.dumps((command, (funcPath, args, kwargs, env))).encode(),
588 http_method=tasks_v2.HttpMethod.POST,
589 relative_uri="/_tasks/deferred",
590 app_engine_routing=tasks_v2.AppEngineRouting(
591 version=_target_version,
592 ),
593 ),
594 )
595 if _name is not None:
596 task.name = taskClient.task_path(conf.instance.project_id, queueRegion, _queue, _name)
598 # Set a schedule time in case eta (absolut) or countdown (relative) was set.
599 if seconds := _countdown:
600 _eta = utils.utcNow() + datetime.timedelta(seconds=seconds)
601 if _eta:
602 # We must send a Timestamp Protobuf instead of a date-string
603 timestamp = protobuf.timestamp_pb2.Timestamp()
604 timestamp.FromDatetime(_eta)
605 task.schedule_time = timestamp
607 # Use the client to build and send the task.
608 parent = taskClient.queue_path(conf.instance.project_id, queueRegion, _queue)
609 logging.debug(f"{parent=}, {task=}")
610 taskClient.create_task(tasks_v2.CreateTaskRequest(parent=parent, task=task))
612 logging.info(f"Created task {func.__name__}.{func.__module__} with {args=} {kwargs=} {env=}")
614 global _deferred_tasks
615 _deferred_tasks[f"{func.__name__}.{func.__module__}"] = func
617 @functools.wraps(func)
618 def wrapper(*args, **kwargs):
619 return make_deferred(func, *args, **kwargs)
621 return wrapper
624def callDeferred(func):
625 """
626 Deprecated version of CallDeferred
627 """
628 import logging, warnings
630 msg = "Use of @callDeferred is deprecated, use @CallDeferred instead."
631 logging.warning(msg, stacklevel=3)
632 warnings.warn(msg, stacklevel=3)
634 return CallDeferred(func)
637def PeriodicTask(interval: datetime.timedelta | int | float = 0, cronName: str = "default") -> t.Callable:
638 """
639 Decorator to call a function periodically during cron job execution.
641 Interval defines a lower bound for the call-frequency for the given task, specified as a timedelta.
643 The true interval of how often cron jobs are being executed is defined in the project's cron.yaml file.
644 This defaults to 4 hours (see https://github.com/viur-framework/viur-base/blob/main/deploy/cron.yaml).
645 In case the interval defined here is lower than 4 hours, the task will be fired once every 4 hours anyway.
647 :param interval: Call at most the given timedelta.
648 """
649 def make_decorator(fn):
650 nonlocal interval
651 if fn.__name__.startswith("_"): 651 ↛ 652line 651 didn't jump to line 652 because the condition on line 651 was never true
652 raise RuntimeError("Periodic called methods cannot start with an underscore! "
653 f"Please rename {fn.__name__!r}")
655 if cronName not in _periodicTasks: 655 ↛ 658line 655 didn't jump to line 658 because the condition on line 655 was always true
656 _periodicTasks[cronName] = {}
658 if isinstance(interval, (int, float)) and "tasks.periodic.useminutes" in conf.compatibility: 658 ↛ 659line 658 didn't jump to line 659 because the condition on line 658 was never true
659 logging.warning(
660 f"PeriodicTask assuming {interval=} minutes here. This is changed into seconds in future. "
661 f"Please use `datetime.timedelta(minutes={interval})` for clarification.",
662 stacklevel=2,
663 )
664 interval *= 60
666 _periodicTasks[cronName][fn] = utils.parse.timedelta(interval)
667 fn.periodicTaskName = f"{fn.__module__}_{fn.__qualname__}".replace(".", "_").lower()
668 return fn
670 return make_decorator
673def CallableTask(fn: t.Callable) -> t.Callable:
674 """Marks a Class as representing a user-callable Task.
675 It *should* extend CallableTaskBase and *must* provide
676 its API
677 """
678 global _callableTasks
679 _callableTasks[fn.key] = fn
680 return fn
683def StartupTask(fn: t.Callable) -> t.Callable:
684 """
685 Functions decorated with this are called shortly at instance startup.
686 It's *not* guaranteed that they actually run on the instance that just started up!
687 Wrapped functions must not take any arguments.
688 """
689 global _startupTasks
690 _startupTasks.append(fn)
691 return fn
694@CallDeferred
695def runStartupTasks():
696 """
697 Runs all queued startupTasks.
698 Do not call directly!
699 """
700 global _startupTasks
701 for st in _startupTasks:
702 st()
705class MetaQueryIter(type):
706 """
707 This is the meta class for QueryIters.
708 Used only to keep track of all subclasses of QueryIter so we can emit the callbacks
709 on the correct class.
710 """
711 _classCache = {} # Mapping className -> Class
713 def __init__(cls, name, bases, dct):
714 MetaQueryIter._classCache[str(cls)] = cls
715 cls.__classID__ = str(cls)
716 super(MetaQueryIter, cls).__init__(name, bases, dct)
719class QueryIter(object, metaclass=MetaQueryIter):
720 """
721 BaseClass to run a database Query and process each entry matched.
722 This will run each step deferred, so it is possible to process an arbitrary number of entries
723 without being limited by time or memory.
725 To use this class create a subclass, override the classmethods handleEntry and handleFinish and then
726 call startIterOnQuery with an instance of a database Query (and possible some custom data to pass along)
727 """
728 queueName = "default" # Name of the taskqueue we will run on
730 @classmethod
731 def startIterOnQuery(cls, query: db.Query, customData: t.Any = None) -> None:
732 """
733 Starts iterating the given query on this class. Will return immediately, the first batch will already
734 run deferred.
736 Warning: Any custom data *must* be json-serializable and *must* be passed in customData. You cannot store
737 any data on this class as each chunk may run on a different instance!
738 """
739 assert not (query._customMultiQueryMerge or query._calculateInternalMultiQueryLimit), \
740 "Cannot iter a query with postprocessing"
741 assert isinstance(query.queries, db.QueryDefinition), "Unsatisfiable query or query with an IN filter"
742 qryDict = {
743 "kind": query.kind,
744 "srcSkel": query.srcSkel.kindName if query.srcSkel else None,
745 "filters": query.queries.filters,
746 "orders": [(propName, sortOrder.value) for propName, sortOrder in query.queries.orders],
747 "startCursor": query.queries.startCursor,
748 "endCursor": query.queries.endCursor,
749 "origKind": query.origKind,
750 "distinct": query.queries.distinct,
751 "classID": cls.__classID__,
752 "customData": customData,
753 "totalCount": 0
754 }
755 cls._requeueStep(qryDict)
757 @classmethod
758 def _requeueStep(cls, qryDict: dict[str, t.Any]) -> None:
759 """
760 Internal use only. Pushes a new step defined in qryDict to either the taskqueue or append it to
761 the current request if we are on the local development server.
762 """
763 if not queueRegion: # Run tasks inline - hopefully development server
764 req = current.request.get()
765 task = lambda *args, **kwargs: cls._qryStep(qryDict)
766 if req:
767 req.pendingTasks.append(task) # < This property will be only exist on development server!
768 return
769 taskClient.create_task(tasks_v2.CreateTaskRequest(
770 parent=taskClient.queue_path(conf.instance.project_id, queueRegion, cls.queueName),
771 task=tasks_v2.Task(
772 app_engine_http_request=tasks_v2.AppEngineHttpRequest(
773 body=utils.json.dumps(qryDict).encode(),
774 http_method=tasks_v2.HttpMethod.POST,
775 relative_uri="/_tasks/queryIter",
776 app_engine_routing=tasks_v2.AppEngineRouting(
777 version=conf.instance.app_version,
778 ),
779 )
780 ),
781 ))
783 @classmethod
784 def _qryStep(cls, qryDict: dict[str, t.Any]) -> None:
785 """
786 Internal use only. Processes one block of five entries from the query defined in qryDict and
787 reschedules the next block.
788 """
789 from viur.core.skeleton import skeletonByKind
790 qry = db.Query(qryDict["kind"])
791 qry.srcSkel = skeletonByKind(qryDict["srcSkel"])() if qryDict["srcSkel"] else None
792 qry.queries.filters = qryDict["filters"]
793 qry.queries.orders = [(propName, db.SortOrder(sortOrder)) for propName, sortOrder in qryDict["orders"]]
794 qry.setCursor(qryDict["startCursor"], qryDict["endCursor"])
795 qry.origKind = qryDict["origKind"]
796 qry.queries.distinct = qryDict["distinct"]
797 if qry.srcSkel:
798 qryIter = qry.fetch(5)
799 else:
800 qryIter = qry.run(5)
801 for item in qryIter:
802 try:
803 cls.handleEntry(item, qryDict["customData"])
804 except: # First exception - we'll try another time (probably/hopefully transaction collision)
805 time.sleep(5)
806 try:
807 cls.handleEntry(item, qryDict["customData"])
808 except Exception as e: # Second exception - call error_handler
809 try:
810 doCont = cls.handleError(item, qryDict["customData"], e)
811 except Exception as e:
812 logging.error(f"handleError failed on {item} - bailing out")
813 logging.exception(e)
814 doCont = False
815 if not doCont:
816 logging.error(f"Exiting queryIter on cursor {qry.getCursor()!r}")
817 return
818 qryDict["totalCount"] += 1
819 cursor = qry.getCursor()
820 if cursor:
821 qryDict["startCursor"] = cursor
822 cls._requeueStep(qryDict)
823 else:
824 cls.handleFinish(qryDict["totalCount"], qryDict["customData"])
826 @classmethod
827 def handleEntry(cls, entry, customData):
828 """
829 Overridable hook to process one entry. "entry" will be either an db.Entity or an
830 SkeletonInstance (if that query has been created by skel.all())
832 Warning: If your query has an sortOrder other than __key__ and you modify that property here
833 it is possible to encounter that object later one *again* (as it may jump behind the current cursor).
834 """
835 logging.debug(f"handleEntry called on {cls} with {entry}.")
837 @classmethod
838 def handleFinish(cls, totalCount: int, customData):
839 """
840 Overridable hook that indicates the current run has been finished.
841 """
842 logging.debug(f"handleFinish called on {cls} with {totalCount} total Entries processed")
844 @classmethod
845 def handleError(cls, entry, customData, exception) -> bool:
846 """
847 Handle a error occurred in handleEntry.
848 If this function returns True, the queryIter continues, otherwise it breaks and prints the current cursor.
849 """
850 logging.debug(f"handleError called on {cls} with {entry}.")
851 logging.exception(exception)
852 return True
855class DeleteEntitiesIter(QueryIter):
856 """
857 Simple Query-Iter to delete all entities encountered.
859 ..Warning: When iterating over skeletons, make sure that the
860 query was created using `Skeleton().all()`.
861 This way the `Skeleton.delete()` method can be used and
862 the appropriate post-processing can be done.
863 """
865 @classmethod
866 def handleEntry(cls, entry, customData):
867 from viur.core.skeleton import SkeletonInstance
868 if isinstance(entry, SkeletonInstance):
869 entry.delete()
870 else:
871 db.Delete(entry.key)