Skip to content

Middleware

Middleware are functions that process requests before they reach your route handlers and responses before they’re sent to clients. Zenith provides a comprehensive set of production-ready middleware.

Built-in Middleware (Production-Ready Security)

Section titled “Built-in Middleware (Production-Ready Security)”

Security Headers (Your First Line of Defense)

Section titled “Security Headers (Your First Line of Defense)”
from zenith import Zenith
from zenith.middleware import SecurityHeadersMiddleware, SecurityConfig
# Configure security headers to protect against common attacks
app = Zenith(
middleware=[
SecurityHeadersMiddleware(SecurityConfig(
# FORCE HTTPS - Redirect all HTTP to HTTPS
force_https=True,
# Why: Prevents man-in-the-middle attacks
# Sets: Redirect 301 from http:// to https://
# HSTS - HTTP Strict Transport Security
hsts_max_age=31536000, # 1 year in seconds
# Why: Tells browsers "NEVER use HTTP for this site"
# Sets: Strict-Transport-Security: max-age=31536000
# CONTENT TYPE SNIFFING PROTECTION
content_type_nosniff=True,
# Why: Prevents browsers from guessing content types
# Sets: X-Content-Type-Options: nosniff
# Blocks: <script> tags served as text/plain
# CLICKJACKING PROTECTION
frame_deny=True,
# Why: Prevents your site in hidden iframes
# Sets: X-Frame-Options: DENY
# Use "SAMEORIGIN" to allow your own iframes
# XSS PROTECTION (for older browsers)
xss_protection=True,
# Why: Enables browser XSS filters
# Sets: X-XSS-Protection: 1; mode=block
# CONTENT SECURITY POLICY (CSP)
csp="default-src 'self'"
# Why: Controls what resources can load
# Sets: Content-Security-Policy: default-src 'self'
# Means: Only load scripts/styles/images from same origin
# Advanced example:
# csp=(
# "default-src 'self'; "
# "script-src 'self' 'unsafe-inline' cdn.jsdelivr.net; "
# "style-src 'self' 'unsafe-inline' fonts.googleapis.com; "
# "img-src 'self' data: https:; "
# "font-src 'self' fonts.gstatic.com"
# )
))
]
)
# These headers prevent:
# XSS attacks (Cross-Site Scripting)
# Clickjacking (hidden iframes)
# MIME type confusion
# Protocol downgrade attacks
# Content injection

CORS (Let Your API Talk to Other Websites)

