본문 바로가기
AI/딥러닝

10. 포켓몬 분류

by 사라리24 2024. 6. 21.
SMALL



1. 포켓몬 분류

 
 * Train:
https://www.kaggle.com/datasets/thedagger/pokemon-generation-one

 * Validation: https://www.kaggle.com/hlrhegemony/pokemon-image-dataset kaggle.comkaggle.com

 

\

 

Kaggle API를 사용하여 Pokemon 데이터셋을 다운로드하고 압축을 해제



      import os

      os.environ['KAGGLE_USERNAME'] = 'sarahlee721'
      os.environ['KAGGLE_KEY'] = 'b5dc455974daae96540089d7bcdf062a'

      !kaggle datasets download -d thedagger/pokemon-generation-one
      !unzip -q pokemon-generation-one.zip

      !kaggle datasets download -d hlrhegemony/pokemon-image-dataset
      !unzip -q pokemon-image-dataset.zip
       

 

 

 

폴더의 이름을 변경, 이전의 폴더를 삭제


       
      !mv dataset train

       !rm -rf train/dataset








 

데이터셋 확인


       

        train_labels = os.listdir('train')
        print(train_labels)
        print(len(train_labels))


        val_labels = os.listdir('validation')
        print(val_labels)
        print(len(val_labels))


['Arcanine', 'Gyarados', 'Ponyta', 'Jigglypuff', 'Gloom', 'Lickitung', 'Hitmonlee', 'MrMime', 'Chansey', 'Weezing', 'Arbok', 'Kadabra', 'Exeggutor', 'Rhydon', 'Hypno', 'Oddish', 'Scyther', 'Abra', 'Graveler', 'Starmie', 'Ditto', 'Zapdos', 'Golem', 'Pinsir', 'Rhyhorn', 'Jolteon', 'Gengar', 'Articuno', 'Venomoth', 'Nidoqueen', 'Seaking', 'Onix', 'Persian', 'Aerodactyl', 'Tentacool', 'Charizard', 'Snorlax', 'Drowzee', 'Sandslash', 'Hitmonchan', 'Growlithe', 'Magneton', 'Flareon', 'Farfetchd', 'Clefable', 'Clefairy', 'Kabutops', 'Poliwhirl', 'Charmeleon', 'Spearow', 'Venonat', 'Slowpoke', 'Wigglytuff', 'Omanyte', 'Lapras', 'Wartortle', 'Cloyster', 'Gastly', 'Machamp', 'Nidoking', 'Haunter', 'Magmar', 'Krabby', 'Mewtwo', 'Porygon', 'Bellsprout', 'Pidgeotto', 'Geodude', 'Meowth', 'Kabuto', 'Golbat', 'Weedle', 'Seel', 'Dodrio', 'Dewgong', 'Butterfree', 'Pidgeot', 'Psyduck', 'Jynx', 'Victreebel', 'Slowbro', 'Grimer', 'Mew', 'Mankey', 'Shellder', 'Raticate', 'Fearow', 'Dragonair', 'Marowak', 'Parasect', 'Metapod', 'Venusaur', 'Muk', 'Tauros', 'Eevee', 'Exeggcute', 'Raichu', 'Voltorb', 'Magikarp', 'Pidgey', 'Pikachu', 'Poliwrath', 'Staryu', 'Tangela', 'Vileplume', 'Machoke', 'Koffing', 'Doduo', 'Kakuna', 'Electrode', 'Machop', 'Goldeen', 'Blastoise', 'Primeape', 'Alakazam', 'Seadra', 'Diglett', 'Tentacruel', 'Nidorina', 'Weepinbell', 'Ninetales', 'Vulpix', 'Rattata', 'Rapidash', 'Charmander', 'Golduck', 'Kingler', 'Moltres', 'Sandshrew', 'Caterpie', 'Dugtrio', 'Electabuzz', 'Nidorino', 'Squirtle', 'Kangaskhan', 'Vaporeon', 'Zubat', 'Magnemite', 'Paras', 'Dratini', 'Ekans', 'Ivysaur', 'Horsea', 'Beedrill', 'Omastar', 'Cubone', 'Dragonite', 'Bulbasaur', 'Poliwag']
149

-------------------------------------------------------------------------------------------------------------------------------------------------------------

