TensorFlow深度学习模型训练异常处理:从数据预处理到模型评估的全流程监控

Felicity967
Felicity967 2026-03-13T23:07:06+08:00
0 0 0

引言

在深度学习项目开发过程中,模型训练往往是一个复杂且容易出错的过程。TensorFlow作为主流的深度学习框架,虽然提供了强大的功能,但在实际应用中仍然可能遇到各种异常情况。这些异常可能出现在数据预处理阶段、模型训练过程、模型保存和评估等各个环节。

构建一个健壮的深度学习系统不仅需要优秀的模型架构和算法设计,更需要完善的异常处理机制来确保训练过程的稳定性和可靠性。本文将从全流程的角度出发,详细讲解TensorFlow深度学习项目中的异常处理策略,涵盖从数据加载到模型评估的各个关键环节,并提供实用的监控和容错机制实现方案。

数据预处理阶段的异常处理

数据加载异常识别与处理

数据是深度学习训练的基础,但在实际项目中,数据质量问题往往导致训练过程出现异常。常见的数据加载异常包括文件路径错误、数据格式不匹配、内存不足等问题。

import tensorflow as tf
import logging
import os
from pathlib import Path

# 配置日志记录
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def safe_data_loading(data_path, batch_size=32):
    """
    安全的数据加载函数,包含异常处理机制
    """
    try:
        # 检查数据路径是否存在
        if not os.path.exists(data_path):
            raise FileNotFoundError(f"数据路径不存在: {data_path}")
        
        # 验证文件格式
        if not data_path.endswith(('.tfrecord', '.csv', '.h5')):
            logger.warning(f"警告: 数据格式可能不支持: {data_path}")
        
        # 创建数据集
        dataset = tf.data.TFRecordDataset(data_path)
        dataset = dataset.map(parse_function, num_parallel_calls=tf.data.AUTOTUNE)
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        
        logger.info(f"成功加载数据集,样本数量: {len(dataset)}")
        return dataset
        
    except FileNotFoundError as e:
        logger.error(f"文件未找到错误: {e}")
        raise
    except tf.errors.InvalidArgumentError as e:
        logger.error(f"数据格式错误: {e}")
        raise
    except Exception as e:
        logger.error(f"数据加载异常: {e}")
        raise

def parse_function(example_proto):
    """
    解析TFRecord样本的函数,包含异常处理
    """
    try:
        # 定义特征描述
        feature_description = {
            'image': tf.io.FixedLenFeature([], tf.string),
            'label': tf.io.FixedLenFeature([], tf.int64)
        }
        
        # 解析样本
        parsed_example = tf.io.parse_single_example(example_proto, feature_description)
        
        # 解码图像
        image = tf.io.decode_image(parsed_example['image'], channels=3)
        image = tf.cast(image, tf.float32) / 255.0
        
        label = parsed_example['label']
        
        return image, label
        
    except Exception as e:
        logger.error(f"样本解析失败: {e}")
        # 返回默认值或抛出异常
        raise

数据验证与清洗

数据质量直接影响模型训练效果,因此在预处理阶段需要进行严格的数据验证:

def validate_dataset(dataset, expected_shape=(224, 224, 3)):
    """
    验证数据集质量
    """
    try:
        # 检查数据类型和形状
        for batch in dataset.take(1):
            images, labels = batch
            
            # 检查图像形状
            if len(images.shape) != 4:
                raise ValueError(f"图像维度不正确,期望4维,实际{len(images.shape)}维")
            
            # 检查数据类型
            if images.dtype != tf.float32:
                logger.warning(f"图像数据类型不是float32: {images.dtype}")
            
            # 检查是否有空值
            has_nan = tf.reduce_any(tf.math.is_nan(images))
            has_inf = tf.reduce_any(tf.math.is_inf(images))
            
            if has_nan or has_inf:
                raise ValueError("检测到数据中的NaN或无穷大值")
            
            logger.info(f"数据验证通过,图像形状: {images.shape}")
            return True
            
    except Exception as e:
        logger.error(f"数据验证失败: {e}")
        raise

