Skip to content

Security Guide

This comprehensive guide covers security best practices for Zenith applications - from authentication and authorization to protecting against common vulnerabilities. Learn how to build secure APIs that protect user data and maintain trust.

Security is not optional. This guide covers:

  • Authentication strategies (JWT, OAuth, API keys)
  • Authorization patterns (RBAC, ABAC, permissions)
  • Common vulnerabilities and prevention (OWASP Top 10)
  • Data protection and encryption
  • Security headers and middleware
  • Rate limiting and DDoS protection
  • Security testing and auditing

JSON Web Tokens provide stateless authentication:

from datetime import datetime, timedelta
from typing import Optional, Dict, Any
from jose import JWTError, jwt
from passlib.context import CryptContext
from zenith import Zenith
app = Zenith()
# Configuration
# IMPORTANT: Generate secure keys with sufficient entropy (16+ unique chars)
# Generate with: python -c 'import secrets; print(secrets.token_urlsafe(32))'
SECRET_KEY = "your-secret-key-change-in-production"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
# Password hashing (Zenith uses pwdlib with Argon2 by default)
from pwdlib import PasswordHash
pwd_hash = PasswordHash.recommended() # Argon2id - modern, secure
# Or customize Argon2 parameters
from pwdlib.hashers.argon2 import Argon2Hasher
pwd_hash = PasswordHash((Argon2Hasher(
time_cost=3,
memory_cost=65536,
parallelism=4
),))
def create_access_token(data: Dict[str, Any]) -> str:
"""
Create JWT access token.
Access tokens should:
- Be short-lived (15-30 minutes)
- Contain minimal data
- Include token type identifier
"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({
"exp": expire,
"iat": datetime.utcnow(),
"type": "access",
"jti": generate_token_id() # Unique token ID for revocation
})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def create_refresh_token(data: Dict[str, Any]) -> str:
"""
Create JWT refresh token.
Refresh tokens should:
- Be long-lived (days/weeks)
- Contain only user identifier
- Be stored securely (httpOnly cookies)
"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({
"exp": expire,
"iat": datetime.utcnow(),
"type": "refresh"
})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def verify_token(token: str, expected_type: str = "access") -> Dict[str, Any]:
"""
Verify and decode JWT token.
Checks:
- Signature validity
- Expiration
- Token type
- Revocation status
"""
try:
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM]
)
# Verify token type
if payload.get("type") != expected_type:
raise JWTError("Invalid token type")
# Check if token is revoked
if is_token_revoked(payload.get("jti")):
raise JWTError("Token has been revoked")
return payload
except JWTError as e:
raise AuthenticationError(f"Invalid token: {e}")
# Token revocation (blacklist)
revoked_tokens = set() # Use Redis in production
def revoke_token(token_id: str):
"""Revoke a token by its JTI."""
revoked_tokens.add(token_id)
# In production, store in Redis with expiry
def is_token_revoked(token_id: str) -> bool:
"""Check if token is revoked."""
return token_id in revoked_tokens

Implement OAuth for third-party authentication:

from authlib.integrations.starlette_client import OAuth
from app.models import User
# Configure OAuth
oauth = OAuth()
oauth.register(
name='google',
client_id=settings.GOOGLE_CLIENT_ID,
client_secret=settings.GOOGLE_CLIENT_SECRET,
server_metadata_url='https://accounts.google.com/.well-known/openid-configuration',
client_kwargs={'scope': 'openid email profile'}
)
@app.get("/auth/google")
async def google_login(request: Request):
"""Redirect to Google OAuth."""
redirect_uri = request.url_for('google_callback')
return await oauth.google.authorize_redirect(request, redirect_uri)
@app.get("/auth/google/callback")
async def google_callback(request: Request, session: AsyncSession = Depends(get_session)):
"""Handle Google OAuth callback."""
try:
# Exchange code for token
token = await oauth.google.authorize_access_token(request)
# Get user info
user_info = token.get('userinfo')
if not user_info:
raise AuthenticationError("Failed to get user info")
# Find or create user
user = await session.exec(
select(User).where(User.email == user_info['email'])
).first()
if not user:
# Create new user from OAuth
user = User(
email=user_info['email'],
name=user_info.get('name'),
avatar_url=user_info.get('picture'),
oauth_provider='google',
oauth_id=user_info['sub'],
is_verified=True # Email verified by Google
)
session.add(user)
await session.commit()
# Create session
access_token = create_access_token(
data={"sub": user.email, "user_id": user.id}
)
return {"access_token": access_token, "token_type": "bearer"}
except Exception as e:
raise AuthenticationError(f"OAuth failed: {e}")

