de
                    array(2) {
  ["de"]=>
  array(13) {
    ["code"]=>
    string(2) "de"
    ["id"]=>
    string(1) "3"
    ["native_name"]=>
    string(7) "Deutsch"
    ["major"]=>
    string(1) "1"
    ["active"]=>
    string(1) "1"
    ["default_locale"]=>
    string(5) "de_DE"
    ["encode_url"]=>
    string(1) "0"
    ["tag"]=>
    string(2) "de"
    ["missing"]=>
    int(0)
    ["translated_name"]=>
    string(7) "Deutsch"
    ["url"]=>
    string(98) "https://www.statworx.com/content-hub/blog/car-model-classification-1-transfer-learning-mit-resnet/"
    ["country_flag_url"]=>
    string(87) "https://www.statworx.com/wp-content/plugins/sitepress-multilingual-cms/res/flags/de.png"
    ["language_code"]=>
    string(2) "de"
  }
  ["en"]=>
  array(13) {
    ["code"]=>
    string(2) "en"
    ["id"]=>
    string(1) "1"
    ["native_name"]=>
    string(7) "English"
    ["major"]=>
    string(1) "1"
    ["active"]=>
    int(0)
    ["default_locale"]=>
    string(5) "en_US"
    ["encode_url"]=>
    string(1) "0"
    ["tag"]=>
    string(2) "en"
    ["missing"]=>
    int(0)
    ["translated_name"]=>
    string(8) "Englisch"
    ["url"]=>
    string(102) "https://www.statworx.com/en/content-hub/blog/car-model-classification-1-transfer-learning-with-resnet/"
    ["country_flag_url"]=>
    string(87) "https://www.statworx.com/wp-content/plugins/sitepress-multilingual-cms/res/flags/en.png"
    ["language_code"]=>
    string(2) "en"
  }
}
                    
Kontakt
Content Hub
Blog Post

Car Model Classification I: Transfer Learning mit ResNet

  • Expert:innen Stephan Müller
  • Datum 05. Mai 2021
  • Thema CodingDeep Learning
  • Format Blog
  • Kategorie Technology
Car Model Classification I: Transfer Learning mit ResNet

Deep Learning ist eines der Themen im Bereich der künstlichen Intelligenz, die uns bei STATWORX besonders faszinieren. In dieser Blogserie möchten wir veranschaulichen, wie ein End-to-end Deep Learning Projekt implementiert werden kann. Dabei verwenden wir die TensorFlow 2.x Bibliothek für die Implementierung.

Die Themen der 4-teiligen Blogserie umfassen:

  • Transfer Learning für Computer Vision
  • Deployment über TensorFlow Serving
  • Interpretierbarkeit von Deep-Learning-Modellen mittels Grad-CAM
  • Integration des Modells in ein Dashboard

Im ersten Teil zeigen wir, wie man Transfer Learning nutzen kann, um die Marke eines Autos mittels Bildklassifizierung vorherzusagen. Wir beginnen mit einem kurzen Überblick über Transfer Learning und das ResNet und gehen dann auf die Details der Implementierung ein. Der vorgestellte Code ist in diesem Github Repository zu finden.

Einführung: Transfer Learning & ResNet

Was ist Transfer Learning?

Beim traditionellen (Machine) Learning entwickeln wir ein Modell und trainieren es auf neuen Daten für jede neue Aufgabe, die ansteht. Transfer Learning unterscheidet sich von diesem Ansatz dadurch, dass das gesammelte Wissen von einer Aufgabe auf eine andere übertragen wird. Dieser Ansatz ist besonders nützlich, wenn einem zu wenige Trainingsdaten zur Verfügung stehen. Modelle, die für ein ähnliches Problem vortrainiert wurden, können als Ausgangspunkt für das Training neuer Modelle verwendet werden. Die vortrainierten Modelle werden als Basismodelle bezeichnet.