def clean_dataset(dataset):
    """
    清洗数据集,移除异常样本
    """
    def filter_function(image, label):
        # 检查图像是否有效
        try:
            # 检查形状
            image_shape = tf.shape(image)
            valid_shape = tf.reduce_all(tf.equal(image_shape, [224, 224, 3]))
            
            # 检查数值范围
            valid_range = tf.reduce_all(tf.logical_and(
                tf.greater_equal(image, 0.0),
                tf.less_equal(image, 1.0)
            ))
            
            return tf.logical_and(valid_shape, valid_range)
        except:
            return False
    
    return dataset.filter(filter_function)

模型训练过程异常处理

训练循环的健壮性设计

在深度学习训练过程中,训练循环本身可能因为各种原因出现异常。设计一个健壮的训练循环至关重要:

import time
from datetime import datetime

class RobustTrainer:
    def __init__(self, model, optimizer, loss_fn, metrics):
        self.model = model
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.metrics = metrics
        self.checkpoint_manager = None
        
    def train_step(self, x_batch, y_batch):
        """
        单步训练,包含异常处理
        """
        try:
            with tf.GradientTape() as tape:
                predictions = self.model(x_batch, training=True)
                loss = self.loss_fn(y_batch, predictions)
                
                # 检查损失值是否正常
                if tf.math.is_nan(loss) or tf.math.is_inf(loss):
                    raise ValueError(f"训练损失异常: {loss}")
                
            gradients = tape.gradient(loss, self.model.trainable_variables)
            
            # 梯度裁剪防止梯度爆炸
            gradients = [tf.clip_by_norm(grad, 1.0) for grad in gradients if grad is not None]
            
            self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
            
            # 更新指标
            for metric in self.metrics:
                metric.update_state(y_batch, predictions)
                
            return loss
            
        except Exception as e:
            logger.error(f"训练步骤异常: {e}")
            # 记录异常样本信息
            logger.debug(f"异常样本形状: {x_batch.shape if hasattr(x_batch, 'shape') else 'Unknown'}")
            raise
    
    def train_epoch(self, dataset, epoch_num):
        """
        单个epoch的训练,包含完整的错误处理和监控
        """
        try:
            start_time = time.time()
            total_loss = 0.0
            num_batches = 0
            
            logger.info(f"开始第 {epoch_num} 轮训练")
            
            for batch_idx, (x_batch, y_batch) in enumerate(dataset):
                try:
                    loss = self.train_step(x_batch, y_batch)
                    total_loss += loss
                    num_batches += 1
                    
                    # 每100个batch打印一次进度
                    if batch_idx % 100 == 0:
                        logger.info(f"Epoch {epoch_num}, Batch {batch_idx}, Loss: {loss:.4f}")
                        
                except Exception as batch_error:
                    logger.error(f"第 {batch_idx} 批次训练失败: {batch_error}")
                    # 可以选择跳过当前批次或停止训练
                    continue
                    
            avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
            
            end_time = time.time()
            logger.info(f"Epoch {epoch_num} 完成,平均损失: {avg_loss:.4f}, 耗时: {end_time - start_time:.2f}s")
            
            return avg_loss
            
        except KeyboardInterrupt:
            logger.warning("训练被用户中断")
            raise
        except Exception as e:
            logger.error(f"Epoch {epoch_num} 训练异常: {e}")
            raise

