You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
(gradcam)_____@server3090-X570-AORUS-PRO-WIFI:~/Grad-CAM.pytorch-master$ python main.py
feature shape:torch.Size([1, 512, 7, 7])
/home/____/.conda/envs/gradcam/lib/python3.8/site-packages/torch/nn/modules/module.py:1033: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
feature shape:torch.Size([1, 512, 7, 7])
ps.最终确实能够生成图,但明显不是基于我自己的模型来生成的。
以下是自己修改后的main.py:
-- coding: utf-8 --
import argparse
import os
import re
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
import cv2
import numpy as np
import torch
from skimage import io
from torch import nn
from torchvision import models
from interpretability.grad_cam import GradCAM, GradCamPlusPlus
from interpretability.guided_back_propagation import GuidedBackPropagation
def get_net(net_name, weight_path=None):
"""
根据名称获取模型
:param net_name: 网络名称
:param weight_path: 与训练权重路径
:return:
"""
pretrain = weight_path is None # 没有指定权重路径,加载默认的预训练权重
if net_name in ['vgg', 'vgg16']:
net = models.vgg16(pretrained=pretrain)
elif net_name in ['resnet', 'resnet18']:
net = models.resnet18(pretrained=pretrain)
else:
raise ValueError('invalid network name:{}'.format(net_name))
# 加载指定路径的权重参数
if weight_path is not None and net_name.startswith('densenet'):
pattern = re.compile(
r'^(.*denselayer\d+.(?:norm|relu|conv)).((?:[12]).(?:weight|bias|running_mean|running_var))$')
state_dict = torch.load(weight_path)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
net.load_state_dict(state_dict)
elif weight_path is not None:
net.load_state_dict({k.replace('resnet18.',''):v for k,v in torch.load(weight_path).items()},strict=False)
return net
def get_last_conv_name(net):
"""
获取网络的最后一个卷积层的名字
:param net:
:return:
"""
layer_name = None
for name, m in net.named_modules():
if isinstance(m, nn.Conv2d):
layer_name = name
return layer_name
作者您好,
首先感谢您的代码贡献,非常简洁,关键注释非常清晰!已按照readme已经成功跑通示例~目前希望依托您的代码框架,进一步想试一试引入自己的预训练网络,生成gradcam,进行图像异常检测。现有一个基于Res18预训练模型,前面添加了head,后面添加了几层额外的卷积层和fc层(最终输出分别是0:正常、1:异常),对自己的正常数据集进行无监督学习训练得到权重,用于对异常图片进行异常检测。然后利用gradcam在异常图像上标注出异常的位置。
现有问题是如何将前面提到的自己的模型权重引入框架?自己试了试之后会有如下报错,请问可能是什么问题?
(gradcam)_____@server3090-X570-AORUS-PRO-WIFI:~/Grad-CAM.pytorch-master$ python main.py
feature shape:torch.Size([1, 512, 7, 7])
/home/____/.conda/envs/gradcam/lib/python3.8/site-packages/torch/nn/modules/module.py:1033: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "
feature shape:torch.Size([1, 512, 7, 7])
ps.最终确实能够生成图,但明显不是基于我自己的模型来生成的。
以下是自己修改后的main.py:
-- coding: utf-8 --
import argparse
import os
import re
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
import cv2
import numpy as np
import torch
from skimage import io
from torch import nn
from torchvision import models
from interpretability.grad_cam import GradCAM, GradCamPlusPlus
from interpretability.guided_back_propagation import GuidedBackPropagation
def get_net(net_name, weight_path=None):
"""
根据名称获取模型
:param net_name: 网络名称
:param weight_path: 与训练权重路径
:return:
"""
pretrain = weight_path is None # 没有指定权重路径,加载默认的预训练权重
if net_name in ['vgg', 'vgg16']:
net = models.vgg16(pretrained=pretrain)
elif net_name in ['resnet', 'resnet18']:
net = models.resnet18(pretrained=pretrain)
else:
raise ValueError('invalid network name:{}'.format(net_name))
# 加载指定路径的权重参数
if weight_path is not None and net_name.startswith('densenet'):
pattern = re.compile(
r'^(.*denselayer\d+.(?:norm|relu|conv)).((?:[12]).(?:weight|bias|running_mean|running_var))$')
state_dict = torch.load(weight_path)
for key in list(state_dict.keys()):
res = pattern.match(key)
if res:
new_key = res.group(1) + res.group(2)
state_dict[new_key] = state_dict[key]
del state_dict[key]
net.load_state_dict(state_dict)
elif weight_path is not None:
net.load_state_dict({k.replace('resnet18.',''):v for k,v in torch.load(weight_path).items()},strict=False)
return net
def get_last_conv_name(net):
"""
获取网络的最后一个卷积层的名字
:param net:
:return:
"""
layer_name = None
for name, m in net.named_modules():
if isinstance(m, nn.Conv2d):
layer_name = name
return layer_name
def prepare_input(image):
image = image.copy()
def gen_cam(image, mask):
"""
生成CAM图
:param image: [H,W,C],原始图像
:param mask: [H,W],范围0~1
:return: tuple(cam,heatmap)
"""
# mask转为heatmap
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
heatmap = np.float32(heatmap) / 255
heatmap = heatmap[..., ::-1] # gbr to rgb
def norm_image(image):
"""
标准化图像
:param image: [H,W,C]
:return:
"""
image = image.copy()
image -= np.max(np.min(image), 0)
image /= np.max(image)
image *= 255.
return np.uint8(image)
def gen_gb(grad):
"""
生guided back propagation 输入图像的梯度
:param grad: tensor,[3,H,W]
:return:
"""
# 标准化
grad = grad.data.numpy()
gb = np.transpose(grad, (1, 2, 0))
return gb
def save_image(image_dicts, input_image_name, network, output_dir):
prefix = os.path.splitext(input_image_name)[0]
for key, image in image_dicts.items():
io.imsave(os.path.join(output_dir, '{}-{}-{}.jpg'.format(prefix, network, key)), image)
def main(args):
# 输入
img = io.imread(args.image_path)
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = np.float32(cv2.resize(img, (224, 224))) / 255
inputs = prepare_input(img)
# 输出图像
image_dict = {}
# 网络
net = get_net(args.network, args.weight_path)
# Grad-CAM
layer_name = get_last_conv_name(net) if args.layer_name is None else args.layer_name
grad_cam = GradCAM(net, layer_name)
mask = grad_cam(inputs, args.class_id) # cam mask
image_dict['cam'], image_dict['heatmap'] = gen_cam(img, mask)
grad_cam.remove_handlers()
# Grad-CAM++
grad_cam_plus_plus = GradCamPlusPlus(net, layer_name)
mask_plus_plus = grad_cam_plus_plus(inputs, args.class_id) # cam mask
image_dict['cam++'], image_dict['heatmap++'] = gen_cam(img, mask_plus_plus)
grad_cam_plus_plus.remove_handlers()
if name == 'main':
parser = argparse.ArgumentParser()
parser.add_argument('--network', type=str, default='resnet18',
help='ImageNet classification network')
parser.add_argument('--image-path', type=str, default='./Cutpaste_examples/icecream/XGQK_test.jpg',
help='input image path')
parser.add_argument('--weight-path', type=str, default='./Cutpaste_examples/icecream/model-icecream-cutpaste-normal.pth',
help='weight path of the model')
parser.add_argument('--layer-name', type=str, default=None,
help='last convolutional layer name')
parser.add_argument('--class-id', type=int, default=None,
help='class id')
parser.add_argument('--output-dir', type=str, default='results',
help='output directory to save results')
arguments = parser.parse_args()
The text was updated successfully, but these errors were encountered: