Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Add question: SimCLR Contrastive Loss (NT-Xent)
  • Loading branch information
BARALLL committed Dec 13, 2025
commit 860463835d6eefa9180d54207dbf2163706a514a
90 changes: 90 additions & 0 deletions build/238.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"id": "238",
"title": "SimCLR Contrastive Loss (NT-Xent)",
"difficulty": "medium",
"category": "Deep Learning",
"video": "",
"likes": "0",
"dislikes": "0",
"contributor": [
{
"profile_link": "https://github.com/BARALLL",
"name": "baralm"
}
],
"description": "\\## NT-Xent Loss for Self-Supervised Contrastive Learning\n\n\n\nIn self-supervised contrastive learning frameworks like \\*\\*SimCLR\\*\\*, we learn meaningful representations without labels by:\n\n1\\. Creating two augmented \"views\" of each image\n\n2\\. Training the model to recognize that views of the \\*\\*same\\*\\* image should have similar embeddings\n\n3\\. While views of \\*\\*different\\*\\* images should have dissimilar embeddings\n\n\n\n\\### The Problem\n\nYou are given a batch of $N$ images. For each image, we generate 2 augmented views, resulting in a batch size of $2N$. The embeddings are organized in an \\*\\*interleaved\\*\\* fashion:\n\n\\- Rows $2k$ and $2k+1$ are two views of the same image $k$ (a positive pair).\n\n\\- Any other pair of rows constitutes a negative pair.\n\n\n\nFor a specific sample $i$, let $j$ be its positive pair. The \\*\\*NT-Xent (Normalized Temperature-scaled Cross-Entropy)\\*\\* loss for sample $i$ is defined as:\n\n\n\n$$\n\n\\\\ell\\_i = -\\\\log \\\\frac{\\\\exp(\\\\text{sim}(z\\_i, z\\_j) / \\\\tau)}{\\\\sum\\_{k=1}^{2N} \\\\mathbb{1}\\_{\\[k \\\\neq i]} \\\\exp(\\\\text{sim}(z\\_i, z\\_k) / \\\\tau)}\n\n$$\n\n\n\nWhere:\n\n\\- $z$ is the batch of L2-normalized embeddings.\n\n\\- $\\\\text{sim}(u, v) = u^\\\\top v$ (Cosine similarity, since $u, v$ are normalized).\n\n\\- $\\\\mathbb{1}\\_{\\[k \\\\neq i]}$ is an indicator function (returns 1 if $k \\\\neq i$, else 0). Effectively, we sum over all samples except the sample itself.\n\n\\- $\\\\tau$ is the temperature parameter.\n\n\n\nThe total loss is the arithmetic mean over all $2N$ samples: $L = \\\\frac{1}{2N} \\\\sum\\_{i=0}^{2N-1} \\\\ell\\_i$.\n\n\n\n\\### Your Task\n\nImplement the function `nt\\_xent\\_loss(z, temperature)` that computes the NT-Xent loss using vectorized NumPy operations.\n\n\n\n\\*\\*Input Format\\*\\*\n\n\\- `z`: A numpy array of shape `(2N, embedding\\_dim)` containing \\*\\*L2-normalized\\*\\* embeddings.\n\n  - \\*\\*Structure\\*\\*: The rows are interleaved such that `z\\[2k]` and `z\\[2k+1]` form a positive pair (two views of image $k$).\n\n  - Visually: `\\[View1\\_Img1, View2\\_Img1, View1\\_Img2, View2\\_Img2, ...]`.\n\n  - All other interactions `z\\[i]` and `z\\[j]` (where `j` is not the pair of `i`) are considered negatives.\n\n\\- `temperature`: A float scaling parameter ($\\\\tau > 0$).\n\n\n\n\\### Output Format\n\n\\- Returns `float`: The average NT-Xent loss over all $2N$ samples.\n\n\n\n\\### Note on Stability\n\n\\- You should implement the \\*\\*Log-Sum-Exp trick\\*\\* (subtracting the maximum value before exponentiation) to ensure numerical stability.\n\n\n\n\\### Constraints\n\n\\- $N \\\\geq 1$ (at least 1 image, so batch size $\\\\geq 2$)\n\n\\- `embedding\\_dim` $\\\\geq 1$\n\n\\- `temperature` $> 0$\n\n\\- Input embeddings are guaranteed to be L2-normalized\n\n\\- \\*\\*Performance:\\*\\* Avoid explicit `for` loops. Use matrix operations and broadcasting.",
"learn_section": "\n# Learn Section\n\n# Understanding NT-Xent Loss (Normalized Temperature-scaled Cross Entropy)\n\n### 1. The Intuition: \"Find Your Partner\"\nAt its core, Self-Supervised Learning (SSL) creates a \"pretext task\" from unlabeled data. \nImagine a crowded room of people (embeddings). Everyone has a generic twin. The goal of NT-Xent is to make you stand as close as possible to your twin (alignment), while pushing everyone else away (uniformity).\n\n* **Positive pairs**: Different views (augmentations) of the SAME image → should be **SIMILAR**.\n* **Negative pairs**: Views of DIFFERENT images → should be **DISSIMILAR**.\n\n### 2. Generating Views\nWe take a batch of $N$ images and generate 2 views for each, resulting in a batch size of $2N$.\n\n```text\nImage A ───────── Augment ──→ View A₁ ──┐\n │ ├── Should be near (Positive) ✓\n └───── Augment ───→ View A₂ ──┘\n\nImage B ───────── Augment ──→ View B₁ ──┐\n ├── Should be far (Negative) ✗\nImage A ───────── Augment ──→ View A₁ ──┘\n```\n\n### 3. The Math: A Classification Problem\nThe NT-Xent loss is essentially a **Softmax Cross-Entropy** loss. We treat the positive pair as the \"correct class\" and all other images in the batch as \"negative classes.\"\n\n$$\\ell_i = -\\log \\frac{\\exp(\\text{sim}(z_i, z_j) / \\tau)}{\\sum_{k \\neq i} \\exp(\\text{sim}(z_i, z_k) / \\tau)}$$\n\n**Breakdown:**\n* **Numerator**: The score of the positive pair ($z_i, z_j$). We want this high.\n* **Denominator**: The sum of scores of $z_i$ against *all* other samples (negatives + positive).\n* **Goal**: By maximizing the numerator relative to the denominator, we force the model to learn unique features that distinguish sample $i$ from the crowd.\n\n### 4. Visualizing the Similarity Matrix\nThe implementation relies on a $(2N \\times 2N)$ similarity matrix. If we organize our batch as `[Cat1, Cat2, Dog1, Dog2]`, the ideal matrix looks like this:\n\n$$\n\\begin{bmatrix}\n\\text{Mask} & \\mathbf{\\text{High}} & \\text{Low} & \\text{Low} \\\\\n\\mathbf{\\text{High}} & \\text{Mask} & \\text{Low} & \\text{Low} \\\\\n\\text{Low} & \\text{Low} & \\text{Mask} & \\mathbf{\\text{High}} \\\\\n\\text{Low} & \\text{Low} & \\mathbf{\\text{High}} & \\text{Mask}\n\\end{bmatrix}\n$$\n\n1. **The Diagonal (Masked)**: We explicitly ignore comparing an image to itself ($k \\neq i$).\n2. **The Off-diagonals**: These are the **Positive Pairs**. We want to maximize these values.\n3. **Everything else**: These are **Negatives**. We want to minimize these values.\n\n### 5. The Critical Role of Temperature ($\\tau$)\nThe temperature $\\tau$ scales the dot products before the softmax. It controls how much the model focuses on difficult examples.\n\n* **High $\\tau$ (e.g., 1.0)**: The distribution is smoother. The model treats all negatives roughly equally.\n* **Low $\\tau$ (e.g., 0.1)**: The distribution becomes sharp/peaky. The model ignores easy negatives and focuses heavily on **\"Hard Negatives\"** (images that look similar to the anchor but aren't).\n\n$$\\text{As } \\tau \\to 0: \\text{Loss approaches argmax (winner-take-all)}$$\n\n### 6. Why It Works: Alignment & Uniformity\nResearch shows this loss optimizes two specific geometric properties on the embedding hypersphere:\n1. **Alignment**: Two views of the same image map to nearby points.\n2. **Uniformity**: Feature vectors spread roughly uniformly across the sphere. This prevents **feature collapse**, where the model maps all images to the same constant vector to cheat the loss.\n\n### 7. Implementation Steps\n1. **Forward Pass**: Get normalized embeddings $z$ (shape $2N \\times D$).\n2. **Similarity**: Compute matrix $S = z \\cdot z^T$ (shape $2N \\times 2N$).\n3. **Scale**: Divide $S$ by $\\tau$.\n4. **Mask**: Set diagonal values to $-\\infty$ (so exp() becomes 0).\n5. **Labels**: Create target labels. If batch is organized as `[View1_A, View1_B, ..., View2_A, View2_B...]`, then $i$ matches with $i + N$.\n6. **Loss**: Apply Standard Cross Entropy.\n\n### 6. Numerical Stability (Log-Sum-Exp Trick)\nComputers struggle with large exponents. If $\\text{sim}=1.0$ and $\\tau=0.01$, then $e^{100}$ is huge. To prevent overflow, we use the identity:\n\n$$ \\log \\left( \\sum e^{x_i} \\right) = a + \\log \\left( \\sum e^{x_i - a} \\right) $$\n\nwhere $a = \\max(x)$.\nBy subtracting the maximum value from the logits before exponentiating, the largest term becomes $e^0 = 1$, preventing overflow while keeping the probabilities mathematically identical.\n\n### 8. Connection to InfoNCE\nNT-Xent is a specific form of InfoNCE (Noise Contrastive Estimation) loss, which has theoretical connections to maximizing mutual information between views.\n\n### References\n* **SimCLR Paper**: [Chen et al., 2020](https://arxiv.org/abs/2002.05709)\n ",
"starter_code": "def nt_xent_loss(z: np.ndarray, temperature: float) -> float:\n \"\"\"\n Compute the NT-Xent loss for contrastive learning.\n \n Args:\n z: L2-normalized embeddings, shape (2N, embedding_dim)\n Positive pairs: z[2k] and z[2k+1] are views of image k\n temperature: Temperature scaling parameter (τ > 0)\n \n Returns:\n The scalar NT-Xent loss value\n \"\"\"\n pass",
"solution": "def nt_xent_loss(z: np.ndarray, temperature: float) -> float:\n N = z.shape[0]\n sim = (z @ z.T) / temperature\n \n sim_exp = np.exp(sim - np.max(sim, axis=1, keepdims=True))\n mask_diag = ~np.eye(N, dtype=bool)\n denominator = np.sum(sim_exp * mask_diag, axis=1)\n \n indices = np.arange(N)\n pos_indices = indices + 1 - 2 * (indices % 2)\n numerator = sim_exp[indices, pos_indices]\n \n losses = -np.log(numerator / denominator)\n \n return float(np.mean(losses))",
"example": {
"input": "# N=2 (Total batch 4).\n# 0 and 1 are views of Cat. 2 and 3 are views of Dog.\n# 0 matches 1 (Positive). 0 mismatches 2 and 3 (Negatives).\n\nz = np.array([\n [1.0, 0.0], # 0: Cat View A\n [1.0, 0.0], # 1: Cat View B (Perfect match with 0)\n [0.0, 1.0], # 2: Dog View A (Orthogonal to 0)\n [0.0, 1.0] # 3: Dog View B (Orthogonal to 0)\n])\ntemperature = 0.5",
"output": "2.2395447662218846",
"reasoning": "Let's calculate loss for index 0 (Cat A):\n\n1. **Positive Pair**: Index 1. Sim = 1.0.\n\n2. **Negatives**: Index 2, Index 3. Sim = 0.0.\n\n3. **Terms**:\n\n - Numerator (Positive): $\\exp(1.0 / 0.5) = \\exp(2) \\approx 7.389$\n\n - Denominator (All $k \\neq 0$):\n\n - Term 1 (Pos): $\\exp(1.0/0.5) \\approx 7.389$\n\n - Term 2 (Neg): $\\exp(0.0/0.5) = 1$\n\n - Term 3 (Neg): $\\exp(0.0/0.5) = 1$\n\n - Denominator Sum: $7.389 + 1 + 1 = 9.389$\n\n4. **Loss for index 0**: $-\\log(7.389 / 9.389) \\approx 0.239$\n\n\n\nSince the setup is symmetric, all 4 indices have the same loss."
},
"test_cases": [
{
"test": "print(nt_xent_loss([[1, 0], [1, 0], [0, 1], [0, 1]], 1.0))",
"expected_output": "0.5514447139320511"
},
{
"test": "print(nt_xent_loss([[1, 0], [1, 0]], 1.0))",
"expected_output": "0.0"
},
{
"test": "print(nt_xent_loss([[1, 0], [1, 0], [1, 0], [1, 0]], 1.0))",
"expected_output": "1.0986122886681098"
},
{
"test": "print(nt_xent_loss([[1, 0], [1, 0], [-1, 0], [-1, 0]], 0.5))",
"expected_output": "0.035976299748193295"
},
{
"test": "print(nt_xent_loss([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], 1.0))",
"expected_output": "1.0986122886681098"
},
{
"test": "print(nt_xent_loss([[1, 0], [1, 0], [0, 1], [0, 1]], 100.0))",
"expected_output": "1.0919567454272663"
},
{
"test": "print(nt_xent_loss([[1, 0], [1, 0], [0, 1], [0, 1]], 0.5))",
"expected_output": "0.23954476622188456"
},
{
"test": "print(nt_xent_loss([[1, 0], [1, 0], [0, 1], [0, 1]], 2.0))",
"expected_output": "0.7943767694176431"
},
{
"test": "print(nt_xent_loss([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]], 1.0))",
"expected_output": "0.904832441554448"
},
{
"test": "print(nt_xent_loss([[1, 0], [0.8, 0.6], [0, 1], [-0.6, 0.8]], 1.0))",
"expected_output": "0.6735767888870939"
},
{
"test": "print(nt_xent_loss([[1, 0], [-1, 0], [0, 1], [0, -1]], 1.0))",
"expected_output": "1.861994804058251"
},
{
"test": "print(nt_xent_loss([[1, 0], [0, 1], [0.99, 0.141], [0.141, 0.99]], 1.0))",
"expected_output": "1.4700659232878173"
},
{
"test": "print(nt_xent_loss([[1], [1]], 1.0))",
"expected_output": "0.0"
},
{
"test": "print(nt_xent_loss([[1.0, 0.0], [0.707, 0.707], [-1.0, 0.0], [-0.707, 0.707]], 1.0))",
"expected_output": "0.4528130954640332"
},
{
"test": "import numpy as np; np.random.seed(42); N = 1000; dim = 64; temperature = 0.1; embeddings = np.random.randn(2 * N, dim); z = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True); print(nt_xent_loss(z, temperature))",
"expected_output": "8.37825890387809"
},
{
"test": "import numpy as np; np.random.seed(42); N = 8; dim = 8192; temperature = 0.5; embeddings = np.random.randn(2 * N, dim); z = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True); print(nt_xent_loss(z, temperature))",
"expected_output": "2.7080214605287156"
}
]
}
96 changes: 96 additions & 0 deletions questions/238_simclr_nt_xent/description.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
\## NT-Xent Loss for Self-Supervised Contrastive Learning



