From b367f2ec423fc2da2fa5fc4142c643317c2437e6 Mon Sep 17 00:00:00 2001 From: Svelte Date: Sun, 30 Jun 2024 18:47:43 +0800 Subject: [PATCH] Changed from torch.cuda.amp.autocast to torch.amp.autocast torch.cuda.amp.autocast to be deprecated --- references/detection/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/engine.py b/references/detection/engine.py index 0e9bfffdf8a..fa0f4fe01db 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -27,7 +27,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets] - with torch.cuda.amp.autocast(enabled=scaler is not None): + with torch.amp.autocast("cuda" if torch.cuda.is_available() else "cpu", enabled=scaler is not None): loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values())