Rate Limiting in a Go Backend


Rate limiting is a method of controlling the flow of requests received by a network, usually used to reduce strain on servers and protect against malicious actors, making it an important topic in backend development and security.

Over the past few months I’ve worked on a few Go Backend projects, namely Compared. As I’ve tinkered around with Go, I decided to create go-plate, a Go Backend boilerplate where I put together everything I’ve learned, one of those things being rate limiting.

For go-plate I’ve focused on 2 popular algorithms: Token Bucket & Sliding Window.

One could tackle this on the REST API server AND/OR the load balancer and in this post I’ll break down the Token Bucket algorithm and how I implemented it in a Go REST API.

Table of Contents:

  1. Token Bucket
  2. Implementation
  3. Distributed Systems
  4. How to use in API
  5. Going Further

Token Bucket

This algorithm uses a fixed capacity “bucket” that holds tokens. Tokens are added to the bucket at a fixed rate up to the bucket’s capacity. When a request is received, it attempts to consume one token. If there are tokens available, the request is allowed, if not, the request is either rejected or delayed until a token becomes available. This allows for handling bursts of traffic up to the bucket’s capacity while maintaining a long-term rate limit.

Token Bucket can accommodate large bursts by allowing immediate processing if tokens are available, but that can lead to bursty traffic and a less predictable output rate.

Implementation

type tokenBucket struct {
	tokens     float64
	lastRefill time.Time
}

type InMemoryTokenBucket struct {
	mu           sync.Mutex
	buckets      map[string]*tokenBucket
	rate         float64
	capacity     float64
	cleanupEvery time.Duration
}

func NewInMemoryTokenBucket(rate, capacity float64, cleanupInterval time.Duration) RateLimiter {
	if rate <= 0 || capacity <= 0 || cleanupInterval <= 0 {
		log.Fatalf("[FATAL] Invalid parameters for InMemoryTokenBucket")
	}

	tb := &InMemoryTokenBucket{
		buckets:      make(map[string]*tokenBucket),
		rate:         rate,
		capacity:     capacity,
		cleanupEvery: cleanupInterval,
	}
	go tb.cleanup()
	return tb
}

func (tb *InMemoryTokenBucket) cleanup() {
	for range time.Tick(tb.cleanupEvery) {
		tb.mu.Lock()
		for key, bucket := range tb.buckets {
			if time.Since(bucket.lastRefill) > tb.cleanupEvery {
				delete(tb.buckets, key)
			}
		}
		tb.mu.Unlock()
	}
}

Each InMemoryTokenBucket can guard an endpoint or set of endpoints. It stores in a map a users IP and a tokenBucket object. Each tokenBucketin turn holds the current amount of tokens and when the last one was added with lastRefill.

The rate dictates how many tokens get added to the bucket per second, and capacity determines the max capacity of the bucket.

Every cleanupEvery, the bucket checks if the last time it was refilled was more than cleanupEvery ago, and if so resets the bucket by deleting it. It’s definitely not the cleanest cleanup mechanism, but it works.

func (tb *InMemoryTokenBucket) Allow(key string) (bool, time.Duration, error) {
	tb.mu.Lock()
	defer tb.mu.Unlock()

	b, exists := tb.buckets[key]
	now := time.Now()
	if !exists {
		tb.buckets[key] = &tokenBucket{tokens: tb.capacity - 1, lastRefill: now}
		return true, 0, nil
	}

	elapsed := now.Sub(b.lastRefill).Seconds()
	b.tokens = min(tb.capacity, b.tokens+elapsed*tb.rate)
	b.lastRefill = now

	if b.tokens >= 1 {
		b.tokens--
		return true, 0, nil
	}

	retryAfter := time.Duration((1-b.tokens)/tb.rate) * time.Second
	return false, retryAfter, nil
}

The Allow method checks if the request should be allowed, and if not returns how long a user needs to wait before trying again.

It uses a mutex to ensure that access to the map is atomic, to prevent race conditions in heavy load situations.

If a bucket does not exist, (specific user IP is “new”/has not accessed the endpoint(s)) a new one is created with its max capacity - 1, to take into account the current request.

If it does exist, the elapsed time from the last refill is calculated, and the tokens are refilled based on that and the refill rate.

If there are enough tokens to satisfy the request (1 in this example), they are consumed and the request is allowed.

If not, it returns how long a user needs to wait before the bucket would be refilled.

This implementation works well for simple scenarios, but there is a problem.

When scalling a REST server, it is very common to do it horizontally, by adding more server instances behind a load balancer. This can become a problem for this implementation, as each instance will have its own Token Buckets, leading to unprecise rate limiting and unintended behaviour.

Distributed Systems

For such scenarios, a centralized data storage is needed. The algorithm is essentially the same, but instead of saving the tokens in memory, they are stored in a fast access cache like Redis.

One could easily pivot the implementation above with Redis commands SET and GET, but there is a gotcha:

Multiple clients might be reading and writing to the same entry, leading to a race condition. This can lead to more requests being allowed than intended.

To prevent this, we can use Lua scripts, which are executed atomically. When you run a Lua script, Redis ensures that no other client can execute commands while the script is running. And all the operations within the script are executed as a single atomic unit.

type RedisTokenBucket struct {
	redis         database.RedisStore
	rate          float64
	capacity      float64
	keyExpiration time.Duration
	ctx           context.Context
	limiterID     string
}