['Araquanid', 'Litleo', 'Galvantula', 'Blipbug', 'Houndour', 'Turtwig', 'Dracozolt', 'Arcanine', 'Ambipom', 'Leafeon', 'Whiscash', 'Skiploom', 'Wooper', 'Flab├йb├й', 'Gyarados', 'Type Null', 'Amaura', 'Ponyta', 'Torterra', 'Bisharp', 'Gourgeist', 'Jigglypuff', 'Rhyperior', 'Tornadus', 'Solgaleo', 'Lunatone', 'Lombre', 'Hippowdon', 'Minun', 'Marshadow', 'Drapion', 'Urshifu', 'Fraxure', 'Gloom', 'Lickitung', 'Ninjask', 'Shuppet', 'Groudon', 'Vigoroth', 'Hitmonlee', 'Cosmog', 'Diancie', 'Xerneas', 'Heatmor', 'Tropius', 'Magby', 'Yamask', 'Murkrow', 'Dragapult', 'Chansey', 'Weezing', 'Arbok', 'Lickilicky', 'Orbeetle', 'Magearna', 'Dewpider', 'Deino', 'Nosepass', 'Skarmory',

... 

'Kangaskhan', 'Lopunny', 'Meowstic', 'Vaporeon', 'Zubat', 'Weavile', 'Leavanny', 'Magnemite', 'Cinderace', 'Aurorus', 'Reuniclus', 'Minccino', 'Comfey', 'Cottonee', 'Claydol', 'Croconaw', 'Panpour', 'Stunfisk', 'Litwick', 'Shiftry', 'Paras', 'Foongus', 'Dratini', 'Ledian', 'Stonjourner', 'Zygarde', 'Ekans', 'Rayquaza', 'Litten', 'Porygon2', 'Tsareena', 'Florges', 'Primarina', 'Fennekin', 'Glastrier', 'Pachirisu', 'Drakloak', 'Ivysaur', 'Indeedee', 'Corvisquire', 'Horsea', 'Froslass', 'Espeon', 'Morgrem', 'Beedrill', 'Barbaracle', 'Pansage', 'Darumaka', 'Remoraid', 'Omastar', 'Spiritomb', 'Cryogonal', 'Mienshao', 'Metagross', 'Dwebble', 'Swanna', 'Glalie', 'Toxel', 'Throh', 'Ho-oh', 'Slugma', 'Palossand', 'Roserade', 'Liepard', 'Spinarak', 'Scolipede', 'Darkrai', 'Yamper', 'Falinks', 'Brionne', 'Aromatisse', 'Cubone', 'Noibat', 'Mawile', 'Swirlix', 'Beheeyem', 'Deoxys', 'Calyrex', 'Sandaconda', 'Dracovish', 'Hoppip', 'Morelull', 'Dragonite', 'Pineco', 'Bulbasaur', 'Silicobra', 'Rufflet', 'Aggron', 'Croagunk', 'Palpitoad', 'Wailord', 'Poliwag', 'Steelix', 'Gorebyss', 'Crabrawler', 'Klefki', 'Surskit', 'Bewear', 'Mantyke', 'Zacian', 'Emboar']
898

 

 

train데이터에서 없는 속성을 validation데이터에서 지우기


       

          import shutil

          for val_label in val_labels:
              if val_label not in train_labels:  # 존재하지 않으면 삭제
                  shutil.rmtree(os.path.join('validation', val_label))
           
          val_labels = os.listdir('validation')
          print(len(val_labels))

 
          # 다시 확인 => 2개 오차 있는것 같음
          val_labels = os.listdir('validation')
          print(len(train_labels))
          print(len(val_labels))

          for train_label in train_labels:
            if train_label not in val_labels:
              print(train_label)
              os.makedirs(os.path.join('validation', train_label), exist_ok = True)

            val_labels = os.listdir('validation')
            print(len(val_labels))

  1. 라이브러리 임포트:
    • Python의 기본 라이브러리 중 하나인 shutil을 임포트합니다.
      shutil은 파일 및 디렉토리 작업을 수행하는 데 사용됩니다.
  2. 반복문을 통한 디렉토리 삭제:
    • val_labels 리스트에 있는 각 항목(val_label)에 대해 반복합니다.
    • if val_label not in train_labels: 조건문을 사용하여 현재 검증 레이블(val_label)이 학습 레이블(train_labels)에 없는 경우를 확인합니다.
    • os.path.join('validation', val_label)을 사용하여 'validation' 디렉토리 내에 있는 해당 레이블의 하위 디렉토리 경로를 생성합니다.
    • shutil.rmtree() 함수를 사용하여 해당 경로에 있는 디렉토리를 재귀적으로 삭제합니다. 이 함수는 디렉토리와 그 내용물을 모두 삭제합니다.
  3. 결과 출력:
    • train_labels와 val_labels의 길이를 출력하여 학습 데이터셋과 검증 데이터셋의 레이블 수를 확인합니다.
    • 이를 통해 얼마나 많은 레이블이 유지되었는지와 디렉토리 삭제 작업의 결과를 확인할 수 있습니다.

해석:

  • 이 코드는 검증 데이터셋(validation 디렉토리)에 있는 각 레이블에 대해 해당 레이블이 학습 데이터셋(train_labels)에 포함되지 않으면 해당 레이블의 디렉토리를 삭제합니다.
  • 이 작업은 학습과 검증 데이터셋 간의 일관성을 유지하고, 필요 없는 데이터를 정리하는 데 사용될 수 있습니다.
  • 마지막으로 출력된 레이블 수는 각 데이터셋에 남아 있는 레이블의 수를 나타내며, 삭제 작업의 효과를 확인하는 데 도움을 줍니다.
 
147
-------------------------
149
147
------------------------
MrMime
Farfetchd
------------------------
149

 

 

모듈 import 


       
        import torch
        import torch.nn as nn
        import torch.optim as optim
        import matplotlib.pyplot as plt
        from torchvision import datasets, models, transforms
        from torch.utils.data import DataLoader

        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        print(device)


 

 

 이미지 증강 기법


       
      data_transform = {
          'train': transforms.Compose([
              transforms.Resize((224, 224)),
              transforms.RandomAffine(0, shear = 10, scale = (0.8, 1.2)),
              transforms.RandomHorizontalFlip(),
              transforms.ToTensor()
          ]),
          'validation': transforms.Compose([
              transforms.Resize((224, 224)),
              transforms.ToTensor()
          ])
      }


  • 'train' 데이터셋 변환:
    • transforms.Resize((224, 224)): 입력 이미지의 크기를 224x224 픽셀로 조정합니다.
    • transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)): 이미지에 임의의 기울임(shear)을 적용하고, 크기를 임의로 조정합니다. shear는 0도에서 10도 사이에서 랜덤하게 설정되며, 크기(scale)는 0.8배에서 1.2배 사이에서 랜덤하게 조정됩니다.
    • transforms.RandomHorizontalFlip(): 이미지를 랜덤하게 좌우로 반전시킵니다.
    • transforms.ToTensor(): 이미지를 PyTorch Tensor 형식으로 변환합니다. 이 과정에서 이미지의 픽셀 값은 0에서 1 사이로 정규화됩니다.
  • 'validation' 데이터셋 변환:
    • transforms.Resize((224, 224)): 입력 이미지의 크기를 224x224 픽셀로 조정합니다.
    • transforms.ToTensor(): 이미지를 PyTorch Tensor 형식으로 변환합니다. 이 과정에서 이미지의 픽셀 값은 0에서 1 사이로 정규화됩니다.

 

이터셋 객체 생성


       
        image_datasets = {
            'train': datasets.ImageFolder('train', data_transform['train']),
            'validation': datasets.ImageFolder('validation', data_transform['validation'])
        }

 
 

 

 데이터 로더 만들기


       
        dataloaders = {
            'train': DataLoader(
                image_datasets['train'],
                batch_size = 32,
                shuffle = True
            ),
            'validation': DataLoader(
                image_datasets['validation'],
                batch_size = 32,
                shuffle = False
            )
        }


    print(len(image_datasets['train']))
    print(len(image_datasets['validation']))


10657
661

 

이미지 시각화


       
        imgs, labels = next(iter(dataloaders['train']))

        _, axes = plt.subplots(4,8,figsize=(20,10))

        for ax, img, label in zip(axes.flatten(), imgs, labels):
            ax.imshow(img.permute(1,2,0))
            ax.set_title(label)
            ax.axis('off')


 

 클래스 이름 확인


       
        image_datasets['train'].classes[101]

 
Pikachu

 



2. 사전 학습된 EfficientNet 모델

* 구글의 연구팀이 개발한 이미지 분류, 객체 검출 등 컴퓨터 비전 작업에서 높은 성능을 보여주는 신경망 모델
* 신경망의 깊이, 너비, 해상도를 동시에 확장하는 방법을 통해 효율성과 성능을 극대화한 것이 특징
* EfficientnetB4는 EfficientNet 시리즈의 중간 크기 모델

 

 

  • import

       
        from torchvision.models import efficientnet_b4, EfficientNet_B4_Weights
        from torchvision.models._api import WeightsEnum
        from torch.hub import load_state_dict_from_url

 
 

 

EfficientNetB4 모델을 사전 학습된 가중치로 초기화하는 과정


     
        def get_state_dict(self, *args, **kwargs):
            kwargs.pop("check_hash")
            return load_state_dict_from_url(self.url, *args, **kwargs)
        WeightsEnum.get_state_dict = get_state_dict

        # 사전 학습된 EfficientNetB4 모델
        model = efficientnet_b4(weights=EfficientNet_B4_Weights.IMAGENET1K_V1).to(device)
        # model = efficientnet_b4(weights="DEFAULT").to(device)


 
Downloading: "https://download.pytorch.org/models/efficientnet_b4_rwightman-23ab8bcd.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b4_rwightman-23ab8bcd.pth
100%|██████████| 74.5M/74.5M [00:00<00:00, 155MB/s]

 

 학습하기
: 모든 파라미터를 동결하고, 새로운 분류기(classifier)를 추가하여 모델을 수정


     
        for param in model.parameters():
          param.requires_grad = False

        model.classifier = nn.Sequential(
            nn.Linear(1792, 512),
            nn.ReLU(),

            nn.Linear(512, 149)
        ).to(device)

        print(model)


  1. 모든 파라미터 동결:
    • for param in model.parameters(): param.requires_grad = False: 모델의 모든 파라미터를 순회하면서 requires_grad를 False로 설정합니다.
    • requires_grad=False는 해당 파라미터가 역전파(backpropagation) 과정에서 경사도(gradient)를 계산하지 않도록 하여, 이 파라미터들의 가중치가 학습되지 않도록 만듭니다. 이는 미세 조정이나 새로운 분류기 학습에 유리합니다.
  2. 새로운 분류기 추가:
    • model.classifier = nn.Sequential(...): 기존 모델의 classifier 부분을 새로운 nn.Sequential로 대체합니다.
    • nn.Linear(1792, 512): 입력 크기가 1792이고 출력 크기가 512인 fully connected 레이어를 추가합니다.
    • nn.ReLU(): ReLU 활성화 함수를 추가합니다. 이는 비선형성을 도입하여 모델이 복잡한 데이터에 대해 더 복잡한 결정 경계를 학습할 수 있게 합니다.
    • nn.Linear(512, 149): 입력 크기가 512이고 출력 크기가 149인 fully connected 레이어를 추가합니다. 여기서 149는 새로운 클래스 수를 나타냅니다.
  3. 모델 출력:
    • print(model): 모델의 현재 구조를 출력합니다. 이를 통해 모델이 어떻게 수정되었는지, 각 레이어의 구성이 어떤지 확인할 수 있습니다.

사용 목적:

  • 기존 모델의 일부 레이어를 고정하고, 새로운 분류기를 추가하여 전이 학습(transfer learning)을 수행합니다.
  • 기존의 사전 학습된 모델의 특성 추출 능력을 그대로 유지하면서, 새로운 데이터셋에 맞는 새로운 분류 작업을 학습할 수 있습니다.
  • nn.Linear와 nn.ReLU를 이용하여 fully connected 레이어와 활성화 함수를 추가하여 모델을 새로운 분류 작업에 맞게 조정합니다.

 

 

◼ 데이터셋을 사용하여 EfficientNet 모델 학습시키고 평가


       

          # optimizer를 Adam으로 설정하고, classifier의 파라미터만 업데이트합니다.

          optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)

          epochs = 10

          for epoch in range(epochs):
              for phase in ['train', 'validation']:
                  if phase == 'train':
                      model.train() # 모델을 학습 모드로 설정
                  else:
                      model.eval() # 모델을 평가 모드로 설정

                  sum_losses = 0 # 손실 합 
                  sum_accs = 0 # 정확도 합 

                  for x_batch, y_batch in dataloaders[phase]:
                      x_batch = x_batch.to(device) # 입력 데이터를 GPU로 옮김
                      y_batch = y_batch.to(device) # 타겟 데이터를 GPU로 옮김

                      y_pred = model(x_batch) # 예측값 계산
                      loss = nn.CrossEntropyLoss()(y_pred, y_batch) # Cross Entropy Loss를 계산하여 손실값 얻음

                      if phase == 'train': # 학습 단계
                          optimizer.zero_grad() # 기울기 초기화
                          loss.backward() # 역전파 수행
                          optimizer.step() # 가중치 업데이트

                      sum_losses = sum_losses + loss   # 배치의 손실값을 합산

                      y_prob = nn.Softmax(1)(y_pred) # 모델의 출력을 softmax 함수를 사용하여 확률로 변환
                      y_pred_index = torch.argmax(y_prob, axis=1) # 확률이 가장 높은 클래스의 인덱스를 예측값으로 설정
                      acc = (y_batch == y_pred_index).float().sum() / len(y_batch) * 100  # 정확도 계산
                      sum_accs += acc.item()  # 배치의 정확도를 합산
                      sum_accs = sum_accs + acc # 배치의 정확도를 합산
                     
                  avg_loss = sum_losses / len(dataloaders[phase]) # epoch의 평균 손실값 계산
                  avg_acc = sum_accs / len(dataloaders[phase]) # epoch의 평균 정확도 계산
                  print(f'{phase:10s}: Epoch {epoch+1:4d}/{epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_acc:.2f}%')




 

 

학습된 모델 파일 저장 


       
        # 학습된 모델 파일 저장
        torch.save(model.state_dict(), 'model.pth') # model.h5 -> tensorflow


 

 

저장된 모델 확인


       
        # 저장된 모델 확인
        model = models.efficientnet_b4.to(device)

        model.classifier = nn.Sequential(
            nn.Linear(1792, 512),
            nn.ReLU(),

            nn.Linear(512, 149)
        ).to(device)

        model.load_state_dict(torch.load('model.pth'))
        model.eval()


EfficientNet-B4 모델을 로드하고, 평가 모드로 설정
  1. 모델 불러오기 및 설정:
    • EfficientNet-B4 모델을 models에서 불러와 GPU(device)로 옮깁니다.
  2. 분류기 재설정:
    • 기존 모델의 분류기를 새로운 Sequential 모듈로 교체합니다.
    • 입력 차원은 EfficientNet-B4의 출력 크기에 맞추어 1792로 설정되었습니다.
    • ReLU 활성화 함수를 거쳐 512 크기의 중간 레이어를 추가하고, 최종 출력 크기는 149로 설정합니다.
  3. 저장된 모델의 상태 복원:
    • 'model.pth' 파일에서 저장된 모델의 상태를 복원합니다.
    • torch.load() 함수를 사용하여 모델의 가중치와 매개변수를 메모리로 로드합니다.
  4. 모델 평가 모드 설정:
    • 모델을 평가 모드로 설정합니다.
    • 평가 모드에서는 배치 정규화와 드롭아웃을 비활성화하여 예측을 일관되게 만듭니다.

사용 목적:

  • 저장된 EfficientNet-B4 모델을 로드하고, 새로운 분류기를 추가하여 특정 문제에 맞게 재사용할 준비를 합니다.
  • 모델의 평가 모드 설정은 모델을 사용하여 예측을 수행할 때, 일관된 결과를 얻기 위해 필요합니다.
  • 이 코드를 사용하면 이전에 학습한 모델의 성능을 평가하거나 실제 데이터에 적용할 수 있습니다.

 

 

 테스트


       

        from PIL import Image

        img1 = Image.open('/content/validation/Snorlax/4.jpg')
        img2 = Image.open('/content/validation/Rattata/0.jpg')

        fig, axes = plt.subplots(1, 2, figsize = (12, 6))

        axes[0].imshow(img1)
        axes[0].axis('off')
        axes[1].imshow(img2)
        axes[1].axis('off')

        plt.show()



 

 