Section titled “CORS (Let Your API Talk to Other Websites)”
from zenith.middleware import CORSMiddleware, CORSConfig
# CORS = Cross-Origin Resource Sharing
# Problem: Browsers block requests between different domains
# Solution: CORS headers tell browser "it's OK to share"
app.add_middleware(CORSMiddleware, CORSConfig(
# ALLOWED ORIGINS - Who can call your API?
allow_origins=[
"https://example.com", # Your production frontend
"http://localhost:3000" # Local development
],
# Use ["*"] to allow ALL origins (risky!)
# ALLOWED METHODS - What HTTP methods can they use?
allow_methods=["GET", "POST", "PUT", "DELETE"],
# GET: Read data
# POST: Create data
# PUT/PATCH: Update data
# DELETE: Remove data
# ALLOWED HEADERS - What headers can they send?
allow_headers=["*"], # Allow all headers
# Or be specific: ["Content-Type", "Authorization"]
# CREDENTIALS - Can they send cookies/auth?
allow_credentials=True,
# True: Browser sends cookies and auth headers
# False: No cookies (more secure if not needed)
# ⚠️ Can't use credentials=True with origins=["*"]
# PREFLIGHT CACHE - How long to cache OPTIONS response
max_age=86400 # 24 hours in seconds
# Browser remembers "this origin is OK" for 24 hours
# Reduces preflight OPTIONS requests
))
# How CORS works:
# 1. Browser: "Can I call api.example.com from app.example.com?"
# 2. Your API: "Yes, here are the CORS headers"
# 3. Browser: "OK, making the actual request now"
from zenith.middleware import RateLimitMiddleware, RateLimitConfig
# Rate limiting prevents:
# - DDoS attacks
# - API abuse
# - Server overload
# - Expensive operation spam
app.add_middleware(RateLimitMiddleware, RateLimitConfig(
# DEFAULT LIMITS - Apply to all endpoints
default_limits=[
"100/minute", # Max 100 requests per minute
"1000/hour" # Max 1000 requests per hour
],
# Format: "count/period"
# Periods: second, minute, hour, day
# KEY FUNCTION - How to identify users?
key_func=lambda request: request.client.host,
# Options:
# - request.client.host (IP address - default)
# - request.headers.get("X-API-Key") (API key)
# - request.user.id (authenticated user)
# STORAGE BACKEND - Where to track counts?
storage="redis://localhost:6379", # Redis (distributed)
# OR
# storage="memory" # In-memory (single server only)
# Redis pros: Works across multiple servers
# Memory pros: No dependencies, faster
# RESPONSE HEADERS - Tell clients their limits
headers_enabled=True
# Adds headers:
# X-RateLimit-Limit: 100
# X-RateLimit-Remaining: 45
# X-RateLimit-Reset: 1634567890
))
# PER-ENDPOINT LIMITS - Override for specific routes
@app.get("/api/expensive", rate_limit="10/minute")
async def expensive_operation():
"""This endpoint is expensive, limit to 10 calls/minute."""
result = await perform_heavy_computation()
return {"result": result}
@app.post("/api/free", rate_limit="1000/minute")
async def free_operation():
"""This is cheap, allow more calls."""
return {"timestamp": datetime.utcnow()}
# What happens when rate limit exceeded:
# 1. Returns 429 Too Many Requests
# 2. Body: {"error": "Rate limit exceeded"}
# 3. Headers show when they can try again
from zenith.middleware import AuthMiddleware
from zenith.auth import JWTConfig
app.add_middleware(AuthMiddleware, {
"jwt_config": JWTConfig(
secret_key="your-secret-key",
algorithm="HS256",
expire_minutes=30
),
"exclude_paths": ["/auth/login", "/auth/register", "/health"]
})
from zenith.middleware import LoggingMiddleware
import logging
logging.basicConfig(level=logging.INFO)
app.add_middleware(LoggingMiddleware, {
"log_request_body": False, # Privacy consideration
"log_response_body": False,
"log_headers": ["User-Agent", "X-Request-ID"],
"exclude_paths": ["/health", "/metrics"]
})
from zenith.middleware import CompressionMiddleware
app.add_middleware(CompressionMiddleware, {
"minimum_size": 1024, # Only compress responses > 1KB
"gzip_level": 6,
"br_quality": 4, # Brotli quality (0-11)
"exclude_types": ["image/jpeg", "image/png"] # Already compressed
})
from zenith.middleware import RequestIDMiddleware
app.add_middleware(RequestIDMiddleware, {
"header_name": "X-Request-ID",
"generate": lambda: str(uuid.uuid4()),
"trust_header": False # Don't trust client-provided IDs
})
from zenith.middleware import CSRFMiddleware
app.add_middleware(CSRFMiddleware, {
"cookie_name": "_csrf_token",
"header_name": "X-CSRF-Token",
"safe_methods": ["GET", "HEAD", "OPTIONS"],
"cookie_secure": True, # HTTPS only
"cookie_samesite": "strict"
})
from zenith import Request, Response
from typing import Callable, Awaitable
import time
class TimingMiddleware:
"""Add response time header to all responses.
Middleware structure:
1. __init__: Store the next middleware/app
2. __call__: Process request and response
"""
def __init__(self, app: Callable[[Request], Awaitable[Response]]):
"""Initialize with the next middleware in chain."""
self.app = app # Next middleware or actual app
async def __call__(self, request: Request) -> Response:
"""Process the request/response.
This method:
1. Runs BEFORE the request is handled
2. Calls the next middleware/handler
3. Runs AFTER the response is generated
"""
# BEFORE REQUEST: Start timing
start = time.time()
# Could also log: print(f"Request: {request.method} {request.url.path}")
# PROCESS REQUEST: Call next middleware or route handler
response = await self.app(request)
# This goes through all remaining middleware and the actual route
# AFTER RESPONSE: Calculate and add timing
duration = time.time() - start
response.headers["X-Response-Time"] = f"{duration:.3f}s"
# Could also log slow requests:
if duration > 1.0: # Longer than 1 second
print(f"⚠️ Slow request: {request.url.path} took {duration:.3f}s")
return response # Pass response back up the chain
# Register the middleware
app.add_middleware(TimingMiddleware)
# Now every response includes:
# X-Response-Time: 0.042s
class APIKeyMiddleware:
"""Validate API key."""
def __init__(self, app, config: dict):
self.app = app
self.api_keys = config.get("api_keys", [])
self.header_name = config.get("header_name", "X-API-Key")
async def __call__(self, request: Request) -> Response:
# Skip for excluded paths
if request.url.path in ["/health", "/docs"]:
return await self.app(request)
# Check API key
api_key = request.headers.get(self.header_name)
if not api_key or api_key not in self.api_keys:
return JSONResponse(
{"error": "Invalid API key"},
status_code=401
)
return await self.app(request)
app.add_middleware(APIKeyMiddleware, {
"api_keys": ["key1", "key2"],
"header_name": "X-API-Key"
})
class DatabaseMiddleware:
"""Provide database connection."""
def __init__(self, app, config: dict):
self.app = app
self.db_url = config["database_url"]
self.pool = None
async def startup(self):
"""Initialize connection pool."""
self.pool = await create_pool(self.db_url)
async def shutdown(self):
"""Close connection pool."""
if self.pool:
await self.pool.close()
async def __call__(self, request: Request) -> Response:
async with self.pool.acquire() as conn:
request.state.db = conn
return await self.app(request)
app.add_middleware(DatabaseMiddleware, {
"database_url": "postgresql://localhost/mydb"
})
app = Zenith(
middleware=[
# REQUEST FLOW ↓ RESPONSE FLOW ↑
SecurityHeadersMiddleware({}), # 1st → ← 5th (last)
CORSMiddleware({}), # 2nd → ← 4th
RateLimitMiddleware({}), # 3rd → ← 3rd
AuthMiddleware({}), # 4th → ← 2nd
LoggingMiddleware({}) # 5th → ← 1st (first)
]
)
# How it works (like Russian dolls):
#
# REQUEST comes in:
# 1. SecurityHeaders checks protocol
# 2. CORS validates origin
# 3. RateLimit counts request
# 4. Auth verifies user
# 5. Logging records request
# 6. → Your route handler runs ←
#
# RESPONSE goes out:
# 5. Logging records response (first)
# 4. Auth adds user headers
# 3. RateLimit adds limit headers
# 2. CORS adds access headers
# 1. SecurityHeaders adds security headers (last)
# WHY ORDER MATTERS:
#
# Good order:
app = Zenith(
middleware=[
SecurityHeadersMiddleware({}), # Security first
CORSMiddleware({}), # Then CORS
RateLimitMiddleware({}), # Limit before auth (prevent brute force)
AuthMiddleware({}), # Authenticate valid requests
CompressionMiddleware({}), # Compress authenticated responses
LoggingMiddleware({}) # Log everything
]
)
# Bad order:
app = Zenith(
middleware=[
LoggingMiddleware({}), # Logs before rate limit
AuthMiddleware({}), # Auth before rate limit (wastes resources)
RateLimitMiddleware({}), # Too late, already did work
SecurityHeadersMiddleware({}), # Security should be first
CORSMiddleware({}), # CORS should be early
]
)
from zenith import Zenith
import os
app = Zenith()
# Only add in production
if os.getenv("ENVIRONMENT") == "production":
app.add_middleware(SecurityHeadersMiddleware, {
"force_https": True
})
app.add_middleware(RateLimitMiddleware, {
"default_limits": ["100/minute"]
})
# Always add CORS
app.add_middleware(CORSMiddleware, {
"allow_origins": os.getenv("CORS_ORIGINS", "*").split(",")
})