In unserem Beispiel kann ein Deep Learning-Modell, das auf dem ImageNet-Datensatz trainiert wurde, als Ausgangspunkt für die Erstellung eines Klassifikationsnetzwerks für Automodelle verwendet werden. Die Hauptidee hinter dem Transfer Learning für Deep Learning-Modelle ist, dass die ersten Layer eines Netzwerks verwendet werden, um wichtige High-Level-Features zu extrahieren, die für die jeweilige Art der behandelten Daten ähnlich bleiben. Die finalen Layer, auch „head“ genannt, des ursprünglichen Netzwerks werden durch einen benutzerdefinierten head ersetzt, der für das vorliegende Problem geeignet ist. Die Gewichte im head werden zufällig initialisiert, und das resultierende Netz kann für die spezifische Aufgabe trainiert werden.

Es gibt verschiedene Möglichkeiten, wie das Basismodell beim Training behandelt werden kann. Im ersten Schritt können seine Gewichte fixiert werden. Wenn der Lernfortschritt darauf schließen lässt, dass das Modell nicht flexibel genug ist, können bestimmte Layer oder das gesamte Basismodell auch mit trainiert werden. Ein weiterer wichtiger Aspekt, den es zu beachten gilt, ist, dass der Input die gleiche Dimensionalität haben muss wie die Daten, auf denen das Basismodell initial trainiert wurde – sofern die ersten Layer des Basismodells festgehalten werden sollen.

image-20200319174208670

Als nächstes stellen wir kurz das ResNet vor, eine beliebte und leistungsfähige CNN-Architektur für Bilddaten. Anschließend zeigen wir, wie wir Transfer Learning mit ResNet zur Klassifizierung von Automodellen eingesetzt haben.

Was ist ResNet?

Das Training von Deep Neural Networks kann aufgrund des sogenannten Vanishing Gradient-Problems schnell zur Herausforderung werden. Aber was sind Vanishing Gradients? Neuronale Netze werden in der Regel mit Back-Propagation trainiert. Dieser Algorithmus nutzt die Kettenregel der Differentialrechnung, um Gradienten in tieferen Layern des Netzes abzuleiten, indem Gradienten aus früheren Layern multipliziert werden. Da Gradienten in Deep Networks wiederholt multipliziert werden, können sie sich während der Backpropagation schnell infinitesimal kleinen Werten annähern.

ResNet ist ein CNN-Netz, welches das Problem des Vanishing Gradients mit sogenannten Residualblöcken löst (eine gute Erklärung, warum sie ‚Residual‘ heißen, findest du hier). Im Residualblock wird die unmodifizierte Eingabe an das nächste Layer weitergereicht, indem sie zum Ausgang eines Layers addiert wird (siehe Abbildung rechts). Diese Modifikation sorgt dafür, dass ein besserer Informationsfluss von der Eingabe zu den tieferen Layers möglich ist. Die gesamte ResNet-Architektur ist im rechten Netzwerk in der linken Abbildung unten dargestellt. Weiter sind daneben ein klassisches CNN und das VGG-19-Netzwerk, eine weitere Standard-CNN-Architektur, abgebildet.

Resnet-Architecture_Residual-Block

ResNet hat sich als leistungsfähige Netzarchitektur für Bildklassifikationsprobleme erwiesen. Zum Beispiel hat ein Ensemble von ResNets mit 152 Layern den ILSVRC 2015 Bildklassifikationswettbewerb gewonnen. Im Modul tensorflow.keras.application sind vortrainierte ResNet-Modelle unterschiedlicher Größe verfügbar, nämlich ResNet50, ResNet101, ResNet152 und die entsprechenden zweiten Versionen (ResNet50V2, …). Die Zahl hinter dem Modellnamen gibt die Anzahl der Layer an, über die die Netze verfügen. Die verfügbaren Gewichte sind auf dem ImageNet-Datensatz vortrainiert. Die Modelle wurden auf großen Rechenclustern unter Verwendung von spezialisierter Hardware (z.B. TPU) über signifikante Zeiträume trainiert. Transfer Learning ermöglicht es uns daher, diese Trainingsergebnisse zu nutzen und die erhaltenen Gewichte als Ausgangspunkt zu verwenden.

Klassifizierung von Automodellen

