Performance Guide
Learn about performance optimization Performance Guide →
This comprehensive guide covers security best practices for Zenith applications - from authentication and authorization to protecting against common vulnerabilities. Learn how to build secure APIs that protect user data and maintain trust.
Security is not optional. This guide covers:
JSON Web Tokens provide stateless authentication:
from datetime import datetime, timedeltafrom typing import Optional, Dict, Anyfrom jose import JWTError, jwtfrom passlib.context import CryptContextfrom zenith import Zenith
app = Zenith()
# Configuration# IMPORTANT: Generate secure keys with sufficient entropy (16+ unique chars)# Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'SECRET_KEY = "your-secret-key-change-in-production"ALGORITHM = "HS256"ACCESS_TOKEN_EXPIRE_MINUTES = 30REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing (Zenith uses pwdlib with Argon2 by default)from pwdlib import PasswordHash
pwd_hash = PasswordHash.recommended() # Argon2id - modern, secure
# Or customize Argon2 parametersfrom pwdlib.hashers.argon2 import Argon2Hasherpwd_hash = PasswordHash((Argon2Hasher( time_cost=3, memory_cost=65536, parallelism=4),))
def create_access_token(data: Dict[str, Any]) -> str: """ Create JWT access token.
Access tokens should: - Be short-lived (15-30 minutes) - Contain minimal data - Include token type identifier """ to_encode = data.copy() expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({ "exp": expire, "iat": datetime.utcnow(), "type": "access", "jti": generate_token_id() # Unique token ID for revocation })
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt
def create_refresh_token(data: Dict[str, Any]) -> str: """ Create JWT refresh token.
Refresh tokens should: - Be long-lived (days/weeks) - Contain only user identifier - Be stored securely (httpOnly cookies) """ to_encode = data.copy() expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({ "exp": expire, "iat": datetime.utcnow(), "type": "refresh" })
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt
def verify_token(token: str, expected_type: str = "access") -> Dict[str, Any]: """ Verify and decode JWT token.
Checks: - Signature validity - Expiration - Token type - Revocation status """ try: payload = jwt.decode( token, SECRET_KEY, algorithms=[ALGORITHM] )
# Verify token type if payload.get("type") != expected_type: raise JWTError("Invalid token type")
# Check if token is revoked if is_token_revoked(payload.get("jti")): raise JWTError("Token has been revoked")
return payload
except JWTError as e: raise AuthenticationError(f"Invalid token: {e}")
# Token revocation (blacklist)revoked_tokens = set() # Use Redis in production
def revoke_token(token_id: str): """Revoke a token by its JTI.""" revoked_tokens.add(token_id) # In production, store in Redis with expiry
def is_token_revoked(token_id: str) -> bool: """Check if token is revoked.""" return token_id in revoked_tokensImplement OAuth for third-party authentication:
from authlib.integrations.starlette_client import OAuthfrom app.models import User
# Configure OAuthoauth = OAuth()oauth.register( name='google', client_id=settings.GOOGLE_CLIENT_ID, client_secret=settings.GOOGLE_CLIENT_SECRET, server_metadata_url='https://accounts.google.com/.well-known/openid-configuration', client_kwargs={'scope': 'openid email profile'})
@app.get("/auth/google")async def google_login(request: Request): """Redirect to Google OAuth.""" redirect_uri = request.url_for('google_callback') return await oauth.google.authorize_redirect(request, redirect_uri)
@app.get("/auth/google/callback")async def google_callback(request: Request, session: AsyncSession = Depends(get_session)): """Handle Google OAuth callback.""" try: # Exchange code for token token = await oauth.google.authorize_access_token(request)
# Get user info user_info = token.get('userinfo') if not user_info: raise AuthenticationError("Failed to get user info")
# Find or create user user = await session.exec( select(User).where(User.email == user_info['email']) ).first()
if not user: # Create new user from OAuth user = User( email=user_info['email'], name=user_info.get('name'), avatar_url=user_info.get('picture'), oauth_provider='google', oauth_id=user_info['sub'], is_verified=True # Email verified by Google ) session.add(user) await session.commit()
# Create session access_token = create_access_token( data={"sub": user.email, "user_id": user.id} )
return {"access_token": access_token, "token_type": "bearer"}
except Exception as e: raise AuthenticationError(f"OAuth failed: {e}")For machine-to-machine authentication:
import secretsimport hashlibfrom datetime import datetime
class APIKey(SQLModel, table=True): """API key model for service authentication.""" id: Optional[int] = Field(primary_key=True) name: str = Field(index=True) key_hash: str = Field(unique=True) # Store hash, not plain key prefix: str = Field(index=True) # First 8 chars for identification scopes: List[str] = Field(sa_column_kwargs={"type_": JSON}) rate_limit: int = Field(default=1000) # Requests per hour expires_at: Optional[datetime] last_used_at: Optional[datetime] created_at: datetime = Field(default_factory=datetime.utcnow) revoked: bool = Field(default=False)
@classmethod def generate_key(cls) -> tuple[str, str]: """Generate API key and its hash.""" # Generate secure random key raw_key = secrets.token_urlsafe(32) key = f"zk_{raw_key}" # Prefix for identification
# Hash for storage key_hash = hashlib.sha256(key.encode()).hexdigest()
return key, key_hash
def verify_key(self, provided_key: str) -> bool: """Verify provided key against hash.""" provided_hash = hashlib.sha256(provided_key.encode()).hexdigest() return secrets.compare_digest(self.key_hash, provided_hash)
# API Key authentication dependencyasync def get_api_key( api_key: str = Header(..., alias="X-API-Key"), session: AsyncSession = Depends(get_session)) -> APIKey: """Validate API key from header."""
# Extract prefix for quick lookup if not api_key.startswith("zk_"): raise HTTPException(status_code=401, detail="Invalid API key format")
prefix = api_key[:10]
# Find potential keys by prefix stmt = select(APIKey).where( APIKey.prefix == prefix, APIKey.revoked == False ) result = await session.exec(stmt) api_keys = result.all()
# Verify against hash for key_obj in api_keys: if key_obj.verify_key(api_key): # Check expiration if key_obj.expires_at and key_obj.expires_at < datetime.utcnow(): raise HTTPException(status_code=401, detail="API key expired")
# Update last used key_obj.last_used_at = datetime.utcnow() await session.commit()
return key_obj
raise HTTPException(status_code=401, detail="Invalid API key")
# Usage@app.get("/api/data")async def get_data(api_key: APIKey = Depends(get_api_key)): """Endpoint requiring API key.""" if "read:data" not in api_key.scopes: raise HTTPException(status_code=403, detail="Insufficient permissions")
return {"data": "sensitive information"}from enum import Enumfrom typing import List, Set
class Role(str, Enum): """User roles with hierarchy.""" SUPER_ADMIN = "super_admin" ADMIN = "admin" MODERATOR = "moderator" USER = "user" GUEST = "guest"
# Role hierarchyROLE_HIERARCHY = { Role.SUPER_ADMIN: 100, Role.ADMIN: 80, Role.MODERATOR: 60, Role.USER: 40, Role.GUEST: 20}
# Role permissionsROLE_PERMISSIONS = { Role.SUPER_ADMIN: ["*"], # All permissions Role.ADMIN: [ "users:*", "posts:*", "comments:*", "settings:read", "settings:write" ], Role.MODERATOR: [ "posts:read", "posts:write", "posts:delete", "comments:*", "users:read" ], Role.USER: [ "posts:read", "posts:write:own", "comments:read", "comments:write:own", "users:read:own", "users:write:own" ], Role.GUEST: [ "posts:read", "comments:read" ]}
class User(SQLModel, table=True): """User model with roles and permissions.""" id: Optional[int] = Field(primary_key=True) email: str = Field(unique=True) role: Role = Field(default=Role.USER) custom_permissions: List[str] = Field(sa_column_kwargs={"type_": JSON}, default_factory=list) is_active: bool = Field(default=True)
def has_permission(self, permission: str) -> bool: """Check if user has specific permission.""" if not self.is_active: return False
# Check custom permissions if permission in self.custom_permissions: return True
# Check role permissions role_perms = ROLE_PERMISSIONS.get(self.role, [])
# Check exact match if permission in role_perms or "*" in role_perms: return True
# Check wildcard patterns for perm in role_perms: if self._matches_wildcard(permission, perm): return True
return False
def _matches_wildcard(self, permission: str, pattern: str) -> bool: """Check if permission matches wildcard pattern.""" import fnmatch return fnmatch.fnmatch(permission, pattern)
def has_role(self, required_role: Role) -> bool: """Check if user has required role or higher.""" user_level = ROLE_HIERARCHY.get(self.role, 0) required_level = ROLE_HIERARCHY.get(required_role, 100) return user_level >= required_level
# Permission decoratordef require_permission(permission: str): """Decorator to require specific permission.""" def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): # Get current user from request context user = kwargs.get('current_user') if not user or not user.has_permission(permission): raise HTTPException( status_code=403, detail=f"Permission denied: {permission} required" ) return await func(*args, **kwargs) return wrapper return decorator
# Usage@app.post("/admin/users")@require_permission("users:write")async def create_user( user_data: UserCreate, current_user: User = Depends(get_current_user)): """Create user - requires users:write permission.""" return await create_user_service(user_data)More flexible authorization based on attributes:
from typing import Dict, Anyimport json
class Policy: """ABAC policy definition."""
def __init__(self, name: str, rules: List[Dict[str, Any]]): self.name = name self.rules = rules
def evaluate(self, subject: Dict, resource: Dict, action: str, context: Dict = None) -> bool: """Evaluate policy against attributes.""" for rule in self.rules: if self._evaluate_rule(rule, subject, resource, action, context): return rule.get("effect", "allow") == "allow" return False
def _evaluate_rule(self, rule, subject, resource, action, context): """Evaluate single rule.""" # Check action if rule.get("action") != action and rule.get("action") != "*": return False
# Check conditions conditions = rule.get("conditions", []) for condition in conditions: if not self._evaluate_condition(condition, subject, resource, context): return False
return True
def _evaluate_condition(self, condition, subject, resource, context): """Evaluate condition.""" attribute = condition.get("attribute") operator = condition.get("operator") value = condition.get("value")
# Get attribute value if attribute.startswith("subject."): actual = self._get_nested(subject, attribute[8:]) elif attribute.startswith("resource."): actual = self._get_nested(resource, attribute[9:]) elif attribute.startswith("context."): actual = self._get_nested(context or {}, attribute[8:]) else: return False
# Evaluate operator if operator == "equals": return actual == value elif operator == "not_equals": return actual != value elif operator == "contains": return value in actual elif operator == "in": return actual in value elif operator == "greater_than": return actual > value elif operator == "less_than": return actual < value else: return False
def _get_nested(self, obj, path): """Get nested attribute value.""" keys = path.split(".") for key in keys: obj = obj.get(key) if obj is None: return None return obj
# Example policiesdocument_policy = Policy( name="document_access", rules=[ { "action": "read", "effect": "allow", "conditions": [ {"attribute": "resource.public", "operator": "equals", "value": True} ] }, { "action": "read", "effect": "allow", "conditions": [ {"attribute": "subject.id", "operator": "equals", "value": "resource.owner_id"} ] }, { "action": "write", "effect": "allow", "conditions": [ {"attribute": "subject.id", "operator": "equals", "value": "resource.owner_id"}, {"attribute": "resource.locked", "operator": "equals", "value": False} ] } ])
# Usage@app.get("/documents/{document_id}")async def get_document( document_id: int, current_user: User = Depends(get_current_user), session: AsyncSession = Depends(get_session)): document = await session.get(Document, document_id)
# ABAC evaluation can_access = document_policy.evaluate( subject={"id": current_user.id, "role": current_user.role}, resource={"id": document.id, "owner_id": document.owner_id, "public": document.is_public}, action="read" )
if not can_access: raise HTTPException(status_code=403, detail="Access denied")
return documentZenith provides built-in protection against SQL injection through validated query builders.
# VULNERABLE: String concatenation@app.get("/users/search")async def search_users_vulnerable(query: str, session: AsyncSession = Depends(get_session)): # NEVER DO THIS! sql = f"SELECT * FROM users WHERE email = '{query}'" result = await session.exec(text(sql)) return result.all()
# SAFE: ZenithModel with validated query builder@app.get("/users/search")async def search_users_safe(email: str): # QueryBuilder validates column names and prevents injection users = await User.where(email=email).all() return users
# SAFE: Direct column access (also validated)@app.get("/users")async def get_users(sort: str = "created_at"): # order_by() validates column names - raises ValueError if invalid try: users = await User.where(active=True).order_by(f"-{sort}").limit(10) return users except ValueError as e: # Invalid column name provided raise HTTPException(400, detail=str(e))
# SAFE: With raw SQL@app.get("/users/advanced-search")async def advanced_search( email: str, name: str, session: AsyncSession = Depends(get_session)): # Parameterized raw SQL stmt = text(""" SELECT * FROM users WHERE email = :email AND name LIKE :name """) result = await session.exec( stmt, {"email": email, "name": f"%{name}%"} ) return result.all()import bleachfrom markupsafe import escape
# Input sanitizationdef sanitize_html(content: str) -> str: """Sanitize HTML content to prevent XSS.""" # Allow only safe tags and attributes allowed_tags = ['p', 'br', 'strong', 'em', 'a', 'ul', 'ol', 'li'] allowed_attributes = {'a': ['href', 'title']}
# Clean HTML cleaned = bleach.clean( content, tags=allowed_tags, attributes=allowed_attributes, strip=True )
# Additional validation for URLs cleaned = bleach.linkify(cleaned)
return cleaned
# Content Security Policy@app.middleware("http")async def add_security_headers(request: Request, call_next): response = await call_next(request)
# Content Security Policy response.headers["Content-Security-Policy"] = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: https:; " "font-src 'self' data:; " "connect-src 'self'; " "frame-ancestors 'none'; " "base-uri 'self'; " "form-action 'self'" )
# Other security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" # Note: X-XSS-Protection removed - deprecated and can create vulnerabilities response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
return response
# Output encoding@app.get("/user/{user_id}/profile")async def get_user_profile(user_id: int): user = await get_user(user_id)
# Always escape user input when rendering return { "name": escape(user.name), "bio": sanitize_html(user.bio), "website": escape(user.website) }import secretsfrom zenith import Form
class CSRFProtection: """CSRF protection middleware."""
def __init__(self, secret_key: str): self.secret_key = secret_key
def generate_token(self) -> str: """Generate CSRF token.""" return secrets.token_urlsafe(32)
def validate_token(self, token: str, session_token: str) -> bool: """Validate CSRF token.""" return secrets.compare_digest(token, session_token)
# CSRF middlewarecsrf_protection = CSRFProtection(settings.SECRET_KEY)
@app.middleware("http")async def csrf_middleware(request: Request, call_next): # Skip for safe methods if request.method in ["GET", "HEAD", "OPTIONS"]: return await call_next(request)
# Get token from header or form token = request.headers.get("X-CSRF-Token") if not token and request.headers.get("content-type") == "application/x-www-form-urlencoded": form = await request.form() token = form.get("csrf_token")
# Get session token session_token = request.session.get("csrf_token")
# Validate if not token or not session_token or not csrf_protection.validate_token(token, session_token): return JSONResponse( status_code=403, content={"detail": "CSRF token validation failed"} )
return await call_next(request)
# Usage in forms@app.get("/form")async def get_form(request: Request): # Generate and store token csrf_token = csrf_protection.generate_token() request.session["csrf_token"] = csrf_token
return HTMLResponse(f""" <form method="post" action="/submit"> <input type="hidden" name="csrf_token" value="{csrf_token}"> <input type="text" name="data"> <button type="submit">Submit</button> </form> """)from cryptography.fernet import Fernetfrom cryptography.hazmat.primitives import hashesfrom cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2
class EncryptionService: """Service for encrypting sensitive data."""
def __init__(self, master_key: str): # Derive encryption key from master key kdf = PBKDF2( algorithm=hashes.SHA256(), length=32, salt=b'stable_salt', # Use random salt in production iterations=100000, ) key = base64.urlsafe_b64encode(kdf.derive(master_key.encode())) self.fernet = Fernet(key)
def encrypt(self, data: str) -> str: """Encrypt string data.""" return self.fernet.encrypt(data.encode()).decode()
def decrypt(self, encrypted_data: str) -> str: """Decrypt string data.""" return self.fernet.decrypt(encrypted_data.encode()).decode()
def encrypt_field(self, value: Any) -> str: """Encrypt field value for database storage.""" json_str = json.dumps(value) return self.encrypt(json_str)
def decrypt_field(self, encrypted_value: str) -> Any: """Decrypt field value from database.""" json_str = self.decrypt(encrypted_value) return json.loads(json_str)
# Custom encrypted fieldclass EncryptedString(TypeDecorator): """SQLAlchemy type for encrypted strings.""" impl = String cache_ok = True
def __init__(self, encryption_service: EncryptionService, *args, **kwargs): self.encryption_service = encryption_service super().__init__(*args, **kwargs)
def process_bind_param(self, value, dialect): """Encrypt before storing.""" if value is not None: return self.encryption_service.encrypt(value) return value
def process_result_value(self, value, dialect): """Decrypt after loading.""" if value is not None: return self.encryption_service.decrypt(value) return value
# Usage in modelsencryption_service = EncryptionService(settings.MASTER_KEY)
class SensitiveData(SQLModel, table=True): id: Optional[int] = Field(primary_key=True) # Encrypted fields ssn: str = Field(sa_column=Column(EncryptedString(encryption_service))) credit_card: str = Field(sa_column=Column(EncryptedString(encryption_service))) # Regular field user_id: intfrom passlib.context import CryptContextimport secretsimport string
# Configure password hashing (Zenith v0.0.11+ uses pwdlib with Argon2)from pwdlib import PasswordHashfrom pwdlib.hashers.argon2 import Argon2Hasher
pwd_hash = PasswordHash(( Argon2Hasher( time_cost=3, memory_cost=65536, parallelism=4 ),))
# Verify and hash passwordshashed = pwd_hash.hash("user_password")is_valid = pwd_hash.verify("user_password", hashed)
class PasswordValidator: """Validate password strength."""
MIN_LENGTH = 12 REQUIRE_UPPERCASE = True REQUIRE_LOWERCASE = True REQUIRE_DIGITS = True REQUIRE_SPECIAL = True COMMON_PASSWORDS_FILE = "common_passwords.txt"
def __init__(self): # Load common passwords with open(self.COMMON_PASSWORDS_FILE) as f: self.common_passwords = set(line.strip().lower() for line in f)
def validate(self, password: str) -> tuple[bool, List[str]]: """Validate password strength.""" errors = []
# Length check if len(password) < self.MIN_LENGTH: errors.append(f"Password must be at least {self.MIN_LENGTH} characters")
# Complexity checks if self.REQUIRE_UPPERCASE and not any(c.isupper() for c in password): errors.append("Password must contain uppercase letter")
if self.REQUIRE_LOWERCASE and not any(c.islower() for c in password): errors.append("Password must contain lowercase letter")
if self.REQUIRE_DIGITS and not any(c.isdigit() for c in password): errors.append("Password must contain digit")
if self.REQUIRE_SPECIAL: special_chars = set(string.punctuation) if not any(c in special_chars for c in password): errors.append("Password must contain special character")
# Common password check if password.lower() in self.common_passwords: errors.append("Password is too common")
# Entropy check if self._calculate_entropy(password) < 50: errors.append("Password is too predictable")
return len(errors) == 0, errors
def _calculate_entropy(self, password: str) -> float: """Calculate password entropy in bits.""" import math charset_size = 0
if any(c.islower() for c in password): charset_size += 26 if any(c.isupper() for c in password): charset_size += 26 if any(c.isdigit() for c in password): charset_size += 10 if any(c in string.punctuation for c in password): charset_size += len(string.punctuation)
entropy = len(password) * math.log2(charset_size) if charset_size > 0 else 0 return entropy
def generate_strong_password(self, length: int = 16) -> str: """Generate a strong random password.""" # Ensure all character types password = [ secrets.choice(string.ascii_uppercase), secrets.choice(string.ascii_lowercase), secrets.choice(string.digits), secrets.choice(string.punctuation) ]
# Fill remaining length all_chars = string.ascii_letters + string.digits + string.punctuation for _ in range(length - 4): password.append(secrets.choice(all_chars))
# Shuffle secrets.SystemRandom().shuffle(password) return ''.join(password)
# Usagevalidator = PasswordValidator()
@app.post("/register")async def register(user_data: UserCreate): # Validate password is_valid, errors = validator.validate(user_data.password) if not is_valid: raise ValidationError(errors)
# Hash password hashed_password = pwd_context.hash(user_data.password)
# Create user user = User( email=user_data.email, password_hash=hashed_password ) # ...from zenith import Requestfrom datetime import datetime, timedelta
class SecurityHeadersMiddleware: """Add comprehensive security headers."""
def __init__(self, app): self.app = app
async def __call__(self, request: Request, call_next): response = await call_next(request)
# Strict Transport Security response.headers["Strict-Transport-Security"] = ( "max-age=63072000; " # 2 years "includeSubDomains; " "preload" )
# Content Security Policy response.headers["Content-Security-Policy"] = self._get_csp()
# X-Frame-Options response.headers["X-Frame-Options"] = "DENY"
# X-Content-Type-Options response.headers["X-Content-Type-Options"] = "nosniff"
# Referrer-Policy (X-XSS-Protection removed - deprecated and insecure) response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions-Policy response.headers["Permissions-Policy"] = ( "accelerometer=(), " "camera=(), " "geolocation=(), " "gyroscope=(), " "magnetometer=(), " "microphone=(), " "payment=(), " "usb=()" )
# X-Permitted-Cross-Domain-Policies response.headers["X-Permitted-Cross-Domain-Policies"] = "none"
# Expect-CT response.headers["Expect-CT"] = ( "max-age=86400, " "enforce, " f'report-uri="{settings.CT_REPORT_URI}"' )
return response
def _get_csp(self) -> str: """Build Content Security Policy with modern secure defaults.""" directives = { "default-src": ["'self'"], # Remove unsafe-inline in production - use nonces instead "script-src": ["'self'"], "style-src": ["'self'"], "img-src": ["'self'", "data:", "https:"], "font-src": ["'self'", "data:"], "connect-src": ["'self'"], "media-src": ["'self'"], "object-src": ["'none'"], "frame-src": ["'none'"], "frame-ancestors": ["'none'"], "base-uri": ["'self'"], "form-action": ["'self'"], "upgrade-insecure-requests": [], "report-uri": [settings.CSP_REPORT_URI] if settings.CSP_REPORT_URI else [] }
csp_parts = [] for directive, values in directives.items(): if values: csp_parts.append(f"{directive} {' '.join(values)}") else: csp_parts.append(directive)
return "; ".join(csp_parts)from collections import defaultdictfrom datetime import datetime, timedeltaimport asynciofrom typing import Optionalimport redis.asyncio as redis
class RateLimiter: """Advanced rate limiting with multiple strategies."""
def __init__(self, redis_client: redis.Redis): self.redis = redis_client
async def check_rate_limit( self, key: str, limit: int, window: int, strategy: str = "sliding_window" ) -> tuple[bool, dict]: """ Check if request is within rate limit.
Strategies: - fixed_window: Simple, resets at window boundaries - sliding_window: More accurate, rolling window - token_bucket: Allows bursts """ if strategy == "fixed_window": return await self._fixed_window(key, limit, window) elif strategy == "sliding_window": return await self._sliding_window(key, limit, window) elif strategy == "token_bucket": return await self._token_bucket(key, limit, window) else: raise ValueError(f"Unknown strategy: {strategy}")
async def _fixed_window(self, key: str, limit: int, window: int): """Fixed window rate limiting.""" current = await self.redis.incr(key)
if current == 1: await self.redis.expire(key, window)
allowed = current <= limit remaining = max(0, limit - current) reset_time = await self.redis.ttl(key)
return allowed, { "limit": limit, "remaining": remaining, "reset": reset_time }
async def _sliding_window(self, key: str, limit: int, window: int): """Sliding window using Redis sorted sets.""" now = datetime.utcnow().timestamp() window_start = now - window
# Remove old entries await self.redis.zremrangebyscore(key, 0, window_start)
# Count current window current = await self.redis.zcard(key)
if current < limit: # Add current request await self.redis.zadd(key, {str(now): now}) await self.redis.expire(key, window) allowed = True else: allowed = False
remaining = max(0, limit - current - (1 if allowed else 0))
return allowed, { "limit": limit, "remaining": remaining, "reset": int(window) }
async def _token_bucket(self, key: str, limit: int, window: int): """Token bucket algorithm for burst handling.""" bucket_key = f"{key}:bucket" timestamp_key = f"{key}:timestamp"
# Get current tokens and last update tokens = await self.redis.get(bucket_key) last_update = await self.redis.get(timestamp_key)
now = datetime.utcnow().timestamp()
if tokens is None: # Initialize bucket tokens = limit last_update = now else: tokens = float(tokens) last_update = float(last_update)
# Add tokens based on time passed time_passed = now - last_update tokens_to_add = time_passed * (limit / window) tokens = min(limit, tokens + tokens_to_add)
if tokens >= 1: # Consume token tokens -= 1 await self.redis.set(bucket_key, tokens) await self.redis.set(timestamp_key, now) allowed = True else: allowed = False
return allowed, { "limit": limit, "remaining": int(tokens), "reset": int((1 - tokens) * (window / limit)) }
# Rate limit decoratordef rate_limit( requests: int = 100, window: int = 60, strategy: str = "sliding_window", key_func=None): """Decorator for rate limiting endpoints.""" def decorator(func): @wraps(func) async def wrapper(*args, **kwargs): request = kwargs.get('request') if not request: return await func(*args, **kwargs)
# Get rate limit key if key_func: key = key_func(request) else: key = f"rate_limit:{request.client.host}:{request.url.path}"
# Check rate limit limiter = RateLimiter(redis_client) allowed, info = await limiter.check_rate_limit( key, requests, window, strategy )
# Add headers request.state.rate_limit_headers = { "X-RateLimit-Limit": str(info["limit"]), "X-RateLimit-Remaining": str(info["remaining"]), "X-RateLimit-Reset": str(info["reset"]) }
if not allowed: raise HTTPException( status_code=429, detail="Rate limit exceeded", headers=request.state.rate_limit_headers )
return await func(*args, **kwargs) return wrapper return decorator
# Usage@app.get("/api/search")@rate_limit(requests=10, window=60) # 10 requests per minuteasync def search(query: str): return {"results": await perform_search(query)}import loggingfrom typing import Optionalfrom datetime import datetime
class SecurityAuditLog: """Log security-related events for auditing."""
def __init__(self, logger: logging.Logger): self.logger = logger
async def log_event( self, event_type: str, user_id: Optional[int], ip_address: str, user_agent: str, details: dict, severity: str = "INFO" ): """Log security event.""" event = { "timestamp": datetime.utcnow().isoformat(), "event_type": event_type, "user_id": user_id, "ip_address": ip_address, "user_agent": user_agent, "details": details, "severity": severity }
# Log to file/service self.logger.log( getattr(logging, severity), json.dumps(event) )
# Store in database for analysis await self._store_event(event)
async def _store_event(self, event: dict): """Store event in database.""" # Implementation depends on your storage pass
# Security events to logSECURITY_EVENTS = { "LOGIN_SUCCESS": "Successful login", "LOGIN_FAILED": "Failed login attempt", "LOGIN_BLOCKED": "Login blocked due to rate limiting", "PASSWORD_RESET": "Password reset requested", "PERMISSION_DENIED": "Access denied to resource", "SUSPICIOUS_ACTIVITY": "Suspicious activity detected", "API_KEY_CREATED": "API key created", "API_KEY_REVOKED": "API key revoked", "DATA_EXPORT": "User data exported", "ACCOUNT_DELETED": "Account deleted"}
# Middleware for automatic logging@app.middleware("http")async def security_audit_middleware(request: Request, call_next): # Log authentication attempts if request.url.path == "/auth/login": response = await call_next(request)
event_type = "LOGIN_SUCCESS" if response.status_code == 200 else "LOGIN_FAILED" await audit_log.log_event( event_type=event_type, user_id=None, # Get from response if available ip_address=request.client.host, user_agent=request.headers.get("user-agent", ""), details={"path": request.url.path} )
return response
return await call_next(request)import pytestfrom httpx import AsyncClient
class TestSecurityHeaders: """Test security headers are present."""
async def test_security_headers(self, client: AsyncClient): response = await client.get("/")
# Check all required headers assert "Strict-Transport-Security" in response.headers assert "Content-Security-Policy" in response.headers assert "X-Frame-Options" in response.headers assert response.headers["X-Frame-Options"] == "DENY" assert "X-Content-Type-Options" in response.headers assert response.headers["X-Content-Type-Options"] == "nosniff"
class TestAuthentication: """Test authentication security."""
async def test_password_hashing(self): """Ensure passwords are properly hashed.""" from app.auth import hash_password, verify_password
password = "SecurePassword123!" hashed = hash_password(password)
# Hash should be different from plaintext assert hashed != password
# Should verify correctly assert verify_password(password, hashed)
# Wrong password should fail assert not verify_password("WrongPassword", hashed)
async def test_jwt_expiration(self): """Test JWT tokens expire.""" from app.auth import create_access_token from jose import jwt, JWTError import time
# Create token with short expiry token = create_access_token( data={"sub": "test@example.com"}, expires_delta=timedelta(seconds=1) )
# Wait for expiry time.sleep(2)
# Should raise error with pytest.raises(JWTError): jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
class TestVulnerabilities: """Test for common vulnerabilities."""
async def test_sql_injection_prevention(self, client: AsyncClient): """Test SQL injection is prevented.""" # Try SQL injection malicious_input = "'; DROP TABLE users; --" response = await client.get(f"/users/search?q={malicious_input}")
# Should not execute SQL assert response.status_code in [200, 400] # Normal or validation error # Database should still be intact
async def test_xss_prevention(self, client: AsyncClient): """Test XSS is prevented.""" # Try XSS payload xss_payload = "<script>alert('XSS')</script>" response = await client.post( "/comments", json={"content": xss_payload} )
# Get the comment comment_id = response.json()["id"] response = await client.get(f"/comments/{comment_id}")
# Script should be escaped or removed content = response.json()["content"] assert "<script>" not in content assert "alert(" not in content
async def test_csrf_protection(self, client: AsyncClient): """Test CSRF protection is active.""" # Try to post without CSRF token response = await client.post( "/api/sensitive-action", json={"data": "test"} )
# Should be rejected assert response.status_code == 403 assert "CSRF" in response.json()["detail"]Solution: Use short expiration, refresh token rotation, bind tokens to IP/device
Solution: Rate limiting, account lockout, CAPTCHA, monitoring
Solution: Secure cookies, session timeout, regenerate session ID
Solution: Filter sensitive data, use structured logging, secure log storage
Performance Guide
Learn about performance optimization Performance Guide →
Need help? Check our FAQ or ask in GitHub Discussions.