Organize middleware into logical groups:

def setup_security_middleware(app: Zenith):
"""Security-related middleware."""
app.add_middleware(SecurityHeadersMiddleware, {
"force_https": True
})
app.add_middleware(CSRFMiddleware, {
"cookie_secure": True
})
app.add_middleware(RateLimitMiddleware, {
"default_limits": ["100/minute"]
})
def setup_monitoring_middleware(app: Zenith):
"""Monitoring and observability."""
app.add_middleware(RequestIDMiddleware)
app.add_middleware(LoggingMiddleware, {
"exclude_paths": ["/health"]
})
app.add_middleware(MetricsMiddleware)
# Apply groups
app = Zenith()
setup_security_middleware(app)
setup_monitoring_middleware(app)
class FastMiddleware:
"""Minimal overhead middleware."""
__slots__ = ['app', 'config'] # Memory optimization
def __init__(self, app, config: dict):
self.app = app
self.config = config
async def __call__(self, request: Request) -> Response:
# Quick check, minimal processing
if self.should_skip(request):
return await self.app(request)
# Fast operation
request.state.processed = True
return await self.app(request)
def should_skip(self, request: Request) -> bool:
# O(1) lookup
return request.url.path in self.config.get('skip_paths', set())
from zenith.middleware import CacheMiddleware
app.add_middleware(CacheMiddleware, {
"backend": "redis://localhost:6379",
"default_ttl": 300, # 5 minutes
"key_prefix": "zenith:cache:",
"methods": ["GET", "HEAD"],
"status_codes": [200, 301, 308]
})
# Per-endpoint caching
@app.get("/api/data", cache_ttl=3600) # 1 hour
async def get_data():
return expensive_computation()
from zenith.testing import TestClient
import pytest
@pytest.mark.asyncio
async def test_rate_limit():
app = Zenith()
app.add_middleware(RateLimitMiddleware, {
"default_limits": ["5/minute"]
})
@app.get("/test")
async def test_endpoint():
return {"ok": True}
async with TestClient(app) as client:
# Should succeed for first 5 requests
for _ in range(5):
response = await client.get("/test")
assert response.status_code == 200
# Should fail on 6th request
response = await client.get("/test")
assert response.status_code == 429
assert "X-RateLimit-Remaining" in response.headers