def robust_training_loop(model, train_dataset, val_dataset, epochs=10):
    """
    健壮的训练循环实现
    """
    # 初始化优化器和损失函数
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
    
    # 创建指标
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
    
    # 初始化训练器
    trainer = RobustTrainer(model, optimizer, loss_fn, [train_accuracy])
    
    # 训练历史记录
    training_history = {
        'epoch': [],
        'train_loss': [],
        'val_loss': [],
        'train_acc': [],
        'val_acc': []
    }
    
    try:
        for epoch in range(epochs):
            try:
                # 训练阶段
                train_loss = trainer.train_epoch(train_dataset, epoch + 1)
                
                # 验证阶段
                val_loss, val_acc = validate_model(model, val_dataset, loss_fn, [val_accuracy])
                
                # 记录历史
                training_history['epoch'].append(epoch + 1)
                training_history['train_loss'].append(float(train_loss))
                training_history['val_loss'].append(float(val_loss))
                training_history['train_acc'].append(float(train_accuracy.result()))
                training_history['val_acc'].append(float(val_acc))
                
                logger.info(f"Epoch {epoch + 1} 完成 - 训练损失: {train_loss:.4f}, "
                          f"验证损失: {val_loss:.4f}, 验证准确率: {val_acc:.4f}")
                
                # 检查是否需要早停
                if check_early_stopping(training_history):
                    logger.info("检测到训练停止条件,提前终止训练")
                    break
                    
            except Exception as e:
                logger.error(f"第 {epoch + 1} 轮训练异常: {e}")
                # 检查是否应该继续训练
                if not continue_training_after_error(e):
                    raise
                else:
                    logger.info("继续训练...")
                    
    except KeyboardInterrupt:
        logger.warning("训练被用户中断,保存当前模型状态")
        save_checkpoint(model, "interrupted_checkpoint")
        raise
    except Exception as e:
        logger.error(f"训练过程发生严重错误: {e}")
        # 保存故障状态
        save_error_state(model, training_history, str(e))
        raise
        
    return training_history

def validate_model(model, dataset, loss_fn, metrics):
    """
    模型验证,包含异常处理
    """
    try:
        total_loss = 0.0
        num_batches = 0
        
        for x_batch, y_batch in dataset:
            predictions = model(x_batch, training=False)
            loss = loss_fn(y_batch, predictions)
            
            total_loss += loss
            num_batches += 1
            
            # 更新指标
            for metric in metrics:
                metric.update_state(y_batch, predictions)
                
        avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
        
        # 获取最终指标值
        final_metrics = [metric.result() for metric in metrics]
        
        # 重置指标状态
        for metric in metrics:
            metric.reset_state()
            
        return avg_loss, final_metrics[0] if final_metrics else 0.0
        
    except Exception as e:
        logger.error(f"模型验证失败: {e}")
        raise

梯度和数值稳定性监控

深度学习训练中,梯度爆炸、梯度消失等问题是常见异常:

class GradientMonitor:
    """
    梯度监控器,用于检测训练过程中的数值稳定性问题
    """
    
    def __init__(self, model):
        self.model = model
        self.gradient_history = []
        
    def monitor_gradients(self, gradients, variables):
        """
        监控梯度变化
        """
        gradient_norms = []
        
        for grad, var in zip(gradients, variables):
            if grad is not None:
                norm = tf.norm(grad)
                gradient_norms.append(norm)
                
                # 检查梯度是否异常
                if tf.math.is_nan(norm) or tf.math.is_inf(norm):
                    logger.error(f"检测到异常梯度: {var.name}, norm: {norm}")
                    raise ValueError(f"梯度异常: {var.name}")
                    
                # 记录梯度信息
                self.gradient_history.append({
                    'variable': var.name,
                    'norm': float(norm),
                    'timestamp': datetime.now()
                })
                
        return gradient_norms
    
    def check_gradient_stability(self, epoch):
        """
        检查梯度稳定性
        """
        if len(self.gradient_history) < 10:
            return True
            
        # 计算最近10个梯度的平均值和标准差
        recent_gradients = [item['norm'] for item in self.gradient_history[-10:]]
        avg_grad = sum(recent_gradients) / len(recent_gradients)
        std_grad = (sum((x - avg_grad)**2 for x in recent_gradients) / len(recent_gradients))**0.5
        
        # 如果标准差过大,可能存在梯度爆炸问题
        if std_grad > 10 * avg_grad and avg_grad > 1e-6:
            logger.warning(f"检测到梯度不稳定 - 平均值: {avg_grad:.6f}, 标准差: {std_grad:.6f}")
            return False
            
        return True

