dreamlessx commited on
Commit
199f152
·
verified ·
1 Parent(s): 32f036f

Update landmarkdiff/metrics_agg.py to v0.3.2

Browse files
Files changed (1) hide show
  1. landmarkdiff/metrics_agg.py +15 -19
landmarkdiff/metrics_agg.py CHANGED
@@ -41,12 +41,8 @@ class MetricsAggregator:
41
  """
42
 
43
  HIGHER_BETTER = {
44
- "ssim": True,
45
- "psnr": True,
46
- "identity_sim": True,
47
- "lpips": False,
48
- "fid": False,
49
- "nme": False,
50
  }
51
 
52
  def __init__(self) -> None:
@@ -61,15 +57,13 @@ class MetricsAggregator:
61
  **metadata: Any,
62
  ) -> None:
63
  """Add a single evaluation record."""
64
- self.records.append(
65
- MetricRecord(
66
- experiment=experiment,
67
- procedure=procedure,
68
- metrics=metrics,
69
- checkpoint_step=checkpoint_step,
70
- metadata=metadata,
71
- )
72
- )
73
 
74
  def add_batch(
75
  self,
@@ -82,9 +76,7 @@ class MetricsAggregator:
82
  """
83
  for rec in records:
84
  proc = rec.get("procedure", "all")
85
- metrics = {
86
- k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))
87
- }
88
  self.add(experiment, proc, metrics)
89
 
90
  @property
@@ -219,7 +211,10 @@ class MetricsAggregator:
219
  val = self.mean(exp, metric, procedure)
220
  if math.isnan(val):
221
  continue
222
- if (higher_better and val > best_val) or (not higher_better and val < best_val):
 
 
 
223
  best_val = val
224
  best_exp = exp
225
 
@@ -309,5 +304,6 @@ class MetricsAggregator:
309
  procedure=rec["procedure"],
310
  metrics=rec["metrics"],
311
  checkpoint_step=rec.get("checkpoint_step"),
 
312
  )
313
  return agg
 
41
  """
42
 
43
  HIGHER_BETTER = {
44
+ "ssim": True, "psnr": True, "identity_sim": True,
45
+ "lpips": False, "fid": False, "nme": False,
 
 
 
 
46
  }
47
 
48
  def __init__(self) -> None:
 
57
  **metadata: Any,
58
  ) -> None:
59
  """Add a single evaluation record."""
60
+ self.records.append(MetricRecord(
61
+ experiment=experiment,
62
+ procedure=procedure,
63
+ metrics=metrics,
64
+ checkpoint_step=checkpoint_step,
65
+ metadata=metadata,
66
+ ))
 
 
67
 
68
  def add_batch(
69
  self,
 
76
  """
77
  for rec in records:
78
  proc = rec.get("procedure", "all")
79
+ metrics = {k: v for k, v in rec.items() if k != "procedure" and isinstance(v, (int, float))}
 
 
80
  self.add(experiment, proc, metrics)
81
 
82
  @property
 
211
  val = self.mean(exp, metric, procedure)
212
  if math.isnan(val):
213
  continue
214
+ if higher_better and val > best_val:
215
+ best_val = val
216
+ best_exp = exp
217
+ elif not higher_better and val < best_val:
218
  best_val = val
219
  best_exp = exp
220
 
 
304
  procedure=rec["procedure"],
305
  metrics=rec["metrics"],
306
  checkpoint_step=rec.get("checkpoint_step"),
307
+ **rec.get("metadata", {}),
308
  )
309
  return agg