| --- |
| license: mit |
| --- |
|  |
| # Usage |
|
|
| **Instantiate the Base Model** |
| ```python |
| from braindecode.models import SignalJEPA |
| from huggingface_hub import hf_hub_download |
| |
| weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth") |
| model_state_dict = torch.load(weights_path) |
| |
| # Signal-related arguments |
| # raw: mne.io.BaseRaw |
| chs_info = raw.info["chs"] |
| sfreq = raw.info["sfreq"] |
| |
| model = SignalJEPA( |
| sfreq=sfreq, |
| input_window_seconds=2, |
| chs_info=chs_info, |
| ) |
| missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) |
| assert unexpected_keys == [] |
| # The spatial positional encoder is initialized using the `chs_info`: |
| assert set(missing_keys) == {"pos_encoder.pos_encoder_spat.weight"} |
| ``` |
|
|
| **Instantiate the Downstream Architectures** |
|
|
| Contrary to the base model, the downstream architectures are equipped with a classification head which is not pre-trained. |
| Guetschel et al. (2024) [arXiv:2403.11772](https://arxiv.org/abs/2403.11772) introduce three downstream architectures: |
| - a) Contextual downstream architecture |
| - b) Post-local downstream architecture |
| - c) Pre-local architecture |
|
|
| ```python |
| from braindecode.models import ( |
| SignalJEPA_Contextual, |
| SignalJEPA_PreLocal, |
| SignalJEPA_PostLocal, |
| ) |
| from huggingface_hub import hf_hub_download |
| |
| weights_path = hf_hub_download(repo_id="braindecode/SignalJEPA", filename="signal-jepa_16s-60_adeuwv4s.pth") |
| model_state_dict = torch.load(weights_path) |
| |
| # Signal-related arguments |
| # raw: mne.io.BaseRaw |
| chs_info = raw.info["chs"] |
| sfreq = raw.info["sfreq"] |
| |
| # The downstream architectures are equipped with an additional classification head |
| # which was not pre-trained. It has the following new parameters: |
| final_layer_keys = { |
| "final_layer.spat_conv.weight", |
| "final_layer.spat_conv.bias", |
| "final_layer.linear.weight", |
| "final_layer.linear.bias", |
| } |
| |
| |
| # a) Contextual downstream architecture |
| # ---------------------------------- |
| model = SignalJEPA_Contextual( |
| sfreq=sfreq, |
| input_window_seconds=2, |
| chs_info=chs_info, |
| n_outputs=1, |
| ) |
| missing_keys, unexpected_keys = model.load_state_dict(model_state_dict, strict=False) |
| assert unexpected_keys == [] |
| # The spatial positional encoder is initialized using the `chs_info`: |
| assert set(missing_keys) == final_layer_keys | {"pos_encoder.pos_encoder_spat.weight"} |
| |
| # In the post-local (b) and pre-local (c) architectures, the transformer is discarded: |
| FILTERED_model_state_dict = { |
| k: v for k, v in model_state_dict.items() if not any(k.startswith(pre) for pre in ["transformer.", "pos_encoder."]) |
| } |
| |
| |
| # b) Post-local downstream architecture |
| # ---------------------------------- |
| model = SignalJEPA_PostLocal( |
| sfreq=sfreq, |
| input_window_seconds=2, |
| n_chans=len(chs_info), # detailed channel info is not needed for this model |
| n_outputs=1, |
| ) |
| missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False) |
| assert unexpected_keys == [] |
| assert set(missing_keys) == final_layer_keys |
| |
| |
| # c) Pre-local architecture |
| # ---------------------- |
| model = SignalJEPA_PreLocal( |
| sfreq=sfreq, |
| input_window_seconds=2, |
| n_chans=len(chs_info), # detailed channel info is not needed for this model |
| n_outputs=1, |
| ) |
| missing_keys, unexpected_keys = model.load_state_dict(FILTERED_model_state_dict, strict=False) |
| assert unexpected_keys == [] |
| assert set(missing_keys) == { |
| "spatial_conv.1.weight", |
| "spatial_conv.1.bias", |
| "final_layer.1.weight", |
| "final_layer.1.bias", |
| } |
| ``` |