In self-supervised contrastive learning frameworks like \*\*SimCLR\*\*, we learn meaningful representations without labels by:

1\. Creating two augmented "views" of each image

2\. Training the model to recognize that views of the \*\*same\*\* image should have similar embeddings

3\. While views of \*\*different\*\* images should have dissimilar embeddings



\### The Problem

You are given a batch of $N$ images. For each image, we generate 2 augmented views, resulting in a batch size of $2N$. The embeddings are organized in an \*\*interleaved\*\* fashion:

\- Rows $2k$ and $2k+1$ are two views of the same image $k$ (a positive pair).

\- Any other pair of rows constitutes a negative pair.



For a specific sample $i$, let $j$ be its positive pair. The \*\*NT-Xent (Normalized Temperature-scaled Cross-Entropy)\*\* loss for sample $i$ is defined as:



$$

\\ell\_i = -\\log \\frac{\\exp(\\text{sim}(z\_i, z\_j) / \\tau)}{\\sum\_{k=1}^{2N} \\mathbb{1}\_{\[k \\neq i]} \\exp(\\text{sim}(z\_i, z\_k) / \\tau)}

$$



Where:

\- $z$ is the batch of L2-normalized embeddings.

\- $\\text{sim}(u, v) = u^\\top v$ (Cosine similarity, since $u, v$ are normalized).

