| import json |
| import os |
|
|
| import torch |
|
|
| from diffusers import UNet1DModel |
|
|
|
|
| os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) |
| os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) |
|
|
| os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) |
|
|
|
|
| def unet(hor): |
| if hor == 128: |
| down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") |
| block_out_channels = (32, 128, 256) |
| up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") |
|
|
| elif hor == 32: |
| down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") |
| block_out_channels = (32, 64, 128, 256) |
| up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") |
| model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") |
| state_dict = model.state_dict() |
| config = { |
| "down_block_types": down_block_types, |
| "block_out_channels": block_out_channels, |
| "up_block_types": up_block_types, |
| "layers_per_block": 1, |
| "use_timestep_embedding": True, |
| "out_block_type": "OutConv1DBlock", |
| "norm_num_groups": 8, |
| "downsample_each_block": False, |
| "in_channels": 14, |
| "out_channels": 14, |
| "extra_in_channels": 0, |
| "time_embedding_type": "positional", |
| "flip_sin_to_cos": False, |
| "freq_shift": 1, |
| "sample_size": 65536, |
| "mid_block_type": "MidResTemporalBlock1D", |
| "act_fn": "mish", |
| } |
| hf_value_function = UNet1DModel(**config) |
| print(f"length of state dict: {len(state_dict.keys())}") |
| print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") |
| mapping = dict(zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) |
| for k, v in mapping.items(): |
| state_dict[v] = state_dict.pop(k) |
| hf_value_function.load_state_dict(state_dict) |
|
|
| torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") |
| with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: |
| json.dump(config, f) |
|
|
|
|
| def value_function(): |
| config = { |
| "in_channels": 14, |
| "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), |
| "up_block_types": (), |
| "out_block_type": "ValueFunction", |
| "mid_block_type": "ValueFunctionMidBlock1D", |
| "block_out_channels": (32, 64, 128, 256), |
| "layers_per_block": 1, |
| "downsample_each_block": True, |
| "sample_size": 65536, |
| "out_channels": 14, |
| "extra_in_channels": 0, |
| "time_embedding_type": "positional", |
| "use_timestep_embedding": True, |
| "flip_sin_to_cos": False, |
| "freq_shift": 1, |
| "norm_num_groups": 8, |
| "act_fn": "mish", |
| } |
|
|
| model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") |
| state_dict = model |
| hf_value_function = UNet1DModel(**config) |
| print(f"length of state dict: {len(state_dict.keys())}") |
| print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") |
|
|
| mapping = dict(zip(state_dict.keys(), hf_value_function.state_dict().keys())) |
| for k, v in mapping.items(): |
| state_dict[v] = state_dict.pop(k) |
|
|
| hf_value_function.load_state_dict(state_dict) |
|
|
| torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") |
| with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: |
| json.dump(config, f) |
|
|
|
|
| if __name__ == "__main__": |
| unet(32) |
| |
| value_function() |
|
|