Sign In
Free Sign Up
  • English
  • Español
  • 简体中文
  • Deutsch
  • 日本語
Sign In
Free Sign Up
  • English
  • Español
  • 简体中文
  • Deutsch
  • 日本語

Ein umfassender Leitfaden zum kontrastiven Lernen: Techniken, Modelle und Anwendungen

Traditionell kann maschinelles Lernen (ML) (opens new window) grob in zwei Arten unterteilt werden: überwachtes und unüberwachtes Lernen (opens new window). Beim überwachten Lernen haben wir einige Daten mit Labels, und diese Labels werden verwendet, um das Modell zu trainieren. Zum Beispiel werden mit Hilfe von (beschrifteten) Bildern verschiedener Objekte ein Bildklassifizierer trainiert.

Beim unüberwachten Lernen hingegen sind keine Labels erforderlich, und hier erkunden wir die Muster ohne jegliche vorherige Informationen. Es ist zweifellos schwierig, aber auch faszinierend. Mit Tonnen von verfügbaren Daten (und neuen Daten, die jeden Tag produziert werden), ist es immer einfach, an die Daten zu gelangen, obwohl das Labeln viel Zeit und Geld erfordert.

Um die riesigen Datenmengen um uns herum zu nutzen, gibt es eine fortgeschrittene Lernstrategie, die als selbstüberwachtes Lernen bekannt ist. Beim selbstüberwachten Lernen nehmen wir die unbeschrifteten Daten und ahmen das überwachte Lernen nach.

# Selbstüberwachtes Lernen (SSL)

Beim selbstüberwachten Lernen teilen wir die Daten in positive und negative Beispiele auf - ähnlich wie bei der binären (überwachten) Klassifizierung - indem wir das betrachtete Objekt als positives Beispiel und alle anderen Beispiele als negativ behandeln.

Methoden des selbstüberwachten Lernens lernen gemeinsame Einbettungen und können grob in zwei Arten unterteilt werden:

  • Kontrastive Methoden
  • Nicht-kontrastive Methoden

Bei kontrastiven Methoden nehmen wir verschiedene Beispiele der gleichen Daten, wie verschiedene Ansichten des gleichen Bildes, und versuchen, ihre Ähnlichkeitswerte zu maximieren, während wir sie für die anderen Beispiele/Bilder minimieren. Nicht-kontrastive Methoden berücksichtigen hingegen keine negativen Beispiele. Einige der bekannten SSL-Methoden wie BYOL oder DINO sind gute Beispiele für nicht-kontrastive Methoden, während SimCLR und MoCo auch negative Beispiele verwenden und gute Beispiele für kontrastive Methoden sind.

# Kontrastives Lernen (CL)

Das kontrastive Lernen dreht sich um das einfache Konzept, eine Repräsentation zu wählen, die die Ähnlichkeiten zwischen positiven Datenpaaren maximiert, während sie für negative Paare minimiert wird. Zum Beispiel gebe ich ein Bild einer Mango ein, und nun sollte das Ziel sein, die Ähnlichkeit zwischen Mangobildern zu maximieren, während sie für andere Bilder minimiert wird.

Die einfachste Einstellung besteht darin, jeden Datenpunkt (bei seiner Betrachtung) als positiven Datenpunkt zu betrachten, während alle anderen als negativen Punkt betrachtet werden. Nehmen wir an, der Punkt ist , dann sollte das Paar als positiv betrachtet werden und als negativ (wobei einen anderen Datenpunkt bezeichnet). Hier repräsentiert das Bild der Mango und die Repräsentation ihres Encoders (und ähnlich für andere Bildrepräsentationen).

# Training

Für jedes Paar von Datenpunkten berechnen wir den Ähnlichkeitswert. Und wir verwenden diesen Wert, um unser Modell zu trainieren, indem wir diesen Wert für die negativen Datenpunkte minimieren und maximieren, wenn . Wie bei anderen Optimierungsproblemen wird dies auch durch Auswahl einer geeigneten Verlustfunktion, wie der Kreuzentropieverlust, durchgeführt.

Um das Training praktischer zu gestalten, führen wir es in Batches durch. Angenommen, wir haben einen Batch mit 32 Bildern, dann muss 1 Bild (es selbst) einen maximierten Ähnlichkeitswert haben, während die anderen 31 so weit wie möglich einen minimierten Ähnlichkeitswert haben müssen.

Da es jedoch schwierig ist, anhand eines einzelnen positiven Beispiels gegen eine Reihe von negativen Beispielen diskriminative Merkmale zu lernen, verwenden wir einige clevere Techniken wie Datenaugmentation, um mehrere Kopien desselben Datenpunkts zu erstellen.

