Skip to content

Middleware API

Middleware in Zenith processes requests and responses, providing cross-cutting functionality like authentication, logging, and security headers.

Adds security headers to protect against common attacks:

from zenith.middleware import SecurityHeadersMiddleware, SecurityConfig
# Recommended: Using SecurityConfig class
config = SecurityConfig(
# HTTPS enforcement
force_https=True,
# HSTS (HTTP Strict Transport Security)
hsts_max_age=31536000, # 1 year
hsts_include_subdomains=True,
hsts_preload=True,
# Content security
content_type_nosniff=True,
frame_options="DENY",
xss_protection="1; mode=block",
# Content Security Policy
csp_policy="default-src 'self'; script-src 'self' 'unsafe-inline'",
csp_report_only=False,
# Referrer policy
referrer_policy="strict-origin-when-cross-origin",
# Permissions policy (formerly Feature Policy)
permissions_policy="geolocation=(), microphone=()"
)
app.add_middleware(SecurityHeadersMiddleware, config=config)
ParameterTypeDefaultDescription
force_httpsboolFalseRedirect HTTP to HTTPS
force_https_permanentboolFalseUse 301 redirect instead of 302
hsts_max_ageint31536000HSTS max age in seconds
hsts_include_subdomainsboolTrueInclude subdomains in HSTS
hsts_preloadboolFalseEnable HSTS preload
content_type_nosniffboolTrueAdd X-Content-Type-Options: nosniff
frame_optionsstr"DENY"X-Frame-Options header value
xss_protectionstr"1; mode=block"X-XSS-Protection header value
csp_policystrNoneContent Security Policy
csp_report_onlyboolFalseUse CSP report-only mode
referrer_policystr"strict-origin-when-cross-origin"Referrer policy
permissions_policystrNonePermissions policy header
trusted_proxieslist[str][]List of trusted proxy IPs

Handles Cross-Origin Resource Sharing:

from zenith.middleware import CORSMiddleware, CORSConfig
# Recommended: Using CORSConfig class
config = CORSConfig(
allow_origins=["https://example.com", "https://app.example.com"],
allow_origin_regex=r"https://.*\.example\.com",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Authorization", "Content-Type"],
expose_headers=["X-Request-ID", "X-Process-Time"],
allow_credentials=True,
max_age_secs=86400 # 24 hours
)
app.add_middleware(CORSMiddleware, config=config)
# Alternative: Direct parameters (backward compatibility)
app.add_middleware(
CORSMiddleware,
allow_origins=["https://example.com"],
allow_methods=["GET", "POST"],
allow_credentials=True
)
ParameterTypeDefaultDescription
allow_originslist[str][]Allowed origin URLs
allow_origin_regexstrNoneRegex pattern for allowed origins
allow_methodslist[str]["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH"]Allowed HTTP methods
allow_headerslist[str]["*"]Allowed request headers
allow_credentialsboolFalseAllow credentials in CORS requests
expose_headerslist[str][]Headers exposed to browser
max_age_secsint600Preflight cache duration in seconds

Implements rate limiting with multiple strategies:

from zenith.middleware import RateLimitMiddleware, RateLimitConfig, RateLimit
from zenith.middleware import MemoryRateLimitStorage, RedisRateLimitStorage
# Basic configuration
config = RateLimitConfig(
default_limits=[
RateLimit(requests=100, window=60), # 100/minute
RateLimit(requests=1000, window=3600), # 1000/hour
],
storage=MemoryRateLimitStorage(),
exempt_paths=["/health", "/metrics"],
error_message="Rate limit exceeded",
include_headers=True
)
app.add_middleware(RateLimitMiddleware, config=config)
# Redis storage for distributed systems
redis_storage = RedisRateLimitStorage("redis://localhost:6379/0")
redis_config = RateLimitConfig(
default_limits=[RateLimit(requests=1000, window=60)],
storage=redis_storage
)
app.add_middleware(RateLimitMiddleware, config=redis_config)
# By user ID (requires authentication)
def user_rate_limit_key(request):
user = getattr(request.state, 'user', None)
return f"user:{user.id}" if user else request.client.host
# By API key
def api_key_rate_limit_key(request):
api_key = request.headers.get('X-API-Key')
return f"api_key:{api_key}" if api_key else request.client.host
# By user tier
def tiered_rate_limit_key(request):
user = getattr(request.state, 'user', None)
if user:
return f"user:{user.id}:tier:{user.tier}"
return request.client.host
app.add_middleware(RateLimitMiddleware, {
"key_func": tiered_rate_limit_key,
"default_limits": ["100/minute"],
"tier_limits": {
"premium": ["1000/minute"],
"enterprise": ["10000/minute"]
}
})