def advanced_training_with_monitoring(model, dataset, epochs=10):
    """
    带监控的高级训练函数
    """
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
    
    # 初始化监控器
    gradient_monitor = GradientMonitor(model)
    
    for epoch in range(epochs):
        logger.info(f"开始第 {epoch + 1} 轮训练")
        
        epoch_loss = 0.0
        num_batches = 0
        
        try:
            for batch_idx, (x_batch, y_batch) in enumerate(dataset):
                with tf.GradientTape() as tape:
                    predictions = model(x_batch, training=True)
                    loss = loss_fn(y_batch, predictions)
                    
                    # 检查损失是否正常
                    if tf.math.is_nan(loss) or tf.math.is_inf(loss):
                        logger.error(f"异常损失值: {loss}")
                        continue
                        
                gradients = tape.gradient(loss, model.trainable_variables)
                
                # 监控梯度
                grad_norms = gradient_monitor.monitor_gradients(gradients, model.trainable_variables)
                
                # 梯度裁剪
                gradients = [tf.clip_by_norm(grad, 1.0) for grad in gradients if grad is not None]
                
                optimizer.apply_gradients(zip(gradients, model.trainable_variables))
                
                epoch_loss += loss
                num_batches += 1
                
                # 每100个batch检查一次稳定性
                if batch_idx % 100 == 0:
                    stability_ok = gradient_monitor.check_gradient_stability(epoch)
                    if not stability_ok:
                        logger.warning("梯度不稳定,考虑调整学习率或添加正则化")
                        
            avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
            logger.info(f"Epoch {epoch + 1} 完成,平均损失: {avg_loss:.4f}")
            
        except Exception as e:
            logger.error(f"训练过程中发生异常: {e}")
            # 可以选择保存当前状态或继续
            save_training_state(model, epoch, avg_loss)
            raise
            
    return model

模型保存与恢复机制

完整的模型保存策略

模型保存是深度学习训练中的重要环节,需要考虑多种异常情况:

import pickle
import json

class ModelCheckpointManager:
    """
    模型检查点管理器,提供完整的保存和恢复机制
    """
    
    def __init__(self, model, checkpoint_dir="./checkpoints", save_freq="epoch"):
        self.model = model
        self.checkpoint_dir = Path(checkpoint_dir)
        self.save_freq = save_freq
        self.checkpoint_dir.mkdir(exist_ok=True)
        
        # 记录检查点信息
        self.checkpoint_info = {
            'checkpoints': [],
            'best_loss': float('inf'),
            'best_checkpoint': None,
            'last_saved': None
        }
        
    def save_checkpoint(self, epoch, loss, accuracy=None):
        """
        保存模型检查点
        """
        try:
            # 创建检查点目录
            checkpoint_path = self.checkpoint_dir / f"epoch_{epoch}"
            checkpoint_path.mkdir(exist_ok=True)
            
            # 保存模型权重
            model_path = checkpoint_path / "model_weights.h5"
            self.model.save_weights(str(model_path))
            
            # 保存模型配置
            config_path = checkpoint_path / "model_config.json"
            model_config = self.model.get_config()
            with open(config_path, 'w') as f:
                json.dump(model_config, f)
                
            # 保存训练状态
            state_path = checkpoint_path / "training_state.pkl"
            training_state = {
                'epoch': epoch,
                'loss': float(loss),
                'accuracy': float(accuracy) if accuracy is not None else None,
                'timestamp': datetime.now().isoformat()
            }
            
            with open(state_path, 'wb') as f:
                pickle.dump(training_state, f)
                
            # 更新检查点信息
            checkpoint_info = {
                'path': str(checkpoint_path),
                'epoch': epoch,
                'loss': float(loss),
                'accuracy': float(accuracy) if accuracy is not None else None,
                'timestamp': datetime.now().isoformat()
            }
            
            self.checkpoint_info['checkpoints'].append(checkpoint_info)
            
            # 更新最佳检查点
            if loss < self.checkpoint_info['best_loss']:
                self.checkpoint_info['best_loss'] = loss
                self.checkpoint_info['best_checkpoint'] = checkpoint_info
                
            self.checkpoint_info['last_saved'] = checkpoint_info
            
            logger.info(f"检查点保存成功: {checkpoint_path}")
            
        except Exception as e:
            logger.error(f"检查点保存失败: {e}")
            raise
            
    def load_checkpoint(self, checkpoint_path):
        """
        加载指定的检查点
        """
        try:
            # 加载模型权重
            weights_path = Path(checkpoint_path) / "model_weights.h5"
            if weights_path.exists():
                self.model.load_weights(str(weights_path))
                logger.info(f"成功加载权重: {weights_path}")
            else:
                raise FileNotFoundError(f"权重文件不存在: {weights_path}")
                
            # 加载训练状态
            state_path = Path(checkpoint_path) / "training_state.pkl"
            if state_path.exists():
                with open(state_path, 'rb') as f:
                    training_state = pickle.load(f)
                logger.info(f"成功加载训练状态: {training_state}")
                
            return True
            
        except Exception as e:
            logger.error(f"检查点加载失败: {e}")
            raise
            
    def save_best_model(self, model_path="./best_model"):
        """
        保存最佳模型
        """
        try:
            if self.checkpoint_info['best_checkpoint']:
                best_path = Path(self.checkpoint_info['best_checkpoint']['path'])
                
                # 复制最佳检查点的权重到最终位置
                final_weights_path = Path(model_path) / "model_weights.h5"
                final_weights_path.parent.mkdir(exist_ok=True)
                
                import shutil
                shutil.copy2(
                    str(best_path / "model_weights.h5"),
                    str(final_weights_path)
                )
                
                logger.info(f"最佳模型已保存: {final_weights_path}")
                return True
                
        except Exception as e:
            logger.error(f"最佳模型保存失败: {e}")
            raise