For machine-to-machine authentication:

import secrets
import hashlib
from datetime import datetime
class APIKey(SQLModel, table=True):
"""API key model for service authentication."""
id: Optional[int] = Field(primary_key=True)
name: str = Field(index=True)
key_hash: str = Field(unique=True) # Store hash, not plain key
prefix: str = Field(index=True) # First 8 chars for identification
scopes: List[str] = Field(sa_column_kwargs={"type_": JSON})
rate_limit: int = Field(default=1000) # Requests per hour
expires_at: Optional[datetime]
last_used_at: Optional[datetime]
created_at: datetime = Field(default_factory=datetime.utcnow)
revoked: bool = Field(default=False)
@classmethod
def generate_key(cls) -> tuple[str, str]:
"""Generate API key and its hash."""
# Generate secure random key
raw_key = secrets.token_urlsafe(32)
key = f"zk_{raw_key}" # Prefix for identification
# Hash for storage
key_hash = hashlib.sha256(key.encode()).hexdigest()
return key, key_hash
def verify_key(self, provided_key: str) -> bool:
"""Verify provided key against hash."""
provided_hash = hashlib.sha256(provided_key.encode()).hexdigest()
return secrets.compare_digest(self.key_hash, provided_hash)
# API Key authentication dependency
async def get_api_key(
api_key: str = Header(..., alias="X-API-Key"),
session: AsyncSession = Depends(get_session)
) -> APIKey:
"""Validate API key from header."""
# Extract prefix for quick lookup
if not api_key.startswith("zk_"):
raise HTTPException(status_code=401, detail="Invalid API key format")
prefix = api_key[:10]
# Find potential keys by prefix
stmt = select(APIKey).where(
APIKey.prefix == prefix,
APIKey.revoked == False
)
result = await session.exec(stmt)
api_keys = result.all()
# Verify against hash
for key_obj in api_keys:
if key_obj.verify_key(api_key):
# Check expiration
if key_obj.expires_at and key_obj.expires_at < datetime.utcnow():
raise HTTPException(status_code=401, detail="API key expired")
# Update last used
key_obj.last_used_at = datetime.utcnow()
await session.commit()
return key_obj
raise HTTPException(status_code=401, detail="Invalid API key")
# Usage
@app.get("/api/data")
async def get_data(api_key: APIKey = Depends(get_api_key)):
"""Endpoint requiring API key."""
if "read:data" not in api_key.scopes:
raise HTTPException(status_code=403, detail="Insufficient permissions")
return {"data": "sensitive information"}
from enum import Enum
from typing import List, Set
class Role(str, Enum):
"""User roles with hierarchy."""
SUPER_ADMIN = "super_admin"
ADMIN = "admin"
MODERATOR = "moderator"
USER = "user"
GUEST = "guest"
# Role hierarchy
ROLE_HIERARCHY = {
Role.SUPER_ADMIN: 100,
Role.ADMIN: 80,
Role.MODERATOR: 60,
Role.USER: 40,
Role.GUEST: 20
}
# Role permissions
ROLE_PERMISSIONS = {
Role.SUPER_ADMIN: ["*"], # All permissions
Role.ADMIN: [
"users:*",
"posts:*",
"comments:*",
"settings:read",
"settings:write"
],
Role.MODERATOR: [
"posts:read",
"posts:write",
"posts:delete",
"comments:*",
"users:read"
],
Role.USER: [
"posts:read",
"posts:write:own",
"comments:read",
"comments:write:own",
"users:read:own",
"users:write:own"
],
Role.GUEST: [
"posts:read",
"comments:read"
]
}
class User(SQLModel, table=True):
"""User model with roles and permissions."""
id: Optional[int] = Field(primary_key=True)
email: str = Field(unique=True)
role: Role = Field(default=Role.USER)
custom_permissions: List[str] = Field(sa_column_kwargs={"type_": JSON}, default_factory=list)
is_active: bool = Field(default=True)
def has_permission(self, permission: str) -> bool:
"""Check if user has specific permission."""
if not self.is_active:
return False
# Check custom permissions
if permission in self.custom_permissions:
return True
# Check role permissions
role_perms = ROLE_PERMISSIONS.get(self.role, [])
# Check exact match
if permission in role_perms or "*" in role_perms:
return True
# Check wildcard patterns
for perm in role_perms:
if self._matches_wildcard(permission, perm):
return True
return False
def _matches_wildcard(self, permission: str, pattern: str) -> bool:
"""Check if permission matches wildcard pattern."""
import fnmatch
return fnmatch.fnmatch(permission, pattern)
def has_role(self, required_role: Role) -> bool:
"""Check if user has required role or higher."""
user_level = ROLE_HIERARCHY.get(self.role, 0)
required_level = ROLE_HIERARCHY.get(required_role, 100)
return user_level >= required_level
# Permission decorator
def require_permission(permission: str):
"""Decorator to require specific permission."""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Get current user from request context
user = kwargs.get('current_user')
if not user or not user.has_permission(permission):
raise HTTPException(
status_code=403,
detail=f"Permission denied: {permission} required"
)
return await func(*args, **kwargs)
return wrapper
return decorator
# Usage
@app.post("/admin/users")
@require_permission("users:write")
async def create_user(
user_data: UserCreate,
current_user: User = Depends(get_current_user)
):
"""Create user - requires users:write permission."""
return await create_user_service(user_data)