Als anschauliches Beispiel für die Anwendung von Transfer Learning behandeln wir das Problem der Klassifizierung des Automodells anhand eines Bildes des Autos. Wir beginnen mit der Beschreibung des verwendeten Datensatzes und wie wir unerwünschte Beispiele aus dem Datensatz herausfiltern können. Anschließend gehen wir darauf ein, wie eine Datenpipeline mit tensorflow.data eingerichtet werden kann. Im zweiten Abschnitt werden wir die Modellimplementierung durchgehen und aufzeigen, auf welche Aspekte ihr beim Training und bei der Inferenz besonders achten müsst.

Datenvorbereitung

Wir haben den Datensatz aus diesem GitHub Repo verwendet – dort könnt ihr den gesamten Datensatz herunterladen. Der Autor hat einen Datascraper gebaut, um alle Autobilder von der Car Connection Website zu scrapen. Er erklärt, dass viele Bilder aus dem Innenraum der Autos stammen. Da sie im Datensatz nicht erwünscht sind, filtern wir sie anhand der Pixelfarbe heraus. Der Datensatz enthält 64’467 jpg-Bilder, wobei die Dateinamen Informationen über die Automarke, das Modell, das Baujahr usw. enthalten. Für einen detaillierteren Einblick in den Datensatz empfehlen wir euch, das originale GitHub Repo zu konsultieren. Hier sind drei Beispielbilder:

Car Collage 01

Bei der Betrachtung der Daten haben wir festgestellt, dass im Datensatz noch viele unerwünschte Bilder enthalten waren, z.B. Bilder von Außenspiegeln, Türgriffen, GPS-Panels oder Leuchten. Beispiele für unerwünschte Bilder sind hier zu sehen:

Car Collage 02

Daher ist es von Vorteil, die Daten zusätzlich vorzufiltern, um mehr unerwünschte Bilder zu entfernen.

Filtern unerwünschter Bilder aus dem Datensatz

Es gibt mehrere mögliche Ansätze, um Nicht-Auto-Bilder aus dem Datensatz herauszufiltern:

  1. Verwendung eines vortrainierten Modells
  2. Ein anderes Modell trainieren, um Auto/Nicht-Auto zu klassifizieren
  3. Trainieren eines Generative Networks auf einem Auto-Datensatz und Verwendung des Diskriminatorteil des Netzwerks

Wir haben uns für den ersten Ansatz entschieden, da er der direkteste ist und ausgezeichnete, vortrainierte Modelle leicht verfügbar sind. Wenn ihr den zweiten oder dritten Ansatz verfolgen wollt, könnt ihr z. B. diesen Datensatz verwenden, um das Modell zu trainieren. Dieser Datensatz enthält nur Bilder von Autos, ist aber deutlich kleiner als der von uns verwendete Datensatz.

Unsere Wahl fiel auf das ResNet50V2 im Modul tensorflow.keras.applications mit den vortrainierten „imagenet“-Gewichten. In einem ersten Schritt müssen wir jetzt die Indizes und Klassennamen der imagenet-Labels herausfinden, die den Autobildern entsprechen.

# Class labels in imagenet corresponding to cars
CAR_IDX = [656, 627, 817, 511, 468, 751, 705, 757, 717, 734, 654, 675, 864, 609, 436]

CAR_CLASSES = ['minivan', 'limousine', 'sports_car', 'convertible', 'cab', 'racer', 'passenger_car', 'recreational_vehicle', 'pickup', 'police_van', 'minibus', 'moving_van', 'tow_truck', 'jeep', 'landrover', 'beach_wagon']

Als nächstes laden wir das vortrainierte ResNet50V2-Modell.

from tensorflow.keras.applications import ResNet50V2

model = ResNet50V2(weights='imagenet')

Wir können dieses Modell nun verwenden, um die Bilder zu klassifizieren. Die Bilder, die der Vorhersagemethode zugeführt werden, müssen identisch skaliert sein wie die Bilder, die zum Training verwendet wurden. Die verschiedenen ResNet-Modelle werden auf unterschiedlich skalierten Bilddaten trainiert. Es ist daher wichtig, das richtige Preprocessing anzuwenden.

from tensorflow.keras.applications.resnet_v2 import preprocess_input

image = tf.io.read_file(filename)
image = tf.image.decode_jpeg(image)
image = tf.cast(image, tf.float32)
image = tf.image.resize_with_crop_or_pad(image, target_height=224, target_width=224)
image = preprocess_input(image)
predictions = model.predict(image)

