| import gradio as gr |
| import numpy as np |
| import pandas as pd |
| from tensorflow import keras |
| from PIL import Image |
| import io |
| import base64 |
|
|
| from sklearn.metrics.pairwise import cosine_similarity |
|
|
| PATH_MODEL = "./autoencoder.keras" |
| PATH_DB = "./mnist_train_small.csv" |
|
|
| |
| model = keras.models.load_model(PATH_MODEL) |
| encoder = model.get_layer("encoder") |
| decoder = model.get_layer("decoder") |
|
|
| data = pd.read_csv(PATH_DB, header=None) |
| X_ref = data.iloc[:, 1:].values.astype("float32") / 255 |
| X_latent = encoder.predict(X_ref, verbose=0) |
|
|
| LATENT_DIM = 32 |
|
|
| |
| def image_to_array(canva): |
| img = canva['composite'].convert("L") |
| img = img.resize((28, 28)) |
| arr = 1 - np.array(img, dtype="float32") / 255 |
| return arr.reshape(1, 784) |
|
|
| def find_similar(img, top_k): |
| X = image_to_array(img) |
| query_vec = encoder.predict(X, verbose=0) |
| sims = cosine_similarity(query_vec, X_latent)[0] |
| top_idx = np.argsort(sims)[::-1][:int(top_k)] |
|
|
| best_arr = (X_ref[top_idx[0]].reshape(28, 28) * 255).astype(np.uint8) |
| best_img = Image.fromarray(best_arr) |
|
|
| table = [[int(i), round(float(sims[i]), 4)] for i in top_idx] |
|
|
| gallery_imgs = [ |
| Image.fromarray((X_ref[i].reshape(28, 28) * 255).astype(np.uint8)) |
| for i in top_idx |
| ] |
| return table, gallery_imgs |
|
|
| with gr.Blocks() as demo: |
|
|
| with gr.Tab("Búsqueda"): |
|
|
| gr.Markdown("## Búsqueda en espacio latente") |
| with gr.Row(): |
| with gr.Column(): |
| canvas = gr.Sketchpad(label="Dibuja", type='pil') |
| with gr.Column(): |
| topk = gr.Slider(1, 50, value=10, step=1, label="top_k") |
| btn = gr.Button("Buscar similares") |
|
|
| gallery = gr.Gallery(label="Imágenes similares", columns=5, object_fit="contain") |
|
|
| with gr.Tab("Metadatos"): |
|
|
| results = gr.Dataframe( |
| headers=["index", "cosine_similarity"], |
| datatype=["number", "number"], |
| label="Ranking", |
| interactive=False |
| ) |
|
|
| btn.click(find_similar, inputs=[canvas, topk], outputs=[results, gallery]) |
|
|
| demo.launch(server_port=7860) |