◼ 이미지  데이터를 변환하는 data_transform을 사용하여 img1과 img2를 처리한 후 각각의 shape를 출력


       
        img1_input = data_transform['validation'](img1)
        img2_input = data_transform['validation'](img2)
        print(img1_input.shape)
        print(img2_input.shape)
 

torch.Size([3, 224, 224])
torch.Size([3, 224, 224])

 

 

두 개의 이미지 데이터를 하나의 배치 텐서인 test_batch로 스택(stack)하여 만들기
 test_batch를 GPU(device)로 옮기고 그 shape를 출력


       
        test_batch = torch.stack([img1_input, img2_input])
        test_batch = test_batch.to(device)
        test_batch.shape


torch.Size([2, 3, 224, 224])

 

 예측하기


       
          y_pred = model(test_batch)
          y_pred


tensor([[-10.1381, -18.2518, -16.0891, -13.0126, -18.1286, -17.1851, -15.6250,
         -13.1667, -13.5227,  -6.7796, -13.8471,  -9.7336,  -4.9369, -13.6744,
          -9.3085,  -8.9335, -15.6262, -10.8516, -13.3686, -11.2549,  -8.4466,
          -6.4885,  -6.5071, -16.3503, -14.5905, -11.8353,  -9.9806,  -8.7323,
          -9.7071,  -8.2985, -13.0898,  -8.7130, -19.6752, -13.4941, -11.8880,
          -9.1840, -13.9102, -17.0178, -13.1194, -11.5825,  -9.4896, -17.9141,
          -9.5629, -16.5170, -13.3171, -12.2587, -16.4166, -18.3653,  -9.4935,
         -13.2763, -17.6866, -13.7130, -14.7392, -14.6677, -11.6655, -10.5128,
         -12.3555,  -7.9406, -14.8172,  -6.5830, -10.8367, -12.2928, -12.8076,
         -17.0396, -10.1280, -13.5942,  -6.7088, -11.5467,  -7.6642,  -5.2480,
         -15.8393, -15.2643, -11.4280, -10.1169, -13.1539, -14.3218, -15.0127,
         -15.6476, -11.6592, -10.1635, -16.6752, -11.6984,  -8.3909, -15.6058,
          -7.2380, -10.5447,  -9.6789,  -7.1109, -10.7635, -12.4095, -14.6815,
         -10.0865,  -6.6754,  -6.9112, -17.0186, -13.2917,  -8.1588, -10.9359,
         -11.8030, -11.2323, -11.5271,  -9.3658, -12.8177, -13.1180, -12.1684,
         -10.8933, -15.0315, -12.2008, -15.7332,  -7.6037,  -7.8880, -15.7726,
         -12.9173,  -9.6909, -11.2190, -13.0147, -11.9097, -15.3206, -14.9645,
         -10.8827, -13.4233,  -8.2616,  -6.5764,  -9.2068,  -1.1722,   4.6309,
         -13.8240,  -8.8144, -17.5671, -16.4612,  -8.7952, -13.9231, -13.7120,
         -12.9466,  -9.8612, -14.6987, -13.3159,  -9.9359,  -7.8293,  -7.9961,
         -14.1366, -12.0898, -12.5732,  -6.7362,  -9.3175, -12.3335,  -7.4365,
         -24.8273, -15.5383],
        [-13.3796, -18.8920, -19.0383,  -9.2321, -14.1503, -15.0578, -13.0850,
          -9.3543, -13.7065,  -7.2419, -10.7370,  -6.9521,  -5.2882, -14.8258,
          -5.8944, -12.0897, -18.6506, -12.8319, -11.5861,  -9.4032,  -5.4389,
           8.1403,  -5.5555, -12.4493, -10.4225,  -6.8148, -11.7280, -10.2866,
         -11.1347,  -1.1614, -13.3300,  -3.8216, -18.1594, -11.4256, -13.6750,
          -6.9662, -13.0039, -11.6008, -10.9354,  -9.6054,  -9.9303,  -8.1292,
         -12.4365, -15.6898, -10.1333, -11.0957, -12.9833, -11.9187,  -5.8141,
         -13.3899, -11.8386, -14.2317,  -9.3236,  -8.3747,  -8.4686, -12.7803,
          -9.0560,  -7.8696, -14.1533,  -6.9840,  -3.8049,  -8.8678, -14.3514,
          -7.8895, -11.4716, -12.4721,  -8.6673,  -6.8360,  -7.3675,  -6.8125,
         -11.4952, -10.9705,  -6.5105,  -9.7348, -14.7709, -11.3191, -14.1895,
         -12.8245,  -9.9535, -10.8555, -11.0447,  -8.4326,  -9.1383, -13.7896,
         -11.9255,  -7.1484, -14.5495,  -9.2247, -13.1285, -14.1737,  -8.4069,
          -7.5965,  -3.4051,  -7.6866, -10.3734,  -7.7260,  -5.2773,  -8.5044,
         -13.3983, -12.9986, -10.5447, -10.1205, -15.6450,  -9.7553, -11.2016,
         -11.3979, -11.5908, -10.2518, -12.1675,  -4.6960, -11.4753, -12.7572,
         -13.6541, -11.3222, -11.3799, -11.1025,  -9.1730, -13.7514, -15.8559,
          -8.6237, -12.6485,  -4.2751,  -7.7025, -12.9300,  -7.0018,  -8.2598,
          -8.8410,  -5.0144, -15.6619,  -9.5564,  -5.5501, -10.5789,  -5.6651,
          -5.5720, -11.6338, -12.4811,  -8.4071, -11.0972,  -9.6168,  -6.4310,
          -8.6936, -10.1972, -12.1732,  -0.8298,  -7.8869, -15.4384,  -9.2720,
         -21.3815, -12.4833]], device='cuda:0', grad_fn=<AddmmBackward0>)

 

 예측결과를 확률 값으로 변환


       
      y_prob = nn.Softmax(1)(y_pred)
      y_prob


