File size: 2,444 Bytes
45fa780
9861b96
 
45fa780
 
 
 
 
 
9861b96
45fa780
9861b96
45fa780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9861b96
 
 
 
45fa780
 
 
 
 
 
9861b96
 
 
45fa780
 
 
 
 
9861b96
45fa780
 
9861b96
45fa780
 
 
9861b96
45fa780
9861b96
45fa780
 
 
 
 
 
 
 
 
 
 
 
 
9861b96
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from api.inference import predict_claim, load_models, get_available_models
from api.schema import ClaimRequest, PredictionResponse, EvidenceItem
from api.rag import load_rag_data, get_rag_result
from dotenv import load_dotenv
import asyncio
import os
from contextlib import asynccontextmanager

load_dotenv()

STATIC_DIR = "frontend/dist/frontend/browser"


@asynccontextmanager
async def lifespan(app: FastAPI):
    print("Loading models...")
    load_models()
    print("Building RAG index...")
    load_rag_data()
    print("Startup complete.")
    yield


app = FastAPI(lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/health")
def health():
    return {"status": "ok", "available_models": get_available_models()}


@app.get("/models")
def list_models():
    return {"available_models": get_available_models()}


@app.post("/predict")
async def predict_claim_endpoint(claim_request: ClaimRequest):
    try:
        claim = claim_request.claim.strip()
        if not claim:
            raise HTTPException(status_code=400, detail="Claim text cannot be empty")

        result = await asyncio.to_thread(predict_claim, claim, claim_request.model)
        predicted_label = result["prediction"]

        evidence_raw, justification = await asyncio.to_thread(
            get_rag_result, claim, predicted_label
        )

        evidence = [EvidenceItem(**e) for e in evidence_raw] if evidence_raw else None

        return PredictionResponse(
            prediction=predicted_label,
            probabilities=result.get("probabilities"),
            evidence=evidence,
            justification=justification,
        )
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing claim: {str(e)}")


if os.path.exists(STATIC_DIR):
    @app.get("/{full_path:path}")
    async def serve_spa(full_path: str):
        file_path = os.path.join(STATIC_DIR, full_path)
        if full_path and os.path.exists(file_path) and os.path.isfile(file_path):
            return FileResponse(file_path)
        return FileResponse(os.path.join(STATIC_DIR, "index.html"))