Es gibt verschiedene Ideen, wie die erhaltenen Vorhersagen für die Autoerkennung verwendet werden können.

  • Ist eine der CAR_CLASSES unter den Top-k-Vorhersagen?
  • Ist die kumulierte Wahrscheinlichkeit der CAR_CLASSES in den Vorhersagen größer als ein definierter Schwellenwert?
  • Spezielle Behandlung unerwünschter Bilder (z. B. Erkennen und Herausfiltern von Rädern)?

Wir zeigen den Code für den Vergleich der kumulierten Wahrscheinlichkeitsmaße über die CAR_CLASSES.

def is_car_acc_prob(predictions, thresh=THRESH, car_idx=CAR_IDX):
    """
    Determine if car on image by accumulating probabilities of car prediction and comparing to threshold

    Args:
        predictions: (?, 1000) matrix of probability predictions resulting from ResNet with                                              imagenet weights
        thresh: threshold accumulative probability over which an image is considered a car
        car_idx: indices corresponding to cars

    Returns:
        np.array of booleans describing if car or not
    """
    predictions = np.array(predictions, dtype=float)
    car_probs = predictions[:, car_idx]
    car_probs_acc = car_probs.sum(axis=1)
    return car_probs_acc > thresh

Je höher der Schwellenwert eingestellt ist, desto strenger ist das Filterverfahren. Ein Wert für den Schwellenwert, der gute Ergebnisse liefert, ist THRESH = 0.1. Damit wird sichergestellt, dass nicht zu viele echte Bilder von Autos verloren gehen. Die Wahl eines geeigneten Schwellenwerts bleibt jedoch eine subjektive Angelegenheit.

Das Colab-Notebook, in dem die Funktion is_car_acc_prob zum Filtern des Datensatzes verwendet wird, ist im GitHub Repository verfügbar.

Bei der Abstimmung der Vorfilterung haben wir Folgendes beobachtet:

  • Viele der Autobilder mit hellem Hintergrund wurden als „Strandwagen“ klassifiziert. Wir haben daher entschieden, auch die Klasse „Strandwagen“ in imagenet als eine der CAR_CLASSES zu berücksichtigen.
  • Bilder, die die Front eines Autos zeigen, bekommen oft eine hohe Wahrscheinlichkeit der Klasse „Kühlergrill“ („grille“) zugeordnet, d.h. dem Gitter an der Front eines Autos, das zur Kühlung dient. Diese Zuordnung ist korrekt, führt aber dazu, dass die oben gezeigte Prozedur bestimmte Bilder von Autos nicht als Autos betrachtet, da wir „grille“ nicht in die CAR_CLASSES aufgenommen haben. Dieses Problem führt zu dem Kompromiss, entweder viele Nahaufnahmen von Autokühlergrills im Datensatz zu belassen oder einige Autobilder herauszufiltern. Wir haben uns für den zweiten Ansatz entschieden, da er einen saubereren Datensatz ergibt.

Nach der Vorfilterung der Bilder mit dem vorgeschlagenen Verfahren verbleiben zunächst 53’738 von 64’467 im Datensatz.

Übersicht über die endgültigen Datensätze

Der vorgefilterte Datensatz enthält Bilder von 323 Automodellen. Wir haben uns dazu entschieden, unsere Aufmerksamkeit auf die 300 häufigsten Klassen im Datensatz zu reduzieren. Das ist deshalb sinnvoll, da einige der am wenigsten häufigen Klassen weniger als zehn Repräsentanten haben und somit nicht sinnvoll in ein Trainings-, Validierungs- und Testset aufgeteilt werden können. Reduziert man den Datensatz auf die Bilder der 300 häufigsten Klassen, erhält man einen Datensatz mit 53.536 beschrifteten Bildern. Die Klassenvorkommen sind wie folgt verteilt:

Histogram

Die Anzahl der Bilder pro Klasse (Automodell) reicht von 24 bis knapp unter 500. Wir können sehen, dass der Datensatz sehr unausgewogen ist. Dies muss beim Training und bei der Auswertung des Modells unbedingt beachtet werden.

Aufbau von Datenpipelines mit tf.data