tensor([[-10.1381, -18.2518, -16.0891, -13.0126, -18.1286, -17.1851, -15.6250,
         -13.1667, -13.5227,  -6.7796, -13.8471,  -9.7336,  -4.9369, -13.6744,
          -9.3085,  -8.9335, -15.6262, -10.8516, -13.3686, -11.2549,  -8.4466,
          -6.4885,  -6.5071, -16.3503, -14.5905, -11.8353,  -9.9806,  -8.7323,
          -9.7071,  -8.2985, -13.0898,  -8.7130, -19.6752, -13.4941, -11.8880,
          -9.1840, -13.9102, -17.0178, -13.1194, -11.5825,  -9.4896, -17.9141,
          -9.5629, -16.5170, -13.3171, -12.2587, -16.4166, -18.3653,  -9.4935,
         -13.2763, -17.6866, -13.7130, -14.7392, -14.6677, -11.6655, -10.5128,
         -12.3555,  -7.9406, -14.8172,  -6.5830, -10.8367, -12.2928, -12.8076,
         -17.0396, -10.1280, -13.5942,  -6.7088, -11.5467,  -7.6642,  -5.2480,
         -15.8393, -15.2643, -11.4280, -10.1169, -13.1539, -14.3218, -15.0127,
         -15.6476, -11.6592, -10.1635, -16.6752, -11.6984,  -8.3909, -15.6058,
          -7.2380, -10.5447,  -9.6789,  -7.1109, -10.7635, -12.4095, -14.6815,
         -10.0865,  -6.6754,  -6.9112, -17.0186, -13.2917,  -8.1588, -10.9359,
         -11.8030, -11.2323, -11.5271,  -9.3658, -12.8177, -13.1180, -12.1684,
         -10.8933, -15.0315, -12.2008, -15.7332,  -7.6037,  -7.8880, -15.7726,
         -12.9173,  -9.6909, -11.2190, -13.0147, -11.9097, -15.3206, -14.9645,
         -10.8827, -13.4233,  -8.2616,  -6.5764,  -9.2068,  -1.1722,   4.6309,
         -13.8240,  -8.8144, -17.5671, -16.4612,  -8.7952, -13.9231, -13.7120,
         -12.9466,  -9.8612, -14.6987, -13.3159,  -9.9359,  -7.8293,  -7.9961,
         -14.1366, -12.0898, -12.5732,  -6.7362,  -9.3175, -12.3335,  -7.4365,
         -24.8273, -15.5383],
        [-13.3796, -18.8920, -19.0383,  -9.2321, -14.1503, -15.0578, -13.0850,
          -9.3543, -13.7065,  -7.2419, -10.7370,  -6.9521,  -5.2882, -14.8258,
          -5.8944, -12.0897, -18.6506, -12.8319, -11.5861,  -9.4032,  -5.4389,
           8.1403,  -5.5555, -12.4493, -10.4225,  -6.8148, -11.7280, -10.2866,
         -11.1347,  -1.1614, -13.3300,  -3.8216, -18.1594, -11.4256, -13.6750,
          -6.9662, -13.0039, -11.6008, -10.9354,  -9.6054,  -9.9303,  -8.1292,
         -12.4365, -15.6898, -10.1333, -11.0957, -12.9833, -11.9187,  -5.8141,
         -13.3899, -11.8386, -14.2317,  -9.3236,  -8.3747,  -8.4686, -12.7803,
          -9.0560,  -7.8696, -14.1533,  -6.9840,  -3.8049,  -8.8678, -14.3514,
          -7.8895, -11.4716, -12.4721,  -8.6673,  -6.8360,  -7.3675,  -6.8125,
         -11.4952, -10.9705,  -6.5105,  -9.7348, -14.7709, -11.3191, -14.1895,
         -12.8245,  -9.9535, -10.8555, -11.0447,  -8.4326,  -9.1383, -13.7896,
         -11.9255,  -7.1484, -14.5495,  -9.2247, -13.1285, -14.1737,  -8.4069,
          -7.5965,  -3.4051,  -7.6866, -10.3734,  -7.7260,  -5.2773,  -8.5044,
         -13.3983, -12.9986, -10.5447, -10.1205, -15.6450,  -9.7553, -11.2016,
         -11.3979, -11.5908, -10.2518, -12.1675,  -4.6960, -11.4753, -12.7572,
         -13.6541, -11.3222, -11.3799, -11.1025,  -9.1730, -13.7514, -15.8559,
          -8.6237, -12.6485,  -4.2751,  -7.7025, -12.9300,  -7.0018,  -8.2598,
          -8.8410,  -5.0144, -15.6619,  -9.5564,  -5.5501, -10.5789,  -5.6651,
          -5.5720, -11.6338, -12.4811,  -8.4071, -11.0972,  -9.6168,  -6.4310,
          -8.6936, -10.1972, -12.1732,  -0.8298,  -7.8869, -15.4384,  -9.2720,
         -21.3815, -12.4833]], device='cuda:0', grad_fn=<AddmmBackward0>)

 