func NewRedisTokenBucket(limiterID string, redis database.RedisStore, rate, capacity float64, keyExpiration time.Duration) RateLimiter {
	if rate <= 0 || capacity <= 0 || keyExpiration <= 0 || redis == nil {
		log.Fatalf("[FATAL] Invalid parameters for RedisTokenBucket `%s`", limiterID)
	}

	return &RedisTokenBucket{
		redis:         redis,
		rate:          rate,
		capacity:      capacity,
		keyExpiration: keyExpiration,
		ctx:           context.Background(),
		limiterID:     limiterID,
	}
}

This implementation is similar except it stores the data in Redis. Since the data store will be shared by multiple rate limiters, it also needs a unique limiterID to ensure it reads and writes to the correct entry in Redis.

func (tb *RedisTokenBucket) Allow(key string) (bool, time.Duration, error) {
	allowed, retryAfter, err := tb.checkBucket(key)
	if err != nil {
		return false, 0, err
	}

	if allowed {
		return true, 0, nil
	} else {
		return false, time.Duration(retryAfter) * time.Second, nil
	}
}

func (tb *RedisTokenBucket) checkBucket(key string) (bool, time.Duration, error) {
	script := redis.NewScript(luaTokenBucket)

	now := float64(time.Now().UnixNano()) / 1e9
	keys := []string{tb.getRedisKey(key)}
	args := []interface{}{tb.rate, tb.capacity, now, int(tb.keyExpiration.Seconds())}

	result, err := script.Run(tb.ctx, tb.redis.GetNativeInstance().(*redis.Client), keys, args...).Result()
	if err != nil {
		return false, 0, fmt.Errorf("failed to run Lua script: %w", err)
	}

	resultSlice, ok := result.([]interface{})
	if !ok {
		return false, 0, fmt.Errorf("unexpected result type or length: %T", result)
	}

	allowed, ok := resultSlice[0].(int64)
	if !ok {
		return false, 0, fmt.Errorf("unexpected 'allowed' type: %T", resultSlice[0])
	}

	var retryAfterSeconds int64
	if allowed != 1 {
		retryAfterSeconds, ok = resultSlice[1].(int64)
		if !ok {
			return false, 0, fmt.Errorf("unexpected 'retry_after' type: %T", resultSlice[1])
		}
	}

	return allowed == 1, time.Duration(retryAfterSeconds), nil
}

func (tb *RedisTokenBucket) getRedisKey(key string) string {
	return fmt.Sprintf("ratelimit:token_bucket:%s:%s", tb.limiterID, key)
}

Similar to the previous implementation, we check if the request was allowed, and if not, return the retry time to the user. To do that, we create a new Lua script and run it on the Redis. Here’s where the magic happens:

var luaTokenBucket = `
local key = KEYS[1]
local rate = tonumber(ARGV[1])
local capacity = tonumber(ARGV[2])
local now = tonumber(ARGV[3])
local expire = tonumber(ARGV[4])

local bucket = redis.call("GET", key)
local tokens = capacity
local last_refill = now

if bucket then
    local data = cjson.decode(bucket)
    tokens = data.tokens
    last_refill = data.last_refill

	local elapsed = now - last_refill
    tokens = math.min(capacity, tokens + elapsed * rate)
    last_refill = now
end

local allowed = 0
local retry_after = 0

if tokens >= 1 then
    tokens = tokens - 1
    allowed = 1
else
    allowed = 0
    retry_after = (1 - tokens) / rate
end

local new_bucket = cjson.encode({tokens=tokens, last_refill=last_refill})
redis.call("SET", key, new_bucket, "EX", expire)

return {allowed, retry_after}
`

As previously stated, Lua must be used for this kind of operation to ensure correctness. It receives the endpoint key, refill rate, capacity, now and expire (time).

It behaves exactly the same as before, but has a cleaner cleanup mechanism with the Redis TTL mechanism.

How to use in API

To use these in an actual API, I chain them into my middleware, creating different rate limiters depending on the endpoint.

There are a variety of things to consider, like if the endpoint accesses a centralized resource that could easily be strained, like a DB, or if you want to limit the amount of sensitive operations a user can perform in a time frame, like login attempts or changing a password.

r := chi.NewRouter()
r.Group(func(r chi.Router) {
	r.Use(middleware.RateLimitMiddleware(ratelimiting.NewRedisTokenBucket("/posts", redis, 0.025, 200, 10*time.Minute)))
	
	r.Get("/user/{id}", getPosts)
})

To test their efficacy, you can use hey to stress test them locally: hey -n 200 -c 10 -H "X-Real-IP: 192.168.1.100" http://localhost:8080/user/1.

Going Further

On a surface level, one wants to protect their API, while ensuring a positive experience for genuine users.

Beyond the basics, rate limiting is a deep topic with many aspects to consider:

Complex challenges can arise in accurately mapping requests to users, especially for public facing APIs. There are more advanced techniques beyond IP-based identification, such as API keys, OAuth tokens, and device fingerprinting.

Handling scale can also be an issue, as it usually is with everything when it comes to web development. It can be hard to balance accuracy with performance in high-traffic environments.

There are also multiple ways algorithms can be implemented, like having an adaptive rate limiting based on server load or user behavior.

For some APIs, it can also make sense to have tiered usage for for different types of users or subscriptions.

Rate limiting can certainly encompass even more things, but hopefully these examples can provide a basic intro to rate limiting and how it can be implemented on a Go backend.