More flexible authorization based on attributes:

from typing import Dict, Any
import json
class Policy:
"""ABAC policy definition."""
def __init__(self, name: str, rules: List[Dict[str, Any]]):
self.name = name
self.rules = rules
def evaluate(self, subject: Dict, resource: Dict, action: str, context: Dict = None) -> bool:
"""Evaluate policy against attributes."""
for rule in self.rules:
if self._evaluate_rule(rule, subject, resource, action, context):
return rule.get("effect", "allow") == "allow"
return False
def _evaluate_rule(self, rule, subject, resource, action, context):
"""Evaluate single rule."""
# Check action
if rule.get("action") != action and rule.get("action") != "*":
return False
# Check conditions
conditions = rule.get("conditions", [])
for condition in conditions:
if not self._evaluate_condition(condition, subject, resource, context):
return False
return True
def _evaluate_condition(self, condition, subject, resource, context):
"""Evaluate condition."""
attribute = condition.get("attribute")
operator = condition.get("operator")
value = condition.get("value")
# Get attribute value
if attribute.startswith("subject."):
actual = self._get_nested(subject, attribute[8:])
elif attribute.startswith("resource."):
actual = self._get_nested(resource, attribute[9:])
elif attribute.startswith("context."):
actual = self._get_nested(context or {}, attribute[8:])
else:
return False
# Evaluate operator
if operator == "equals":
return actual == value
elif operator == "not_equals":
return actual != value
elif operator == "contains":
return value in actual
elif operator == "in":
return actual in value
elif operator == "greater_than":
return actual > value
elif operator == "less_than":
return actual < value
else:
return False
def _get_nested(self, obj, path):
"""Get nested attribute value."""
keys = path.split(".")
for key in keys:
obj = obj.get(key)
if obj is None:
return None
return obj
# Example policies
document_policy = Policy(
name="document_access",
rules=[
{
"action": "read",
"effect": "allow",
"conditions": [
{"attribute": "resource.public", "operator": "equals", "value": True}
]
},
{
"action": "read",
"effect": "allow",
"conditions": [
{"attribute": "subject.id", "operator": "equals", "value": "resource.owner_id"}
]
},
{
"action": "write",
"effect": "allow",
"conditions": [
{"attribute": "subject.id", "operator": "equals", "value": "resource.owner_id"},
{"attribute": "resource.locked", "operator": "equals", "value": False}
]
}
]
)
# Usage
@app.get("/documents/{document_id}")
async def get_document(
document_id: int,
current_user: User = Depends(get_current_user),
session: AsyncSession = Depends(get_session)
):
document = await session.get(Document, document_id)
# ABAC evaluation
can_access = document_policy.evaluate(
subject={"id": current_user.id, "role": current_user.role},
resource={"id": document.id, "owner_id": document.owner_id, "public": document.is_public},
action="read"
)
if not can_access:
raise HTTPException(status_code=403, detail="Access denied")
return document

Zenith provides built-in protection against SQL injection through validated query builders.