# Beispiele für CL

Das kontrastive Lernen gibt es schon seit einiger Zeit. In den letzten 10 Jahren hat es sich von der normalen CNN-basierten (opens new window) Bildunterscheidung (die sich nur darüber freute, SIFT zu übertreffen) bis hin zu CLIP weiterentwickelt. Einige der modernen CL-Algorithmen sind:

  • Kontrastive Vorhersagekodierung, CPC
  • Ein einfacher Rahmen für das kontrastive Lernen von (visuellen) Repräsentationen, SimCLR
  • Momentum-Kontrast, MoCo
  • Kontrastives Sprach-Bild-Vortraining, CLIP

Lassen Sie uns sie kurz durchgehen, um eine bessere Vorstellung davon zu bekommen, wie kontrastives Lernen praktisch angewendet wird.

# Kontrastive Vorhersagekodierung

Inspiriert von der klassischen Datenkompressionstechnik der Vorhersagekodierung und ihrer Anpassung in der Neurowissenschaft versucht CPC (Kontrastive Vorhersagekodierung) (opens new window), sich auf die hochrangigen Informationen in den Daten zu konzentrieren und das Niederpegelrauschen zu ignorieren.

CPC funktioniert im Wesentlichen wie folgt:

  • Hochdimensionale Daten werden in einen geeigneten latenten Einbettungsraum komprimiert. Diese Kompression erleichtert die Modellierung der Daten und die entsprechenden Vorhersagen.
  • Vorhersagen werden im gewählten Einbettungsraum gemacht.
  • Das Modell wird mit der Noise-Contrastive Estimation (NCE)-Verlustfunktion trainiert.

CPC

# SimCLR

Ein einfacher Rahmen für das kontrastive Lernen visueller Repräsentationen, SimCLR ist eine fortschrittliche kontrastive Lernmethode für Computer Vision. Ohne dass eine Vor-Augmentation oder spezialisierte Architektur erforderlich ist, funktioniert SimCLR wie folgt:

  • Ein Bild wird zufällig ausgewählt, und seine Ansichten (zwei in der Originalimplementierung) werden mit verschiedenen Augmentationstechniken wie zufälligem Zuschneiden, zufälliger Farbverzerrung oder Gaußscher Unschärfe erzeugt.
  • Die Bildrepräsentation/Einbettung wird mithilfe eines auf ResNet basierenden CNN berechnet.
  • Diese Repräsentation wird weiter in eine (nichtlineare) Projektion mithilfe eines MLP transformiert.
  • Sowohl das CNN als auch das MLP werden trainiert, um die kontrastive Verlustfunktion zu minimieren.

Bisher haben wir über die Notwendigkeit des unüberwachten Lernens gesprochen, aber wir haben auch einige beschriftete Daten zur Verfügung. Wenn wir das CNN am Ende auf einigen beschrifteten Bildern feinabstimmen, hilft dies, die Leistung und Verallgemeinerung auf verschiedene andere (nachgelagerte) Aufgaben zu verbessern.

Einige Erkenntnisse darüber, wie kontrastives Lernen funktioniert

SimCLR hat nicht nur ein neues Modell mit sehr guter Leistung eingeführt (Sie können das Papier für eine detaillierte Ergebnisanalyse überprüfen), sondern die Autoren haben auch einige neue Erkenntnisse gegeben, die für nahezu jede kontrastive Lernmethode nützlich sein können. Daher halte ich es für sinnvoll, sie hier zu teilen:

  • Eine Kombination von Augmentationstechniken ist entscheidend: Zufälliges Zuschneiden und Farbverzerrung ergaben keine herausragenden Ergebnisse, wenn sie einzeln verwendet wurden, aber wenn sie in Kombination verwendet wurden, erzielten sie die besten Ergebnisse.

  • Nichtlineare Projektion ist wichtig: Die komplexe Natur von neuronalen Netzwerken und der kontrastiven Verlustfunktion bedeutet, dass es schwer zu verstehen ist, was hinter den Kulissen passiert, aber empirisch ist klar, dass die nichtlineare Projektion (durch MLP) nützlich ist, da sie die Leistung um bis zu 10% erhöht. Diese Tatsache wird auch in der Veröffentlichung von MoCov2 unabhängig davon beobachtet, wie wir gleich sehen werden.

  • Skalierung verbessert die Leistung: Während einige der Beobachtungen spezifisch für SimCLR sind, sind sie hauptsächlich allgemein für kontrastives Lernen. Eine Erhöhung der Kapazität des Modells (entweder Breite oder Tiefe), eine Erhöhung der Batch-Größe oder sogar die Anzahl der Epochen führen alle zu einer Leistungssteigerung.