Selbst nach der Vorfilterung und der Reduktion auf die besten 300 Klassen bleiben immer noch zahlreiche Bilder übrig. Dies stellt ein potenzielles Problem dar, da wir nicht einfach alle Bilder auf einmal in den Speicher unserer GPU laden können. Um dieses Problem zu lösen, werden wir tf.data verwenden.

Mit tf.data und insbesondere der tf.data.Dataset API lassen sich elegante und gleichzeitig sehr effiziente Eingabe-Pipelines erstellen. Die API enthält viele allgemeine Methoden, die zum Laden und Transformieren potenziell großer Datensätze verwendet werden können. Die Methode tf.data.Dataset ist besonders nützlich, wenn Modelle auf GPU(s) trainiert werden. Es ermöglicht das Laden von Daten von der Festplatte, wendet on-the-fly Transformationen an und erstellt Batches, die dann an die GPU gesendet werden. Und das alles geschieht so, dass die GPU nie auf neue Daten warten muss.

Die folgenden Funktionen erstellen eine <code>tf.data.Dataset-Instanz für unseren konkreten Anwendungsfall:

def construct_ds(input_files: list,
                 batch_size: int,
                 classes: list,
                 label_type: str,
                 input_size: tuple = (212, 320),
                 prefetch_size: int = 10,
                 shuffle_size: int = 32,
                 shuffle: bool = True,
                 augment: bool = False):
    """
    Function to construct a tf.data.Dataset set from list of files

    Args:
        input_files: list of files
        batch_size: number of observations in batch
        classes: list with all class labels
        input_size: size of images (output size)
        prefetch_size: buffer size (number of batches to prefetch)
        shuffle_size: shuffle size (size of buffer to shuffle from)
        shuffle: boolean specifying whether to shuffle dataset
        augment: boolean if image augmentation should be applied
        label_type: 'make' or 'model'

    Returns:
        buffered and prefetched tf.data.Dataset object with (image, label) tuple
    """
    # Create tf.data.Dataset from list of files
    ds = tf.data.Dataset.from_tensor_slices(input_files)

    # Shuffle files
    if shuffle:
        ds = ds.shuffle(buffer_size=shuffle_size)

    # Load image/labels
    ds = ds.map(lambda x: parse_file(x, classes=classes, input_size=input_size,                                                                                                                                        label_type=label_type))

    # Image augmentation
    if augment and tf.random.uniform((), minval=0, maxval=1, dtype=tf.dtypes.float32, seed=None, name=None) < 0.7:
        ds = ds.map(image_augment)

    # Batch and prefetch data
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=prefetch_size)

    return ds

Wir werden nun die verwendeten tf.data-Methoden beschreiben:

  • from_tensor_slices() ist eine der verfügbaren Methoden für die Erstellung eines Datensatzes. Der erzeugte Datensatz enthält Slices des angegebenen Tensors, in diesem Fall die Dateinamen.
  • Als nächstes betrachtet die Methode shuffle() jeweils buffer_size-Elemente separat und mischt diese Elemente isoliert vom Rest des Datensatzes. Wenn das Mischen des gesamten Datensatzes erforderlich ist, muss buffer_size größer sein als die Anzahl der Einträge im Datensatz. Das Mischen wird nur durchgeführt, wenn shuffle=True gesetzt ist.
  • Mit map() lassen sich beliebige Funktionen auf den Datensatz anwenden. Wir haben eine Funktion parse_file() erstellt, die im GitHub Repo zu finden ist. Sie ist verantwortlich für das Lesen und die Größenänderung der Bilder, das Ableiten der Beschriftungen aus dem Dateinamen und die Kodierung der Beschriftungen mit einem One-Hot-Encoder. Wenn die Flag „augment“ gesetzt ist, wird das Verfahren zur Datenerweiterung aktiviert. Die Augmentierung wird nur in 70 % der Fälle angewendet, da es von Vorteil ist, das Modell auch auf nicht modifizierten Bildern zu trainieren. Die in image_augment verwendeten Augmentierungstechniken sind Flipping, Helligkeits- und Kontrastanpassungen.
  • Schließlich wird die Methode batch() verwendet, um den Datensatz in Batches der Größe batch_size zu gruppieren, und die Methode prefetch() ermöglicht die Vorbereitung späterer Batches, während der aktuelle Batch verarbeitet wird, und verbessert so die Leistung. Wenn die Methode nach einem Aufruf von batch() verwendet wird, werden prefetch_size-Batches vorab geholt.