# VULNERABLE: String concatenation
@app.get("/users/search")
async def search_users_vulnerable(query: str, session: AsyncSession = Depends(get_session)):
# NEVER DO THIS!
sql = f"SELECT * FROM users WHERE email = '{query}'"
result = await session.exec(text(sql))
return result.all()
# SAFE: ZenithModel with validated query builder
@app.get("/users/search")
async def search_users_safe(email: str):
# QueryBuilder validates column names and prevents injection
users = await User.where(email=email).all()
return users
# SAFE: Direct column access (also validated)
@app.get("/users")
async def get_users(sort: str = "created_at"):
# order_by() validates column names - raises ValueError if invalid
try:
users = await User.where(active=True).order_by(f"-{sort}").limit(10)
return users
except ValueError as e:
# Invalid column name provided
raise HTTPException(400, detail=str(e))
# SAFE: With raw SQL
@app.get("/users/advanced-search")
async def advanced_search(
email: str,
name: str,
session: AsyncSession = Depends(get_session)
):
# Parameterized raw SQL
stmt = text("""
SELECT * FROM users
WHERE email = :email
AND name LIKE :name
""")
result = await session.exec(
stmt,
{"email": email, "name": f"%{name}%"}
)
return result.all()
import bleach
from markupsafe import escape
# Input sanitization
def sanitize_html(content: str) -> str:
"""Sanitize HTML content to prevent XSS."""
# Allow only safe tags and attributes
allowed_tags = ['p', 'br', 'strong', 'em', 'a', 'ul', 'ol', 'li']
allowed_attributes = {'a': ['href', 'title']}
# Clean HTML
cleaned = bleach.clean(
content,
tags=allowed_tags,
attributes=allowed_attributes,
strip=True
)
# Additional validation for URLs
cleaned = bleach.linkify(cleaned)
return cleaned
# Content Security Policy
@app.middleware("http")
async def add_security_headers(request: Request, call_next):
response = await call_next(request)
# Content Security Policy
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' https://cdn.jsdelivr.net; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self' data:; "
"connect-src 'self'; "
"frame-ancestors 'none'; "
"base-uri 'self'; "
"form-action 'self'"
)
# Other security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
# Note: X-XSS-Protection removed - deprecated and can create vulnerabilities
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
return response
# Output encoding
@app.get("/user/{user_id}/profile")
async def get_user_profile(user_id: int):
user = await get_user(user_id)
# Always escape user input when rendering
return {
"name": escape(user.name),
"bio": sanitize_html(user.bio),
"website": escape(user.website)
}

Cross-Site Request Forgery (CSRF) Prevention

