| from typing import * |
|
|
| import torch |
| import torch.distributed.rpc as rpc |
| from torch import Tensor |
| from torch._jit_internal import Future |
| from torch.distributed.rpc import RRef |
| from typing import Tuple |
|
|
|
|
| module_interface_cls = None |
|
|
|
|
| def forward_async(self, *args, **kwargs): |
| args = (self.module_rref, self.device, self.is_device_map_set, *args) |
| kwargs = {**kwargs} |
| return rpc.rpc_async( |
| self.module_rref.owner(), |
| _remote_forward, |
| args, |
| kwargs, |
| ) |
|
|
|
|
| def forward(self, *args, **kwargs): |
| args = (self.module_rref, self.device, self.is_device_map_set, *args) |
| kwargs = {**kwargs} |
| ret_fut = rpc.rpc_async( |
| self.module_rref.owner(), |
| _remote_forward, |
| args, |
| kwargs, |
| ) |
| return ret_fut.wait() |
|
|
|
|
| _generated_methods = [ |
| forward_async, |
| forward, |
| ] |
|
|
|
|
|
|
|
|
| def _remote_forward( |
| module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, *args, **kwargs): |
| module = module_rref.local_value() |
| device = torch.device(device) |
|
|
| if device.type != "cuda": |
| return module.forward(*args, **kwargs) |
|
|
| |
| |
| |
| |
| |
| args = (*args,) |
| out_args: Tuple[()] = () |
| for arg in args: |
| arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,) |
| out_args = out_args + arg |
|
|
| kwargs = {**kwargs} |
| for k, v in kwargs.items(): |
| if isinstance(v, Tensor): |
| kwargs[k] = kwargs[k].to(device) |
|
|
| if is_device_map_set: |
| return module.forward(*out_args, **kwargs) |
|
|
| |
| |
| |
| |
| |
| ret: Tuple[()] = () |
| for i in module.forward(*out_args, **kwargs): |
| i = (i.cpu(),) if isinstance(i, Tensor) else (i,) |
| ret = ret + i |
| return ret |
|
|