Fine Tuning des Modells

Nachdem wir unsere Eingabe-Pipeline definiert haben, wenden wir uns nun dem Trainingsteil des Modells zu. Der Code unten zeigt auf, wie ein Modell basierend auf dem vortrainierten ResNet instanziiert werden kann:

from tensorflow.keras.applications import ResNet50V2
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D


class TransferModel:

    def __init__(self, shape: tuple, classes: list):
        """
        Class for transfer learning from ResNet

        Args:
            shape: Input shape as tuple (height, width, channels)
            classes: List of class labels
        """
        self.shape = shape
        self.classes = classes
        self.history = None
        self.model = None

        # Use pre-trained ResNet model
        self.base_model = ResNet50V2(include_top=False,
                                     input_shape=self.shape,
                                     weights='imagenet')

        # Allow parameter updates for all layers
        self.base_model.trainable = True

        # Add a new pooling layer on the original output
        add_to_base = self.base_model.output
        add_to_base = GlobalAveragePooling2D(data_format='channels_last', name='head_gap')(add_to_base)

        # Add new output layer as head
        new_output = Dense(len(self.classes), activation='softmax', name='head_pred')(add_to_base)

        # Define model
        self.model = Model(self.base_model.input, new_output)

Ein paar weitere Details zum oben stehenden Code:

  • Wir erzeugen zunächst eine Instanz der Klasse tf.keras.applications.ResNet50V2. Mit include_top=False weisen wir das vortrainierte Modell an, den ursprünglichen head des Modells (in diesem Fall für die Klassifikation von 1000 Klassen auf ImageNet ausgelegt) wegzulassen.
  • Mit base_model.trainable = True werden alle Layer trainierbar.
  • Mit der funktionalen API tf.keras stapeln wir dann ein neues Pooling-Layer auf den letzten Faltungsblock des ursprünglichen ResNet-Modells. Dies ist ein notwendiger Zwischenschritt, bevor die Ausgabe an die endgültigen Klassifizierungs-Layer weitergeleitet wird.
  • Die endgültigen Klassifizierungs-Layer wird dann mit „tf.keras.layers.Dense“ definiert. Wir definieren die Anzahl der Neuronen so, dass sie gleich der Anzahl der gewünschten Klassen ist. Und die Softmax-Aktivierungsfunktion sorgt dafür, dass die Ausgabe eine Pseudowahrscheinlichkeit im Bereich von (0,1] ist.

Die Vollversion von TransferModel (s. GitHub) enthält auch die Option, das Basismodell durch ein VGG16-Netzwerk zu ersetzen, ein weiteres Standard-CNN für die Bildklassifikation. Außerdem erlaubt es, nur bestimmte Layer freizugeben, d.h. wir können die entsprechenden Parameter trainierbar machen, während wir die anderen festgehalten werden. Standardmässig haben wir hier alle Parameter trainierbar gemacht.

Nachdem wir das Modell definiert haben, müssen wir es für das Training konfigurieren. Dies kann mit der compile()-Methode von tf.keras.Model gemacht werden:

def compile(self, **kwargs):
      """
    Compile method
    """
    self.model.compile(**kwargs)

Wir übergeben dann die folgenden Keyword-Argumente an unsere Methode:

  • loss = "categorical_crossentropy" für die Mehrklassen-Klassifikation,
  • optimizer = Adam(0.0001) für die Verwendung des Adam-Optimierers aus tf.keras.optimizers mit einer relativ kleinen Lernrate (mehr zur Lernrate weiter unten), und
  • metrics = ["categorical_accuracy"] für die Trainings- und Validierungsüberwachung.

Als Nächstes wollen wir uns das Trainingsverfahren ansehen. Dazu definieren wir eine train-Methode für unsere oben vorgestellte TransferModel-Klasse:

from tensorflow.keras.callbacks import EarlyStopping

