Tensorflow Multiclass Image Classification Using Image Data Generator
Python
An example of using Tensorflow for multiclass image classification with image augmentation done through the image data generator. Efficient Net is used as the base model. The image filenames for this were stored in csv files that were already split into train, validation and test.
1| import numpy as np 2| import pandas as pd 3| import tensorflow as tf 4| tf.random.set_seed(101) 5| from tensorflow.keras.models import Sequential 6| from tensorflow.keras.layers import Dense, Flatten, Dropout, MaxPooling2D 7| from tensorflow import keras 8| from tensorflow.keras import layers 9| from keras_preprocessing.image import ImageDataGenerator 10| from efficientnet.tfkeras import EfficientNetB0 11| from sklearn.metrics import accuracy_score 12| 13| CLASSES = ['0', '1', '2', '3', '4'] 14| EPOCHS = 25 15| 16| IMG_DIM = 255 17| BATCH_SIZE = 64 18| 19| #UNCOMMENT AND EDIT FOR ADDITIONAL AUGMENTATIONS 20| train_datagen = ImageDataGenerator( 21| rescale=1./255, 22| #horizontal_flip=True, 23| #vertical_flip=True, 24| #featurewise_std_normalization=True, 25| #rotation_range=180, 26| #shear_range=2.0, 27| #brightness_range=(0.0, 0.2) 28| #zoom_range=(0,0.25, 29| #channel_shift_range=30, 30| #width_shift_range=0.2, 31| #height_shift_range=0.2 32| ) 33| 34| test_datagen = ImageDataGenerator( 35| rescale=1./255) 36| 37| def build_model(): 38| base_model = EfficientNetB0(input_shape=(IMG_DIM , IMG_DIM ,3), 39| include_top=False, 40| weights='imagenet') 41| 42| base_model.trainable = False 43| model = Sequential( 44| [base_model, 45| tf.keras.layers.GlobalAveragePooling2D(), 46| layers.Flatten(), 47| Dense(512, activation= 'relu'), 48| Dropout(0.25), 49| Dense(256, activation='relu'), 50| Dropout(0.25), 51| Dense(128, activation='relu'), 52| Dropout(0.25), 53| Dense(len(CLASSES), activation='softmax')] 54| ) 55| return model 56| 57| train = pd.read_csv('data/train.csv') 58| val = pd.read_csv('data/validation.csv') 59| test = pd.read_csv('data/test.csv') 60| train = train.sample(frac=1).reset_index(drop=True) 61| val = val.sample(frac=1).reset_index(drop=True) 62| 63| train['label'] = train['label'].astype(str) 64| val['label'] = val['label'].astype(str) 65| test['label'] = test['label'].astype(str) 66| 67| train_generator = train_datagen.flow_from_dataframe(train, 68| directory='data/images', 69| x_col='image_id', #IMAGE FILENAME IN TRAIN DF 70| y_col='label', #IMAGE TARGET IN TRAIN DF 71| target_size=(IMG_DIM , IMG_DIM ), 72| batch_size=BATCH_SIZE , 73| class_mode='sparse') 74| val_generator = test_datagen.flow_from_dataframe(val, 75| directory='data/images', 76| x_col='image_id', #IMAGE FILENAME IN VAL DF 77| y_col='label', #IMAGE TARGET IN VAL DF 78| shuffle=False, 79| target_size=(IMG_DIM ,IMG_DIM ), 80| batch_size=BATCH_SIZE , 81| class_mode='sparse') 82| 83| test_generator = test_datagen.flow_from_dataframe(test, 84| directory='data/images', 85| x_col='image_id', #IMAGE FILENAME IN TEST DF 86| y_col='label', #IMAGE TARGET IN TEST DF 87| shuffle=False, 88| target_size=(IMG_DIM ,IMG_DIM ), 89| batch_size=BATCH_SIZE , 90| class_mode='sparse') 91| 92| tf.keras.backend.clear_session() 93| 94| model = build_model() 95| 96| STEP_SIZE_TRAIN=train_generator.n//train_generator.batch_size 97| STEP_SIZE_VALID=val_generator.n//val_generator.batch_size 98| 99| model.compile(optimizer = tf.keras.optimizers.Adam(0.001), 100| loss='sparse_categorical_crossentropy', 101| metrics=['sparse_categorical_accuracy']) 102| early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10) 103| history = model.fit(train_generator, 104| steps_per_epoch=STEP_SIZE_TRAIN, 105| validation_data=val_generator, 106| validation_steps=STEP_SIZE_VALID, 107| callbacks=[early_stop], 108| epochs=EPOCHS) 109| 110| 111| #MAKE PREDICTIONS FOR VALIDATION SET AND CALCULATE ACCURACY 112| predictions = model.predict(test_generator, 113| steps=None, 114| verbose=1) 115| predictions = np.argmax(predictions, axis=-1) 116| val['label'] = val['label'].astype(int) 117| score = accuracy_score(val['label'], predictions) 118| print('Accuracy:',score)
150
133
128
120