# Momentum-Kontrast (MoCo)

Momentum-Kontrast (MoCo) betrachtet das kontrastive Lernen als ein Wörterbuch-Lookup aus einer alternativen Perspektive. Dieser interessante Standpunkt, der einige Ähnlichkeiten mit den Transformer-Modellen aufweist, funktioniert wie folgt:

  • Datenaugmentation wird angewendet, um zwei Kopien, und , zu erzeugen.
  • Der Query-Encoder (links im Bild) nimmt auf und generiert Einbettungen.
  • Der Momentum-Encoder nimmt die andere augmentierte Kopie, , und generiert dynamisch ein Wörterbuch von Schlüsseln, .

Um es relevant zu machen, wird es als Warteschlange implementiert, die aktuelle Minibatches-Daten aufnimmt und das älteste (sobald die Minibatches größer als sind) aus der Warteschlange entfernt. Da es als momentumbasierte gleitende Durchschnitt des Schlüsselencoders implementiert ist, hat es diesen Namen.

  • Der codierte Query, , passt zu dem Wörterbuch der Schlüssel, , und der kontrastive Verlust (infoNCE) wird berechnet.
  • Beide Encoder werden gemeinsam trainiert, um diesen kontrastiven Verlust zu minimieren.

Moco

Wenn Sie mit den Transformern vertraut sind, können Sie sehen, dass der InfoNCE-Verlust dem Vorgehen ähnelt, wie wir die Aufmerksamkeit in den Transformern berechnen.

# CLIP

CLIP, das 2021 eingeführt wurde, erhöht den Einsatz, indem es sowohl Bilder als auch ihre Bildunterschriften kombiniert. Obwohl es nicht auf einer momentumbasierten Durchschnittsberechnung basiert, arbeitet es auch mit zwei Encodern, einem für Text und einem für Bilder. Hier ist sein kurzer Workflow:

  • Das Bild wird in den Bildencoder eingegeben und seine Bildunterschrift wird in den Textencoder eingegeben.
  • Der Bildencoder, basierend auf einem ViT (Vision Transformer), erhält die Bild-Einbettung, während der Textencoder die Bildunterschrift tokenisiert, um die Textmerkmale zu erhalten. Diese Merkmale werden als Paar im Einbettungsraum zusammengeführt.
  • Sowohl der Text- als auch der Bildencoder werden so trainiert, dass der Abstand eines gegebenen Paares zu den anderen maximiert wird.
  • Beim Testen stellen wir ein Wörterbuch von Bildunterschriften bereit (nicht zu verwechseln mit dem dynamischen Wörterbuch in MoCo) und das gewünschte Bild. Basierend auf dem Bild gibt es die Bildunterschrift mit der höchsten Wahrscheinlichkeit zurück.

CLIP

Hinweis: Für weitere Details zu CLIP lesen Sie bitte diesen Artikel (opens new window).

Boost Your AI App Efficiency now
Sign up for free to benefit from 150+ QPS with 5,000,000 vectors
Free Trial
Explore our product

# Codebeispiel

Lassen Sie uns die Dinge mit einem Codebeispiel abschließen, um das Bild zu vervollständigen. Hier verwende ich die offizielle Implementierung von MoCo (opens new window) von Meta Research. Dieser Code konzentriert sich hauptsächlich auf die Klasse MoCo.

# Konstruktor

Der Konstruktor der Klasse MoCo initialisiert die Attribute wie K, m und T. Wie wir sehen können, verwendet er die Standardwerte für die Merkmalsdimension (dim) als 128, die Warteschlangengröße (K) als 16-Bit (65.536), während der Momentum-Koeffizient μ 0,999 (ziemlich langsame gleitende Durchschnitt) ist. Die Softmax-Temperatur τ, wie bereits im Papier angegeben, beträgt 0,07.

Außerdem können wir die Implementierung des MLP sehen, das wir zuerst in SimCLR und später in MoCov2 gesehen haben (diese Implementierung umfasst beide Versionen).

def __init__(self, base_encoder, dim=128, K=65536, m=0.999, T=0.07, mlp=False):
        
        super(MoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T

        # create the encoders
        # num_classes is the output fc dimension
        self.encoder_q = base_encoder(num_classes=dim)
        self.encoder_k = base_encoder(num_classes=dim)

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc
            )
            self.encoder_k.fc = nn.Sequential(
                nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc
            )

        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

# Momentum-Encoder

Der Kernbereich ist einfach, aber sehr effektiv. Im Momentum-Encoder implementieren wir einfach diese Gleichung für den momentumbasierten Durchschnitt:

