Description
hello, My question is as follows
I have downloaded model.pth from the address。
code:
`import glob
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from torchvision.models import inception_v3
from monai.bundle import download, load
data_dir = 'breast_density_classification/sample_data'
test_images = sorted(glob.glob(os.path.join(data_dir, "A", "*.jpg")))
preprocess = transforms.Compose([
transforms.Resize((299, 299)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = "breast_density_classification/models/model.pth"
model = inception_v3(pretrained=False, aux_logits=False, num_classes=4).to(device)
state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()
for img_path in test_images:
img = Image.open(img_path).convert('RGB')
img_tensor = preprocess(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.nn.functional.softmax(outputs, dim=1)
pred_class = torch.argmax(probs, dim=1).item()
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.title(f'Predicted Class: {pred_class}')
plt.axis('off')
plt.show()`
error:
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for Inception3:
Missing key(s) in state_dict: "Conv2d_1a_3x3.conv.weight", "Conv2d_1a_3x3.bn.weight",
This error indicates that the loaded state_dict does not match the defined InceptionV3 model. Did I do something wrong.
This is an urgent problem. thank you