\- $\\mathbb{1}\_{\[k \\neq i]}$ is an indicator function (returns 1 if $k \\neq i$, else 0). Effectively, we sum over all samples except the sample itself.

\- $\\tau$ is the temperature parameter.



The total loss is the arithmetic mean over all $2N$ samples: $L = \\frac{1}{2N} \\sum\_{i=0}^{2N-1} \\ell\_i$.



\### Your Task

Implement the function `nt\_xent\_loss(z, temperature)` that computes the NT-Xent loss using vectorized NumPy operations.



\*\*Input Format\*\*

\- `z`: A numpy array of shape `(2N, embedding\_dim)` containing \*\*L2-normalized\*\* embeddings.

  - \*\*Structure\*\*: The rows are interleaved such that `z\[2k]` and `z\[2k+1]` form a positive pair (two views of image $k$).

  - Visually: `\[View1\_Img1, View2\_Img1, View1\_Img2, View2\_Img2, ...]`.

  - All other interactions `z\[i]` and `z\[j]` (where `j` is not the pair of `i`) are considered negatives.

\- `temperature`: A float scaling parameter ($\\tau > 0$).



\### Output Format

\- Returns `float`: The average NT-Xent loss over all $2N$ samples.



\### Note on Stability

\- You should implement the \*\*Log-Sum-Exp trick\*\* (subtracting the maximum value before exponentiation) to ensure numerical stability.