def safe_model_saving(model, checkpoint_manager, epoch, loss, accuracy=None):
    """
    安全的模型保存函数
    """
    try:
        # 验证模型完整性
        if not model.built:
            logger.warning("模型未构建,尝试构建...")
            # 可以在这里添加模型构建逻辑
            
        # 保存检查点
        checkpoint_manager.save_checkpoint(epoch, loss, accuracy)
        
        # 验证保存的文件
        latest_checkpoint = checkpoint_manager.checkpoint_info['last_saved']
        if latest_checkpoint:
            checkpoint_path = Path(latest_checkpoint['path'])
            required_files = [
                checkpoint_path / "model_weights.h5",
                checkpoint_path / "model_config.json",
                checkpoint_path / "training_state.pkl"
            ]
            
            for file_path in required_files:
                if not file_path.exists():
                    raise FileNotFoundError(f"检查点文件缺失: {file_path}")
                    
        logger.info(f"模型保存验证通过 - Epoch {epoch}")
        
    except Exception as e:
        logger.error(f"安全模型保存失败: {e}")
        # 可以尝试重新保存或发送告警
        raise

恢复训练的容错机制

在训练中断后能够正确恢复是生产环境中的关键需求:

def resume_training_from_checkpoint(model, checkpoint_manager, start_epoch=0):
    """
    从检查点恢复训练
    """
    try:
        # 查找最新的检查点
        latest_checkpoint = find_latest_checkpoint(checkpoint_manager.checkpoint_info)
        
        if latest_checkpoint:
            logger.info(f"找到最新检查点: {latest_checkpoint['path']}")
            
            # 加载检查点
            checkpoint_path = Path(latest_checkpoint['path'])
            model.load_weights(str(checkpoint_path / "model_weights.h5"))
            
            # 恢复训练状态
            training_state_path = checkpoint_path / "training_state.pkl"
            if training_state_path.exists():
                with open(training_state_path, 'rb') as f:
                    training_state = pickle.load(f)
                
                logger.info(f"恢复训练状态 - Epoch {training_state['epoch']}")
                return training_state['epoch'] + 1
                
        else:
            logger.info("未找到可用的检查点,从头开始训练")
            return start_epoch
            
    except Exception as e:
        logger.error(f"恢复训练失败: {e}")
        raise

def find_latest_checkpoint(checkpoint_info):
    """
    查找最新的检查点
    """
    if not checkpoint_info['checkpoints']:
        return None
        
    # 按时间排序,返回最新的检查点
    checkpoints = sorted(
        checkpoint_info['checkpoints'],
        key=lambda x: datetime.fromisoformat(x['timestamp']),
        reverse=True
    )
    
    return checkpoints[0]