def train(self,
          ds_train: tf.data.Dataset,
          epochs: int,
          ds_valid: tf.data.Dataset = None,
          class_weights: np.array = None):
    """
    Trains model in ds_train with for epochs rounds

    Args:
        ds_train: training data as tf.data.Dataset
        epochs: number of epochs to train
        ds_valid: optional validation data as tf.data.Dataset
        class_weights: optional class weights to treat unbalanced classes

    Returns
        Training history from self.history
    """

    # Define early stopping as callback
    early_stopping = EarlyStopping(monitor='val_loss',
                                   min_delta=0,
                                   patience=12,
                                   restore_best_weights=True)

    callbacks = [early_stopping]

    # Fitting
    self.history = self.model.fit(ds_train,
                                  epochs=epochs,
                                  validation_data=ds_valid,
                                  callbacks=callbacks,
                                  class_weight=class_weights)

    return self.history

Da unser Modell eine Instanz von tensorflow.keras.Model ist, können wir es mit der Methode fit trainieren. Um Overfitting zu verhindern, wird Early Stopping verwendet, indem es als Callback-Funktion an die fit-Methode übergeben wird. Der patience-Parameter kann eingestellt werden, um festzulegen, wie schnell das Early Stopping angewendet werden soll. Der Parameter steht für die Anzahl der Epochen, nach denen, wenn keine Abnahme des Validierungsverlustes registriert wird, das Training abgebrochen wird. Weiterhin können Klassengewichte an die Methode fit übergeben werden. Klassengewichte erlauben es, unausgewogene Daten zu behandeln, indem den verschiedenen Klassen unterschiedliche Gewichte zugewiesen werden, wodurch die Wirkung von Klassen mit weniger Trainingsbeispielen erhöht werden kann.

Wir können den Trainingsprozess mit einem vortrainierten Modell wie folgt beschreiben: Da die Gewichte im head zufällig initialisiert werden und die Gewichte des Basismodells vortrainiert sind, setzt sich das Training aus dem Training des heads von Grund auf und der Feinabstimmung der Gewichte des vortrainierten Modells zusammen. Es wird generell für Transfer Learning empfohlen, eine kleine Lernrate zu verwenden (z. B. 1e-4), da eine zu große Lernrate die nahezu optimalen vortrainierten Gewichte des Basismodells zerstören kann.

Der Trainingsvorgang kann beschleunigt werden, indem zunächst einige Epochen lang trainiert wird, ohne dass das Basismodell trainierbar ist. Der Zweck dieser ersten Epochen ist es, die Gewichte des heads an das Problem anzupassen. Dies beschleunigt das Training, da wenn nur der head trainiert wird, viel weniger Parameter trainierbar sind und somit für jeden Batch aktualisiert werden. Die resultierenden Modellgewichte können dann als Ausgangspunkt für das Training des gesamten Modells verwendet werden, wobei das Basismodell trainierbar ist. Für das hier betrachtete Autoklassifizierungsproblem führte die Anwendung dieses zweistufigen Trainings zu keiner nennenswerten Leistungsverbesserung.

Evaluation/Vorhersage der Modell Performance

Bei der Verwendung der API tf.data.Dataset muss man auf die Art der verwendeten Methoden achten. Die folgende Methode in unserer Klasse TransferModel kann als Vorhersagemethode verwendet werden.

def predict(self, ds_new: tf.data.Dataset, proba: bool = True):
    """
    Predict class probs or labels on ds_new
    Labels are obtained by taking the most likely class given the predicted probs

    Args:
        ds_new: New data as tf.data.Dataset
        proba: Boolean if probabilities should be returned

    Returns:
        class labels or probabilities
    """

    p = self.model.predict(ds_new)

    if proba:
        return p
    else:
        return [np.argmax(x) for x in p]

Es ist wichtig, dass der Datensatz ds_new nicht gemischt wird, sonst stimmen die erhaltenen Vorhersagen nicht mit den erhaltenen Bildern überein, wenn ein zweites Mal über den Datensatz iteriert wird. Dies ist der Fall, da die Flag reshuffle_each_iteration in der Implementierung der Methode shuffle standardmäßig auf True gesetzt ist. Ein weiterer Effekt des Shufflens ist, dass mehrere Aufrufe der Methode take nicht die gleichen Daten zurückgeben. Dies ist wichtig, wenn z. B. Vorhersagen für nur eine Charge überprüft werden sollen. Ein einfaches Beispiel, an dem dies zu sehen ist, ist:

