限制HTTP请求速率(通过http.HandlerFunc中间件)

12

我想编写一个小型的速率限制中间件,它能够:

  1. 允许我为每个远程 IP 设置合理的速率(例如,每秒10个请求)
  2. 可能(但不一定)允许并发请求
  3. 关闭超过速率限制的连接,并返回HTTP 429

然后我可以将其包装在身份验证路由或其他可能受到暴力攻击的路由周围(例如使用过期令牌的密码重置URL等)。虽然暴力破解16或24字节的令牌的机会真的很低,但采取额外的防范措施也无妨。

我已经查看了https://code.google.com/p/go-wiki/wiki/RateLimiting,但不确定如何与http.Request(s)协调。此外,我不确定如何“跟踪”来自给定IP的请求。

理想情况下,我希望得到像这样的东西,注意我在反向代理(nginx)后面,因此我们检查的是REMOTE_ADDR HTTP标头而不是使用r.RemoteAddr

// Rate-limiting middleware
func rateLimit(h http.HandlerFunc) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {

        remoteIP := r.Header.Get("REMOTE_ADDR")
        for req := range (what here?) {
            // what here?
            // w.WriteHeader(429) and close the request if it exceeds the limit
            // else pass to the next handler in the chain
            h.ServeHTTP(w, r)
        }
}

// Example routes
r.HandleFunc("/login", use(loginForm, rateLimit, csrf)
r.HandleFunc("/form", use(editHandler, rateLimit, csrf)

// Middleware wrapper, for context
func use(h http.HandlerFunc, middleware ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
    for _, m := range middleware {
        h = m(h)
    }

    return h
}

我希望你能在这里给我一些指导。

4个回答

11

你提供的速率限制示例是一个通用示例。它使用范围(range)是因为它通过通道获取请求。

但对于HTTP请求来说,情况则不同,但这里并没有什么真正复杂的地方。请注意,您不会迭代请求通道或其他任何东西--您的HandlerFunc将为每个传入请求分别调用。

func rateLimit(h http.HandlerFunc) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {
        remoteIP := r.Header.Get("REMOTE_ADDR")
        if exceededTheLimit(remoteIP) {
            w.WriteHeader(429)
            // it then returns, not passing the request down the chain
        } else {
            h.ServeHTTP(w, r);
        }
    }       
}

现在,选择存储速率限制计数器的位置由您决定。一种解决方案是简单地使用全局映射(不要忘记安全并发访问),将IP映射到它们的请求计数器。但是,您必须了解请求是多久之前进行的。

Sergio建议使用Redis。它的键值特性非常适合这样的简单结构,并且您可以免费获得过期功能。


谢谢,Justinas。经过设计思考:某个关键字是路径+远程地址的简单计数器组合应该可以解决问题,并带有(可配置的)到期时间? 即 "SETEX",path + remoteIP,reqLimit,counter",其中计数器是在每个请求中检索、检查/增加并发送回Redis的整数? 但需要注意的是,在到期之前我们必须确保不会“重置”计数器。 - elithrar
1
我建议您使用SETEX path+remoteIP expirationPeriod counter,然后在每个请求上进行INCR。问题是,INCR不会重置到期时间。 - justinas
啊,好主意。查看Redis文档后发现他们实际上通过使用path+IP+time.Now().Second()作为键来覆盖了这个(简化的)实现。有什么缺点吗?这似乎是一种明智的方法。另外,对于任何竞争条件还是要谨慎,虽然这可能会被Redis自身解决(?)。 - elithrar
1
我仍然会选择path+IP作为键,计数器作为值,并设置过期时间,在每个请求上进行INCR操作。 INCR是一个原子操作,不应该存在竞争条件。 - justinas
我已经准备了一个快速示例:根据我所确定的(在漫长的一周结束时!),我们应该执行 INCR path+IP(如果键不存在,则将我们的键设置为1),检查返回值是否大于限制(返回 429 并跳过我们剩余的处理程序),否则,如果当前值等于1,则执行 SETEX path+IP, 1 second, 1 (在“新鲜”键上设置到期时间),否则将其传递给链中的下一个处理程序。这样做对吗?(在此之后我就让你走了!);) - elithrar
显示剩余2条评论

4
我今天早上做了一些简单而类似的事情,我认为它可能有助于你的情况。
package main

import (
    "log"
    "net/http"
    "strings"
    "time"
)

func main() {
    fs := http.FileServer(http.Dir("./html/"))
    http.Handle("/", fs)
    log.Println("Listening..")
    go clearLastRequestsIPs()
    go clearBlockedIPs()
    err := http.ListenAndServe(":8080", middleware(nil))
    if err != nil {
        log.Fatalln(err)
    }
}

// Stores last requests IPs
var lastRequestsIPs []string

// Block IP for 6 hours
var blockedIPs []string

func middleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        ipAddr := strings.Split(r.RemoteAddr, ":")[0]
        if existsBlockedIP(ipAddr) {
            http.Error(w, "", http.StatusTooManyRequests)
            return
        }
        // how many requests the current IP made in last 5 mins
        requestCounter := 0
        for _, ip := range lastRequestsIPs {
            if ip == ipAddr {
                requestCounter++
            }
        }
        if requestCounter >= 1000 {
            blockedIPs = append(blockedIPs, ipAddr)
            http.Error(w, "", http.StatusTooManyRequests)
            return
        }
        lastRequestsIPs = append(lastRequestsIPs, ipAddr)

        // Don't cut the chain of middlewares
        if next == nil {
            http.DefaultServeMux.ServeHTTP(w, r)
            return
        }
        next.ServeHTTP(w, r)
    })
}

func existsBlockedIP(ipAddr string) bool {
    for _, ip := range blockedIPs {
        if ip == ipAddr {
            return true
        }
    }
    return false
}

func existsLastRequest(ipAddr string) bool {
    for _, ip := range lastRequestsIPs {
        if ip == ipAddr {
            return true
        }
    }
    return false
}

// Clears lastRequestsIPs array every 5 mins
func clearLastRequestsIPs() {
    for {
        lastRequestsIPs = []string{}
        time.Sleep(time.Minute * 5)
    }
}

// Clears blockedIPs array every 6 hours
func clearBlockedIPs() {
    for {
        blockedIPs = []string{}
        time.Sleep(time.Hour * 6)
    }
}

目前它还不是很精确,但是它可以作为速率限制器的简单示例。你可以通过添加请求路径、HTTP方法甚至身份验证来改进它,以决定流量是否是攻击。


你需要保护(同步)对全局切片的写入/读取。 - elithrar
我用一个 map 简化了你的示例:https://play.golang.org/p/hmEguNLOMO - AESTHETICS

4

您可以使用Redis存储数据。以下是一条非常有用的命令,甚至在其文档中提到了速率限制应用程序:INCR。 Redis还将处理旧数据的清理(通过过期旧键)。

此外,将Redis作为速率限制器存储,您可以使用多个前端进程共享此中央存储。

有人会认为每次调用外部进程都很昂贵。但密码重置页面不是绝对需要最佳性能的页面。此外,如果将Redis放在同一台机器上,则延迟应该非常低。


我已经通过redistore在服务器端运行Redis进行会话管理(在更高的级别上),但是这里的最高性能并不是问题,所以我不介意依赖它。一些Flask中间件也利用Redis实现了相同的目标。然而,我不确定如何开始处理http.Request逻辑:一旦我解决了这个问题,我可能可以弄清楚如何使用Redigo。 - elithrar

2
这是我的速率限制中间件实现。它非常适合作为全局速率限制器或单个请求的速率限制器。我在我的应用程序中广泛使用它。
以下是它的功能:
  • 无外部依赖
  • 可测试
  • 可配置
  • 添加标头,以便客户端可以了解他们在达到限制之前还剩下多少请求等。
  • 自动删除过期数据。
首先是实现代码:
r := router.New()
stats := stats.New()
r.With(middleware.RateLimit(1, time.Minute * 1, stats)).Post("/contact", c.Contact)

这个中间件将允许每分钟发送一个请求,当向/contact发起一个POST请求时。

以下是该中间件:

package middleware

import (
    "net/http"
    "strconv"
    "time"
)

// Stats is an interface to an underlying hash table/map data
// structure. Implement it however you'd like.
type Stats interface {
    // Reset will reset the map.
    Reset()

    // Add would add "count" to the map at the key of "identifier",
    // and returns an int which is the total count of the value 
    // at that key.
    Add(identifier string, count int) int
}

// RateLimit middleware is a generic rate limiter that can be used in any scenario
// because it allows granular rate limiting for each specific request. Or you can
// set the rate limiter on the entire router group. It's just a HandlerFunc.
func RateLimit(limit int, window time.Duration, stats Stats) func(next http.Handler) http.Handler {
    var windowStart time.Time

    // Clear the rate limit stats after each window.
    ticker := time.NewTicker(window)
    go func() {
        windowStart = time.Now()

        for range ticker.C {
            windowStart = time.Now()
            stats.Reset()
        }
    }()

    return func(next http.Handler) http.Handler {
        h := func(w http.ResponseWriter, r *http.Request) {
            value := int(stats.Add(identifyRequest(r), 1))

            XRateLimitRemaining := limit - value
            if XRateLimitRemaining < 0 {
                XRateLimitRemaining = 0
            }

            w.Header().Add("X-Rate-Limit-Limit", strconv.Itoa(limit))
            w.Header().Add("X-Rate-Limit-Remaining", strconv.Itoa(XRateLimitRemaining))
            w.Header().Add("X-Rate-Limit-Reset", strconv.Itoa(int(window.Seconds()-time.Since(windowStart).Seconds())+1))

            if value >= limit {
                w.WriteHeader(429)
                // Do something else...
            } else {
                next.ServeHTTP(w, r)
            }
        }

        return http.HandlerFunc(h)
    }
}

// identifyRequest gets an identifier from the request context.
func identifyRequest(r *http.Request) string {
    // Identify your request here (get IP address, etc.)
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接