Coverage for /home/runner/work/viur-core/viur-core/viur/src/viur/core/ratelimit.py: 0%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-02-27 07:59 +0000

1import datetime 

2 

3from viur.core import current, db, errors, utils 

4from viur.core.tasks import PeriodicTask, DeleteEntitiesIter 

5import typing as t 

6from datetime import timedelta 

7 

8 

9class RateLimit(object): 

10 """ 

11 This class is used to restrict access to certain functions to *maxRate* calls per minute. 

12 

13 Usage: Create an instance of this object in you modules __init__ function. then call 

14 isQuotaAvailable before executing the action to check if there is quota available and 

15 after executing the action decrementQuota. 

16 

17 """ 

18 rateLimitKind = "viur-ratelimit" 

19 

20 def __init__(self, resource: str, maxRate: int, minutes: int, method: t.Literal["ip", "user"]): 

21 """ 

22 Initializes a new RateLimit gate. 

23 

24 :param resource: Name of the resource to protect 

25 :param maxRate: Amount of tries allowed in the give time-span 

26 :param minutes: Length of the time-span in minutes 

27 :param method: Lock by IP or by the current user 

28 """ 

29 super(RateLimit, self).__init__() 

30 self.resource = resource 

31 self.maxRate = maxRate 

32 self.minutes = minutes 

33 self.steps = min(minutes, 5) 

34 self.secondsPerStep = 60 * (float(minutes) / float(self.steps)) 

35 assert method in ["ip", "user"], "method must be 'ip' or 'user'" 

36 self.useUser = method == "user" 

37 

38 def _getEndpointKey(self) -> db.Key | str: 

39 """ 

40 :warning: 

41 It's invalid to call _getEndpointKey if method is set to user and there's no user logged in! 

42 

43 :return: the key associated with the current endpoint (it's IP or the key of the current user) 

44 """ 

45 if self.useUser: 

46 user = current.user.get() 

47 assert user, "Cannot decrement usage from guest!" 

48 return user["key"] 

49 else: 

50 remoteAddr = current.request.get().request.remote_addr 

51 if "::" in remoteAddr: # IPv6 in shorted form 

52 remoteAddr = remoteAddr.split(":") 

53 blankIndex = remoteAddr.index("") 

54 missigParts = ["0000"] * (8 - len(remoteAddr)) 

55 remoteAddr = remoteAddr[:blankIndex] + missigParts + remoteAddr[blankIndex + 1:] 

56 return ":".join(remoteAddr[:4]) 

57 elif ":" in remoteAddr: # It's IPv6, so we remove the last 64 bits (interface id) 

58 # as it is easily controlled by the user 

59 return ":".join(remoteAddr.split(":")[:4]) 

60 else: # It's IPv4, simply return that address 

61 return remoteAddr 

62 

63 def _getCurrentTimeKey(self) -> str: 

64 """ 

65 :return: the current lockperiod used in second position of the memcache key 

66 """ 

67 dateTime = utils.utcNow() 

68 key = dateTime.strftime("%Y-%m-%d-%%s") 

69 secsinceMidnight = (dateTime - dateTime.replace(hour=0, minute=0, second=0, microsecond=0)).total_seconds() 

70 currentStep = int(secsinceMidnight / self.secondsPerStep) 

71 return key % currentStep 

72 

73 def decrementQuota(self) -> None: 

74 """ 

75 Removes one attempt from the pool of available Quota for that user/ip 

76 """ 

77 

78 def updateTxn(cacheKey: str) -> None: 

79 key = db.Key(self.rateLimitKind, cacheKey) 

80 obj = db.Get(key) 

81 if obj is None: 

82 obj = db.Entity(key) 

83 obj["value"] = 0 

84 obj["value"] += 1 

85 obj["expires"] = utils.utcNow() + timedelta(minutes=2 * self.minutes) 

86 db.Put(obj) 

87 

88 lockKey = f"{self.resource}-{self._getEndpointKey()}-{self._getCurrentTimeKey()}" 

89 db.RunInTransaction(updateTxn, lockKey) 

90 

91 def isQuotaAvailable(self) -> bool: 

92 """ 

93 Checks if there's currently quota available for the current user/ip 

94 :return: True if there's quota available, False otherwise 

95 """ 

96 endPoint = self._getEndpointKey() 

97 currentDateTime = utils.utcNow() 

98 secSinceMidnight = (currentDateTime - currentDateTime.replace(hour=0, minute=0, second=0, 

99 microsecond=0)).total_seconds() 

100 currentStep = int(secSinceMidnight / self.secondsPerStep) 

101 keyBase = currentDateTime.strftime("%Y-%m-%d-%%s") 

102 cacheKeys = [] 

103 for x in range(0, self.steps): 

104 cacheKeys.append( 

105 db.Key(self.rateLimitKind, f"{self.resource}-{endPoint}-{keyBase % (currentStep - x)}")) 

106 tmpRes = db.Get(cacheKeys) 

107 return sum([x["value"] for x in tmpRes if x and currentDateTime < x["expires"]]) <= self.maxRate 

108 

109 def assertQuotaIsAvailable(self, setRetryAfterHeader: bool = True) -> bool: 

110 """Assert quota is available. 

111 

112 If not quota is available a :class:`viur.core.errors.TooManyRequests` 

113 exception will be raised. 

114 

115 :param setRetryAfterHeader: Set the Retry-After header on the 

116 current request response, if the quota is exceeded. 

117 :return: True if quota is available. 

118 :raises: :exc:`viur.core.errors.TooManyRequests`, if no quote is available. 

119 """ 

120 if self.isQuotaAvailable(): 

121 return True 

122 if setRetryAfterHeader: 

123 current.request.get().response.headers["Retry-After"] = str(self.maxRate * 60) 

124 

125 raise errors.TooManyRequests( 

126 f"{self.steps} requests allowed per {self.maxRate} minute(s). Try again later." 

127 ) 

128 

129 

130@PeriodicTask(interval=datetime.timedelta(hours=1)) 

131def cleanOldRateLocks(*args, **kwargs) -> None: 

132 DeleteEntitiesIter.startIterOnQuery(db.Query(RateLimit.rateLimitKind).filter("expires <", utils.utcNow()))