120 lines
4.6 KiB
Python
120 lines
4.6 KiB
Python
from starlette.types import ASGIApp, Receive, Scope, Send
|
|
from starlette.datastructures import Headers
|
|
from fastapi import Request, Response
|
|
from app.database import SessionLocal
|
|
from app import models, auth_utils
|
|
from jose import jwt, JWTError
|
|
import json
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
class IdempotencyMiddleware:
|
|
def __init__(self, app: ASGIApp):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
headers = Headers(scope=scope)
|
|
idempotency_key = headers.get("idempotency-key")
|
|
|
|
if not idempotency_key:
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
# Extract User ID (Manual JWT parse)
|
|
auth_header = headers.get("authorization")
|
|
user_id = None
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
token = auth_header.split(" ")[1]
|
|
try:
|
|
payload = jwt.decode(token, auth_utils.SECRET_KEY, algorithms=[auth_utils.ALGORITHM])
|
|
user_email = payload.get("sub")
|
|
if user_email:
|
|
# DB lookup for user_id
|
|
# Optimization: In async ASGI, we can't easily use sync SessionLocal without risk?
|
|
# Ideally allow blocking DB call for MVP or use run_in_threadpool.
|
|
# For now, simplistic sync call is okay if low traffic, or use separate thread.
|
|
db = SessionLocal()
|
|
try:
|
|
user = db.query(models.User).filter(models.User.email == user_email).first()
|
|
if user:
|
|
user_id = user.id
|
|
finally:
|
|
db.close()
|
|
except JWTError:
|
|
pass
|
|
|
|
if not user_id:
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
endpoint = scope["path"]
|
|
|
|
# Check DB for existing key
|
|
db = SessionLocal()
|
|
try:
|
|
existing_key = db.query(models.IdempotencyKey).filter(
|
|
models.IdempotencyKey.key == idempotency_key,
|
|
models.IdempotencyKey.user_id == user_id,
|
|
models.IdempotencyKey.endpoint == endpoint
|
|
).first()
|
|
|
|
if existing_key:
|
|
# Return stored response
|
|
response = Response(
|
|
content=existing_key.response_body,
|
|
media_type="application/json"
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
finally:
|
|
db.close()
|
|
|
|
# Capture response
|
|
response_body = []
|
|
response_status = [200] # default
|
|
|
|
async def send_wrapper(message):
|
|
if message["type"] == "http.response.start":
|
|
response_status[0] = message["status"]
|
|
await send(message)
|
|
elif message["type"] == "http.response.body":
|
|
body = message.get("body", b"")
|
|
response_body.append(body)
|
|
await send(message)
|
|
|
|
if not message.get("more_body", False):
|
|
# Request finished, save if successful
|
|
if 200 <= response_status[0] < 300:
|
|
full_body = b"".join(response_body).decode("utf-8")
|
|
# Save to DB
|
|
# Again, sync DB call in async loop.
|
|
# Should wrapping in threadpool?
|
|
# For prototype, direct call.
|
|
try:
|
|
db_save = SessionLocal()
|
|
try:
|
|
new_key = models.IdempotencyKey(
|
|
key=idempotency_key,
|
|
user_id=user_id,
|
|
endpoint=endpoint,
|
|
response_body=full_body
|
|
)
|
|
db_save.add(new_key)
|
|
db_save.commit()
|
|
except IntegrityError:
|
|
db_save.rollback()
|
|
except Exception as e:
|
|
print(f"Save error: {e}")
|
|
db_save.rollback()
|
|
finally:
|
|
db_save.close()
|
|
except:
|
|
pass
|
|
else:
|
|
await send(message)
|
|
|
|
await self.app(scope, receive, send_wrapper)
|