defcompute_train_transform(seed=123456): random.seed(seed) torch.random.manual_seed(seed) # Transformation that applies color jitter with brightness=0.4, contrast=0.4, saturation=0.4, and hue=0.1 color_jitter = transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
train_transform = transforms.Compose([ # Step 1: Randomly resize and crop to 32x32. transforms.RandomResizedCrop(32), # Step 2: Horizontally flip the image with probability 0.5 transforms.RandomHorizontalFlip(p=0.5), # Step 3: With a probability of 0.8, apply color jitter # (you can use "color_jitter" defined above. transforms.RandomApply([color_jitter], p=0.8), # Step 4: With a probability of 0.2, convert the image to grayscale transforms.RandomGrayscale(p=0.2), transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) ]) return train_transform
norm_dot_product = torch.dot(z_i, z_j) / (torch.linalg.norm(z_i) * torch.linalg.norm(z_j)) return norm_dot_product defsimclr_loss_naive(out_left, out_right, tau): N = out_left.shape[0] # total number of training examples
# Concatenate out_left and out_right into a 2*N x D tensor. out = torch.cat([out_left, out_right], dim=0) # [2*N, D]
total_loss = 0 for k inrange(N): # loop through each positive pair (k, k+N) z_k, z_k_N = out[k], out[k + N] exp_sum1, exp_sum2 = 0, 0 for i inrange(2 * N): if i != k: exp_sum1 += torch.exp(sim(z_k, out[i]) / tau) if i != k + N: exp_sum2 += torch.exp(sim(z_k_N, out[i]) / tau) total_loss += -torch.log(torch.exp(sim(z_k, z_k_N) / tau) / exp_sum1) total_loss += -torch.log(torch.exp(sim(z_k_N, z_k) / tau) / exp_sum2)
# In the end, we need to divide the total loss by 2N, the number of samples in the batch. total_loss = total_loss / (2 * N) return total_loss
defsimclr_loss_vectorized(out_left, out_right, tau, device='cuda'): N = out_left.shape[0] # Concatenate out_left and out_right into a 2*N x D tensor. out = torch.cat([out_left, out_right], dim=0) # [2*N, D] # Compute similarity matrix between all pairs of augmented examples in the batch. sim_matrix = compute_sim_matrix(out) # [2*N, 2*N] # Step 1: Use sim_matrix to compute the denominator value for all augmented samples. # Hint: Compute e^{sim / tau} and store into exponential, which should have shape 2N x 2N. exponential = torch.exp(sim_matrix / tau)
# This binary mask zeros out terms where k=i. mask = (torch.ones_like(exponential, device=device) - torch.eye(2 * N, device=device)).to(device).bool()
# We apply the binary mask. exponential = exponential.masked_select(mask).view(2 * N, -1) # [2*N, 2*N-1]
# Hint: Compute the denominator values for all augmented samples. This should be a 2N x 1 vector. denom = torch.sum(exponential, dim=1, keepdim=True)
# Step 2: Compute similarity between positive pairs. # You can do this in two ways: # Option 1: Extract the corresponding indices from sim_matrix. # Option 2: Use sim_positive_pairs().
# Step 3: Compute the numerator value for all augmented samples. numerator = torch.exp(sim_pairs / tau) # Step 4: Now that you have the numerator and denominator for all augmented samples, compute the total loss. loss = torch.mean(-torch.log(numerator / denom))