DAHS / src /simulator.py
Vittal-M's picture
Upload 39 files
f3fc7bb verified
"""
simulator.py β€” Discrete-Event Warehouse Simulation Engine (DAHS_2)
Implements a realistic e-commerce fulfillment warehouse with 8 zones,
37 stations, 5 job types, stochastic disruptions, and pluggable heuristics.
NEW in DAHS_2:
- save_state() -> dict β€” snapshot full simulation state for fork training
- from_state(state_dict, heuristic_fn) -> WarehouseSimulator (classmethod)
- get_partial_metrics(since_time) -> SimulationMetrics β€” for 20-min fork windows
"""
from __future__ import annotations
import copy
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import simpy
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data Structures
# ---------------------------------------------------------------------------
@dataclass
class ZoneConfig:
"""Configuration for a single warehouse zone."""
zone_id: int
name: str
num_stations: int
zone_type: str # e.g. "receiving", "picking", "packing", "shipping"
@dataclass
class JobType:
"""Specification for a category of warehouse jobs."""
name: str # "A" – "E"
route: List[int] # ordered zone IDs
proc_time_ranges: List[Tuple[float, float]] # (min, max) minutes per zone
due_date_offset: float # minutes from arrival to due date
frequency: float # relative arrival weight
priority_weight: float # higher = more important
@dataclass
class Operation:
"""One processing step of a job at a specific zone/station."""
zone_id: int
nominal_proc_time: float
actual_proc_time: float = 0.0
start_time: float = -1.0
end_time: float = -1.0
station_id: int = -1
@dataclass
class Job:
"""A single warehouse order moving through the system."""
job_id: int
job_type: str
arrival_time: float
due_date: float
operations: List[Operation]
current_op_idx: int = 0
priority: int = 1 # 1=standard, 2=expedited, 3=VIP
status: str = "waiting" # waiting / processing / done / late
completion_time: float = -1.0
priority_escalated: bool = False
@property
def is_complete(self) -> bool:
return self.current_op_idx >= len(self.operations)
@property
def next_zone_id(self) -> Optional[int]:
if self.is_complete:
return None
return self.operations[self.current_op_idx].zone_id
def remaining_proc_time(self) -> float:
"""Sum of nominal proc times for all remaining operations."""
return sum(op.nominal_proc_time for op in self.operations[self.current_op_idx:])
@dataclass
class StationState:
"""Runtime state of a single processing station."""
station_id: int
zone_id: int
is_broken: bool = False
repair_end_time: float = 0.0
current_job: Optional[int] = None # job_id or None
busy_until: float = 0.0
@dataclass
class SimulationMetrics:
"""All performance metrics from one simulation run."""
makespan: float = 0.0
total_tardiness: float = 0.0
sla_breach_rate: float = 0.0
avg_cycle_time: float = 0.0
zone_utilization: Dict[int, float] = field(default_factory=dict)
throughput: float = 0.0
queue_max: int = 0
queue_history: List[Tuple[float, Dict[int, int]]] = field(default_factory=list)
completed_jobs: int = 0
total_jobs: int = 0
# ---------------------------------------------------------------------------
# Simulator
# ---------------------------------------------------------------------------
class WarehouseSimulator:
"""
SimPy-based discrete-event simulator for an e-commerce fulfillment center.
Simulation parameters are calibrated to published warehouse operations research:
- Zone structure & station counts (37 total, 8 zones):
De Koster et al. (2007), EJOR 182(2):481-501 β€” 20-50 stations typical for
mid-scale distribution centers.
Gu et al. (2010), EJOR 203(3):539-549 β€” warehouse design benchmarks.
- Arrival rate (BASE_ARRIVAL_RATE = 1.5 jobs/min = 90/hr):
Gu et al. (2010) β€” 60-150 orders/hour for mid-scale DCs.
(Default constructor arg is 2.5, calibrated preset uses 1.5.)
- Processing time ranges (Picking 5-18 min, Receiving 3-8 min):
Tompkins et al. (2010), Facilities Planning, Wiley 4th ed.
Bartholdi & Hackman (2019), Warehouse & Distribution Science, GT.
- Breakdown frequency (BREAKDOWN_PROB = 0.003):
Inman (1999), Prod. & Inv. Mgmt. Journal 40(2):67-71 β€” 2-5% of
operational hours. 0.003/min Γ— 37 stations Γ— 600 min β‰ˆ 2.7% exposure.
- Repair time mean (18 min, Exponential):
Goetschalckx & Ashayeri (1989) β€” 10-30 min MTTR for conveyor/AGV.
- Batch arrival size (30 jobs, every 45 min):
Bartholdi & Hackman (2019) β€” 20-60 items per truck unload;
30-60 min between truck docks for mid-scale DC.
- Processing time variability (lognormal Οƒ = 0.30, CV β‰ˆ 30%):
De Koster et al. (2007) β€” CV of 20-35% for manual warehouse operations.
- Lunch productivity penalty (1.3Γ—, 30% slowdown):
Garg et al. (2017), Int. J. Industrial Engineering 24(3):181-192 β€”
20-40% productivity drop during scheduled breaks.
- Worker utilization target (implicit 65-80%):
Frazelle (2016), World-Class Warehousing, McGraw-Hill 2nd ed.
- Due date SLA windows (60-320 min, spanning 1-5.3 hours):
Industry standard SLA windows of 1-8 hours for e-commerce fulfillment.
Frazelle (2016) β€” 2-10% SLA breach acceptable in well-run warehouses.
Parameters
----------
seed : int
Random seed for full reproducibility.
heuristic_fn : Callable
Dispatch function: (jobs, current_time, zone_id) -> ordered List[Job].
feature_extractor : optional
FeatureExtractor instance used when running in hybrid-ML mode.
"""
# Zone configuration: 8 zones with station counts summing to 37
# Total 37 stations within published 20-50 range for mid-scale DCs
# Ref: De Koster et al. (2007), EJOR 182(2):481-501
# Ref: Gu et al. (2010), EJOR 203(3):539-549
ZONE_SPECS: List[Tuple[int, str, int, str]] = [
(0, "Receiving", 3, "receiving"),
(1, "Sorting", 4, "sorting"),
(2, "Picking-A", 6, "picking"),
(3, "Picking-B", 8, "picking"),
(4, "Value-Add", 5, "value_add"),
(5, "QC", 4, "quality"),
(6, "Packing", 3, "packing"),
(7, "Shipping", 4, "shipping"),
]
# Job-type definitions (name, route, proc_time_ranges, due_date_offset_min, freq, prio_weight)
# Processing time ranges (min, max) in minutes:
# Receiving ops (3-8 min): Bartholdi & Hackman (2019) β€” upper-end realistic with inspection
# Picking ops (5-18 min): Tompkins et al. (2010), Facilities Planning β€” 2-15 min/order
# Value-Add (8-18 min): Tompkins et al. (2010) β€” extended operations
# Due date offsets (60-320 min, spanning 1-5.3 hours):
# Ref: Frazelle (2016) β€” typical SLA windows 1-8 hours for e-commerce fulfillment
JOB_TYPE_SPECS = [
("A", [0, 1, 2, 6, 7], [(3,8),(2,5),(5,12),(4,9),(2,4)], 120, 0.25, 2.0),
("B", [0, 1, 3, 5, 6, 7], [(3,8),(2,5),(6,14),(3,7),(4,9),(2,4)], 160, 0.30, 1.5),
("C", [0, 1, 4, 5, 6, 7], [(3,8),(2,5),(8,18),(3,7),(4,9),(2,4)], 240, 0.20, 1.0),
("D", [0, 1, 2, 4, 5, 6, 7], [(3,8),(2,5),(5,12),(8,18),(3,7),(4,9),(2,4)], 320, 0.15, 0.8),
("E", [1, 3, 7], [(2,5),(4,10),(1,3)], 60, 0.10, 3.0), # express β€” tight SLA
]
# Base arrival rate: 2.5 jobs/min = 150/hr (peak); calibrated preset uses 1.5 (90/hr = mid-scale)
# Published range: 60-150 orders/hour for mid-scale distribution centers
# Ref: Gu et al. (2010), EJOR 203(3):539-549
BASE_ARRIVAL_RATE = 2.5 # jobs per minute
SIM_DURATION = 600.0 # minutes (one 10-hour shift)
def __init__(
self,
seed: int,
heuristic_fn: Callable,
feature_extractor=None,
# breakdown_prob: 0.003/min β‰ˆ 2.7% exposure over 600 min Γ— 37 stations
# Published range: 2-5% of operational hours β€” Inman (1999)
base_arrival_rate: float = 2.5,
breakdown_prob: float = 0.003,
# batch_arrival_size: 30 items per truck β€” within published 20-60 range
# Ref: Bartholdi & Hackman (2019), Warehouse & Distribution Science
batch_arrival_size: int = 30,
# lunch_penalty_factor: 1.3x = 30% productivity drop during break
# Published range: 20-40% β€” Garg et al. (2017), Int. J. Industrial Engineering
lunch_penalty_factor: float = 1.3,
# Preset overrides β€” leave empty/1.0 for default behavior
job_type_frequencies: Optional[Dict[str, float]] = None,
due_date_tightness: float = 1.0,
processing_time_scale: float = 1.0,
) -> None:
self.seed = seed
self.heuristic_fn = heuristic_fn
self.feature_extractor = feature_extractor
self._base_arrival_rate = base_arrival_rate
self._breakdown_prob = breakdown_prob
self._batch_arrival_size = batch_arrival_size
self._lunch_penalty_factor = lunch_penalty_factor
self._job_type_frequencies = job_type_frequencies or {}
self._due_date_tightness = due_date_tightness
self._processing_time_scale = processing_time_scale
# Validate preset frequency overrides sum to ~1.0
if self._job_type_frequencies:
total = sum(self._job_type_frequencies.values())
if total > 0 and abs(total - 1.0) > 0.01:
logger.warning("job_type_frequencies sum=%.3f (expected ~1.0)", total)
self.rng = np.random.default_rng(seed)
self.env = simpy.Environment()
self.zones: Dict[int, ZoneConfig] = {}
self.job_types: Dict[str, JobType] = {}
self.stations: Dict[int, StationState] = {}
self.station_resources: Dict[int, simpy.Resource] = {}
# Zone-level queues (list of Job)
self.zone_queues: Dict[int, List[Job]] = {}
# Job registry
self.all_jobs: Dict[int, Job] = {}
self.completed_jobs: List[Job] = []
self._job_counter = 0
# Metrics tracking
self._zone_busy_time: Dict[int, float] = {}
self._queue_snapshots: List[Tuple[float, Dict[int, int]]] = []
self._max_queue: int = 0
self._lunch_active: bool = False
self._setup_zones()
self._setup_job_types()
# ------------------------------------------------------------------
# Setup helpers
# ------------------------------------------------------------------
def _setup_zones(self) -> None:
station_id = 0
self.dispatcher_triggers = {}
for zone_id, name, n_stations, zone_type in self.ZONE_SPECS:
self.zones[zone_id] = ZoneConfig(zone_id, name, n_stations, zone_type)
self.zone_queues[zone_id] = []
self.dispatcher_triggers[zone_id] = self.env.event()
self._zone_busy_time[zone_id] = 0.0
for _ in range(n_stations):
st = StationState(station_id=station_id, zone_id=zone_id)
self.stations[station_id] = st
self.station_resources[station_id] = simpy.Resource(self.env, capacity=1)
station_id += 1
def _setup_job_types(self) -> None:
for name, route, proc_ranges, due_offset, freq, prio_w in self.JOB_TYPE_SPECS:
effective_freq = self._job_type_frequencies.get(name, freq) if self._job_type_frequencies else freq
effective_due = due_offset * self._due_date_tightness
scaled_ranges = [
(lo * self._processing_time_scale, hi * self._processing_time_scale)
for lo, hi in proc_ranges
]
self.job_types[name] = JobType(
name=name,
route=route,
proc_time_ranges=scaled_ranges,
due_date_offset=effective_due,
frequency=effective_freq,
priority_weight=prio_w,
)
# ------------------------------------------------------------------
# Utility
# ------------------------------------------------------------------
def _next_job_id(self) -> int:
jid = self._job_counter
self._job_counter += 1
return jid
# Time-varying composition profile β€” reflects realistic daily order-mix shifts
# observed in e-commerce fulfillment centres:
# morning (0-120 min): overnight standard-order backlog β†’ Type A dominant
# mid-morning (120-240): diversifying mix β€” bulk Type B/C joins the floor
# afternoon (240-420): heavy bulk (C, D) as truck deliveries concentrate
# evening peak (420-600): same-day cut-off surge β€” Type E express dominates
# Values are anchor points; _get_composition_profile interpolates linearly
# between them so the distribution shifts smoothly rather than in hard steps.
# Refs: Bartholdi & Hackman (2019) Β§6; De Koster et al. (2007) EJOR 182(2);
# Boysen et al. (2019) EJOR 277(2):396-411 β€” e-commerce warehousing patterns.
_COMPOSITION_PROFILE = [
(0.0, {"A": 0.55, "B": 0.18, "C": 0.10, "D": 0.09, "E": 0.08}),
(120.0, {"A": 0.45, "B": 0.22, "C": 0.13, "D": 0.10, "E": 0.10}),
(240.0, {"A": 0.25, "B": 0.32, "C": 0.20, "D": 0.13, "E": 0.10}),
(360.0, {"A": 0.15, "B": 0.25, "C": 0.30, "D": 0.20, "E": 0.10}),
(480.0, {"A": 0.12, "B": 0.18, "C": 0.22, "D": 0.13, "E": 0.35}),
(600.0, {"A": 0.10, "B": 0.14, "C": 0.12, "D": 0.08, "E": 0.56}),
]
# Composition noise: Gaussian perturbation Οƒ applied per component, then
# renormalised to sum to 1. Keeps the profile from being artificially smooth
# while preserving the overall daily trend. Low enough (Οƒ=0.03) that no single
# solver is accidentally favoured by random fluctuations.
_COMPOSITION_NOISE_SIGMA = 0.03
# Intraday arrival-rate multiplier anchors (time in minutes from shift start).
# Bimodal curve with a mild morning plateau, lunch dip, and a strong evening
# peak reflecting the same-day cut-off surge that is characteristic of
# e-commerce fulfilment centres. Values are interpolated linearly between
# anchors and a small multiplicative noise band is applied per sample.
# Refs: Boysen et al. (2019) EJOR 277(2); Bartholdi & Hackman (2019) Β§2.3;
# De Koster et al. (2007) EJOR 182(2) β€” workload profiles in DCs.
_SURGE_PROFILE = [
(0.0, 0.55), # shift start β€” overnight backlog, still warming up
(60.0, 0.95), # morning ramp complete
(120.0, 1.05), # morning baseline
(180.0, 1.15), # pre-lunch mild peak
(240.0, 0.60), # lunch dip (productivity drop)
(300.0, 0.95), # post-lunch recovery
(360.0, 1.20), # afternoon ramp
(420.0, 1.45), # approaching evening peak
(480.0, 1.65), # evening peak β€” same-day cut-off surge
(540.0, 1.50), # late evening (still elevated)
(600.0, 1.30), # shift close (slight taper)
]
# Multiplicative noise band applied per surge evaluation; keeps arrivals
# stochastic without systematically biasing any heuristic.
_SURGE_NOISE_LO = 0.93
_SURGE_NOISE_HI = 1.07
def _get_composition_profile(self, t: float) -> Dict[str, float]:
"""Per-type probability vector at time t.
If the caller supplied explicit ``job_type_frequencies`` (used by
calibration tests and heuristic-biased presets) those are returned
verbatim. Otherwise the profile is **linearly interpolated** between the
anchor points in ``_COMPOSITION_PROFILE`` and a small Gaussian noise
term is added so the distribution is not artificially deterministic.
The noisy vector is clipped to be non-negative and renormalised to 1.
"""
if self._job_type_frequencies:
return dict(self._job_type_frequencies)
types = ("A", "B", "C", "D", "E")
# Find the two anchor points bracketing t
anchors = self._COMPOSITION_PROFILE
if t <= anchors[0][0]:
base = anchors[0][1]
elif t >= anchors[-1][0]:
base = anchors[-1][1]
else:
base = anchors[0][1]
for (t_a, p_a), (t_b, p_b) in zip(anchors[:-1], anchors[1:]):
if t_a <= t < t_b:
alpha = (t - t_a) / max(t_b - t_a, 1e-9)
base = {k: (1 - alpha) * p_a[k] + alpha * p_b[k] for k in types}
break
# Stochastic perturbation for realism (seeded via self.rng).
if self._COMPOSITION_NOISE_SIGMA > 0:
noisy = {
k: max(0.0, base[k] + float(self.rng.normal(0.0, self._COMPOSITION_NOISE_SIGMA)))
for k in types
}
total = sum(noisy.values())
if total > 0:
return {k: v / total for k, v in noisy.items()}
return dict(base)
def _sample_job_type(self) -> str:
profile = self._get_composition_profile(self.env.now)
types = list(self.job_types.keys())
weights = [profile.get(t, self.job_types[t].frequency) for t in types]
total = sum(weights)
if total <= 0:
weights = [self.job_types[t].frequency for t in types]
total = sum(weights)
probs = [w / total for w in weights]
return self.rng.choice(types, p=probs)
def _create_job(self, job_type_name: str, arrival_time: float) -> Job:
jt = self.job_types[job_type_name]
operations = []
for zone_id, (lo, hi) in zip(jt.route, jt.proc_time_ranges):
nominal = float(self.rng.uniform(lo, hi))
operations.append(Operation(zone_id=zone_id, nominal_proc_time=nominal))
return Job(
job_id=self._next_job_id(),
job_type=job_type_name,
arrival_time=arrival_time,
due_date=arrival_time + jt.due_date_offset,
operations=operations,
priority=3 if job_type_name == "E" else 1,
)
def _surge_base_rate(self, current_time: float) -> float:
"""Deterministic trend value of the surge multiplier at time ``t``.
Pure anchor-point interpolation β€” no RNG calls, so this is safe to
invoke from informational paths (state snapshots, feature extraction)
without disturbing the arrival-process sample stream.
"""
anchors = self._SURGE_PROFILE
if current_time <= anchors[0][0]:
return float(anchors[0][1])
if current_time >= anchors[-1][0]:
return float(anchors[-1][1])
for (t_a, v_a), (t_b, v_b) in zip(anchors[:-1], anchors[1:]):
if t_a <= current_time < t_b:
alpha = (current_time - t_a) / max(t_b - t_a, 1e-9)
return float((1.0 - alpha) * v_a + alpha * v_b)
return float(anchors[-1][1])
def _get_surge_multiplier(self, current_time: float) -> float:
"""Time-of-day arrival-rate multiplier (t in minutes from shift start).
The curve is a linear interpolation between the anchor points in
``_SURGE_PROFILE`` plus a small multiplicative noise term drawn from
``U(_SURGE_NOISE_LO, _SURGE_NOISE_HI)`` β€” so the instantaneous rate is
both deterministically trended (bimodal with evening peak) and
stochastically perturbed each time the process samples an arrival.
Returns a strictly positive multiplier.
"""
base = self._surge_base_rate(current_time)
noise = float(self.rng.uniform(self._SURGE_NOISE_LO, self._SURGE_NOISE_HI))
return max(0.05, base * noise)
def _record_queue_snapshot(self) -> None:
snapshot = {z: len(q) for z, q in self.zone_queues.items()}
self._queue_snapshots.append((self.env.now, snapshot))
total = sum(snapshot.values())
if total > self._max_queue:
self._max_queue = total
# ------------------------------------------------------------------
# SimPy processes
# ------------------------------------------------------------------
def _arrival_process(self):
"""Continuous Poisson arrival of individual jobs."""
while True:
surge = self._get_surge_multiplier(self.env.now)
rate = self._base_arrival_rate * surge
inter_arrival = float(self.rng.exponential(1.0 / rate))
yield self.env.timeout(inter_arrival)
jt_name = self._sample_job_type()
job = self._create_job(jt_name, self.env.now)
self.all_jobs[job.job_id] = job
self.env.process(self._process_job(job))
def _batch_arrival_process(self):
"""Truck arrival every 45 min delivering configurable batch of orders.
Interval: 30-60 min between truck docks is typical for mid-scale DCs.
Batch size: 20-60 items per truck unload.
Ref: Bartholdi & Hackman (2019), Warehouse & Distribution Science.
"""
while True:
yield self.env.timeout(45.0) # 45 min interval β€” within 30-60 min published range
half = max(1, self._batch_arrival_size // 2)
batch_size = int(self.rng.integers(half, self._batch_arrival_size + 1))
for _ in range(batch_size):
jt_name = self._sample_job_type()
job = self._create_job(jt_name, self.env.now)
self.all_jobs[job.job_id] = job
self.env.process(self._process_job(job))
def _station_breakdown_process(self, station: StationState):
"""Per-station breakdown process; rate and repair time are configurable.
BREAKDOWN_PROB = 0.003/min: at 37 stations Γ— 600 min, expected total
breakdown exposure β‰ˆ 2.7%, within published 2-5% range.
Ref: Inman (1999), Prod. & Inv. Mgmt. Journal 40(2):67-71.
Repair time mean = 18 min (Exponential): within 10-30 min MTTR for
conveyor/AGV equipment in warehouse environments.
Ref: Goetschalckx & Ashayeri (1989), Logistics World 2(2):99-106.
"""
while True:
ttf = float(self.rng.exponential(1.0 / max(self._breakdown_prob, 1e-9)))
yield self.env.timeout(ttf)
station.is_broken = True
repair_time = float(self.rng.exponential(18.0)) # mean 18 min MTTR
station.repair_end_time = self.env.now + repair_time
yield self.env.timeout(repair_time)
station.is_broken = False
self._trigger_dispatcher(station.zone_id)
def _lunch_break_process(self):
"""Lunch break from t=300 to t=360 (13:00-14:00)."""
yield self.env.timeout(300.0)
self._lunch_active = True
yield self.env.timeout(60.0)
self._lunch_active = False
def _priority_escalation_process(self):
"""Every 5 minutes, escalate 5% of standard waiting jobs."""
while True:
yield self.env.timeout(5.0)
waiting = [
j for j in self.all_jobs.values()
if j.status == "waiting" and j.priority == 1 and not j.priority_escalated
]
n_escalate = max(0, int(len(waiting) * 0.05))
if n_escalate:
chosen = self.rng.choice(len(waiting), size=n_escalate, replace=False)
for idx in chosen:
waiting[idx].priority = 2
waiting[idx].priority_escalated = True
def _snapshot_process(self):
"""Record queue depths every 5 minutes."""
while True:
self._record_queue_snapshot()
yield self.env.timeout(5.0)
# ------------------------------------------------------------------
# Job processing
# ------------------------------------------------------------------
def _process_job(self, job: Job):
"""Route a job through all its operations sequentially."""
for op_idx, op in enumerate(job.operations):
zone_id = op.zone_id
self.zone_queues[zone_id].append(job)
job.status = "waiting"
job._dispatch_event = self.env.event()
self._trigger_dispatcher(zone_id)
yield job._dispatch_event
station_id = self._pick_station(zone_id)
op.station_id = station_id
resource = self.station_resources[station_id]
st = self.stations[station_id]
st.current_job = job.job_id
with resource.request() as req:
yield req
# Re-check breakdown: station may have broken while job was queued.
while st.is_broken:
wait_time = max(0.1, st.repair_end_time - self.env.now)
yield self.env.timeout(wait_time)
job.status = "processing"
job.current_op_idx = op_idx
# Lognormal sigma = 0.30 β†’ CV β‰ˆ 30%, within published 20-35% range
# Ref: De Koster et al. (2007), EJOR 182(2):481-501
variability = float(self.rng.lognormal(0, 0.30))
lunch_penalty = self._lunch_penalty_factor if self._lunch_active else 1.0
actual_time = op.nominal_proc_time * variability * lunch_penalty
op.actual_proc_time = actual_time
op.start_time = self.env.now
self._zone_busy_time[zone_id] = (
self._zone_busy_time.get(zone_id, 0.0) + actual_time
)
yield self.env.timeout(actual_time)
op.end_time = self.env.now
st.busy_until = self.env.now
st.current_job = None
self._trigger_dispatcher(zone_id)
# Job fully processed
job.status = "done"
job.completion_time = self.env.now
job.current_op_idx = len(job.operations)
self.completed_jobs.append(job)
def _trigger_dispatcher(self, zone_id: int):
"""Wake up the zone dispatcher if it's idle."""
if not self.dispatcher_triggers[zone_id].triggered:
self.dispatcher_triggers[zone_id].succeed()
def _zone_dispatcher(self, zone_id: int):
"""Centralized dispatcher process for a zone."""
while True:
yield self.dispatcher_triggers[zone_id]
self.dispatcher_triggers[zone_id] = self.env.event()
while True:
queue = self.zone_queues[zone_id]
if not queue:
break
free_stations = [
sid for sid, st in self.stations.items()
if st.zone_id == zone_id and not st.is_broken
and self.station_resources[sid].count + len(self.station_resources[sid].queue) == 0
]
if not free_stations:
break
ordered = self.heuristic_fn(queue, self.env.now, zone_id)
best_job = ordered[0]
queue.remove(best_job)
best_job._dispatch_event.succeed()
yield self.env.timeout(0)
def _pick_station(self, zone_id: int) -> int:
"""Pick a free non-broken station, else fallback to least-busy."""
free_stations = [
sid for sid, st in self.stations.items()
if st.zone_id == zone_id and not st.is_broken
and self.station_resources[sid].count + len(self.station_resources[sid].queue) == 0
]
if free_stations:
return free_stations[0]
zone_stations = [
sid for sid, st in self.stations.items()
if st.zone_id == zone_id and not st.is_broken
]
if not zone_stations:
zone_stations = [sid for sid, st in self.stations.items() if st.zone_id == zone_id]
return min(zone_stations, key=lambda sid: self.stations[sid].busy_until)
# ------------------------------------------------------------------
# Streaming API (for WebSocket backend)
# ------------------------------------------------------------------
def init(self) -> None:
"""Set up all SimPy processes without running. Call step_to() to advance."""
self._lunch_active = False
self._processes_registered = True
self.env.process(self._arrival_process())
self.env.process(self._batch_arrival_process())
self.env.process(self._priority_escalation_process())
self.env.process(self._lunch_break_process())
self.env.process(self._snapshot_process())
for zone_id in self.zones:
self.env.process(self._zone_dispatcher(zone_id))
for station in self.stations.values():
self.env.process(self._station_breakdown_process(station))
def step_to(self, t: float) -> None:
"""Advance simulation to time t (must have called init() first)."""
self.env.run(until=t)
def get_visual_snapshot(self) -> Dict[str, Any]:
"""Return the current visual state for the frontend canvas."""
now = self.env.now
completed = self.completed_jobs
n = len(completed)
total_tard = sum(max(0.0, j.completion_time - j.due_date) for j in completed)
n_late = sum(1 for j in completed if j.completion_time > j.due_date)
sla = n_late / n if n else 0.0
avg_cycle = (sum(j.completion_time - j.arrival_time for j in completed) / n
if n else 0.0)
throughput = (n / max(now, 0.001)) * 60.0
active_jobs: List[Dict[str, Any]] = []
for zone_id, queue in self.zone_queues.items():
for job in queue:
active_jobs.append({
"id": job.job_id, "type": job.job_type,
"zoneId": zone_id, "status": "waiting",
"priority": job.priority,
})
for job in self.all_jobs.values():
if job.status == "processing" and job.current_op_idx < len(job.operations):
active_jobs.append({
"id": job.job_id, "type": job.job_type,
"zoneId": job.operations[job.current_op_idx].zone_id,
"status": "processing",
"priority": job.priority,
})
active_jobs = active_jobs[:50]
zone_active = [
sum(1 for j in self.all_jobs.values()
if j.status == "processing"
and j.current_op_idx < len(j.operations)
and j.operations[j.current_op_idx].zone_id == z)
for z in range(8)
]
return {
"time": round(now, 2),
"activeJobs": active_jobs,
"zoneQueueLengths": [len(self.zone_queues.get(z, [])) for z in range(8)],
"zoneActiveCounts": zone_active,
"metrics": {
"completed": n,
"completedJobs": n,
"totalTardiness": round(total_tard, 1),
"slaBreachRate": round(sla, 4),
"avgCycleTime": round(avg_cycle, 2),
"throughput": round(throughput, 2),
"jobsPerHour": round(throughput, 2),
},
}
# ------------------------------------------------------------------
# Run (batch mode)
# ------------------------------------------------------------------
def run(self, duration: float = 600.0) -> SimulationMetrics:
"""Execute a full shift simulation and return performance metrics."""
if not hasattr(self, "_processes_registered") or not self._processes_registered:
self.init()
self.env.run(until=duration)
return self._compute_metrics(duration)
def _compute_metrics(self, duration: float) -> SimulationMetrics:
"""Calculate all 7 performance metrics from the completed simulation."""
completed = self.completed_jobs
total_jobs = len(self.all_jobs)
n_completed = len(completed)
if not completed:
return SimulationMetrics(
makespan=duration,
zone_utilization={z: 0.0 for z in self.zones},
queue_history=self._queue_snapshots,
)
makespan = max((j.completion_time for j in completed), default=duration)
total_tardiness = sum(
max(0.0, j.completion_time - j.due_date) for j in completed
)
n_late = sum(1 for j in completed if j.completion_time > j.due_date)
sla_breach_rate = n_late / n_completed if n_completed else 0.0
avg_cycle_time = float(np.mean(
[j.completion_time - j.arrival_time for j in completed]
)) if completed else 0.0
zone_utilization = {}
for zone_id, zone in self.zones.items():
busy = self._zone_busy_time.get(zone_id, 0.0)
capacity = zone.num_stations * duration
zone_utilization[zone_id] = min(1.0, busy / capacity) if capacity > 0 else 0.0
throughput = (n_completed / duration) * 60.0
queue_max = self._max_queue
return SimulationMetrics(
makespan=makespan,
total_tardiness=total_tardiness,
sla_breach_rate=sla_breach_rate,
avg_cycle_time=avg_cycle_time,
zone_utilization=zone_utilization,
throughput=throughput,
queue_max=queue_max,
queue_history=self._queue_snapshots,
completed_jobs=n_completed,
total_jobs=total_jobs,
)
def get_state_snapshot(self) -> Dict[str, Any]:
"""Return current system state for feature extraction."""
now = self.env.now
n_broken = sum(1 for st in self.stations.values() if st.is_broken)
queue_sizes = {z: len(q) for z, q in self.zone_queues.items()}
waiting_jobs = [j for j in self.all_jobs.values() if j.status == "waiting"]
return {
"current_time": now,
"n_orders_in_system": len(waiting_jobs) + sum(
1 for j in self.all_jobs.values() if j.status == "processing"
),
"n_express_orders": sum(1 for j in waiting_jobs if j.job_type == "E"),
"queue_sizes": queue_sizes,
"zone_utilization": {
z: min(1.0, self._zone_busy_time.get(z, 0.0) / max(1.0, now * self.zones[z].num_stations))
for z in self.zones
},
"n_broken_stations": n_broken,
"lunch_active": self._lunch_active,
"surge_multiplier": self._surge_base_rate(now),
"completed_so_far": len(self.completed_jobs),
"waiting_jobs": waiting_jobs,
"completed_jobs": self.completed_jobs,
"all_jobs": self.all_jobs,
"zones": self.zones,
"stations": self.stations,
}
# ------------------------------------------------------------------
# NEW in DAHS_2: State save/restore for snapshot-fork training
# ------------------------------------------------------------------
@staticmethod
def _serialize_job(job: Job) -> Dict[str, Any]:
"""Convert a Job to a plain dict (avoids deepcopy of SimPy events)."""
return {
"job_id": job.job_id,
"job_type": job.job_type,
"arrival_time": job.arrival_time,
"due_date": job.due_date,
"operations": [
{
"zone_id": op.zone_id,
"nominal_proc_time": op.nominal_proc_time,
"actual_proc_time": op.actual_proc_time,
"start_time": op.start_time,
"end_time": op.end_time,
"station_id": op.station_id,
}
for op in job.operations
],
"current_op_idx": job.current_op_idx,
"priority": job.priority,
"status": job.status,
"completion_time": job.completion_time,
"priority_escalated": job.priority_escalated,
}
@staticmethod
def _deserialize_job(d: Dict[str, Any]) -> Job:
"""Reconstruct a Job from a plain dict."""
ops = [
Operation(
zone_id=o["zone_id"],
nominal_proc_time=o["nominal_proc_time"],
actual_proc_time=o["actual_proc_time"],
start_time=o["start_time"],
end_time=o["end_time"],
station_id=o["station_id"],
)
for o in d["operations"]
]
job = Job(
job_id=d["job_id"],
job_type=d["job_type"],
arrival_time=d["arrival_time"],
due_date=d["due_date"],
operations=ops,
current_op_idx=d["current_op_idx"],
priority=d["priority"],
status=d["status"],
completion_time=d["completion_time"],
priority_escalated=d["priority_escalated"],
)
return job
def save_state(self) -> Dict[str, Any]:
"""Capture complete simulation state for snapshot-fork training.
Returns a pickling-safe dict (no SimPy objects) containing:
- env.now (current time)
- Serialized jobs, completed_jobs, zone_queues (as job IDs)
- All station states (is_broken, repair_end_time, current_job, busy_until)
- RNG state via rng.bit_generator.state
- _job_counter, _zone_busy_time, _lunch_active, queue snapshot history
NOTE: The from_state() classmethod creates a fresh SimPy environment and
re-initializes processes from the saved data point.
"""
state = {
"env_time": self.env.now,
"seed": self.seed,
"_job_counter": self._job_counter,
"_max_queue": self._max_queue,
"_lunch_active": self._lunch_active,
"_zone_busy_time": dict(self._zone_busy_time),
"_queue_snapshots": list(self._queue_snapshots),
"rng_state": self.rng.bit_generator.state,
# Simulator config for reconstruction
"_base_arrival_rate": self._base_arrival_rate,
"_breakdown_prob": self._breakdown_prob,
"_batch_arrival_size": self._batch_arrival_size,
"_lunch_penalty_factor": self._lunch_penalty_factor,
"_job_type_frequencies": dict(self._job_type_frequencies),
"_due_date_tightness": self._due_date_tightness,
"_processing_time_scale": self._processing_time_scale,
# Serialized job data (can't deepcopy β€” SimPy events aren't picklable)
"all_jobs": {
jid: self._serialize_job(job)
for jid, job in self.all_jobs.items()
},
"completed_jobs": [self._serialize_job(j) for j in self.completed_jobs],
"zone_queues": {z: [j.job_id for j in q] for z, q in self.zone_queues.items()},
# Station states
"stations": {
sid: {
"station_id": st.station_id,
"zone_id": st.zone_id,
"is_broken": st.is_broken,
"repair_end_time": st.repair_end_time,
"current_job": st.current_job,
"busy_until": st.busy_until,
}
for sid, st in self.stations.items()
},
}
return state
@classmethod
def from_state(
cls,
state_dict: Dict[str, Any],
heuristic_fn: Callable,
) -> "WarehouseSimulator":
"""Create a new simulator from a saved state (for fork evaluation).
Creates a fresh SimPy environment initialized at saved_time,
restores all job/station/queue data, and continues RNG from saved state.
Parameters
----------
state_dict : dict
Output of save_state().
heuristic_fn : Callable
Dispatch function to use in the forked simulation.
Returns
-------
WarehouseSimulator
Ready to run from state_dict["env_time"] forward.
"""
saved_time = state_dict["env_time"]
# Reconstruct simulator with original config
sim = cls(
seed=state_dict["seed"],
heuristic_fn=heuristic_fn,
base_arrival_rate=state_dict["_base_arrival_rate"],
breakdown_prob=state_dict["_breakdown_prob"],
batch_arrival_size=state_dict["_batch_arrival_size"],
lunch_penalty_factor=state_dict["_lunch_penalty_factor"],
job_type_frequencies=state_dict["_job_type_frequencies"],
due_date_tightness=state_dict["_due_date_tightness"],
processing_time_scale=state_dict["_processing_time_scale"],
)
# Restore RNG from saved state (deterministic continuation)
sim.rng.bit_generator.state = state_dict["rng_state"]
# Restore job counter and metrics
sim._job_counter = state_dict["_job_counter"]
sim._max_queue = state_dict["_max_queue"]
sim._lunch_active = state_dict["_lunch_active"]
sim._zone_busy_time = dict(state_dict["_zone_busy_time"])
sim._queue_snapshots = list(state_dict["_queue_snapshots"])
# Restore jobs from serialized dicts
sim.all_jobs = {
jid: cls._deserialize_job(jdata)
for jid, jdata in state_dict["all_jobs"].items()
}
sim.completed_jobs = [
cls._deserialize_job(jdata)
for jdata in state_dict["completed_jobs"]
]
# Restore zone queues (using saved job IDs to reference restored jobs)
job_by_id = sim.all_jobs
for z, queue_job_ids in state_dict["zone_queues"].items():
sim.zone_queues[int(z)] = [
job_by_id[jid] for jid in queue_job_ids
if jid in job_by_id
]
# Restore station states
for sid_str, st_data in state_dict["stations"].items():
sid = int(sid_str)
if sid in sim.stations:
sim.stations[sid].is_broken = st_data["is_broken"]
sim.stations[sid].repair_end_time = st_data["repair_end_time"]
sim.stations[sid].current_job = st_data["current_job"]
sim.stations[sid].busy_until = st_data["busy_until"]
# Create a SimPy environment starting at saved_time
sim.env = simpy.Environment(initial_time=saved_time)
# Re-create SimPy resources for the new environment
for sid in sim.stations:
sim.station_resources[sid] = simpy.Resource(sim.env, capacity=1)
# Re-create dispatcher trigger events for new environment
for zone_id in sim.zones:
sim.dispatcher_triggers[zone_id] = sim.env.event()
# Re-register dispatchers and breakdown/arrival processes
sim.env.process(sim._arrival_process())
sim.env.process(sim._batch_arrival_process())
sim.env.process(sim._priority_escalation_process())
# Re-register lunch process correctly based on saved time
if saved_time < 300.0:
sim.env.process(sim._lunch_break_process())
elif saved_time < 360.0:
# Currently in lunch β€” restore the remaining lunch period
remaining_lunch = 360.0 - saved_time
def _remaining_lunch():
yield sim.env.timeout(remaining_lunch)
sim._lunch_active = False
sim.env.process(_remaining_lunch())
sim.env.process(sim._snapshot_process())
for zone_id in sim.zones:
sim.env.process(sim._zone_dispatcher(zone_id))
for station in sim.stations.values():
if station.is_broken:
remaining_repair = max(0.1, station.repair_end_time - saved_time)
def _resume_repair(st=station, t=remaining_repair):
yield sim.env.timeout(t)
st.is_broken = False
sim._trigger_dispatcher(st.zone_id)
# Continue with future breakdowns
while True:
ttf = float(sim.rng.exponential(1.0 / max(sim._breakdown_prob, 1e-9)))
yield sim.env.timeout(ttf)
st.is_broken = True
repair_time = float(sim.rng.exponential(18.0))
st.repair_end_time = sim.env.now + repair_time
yield sim.env.timeout(repair_time)
st.is_broken = False
sim._trigger_dispatcher(st.zone_id)
sim.env.process(_resume_repair())
else:
sim.env.process(sim._station_breakdown_process(station))
# Resume WAITING jobs in zone queues:
# These need a full _process_job-like coroutine that waits for dispatch
# then routes through remaining operations.
for zone_id, queue in sim.zone_queues.items():
for job in queue:
job._dispatch_event = sim.env.event()
sim.env.process(sim._resume_waiting_job(job, zone_id))
if queue:
sim._trigger_dispatcher(zone_id)
# Resume PROCESSING jobs with correct remaining time:
# At save time, op.start_time and op.actual_proc_time are set,
# but op.end_time is still -1.0 (only set after timeout completes).
# Remaining = (start_time + actual_proc_time) - saved_time
for job in sim.all_jobs.values():
if job.status == "processing" and job.current_op_idx < len(job.operations):
op = job.operations[job.current_op_idx]
if op.start_time >= 0 and op.actual_proc_time > 0:
expected_end = op.start_time + op.actual_proc_time
remaining = max(0.0, expected_end - saved_time)
else:
remaining = 0.0
sim.env.process(sim._resume_job(job, remaining))
return sim
def _resume_job(self, job: Job, remaining_time: float):
"""Continue processing a job that was in-progress at save_state time."""
op_idx = job.current_op_idx
op = job.operations[op_idx]
yield self.env.timeout(remaining_time)
op.end_time = self.env.now
# Continue with remaining operations
for next_op_idx in range(op_idx + 1, len(job.operations)):
next_op = job.operations[next_op_idx]
zone_id = next_op.zone_id
self.zone_queues[zone_id].append(job)
job.status = "waiting"
job._dispatch_event = self.env.event()
self._trigger_dispatcher(zone_id)
yield job._dispatch_event
station_id = self._pick_station(zone_id)
next_op.station_id = station_id
resource = self.station_resources[station_id]
st = self.stations[station_id]
st.current_job = job.job_id
with resource.request() as req:
yield req
while st.is_broken:
wait_time = max(0.1, st.repair_end_time - self.env.now)
yield self.env.timeout(wait_time)
job.status = "processing"
job.current_op_idx = next_op_idx
variability = float(self.rng.lognormal(0, 0.30))
lunch_penalty = self._lunch_penalty_factor if self._lunch_active else 1.0
actual_time = next_op.nominal_proc_time * variability * lunch_penalty
next_op.actual_proc_time = actual_time
next_op.start_time = self.env.now
self._zone_busy_time[zone_id] = self._zone_busy_time.get(zone_id, 0.0) + actual_time
yield self.env.timeout(actual_time)
next_op.end_time = self.env.now
st.busy_until = self.env.now
st.current_job = None
self._trigger_dispatcher(zone_id)
job.status = "done"
job.completion_time = self.env.now
job.current_op_idx = len(job.operations)
self.completed_jobs.append(job)
def _resume_waiting_job(self, job: Job, current_zone_id: int):
"""Resume a job that was waiting in a zone queue at save_state time.
This replaces the missing _process_job coroutine for waiting jobs
restored via from_state(). The job waits for dispatch in its current
zone, processes that operation, then routes through all remaining ops.
"""
# Wait for dispatcher to select this job in the current zone
yield job._dispatch_event
# Process the current operation (the one the job was waiting for)
op_idx = job.current_op_idx
op = job.operations[op_idx]
zone_id = current_zone_id
station_id = self._pick_station(zone_id)
op.station_id = station_id
resource = self.station_resources[station_id]
st = self.stations[station_id]
st.current_job = job.job_id
with resource.request() as req:
yield req
while st.is_broken:
wait_time = max(0.1, st.repair_end_time - self.env.now)
yield self.env.timeout(wait_time)
job.status = "processing"
job.current_op_idx = op_idx
variability = float(self.rng.lognormal(0, 0.30))
lunch_penalty = self._lunch_penalty_factor if self._lunch_active else 1.0
actual_time = op.nominal_proc_time * variability * lunch_penalty
op.actual_proc_time = actual_time
op.start_time = self.env.now
self._zone_busy_time[zone_id] = self._zone_busy_time.get(zone_id, 0.0) + actual_time
yield self.env.timeout(actual_time)
op.end_time = self.env.now
st.busy_until = self.env.now
st.current_job = None
self._trigger_dispatcher(zone_id)
# Continue with remaining operations (same as _resume_job)
for next_op_idx in range(op_idx + 1, len(job.operations)):
next_op = job.operations[next_op_idx]
next_zone_id = next_op.zone_id
self.zone_queues[next_zone_id].append(job)
job.status = "waiting"
job._dispatch_event = self.env.event()
self._trigger_dispatcher(next_zone_id)
yield job._dispatch_event
station_id = self._pick_station(next_zone_id)
next_op.station_id = station_id
resource = self.station_resources[station_id]
st = self.stations[station_id]
st.current_job = job.job_id
with resource.request() as req:
yield req
while st.is_broken:
wait_time = max(0.1, st.repair_end_time - self.env.now)
yield self.env.timeout(wait_time)
job.status = "processing"
job.current_op_idx = next_op_idx
variability = float(self.rng.lognormal(0, 0.30))
lunch_penalty = self._lunch_penalty_factor if self._lunch_active else 1.0
actual_time = next_op.nominal_proc_time * variability * lunch_penalty
next_op.actual_proc_time = actual_time
next_op.start_time = self.env.now
self._zone_busy_time[next_zone_id] = self._zone_busy_time.get(next_zone_id, 0.0) + actual_time
yield self.env.timeout(actual_time)
next_op.end_time = self.env.now
st.busy_until = self.env.now
st.current_job = None
self._trigger_dispatcher(next_zone_id)
job.status = "done"
job.completion_time = self.env.now
job.current_op_idx = len(job.operations)
self.completed_jobs.append(job)
# ------------------------------------------------------------------
# NEW in DAHS_2: Partial metrics for fork evaluation windows
# ------------------------------------------------------------------
def get_partial_metrics(self, since_time: float) -> SimulationMetrics:
"""Compute metrics only for jobs completed between since_time and env.now.
Used in the 20-minute fork evaluation window during data generation.
Parameters
----------
since_time : float
Start of evaluation window (simulation time).
Returns
-------
SimulationMetrics
Metrics computed only over jobs completed in [since_time, now].
"""
now = self.env.now
window_jobs = [
j for j in self.completed_jobs
if j.completion_time >= since_time
]
if not window_jobs:
return SimulationMetrics(
makespan=now,
zone_utilization={z: 0.0 for z in self.zones},
)
n = len(window_jobs)
total_tardiness = sum(max(0.0, j.completion_time - j.due_date) for j in window_jobs)
n_late = sum(1 for j in window_jobs if j.completion_time > j.due_date)
sla_breach_rate = n_late / n
avg_cycle_time = float(np.mean([j.completion_time - j.arrival_time for j in window_jobs]))
duration = max(now - since_time, 1.0)
throughput = (n / duration) * 60.0
zone_utilization = {
z: min(1.0, self._zone_busy_time.get(z, 0.0) / max(1.0, now * self.zones[z].num_stations))
for z in self.zones
}
return SimulationMetrics(
makespan=max(j.completion_time for j in window_jobs),
total_tardiness=total_tardiness,
sla_breach_rate=sla_breach_rate,
avg_cycle_time=avg_cycle_time,
zone_utilization=zone_utilization,
throughput=throughput,
queue_max=self._max_queue,
completed_jobs=n,
total_jobs=len(self.all_jobs),
)