μμ

@torch.no_grad()
    def _momentum_update_key_encoder(self):
    
        for param_q, param_k in zip(
            self.encoder_q.parameters(), self.encoder_k.parameters()
        ):
            param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m)

# Warteschlangenverwaltung

Schließlich wird die Warteschlange aktualisiert, um die älteren Batch-Daten zu entfernen. Es funktioniert wie folgt:

  • Wir sammeln zuerst die Schlüssel.
  • Der Zeiger der Warteschlange zeigt auf die Transposition der keys. Warum dies genau für die Spalten und nicht für die Zeilen erfolgt, ist etwas, das mich interessiert.
  • Am Ende bewegt sich der Zeiger zum nächsten Batch (hier ist das Modulo für die zirkuläre Warteschlange), und dies spiegelt sich auch im Attribut (self.queue_ptr) wider.
 def _dequeue_and_enqueue(self, keys):
        
        keys = concat_all_gather(keys)
        batch_size = keys.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0

        self.queue[:, ptr : ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K 
        self.queue_ptr[0] = ptr

# Training

Im Trainingsschritt kombinieren wir dies alles wie folgt:

  • Berechnung der Query-Merkmale
  • Berechnung der Key-Merkmale (unter Verwendung des Momentum-Updates)
  • Berechnung der Logits (sowohl positiv als auch negativ)
  • Dequeue und Enqueue unter Verwendung der oben diskutierten Funktion
def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            k = self._batch_unshuffle_ddp(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()

        # dequeue and enqueue
        self._dequeue_and_enqueue(k)

        return logits, labels

Übrigens sind all diese Kommentare von den Autoren selbst geschrieben und zeigen, dass sie nicht nur ein klares Papier geschrieben haben, sondern auch einen guten, leicht verständlichen Code.

Join Our Newsletter

# Kontrastives Lernen und Vektordatenbanken

Da sich das kontrastive Lernen um die Einbettungen der Eingabedaten und deren Abstand dreht, sind sie für Vektordatenbanken von großer Bedeutung. Mit einem trainierten CL-Modell können wir einfach eine Metrik wie den euklidischen Abstand, die Kosinusähnlichkeit oder den -Abstand anwenden, um Vorhersagen für die neuen Proben zu treffen.

Daher gibt es unzählige Anwendungen des kontrastiven Lernens in Verbindung mit Vektordatenbanken, zum Beispiel:

  • Zero-Shot-Erkennung (kann selbst eine Vielzahl von Anwendungen haben) mit CLIP
  • Empfehlungssystem für einen Online-Benutzer - basierend auf der Historie des Benutzers schlägt es Produkte vor, die der Historie am nächsten kommen.
  • Dokumentenabfrage - Abrufen des Dokuments aus der Datenbank, das der Abfrageeinbettung des Benutzers am nächsten kommt.
  • Anomalieerkennung - Ob eine neue Transaktion mit dieser Kreditkarte normal oder betrügerisch ist, kann mithilfe der Einbettungen früherer Transaktionen und der Ermittlung der Ähnlichkeit (oder Nicht-Ähnlichkeit) zwischen dieser Transaktion und dem "Bereich" der Transaktionshistorie im Einbettungsraum überprüft werden.

# Referenzen

  1. Dosovitskiy, et al., Discriminative Unsupervised Feature Learning with Exemplar Convolutional Neural Networks, IEEE PAMI, 2016.
  2. Hjelm, et al., Learning Deep Representations by Mutual Information Estimation and Maximization, ICLR, 2019.
  3. Oord, et al., Representation Learning with Contrastive Predictive Coding, arXiv 2018.
  4. Chen, et al., A Simple Framework for Contrastive Learning of Visual Representations, ICML 2020.
  5. Radford, et al., Learning Transferable Visual Models From Natural Language Supervision, arXiv 2020
  6. He, et al., Momentum Contrast for Unsupervised Visual Representation Learning, arXiv 2020
  7. He, et al., Improved Baselines with Momentum Contrastive Learning, arXiv 2020
Keep Reading
images
End-to-End RAG-Pipeline-Entwicklung mit MyScale und LlamaIndex

LlamaIndex ist ein Datenframework, das für die Implementierung von Anwendungen mit Large Language Models (LLMs) entwickelt wurde. Es vereinfacht das Parsen, Speichern und Abrufen verschiedener Arten von Dokumentendaten und fügt den Fähigkeiten von LLM-Anwendungen einen immensen Mehrwert hinzu ...

Start building your Al projects with MyScale today

Free Trial
Contact Us