Learning Distance Metrics¶
Real-World Scenario: Drug Compound Similarity from Molecular Descriptors¶
In drug discovery, we screen thousands of compounds and characterize them with molecular descriptors (e.g., molecular weight, LogP, number of hydrogen bond donors/acceptors, topological polar surface area). Compounds with the same mechanism of action (MoA) should be "close" in descriptor space, but Euclidean distance over raw descriptors is often misleading — irrelevant features dominate, and the truly discriminative dimensions are buried.
Metric learning solves this by learning a distance function that pulls same-MoA compounds together and pushes different-MoA compounds apart. We explore:
- Mahalanobis distance and why raw Euclidean distance fails
- Large Margin Nearest Neighbors (LMNN) — a linear metric learning method
- Neighborhood Components Analysis (NCA) — a differentiable linear approach
- Deep metric learning — contrastive (Siamese), triplet, and N-pairs losses
- Mining strategies — hard vs. semi-hard negatives
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.patches import Ellipse
from scipy.spatial.distance import cdist
from scipy.optimize import minimize
from collections import Counter
np.random.seed(42)
plt.style.use('seaborn-v0_8-whitegrid')
mpl.rcParams['font.family'] = 'DejaVu Sans'
Key Formulas from PML Chapter 16.2¶
Mahalanobis distance (Eq. 16.3):
$$d_M(\mathbf{x}, \mathbf{x}') = \sqrt{(\mathbf{x} - \mathbf{x}')^\top M (\mathbf{x} - \mathbf{x}')}$$
where $M \succeq 0$ is a positive semi-definite matrix. Setting $M = W^\top W$ gives $d_M(\mathbf{x}, \mathbf{x}') = \|W\mathbf{x} - W\mathbf{x}'\|_2$, i.e., Euclidean distance in a linearly transformed space.
LMNN objective (Eq. 16.4–16.5): minimize a weighted combination of
$$\mathcal{L}_{\text{pull}}(M) = \sum_{i} \sum_{j \in \mathcal{N}_i} d_M(\mathbf{x}_i, \mathbf{x}_j)^2 \qquad \mathcal{L}_{\text{push}}(M) = \sum_{i} \sum_{j \in \mathcal{N}_i} \sum_{l: y_l \neq y_i} \big[m + d_M(\mathbf{x}_i, \mathbf{x}_j)^2 - d_M(\mathbf{x}_i, \mathbf{x}_l)^2\big]_+$$
NCA probability (Eq. 16.6): the probability that point $i$ selects $j$ as its nearest neighbor:
$$p_{ij}^W = \frac{\exp(-\|W\mathbf{x}_i - W\mathbf{x}_j\|^2)}{\sum_{l \neq i} \exp(-\|W\mathbf{x}_i - W\mathbf{x}_l\|^2)}$$
Contrastive loss (Eq. 16.10):
$$\mathcal{L}(\theta; \mathbf{x}_i, \mathbf{x}_j) = \mathbb{1}(y_i = y_j)\, d^2 + \mathbb{1}(y_i \neq y_j)\,[m - d]_+^2$$
Triplet loss (Eq. 16.11):
$$\mathcal{L}(\theta; \mathbf{x}_i, \mathbf{x}_i^+, \mathbf{x}_i^-) = \big[d(\mathbf{x}_i, \mathbf{x}_i^+)^2 - d(\mathbf{x}_i, \mathbf{x}_i^-)^2 + m\big]_+$$
N-pairs loss (Eq. 16.12):
$$\mathcal{L} = -\log \frac{\exp(\hat{\mathbf{e}}^\top \hat{\mathbf{e}}^+)}{\exp(\hat{\mathbf{e}}^\top \hat{\mathbf{e}}^+) + \sum_k \exp(\hat{\mathbf{e}}^\top \hat{\mathbf{e}}_k^-)}$$
1. Generate Synthetic Drug Compound Data¶
We simulate molecular descriptors for compounds from 4 mechanisms of action (MoA): kinase inhibitors, HDAC inhibitors, proteasome inhibitors, and DNA alkylators. Each compound is described by 8 molecular descriptors, but only a few dimensions truly separate the MoA classes.
N_per_class = 50
C = 4
D = 12 # 3 discriminative + 9 noisy dimensions
N = N_per_class * C
moa_names = ['Kinase Inh.', 'HDAC Inh.', 'Proteasome Inh.', 'DNA Alkylator']
colors = ['#e41a1c', '#377eb8', '#4daf4a', '#984ea3']
# Class centers: only dimensions 0-2 are discriminative, rest are shared/noisy
centers = np.zeros((C, D))
centers[:, :3] = np.array([
[ 2.0, 1.5, 0.5], # Kinase inhibitors
[-1.5, 2.0, -1.0], # HDAC inhibitors
[-1.0, -2.0, 2.0], # Proteasome inhibitors
[ 1.0, -1.5, -1.5], # DNA alkylators
])
# Covariance: discriminative dims have moderate spread, noisy dims have large spread
cov_diag = np.concatenate([np.full(3, 0.8), np.full(9, 5.0)])
cov = np.diag(cov_diag)
X_all = []
y_all = []
for c in range(C):
X_c = np.random.multivariate_normal(centers[c], cov, size=N_per_class)
X_all.append(X_c)
y_all.append(np.full(N_per_class, c))
X = np.vstack(X_all)
y = np.concatenate(y_all)
# Standardize
X_mean, X_std = X.mean(0), X.std(0)
X = (X - X_mean) / X_std
# Train/test split
idx = np.random.permutation(N)
n_train = int(0.7 * N)
X_train, y_train = X[idx[:n_train]], y[idx[:n_train]]
X_test, y_test = X[idx[n_train:]], y[idx[n_train:]]
print(f'Training set: {X_train.shape[0]} compounds, {D} descriptors, {C} MoA classes')
print(f'Test set: {X_test.shape[0]} compounds')
Training set: 140 compounds, 12 descriptors, 4 MoA classes Test set: 60 compounds
# cov
# centers
2. Why Euclidean Distance Fails¶
Euclidean distance treats all 8 descriptor dimensions equally. Since dimensions 3–7 are high-variance noise, they dominate the distance computation and drown out the discriminative signal in dimensions 0–2.
def knn_accuracy(X_train, y_train, X_test, y_test, K=5, W=None):
"""KNN classifier with optional linear transform W (Mahalanobis: M = W^T W)."""
if W is not None:
X_tr = X_train @ W.T
X_te = X_test @ W.T
else:
X_tr, X_te = X_train, X_test
dists = cdist(X_te, X_tr, metric='euclidean')
preds = []
for i in range(len(X_te)): # Actual KNN applied on X_test
nn_idx = np.argsort(dists[i])[:K]
nn_labels = y_train[nn_idx]
# Majority vote: pick the most frequent label among the K neighbors
counts = Counter(nn_labels)
preds.append(counts.most_common(1)[0][0])
return np.mean(np.array(preds) == y_test)
acc_euclidean = knn_accuracy(X_train, y_train, X_test, y_test, K=5)
print(f'KNN (K=5) with Euclidean distance: {acc_euclidean:.1%}')
# Oracle: only use the 3 discriminative dimensions
acc_oracle = knn_accuracy(X_train[:, :3], y_train, X_test[:, :3], y_test, K=5)
print(f'KNN (K=5) with oracle (dims 0-2 only): {acc_oracle:.1%}')
KNN (K=5) with Euclidean distance: 83.3% KNN (K=5) with oracle (dims 0-2 only): 96.7%
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Left: dims 0 vs 1 (discriminative)
for c in range(C):
mask = y_train == c
axes[0].scatter(X_train[mask, 0], X_train[mask, 1], c=colors[c],
label=moa_names[c], alpha=0.6, s=30, edgecolors='w', linewidths=0.3)
axes[0].set_xlabel('Descriptor 0 (discriminative)')
axes[0].set_ylabel('Descriptor 1 (discriminative)')
axes[0].set_title('Discriminative dimensions (0 vs 1)')
axes[0].legend(fontsize=8, loc='upper left')
# Right: dims 3 vs 4 (noisy)
for c in range(C):
mask = y_train == c
axes[1].scatter(X_train[mask, 3], X_train[mask, 4], c=colors[c],
label=moa_names[c], alpha=0.6, s=30, edgecolors='w', linewidths=0.3)
axes[1].set_xlabel('Descriptor 3 (noisy)')
axes[1].set_ylabel('Descriptor 4 (noisy)')
axes[1].set_title('Noisy dimensions (3 vs 4)')
axes[1].legend(fontsize=8, loc='upper left')
plt.tight_layout()
plt.show()
3. Mahalanobis Distance and LMNN¶
3.1 The Mahalanobis Distance¶
The Mahalanobis distance (Eq. 16.3) generalizes Euclidean distance by introducing a PSD matrix $M$:
$$d_M(\mathbf{x}, \mathbf{x}') = \sqrt{(\mathbf{x} - \mathbf{x}')^\top M (\mathbf{x} - \mathbf{x}')}$$
Writing $M = W^\top W$, this becomes Euclidean distance after the linear transform $W$: the matrix $W$ stretches space along discriminative directions and compresses it along noisy ones.
3.2 Large Margin Nearest Neighbors (LMNN)¶
LMNN learns $M$ by combining two objectives:
- Pull loss $\mathcal{L}_{\text{pull}}$: minimize distances between each point and its target neighbors (same-class nearest neighbors in Euclidean space)
- Push loss $\mathcal{L}_{\text{push}}$: ensure that impostors (different-class points) are farther away by at least a margin $m$
We parametrize $M = W^\top W$ and optimize $W$ with gradient descent.
Gradient derivation¶
The total loss is:
$$\mathcal{L} = (1 - \lambda) \sum_{i,j \in \text{targets}} d_{ij}^2 + \lambda \sum_{i,j,l} \big[m + d_{ij}^2 - d_{il}^2\big]_+$$
where $d_{ij}^2 = \|W(\mathbf{x}_i - \mathbf{x}_j)\|^2 = (W \boldsymbol{\delta}_{ij})^\top (W \boldsymbol{\delta}_{ij})$ with $\boldsymbol{\delta}_{ij} = \mathbf{x}_i - \mathbf{x}_j$.
Step 1 — derivative of a squared distance w.r.t. $W$. Let $\mathbf{z}_i = W\mathbf{x}_i$, so $d_{ij}^2 = \|\mathbf{z}_i - \mathbf{z}_j\|^2$. Using the chain rule:
$$\frac{\partial\, d_{ij}^2}{\partial W} = \frac{\partial}{\partial W} (W \boldsymbol{\delta}_{ij})^\top (W \boldsymbol{\delta}_{ij}) = 2\,(W \boldsymbol{\delta}_{ij})\, \boldsymbol{\delta}_{ij}^\top$$
This is the outer product $2\,(\mathbf{z}_i - \mathbf{z}_j)\,(\mathbf{x}_i - \mathbf{x}_j)^\top$, which has the same shape as $W$ (an $L \times D$ matrix).
Step 2 — pull gradient. Each pull term contributes:
$$\frac{\partial\, \mathcal{L}_{\text{pull}}}{\partial W} = 2(1 - \lambda) \sum_{i,j \in \text{targets}} (\mathbf{z}_i - \mathbf{z}_j)(\mathbf{x}_i - \mathbf{x}_j)^\top$$
This pushes $W$ to shrink the projected distance between target neighbors.
Step 3 — push gradient. Each active hinge ($m + d_{ij}^2 - d_{il}^2 > 0$) contributes:
$$\frac{\partial\, \mathcal{L}_{\text{push}}}{\partial W} = 2\lambda \sum_{\text{active}} \Big[(\mathbf{z}_i - \mathbf{z}_j)(\mathbf{x}_i - \mathbf{x}_j)^\top - (\mathbf{z}_i - \mathbf{z}_l)(\mathbf{x}_i - \mathbf{x}_l)^\top\Big]$$
The first term increases the target-neighbor distance (same direction as pull — the impostor "steals" the pull budget), while the second term decreases the impostor distance, pushing $W$ to move impostor $l$ farther away. Inactive hinges contribute zero gradient.
def find_target_neighbors(X, y, K=3):
"""For each point, find K nearest same-class neighbors (in Euclidean space)."""
N = len(X)
targets = []
for i in range(N):
# np.where returns the indices where the condition is True Eg : np.where([False, False, True, False]) → (array([2]),)
same_class = np.where((y == y[i]) & (np.arange(N) != i))[0] # Get all points with the same class, excluding the point itself
dists = np.sum((X[same_class] - X[i])**2, axis=1) # Compute the squared Euclidean distance between the point and all other points in the same class
nn = same_class[np.argsort(dists)[:K]] # Get the indices of the K nearest neighbors
targets.append(nn)
return targets
def lmnn_loss_and_grad(W_flat, X, y, targets, lam=0.5, margin=1.0):
"""Compute LMNN loss and gradient w.r.t. W (where M = W^T W)."""
D = X.shape[1]
L = W_flat.shape[0] // D
W = W_flat.reshape(L, D)
Z = X @ W.T # project to L-dim space
N = len(X)
loss_pull = 0.0
loss_push = 0.0
grad_W = np.zeros_like(W)
for i in range(N):
for j in targets[i]:
diff_ij = Z[i] - Z[j]
d_ij_sq = np.dot(diff_ij, diff_ij)
# Pull term
loss_pull += d_ij_sq
x_diff_ij = X[i] - X[j]
grad_W += 2 * (1 - lam) * np.outer(diff_ij, x_diff_ij)
# Push term: find impostors
diff_class = np.where(y != y[i])[0]
for l in diff_class:
diff_il = Z[i] - Z[l]
d_il_sq = np.dot(diff_il, diff_il)
hinge = margin + d_ij_sq - d_il_sq
if hinge > 0:
loss_push += hinge
x_diff_il = X[i] - X[l]
grad_W += 2 * lam * (np.outer(diff_ij, x_diff_ij) - np.outer(diff_il, x_diff_il))
loss = (1 - lam) * loss_pull + lam * loss_push
return loss, grad_W.ravel()
# Use a subset for speed
n_sub = 80
idx_sub = np.random.choice(len(X_train), n_sub, replace=False)
X_sub, y_sub = X_train[idx_sub], y_train[idx_sub]
targets = find_target_neighbors(X_sub, y_sub, K=3)
# Initialize W as identity (project to 3D for visualization)
L_dim = 3
W0 = np.random.randn(L_dim, D) * 0.3
W0[:3, :3] = np.eye(3) # warm start on discriminative dims
print('Optimizing LMNN...')
result = minimize(
lambda w: lmnn_loss_and_grad(w, X_sub, y_sub, targets, lam=0.5, margin=1.0),
W0.ravel(), jac=True, method='L-BFGS-B',
options={'maxiter': 200, 'disp': False}
)
W_lmnn = result.x.reshape(L_dim, D)
acc_lmnn = knn_accuracy(X_train, y_train, X_test, y_test, K=5, W=W_lmnn)
print(f'LMNN converged: {result.success}')
print(f'KNN (K=5) with LMNN metric: {acc_lmnn:.1%}')
print(f' vs. Euclidean: {acc_euclidean:.1%}')
Optimizing LMNN...
/var/folders/34/4mb6rzb52l76jcqm_pjx3fph0000gn/T/ipykernel_50093/2381697575.py:64: DeprecationWarning: scipy.optimize: The `disp` and `iprint` options of the L-BFGS-B solver are deprecated and will be removed in SciPy 1.18.0. result = minimize(
LMNN converged: True KNN (K=5) with LMNN metric: 98.3% vs. Euclidean: 83.3%
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Left: Euclidean (first 2 dims of raw space)
for c in range(C):
mask = y_train == c
axes[0].scatter(X_train[mask, 0], X_train[mask, 1], c=colors[c],
label=moa_names[c], alpha=0.6, s=30, edgecolors='w', linewidths=0.3)
axes[0].set_title('Raw space (dims 0–1)')
axes[0].set_xlabel('Dim 0'); axes[0].set_ylabel('Dim 1')
axes[0].legend(fontsize=8)
# Right: LMNN-transformed space
Z_lmnn = X_train @ W_lmnn.T
for c in range(C):
mask = y_train == c
axes[1].scatter(Z_lmnn[mask, 0], Z_lmnn[mask, 1], c=colors[c],
label=moa_names[c], alpha=0.6, s=30, edgecolors='w', linewidths=0.3)
axes[1].set_title('LMNN-transformed space (dims 0–1)')
axes[1].set_xlabel('LMNN Dim 0'); axes[1].set_ylabel('LMNN Dim 1')
axes[1].legend(fontsize=8)
plt.tight_layout()
plt.show()
4. Neighborhood Components Analysis (NCA)¶
NCA (Eq. 16.6) defines a soft nearest-neighbor probability using a linear softmax:
$$p_{ij}^W = \frac{\exp(-\|W\mathbf{x}_i - W\mathbf{x}_j\|^2)}{\sum_{l \neq i} \exp(-\|W\mathbf{x}_i - W\mathbf{x}_l\|^2)}$$
The expected leave-one-out accuracy is $J(W) = \sum_i \sum_{j: y_j = y_i} p_{ij}^W$. We maximize $J$ (minimize $-J$) w.r.t. $W$.
Unlike LMNN, NCA uses a differentiable soft assignment rather than hard nearest neighbors, which makes the objective smooth.
def nca_loss_and_grad(W_flat, X, y, L):
"""NCA loss (negative expected LOO accuracy) and gradient."""
D = X.shape[1]
W = W_flat.reshape(L, D)
Z = X @ W.T
N = len(X)
# Pairwise squared distances in transformed space
dists_sq = np.sum((Z[:, None, :] - Z[None, :, :]) ** 2, axis=2) # (N, N)
# Softmax probabilities (exclude self)
np.fill_diagonal(dists_sq, np.inf) # exclude self
exp_neg = np.exp(-dists_sq)
P = exp_neg / exp_neg.sum(axis=1, keepdims=True) # (N, N)
# Expected accuracy: sum P_ij for same-class pairs
same_class = (y[:, None] == y[None, :]) # (N, N)
np.fill_diagonal(same_class, False)
p_correct = (P * same_class).sum(axis=1) # (N,)
loss = -p_correct.sum()
# Gradient
grad_W = np.zeros_like(W)
for i in range(N):
# Weighted mean of all neighbors
diff = Z[i] - Z # (N, L)
# Term 1: sum_j p_ij * (z_i - z_j)(x_i - x_j)^T for same class
x_diff = X[i] - X # (N, D)
weights_same = P[i] * same_class[i]
weights_all = P[i]
term1 = (weights_same[:, None] * diff).T @ x_diff # (L, D)
term2 = p_correct[i] * (weights_all[:, None] * diff).T @ x_diff # (L, D)
grad_W += 2 * (term1 - term2)
return loss, grad_W.ravel()
W0_nca = np.random.randn(L_dim, D) * 0.3
W0_nca[:3, :3] = np.eye(3) * 0.5
print('Optimizing NCA...')
result_nca = minimize(
lambda w: nca_loss_and_grad(w, X_sub, y_sub, L_dim),
W0_nca.ravel(), jac=True, method='L-BFGS-B',
options={'maxiter': 200, 'disp': False}
)
W_nca = result_nca.x.reshape(L_dim, D)
acc_nca = knn_accuracy(X_train, y_train, X_test, y_test, K=5, W=W_nca)
print(f'NCA converged: {result_nca.success}')
print(f'KNN (K=5) with NCA metric: {acc_nca:.1%}')
print(f' vs. Euclidean: {acc_euclidean:.1%}')
Optimizing NCA... NCA converged: True KNN (K=5) with NCA metric: 98.3% vs. Euclidean: 83.3%
/var/folders/34/4mb6rzb52l76jcqm_pjx3fph0000gn/T/ipykernel_50093/1474824119.py:43: DeprecationWarning: scipy.optimize: The `disp` and `iprint` options of the L-BFGS-B solver are deprecated and will be removed in SciPy 1.18.0. result_nca = minimize(
# Visualize NCA embedding
Z_nca = X_train @ W_nca.T
fig, axes = plt.subplots(1, 3, figsize=(16, 4.5))
titles = ['Euclidean (raw)', 'LMNN', 'NCA']
embeddings = [X_train[:, :2], Z_lmnn[:, :2], Z_nca[:, :2]]
for ax, Z_emb, title in zip(axes, embeddings, titles):
for c in range(C):
mask = y_train == c
ax.scatter(Z_emb[mask, 0], Z_emb[mask, 1], c=colors[c],
label=moa_names[c], alpha=0.6, s=30, edgecolors='w', linewidths=0.3)
ax.set_title(title)
ax.legend(fontsize=7, loc='best')
plt.suptitle('Learned embeddings: Euclidean vs. LMNN vs. NCA', fontsize=13, y=1.02)
plt.tight_layout()
plt.show()
5. Deep Metric Learning Losses¶
When the input is high-dimensional or structured (images, sequences), we learn a nonlinear embedding $\mathbf{e} = f(\mathbf{x}; \theta)$ using a neural network, and compute distances in embedding space.
Here we implement the key loss functions from scratch using our drug compound data to build intuition. In practice, these losses are applied to deep networks trained on images or molecular graphs.
5.1 Contrastive Loss (Siamese Networks)¶
The contrastive loss (Eq. 16.10) operates on pairs — pulling same-class pairs together and pushing different-class pairs apart beyond a margin $m$:
$$\mathcal{L}(\mathbf{x}_i, \mathbf{x}_j) = \begin{cases} d^2 & \text{if } y_i = y_j \\ [m - d]_+^2 & \text{if } y_i \neq y_j \end{cases}$$
def contrastive_loss(d, is_same, margin=1.0):
"""Contrastive loss for a pair at distance d."""
if is_same:
return d ** 2
else:
return max(0, margin - d) ** 2
# Visualize contrastive loss as a function of distance
d_range = np.linspace(0, 2.5, 200)
margin = 1.0
loss_pos = d_range ** 2
loss_neg = np.maximum(0, margin - d_range) ** 2
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(d_range, loss_pos, 'b-', lw=2, label='Same class (positive pair)')
ax.plot(d_range, loss_neg, 'r-', lw=2, label='Different class (negative pair)')
ax.axvline(margin, color='gray', ls='--', alpha=0.5, label=f'Margin m = {margin}')
ax.set_xlabel('Distance d')
ax.set_ylabel('Loss')
ax.set_title('Contrastive Loss (Eq. 16.10)')
ax.legend()
ax.set_ylim(-0.1, 4)
plt.tight_layout()
plt.show()
5.2 Triplet Loss¶
The triplet loss (Eq. 16.11) operates on (anchor, positive, negative) triplets. It ensures the anchor is closer to its positive than to its negative by at least a margin $m$:
$$\mathcal{L} = \big[d(\mathbf{x}_a, \mathbf{x}_p)^2 - d(\mathbf{x}_a, \mathbf{x}_n)^2 + m\big]_+$$
This avoids the problem of contrastive loss where positive and negative objectives are independent.
def triplet_loss(d_pos, d_neg, margin=1.0):
"""Triplet loss given distances to positive and negative."""
return max(0, d_pos**2 - d_neg**2 + margin)
# Visualize triplet loss as a function of (d_pos - d_neg)
diff_range = np.linspace(-3, 3, 300)
# For fixed d_neg=1, vary d_pos
d_neg_fixed = 1.5
d_pos_range = np.linspace(0, 3, 300)
losses_triplet = np.maximum(0, d_pos_range**2 - d_neg_fixed**2 + margin)
fig, axes = plt.subplots(1, 2, figsize=(13, 4.5))
# Left: triplet loss vs d_pos for fixed d_neg
axes[0].plot(d_pos_range, losses_triplet, 'b-', lw=2)
axes[0].axvline(np.sqrt(d_neg_fixed**2 - margin), color='gray', ls='--', alpha=0.5,
label=f'Zero-loss boundary')
axes[0].set_xlabel('d(anchor, positive)')
axes[0].set_ylabel('Triplet Loss')
axes[0].set_title(f'Triplet Loss (d_neg = {d_neg_fixed:.1f}, m = {margin})')
axes[0].legend()
axes[0].set_ylim(-0.2, 8)
# Right: 2D heatmap of triplet loss
d_p = np.linspace(0, 3, 100)
d_n = np.linspace(0, 3, 100)
D_P, D_N = np.meshgrid(d_p, d_n)
L_triplet = np.maximum(0, D_P**2 - D_N**2 + margin)
im = axes[1].contourf(D_P, D_N, L_triplet, levels=20, cmap='RdYlBu_r')
axes[1].plot([0, 3], [0, 3], 'k--', alpha=0.3, label='d_pos = d_neg')
axes[1].contour(D_P, D_N, L_triplet, levels=[0.001], colors='white', linewidths=2)
axes[1].set_xlabel('d(anchor, positive)')
axes[1].set_ylabel('d(anchor, negative)')
axes[1].set_title('Triplet Loss Landscape')
plt.colorbar(im, ax=axes[1], label='Loss')
axes[1].legend(fontsize=8)
plt.tight_layout()
plt.show()
5.3 N-Pairs Loss¶
N-pairs loss (Eq. 16.12) generalizes triplet loss by comparing the anchor against multiple negatives simultaneously. It uses cosine similarity (via normalized embeddings) and softmax:
$$\mathcal{L} = -\log \frac{\exp(\hat{\mathbf{e}}^\top \hat{\mathbf{e}}^+)}{\exp(\hat{\mathbf{e}}^\top \hat{\mathbf{e}}^+) + \sum_{k} \exp(\hat{\mathbf{e}}^\top \hat{\mathbf{e}}_k^-)}$$
This is equivalent to a softmax cross-entropy where the positive is the correct "class". When $N=2$ (one negative), it reduces to the logistic loss (Eq. 16.14).
def n_pairs_loss(sim_pos, sims_neg):
"""N-pairs loss given cosine similarity to positive and array of negatives."""
logits = np.concatenate([[sim_pos], sims_neg])
log_softmax = sim_pos - np.log(np.sum(np.exp(logits)))
return -log_softmax
# Compare triplet vs N-pairs loss as function of sim_pos - sim_neg
sim_diff = np.linspace(-3, 3, 300)
# Triplet (hinge): max(0, -diff + m)
m_trip = 1.0
loss_hinge = np.maximum(0, -sim_diff + m_trip)
# Logistic (N=2 case of N-pairs): log(1 + exp(-diff))
loss_logistic = np.log(1 + np.exp(-sim_diff))
fig, ax = plt.subplots(figsize=(7, 4))
ax.plot(sim_diff, loss_hinge, 'b-', lw=2, label='Triplet (hinge, m=1)')
ax.plot(sim_diff, loss_logistic, 'r-', lw=2, label='N-pairs (logistic, N=2)')
ax.set_xlabel('sim(anchor, pos) − sim(anchor, neg)')
ax.set_ylabel('Loss')
ax.set_title('Triplet (Hinge) vs. N-pairs (Logistic) Loss')
ax.legend()
ax.set_ylim(-0.1, 4)
ax.axhline(0, color='gray', lw=0.5)
ax.axvline(0, color='gray', lw=0.5)
plt.tight_layout()
plt.show()
The logistic loss (N-pairs) is smooth everywhere — it always provides a gradient signal, even for well-separated pairs. In contrast, the hinge loss has a flat region where the gradient is exactly zero, which can slow learning.
6. Hard Negative Mining¶
Most triplets are uninformative (the negative is already far from the anchor, giving zero loss). Training is dominated by the rare hard and semi-hard negatives (Fig. 16.6a):
| Type | Condition |
|---|---|
| Hard negative | $d(a, n) < d(a, p)$ — closer than the positive |
| Semi-hard negative | $d(a, p) < d(a, n) < d(a, p) + m$ — within the margin |
| Easy negative | $d(a, n) > d(a, p) + m$ — already beyond the margin (zero loss) |
We demonstrate this by sampling triplets from our drug compound data and classifying them.
# Sample triplets and classify as hard / semi-hard / easy
n_triplets = 2000
margin = 1.0
d_pos_list, d_neg_list = [], []
triplet_type = []
for _ in range(n_triplets):
# Pick random anchor
a = np.random.randint(len(X_train))
# Random positive (same class)
same = np.where((y_train == y_train[a]) & (np.arange(len(X_train)) != a))[0]
p = np.random.choice(same)
# Random negative (different class)
diff = np.where(y_train != y_train[a])[0]
n = np.random.choice(diff)
d_ap = np.linalg.norm(X_train[a] - X_train[p])
d_an = np.linalg.norm(X_train[a] - X_train[n])
d_pos_list.append(d_ap)
d_neg_list.append(d_an)
if d_an < d_ap:
triplet_type.append('hard')
elif d_an < d_ap + margin:
triplet_type.append('semi-hard')
else:
triplet_type.append('easy')
d_pos_arr = np.array(d_pos_list)
d_neg_arr = np.array(d_neg_list)
triplet_type = np.array(triplet_type)
print('Triplet distribution (Euclidean):')
for t in ['hard', 'semi-hard', 'easy']:
print(f' {t:10s}: {np.sum(triplet_type == t):5d} ({np.mean(triplet_type == t):.1%})')
Triplet distribution (Euclidean): hard : 573 (28.6%) semi-hard : 646 (32.3%) easy : 781 (39.1%)
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
type_colors = {'hard': '#e41a1c', 'semi-hard': '#ff7f00', 'easy': '#4daf4a'}
# Left: Euclidean space
for t in ['easy', 'semi-hard', 'hard']: # plot easy first so hard is on top
mask = triplet_type == t
axes[0].scatter(d_pos_arr[mask], d_neg_arr[mask], c=type_colors[t],
alpha=0.3, s=10, label=f'{t} ({mask.sum()})')
lim = max(d_pos_arr.max(), d_neg_arr.max()) * 1.05
axes[0].plot([0, lim], [0, lim], 'k--', alpha=0.3, label='d_pos = d_neg')
axes[0].plot([0, lim], [margin, lim + margin], 'k:', alpha=0.3, label='d_neg = d_pos + m')
axes[0].set_xlabel('d(anchor, positive)')
axes[0].set_ylabel('d(anchor, negative)')
axes[0].set_title('Triplet types — Euclidean distance')
axes[0].legend(fontsize=8)
axes[0].set_xlim(0, lim); axes[0].set_ylim(0, lim)
# Right: after LMNN transform
Z_lmnn_train = X_train @ W_lmnn.T
d_pos_lmnn, d_neg_lmnn = [], []
triplet_type_lmnn = []
np.random.seed(42)
for _ in range(n_triplets):
a = np.random.randint(len(X_train))
same = np.where((y_train == y_train[a]) & (np.arange(len(X_train)) != a))[0]
p = np.random.choice(same)
diff = np.where(y_train != y_train[a])[0]
n = np.random.choice(diff)
d_ap = np.linalg.norm(Z_lmnn_train[a] - Z_lmnn_train[p])
d_an = np.linalg.norm(Z_lmnn_train[a] - Z_lmnn_train[n])
d_pos_lmnn.append(d_ap)
d_neg_lmnn.append(d_an)
if d_an < d_ap:
triplet_type_lmnn.append('hard')
elif d_an < d_ap + margin:
triplet_type_lmnn.append('semi-hard')
else:
triplet_type_lmnn.append('easy')
d_pos_lmnn = np.array(d_pos_lmnn)
d_neg_lmnn = np.array(d_neg_lmnn)
triplet_type_lmnn = np.array(triplet_type_lmnn)
for t in ['easy', 'semi-hard', 'hard']:
mask = triplet_type_lmnn == t
axes[1].scatter(d_pos_lmnn[mask], d_neg_lmnn[mask], c=type_colors[t],
alpha=0.3, s=10, label=f'{t} ({mask.sum()})')
lim2 = max(d_pos_lmnn.max(), d_neg_lmnn.max()) * 1.05
axes[1].plot([0, lim2], [0, lim2], 'k--', alpha=0.3)
axes[1].plot([0, lim2], [margin, lim2 + margin], 'k:', alpha=0.3)
axes[1].set_xlabel('d(anchor, positive)')
axes[1].set_ylabel('d(anchor, negative)')
axes[1].set_title('Triplet types — LMNN metric')
axes[1].legend(fontsize=8)
axes[1].set_xlim(0, lim2); axes[1].set_ylim(0, lim2)
plt.tight_layout()
plt.show()
print('\nAfter LMNN:')
for t in ['hard', 'semi-hard', 'easy']:
print(f' {t:10s}: {np.sum(triplet_type_lmnn == t):5d} ({np.mean(triplet_type_lmnn == t):.1%})')
After LMNN: hard : 100 (5.0%) semi-hard : 582 (29.1%) easy : 1318 (65.9%)
7. Deep Metric Learning with a Simple MLP¶
We now implement a minimal deep metric learning pipeline: a 2-layer MLP trained with triplet loss using semi-hard negative mining. This demonstrates the full DML workflow from Section 16.2.2–16.2.5.
class SimpleEmbeddingNet:
"""2-layer MLP embedding network with L2-normalized output."""
def __init__(self, D_in, D_hidden, D_embed):
# Xavier initialization
self.W1 = np.random.randn(D_in, D_hidden) * np.sqrt(2.0 / D_in)
self.b1 = np.zeros(D_hidden)
self.W2 = np.random.randn(D_hidden, D_embed) * np.sqrt(2.0 / D_hidden)
self.b2 = np.zeros(D_embed)
def forward(self, X):
"""Forward pass returning normalized embeddings and cached activations."""
self.X = X
self.h1 = X @ self.W1 + self.b1
self.a1 = np.maximum(0, self.h1) # ReLU
self.h2 = self.a1 @ self.W2 + self.b2
# L2 normalize
norms = np.linalg.norm(self.h2, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-8)
self.e_hat = self.h2 / norms
return self.e_hat
def backward_triplet(self, idx_a, idx_p, idx_n, margin=0.2, lr=0.01):
"""Backprop through triplet loss for one batch of triplets."""
e_a = self.e_hat[idx_a] # (B, D_embed)
e_p = self.e_hat[idx_p]
e_n = self.e_hat[idx_n]
d_pos_sq = np.sum((e_a - e_p)**2, axis=1) # (B,)
d_neg_sq = np.sum((e_a - e_n)**2, axis=1)
losses = np.maximum(0, d_pos_sq - d_neg_sq + margin)
active = losses > 0 # (B,)
if not np.any(active):
return 0.0
# Gradient of loss w.r.t. embeddings
B = len(idx_a)
N = len(self.e_hat)
D_embed = self.e_hat.shape[1]
# d(loss)/d(e_hat) accumulated for all points
grad_e = np.zeros((N, D_embed))
for b in range(B):
if not active[b]:
continue
grad_e[idx_a[b]] += 2 * (e_a[b] - e_p[b]) - 2 * (e_a[b] - e_n[b])
grad_e[idx_p[b]] += -2 * (e_a[b] - e_p[b])
grad_e[idx_n[b]] += 2 * (e_a[b] - e_n[b])
grad_e /= B
# Backprop through L2 normalization
norms = np.linalg.norm(self.h2, axis=1, keepdims=True)
norms = np.maximum(norms, 1e-8)
grad_h2 = (grad_e - self.e_hat * np.sum(grad_e * self.e_hat, axis=1, keepdims=True)) / norms
# Backprop through second layer
grad_W2 = self.a1.T @ grad_h2
grad_b2 = grad_h2.sum(axis=0)
# Backprop through ReLU
grad_a1 = grad_h2 @ self.W2.T
grad_h1 = grad_a1 * (self.h1 > 0)
# Backprop through first layer
grad_W1 = self.X.T @ grad_h1
grad_b1 = grad_h1.sum(axis=0)
# Update
self.W1 -= lr * grad_W1
self.b1 -= lr * grad_b1
self.W2 -= lr * grad_W2
self.b2 -= lr * grad_b2
return losses[active].mean()
def mine_semi_hard_triplets(embeddings, labels, margin=0.2, n_triplets=32):
"""Mine semi-hard triplets from a batch of embeddings."""
N = len(embeddings)
dists = cdist(embeddings, embeddings, 'sqeuclidean')
anchors, positives, negatives = [], [], []
for _ in range(n_triplets * 5): # oversample then trim
a = np.random.randint(N)
same = np.where((labels == labels[a]) & (np.arange(N) != a))[0]
if len(same) == 0:
continue
p = np.random.choice(same)
d_ap = dists[a, p]
diff = np.where(labels != labels[a])[0]
# Semi-hard: d_ap < d_an < d_ap + margin
semi_hard = diff[(dists[a, diff] > d_ap) & (dists[a, diff] < d_ap + margin)]
if len(semi_hard) > 0:
n = np.random.choice(semi_hard)
else:
# Fall back to hardest negative
n = diff[np.argmin(dists[a, diff])]
anchors.append(a)
positives.append(p)
negatives.append(n)
if len(anchors) >= n_triplets:
break
return np.array(anchors), np.array(positives), np.array(negatives)
# Train the embedding network
np.random.seed(123)
net = SimpleEmbeddingNet(D_in=D, D_hidden=32, D_embed=2)
n_epochs = 300
margin_dml = 0.5
lr = 0.01
losses_history = []
for epoch in range(n_epochs):
embeddings = net.forward(X_train)
idx_a, idx_p, idx_n = mine_semi_hard_triplets(
embeddings, y_train, margin=margin_dml, n_triplets=64
)
if len(idx_a) == 0:
losses_history.append(0)
continue
loss = net.backward_triplet(idx_a, idx_p, idx_n, margin=margin_dml, lr=lr)
losses_history.append(loss)
print(f'Final triplet loss: {losses_history[-1]:.4f}')
# KNN in embedding space
emb_train = net.forward(X_train)
emb_test = net.forward(X_test)
dists_dml = cdist(emb_test, emb_train, 'euclidean')
preds_dml = []
for i in range(len(emb_test)):
nn_idx = np.argsort(dists_dml[i])[:5]
counts = Counter(y_train[nn_idx])
preds_dml.append(counts.most_common(1)[0][0])
acc_dml = np.mean(np.array(preds_dml) == y_test)
print(f'\nKNN (K=5) accuracy comparison:')
print(f' Euclidean: {acc_euclidean:.1%}')
print(f' LMNN: {acc_lmnn:.1%}')
print(f' NCA: {acc_nca:.1%}')
print(f' DML (MLP): {acc_dml:.1%}')
Final triplet loss: 0.2341 KNN (K=5) accuracy comparison: Euclidean: 83.3% LMNN: 98.3% NCA: 98.3% DML (MLP): 86.7%
fig, axes = plt.subplots(1, 2, figsize=(13, 5))
# Left: training loss
axes[0].plot(losses_history, 'b-', alpha=0.6, lw=1)
# Smoothed
window = 10
smoothed = np.convolve(losses_history, np.ones(window)/window, mode='valid')
axes[0].plot(np.arange(window-1, n_epochs), smoothed, 'r-', lw=2, label='Smoothed')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Triplet Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
# Right: 2D embedding space
emb_train = net.forward(X_train)
for c in range(C):
mask = y_train == c
axes[1].scatter(emb_train[mask, 0], emb_train[mask, 1], c=colors[c],
label=moa_names[c], alpha=0.6, s=30, edgecolors='w', linewidths=0.3)
axes[1].set_xlabel('Embedding dim 0')
axes[1].set_ylabel('Embedding dim 1')
axes[1].set_title('DML Embedding Space (2D)')
axes[1].legend(fontsize=8)
plt.tight_layout()
plt.show()
8. Comparison of All Methods¶
We compare all learned metrics by visualizing the pairwise distance distributions for same-class vs. different-class pairs.
def compute_pair_distances(X_emb, y):
"""Compute distances for same-class and different-class pairs."""
dists = cdist(X_emb, X_emb, 'euclidean')
same_mask = (y[:, None] == y[None, :]) & ~np.eye(len(y), dtype=bool)
diff_mask = (y[:, None] != y[None, :])
return dists[same_mask], dists[diff_mask]
methods = {
'Euclidean': X_train,
'LMNN': X_train @ W_lmnn.T,
'NCA': X_train @ W_nca.T,
'DML (Triplet)': net.forward(X_train),
}
fig, axes = plt.subplots(1, 4, figsize=(18, 4))
for ax, (name, emb) in zip(axes, methods.items()):
d_same, d_diff = compute_pair_distances(emb, y_train)
ax.hist(d_same, bins=40, alpha=0.6, density=True, color='#377eb8', label='Same MoA')
ax.hist(d_diff, bins=40, alpha=0.6, density=True, color='#e41a1c', label='Diff MoA')
ax.set_title(name)
ax.set_xlabel('Pairwise distance')
ax.legend(fontsize=8)
# Overlap measure
threshold = (np.median(d_same) + np.median(d_diff)) / 2
sep = np.median(d_diff) - np.median(d_same)
ax.set_ylabel('Density' if ax == axes[0] else '')
axes[0].set_ylabel('Density')
plt.suptitle('Pairwise Distance Distributions: Same-MoA vs. Different-MoA',
fontsize=13, y=1.02)
plt.tight_layout()
plt.show()
print('\nFinal KNN (K=5) accuracy summary:')
print(f' Euclidean: {acc_euclidean:.1%}')
print(f' Oracle (3 dims): {acc_oracle:.1%}')
print(f' LMNN: {acc_lmnn:.1%}')
print(f' NCA: {acc_nca:.1%}')
print(f' DML (Triplet MLP): {acc_dml:.1%}')
Final KNN (K=5) accuracy summary: Euclidean: 83.3% Oracle (3 dims): 96.7% LMNN: 98.3% NCA: 98.3% DML (Triplet MLP): 86.7%
Summary¶
| Method | Key Idea | Complexity |
|---|---|---|
| Euclidean | Equal weight to all dimensions | $O(ND)$ |
| LMNN | Learn $M$ via pull/push losses with margin | Convex in $M$, $O(N^3)$ naively |
| NCA | Soft nearest-neighbor probability, maximize LOO accuracy | Smooth, $O(N^2)$ |
| Contrastive | Pair loss: pull same, push different beyond margin | $O(N^2)$ pairs |
| Triplet | Anchor-positive-negative: relative ordering with margin | $O(N^3)$ triplets |
| N-pairs | Softmax over 1 positive + $N-1$ negatives (= InfoNCE) | $O(NB)$ per batch |
All methods learn to separate same-class from different-class pairs in distance/similarity space. The shift from linear (Mahalanobis) to deep (DNN embedding) methods allows learning nonlinear transformations that capture complex semantic similarity — critical for high-dimensional inputs like molecular structures, cell images, or protein sequences.