| import io, base64 |
| import numpy as np |
| from PIL import Image |
| import torch |
| from safetensors.torch import load_file |
| from model import TinyConv |
|
|
| LABELS = [str(i) for i in range(10)] |
|
|
| def load_model(path: str = ""): |
| m = TinyConv().eval() |
| try: |
| state = load_file(f"{path}model.safetensors") |
| m.load_state_dict(state, strict=True) |
| except Exception: |
| |
| pass |
| return m |
|
|
| MODEL = load_model() |
|
|
| def preprocess_pil(img: Image.Image) -> torch.Tensor: |
| img = img.convert("L").resize((28, 28)) |
| x = torch.from_numpy(np.array(img)).float() / 255.0 |
| x = x.unsqueeze(0).unsqueeze(0) |
| return x |
|
|
| def predict_from_pil(img: Image.Image): |
| x = preprocess_pil(img) |
| with torch.inference_mode(): |
| logits = MODEL(x) |
| probs = torch.softmax(logits, dim=-1)[0].tolist() |
| return {label: float(probs[i]) for i, label in enumerate(LABELS)} |
|
|
| def predict_from_base64(b64: str): |
| img = Image.open(io.BytesIO(base64.b64decode(b64))) |
| return predict_from_pil(img) |
|
|