Handles authentication for protected routes:

from zenith.middleware import AuthMiddleware
from zenith.auth import JWTConfig
app.add_middleware(AuthMiddleware, {
# JWT configuration
"jwt_config": JWTConfig(
secret_key="your-secret-key",
algorithm="HS256",
access_token_expire_minutes=30
),
# Token extraction
"token_url": "/auth/token",
"header_name": "Authorization",
"header_scheme": "Bearer",
"cookie_name": "access_token",
# Path exclusions
"exclude_paths": [
"/",
"/health",
"/auth/login",
"/auth/register",
"/docs",
"/openapi.json"
],
# Optional authentication
"optional_auth_paths": ["/api/public"],
# Auto error handling
"auto_error": True,
# Custom user loader
"user_loader": async_get_user_by_id
})

Structured request/response logging:

from zenith.middleware import LoggingMiddleware
import logging
logging.basicConfig(level=logging.INFO)
app.add_middleware(LoggingMiddleware, {
# Basic settings
"logger_name": "zenith.requests",
"log_level": logging.INFO,
# Request logging
"log_request_body": False, # Privacy consideration
"log_request_headers": ["User-Agent", "X-Request-ID"],
"max_body_size": 1024, # Truncate large bodies
# Response logging
"log_response_body": False,
"log_response_headers": ["Content-Type", "X-Process-Time"],
# Filtering
"exclude_paths": ["/health", "/metrics"],
"exclude_methods": ["OPTIONS"],
# Format customization
"format_string": "{method} {url} - {status_code} {process_time}ms",
# Sensitive data masking
"sensitive_headers": ["Authorization", "X-API-Key"],
"mask_value": "[REDACTED]"
})

Gzip and Brotli compression:

from zenith.middleware import CompressionMiddleware
app.add_middleware(CompressionMiddleware, {
# Compression settings
"minimum_size": 1024, # Only compress responses > 1KB
# Gzip settings
"gzip_enabled": True,
"gzip_level": 6, # 1-9, higher = more compression
# Brotli settings (if brotli package installed)
"brotli_enabled": True,
"brotli_quality": 4, # 0-11
"brotli_mode": 0, # 0=generic, 1=text, 2=font
# Content type filters
"include_media_types": [
"application/json",
"application/javascript",
"text/css",
"text/html",
"text/plain",
"text/xml"
],
"exclude_media_types": [
"image/jpeg",
"image/png",
"image/gif",
"application/zip",
"video/mp4"
]
})

Adds unique request identifiers:

from zenith.middleware import RequestIDMiddleware
import uuid
app.add_middleware(RequestIDMiddleware, {
# ID generation
"generate_id": lambda: str(uuid.uuid4()),
# Headers
"request_id_header": "X-Request-ID",
"response_id_header": "X-Request-ID",
# Trust client-provided IDs
"trust_incoming_id": False,
# Validation
"validate_id": lambda id: len(id) > 0 and len(id) < 100
})

These middleware serve specific use cases and can be added when your application needs their functionality.

Zenith includes several performance optimizations that work automatically without any configuration:

Database Optimization:

  • Automatic request-scoped session reuse
  • Connection pooling with smart defaults
  • 15-25% performance improvement

Server-Sent Events:

  • Built-in backpressure handling
  • 10x more concurrent connections
  • Memory-efficient streaming

Zenith automatically optimizes database connections with request-scoped session reuse:

# No configuration needed - automatic optimization
from zenith import Database
db = Database(url="postgresql+asyncpg://user:pass@localhost/db")

Key Features:

  • Automatic request-scoped session reuse
  • Connection pooling with smart defaults
  • 15-25% performance improvement
  • Zero configuration required

Prevents memory issues in SSE connections with intelligent backpressure:

Zenith provides built-in SSE support with automatic backpressure handling:

from zenith import create_sse_response
@app.get("/events")
async def stream_events():
async def event_generator():
for i in range(100):
yield {"type": "update", "data": {"count": i}}
await asyncio.sleep(1)
return create_sse_response(event_generator())

Key Features:

  • Automatic backpressure handling
  • Handles 10x more concurrent connections
  • Memory-efficient streaming
  • Clean async generator API

Cross-Site Request Forgery protection:

from zenith.middleware import CSRFMiddleware
app.add_middleware(CSRFMiddleware, {
# Token settings
"secret_key": "your-csrf-secret",
"token_length": 32,
# Cookie settings
"cookie_name": "csrftoken",
"cookie_domain": None,
"cookie_path": "/",
"cookie_secure": True,
"cookie_httponly": True,
"cookie_samesite": "strict",
"cookie_max_age": 3600,
# Header settings
"header_name": "X-CSRFToken",
# Method exemptions
"safe_methods": ["GET", "HEAD", "OPTIONS", "TRACE"],
# Path exemptions
"exempt_paths": ["/api/webhook"],
# Error handling
"failure_status_code": 403,
"failure_message": "CSRF token missing or invalid"
})
from zenith import Request, Response
from typing import Callable, Awaitable
class TimingMiddleware:
"""Add response time header to all responses."""
def __init__(self, app: Callable[[Request], Awaitable[Response]]):
self.app = app
async def __call__(self, request: Request) -> Response:
import time
start_time = time.time()
# Process request
response = await self.app(request)
# Add timing header
process_time = time.time() - start_time
response.headers["X-Process-Time"] = f"{process_time:.3f}"
return response
# Add to application
app.add_middleware(TimingMiddleware)
class CustomLoggingMiddleware:
"""Custom logging middleware with configuration."""
def __init__(self, app, config: dict):
self.app = app
self.log_requests = config.get('log_requests', True)
self.log_responses = config.get('log_responses', True)
self.exclude_paths = set(config.get('exclude_paths', []))
self.logger = logging.getLogger(config.get('logger_name', __name__))
async def __call__(self, request: Request) -> Response:
# Skip excluded paths
if request.url.path in self.exclude_paths:
return await self.app(request)
# Log request
if self.log_requests:
self.logger.info(f"Request: {request.method} {request.url}")
# Process request
response = await self.app(request)
# Log response
if self.log_responses:
self.logger.info(f"Response: {response.status_code}")
return response
# Add with configuration
app.add_middleware(CustomLoggingMiddleware, {
'log_requests': True,
'log_responses': False,
'exclude_paths': ['/health', '/metrics'],
'logger_name': 'my_app.requests'
})
import aiohttp
import asyncio
from typing import Dict, Any
class GeoLocationMiddleware:
"""Add geolocation data based on client IP."""
def __init__(self, app, config: dict):
self.app = app
self.api_key = config.get('api_key')
self.cache = {} # Simple in-memory cache
self.cache_ttl = config.get('cache_ttl', 3600)
self.enabled = config.get('enabled', True)
async def __call__(self, request: Request) -> Response:
if self.enabled and self.api_key:
# Get client IP
client_ip = request.client.host
# Skip private/local IPs
if not self._is_public_ip(client_ip):
return await self.app(request)
# Get geolocation data
geo_data = await self._get_geolocation(client_ip)
# Add to request state
request.state.geo_location = geo_data
return await self.app(request)
async def _get_geolocation(self, ip: str) -> Dict[str, Any]:
# Check cache first
if ip in self.cache:
cached_data, timestamp = self.cache[ip]
if time.time() - timestamp < self.cache_ttl:
return cached_data
# Fetch from API
try:
async with aiohttp.ClientSession() as session:
url = f"https://api.ipgeolocation.io/ipgeo?apiKey={self.api_key}&ip={ip}"
async with session.get(url, timeout=2) as response:
if response.status == 200:
data = await response.json()
# Cache the result
self.cache[ip] = (data, time.time())
return data
except (aiohttp.ClientError, asyncio.TimeoutError):
pass
return {}
def _is_public_ip(self, ip: str) -> bool:
# Simple check for public IP (implement more comprehensive logic)
return not (ip.startswith('192.168.') or
ip.startswith('10.') or
ip.startswith('172.') or
ip == '127.0.0.1')
# Usage
app.add_middleware(GeoLocationMiddleware, {
'api_key': 'your-api-key',
'cache_ttl': 3600,
'enabled': True
})
import traceback
from zenith import HTTPException, JSONResponse
class ErrorHandlingMiddleware:
"""Global error handling and formatting."""
def __init__(self, app, config: dict):
self.app = app
self.debug = config.get('debug', False)
self.logger = logging.getLogger('error_handler')
async def __call__(self, request: Request) -> Response:
try:
return await self.app(request)
except HTTPException as exc:
# Handle HTTP exceptions
return JSONResponse(
status_code=exc.status_code,
content={
"error": {
"type": "http_exception",
"message": exc.detail,
"status_code": exc.status_code
}
}
)
except ValueError as exc:
# Handle validation errors
return JSONResponse(
status_code=400,
content={
"error": {
"type": "validation_error",
"message": str(exc)
}
}
)
except Exception as exc:
# Handle unexpected errors
error_id = str(uuid.uuid4())
# Log the error
self.logger.error(
f"Unexpected error [{error_id}]: {str(exc)}",
exc_info=True
)
# Return formatted error response
error_response = {
"error": {
"type": "internal_error",
"message": "An unexpected error occurred",
"error_id": error_id
}
}
# Include traceback in debug mode
if self.debug:
error_response["error"]["traceback"] = traceback.format_exc()
return JSONResponse(
status_code=500,
content=error_response
)
# Add to application
app.add_middleware(ErrorHandlingMiddleware, {
'debug': app.debug
})