모델을 사용하여 예측을 수행하고, 예측 결과에서 상위 K개의 클래스 확률과 인덱스를 출력


       
        probs, idx = torch.topk(model(y_prob),k=3)
        print(probs)
        print(idx)


tensor([[9.9668e-01, 3.0082e-03, 6.9708e-05],
        [9.9973e-01, 1.2713e-04, 9.1248e-05]], device='cuda:0',
       grad_fn=<TopkBackward0>)
tensor([[125, 124,  12],
        [ 21, 143,  29]], device='cuda:0')

 

예측결과 시각화


       


          fig, axes = plt.subplots(1, 2, figsize=(15, 6))

          axes[0].set_title('{:.2f}% {}, {:.2f}% {}, {:.2f}% {}'.format(
              probs[0, 0] * 100,
              image_datasets['validation'].classes[idx[0, 0]],
              probs[0, 1] * 100,
              image_datasets['validation'].classes[idx[0, 1]],
              probs[0, 2] * 100,
              image_datasets['validation'].classes[idx[0, 2]],
          ))
          axes[0].imshow(img1)
          axes[0].axis('off')

          axes[1].set_title('{:.2f}% {}, {:.2f}% {}, {:.2f}% {}'.format(
              probs[1, 0] * 100,
              image_datasets['validation'].classes[idx[1, 0]],
              probs[1, 1] * 100,
              image_datasets['validation'].classes[idx[1, 1]],
              probs[1, 2] * 100,
              image_datasets['validation'].classes[idx[1, 2]],
          ))
          axes[1].imshow(img2)
          axes[1].axis('off')



 

'AI > 딥러닝' 카테고리의 다른 글

09. 전이학습  (0) 2024.06.21
08. 간단한 CNN 모델 만들기  (0) 2024.06.20
07. CNN 기초  (0) 2024.06.20
06. 비선형 활성화 함수  (0) 2024.06.20
05. 딥러닝  (0) 2024.06.20