Proyecto / app.py
Cesar017's picture
Update app.py
9a96194 verified
import gradio as gr
import numpy as np
import pandas as pd
from tensorflow import keras
from PIL import Image
import io
import base64
from sklearn.metrics.pairwise import cosine_similarity
PATH_MODEL = "./autoencoder.keras"
PATH_DB = "./mnist_train_small.csv"
# ── Cargar modelo y datos al iniciar ─────────────────────────────────────────
model = keras.models.load_model(PATH_MODEL)
encoder = model.get_layer("encoder")
decoder = model.get_layer("decoder")
data = pd.read_csv(PATH_DB, header=None)
X_ref = data.iloc[:, 1:].values.astype("float32") / 255
X_latent = encoder.predict(X_ref, verbose=0)
LATENT_DIM = 32
# ── Helper: imagen subida → array (1, 784) ────────────────────────────────────
def image_to_array(canva):
img = canva['composite'].convert("L")
img = img.resize((28, 28))
arr = 1 - np.array(img, dtype="float32") / 255
return arr.reshape(1, 784)
def find_similar(img, top_k):
X = image_to_array(img)
query_vec = encoder.predict(X, verbose=0)
sims = cosine_similarity(query_vec, X_latent)[0]
top_idx = np.argsort(sims)[::-1][:int(top_k)]
best_arr = (X_ref[top_idx[0]].reshape(28, 28) * 255).astype(np.uint8)
best_img = Image.fromarray(best_arr)
table = [[int(i), round(float(sims[i]), 4)] for i in top_idx]
gallery_imgs = [
Image.fromarray((X_ref[i].reshape(28, 28) * 255).astype(np.uint8))
for i in top_idx
]
return table, gallery_imgs
with gr.Blocks() as demo:
with gr.Tab("Búsqueda"):
gr.Markdown("## Búsqueda en espacio latente")
with gr.Row():
with gr.Column():
canvas = gr.Sketchpad(label="Dibuja", type='pil')
with gr.Column():
topk = gr.Slider(1, 50, value=10, step=1, label="top_k")
btn = gr.Button("Buscar similares")
gallery = gr.Gallery(label="Imágenes similares", columns=5, object_fit="contain")
with gr.Tab("Metadatos"):
results = gr.Dataframe(
headers=["index", "cosine_similarity"],
datatype=["number", "number"],
label="Ranking",
interactive=False
)
btn.click(find_similar, inputs=[canvas, topk], outputs=[results, gallery])
demo.launch(server_port=7860)