iliaspap commited on
Commit
ae9934c
·
1 Parent(s): 712fb76

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +33 -5
handler.py CHANGED
@@ -5,6 +5,33 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  # check for GPU
6
  device = 0 if torch.cuda.is_available() else -1
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # multi-model list
9
  multi_model_list = [
10
  {"model_id": "distilbert-base-uncased-finetuned-sst-2-english", "task": "text-classification"},
@@ -32,9 +59,10 @@ class EndpointHandler():
32
  raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
33
 
34
  # pass inputs with all kwargs in data
35
- if parameters is not None:
36
- prediction = self.multi_model[model_id](inputs, **parameters)
37
- else:
38
- prediction = self.multi_model[model_id](inputs)
39
- # postprocess the prediction
 
40
  return prediction
 
5
  # check for GPU
6
  device = 0 if torch.cuda.is_available() else -1
7
 
8
+
9
+ from PIL import Image
10
+ from torchvision.transforms import Compose, ConvertImageDtype, Normalize, PILToTensor, Resize
11
+ from torchvision.transforms.functional import InterpolationMode
12
+ from pyrovision.models import model_from_hf_hub
13
+
14
+ # model = model_from_hf_hub("pyronear/mobilenet_v3_small").eval()
15
+
16
+ # img = Image.open(path_to_an_image).convert("RGB")
17
+
18
+ # # Preprocessing
19
+ # config = model.default_cfg
20
+ # transform = Compose([
21
+ # Resize(config['input_shape'][1:], interpolation=InterpolationMode.BILINEAR),
22
+ # PILToTensor(),
23
+ # ConvertImageDtype(torch.float32),
24
+ # Normalize(config['mean'], config['std'])
25
+ # ])
26
+
27
+ # input_tensor = transform(img).unsqueeze(0)
28
+
29
+ # # Inference
30
+ # with torch.inference_mode():
31
+ # output = model(input_tensor)
32
+ # probs = output.squeeze(0).softmax(dim=0)
33
+
34
+
35
  # multi-model list
36
  multi_model_list = [
37
  {"model_id": "distilbert-base-uncased-finetuned-sst-2-english", "task": "text-classification"},
 
59
  raise ValueError(f"model_id: {model_id} is not valid. Available models are: {list(self.multi_model.keys())}")
60
 
61
  # pass inputs with all kwargs in data
62
+ prediction = {'output':'test'}
63
+ # if parameters is not None:
64
+ # prediction = self.multi_model[model_id](inputs, **parameters)
65
+ # else:
66
+ # prediction = self.multi_model[model_id](inputs)
67
+ # # postprocess the prediction
68
  return prediction