Keras를 사용하여 간단한 CNN 기반 이미지 분류기를 구축하는 방법


  • 1 Introduction
  • 1.1 What are Convolutional neural networks (CNN)?
  • 1.2 The classification task
  • 1.3 Problem statement
  • 1.4 Approach


  • 2 The Code
  • 2.1 Import data manipulation packages
  • 2.2 Load in the Data
  • 2.3 Build the CNN
  • 2.4 Training the CNN
  • 2.4 Evaluating the performance of CNN
  • 2.5 Saving the model
  • 2.6 Load the model
  • 2.7 Create Sample Submission

  • Related articles
  • References

  • 1. 소개



    1.1 합성곱 신경망(CNN)이란 무엇입니까?



    합성곱 신경망(CNN)은 다양한 컴퓨터 비전 작업에서 지배적인 인공 신경망의 한 종류로 다양한 영역에서 관심을 끌고 있습니다.

    컨볼루션 신경망은 컨볼루션 레이어, 풀링 레이어, 완전 연결 레이어와 같은 여러 빌딩 블록으로 구성되며 역전파 알고리즘을 통해 피처의 공간 계층을 자동으로 적응적으로 학습하도록 설계되었습니다.

    CNN은 이미지 분류, 객체 감지, 이미지 인식 등과 같은 컴퓨터 비전 작업에서 잘 작동합니다. 유사한 작업에 사용되는 다른 신경망에는 순환 신경망(RNN), 장단기 기억(LSTM), 인공 신경망(ANN) 등이 있습니다. .,

    1.2 분류 작업



    이 기사에서는 HP Unlocked challenge 문제를 해결하려고 합니다. 4번째 도전입니다.

    위 웹사이트에서 데이터를 가져오거나 내GitHub repository - Unlocked_Challenge_4에서 파일을 포크하여 직접 작업을 시작할 수도 있습니다.

    1.3 문제 설명



    이것은 이진 분류 작업입니다. 과제는 꽃의 일종인 "La Eterna"의 이미지를 분류하는 기계 학습 모델을 구축하는 것입니다.

    1.4 접근



    CNN 모델을 사용하여 기준 점수를 얻습니다. 초기 분석에서 데이터 세트는 딥 러닝 작업에 비해 매우 작습니다. 데이터 세트를 늘리기 위해 몇 가지image augmentations를 수행하겠습니다.

    또한 다음 기사에서 VGG19를 사용한 하이퍼파라미터 튜닝 및 전이 학습에 대해서도 살펴볼 것입니다.

    2 코드



    코드를 실행하기 위해 로컬 Jupyter 랩 인스턴스를 사용했습니다. 최종 코드는 here에서 찾을 수 있습니다.

    2.1 데이터 조작 패키지 가져오기




    import pandas as pd
    import numpy as np
    import os
    import cv2
    import matplotlib.pyplot as plt
    import warnings
    



    import tensorflow as tf
    from tensorflow import keras
    from tensorflow.keras import layers
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    import keras_tuner as kt
    



    print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
    



    # Set the seed value for experiment reproducibility.
    seed = 1842
    tf.random.set_seed(seed)
    np.random.seed(seed)
    # Turn off warnings for cleaner looking notebook
    warnings.simplefilter('ignore')
    


    2.2 데이터 로드




    #define image dataset
    # Data Augmentation
    image_generator = ImageDataGenerator(
            rescale=1/255,
            rotation_range=10, # rotation
            width_shift_range=0.2, # horizontal shift
            height_shift_range=0.2, # vertical shift
            zoom_range=0.2, # zoom
            horizontal_flip=True, # horizontal flip
            brightness_range=[0.2,1.2],# brightness
            validation_split=0.2,)
    
    #Train & Validation Split
    train_dataset = image_generator.flow_from_directory(batch_size=32,
                                                     directory='data_cleaned/Train',
                                                     shuffle=True,
                                                     target_size=(224, 224),
                                                     subset="training",
                                                     class_mode='categorical')
    
    validation_dataset = image_generator.flow_from_directory(batch_size=32,
                                                     directory='data_cleaned/Train',
                                                     shuffle=True,
                                                     target_size=(224, 224),
                                                     subset="validation",
                                                     class_mode='categorical')
    
    #Organize data for our predictions
    image_generator_submission = ImageDataGenerator(rescale=1/255)
    submission = image_generator_submission.flow_from_directory(
                                                     directory='data_cleaned/scraped_images',
                                                     shuffle=False,
                                                     target_size=(224, 224),
                                                     class_mode=None)
    


    2.3 CNN 구축



    네트워크 생성 방법에 대해 걱정하지 마십시오. 하이퍼파라미터 튜닝을 사용하여 레이어를 더 잘 튜닝하고 안정적인 네트워크를 얻을 수 있습니다. 다음 기사에서 다루겠습니다. 이에 대해 자세히 알아보려면 Keras documentation을 확인하십시오.

    입력 및 출력 모양을 엉망으로 만들지 마십시오. 여기서 입력 형태는 (224, 224, 3)입니다. 이미지의 높이, 너비 및 채널이 각각 224, 224 및 3임을 의미합니다(3은 컬러 이미지의 빨강, 녹색 및 파랑 채널임).

    model = keras.models.Sequential([
        keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape = [224, 224,3]),
        keras.layers.MaxPooling2D(),
        keras.layers.Conv2D(64, (2, 2), activation='relu'),
        keras.layers.MaxPooling2D(),
        keras.layers.Conv2D(64, (2, 2), activation='relu'),
        keras.layers.Flatten(),
        keras.layers.Dense(100, activation='relu'),
        keras.layers.Dense(2, activation ='softmax')
    ])
    


    이제 준비된 모델을 컴파일할 수 있습니다. 훈련을 조기에 중지하기 위해 콜백도 사용하고 있습니다. 이 경우 유효성 검사 손실이 동일하거나 3 epoch 이상 증가하면 콜백이 트리거됩니다.

    model.compile(optimizer='adam',
                 loss = 'binary_crossentropy',
                 metrics=['accuracy'])
    
    callback = keras.callbacks.EarlyStopping(monitor='val_loss',
                                                patience=3,
                                                restore_best_weights=True)
    


    2.4 CNN 교육




    model.fit(train_dataset, epochs=20, validation_data=validation_dataset, callbacks=callback)
    


    2.4 CNN 성능 평가




    loss, accuracy = model.evaluate(validation_dataset)
    print("Loss: ", loss)
    print("Accuracy: ", accuracy)
    


    2.5 모델 저장




    model.save('cnn-model')
    


    2.6 모델 불러오기




    model = keras.models.load_model('cnn-model')
    


    2.7 샘플 제출물 만들기




    onlyfiles = [f.split('.')[0] for f in os.listdir(os.path.join('data_cleaned/scraped_images/image_files')) if os.path.isfile(os.path.join(os.path.join('data_cleaned/scraped_images/image_files'), f))]
    submission_df = pd.DataFrame(onlyfiles, columns =['images'])
    submission_df[['la_eterna', 'other_flower']] = model.predict(submission)
    submission_df.head()
    



    submission_df.to_csv('submission_file.csv', index = False)
    


    관련 기사


  • Installing TensorFlow on M1 MacBook Air with GPU (Metal)
  • Run Ubuntu on M1 Macbook Air using UTM
  • Mushroom dataset analysis and classification in python
  • How to open sublime text from the windows command line

  • 참조


  • https://www.tensorflow.org/tutorials/images/data_augmentation
  • https://www.analyticsvidhya.com/blog/2020/02/learn-image-classification-cnn-convolutional-neural-networks-3-datasets/

  • 좋은 웹페이지 즐겨찾기