| import torch |
| import torch.nn as nn |
| import torch.nn. functional as F |
| from optimizer_schedule import ScheduledOptim |
| import tqdm |
| from torch.optim import Adam |
|
|
|
|
| class BERTTrainer: |
| def __init__( |
| self, |
| model, |
| train_dataloader, |
| test_dataloader=None, |
| lr= 1e-4, |
| weight_decay=0.01, |
| betas=(0.9, 0.999), |
| warmup_steps=10000, |
| log_freq=10, |
| device='cuda' |
| ): |
|
|
| self.device = device |
| self.model = model |
| self.train_data = train_dataloader |
| self.test_data = test_dataloader |
|
|
| |
| self.optim = Adam(self.model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) |
| self.optim_schedule = ScheduledOptim( |
| self.optim, self.model.bert.d_model, n_warmup_steps=warmup_steps |
| ) |
|
|
| |
| self.criterion = torch.nn.NLLLoss(ignore_index=0) |
| self.log_freq = log_freq |
| print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()])) |
| |
| def train(self, epoch): |
| self.iteration(epoch, self.train_data) |
|
|
| def test(self, epoch): |
| self.iteration(epoch, self.test_data, train=False) |
|
|
| def iteration(self, epoch, data_loader, train=True): |
| |
| avg_loss = 0.0 |
| total_correct = 0 |
| total_element = 0 |
| |
| mode = "train" if train else "test" |
|
|
| |
| data_iter = tqdm.tqdm( |
| enumerate(data_loader), |
| desc="EP_%s:%d" % (mode, epoch), |
| total=len(data_loader), |
| bar_format="{l_bar}{r_bar}" |
| ) |
|
|
| for i, data in data_iter: |
|
|
| |
| data = {key: value.to(self.device) for key, value in data.items()} |
|
|
| |
| next_sent_output, mask_lm_output = self.model.forward(data["bert_input"], data["segment_label"]) |
|
|
| |
| next_loss = self.criterion(next_sent_output, data["is_next"]) |
|
|
| |
| |
| |
| mask_loss = self.criterion(mask_lm_output.transpose(1, 2), data["bert_label"]) |
|
|
| |
| loss = next_loss + mask_loss |
|
|
| |
| if train: |
| self.optim_schedule.zero_grad() |
| loss.backward() |
| self.optim_schedule.step_and_update_lr() |
|
|
| |
| correct = next_sent_output.argmax(dim=-1).eq(data["is_next"]).sum().item() |
| avg_loss += loss.item() |
| total_correct += correct |
| total_element += data["is_next"].nelement() |
|
|
| post_fix = { |
| "epoch": epoch, |
| "iter": i, |
| "avg_loss": avg_loss / (i + 1), |
| "avg_acc": total_correct / total_element * 100, |
| "loss": loss.item() |
| } |
|
|
| if i % self.log_freq == 0: |
| data_iter.write(str(post_fix)) |
| print( |
| f"EP{epoch}, {mode}: \ |
| avg_loss={avg_loss / len(data_iter)}, \ |
| total_acc={total_correct * 100.0 / total_element}" |
| ) |