Spaces:
Running
Running
File size: 2,444 Bytes
45fa780 9861b96 45fa780 9861b96 45fa780 9861b96 45fa780 9861b96 45fa780 9861b96 45fa780 9861b96 45fa780 9861b96 45fa780 9861b96 45fa780 9861b96 45fa780 9861b96 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from api.inference import predict_claim, load_models, get_available_models
from api.schema import ClaimRequest, PredictionResponse, EvidenceItem
from api.rag import load_rag_data, get_rag_result
from dotenv import load_dotenv
import asyncio
import os
from contextlib import asynccontextmanager
load_dotenv()
STATIC_DIR = "frontend/dist/frontend/browser"
@asynccontextmanager
async def lifespan(app: FastAPI):
print("Loading models...")
load_models()
print("Building RAG index...")
load_rag_data()
print("Startup complete.")
yield
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
def health():
return {"status": "ok", "available_models": get_available_models()}
@app.get("/models")
def list_models():
return {"available_models": get_available_models()}
@app.post("/predict")
async def predict_claim_endpoint(claim_request: ClaimRequest):
try:
claim = claim_request.claim.strip()
if not claim:
raise HTTPException(status_code=400, detail="Claim text cannot be empty")
result = await asyncio.to_thread(predict_claim, claim, claim_request.model)
predicted_label = result["prediction"]
evidence_raw, justification = await asyncio.to_thread(
get_rag_result, claim, predicted_label
)
evidence = [EvidenceItem(**e) for e in evidence_raw] if evidence_raw else None
return PredictionResponse(
prediction=predicted_label,
probabilities=result.get("probabilities"),
evidence=evidence,
justification=justification,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing claim: {str(e)}")
if os.path.exists(STATIC_DIR):
@app.get("/{full_path:path}")
async def serve_spa(full_path: str):
file_path = os.path.join(STATIC_DIR, full_path)
if full_path and os.path.exists(file_path) and os.path.isfile(file_path):
return FileResponse(file_path)
return FileResponse(os.path.join(STATIC_DIR, "index.html"))
|