基于孪生网络的图像文本匹配实现

Charlie758 +0/-0 0 0 正常 2025-12-24T07:01:19

基于孪生网络的图像文本匹配实现

数据预处理流程

首先对图像数据进行标准化处理:

import torch
from torchvision import transforms

class ImagePreprocessor:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def preprocess(self, image):
        return self.transform(image)

文本数据使用BERT tokenizer进行编码:

from transformers import BertTokenizer

class TextPreprocessor:
    def __init__(self):
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    def preprocess(self, text):
        return self.tokenizer(text, padding=True, truncation=True, max_length=128)

孪生网络架构设计

import torch.nn as nn
import torch.nn.functional as F

class SiameseNetwork(nn.Module):
    def __init__(self, embedding_dim=768):
        super().__init__()
        # 图像分支
        self.image_branch = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((7, 7))
        )
        
        # 文本分支
        self.text_branch = nn.LSTM(embedding_dim, 512, batch_first=True)
        
        # 联合特征提取
        self.fusion_layer = nn.Sequential(
            nn.Linear(128 * 7 * 7 + 512, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256)
        )
        
    def forward(self, image, text):
        # 图像特征提取
        img_features = self.image_branch(image).view(image.size(0), -1)
        
        # 文本特征提取
        text_output, _ = self.text_branch(text)
        text_features = text_output[:, -1, :]  # 取最后一个时间步
        
        # 特征拼接
        combined_features = torch.cat([img_features, text_features], dim=1)
        output = self.fusion_layer(combined_features)
        
        return F.normalize(output, p=2, dim=1)

训练策略

使用对比损失函数进行联合训练:

import torch.optim as optim
from torch.nn import CosineEmbeddingLoss

# 损失函数
loss_fn = CosineEmbeddingLoss(margin=0.3)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 训练循环
for epoch in range(10):
    for batch in dataloader:
        image, text, labels = batch['image'], batch['text'], batch['label']
        features = model(image, text)
        loss = loss_fn(features[0], features[1], labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()```

该架构通过孪生网络结构,实现了图像与文本的联合特征提取,为多模态匹配任务提供了可复现的解决方案。
推广
广告位招租

讨论

0/2000
Piper146
Piper146 · 2026-01-08T10:24:58
这代码看起来很标准,但别忘了数据增强和过拟合风险。图像分支的卷积层太简单了,建议加个ResNet backbone,不然训练效果可能不理想。
HotStar
HotStar · 2026-01-08T10:24:58
BERT文本编码没问题,但注意句子长度截断会丢失语义信息。建议用更长序列或引入注意力机制,否则匹配精度会打折扣。
Max514
Max514 · 2026-01-08T10:24:58
孪生网络结构确实适合这种任务,但损失函数得小心设计。对比损失容易收敛慢,推荐试试三元组损失或者交叉熵,提升训练效率