Middleware
What is Middleware?
Section titled “What is Middleware?”Middleware are functions that process requests before they reach your route handlers and responses before they’re sent to clients. Zenith provides a comprehensive set of production-ready middleware.
Built-in Middleware (Production-Ready Security)
Section titled “Built-in Middleware (Production-Ready Security)”Security Headers (Your First Line of Defense)
Section titled “Security Headers (Your First Line of Defense)”from zenith import Zenithfrom zenith.middleware import SecurityHeadersMiddleware, SecurityConfig
# Configure security headers to protect against common attacksapp = Zenith( middleware=[ SecurityHeadersMiddleware(SecurityConfig( # FORCE HTTPS - Redirect all HTTP to HTTPS force_https=True, # Why: Prevents man-in-the-middle attacks # Sets: Redirect 301 from http:// to https://
# HSTS - HTTP Strict Transport Security hsts_max_age=31536000, # 1 year in seconds # Why: Tells browsers "NEVER use HTTP for this site" # Sets: Strict-Transport-Security: max-age=31536000
# CONTENT TYPE SNIFFING PROTECTION content_type_nosniff=True, # Why: Prevents browsers from guessing content types # Sets: X-Content-Type-Options: nosniff # Blocks: <script> tags served as text/plain
# CLICKJACKING PROTECTION frame_deny=True, # Why: Prevents your site in hidden iframes # Sets: X-Frame-Options: DENY # Use "SAMEORIGIN" to allow your own iframes
# XSS PROTECTION (for older browsers) xss_protection=True, # Why: Enables browser XSS filters # Sets: X-XSS-Protection: 1; mode=block
# CONTENT SECURITY POLICY (CSP) csp="default-src 'self'" # Why: Controls what resources can load # Sets: Content-Security-Policy: default-src 'self' # Means: Only load scripts/styles/images from same origin # Advanced example: # csp=( # "default-src 'self'; " # "script-src 'self' 'unsafe-inline' cdn.jsdelivr.net; " # "style-src 'self' 'unsafe-inline' fonts.googleapis.com; " # "img-src 'self' data: https:; " # "font-src 'self' fonts.gstatic.com" # ) )) ])
# These headers prevent:# XSS attacks (Cross-Site Scripting)# Clickjacking (hidden iframes)# MIME type confusion# Protocol downgrade attacks# Content injectionCORS (Let Your API Talk to Other Websites)
Section titled “CORS (Let Your API Talk to Other Websites)”from zenith.middleware import CORSMiddleware, CORSConfig
# CORS = Cross-Origin Resource Sharing# Problem: Browsers block requests between different domains# Solution: CORS headers tell browser "it's OK to share"
app.add_middleware(CORSMiddleware, CORSConfig( # ALLOWED ORIGINS - Who can call your API? allow_origins=[ "https://example.com", # Your production frontend "http://localhost:3000" # Local development ], # Use ["*"] to allow ALL origins (risky!)
# ALLOWED METHODS - What HTTP methods can they use? allow_methods=["GET", "POST", "PUT", "DELETE"], # GET: Read data # POST: Create data # PUT/PATCH: Update data # DELETE: Remove data
# ALLOWED HEADERS - What headers can they send? allow_headers=["*"], # Allow all headers # Or be specific: ["Content-Type", "Authorization"]
# CREDENTIALS - Can they send cookies/auth? allow_credentials=True, # True: Browser sends cookies and auth headers # False: No cookies (more secure if not needed) # ⚠️ Can't use credentials=True with origins=["*"]
# PREFLIGHT CACHE - How long to cache OPTIONS response max_age=86400 # 24 hours in seconds # Browser remembers "this origin is OK" for 24 hours # Reduces preflight OPTIONS requests))
# How CORS works:# 1. Browser: "Can I call api.example.com from app.example.com?"# 2. Your API: "Yes, here are the CORS headers"# 3. Browser: "OK, making the actual request now"Rate Limiting (Prevent Abuse & DDoS)
Section titled “Rate Limiting (Prevent Abuse & DDoS)”from zenith.middleware import RateLimitMiddleware, RateLimitConfig
# Rate limiting prevents:# - DDoS attacks# - API abuse# - Server overload# - Expensive operation spam
app.add_middleware(RateLimitMiddleware, RateLimitConfig( # DEFAULT LIMITS - Apply to all endpoints default_limits=[ "100/minute", # Max 100 requests per minute "1000/hour" # Max 1000 requests per hour ], # Format: "count/period" # Periods: second, minute, hour, day
# KEY FUNCTION - How to identify users? key_func=lambda request: request.client.host, # Options: # - request.client.host (IP address - default) # - request.headers.get("X-API-Key") (API key) # - request.user.id (authenticated user)
# STORAGE BACKEND - Where to track counts? storage="redis://localhost:6379", # Redis (distributed) # OR # storage="memory" # In-memory (single server only)
# Redis pros: Works across multiple servers # Memory pros: No dependencies, faster
# RESPONSE HEADERS - Tell clients their limits headers_enabled=True # Adds headers: # X-RateLimit-Limit: 100 # X-RateLimit-Remaining: 45 # X-RateLimit-Reset: 1634567890))
# PER-ENDPOINT LIMITS - Override for specific routes@app.get("/api/expensive", rate_limit="10/minute")async def expensive_operation(): """This endpoint is expensive, limit to 10 calls/minute.""" result = await perform_heavy_computation() return {"result": result}
@app.post("/api/free", rate_limit="1000/minute")async def free_operation(): """This is cheap, allow more calls.""" return {"timestamp": datetime.utcnow()}
# What happens when rate limit exceeded:# 1. Returns 429 Too Many Requests# 2. Body: {"error": "Rate limit exceeded"}# 3. Headers show when they can try againAuthentication
Section titled “Authentication”from zenith.middleware import AuthMiddlewarefrom zenith.auth import JWTConfig
app.add_middleware(AuthMiddleware, { "jwt_config": JWTConfig( secret_key="your-secret-key", algorithm="HS256", expire_minutes=30 ), "exclude_paths": ["/auth/login", "/auth/register", "/health"]})Request Logging
Section titled “Request Logging”from zenith.middleware import LoggingMiddlewareimport logging
logging.basicConfig(level=logging.INFO)
app.add_middleware(LoggingMiddleware, { "log_request_body": False, # Privacy consideration "log_response_body": False, "log_headers": ["User-Agent", "X-Request-ID"], "exclude_paths": ["/health", "/metrics"]})Compression
Section titled “Compression”from zenith.middleware import CompressionMiddleware
app.add_middleware(CompressionMiddleware, { "minimum_size": 1024, # Only compress responses > 1KB "gzip_level": 6, "br_quality": 4, # Brotli quality (0-11) "exclude_types": ["image/jpeg", "image/png"] # Already compressed})Request ID
Section titled “Request ID”from zenith.middleware import RequestIDMiddleware
app.add_middleware(RequestIDMiddleware, { "header_name": "X-Request-ID", "generate": lambda: str(uuid.uuid4()), "trust_header": False # Don't trust client-provided IDs})CSRF Protection
Section titled “CSRF Protection”from zenith.middleware import CSRFMiddleware
app.add_middleware(CSRFMiddleware, { "cookie_name": "_csrf_token", "header_name": "X-CSRF-Token", "safe_methods": ["GET", "HEAD", "OPTIONS"], "cookie_secure": True, # HTTPS only "cookie_samesite": "strict"})Custom Middleware
Section titled “Custom Middleware”Basic Middleware (Measure Response Time)
Section titled “Basic Middleware (Measure Response Time)”from zenith import Request, Responsefrom typing import Callable, Awaitableimport time
class TimingMiddleware: """Add response time header to all responses.
Middleware structure: 1. __init__: Store the next middleware/app 2. __call__: Process request and response """
def __init__(self, app: Callable[[Request], Awaitable[Response]]): """Initialize with the next middleware in chain.""" self.app = app # Next middleware or actual app
async def __call__(self, request: Request) -> Response: """Process the request/response.
This method: 1. Runs BEFORE the request is handled 2. Calls the next middleware/handler 3. Runs AFTER the response is generated """
# BEFORE REQUEST: Start timing start = time.time() # Could also log: print(f"Request: {request.method} {request.url.path}")
# PROCESS REQUEST: Call next middleware or route handler response = await self.app(request) # This goes through all remaining middleware and the actual route
# AFTER RESPONSE: Calculate and add timing duration = time.time() - start response.headers["X-Response-Time"] = f"{duration:.3f}s"
# Could also log slow requests: if duration > 1.0: # Longer than 1 second print(f"⚠️ Slow request: {request.url.path} took {duration:.3f}s")
return response # Pass response back up the chain
# Register the middlewareapp.add_middleware(TimingMiddleware)
# Now every response includes:# X-Response-Time: 0.042sMiddleware with Configuration
Section titled “Middleware with Configuration”class APIKeyMiddleware: """Validate API key."""
def __init__(self, app, config: dict): self.app = app self.api_keys = config.get("api_keys", []) self.header_name = config.get("header_name", "X-API-Key")
async def __call__(self, request: Request) -> Response: # Skip for excluded paths if request.url.path in ["/health", "/docs"]: return await self.app(request)
# Check API key api_key = request.headers.get(self.header_name) if not api_key or api_key not in self.api_keys: return JSONResponse( {"error": "Invalid API key"}, status_code=401 )
return await self.app(request)
app.add_middleware(APIKeyMiddleware, { "api_keys": ["key1", "key2"], "header_name": "X-API-Key"})Async Middleware
Section titled “Async Middleware”class DatabaseMiddleware: """Provide database connection."""
def __init__(self, app, config: dict): self.app = app self.db_url = config["database_url"] self.pool = None
async def startup(self): """Initialize connection pool.""" self.pool = await create_pool(self.db_url)
async def shutdown(self): """Close connection pool.""" if self.pool: await self.pool.close()
async def __call__(self, request: Request) -> Response: async with self.pool.acquire() as conn: request.state.db = conn return await self.app(request)
app.add_middleware(DatabaseMiddleware, { "database_url": "postgresql://localhost/mydb"})Middleware Order (First In, Last Out)
Section titled “Middleware Order (First In, Last Out)”app = Zenith( middleware=[ # REQUEST FLOW ↓ RESPONSE FLOW ↑ SecurityHeadersMiddleware({}), # 1st → ← 5th (last) CORSMiddleware({}), # 2nd → ← 4th RateLimitMiddleware({}), # 3rd → ← 3rd AuthMiddleware({}), # 4th → ← 2nd LoggingMiddleware({}) # 5th → ← 1st (first) ])
# How it works (like Russian dolls):## REQUEST comes in:# 1. SecurityHeaders checks protocol# 2. CORS validates origin# 3. RateLimit counts request# 4. Auth verifies user# 5. Logging records request# 6. → Your route handler runs ←## RESPONSE goes out:# 5. Logging records response (first)# 4. Auth adds user headers# 3. RateLimit adds limit headers# 2. CORS adds access headers# 1. SecurityHeaders adds security headers (last)
# WHY ORDER MATTERS:## Good order:app = Zenith( middleware=[ SecurityHeadersMiddleware({}), # Security first CORSMiddleware({}), # Then CORS RateLimitMiddleware({}), # Limit before auth (prevent brute force) AuthMiddleware({}), # Authenticate valid requests CompressionMiddleware({}), # Compress authenticated responses LoggingMiddleware({}) # Log everything ])
# Bad order:app = Zenith( middleware=[ LoggingMiddleware({}), # Logs before rate limit AuthMiddleware({}), # Auth before rate limit (wastes resources) RateLimitMiddleware({}), # Too late, already did work SecurityHeadersMiddleware({}), # Security should be first CORSMiddleware({}), # CORS should be early ])Conditional Middleware
Section titled “Conditional Middleware”from zenith import Zenithimport os
app = Zenith()
# Only add in productionif os.getenv("ENVIRONMENT") == "production": app.add_middleware(SecurityHeadersMiddleware, { "force_https": True }) app.add_middleware(RateLimitMiddleware, { "default_limits": ["100/minute"] })
# Always add CORSapp.add_middleware(CORSMiddleware, { "allow_origins": os.getenv("CORS_ORIGINS", "*").split(",")})Middleware Groups
Section titled “Middleware Groups”Organize middleware into logical groups:
def setup_security_middleware(app: Zenith): """Security-related middleware.""" app.add_middleware(SecurityHeadersMiddleware, { "force_https": True }) app.add_middleware(CSRFMiddleware, { "cookie_secure": True }) app.add_middleware(RateLimitMiddleware, { "default_limits": ["100/minute"] })
def setup_monitoring_middleware(app: Zenith): """Monitoring and observability.""" app.add_middleware(RequestIDMiddleware) app.add_middleware(LoggingMiddleware, { "exclude_paths": ["/health"] }) app.add_middleware(MetricsMiddleware)
# Apply groupsapp = Zenith()setup_security_middleware(app)setup_monitoring_middleware(app)Performance Considerations
Section titled “Performance Considerations”Lightweight Middleware
Section titled “Lightweight Middleware”class FastMiddleware: """Minimal overhead middleware."""
__slots__ = ['app', 'config'] # Memory optimization
def __init__(self, app, config: dict): self.app = app self.config = config
async def __call__(self, request: Request) -> Response: # Quick check, minimal processing if self.should_skip(request): return await self.app(request)
# Fast operation request.state.processed = True return await self.app(request)
def should_skip(self, request: Request) -> bool: # O(1) lookup return request.url.path in self.config.get('skip_paths', set())Caching Middleware
Section titled “Caching Middleware”from zenith.middleware import CacheMiddleware
app.add_middleware(CacheMiddleware, { "backend": "redis://localhost:6379", "default_ttl": 300, # 5 minutes "key_prefix": "zenith:cache:", "methods": ["GET", "HEAD"], "status_codes": [200, 301, 308]})
# Per-endpoint caching@app.get("/api/data", cache_ttl=3600) # 1 hourasync def get_data(): return expensive_computation()Testing Middleware
Section titled “Testing Middleware”from zenith.testing import TestClientimport pytest
@pytest.mark.asyncioasync def test_rate_limit(): app = Zenith() app.add_middleware(RateLimitMiddleware, { "default_limits": ["5/minute"] })
@app.get("/test") async def test_endpoint(): return {"ok": True}
async with TestClient(app) as client: # Should succeed for first 5 requests for _ in range(5): response = await client.get("/test") assert response.status_code == 200
# Should fail on 6th request response = await client.get("/test") assert response.status_code == 429 assert "X-RateLimit-Remaining" in response.headersNext Steps
Section titled “Next Steps”- Implement Authentication with middleware
- Learn about Database middleware
- Explore Performance optimization