\### Constraints

\- $N \\geq 1$ (at least 1 image, so batch size $\\geq 2$)

\- `embedding\_dim` $\\geq 1$

\- `temperature` $> 0$

\- Input embeddings are guaranteed to be L2-normalized

\- \*\*Performance:\*\* Avoid explicit `for` loops. Use matrix operations and broadcasting.

5 changes: 5 additions & 0 deletions questions/238_simclr_nt_xent/example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"input": "# N=2 (Total batch 4).\n# 0 and 1 are views of Cat. 2 and 3 are views of Dog.\n# 0 matches 1 (Positive). 0 mismatches 2 and 3 (Negatives).\n\nz = np.array([\n [1.0, 0.0], # 0: Cat View A\n [1.0, 0.0], # 1: Cat View B (Perfect match with 0)\n [0.0, 1.0], # 2: Dog View A (Orthogonal to 0)\n [0.0, 1.0] # 3: Dog View B (Orthogonal to 0)\n])\ntemperature = 0.5",
"output": "2.2395447662218846",
"reasoning": "Let's calculate loss for index 0 (Cat A):\n\n1. **Positive Pair**: Index 1. Sim = 1.0.\n\n2. **Negatives**: Index 2, Index 3. Sim = 0.0.\n\n3. **Terms**:\n\n - Numerator (Positive): $\\exp(1.0 / 0.5) = \\exp(2) \\approx 7.389$\n\n - Denominator (All $k \\neq 0$):\n\n - Term 1 (Pos): $\\exp(1.0/0.5) \\approx 7.389$\n\n - Term 2 (Neg): $\\exp(0.0/0.5) = 1$\n\n - Term 3 (Neg): $\\exp(0.0/0.5) = 1$\n\n - Denominator Sum: $7.389 + 1 + 1 = 9.389$\n\n4. **Loss for index 0**: $-\\log(7.389 / 9.389) \\approx 0.239$\n\n\n\nSince the setup is symmetric, all 4 indices have the same loss."
}
Loading