def robust_training_with_recovery(model, dataset, epochs=10):
    """
    带恢复功能的健壮训练流程
    """
    # 初始化检查点管理器
    checkpoint_manager = ModelCheckpointManager(model, "./checkpoints")
    
    try:
        # 检查是否可以从检查点恢复
        start_epoch = resume_training_from_checkpoint(model, checkpoint_manager)
        
        if start_epoch == 0:
            logger.info("开始全新的训练过程")
        else:
            logger.info(f"从第 {start_epoch} 轮开始恢复训练")
            
        # 执行训练循环
        for epoch in range(start_epoch, epochs):
            try:
                # 训练逻辑
                epoch_loss = 0.0
                num_batches = 0
                
                for x_batch, y_batch in dataset:
                    with tf.GradientTape() as tape:
                        predictions = model(x_batch, training=True)
                        loss = tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions)
                        
                    gradients = tape.gradient(loss, model.trainable_variables)
                    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
                    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
                    
                    epoch_loss += tf.reduce_mean(loss)
                    num_batches += 1
                    
                avg_loss = epoch_loss / num_batches
                
                # 保存检查点
                safe_model_saving(model, checkpoint_manager, epoch + 1, avg_loss)
                
                logger.info(f"Epoch {epoch + 1} 完成,损失: {avg_loss:.4f}")
                
            except KeyboardInterrupt:
                logger.warning("训练被用户中断")
                save_interrupted_state(model, checkpoint_manager, epoch)
                raise
            except Exception as e:
                logger.error(f"第 {epoch + 1} 轮训练异常: {e}")
                # 检查是否可以恢复
                if not attempt_recovery(model, checkpoint_manager):
                    raise
                    
        return model
        
    except Exception as e:
        logger.error(f"训练过程中发生严重错误: {e}")
        raise

模型评估阶段的异常处理

评估过程的健壮性设计

模型评估是验证训练效果的重要环节,同样需要完善的异常处理机制:

def robust_model_evaluation(model, dataset, metrics=None):
    """
    健壮的模型评估函数
    """
    try:
        # 初始化评估指标
        if metrics is None:
            metrics = [
                tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                tf.keras.metrics.SparseCategoricalCrossentropy(name="loss")
            ]
        
        # 重置指标
        for metric in metrics:
            metric.reset_state()
            
        total_loss = 0.0
        num_batches = 0
        
        logger.info("开始模型评估...")
        
        # 执行评估
        for batch_idx, (x_batch, y_batch) in enumerate(dataset):
            try:
                # 验证输入数据
                if not validate_input_data(x_batch, y_batch):
                    logger.warning(f"跳过无效批次: {batch_idx}")
                    continue
                    
                # 模型预测
                predictions = model(x_batch, training=False)
                
                # 计算损失和指标
                loss = tf.keras.losses.sparse_categorical_crossentropy(y_batch, predictions)
                total_loss += tf.reduce_mean(loss)
                
                # 更新指标
                for metric in metrics:
                    if isinstance(metric, tf.keras.metrics.SparseCategoricalCrossentropy):
                        # 特殊处理损失指标
                        metric.update_state(y_batch, predictions)
                    else:
                        metric.update_state(y_batch, predictions)
                        
                num_batches += 1
                
                if batch_idx % 50 == 0:
                    logger.info(f"评估进度: {batch_idx} 批次")
                    
            except Exception as batch_error:
                logger.error(f"批次 {batch_idx} 评估失败: {batch_error}")
                # 可以选择跳过或停止
                continue
                
        if num_batches == 0:
            raise ValueError("没有有效的评估批次")
            
        avg_loss = total_loss / num_batches
        
        # 获取最终指标值
        results = {}
        for metric in metrics:
            results[metric.name] = metric.result().numpy()
            
        logger.info(f"模型评估完成 - 损失: {avg_loss:.4f}")
        
        return results
        
    except Exception as e:
        logger.error(f"模型评估失败: {e}")
        raise

def validate_input_data(x_batch, y_batch):
    """
    验证输入数据的有效性
    """
    try:
        # 检查维度
        if len(x_batch.shape) != 4 or len(y_batch.shape) != 
相关推荐
广告位招租

相似文章

    评论 (0)

    0/2000