1. 포켓몬 분류
* Train: https://www.kaggle.com/datasets/thedagger/pokemon-generation-one
* Validation: https://www.kaggle.com/hlrhegemony/pokemon-image-dataset kaggle.comkaggle.com
◼ Kaggle API를 사용하여 Pokemon 데이터셋을 다운로드하고 압축을 해제
import os
os.environ['KAGGLE_USERNAME'] = 'sarahlee721'
os.environ['KAGGLE_KEY'] = 'b5dc455974daae96540089d7bcdf062a'
!kaggle datasets download -d thedagger/pokemon-generation-one
!unzip -q pokemon-generation-one.zip
!kaggle datasets download -d hlrhegemony/pokemon-image-dataset
!unzip -q pokemon-image-dataset.zip
|
◼ 폴더의 이름을 변경, 이전의 폴더를 삭제
!mv dataset train
!rm -rf train/dataset
|
◼ 데이터셋 확인
train_labels = os.listdir('train')
print(train_labels)
print(len(train_labels))
val_labels = os.listdir('validation')
print(val_labels)
print(len(val_labels))
|
['Arcanine', 'Gyarados', 'Ponyta', 'Jigglypuff', 'Gloom', 'Lickitung', 'Hitmonlee', 'MrMime', 'Chansey', 'Weezing', 'Arbok', 'Kadabra', 'Exeggutor', 'Rhydon', 'Hypno', 'Oddish', 'Scyther', 'Abra', 'Graveler', 'Starmie', 'Ditto', 'Zapdos', 'Golem', 'Pinsir', 'Rhyhorn', 'Jolteon', 'Gengar', 'Articuno', 'Venomoth', 'Nidoqueen', 'Seaking', 'Onix', 'Persian', 'Aerodactyl', 'Tentacool', 'Charizard', 'Snorlax', 'Drowzee', 'Sandslash', 'Hitmonchan', 'Growlithe', 'Magneton', 'Flareon', 'Farfetchd', 'Clefable', 'Clefairy', 'Kabutops', 'Poliwhirl', 'Charmeleon', 'Spearow', 'Venonat', 'Slowpoke', 'Wigglytuff', 'Omanyte', 'Lapras', 'Wartortle', 'Cloyster', 'Gastly', 'Machamp', 'Nidoking', 'Haunter', 'Magmar', 'Krabby', 'Mewtwo', 'Porygon', 'Bellsprout', 'Pidgeotto', 'Geodude', 'Meowth', 'Kabuto', 'Golbat', 'Weedle', 'Seel', 'Dodrio', 'Dewgong', 'Butterfree', 'Pidgeot', 'Psyduck', 'Jynx', 'Victreebel', 'Slowbro', 'Grimer', 'Mew', 'Mankey', 'Shellder', 'Raticate', 'Fearow', 'Dragonair', 'Marowak', 'Parasect', 'Metapod', 'Venusaur', 'Muk', 'Tauros', 'Eevee', 'Exeggcute', 'Raichu', 'Voltorb', 'Magikarp', 'Pidgey', 'Pikachu', 'Poliwrath', 'Staryu', 'Tangela', 'Vileplume', 'Machoke', 'Koffing', 'Doduo', 'Kakuna', 'Electrode', 'Machop', 'Goldeen', 'Blastoise', 'Primeape', 'Alakazam', 'Seadra', 'Diglett', 'Tentacruel', 'Nidorina', 'Weepinbell', 'Ninetales', 'Vulpix', 'Rattata', 'Rapidash', 'Charmander', 'Golduck', 'Kingler', 'Moltres', 'Sandshrew', 'Caterpie', 'Dugtrio', 'Electabuzz', 'Nidorino', 'Squirtle', 'Kangaskhan', 'Vaporeon', 'Zubat', 'Magnemite', 'Paras', 'Dratini', 'Ekans', 'Ivysaur', 'Horsea', 'Beedrill', 'Omastar', 'Cubone', 'Dragonite', 'Bulbasaur', 'Poliwag'] 149 ------------------------------------------------------------------------------------------------------------------------------------------------------------- ['Araquanid', 'Litleo', 'Galvantula', 'Blipbug', 'Houndour', 'Turtwig', 'Dracozolt', 'Arcanine', 'Ambipom', 'Leafeon', 'Whiscash', 'Skiploom', 'Wooper', 'Flab├йb├й', 'Gyarados', 'Type Null', 'Amaura', 'Ponyta', 'Torterra', 'Bisharp', 'Gourgeist', 'Jigglypuff', 'Rhyperior', 'Tornadus', 'Solgaleo', 'Lunatone', 'Lombre', 'Hippowdon', 'Minun', 'Marshadow', 'Drapion', 'Urshifu', 'Fraxure', 'Gloom', 'Lickitung', 'Ninjask', 'Shuppet', 'Groudon', 'Vigoroth', 'Hitmonlee', 'Cosmog', 'Diancie', 'Xerneas', 'Heatmor', 'Tropius', 'Magby', 'Yamask', 'Murkrow', 'Dragapult', 'Chansey', 'Weezing', 'Arbok', 'Lickilicky', 'Orbeetle', 'Magearna', 'Dewpider', 'Deino', 'Nosepass', 'Skarmory', ... 'Kangaskhan', 'Lopunny', 'Meowstic', 'Vaporeon', 'Zubat', 'Weavile', 'Leavanny', 'Magnemite', 'Cinderace', 'Aurorus', 'Reuniclus', 'Minccino', 'Comfey', 'Cottonee', 'Claydol', 'Croconaw', 'Panpour', 'Stunfisk', 'Litwick', 'Shiftry', 'Paras', 'Foongus', 'Dratini', 'Ledian', 'Stonjourner', 'Zygarde', 'Ekans', 'Rayquaza', 'Litten', 'Porygon2', 'Tsareena', 'Florges', 'Primarina', 'Fennekin', 'Glastrier', 'Pachirisu', 'Drakloak', 'Ivysaur', 'Indeedee', 'Corvisquire', 'Horsea', 'Froslass', 'Espeon', 'Morgrem', 'Beedrill', 'Barbaracle', 'Pansage', 'Darumaka', 'Remoraid', 'Omastar', 'Spiritomb', 'Cryogonal', 'Mienshao', 'Metagross', 'Dwebble', 'Swanna', 'Glalie', 'Toxel', 'Throh', 'Ho-oh', 'Slugma', 'Palossand', 'Roserade', 'Liepard', 'Spinarak', 'Scolipede', 'Darkrai', 'Yamper', 'Falinks', 'Brionne', 'Aromatisse', 'Cubone', 'Noibat', 'Mawile', 'Swirlix', 'Beheeyem', 'Deoxys', 'Calyrex', 'Sandaconda', 'Dracovish', 'Hoppip', 'Morelull', 'Dragonite', 'Pineco', 'Bulbasaur', 'Silicobra', 'Rufflet', 'Aggron', 'Croagunk', 'Palpitoad', 'Wailord', 'Poliwag', 'Steelix', 'Gorebyss', 'Crabrawler', 'Klefki', 'Surskit', 'Bewear', 'Mantyke', 'Zacian', 'Emboar'] 898 |
◼ train데이터에서 없는 속성을 validation데이터에서 지우기
import shutil
for val_label in val_labels:
if val_label not in train_labels: # 존재하지 않으면 삭제
shutil.rmtree(os.path.join('validation', val_label))
val_labels = os.listdir('validation')
print(len(val_labels))
# 다시 확인 => 2개 오차 있는것 같음
val_labels = os.listdir('validation')
print(len(train_labels))
print(len(val_labels))
for train_label in train_labels:
if train_label not in val_labels:
print(train_label)
os.makedirs(os.path.join('validation', train_label), exist_ok = True)
val_labels = os.listdir('validation')
print(len(val_labels))
|
해석:
|
147 ------------------------- 149 147 ------------------------
MrMime Farfetchd ------------------------ 149 |
◼ 모듈 import
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
|
◼ 이미지 증강 기법
data_transform = {
'train': transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomAffine(0, shear = 10, scale = (0.8, 1.2)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
]),
'validation': transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
}
|
|
◼ 데이터셋 객체 생성
image_datasets = {
'train': datasets.ImageFolder('train', data_transform['train']),
'validation': datasets.ImageFolder('validation', data_transform['validation'])
}
|
◼ 데이터 로더 만들기
dataloaders = {
'train': DataLoader(
image_datasets['train'],
batch_size = 32,
shuffle = True
),
'validation': DataLoader(
image_datasets['validation'],
batch_size = 32,
shuffle = False
)
}
print(len(image_datasets['train']))
print(len(image_datasets['validation']))
|
10657 661 |
◼ 이미지 시각화
imgs, labels = next(iter(dataloaders['train']))
_, axes = plt.subplots(4,8,figsize=(20,10))
for ax, img, label in zip(axes.flatten(), imgs, labels):
ax.imshow(img.permute(1,2,0))
ax.set_title(label)
ax.axis('off')
|
◼ 클래스 이름 확인
image_datasets['train'].classes[101]
|
Pikachu |
2. 사전 학습된 EfficientNet 모델
* 구글의 연구팀이 개발한 이미지 분류, 객체 검출 등 컴퓨터 비전 작업에서 높은 성능을 보여주는 신경망 모델
* 신경망의 깊이, 너비, 해상도를 동시에 확장하는 방법을 통해 효율성과 성능을 극대화한 것이 특징
* EfficientnetB4는 EfficientNet 시리즈의 중간 크기 모델
- import
from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
from torchvision.models._api import WeightsEnum
from torch.hub import load_state_dict_from_url
|
◼ EfficientNetB4 모델을 사전 학습된 가중치로 초기화하는 과정
def get_state_dict(self, *args, **kwargs):
kwargs.pop("check_hash")
return load_state_dict_from_url(self.url, *args, **kwargs)
WeightsEnum.get_state_dict = get_state_dict
# 사전 학습된 EfficientNetB4 모델
model = efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1).to(device)
# model = efficientnet_b4(weights="DEFAULT").to(device)
|
Downloading: "https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b4_rwightman-23ab8bcd.pth 100%|██████████| 74.5M/74.5M [00:00<00:00, 155MB/s] |
◼ 학습하기
: 모든 파라미터를 동결하고, 새로운 분류기(classifier)를 추가하여 모델을 수정
for param in model.parameters():
param.requires_grad = False
model.classifier = nn.Sequential(
nn.Linear(1792, 512),
nn.ReLU(),
nn.Linear(512, 149)
).to(device)
print(model)
|
사용 목적:
|
|
◼ 데이터셋을 사용하여 EfficientNet 모델을 학습시키고 평가
# optimizer를 Adam으로 설정하고, classifier의 파라미터만 업데이트합니다.
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)
epochs = 10
for epoch in range(epochs):
for phase in ['train', 'validation']:
if phase == 'train':
model.train() # 모델을 학습 모드로 설정
else:
model.eval() # 모델을 평가 모드로 설정
sum_losses = 0 # 손실 합
sum_accs = 0 # 정확도 합
for x_batch, y_batch in dataloaders[phase]:
x_batch = x_batch.to(device) # 입력 데이터를 GPU로 옮김
y_batch = y_batch.to(device) # 타겟 데이터를 GPU로 옮김
y_pred = model(x_batch) # 예측값 계산
loss = nn.CrossEntropyLoss()(y_pred, y_batch) # Cross Entropy Loss를 계산하여 손실값 얻음
if phase == 'train': # 학습 단계
optimizer.zero_grad() # 기울기 초기화
loss.backward() # 역전파 수행
optimizer.step() # 가중치 업데이트
sum_losses = sum_losses + loss # 배치의 손실값을 합산
y_prob = nn.Softmax(1)(y_pred) # 모델의 출력을 softmax 함수를 사용하여 확률로 변환
y_pred_index = torch.argmax(y_prob, axis=1) # 확률이 가장 높은 클래스의 인덱스를 예측값으로 설정
acc = (y_batch == y_pred_index).float().sum() / len(y_batch) * 100 # 정확도 계산
sum_accs += acc.item() # 배치의 정확도를 합산
sum_accs = sum_accs + acc # 배치의 정확도를 합산
avg_loss = sum_losses / len(dataloaders[phase]) # epoch의 평균 손실값 계산
avg_acc = sum_accs / len(dataloaders[phase]) # epoch의 평균 정확도 계산
print(f'{phase:10s}: Epoch {epoch+1:4d}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.2f}%')
|
◼ 학습된 모델 파일 저장
# 학습된 모델 파일 저장
torch.save(model.state_dict(), 'model.pth') # model.h5 -> tensorflow
|
◼ 저장된 모델 확인
# 저장된 모델 확인
model = models.efficientnet_b4.to(device)
model.classifier = nn.Sequential(
nn.Linear(1792, 512),
nn.ReLU(),
nn.Linear(512, 149)
).to(device)
model.load_state_dict(torch.load('model.pth'))
model.eval()
|
EfficientNet-B4 모델을 로드하고, 평가 모드로 설정
사용 목적:
|
◼ 테스트
from PIL import Image
img1 = Image.open('/content/validation/Snorlax/4.jpg')
img2 = Image.open('/content/validation/Rattata/0.jpg')
fig, axes = plt.subplots(1, 2, figsize = (12, 6))
axes[0].imshow(img1)
axes[0].axis('off')
axes[1].imshow(img2)
axes[1].axis('off')
plt.show()
|
◼ 이미지 데이터를 변환하는 data_transform을 사용하여 img1과 img2를 처리한 후 각각의 shape를 출력
img1_input = data_transform['validation'](img1)
img2_input = data_transform['validation'](img2)
print(img1_input.shape)
print(img2_input.shape)
|
torch.Size([3, 224, 224]) torch.Size([3, 224, 224]) |
◼ 두 개의 이미지 데이터를 하나의 배치 텐서인 test_batch로 스택(stack)하여 만들기
test_batch를 GPU(device)로 옮기고 그 shape를 출력
test_batch = torch.stack([img1_input, img2_input])
test_batch = test_batch.to(device)
test_batch.shape
|
torch.Size([2, 3, 224, 224]) |
◼ 예측하기
y_pred = model(test_batch)
y_pred
|
tensor([[-10.1381, -18.2518, -16.0891, -13.0126, -18.1286, -17.1851, -15.6250, -13.1667, -13.5227, -6.7796, -13.8471, -9.7336, -4.9369, -13.6744, -9.3085, -8.9335, -15.6262, -10.8516, -13.3686, -11.2549, -8.4466, -6.4885, -6.5071, -16.3503, -14.5905, -11.8353, -9.9806, -8.7323, -9.7071, -8.2985, -13.0898, -8.7130, -19.6752, -13.4941, -11.8880, -9.1840, -13.9102, -17.0178, -13.1194, -11.5825, -9.4896, -17.9141, -9.5629, -16.5170, -13.3171, -12.2587, -16.4166, -18.3653, -9.4935, -13.2763, -17.6866, -13.7130, -14.7392, -14.6677, -11.6655, -10.5128, -12.3555, -7.9406, -14.8172, -6.5830, -10.8367, -12.2928, -12.8076, -17.0396, -10.1280, -13.5942, -6.7088, -11.5467, -7.6642, -5.2480, -15.8393, -15.2643, -11.4280, -10.1169, -13.1539, -14.3218, -15.0127, -15.6476, -11.6592, -10.1635, -16.6752, -11.6984, -8.3909, -15.6058, -7.2380, -10.5447, -9.6789, -7.1109, -10.7635, -12.4095, -14.6815, -10.0865, -6.6754, -6.9112, -17.0186, -13.2917, -8.1588, -10.9359, -11.8030, -11.2323, -11.5271, -9.3658, -12.8177, -13.1180, -12.1684, -10.8933, -15.0315, -12.2008, -15.7332, -7.6037, -7.8880, -15.7726, -12.9173, -9.6909, -11.2190, -13.0147, -11.9097, -15.3206, -14.9645, -10.8827, -13.4233, -8.2616, -6.5764, -9.2068, -1.1722, 4.6309, -13.8240, -8.8144, -17.5671, -16.4612, -8.7952, -13.9231, -13.7120, -12.9466, -9.8612, -14.6987, -13.3159, -9.9359, -7.8293, -7.9961, -14.1366, -12.0898, -12.5732, -6.7362, -9.3175, -12.3335, -7.4365, -24.8273, -15.5383], [-13.3796, -18.8920, -19.0383, -9.2321, -14.1503, -15.0578, -13.0850, -9.3543, -13.7065, -7.2419, -10.7370, -6.9521, -5.2882, -14.8258, -5.8944, -12.0897, -18.6506, -12.8319, -11.5861, -9.4032, -5.4389, 8.1403, -5.5555, -12.4493, -10.4225, -6.8148, -11.7280, -10.2866, -11.1347, -1.1614, -13.3300, -3.8216, -18.1594, -11.4256, -13.6750, -6.9662, -13.0039, -11.6008, -10.9354, -9.6054, -9.9303, -8.1292, -12.4365, -15.6898, -10.1333, -11.0957, -12.9833, -11.9187, -5.8141, -13.3899, -11.8386, -14.2317, -9.3236, -8.3747, -8.4686, -12.7803, -9.0560, -7.8696, -14.1533, -6.9840, -3.8049, -8.8678, -14.3514, -7.8895, -11.4716, -12.4721, -8.6673, -6.8360, -7.3675, -6.8125, -11.4952, -10.9705, -6.5105, -9.7348, -14.7709, -11.3191, -14.1895, -12.8245, -9.9535, -10.8555, -11.0447, -8.4326, -9.1383, -13.7896, -11.9255, -7.1484, -14.5495, -9.2247, -13.1285, -14.1737, -8.4069, -7.5965, -3.4051, -7.6866, -10.3734, -7.7260, -5.2773, -8.5044, -13.3983, -12.9986, -10.5447, -10.1205, -15.6450, -9.7553, -11.2016, -11.3979, -11.5908, -10.2518, -12.1675, -4.6960, -11.4753, -12.7572, -13.6541, -11.3222, -11.3799, -11.1025, -9.1730, -13.7514, -15.8559, -8.6237, -12.6485, -4.2751, -7.7025, -12.9300, -7.0018, -8.2598, -8.8410, -5.0144, -15.6619, -9.5564, -5.5501, -10.5789, -5.6651, -5.5720, -11.6338, -12.4811, -8.4071, -11.0972, -9.6168, -6.4310, -8.6936, -10.1972, -12.1732, -0.8298, -7.8869, -15.4384, -9.2720, -21.3815, -12.4833]], device='cuda:0', grad_fn=<AddmmBackward0>) |
◼ 예측결과를 확률 값으로 변환
y_prob = nn.Softmax(1)(y_pred)
y_prob
|
tensor([[-10.1381, -18.2518, -16.0891, -13.0126, -18.1286, -17.1851, -15.6250, -13.1667, -13.5227, -6.7796, -13.8471, -9.7336, -4.9369, -13.6744, -9.3085, -8.9335, -15.6262, -10.8516, -13.3686, -11.2549, -8.4466, -6.4885, -6.5071, -16.3503, -14.5905, -11.8353, -9.9806, -8.7323, -9.7071, -8.2985, -13.0898, -8.7130, -19.6752, -13.4941, -11.8880, -9.1840, -13.9102, -17.0178, -13.1194, -11.5825, -9.4896, -17.9141, -9.5629, -16.5170, -13.3171, -12.2587, -16.4166, -18.3653, -9.4935, -13.2763, -17.6866, -13.7130, -14.7392, -14.6677, -11.6655, -10.5128, -12.3555, -7.9406, -14.8172, -6.5830, -10.8367, -12.2928, -12.8076, -17.0396, -10.1280, -13.5942, -6.7088, -11.5467, -7.6642, -5.2480, -15.8393, -15.2643, -11.4280, -10.1169, -13.1539, -14.3218, -15.0127, -15.6476, -11.6592, -10.1635, -16.6752, -11.6984, -8.3909, -15.6058, -7.2380, -10.5447, -9.6789, -7.1109, -10.7635, -12.4095, -14.6815, -10.0865, -6.6754, -6.9112, -17.0186, -13.2917, -8.1588, -10.9359, -11.8030, -11.2323, -11.5271, -9.3658, -12.8177, -13.1180, -12.1684, -10.8933, -15.0315, -12.2008, -15.7332, -7.6037, -7.8880, -15.7726, -12.9173, -9.6909, -11.2190, -13.0147, -11.9097, -15.3206, -14.9645, -10.8827, -13.4233, -8.2616, -6.5764, -9.2068, -1.1722, 4.6309, -13.8240, -8.8144, -17.5671, -16.4612, -8.7952, -13.9231, -13.7120, -12.9466, -9.8612, -14.6987, -13.3159, -9.9359, -7.8293, -7.9961, -14.1366, -12.0898, -12.5732, -6.7362, -9.3175, -12.3335, -7.4365, -24.8273, -15.5383], [-13.3796, -18.8920, -19.0383, -9.2321, -14.1503, -15.0578, -13.0850, -9.3543, -13.7065, -7.2419, -10.7370, -6.9521, -5.2882, -14.8258, -5.8944, -12.0897, -18.6506, -12.8319, -11.5861, -9.4032, -5.4389, 8.1403, -5.5555, -12.4493, -10.4225, -6.8148, -11.7280, -10.2866, -11.1347, -1.1614, -13.3300, -3.8216, -18.1594, -11.4256, -13.6750, -6.9662, -13.0039, -11.6008, -10.9354, -9.6054, -9.9303, -8.1292, -12.4365, -15.6898, -10.1333, -11.0957, -12.9833, -11.9187, -5.8141, -13.3899, -11.8386, -14.2317, -9.3236, -8.3747, -8.4686, -12.7803, -9.0560, -7.8696, -14.1533, -6.9840, -3.8049, -8.8678, -14.3514, -7.8895, -11.4716, -12.4721, -8.6673, -6.8360, -7.3675, -6.8125, -11.4952, -10.9705, -6.5105, -9.7348, -14.7709, -11.3191, -14.1895, -12.8245, -9.9535, -10.8555, -11.0447, -8.4326, -9.1383, -13.7896, -11.9255, -7.1484, -14.5495, -9.2247, -13.1285, -14.1737, -8.4069, -7.5965, -3.4051, -7.6866, -10.3734, -7.7260, -5.2773, -8.5044, -13.3983, -12.9986, -10.5447, -10.1205, -15.6450, -9.7553, -11.2016, -11.3979, -11.5908, -10.2518, -12.1675, -4.6960, -11.4753, -12.7572, -13.6541, -11.3222, -11.3799, -11.1025, -9.1730, -13.7514, -15.8559, -8.6237, -12.6485, -4.2751, -7.7025, -12.9300, -7.0018, -8.2598, -8.8410, -5.0144, -15.6619, -9.5564, -5.5501, -10.5789, -5.6651, -5.5720, -11.6338, -12.4811, -8.4071, -11.0972, -9.6168, -6.4310, -8.6936, -10.1972, -12.1732, -0.8298, -7.8869, -15.4384, -9.2720, -21.3815, -12.4833]], device='cuda:0', grad_fn=<AddmmBackward0>) |
◼ 모델을 사용하여 예측을 수행하고, 예측 결과에서 상위 K개의 클래스 확률과 인덱스를 출력
probs, idx = torch.topk(model(y_prob),k=3)
print(probs)
print(idx)
|
tensor([[9.9668e-01, 3.0082e-03, 6.9708e-05], [9.9973e-01, 1.2713e-04, 9.1248e-05]], device='cuda:0', grad_fn=<TopkBackward0>) tensor([[125, 124, 12], [ 21, 143, 29]], device='cuda:0') |
◼ 예측결과 시각화
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
axes[0].set_title('{:.2f}% {}, {:.2f}% {}, {:.2f}% {}'.format(
probs[0, 0] * 100,
image_datasets['validation'].classes[idx[0, 0]],
probs[0, 1] * 100,
image_datasets['validation'].classes[idx[0, 1]],
probs[0, 2] * 100,
image_datasets['validation'].classes[idx[0, 2]],
))
axes[0].imshow(img1)
axes[0].axis('off')
axes[1].set_title('{:.2f}% {}, {:.2f}% {}, {:.2f}% {}'.format(
probs[1, 0] * 100,
image_datasets['validation'].classes[idx[1, 0]],
probs[1, 1] * 100,
image_datasets['validation'].classes[idx[1, 1]],
probs[1, 2] * 100,
image_datasets['validation'].classes[idx[1, 2]],
))
axes[1].imshow(img2)
axes[1].axis('off')
|
'AI > 딥러닝' 카테고리의 다른 글
09. 전이학습 (0) | 2024.06.21 |
---|---|
08. 간단한 CNN 모델 만들기 (0) | 2024.06.20 |
07. CNN 기초 (0) | 2024.06.20 |
06. 비선형 활성화 함수 (0) | 2024.06.20 |
05. 딥러닝 (0) | 2024.06.20 |