# Use construct_ds method from above to create a shuffled dataset
ds = construct_ds(..., shuffle=True)

# Take 1 batch (e.g. 32 images) of dataset: This returns a new dataset
ds_batch = ds.take(1)

# Predict labels for one batch
predictions = model.predict(ds_batch)

# Predict labels again: The result will not be the same as predictions above due to shuffling
predictions_2 = model.predict(ds_batch)

Eine Funktion zum Plotten von Bildern, die mit den entsprechenden Vorhersagen beschriftet sind, könnte wie folgt aussehen:

def show_batch_with_pred(model, ds, classes, rescale=True, size=(10, 10), title=None):
      for image, label in ds.take(1):
        image_array = image.numpy()
        label_array = label.numpy()
        batch_size = image_array.shape[0]
        pred = model.predict(image, proba=False)
        for idx in range(batch_size):
            label = classes[np.argmax(label_array[idx])]
            ax = plt.subplot(np.ceil(batch_size / 4), 4, idx + 1)
            if rescale:
                plt.imshow(image_array[idx] / 255)
            else:
                plt.imshow(image_array[idx])
            plt.title("label: " + label + "n" 
                      + "prediction: " + classes[pred[idx]], fontsize=10)
            plt.axis('off')

Die Methode show_batch_with_pred funktioniert auch für gemischte Datensätze, da image und label demselben Aufruf der Methode take entsprechen.

Die Auswertung der Model-Performance kann mit der Methode evaluate von keras.Model durchgeführt werden.

Wie akkurat ist unser finales Modell?

Das Modell erreicht eine kategoriale Genauigkeit von etwas über 70 % für die Vorhersage des Automodells für Bilder aus 300 Modellklassen. Um die Vorhersagen des Modells besser zu verstehen, ist es hilfreich, die Konfusionsmatrix zu betrachten. Unten ist die Heatmap der Vorhersagen des Modells für den Validierungsdatensatz abgebildet.

heatmap

Wir haben die Heatmap auf Einträge der Konfusionsmatrix in [0, 5] beschränkt, da das Zulassen einer weiteren Spanne keine Region außerhalb der Diagonalen signifikant hervorgehoben hat. Wie in der Heatmap zu sehen ist, wird eine Klasse den Beispielen fast aller Klassen zugeordnet. Das ist an der dunkelroten vertikalen Linie zwei Drittel rechts in der Abbildung oben zu erkennen.

Abgesehen von der zuvor erwähnten Klasse gibt es keine offensichtlichen Verzerrungen in den Vorhersagen. Wir möchten an dieser Stelle betonen, dass die Accuracy im Allgemeinen nicht ausreicht, um die Leistung eines Modells zufriedenstellend zu beurteilen, insbesondere im Fall unausgewogener Klassen.

Fazit und nächste Schritte

In diesem Blog-Beitrag haben wir Transfer Learning mit dem ResNet50V2 angewendet, um das Fahrzeugmodell anhand von Bildern von Autos zu klassifizieren. Unser Modell erreicht 70% kategoriale Genauigkeit über 300 Klassen.

Wir haben festgestellt, dass das Trainieren des gesamten Basismodells und die Verwendung einer kleinen Lernrate die besten Ergebnisse erzielen. Ein cooles Auto-Klassifikationsmodell entwickelt zu haben ist großartig, aber wie können wir unser Modell in einer produktiven Umgebung einsetzen? Natürlich könnten wir unsere eigene Modell-API mit Flask oder FastAPI bauen… Aber gibt es vielleicht sogar einen einfacheren, standardisierten Weg?

Im zweiten Beitrag unserer Blog-Serie, „Deployment von TensorFlow-Modellen in Docker mit TensorFlow Serving“ zeigen wir Euch, wie dieses Modell mit TensorFlow Serving bereitgestellt werden kann.

Stephan Müller Stephan Müller

Mehr erfahren!

Als eines der führenden Beratungs- und Entwicklungs­unternehmen für Data Science und KI begleiten wir Unternehmen in die datengetriebene Zukunft. Erfahre mehr über statworx und darüber, was uns antreibt.
Über uns