Middleware API
Middleware Overview
Section titled “Middleware Overview”Middleware in Zenith processes requests and responses, providing cross-cutting functionality like authentication, logging, and security headers.
Built-in Middleware
Section titled “Built-in Middleware”SecurityHeadersMiddleware
Section titled “SecurityHeadersMiddleware”Adds security headers to protect against common attacks:
from zenith.middleware import SecurityHeadersMiddleware, SecurityConfig
# Recommended: Using SecurityConfig classconfig = SecurityConfig( # HTTPS enforcement force_https=True,
# HSTS (HTTP Strict Transport Security) hsts_max_age=31536000, # 1 year hsts_include_subdomains=True, hsts_preload=True,
# Content security content_type_nosniff=True, frame_options="DENY", xss_protection="1; mode=block",
# Content Security Policy csp_policy="default-src 'self'; script-src 'self' 'unsafe-inline'", csp_report_only=False,
# Referrer policy referrer_policy="strict-origin-when-cross-origin",
# Permissions policy (formerly Feature Policy) permissions_policy="geolocation=(), microphone=()")
app.add_middleware(SecurityHeadersMiddleware, config=config)SecurityConfig Parameters
Section titled “SecurityConfig Parameters”| Parameter | Type | Default | Description |
|---|---|---|---|
force_https | bool | False | Redirect HTTP to HTTPS |
force_https_permanent | bool | False | Use 301 redirect instead of 302 |
hsts_max_age | int | 31536000 | HSTS max age in seconds |
hsts_include_subdomains | bool | True | Include subdomains in HSTS |
hsts_preload | bool | False | Enable HSTS preload |
content_type_nosniff | bool | True | Add X-Content-Type-Options: nosniff |
frame_options | str | "DENY" | X-Frame-Options header value |
xss_protection | str | "1; mode=block" | X-XSS-Protection header value |
csp_policy | str | None | Content Security Policy |
csp_report_only | bool | False | Use CSP report-only mode |
referrer_policy | str | "strict-origin-when-cross-origin" | Referrer policy |
permissions_policy | str | None | Permissions policy header |
trusted_proxies | list[str] | [] | List of trusted proxy IPs |
CORSMiddleware
Section titled “CORSMiddleware”Handles Cross-Origin Resource Sharing:
from zenith.middleware import CORSMiddleware, CORSConfig
# Recommended: Using CORSConfig classconfig = CORSConfig( allow_origins=["https://example.com", "https://app.example.com"], allow_origin_regex=r"https://.*\.example\.com", allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Authorization", "Content-Type"], expose_headers=["X-Request-ID", "X-Process-Time"], allow_credentials=True, max_age_secs=86400 # 24 hours)
app.add_middleware(CORSMiddleware, config=config)
# Alternative: Direct parameters (backward compatibility)app.add_middleware( CORSMiddleware, allow_origins=["https://example.com"], allow_methods=["GET", "POST"], allow_credentials=True)CORSConfig Parameters
Section titled “CORSConfig Parameters”| Parameter | Type | Default | Description |
|---|---|---|---|
allow_origins | list[str] | [] | Allowed origin URLs |
allow_origin_regex | str | None | Regex pattern for allowed origins |
allow_methods | list[str] | ["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"] | Allowed HTTP methods |
allow_headers | list[str] | ["*"] | Allowed request headers |
allow_credentials | bool | False | Allow credentials in CORS requests |
expose_headers | list[str] | [] | Headers exposed to browser |
max_age_secs | int | 600 | Preflight cache duration in seconds |
RateLimitMiddleware
Section titled “RateLimitMiddleware”Implements rate limiting with multiple strategies:
from zenith.middleware import RateLimitMiddleware, RateLimitConfig, RateLimitfrom zenith.middleware import MemoryRateLimitStorage, RedisRateLimitStorage
# Basic configurationconfig = RateLimitConfig( default_limits=[ RateLimit(requests=100, window=60), # 100/minute RateLimit(requests=1000, window=3600), # 1000/hour ], storage=MemoryRateLimitStorage(), exempt_paths=["/health", "/metrics"], error_message="Rate limit exceeded", include_headers=True)
app.add_middleware(RateLimitMiddleware, config=config)
# Redis storage for distributed systemsredis_storage = RedisRateLimitStorage("redis://localhost:6379/0")redis_config = RateLimitConfig( default_limits=[RateLimit(requests=1000, window=60)], storage=redis_storage)
app.add_middleware(RateLimitMiddleware, config=redis_config)Custom Rate Limit Key Functions
Section titled “Custom Rate Limit Key Functions”# By user ID (requires authentication)def user_rate_limit_key(request): user = getattr(request.state, 'user', None) return f"user:{user.id}" if user else request.client.host
# By API keydef api_key_rate_limit_key(request): api_key = request.headers.get('X-API-Key') return f"api_key:{api_key}" if api_key else request.client.host
# By user tierdef tiered_rate_limit_key(request): user = getattr(request.state, 'user', None) if user: return f"user:{user.id}:tier:{user.tier}" return request.client.host
app.add_middleware(RateLimitMiddleware, { "key_func": tiered_rate_limit_key, "default_limits": ["100/minute"], "tier_limits": { "premium": ["1000/minute"], "enterprise": ["10000/minute"] }})AuthMiddleware
Section titled “AuthMiddleware”Handles authentication for protected routes:
from zenith.middleware import AuthMiddlewarefrom zenith.auth import JWTConfig
app.add_middleware(AuthMiddleware, { # JWT configuration "jwt_config": JWTConfig( secret_key="your-secret-key", algorithm="HS256", access_token_expire_minutes=30 ),
# Token extraction "token_url": "/auth/token", "header_name": "Authorization", "header_scheme": "Bearer", "cookie_name": "access_token",
# Path exclusions "exclude_paths": [ "/", "/health", "/auth/login", "/auth/register", "/docs", "/openapi.json" ],
# Optional authentication "optional_auth_paths": ["/api/public"],
# Auto error handling "auto_error": True,
# Custom user loader "user_loader": async_get_user_by_id})LoggingMiddleware
Section titled “LoggingMiddleware”Structured request/response logging:
from zenith.middleware import LoggingMiddlewareimport logging
logging.basicConfig(level=logging.INFO)
app.add_middleware(LoggingMiddleware, { # Basic settings "logger_name": "zenith.requests", "log_level": logging.INFO,
# Request logging "log_request_body": False, # Privacy consideration "log_request_headers": ["User-Agent", "X-Request-ID"], "max_body_size": 1024, # Truncate large bodies
# Response logging "log_response_body": False, "log_response_headers": ["Content-Type", "X-Process-Time"],
# Filtering "exclude_paths": ["/health", "/metrics"], "exclude_methods": ["OPTIONS"],
# Format customization "format_string": "{method} {url} - {status_code} {process_time}ms",
# Sensitive data masking "sensitive_headers": ["Authorization", "X-API-Key"], "mask_value": "[REDACTED]"})CompressionMiddleware
Section titled “CompressionMiddleware”Gzip and Brotli compression:
from zenith.middleware import CompressionMiddleware
app.add_middleware(CompressionMiddleware, { # Compression settings "minimum_size": 1024, # Only compress responses > 1KB
# Gzip settings "gzip_enabled": True, "gzip_level": 6, # 1-9, higher = more compression
# Brotli settings (if brotli package installed) "brotli_enabled": True, "brotli_quality": 4, # 0-11 "brotli_mode": 0, # 0=generic, 1=text, 2=font
# Content type filters "include_media_types": [ "application/json", "application/javascript", "text/css", "text/html", "text/plain", "text/xml" ],
"exclude_media_types": [ "image/jpeg", "image/png", "image/gif", "application/zip", "video/mp4" ]})RequestIDMiddleware
Section titled “RequestIDMiddleware”Adds unique request identifiers:
from zenith.middleware import RequestIDMiddlewareimport uuid
app.add_middleware(RequestIDMiddleware, { # ID generation "generate_id": lambda: str(uuid.uuid4()),
# Headers "request_id_header": "X-Request-ID", "response_id_header": "X-Request-ID",
# Trust client-provided IDs "trust_incoming_id": False,
# Validation "validate_id": lambda id: len(id) > 0 and len(id) < 100})Specialized Middleware
Section titled “Specialized Middleware”These middleware serve specific use cases and can be added when your application needs their functionality.
Built-in Optimizations
Section titled “Built-in Optimizations”Zenith includes several performance optimizations that work automatically without any configuration:
Database Optimization:
- Automatic request-scoped session reuse
- Connection pooling with smart defaults
- 15-25% performance improvement
Server-Sent Events:
- Built-in backpressure handling
- 10x more concurrent connections
- Memory-efficient streaming
Database Optimization (Built-in)
Section titled “Database Optimization (Built-in)”Zenith automatically optimizes database connections with request-scoped session reuse:
# No configuration needed - automatic optimizationfrom zenith import Database
db = Database(url="postgresql+asyncpg://user:pass@localhost/db")Key Features:
- Automatic request-scoped session reuse
- Connection pooling with smart defaults
- 15-25% performance improvement
- Zero configuration required
Prevents memory issues in SSE connections with intelligent backpressure:
Server-Sent Events (Built-in)
Section titled “Server-Sent Events (Built-in)”Zenith provides built-in SSE support with automatic backpressure handling:
from zenith import create_sse_response
@app.get("/events")async def stream_events(): async def event_generator(): for i in range(100): yield {"type": "update", "data": {"count": i}} await asyncio.sleep(1)
return create_sse_response(event_generator())Key Features:
- Automatic backpressure handling
- Handles 10x more concurrent connections
- Memory-efficient streaming
- Clean async generator API
CSRFMiddleware
Section titled “CSRFMiddleware”Cross-Site Request Forgery protection:
from zenith.middleware import CSRFMiddleware
app.add_middleware(CSRFMiddleware, { # Token settings "secret_key": "your-csrf-secret", "token_length": 32,
# Cookie settings "cookie_name": "csrftoken", "cookie_domain": None, "cookie_path": "/", "cookie_secure": True, "cookie_httponly": True, "cookie_samesite": "strict", "cookie_max_age": 3600,
# Header settings "header_name": "X-CSRFToken",
# Method exemptions "safe_methods": ["GET", "HEAD", "OPTIONS", "TRACE"],
# Path exemptions "exempt_paths": ["/api/webhook"],
# Error handling "failure_status_code": 403, "failure_message": "CSRF token missing or invalid"})Creating Custom Middleware
Section titled “Creating Custom Middleware”Basic Custom Middleware
Section titled “Basic Custom Middleware”from zenith import Request, Responsefrom typing import Callable, Awaitable
class TimingMiddleware: """Add response time header to all responses."""
def __init__(self, app: Callable[[Request], Awaitable[Response]]): self.app = app
async def __call__(self, request: Request) -> Response: import time start_time = time.time()
# Process request response = await self.app(request)
# Add timing header process_time = time.time() - start_time response.headers["X-Process-Time"] = f"{process_time:.3f}"
return response
# Add to applicationapp.add_middleware(TimingMiddleware)Configurable Custom Middleware
Section titled “Configurable Custom Middleware”class CustomLoggingMiddleware: """Custom logging middleware with configuration."""
def __init__(self, app, config: dict): self.app = app self.log_requests = config.get('log_requests', True) self.log_responses = config.get('log_responses', True) self.exclude_paths = set(config.get('exclude_paths', [])) self.logger = logging.getLogger(config.get('logger_name', __name__))
async def __call__(self, request: Request) -> Response: # Skip excluded paths if request.url.path in self.exclude_paths: return await self.app(request)
# Log request if self.log_requests: self.logger.info(f"Request: {request.method} {request.url}")
# Process request response = await self.app(request)
# Log response if self.log_responses: self.logger.info(f"Response: {response.status_code}")
return response
# Add with configurationapp.add_middleware(CustomLoggingMiddleware, { 'log_requests': True, 'log_responses': False, 'exclude_paths': ['/health', '/metrics'], 'logger_name': 'my_app.requests'})Async Middleware with External Services
Section titled “Async Middleware with External Services”import aiohttpimport asynciofrom typing import Dict, Any
class GeoLocationMiddleware: """Add geolocation data based on client IP."""
def __init__(self, app, config: dict): self.app = app self.api_key = config.get('api_key') self.cache = {} # Simple in-memory cache self.cache_ttl = config.get('cache_ttl', 3600) self.enabled = config.get('enabled', True)
async def __call__(self, request: Request) -> Response: if self.enabled and self.api_key: # Get client IP client_ip = request.client.host
# Skip private/local IPs if not self._is_public_ip(client_ip): return await self.app(request)
# Get geolocation data geo_data = await self._get_geolocation(client_ip)
# Add to request state request.state.geo_location = geo_data
return await self.app(request)
async def _get_geolocation(self, ip: str) -> Dict[str, Any]: # Check cache first if ip in self.cache: cached_data, timestamp = self.cache[ip] if time.time() - timestamp < self.cache_ttl: return cached_data
# Fetch from API try: async with aiohttp.ClientSession() as session: url = f"https://api.ipgeolocation.io/ipgeo?apiKey={self.api_key}&ip={ip}" async with session.get(url, timeout=2) as response: if response.status == 200: data = await response.json() # Cache the result self.cache[ip] = (data, time.time()) return data except (aiohttp.ClientError, asyncio.TimeoutError): pass
return {}
def _is_public_ip(self, ip: str) -> bool: # Simple check for public IP (implement more comprehensive logic) return not (ip.startswith('192.168.') or ip.startswith('10.') or ip.startswith('172.') or ip == '127.0.0.1')
# Usageapp.add_middleware(GeoLocationMiddleware, { 'api_key': 'your-api-key', 'cache_ttl': 3600, 'enabled': True})Error Handling Middleware
Section titled “Error Handling Middleware”import tracebackfrom zenith import HTTPException, JSONResponse
class ErrorHandlingMiddleware: """Global error handling and formatting."""
def __init__(self, app, config: dict): self.app = app self.debug = config.get('debug', False) self.logger = logging.getLogger('error_handler')
async def __call__(self, request: Request) -> Response: try: return await self.app(request)
except HTTPException as exc: # Handle HTTP exceptions return JSONResponse( status_code=exc.status_code, content={ "error": { "type": "http_exception", "message": exc.detail, "status_code": exc.status_code } } )
except ValueError as exc: # Handle validation errors return JSONResponse( status_code=400, content={ "error": { "type": "validation_error", "message": str(exc) } } )
except Exception as exc: # Handle unexpected errors error_id = str(uuid.uuid4())
# Log the error self.logger.error( f"Unexpected error [{error_id}]: {str(exc)}", exc_info=True )
# Return formatted error response error_response = { "error": { "type": "internal_error", "message": "An unexpected error occurred", "error_id": error_id } }
# Include traceback in debug mode if self.debug: error_response["error"]["traceback"] = traceback.format_exc()
return JSONResponse( status_code=500, content=error_response )
# Add to applicationapp.add_middleware(ErrorHandlingMiddleware, { 'debug': app.debug})Middleware Order
Section titled “Middleware Order”Middleware execution order is important. Middleware are executed in the order they’re added for requests, and in reverse order for responses:
# Recommended order for standard middlewareapp.add_middleware(ErrorHandlingMiddleware) # 1st request, last responseapp.add_middleware(SecurityHeadersMiddleware) # 2nd request, 2nd-last responseapp.add_middleware(CORSMiddleware) # 3rd request, 3rd-last responseapp.add_middleware(RequestIDMiddleware) # 4th request, 4th-last responseapp.add_middleware(LoggingMiddleware) # 5th request, 5th-last responseapp.add_middleware(AuthMiddleware) # 6th request, 6th-last responseapp.add_middleware(RateLimitMiddleware) # 7th request, 7th-last responseapp.add_middleware(CompressionMiddleware) # Last request, 1st response
# Recommended order with specialized middleware (add as needed)app.add_middleware(ErrorHandlingMiddleware) # Error handling (outermost)app.add_middleware(SecurityHeadersMiddleware) # Security headersapp.add_middleware(CORSMiddleware) # CORS handling# Note: Database and SSE optimizations are built-in and automaticapp.add_middleware(RequestIDMiddleware) # Request trackingapp.add_middleware(LoggingMiddleware) # Request loggingapp.add_middleware(AuthenticationMiddleware) # Authenticationapp.add_middleware(RateLimitMiddleware) # Rate limitingapp.add_middleware(CompressionMiddleware) # Response compression (last)Testing Middleware
Section titled “Testing Middleware”from zenith.testing import TestClientimport pytest
@pytest.mark.asyncioasync def test_timing_middleware(): app = Zenith() app.add_middleware(TimingMiddleware)
@app.get("/test") async def test_endpoint(): return {"message": "test"}
async with TestClient(app) as client: response = await client.get("/test")
assert response.status_code == 200 assert "X-Process-Time" in response.headers assert float(response.headers["X-Process-Time"]) > 0
@pytest.mark.asyncioasync def test_custom_middleware_config(): app = Zenith() app.add_middleware(CustomLoggingMiddleware, { 'log_requests': True, 'exclude_paths': ['/health'] })
@app.get("/test") async def test_endpoint(): return {"message": "test"}
@app.get("/health") async def health_endpoint(): return {"status": "ok"}
async with TestClient(app) as client: # Both should work, but middleware behavior differs response1 = await client.get("/test") response2 = await client.get("/health")
assert response1.status_code == 200 assert response2.status_code == 200Middleware provides comprehensive cross-cutting functionality in Zenith applications. The built-in middleware covers common production needs, while the custom middleware system allows for application-specific requirements.