Middleware execution order is important. Middleware are executed in the order they’re added for requests, and in reverse order for responses:

# Recommended order for standard middleware
app.add_middleware(ErrorHandlingMiddleware) # 1st request, last response
app.add_middleware(SecurityHeadersMiddleware) # 2nd request, 2nd-last response
app.add_middleware(CORSMiddleware) # 3rd request, 3rd-last response
app.add_middleware(RequestIDMiddleware) # 4th request, 4th-last response
app.add_middleware(LoggingMiddleware) # 5th request, 5th-last response
app.add_middleware(AuthMiddleware) # 6th request, 6th-last response
app.add_middleware(RateLimitMiddleware) # 7th request, 7th-last response
app.add_middleware(CompressionMiddleware) # Last request, 1st response
# Recommended order with specialized middleware (add as needed)
app.add_middleware(ErrorHandlingMiddleware) # Error handling (outermost)
app.add_middleware(SecurityHeadersMiddleware) # Security headers
app.add_middleware(CORSMiddleware) # CORS handling
# Note: Database and SSE optimizations are built-in and automatic
app.add_middleware(RequestIDMiddleware) # Request tracking
app.add_middleware(LoggingMiddleware) # Request logging
app.add_middleware(AuthenticationMiddleware) # Authentication
app.add_middleware(RateLimitMiddleware) # Rate limiting
app.add_middleware(CompressionMiddleware) # Response compression (last)
from zenith.testing import TestClient
import pytest
@pytest.mark.asyncio
async def test_timing_middleware():
app = Zenith()
app.add_middleware(TimingMiddleware)
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
async with TestClient(app) as client:
response = await client.get("/test")
assert response.status_code == 200
assert "X-Process-Time" in response.headers
assert float(response.headers["X-Process-Time"]) > 0
@pytest.mark.asyncio
async def test_custom_middleware_config():
app = Zenith()
app.add_middleware(CustomLoggingMiddleware, {
'log_requests': True,
'exclude_paths': ['/health']
})
@app.get("/test")
async def test_endpoint():
return {"message": "test"}
@app.get("/health")
async def health_endpoint():
return {"status": "ok"}
async with TestClient(app) as client:
# Both should work, but middleware behavior differs
response1 = await client.get("/test")
response2 = await client.get("/health")
assert response1.status_code == 200
assert response2.status_code == 200

Middleware provides comprehensive cross-cutting functionality in Zenith applications. The built-in middleware covers common production needs, while the custom middleware system allows for application-specific requirements.