Section titled “Cross-Site Request Forgery (CSRF) Prevention”
import secrets
from zenith import Form
class CSRFProtection:
"""CSRF protection middleware."""
def __init__(self, secret_key: str):
self.secret_key = secret_key
def generate_token(self) -> str:
"""Generate CSRF token."""
return secrets.token_urlsafe(32)
def validate_token(self, token: str, session_token: str) -> bool:
"""Validate CSRF token."""
return secrets.compare_digest(token, session_token)
# CSRF middleware
csrf_protection = CSRFProtection(settings.SECRET_KEY)
@app.middleware("http")
async def csrf_middleware(request: Request, call_next):
# Skip for safe methods
if request.method in ["GET", "HEAD", "OPTIONS"]:
return await call_next(request)
# Get token from header or form
token = request.headers.get("X-CSRF-Token")
if not token and request.headers.get("content-type") == "application/x-www-form-urlencoded":
form = await request.form()
token = form.get("csrf_token")
# Get session token
session_token = request.session.get("csrf_token")
# Validate
if not token or not session_token or not csrf_protection.validate_token(token, session_token):
return JSONResponse(
status_code=403,
content={"detail": "CSRF token validation failed"}
)
return await call_next(request)
# Usage in forms
@app.get("/form")
async def get_form(request: Request):
# Generate and store token
csrf_token = csrf_protection.generate_token()
request.session["csrf_token"] = csrf_token
return HTMLResponse(f"""
<form method="post" action="/submit">
<input type="hidden" name="csrf_token" value="{csrf_token}">
<input type="text" name="data">
<button type="submit">Submit</button>
</form>
""")
from cryptography.fernet import Fernet
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2
class EncryptionService:
"""Service for encrypting sensitive data."""
def __init__(self, master_key: str):
# Derive encryption key from master key
kdf = PBKDF2(
algorithm=hashes.SHA256(),
length=32,
salt=b'stable_salt', # Use random salt in production
iterations=100000,
)
key = base64.urlsafe_b64encode(kdf.derive(master_key.encode()))
self.fernet = Fernet(key)
def encrypt(self, data: str) -> str:
"""Encrypt string data."""
return self.fernet.encrypt(data.encode()).decode()
def decrypt(self, encrypted_data: str) -> str:
"""Decrypt string data."""
return self.fernet.decrypt(encrypted_data.encode()).decode()
def encrypt_field(self, value: Any) -> str:
"""Encrypt field value for database storage."""
json_str = json.dumps(value)
return self.encrypt(json_str)
def decrypt_field(self, encrypted_value: str) -> Any:
"""Decrypt field value from database."""
json_str = self.decrypt(encrypted_value)
return json.loads(json_str)
# Custom encrypted field
class EncryptedString(TypeDecorator):
"""SQLAlchemy type for encrypted strings."""
impl = String
cache_ok = True
def __init__(self, encryption_service: EncryptionService, *args, **kwargs):
self.encryption_service = encryption_service
super().__init__(*args, **kwargs)
def process_bind_param(self, value, dialect):
"""Encrypt before storing."""
if value is not None:
return self.encryption_service.encrypt(value)
return value
def process_result_value(self, value, dialect):
"""Decrypt after loading."""
if value is not None:
return self.encryption_service.decrypt(value)
return value
# Usage in models
encryption_service = EncryptionService(settings.MASTER_KEY)
class SensitiveData(SQLModel, table=True):
id: Optional[int] = Field(primary_key=True)
# Encrypted fields
ssn: str = Field(sa_column=Column(EncryptedString(encryption_service)))
credit_card: str = Field(sa_column=Column(EncryptedString(encryption_service)))
# Regular field
user_id: int
from passlib.context import CryptContext
import secrets
import string
# Configure password hashing (Zenith v0.0.11+ uses pwdlib with Argon2)
from pwdlib import PasswordHash
from pwdlib.hashers.argon2 import Argon2Hasher
pwd_hash = PasswordHash((
Argon2Hasher(
time_cost=3,
memory_cost=65536,
parallelism=4
),
))
# Verify and hash passwords
hashed = pwd_hash.hash("user_password")
is_valid = pwd_hash.verify("user_password", hashed)
class PasswordValidator:
"""Validate password strength."""
MIN_LENGTH = 12
REQUIRE_UPPERCASE = True
REQUIRE_LOWERCASE = True
REQUIRE_DIGITS = True
REQUIRE_SPECIAL = True
COMMON_PASSWORDS_FILE = "common_passwords.txt"
def __init__(self):
# Load common passwords
with open(self.COMMON_PASSWORDS_FILE) as f:
self.common_passwords = set(line.strip().lower() for line in f)
def validate(self, password: str) -> tuple[bool, List[str]]:
"""Validate password strength."""
errors = []
# Length check
if len(password) < self.MIN_LENGTH:
errors.append(f"Password must be at least {self.MIN_LENGTH} characters")
# Complexity checks
if self.REQUIRE_UPPERCASE and not any(c.isupper() for c in password):
errors.append("Password must contain uppercase letter")
if self.REQUIRE_LOWERCASE and not any(c.islower() for c in password):
errors.append("Password must contain lowercase letter")
if self.REQUIRE_DIGITS and not any(c.isdigit() for c in password):
errors.append("Password must contain digit")
if self.REQUIRE_SPECIAL:
special_chars = set(string.punctuation)
if not any(c in special_chars for c in password):
errors.append("Password must contain special character")
# Common password check
if password.lower() in self.common_passwords:
errors.append("Password is too common")
# Entropy check
if self._calculate_entropy(password) < 50:
errors.append("Password is too predictable")
return len(errors) == 0, errors
def _calculate_entropy(self, password: str) -> float:
"""Calculate password entropy in bits."""
import math
charset_size = 0
if any(c.islower() for c in password):
charset_size += 26
if any(c.isupper() for c in password):
charset_size += 26
if any(c.isdigit() for c in password):
charset_size += 10
if any(c in string.punctuation for c in password):
charset_size += len(string.punctuation)
entropy = len(password) * math.log2(charset_size) if charset_size > 0 else 0
return entropy
def generate_strong_password(self, length: int = 16) -> str:
"""Generate a strong random password."""
# Ensure all character types
password = [
secrets.choice(string.ascii_uppercase),
secrets.choice(string.ascii_lowercase),
secrets.choice(string.digits),
secrets.choice(string.punctuation)
]
# Fill remaining length
all_chars = string.ascii_letters + string.digits + string.punctuation
for _ in range(length - 4):
password.append(secrets.choice(all_chars))
# Shuffle
secrets.SystemRandom().shuffle(password)
return ''.join(password)
# Usage
validator = PasswordValidator()
@app.post("/register")
async def register(user_data: UserCreate):
# Validate password
is_valid, errors = validator.validate(user_data.password)
if not is_valid:
raise ValidationError(errors)
# Hash password
hashed_password = pwd_context.hash(user_data.password)
# Create user
user = User(
email=user_data.email,
password_hash=hashed_password
)
# ...
from zenith import Request
from datetime import datetime, timedelta
class SecurityHeadersMiddleware:
"""Add comprehensive security headers."""
def __init__(self, app):
self.app = app
async def __call__(self, request: Request, call_next):
response = await call_next(request)
# Strict Transport Security
response.headers["Strict-Transport-Security"] = (
"max-age=63072000; " # 2 years
"includeSubDomains; "
"preload"
)
# Content Security Policy
response.headers["Content-Security-Policy"] = self._get_csp()
# X-Frame-Options
response.headers["X-Frame-Options"] = "DENY"
# X-Content-Type-Options
response.headers["X-Content-Type-Options"] = "nosniff"
# Referrer-Policy (X-XSS-Protection removed - deprecated and insecure)
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
# Permissions-Policy
response.headers["Permissions-Policy"] = (
"accelerometer=(), "
"camera=(), "
"geolocation=(), "
"gyroscope=(), "
"magnetometer=(), "
"microphone=(), "
"payment=(), "
"usb=()"
)
# X-Permitted-Cross-Domain-Policies
response.headers["X-Permitted-Cross-Domain-Policies"] = "none"
# Expect-CT
response.headers["Expect-CT"] = (
"max-age=86400, "
"enforce, "
f'report-uri="{settings.CT_REPORT_URI}"'
)
return response
def _get_csp(self) -> str:
"""Build Content Security Policy with modern secure defaults."""
directives = {
"default-src": ["'self'"],
# Remove unsafe-inline in production - use nonces instead
"script-src": ["'self'"],
"style-src": ["'self'"],
"img-src": ["'self'", "data:", "https:"],
"font-src": ["'self'", "data:"],
"connect-src": ["'self'"],
"media-src": ["'self'"],
"object-src": ["'none'"],
"frame-src": ["'none'"],
"frame-ancestors": ["'none'"],
"base-uri": ["'self'"],
"form-action": ["'self'"],
"upgrade-insecure-requests": [],
"report-uri": [settings.CSP_REPORT_URI] if settings.CSP_REPORT_URI else []
}
csp_parts = []
for directive, values in directives.items():
if values:
csp_parts.append(f"{directive} {' '.join(values)}")
else:
csp_parts.append(directive)
return "; ".join(csp_parts)
from collections import defaultdict
from datetime import datetime, timedelta
import asyncio
from typing import Optional
import redis.asyncio as redis
class RateLimiter:
"""Advanced rate limiting with multiple strategies."""
def __init__(self, redis_client: redis.Redis):
self.redis = redis_client
async def check_rate_limit(
self,
key: str,
limit: int,
window: int,
strategy: str = "sliding_window"
) -> tuple[bool, dict]:
"""
Check if request is within rate limit.
Strategies:
- fixed_window: Simple, resets at window boundaries
- sliding_window: More accurate, rolling window
- token_bucket: Allows bursts
"""
if strategy == "fixed_window":
return await self._fixed_window(key, limit, window)
elif strategy == "sliding_window":
return await self._sliding_window(key, limit, window)
elif strategy == "token_bucket":
return await self._token_bucket(key, limit, window)
else:
raise ValueError(f"Unknown strategy: {strategy}")
async def _fixed_window(self, key: str, limit: int, window: int):
"""Fixed window rate limiting."""
current = await self.redis.incr(key)
if current == 1:
await self.redis.expire(key, window)
allowed = current <= limit
remaining = max(0, limit - current)
reset_time = await self.redis.ttl(key)
return allowed, {
"limit": limit,
"remaining": remaining,
"reset": reset_time
}
async def _sliding_window(self, key: str, limit: int, window: int):
"""Sliding window using Redis sorted sets."""
now = datetime.utcnow().timestamp()
window_start = now - window
# Remove old entries
await self.redis.zremrangebyscore(key, 0, window_start)
# Count current window
current = await self.redis.zcard(key)
if current < limit:
# Add current request
await self.redis.zadd(key, {str(now): now})
await self.redis.expire(key, window)
allowed = True
else:
allowed = False
remaining = max(0, limit - current - (1 if allowed else 0))
return allowed, {
"limit": limit,
"remaining": remaining,
"reset": int(window)
}
async def _token_bucket(self, key: str, limit: int, window: int):
"""Token bucket algorithm for burst handling."""
bucket_key = f"{key}:bucket"
timestamp_key = f"{key}:timestamp"
# Get current tokens and last update
tokens = await self.redis.get(bucket_key)
last_update = await self.redis.get(timestamp_key)
now = datetime.utcnow().timestamp()
if tokens is None:
# Initialize bucket
tokens = limit
last_update = now
else:
tokens = float(tokens)
last_update = float(last_update)
# Add tokens based on time passed
time_passed = now - last_update
tokens_to_add = time_passed * (limit / window)
tokens = min(limit, tokens + tokens_to_add)
if tokens >= 1:
# Consume token
tokens -= 1
await self.redis.set(bucket_key, tokens)
await self.redis.set(timestamp_key, now)
allowed = True
else:
allowed = False
return allowed, {
"limit": limit,
"remaining": int(tokens),
"reset": int((1 - tokens) * (window / limit))
}
# Rate limit decorator
def rate_limit(
requests: int = 100,
window: int = 60,
strategy: str = "sliding_window",
key_func=None
):
"""Decorator for rate limiting endpoints."""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
request = kwargs.get('request')
if not request:
return await func(*args, **kwargs)
# Get rate limit key
if key_func:
key = key_func(request)
else:
key = f"rate_limit:{request.client.host}:{request.url.path}"
# Check rate limit
limiter = RateLimiter(redis_client)
allowed, info = await limiter.check_rate_limit(
key, requests, window, strategy
)
# Add headers
request.state.rate_limit_headers = {
"X-RateLimit-Limit": str(info["limit"]),
"X-RateLimit-Remaining": str(info["remaining"]),
"X-RateLimit-Reset": str(info["reset"])
}
if not allowed:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded",
headers=request.state.rate_limit_headers
)
return await func(*args, **kwargs)
return wrapper
return decorator
# Usage
@app.get("/api/search")
@rate_limit(requests=10, window=60) # 10 requests per minute
async def search(query: str):
return {"results": await perform_search(query)}
import logging
from typing import Optional
from datetime import datetime
class SecurityAuditLog:
"""Log security-related events for auditing."""
def __init__(self, logger: logging.Logger):
self.logger = logger
async def log_event(
self,
event_type: str,
user_id: Optional[int],
ip_address: str,
user_agent: str,
details: dict,
severity: str = "INFO"
):
"""Log security event."""
event = {
"timestamp": datetime.utcnow().isoformat(),
"event_type": event_type,
"user_id": user_id,
"ip_address": ip_address,
"user_agent": user_agent,
"details": details,
"severity": severity
}
# Log to file/service
self.logger.log(
getattr(logging, severity),
json.dumps(event)
)
# Store in database for analysis
await self._store_event(event)
async def _store_event(self, event: dict):
"""Store event in database."""
# Implementation depends on your storage
pass
# Security events to log
SECURITY_EVENTS = {
"LOGIN_SUCCESS": "Successful login",
"LOGIN_FAILED": "Failed login attempt",
"LOGIN_BLOCKED": "Login blocked due to rate limiting",
"PASSWORD_RESET": "Password reset requested",
"PERMISSION_DENIED": "Access denied to resource",
"SUSPICIOUS_ACTIVITY": "Suspicious activity detected",
"API_KEY_CREATED": "API key created",
"API_KEY_REVOKED": "API key revoked",
"DATA_EXPORT": "User data exported",
"ACCOUNT_DELETED": "Account deleted"
}
# Middleware for automatic logging
@app.middleware("http")
async def security_audit_middleware(request: Request, call_next):
# Log authentication attempts
if request.url.path == "/auth/login":
response = await call_next(request)
event_type = "LOGIN_SUCCESS" if response.status_code == 200 else "LOGIN_FAILED"
await audit_log.log_event(
event_type=event_type,
user_id=None, # Get from response if available
ip_address=request.client.host,
user_agent=request.headers.get("user-agent", ""),
details={"path": request.url.path}
)
return response
return await call_next(request)
tests/test_security.py
import pytest
from httpx import AsyncClient
class TestSecurityHeaders:
"""Test security headers are present."""
async def test_security_headers(self, client: AsyncClient):
response = await client.get("/")
# Check all required headers
assert "Strict-Transport-Security" in response.headers
assert "Content-Security-Policy" in response.headers
assert "X-Frame-Options" in response.headers
assert response.headers["X-Frame-Options"] == "DENY"
assert "X-Content-Type-Options" in response.headers
assert response.headers["X-Content-Type-Options"] == "nosniff"
class TestAuthentication:
"""Test authentication security."""
async def test_password_hashing(self):
"""Ensure passwords are properly hashed."""
from app.auth import hash_password, verify_password
password = "SecurePassword123!"
hashed = hash_password(password)
# Hash should be different from plaintext
assert hashed != password
# Should verify correctly
assert verify_password(password, hashed)
# Wrong password should fail
assert not verify_password("WrongPassword", hashed)
async def test_jwt_expiration(self):
"""Test JWT tokens expire."""
from app.auth import create_access_token
from jose import jwt, JWTError
import time
# Create token with short expiry
token = create_access_token(
data={"sub": "test@example.com"},
expires_delta=timedelta(seconds=1)
)
# Wait for expiry
time.sleep(2)
# Should raise error
with pytest.raises(JWTError):
jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
class TestVulnerabilities:
"""Test for common vulnerabilities."""
async def test_sql_injection_prevention(self, client: AsyncClient):
"""Test SQL injection is prevented."""
# Try SQL injection
malicious_input = "'; DROP TABLE users; --"
response = await client.get(f"/users/search?q={malicious_input}")
# Should not execute SQL
assert response.status_code in [200, 400] # Normal or validation error
# Database should still be intact
async def test_xss_prevention(self, client: AsyncClient):
"""Test XSS is prevented."""
# Try XSS payload
xss_payload = "<script>alert('XSS')</script>"
response = await client.post(
"/comments",
json={"content": xss_payload}
)
# Get the comment
comment_id = response.json()["id"]
response = await client.get(f"/comments/{comment_id}")
# Script should be escaped or removed
content = response.json()["content"]
assert "<script>" not in content
assert "alert(" not in content
async def test_csrf_protection(self, client: AsyncClient):
"""Test CSRF protection is active."""
# Try to post without CSRF token
response = await client.post(
"/api/sensitive-action",
json={"data": "test"}
)
# Should be rejected
assert response.status_code == 403
assert "CSRF" in response.json()["detail"]
  • Use environment variables for secrets
  • Enable debug mode only in development
  • Use HTTPS even in development (self-signed cert)
  • Test with security scanners (OWASP ZAP, Burp Suite)
  • Strong password requirements
  • Password hashing with Argon2 (Zenith default via pwdlib)
  • JWT tokens with expiration
  • Refresh token rotation
  • Account lockout after failed attempts
  • Two-factor authentication (2FA)
  • Encrypt sensitive data at rest
  • Use TLS 1.2+ for data in transit
  • Sanitize all user input
  • Validate output encoding
  • Implement field-level encryption for PII
  • Rate limiting on all endpoints
  • API versioning
  • Request size limits
  • Timeouts on all operations
  • Input validation with whitelisting
  • Regular security updates
  • Least privilege principle
  • Network segmentation
  • Firewall rules
  • Regular backups with encryption
  • Security event logging
  • Anomaly detection
  • Failed login monitoring
  • API usage analytics
  • Regular security audits

Solution: Use short expiration, refresh token rotation, bind tokens to IP/device

Solution: Rate limiting, account lockout, CAPTCHA, monitoring

Solution: Secure cookies, session timeout, regenerate session ID

Solution: Filter sensitive data, use structured logging, secure log storage


Need help? Check our FAQ or ask in GitHub Discussions.