103 lines
2.4 KiB
Go
103 lines
2.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
"users_management/m/utils/common"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/time/rate"
|
|
)
|
|
|
|
var (
|
|
rateLimiters = make(map[string]*rate.Limiter)
|
|
mu sync.Mutex
|
|
)
|
|
|
|
func getRateLimiter(userID string) *rate.Limiter {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
limiter, exists := rateLimiters[userID]
|
|
if !exists {
|
|
limiter = rate.NewLimiter(rate.Every(1*time.Minute), 50) // 50 requests per minute
|
|
rateLimiters[userID] = limiter
|
|
}
|
|
|
|
return limiter
|
|
}
|
|
|
|
func getLoginLimiter() *rate.Limiter {
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
|
|
limiter, exists := rateLimiters["login"]
|
|
if !exists {
|
|
limiter = rate.NewLimiter(rate.Every(1*time.Minute), 10) // 10 requests per minute for login
|
|
rateLimiters["login"] = limiter
|
|
}
|
|
|
|
return limiter
|
|
}
|
|
|
|
func RateLimitMiddleware() gin.HandlerFunc{
|
|
return func(c *gin.Context) {
|
|
if c.Request.Method == http.MethodOptions {
|
|
c.Next()
|
|
return
|
|
}
|
|
|
|
userID, exists := c.Get("userID")
|
|
if !exists {
|
|
common.ErrorResponses(c, http.StatusUnauthorized, "Unauthorized: No user ID found")
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
// Convert UUID to string for rate limiting
|
|
var userIDStr string
|
|
switch v := userID.(type) {
|
|
case uuid.UUID:
|
|
userIDStr = v.String()
|
|
case string:
|
|
userIDStr = v
|
|
default:
|
|
log.Printf("Unexpected userID type: %T", userID)
|
|
common.ErrorResponses(c, http.StatusInternalServerError, "Invalid user ID type")
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
log.Printf("User ID for rate limiting: %s", userIDStr)
|
|
|
|
limiter := getRateLimiter(userIDStr)
|
|
|
|
if !limiter.Allow() {
|
|
common.ErrorResponses(c, http.StatusTooManyRequests, "Too many requests")
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func RateLoginMiddleware() gin.HandlerFunc{
|
|
return func(c *gin.Context) {
|
|
if c.Request.Method == http.MethodOptions {
|
|
c.Next()
|
|
return
|
|
}
|
|
limiter := getLoginLimiter()
|
|
|
|
if !limiter.Allow() {
|
|
common.ErrorResponses(c, http.StatusTooManyRequests, "Too many requests")
|
|
c.Abort()
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
} |