将人脸伪造判别分类网络(Face Forgery Detection Classification Network)设计为一种卷积神经网络(CNN)涉及以下几个关键步骤:### 1. 数据收集与预处理- **数据收集**:收集大量的人脸图像数据,包括真实人

摘要:作者:SkyXZ CSDN:SkyXZ~-CSDN博客 博客园:SkyXZ - 博客园 一、获取数据集 FaceForensics++ 是一个取证数据集,由1000段原始视频序列组成,这些视频通过四种自动
作者:SkyXZ CSDN:SkyXZ~-CSDN博客 博客园:SkyXZ - 博客园 一、获取数据集 FaceForensics++ 是一个取证数据集,由1000段原始视频序列组成,这些视频通过四种自动人脸操纵方法进行处理:Deepfakes、Face2Face、FaceSwap 和 NeuralTextures。数据来自 977 段 YouTube 视频,所有视频中都包含一张可跟踪的、主要为正面且没有遮挡的人脸,使得自动篡改方法能够生成逼真的伪造视频。同时,由于该数据集提供了二值掩码,这些数据可以用于图像和视频分类以及分割。此外,官方还提供了 1000 个 Deepfakes 模型,用于生成和扩充新数据。 原始论文:https://arxiv.org/abs/1901.08971 GitHub链接:https://github.com/ondyari/FaceForensics FaceForensics++数据集无法直接下载,需要按照要求填写谷歌表单来申请获取https://docs.google.com/forms/d/e/1FAIpQLSdRRR3L5zAv6tQ_CKxmK4W96tAab_pfBu2EKAgQbeDVhmXagg/viewform 等待几天之后会收到如下邮件,里面会附上数据集的下载Code,直接使用下载脚本下载即可获取: #!/usr/bin/env python """ Downloads FaceForensics++ and Deep Fake Detection public data release Example usage: see -h or https://github.com/ondyari/FaceForensics """ # -*- coding: utf-8 -*- import argparse import os import urllib import urllib.request import tempfile import time import sys import json import random from tqdm import tqdm from os.path import join # URLs and filenames FILELIST_URL = 'misc/filelist.json' DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json' DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',] # Parameters DATASETS = { 'original_youtube_videos': 'misc/downloaded_youtube_videos.zip', 'original_youtube_videos_info': 'misc/downloaded_youtube_videos_info.zip', 'original': 'original_sequences/youtube', 'DeepFakeDetection_original': 'original_sequences/actors', 'Deepfakes': 'manipulated_sequences/Deepfakes', 'DeepFakeDetection': 'manipulated_sequences/DeepFakeDetection', 'Face2Face': 'manipulated_sequences/Face2Face', 'FaceShifter': 'manipulated_sequences/FaceShifter', 'FaceSwap': 'manipulated_sequences/FaceSwap', 'NeuralTextures': 'manipulated_sequences/NeuralTextures' } ALL_DATASETS = ['original', 'DeepFakeDetection_original', 'Deepfakes', 'DeepFakeDetection', 'Face2Face', 'FaceShifter', 'FaceSwap', 'NeuralTextures'] COMPRESSION = ['raw', 'c23', 'c40'] TYPE = ['videos', 'masks', 'models'] SERVERS = ['EU', 'EU2', 'CA'] def parse_args(): parser = argparse.ArgumentParser( description='Downloads FaceForensics v2 public data release.', formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument('output_path', type=str, help='Output directory.') parser.add_argument('-d', '--dataset', type=str, default='all', help='Which dataset to download, either pristine or ' 'manipulated data or the downloaded youtube ' 'videos.', choices=list(DATASETS.keys()) + ['all'] ) parser.add_argument('-c', '--compression', type=str, default='raw', help='Which compression degree. All videos ' 'have been generated with h264 with a varying ' 'codec. Raw (c0) videos are lossless compressed.', choices=COMPRESSION ) parser.add_argument('-t', '--type', type=str, default='videos', help='Which file type, i.e. videos, masks, for our ' 'manipulation methods, models, for Deepfakes.', choices=TYPE ) parser.add_argument('-n', '--num_videos', type=int, default=None, help='Select a number of videos number to ' "download if you don't want to download the full" ' dataset.') parser.add_argument('--server', type=str, default='EU', help='Server to download the data from. If you ' 'encounter a slow download speed, consider ' 'changing the server.', choices=SERVERS ) args = parser.parse_args() # URLs server = args.server if server == 'EU': server_url = 'http://canis.vc.in.tum.de:8100/' elif server == 'EU2': server_url = 'http://kaldir.vc.in.tum.de/faceforensics/' elif server == 'CA': server_url = 'http://falas.cmpt.sfu.ca:8100/' else: raise Exception('Wrong server name. Choices: {}'.format(str(SERVERS))) args.tos_url = server_url + 'webpage/FaceForensics_TOS.pdf' args.base_url = server_url + 'v3/' args.deepfakes_model_url = server_url + 'v3/manipulated_sequences/' + \ 'Deepfakes/models/' return args def download_files(filenames, base_url, output_path, report_progress=True): os.makedirs(output_path, exist_ok=True) if report_progress: filenames = tqdm(filenames) for filename in filenames: download_file(base_url + filename, join(output_path, filename)) def reporthook(count, block_size, total_size): global start_time if count == 0: start_time = time.time() return duration = time.time() - start_time progress_size = int(count * block_size) speed = int(progress_size / (1024 * duration)) percent = int(count * block_size * 100 / total_size) sys.stdout.write("\rProgress: %d%%, %d MB, %d KB/s, %d seconds passed" % (percent, progress_size / (1024 * 1024), speed, duration)) sys.stdout.flush() def download_file(url, out_file, report_progress=False): out_dir = os.path.dirname(out_file) if not os.path.isfile(out_file): fh, out_file_tmp = tempfile.mkstemp(dir=out_dir) f = os.fdopen(fh, 'w') f.close() if report_progress: urllib.request.urlretrieve(url, out_file_tmp, reporthook=reporthook) else: urllib.request.urlretrieve(url, out_file_tmp) os.rename(out_file_tmp, out_file) else: tqdm.write('WARNING: skipping download of existing file ' + out_file) def main(args): # TOS print('By pressing any key to continue you confirm that you have agreed '\ 'to the FaceForensics terms of use as described at:') print(args.tos_url) print('***') print('Press any key to continue, or CTRL-C to exit.') _ = input('') # Extract arguments c_datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS c_type = args.type c_compression = args.compression num_videos = args.num_videos output_path = args.output_path os.makedirs(output_path, exist_ok=True) # Check for special dataset cases for dataset in c_datasets: dataset_path = DATASETS[dataset] # Special cases if 'original_youtube_videos' in dataset: # Here we download the original youtube videos zip file print('Downloading original youtube videos.') if not 'info' in dataset_path: print('Please be patient, this may take a while (~40gb)') suffix = '' else: suffix = 'info' download_file(args.base_url + '/' + dataset_path, out_file=join(output_path, 'downloaded_videos{}.zip'.format( suffix)), report_progress=True) return # Else: regular datasets print('Downloading {} of dataset "{}"'.format( c_type, dataset_path )) # Get filelists and video lenghts list from server if 'DeepFakeDetection' in dataset_path or 'actors' in dataset_path: filepaths = json.loads(urllib.request.urlopen(args.base_url + '/' + DEEPFEAKES_DETECTION_URL).read().decode("utf-8")) if 'actors' in dataset_path: filelist = filepaths['actors'] else: filelist = filepaths['DeepFakesDetection'] elif 'original' in dataset_path: # Load filelist from server file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' + FILELIST_URL).read().decode("utf-8")) filelist = [] for pair in file_pairs: filelist += pair else: # Load filelist from server file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' + FILELIST_URL).read().decode("utf-8")) # Get filelist filelist = [] for pair in file_pairs: filelist.append('_'.join(pair)) if c_type != 'models': filelist.append('_'.join(pair[::-1])) # Maybe limit number of videos for download if num_videos is not None and num_videos > 0: print('Downloading the first {} videos'.format(num_videos)) filelist = filelist[:num_videos] # Server and local paths dataset_videos_url = args.base_url + '{}/{}/{}/'.format( dataset_path, c_compression, c_type) dataset_mask_url = args.base_url + '{}/{}/videos/'.format( dataset_path, 'masks', c_type) if c_type == 'videos': dataset_output_path = join(output_path, dataset_path, c_compression, c_type) print('Output path: {}'.format(dataset_output_path)) filelist = [filename + '.mp4' for filename in filelist] download_files(filelist, dataset_videos_url, dataset_output_path) elif c_type == 'masks': dataset_output_path = join(output_path, dataset_path, c_type, 'videos') print('Output path: {}'.format(dataset_output_path)) if 'original' in dataset: if args.dataset != 'all': print('Only videos available for original data. Aborting.') return else: print('Only videos available for original data. ' 'Skipping original.\n') continue if 'FaceShifter' in dataset: print('Masks not available for FaceShifter. Aborting.') return filelist = [filename + '.mp4' for filename in filelist] download_files(filelist, dataset_mask_url, dataset_output_path) # Else: models for deepfakes else: if dataset != 'Deepfakes' and c_type == 'models': print('Models only available for Deepfakes. Aborting') return dataset_output_path = join(output_path, dataset_path, c_type) print('Output path: {}'.format(dataset_output_path)) # Get Deepfakes models for folder in tqdm(filelist): folder_filelist = DEEPFAKES_MODEL_NAMES # Folder paths folder_base_url = args.deepfakes_model_url + folder + '/' folder_dataset_output_path = join(dataset_output_path, folder) download_files(folder_filelist, folder_base_url, folder_dataset_output_path, report_progress=False) # already done if __name__ == "__main__": args = parse_args() main(args) 接下来使用如下命令即可下载数据集 python download-FaceForensics.py <output path> -d <dataset type, e.g., Face2Face, original or all> -c <compression quality, e.g., c23 or raw> -t <file type, e.g., videos, masks or models> <output path> 表示数据集的保存路径,即下载后的 FaceForensics++ 或 DeepFakeDetection 数据将被存放的位置。例如,可以设置为当前项目下的 ./data/,也可以设置为单独的数据盘路径,如 /mnt/data2/qi.xiong/Dataset/FaceForensics/。下载脚本会在该目录下自动构建对应的数据集层级结构. d 用于指定下载的数据类型(dataset type)。常见可选项包括 original、Face2Face、Deepfakes、FaceSwap、NeuralTextures、DeepFakeDetection 以及 all 等。其中,original 表示下载原始真实视频序列,通常对应 original_sequences/youtube;Face2Face、Deepfakes、FaceSwap 和 NeuralTextures 表示下载四种主要伪造方法生成的数据;DeepFakeDetection 表示下载 DeepFakeDetection 扩展数据;all 表示一次性下载全部可用数据。若仅用于常规 deepfake 检测实验,通常优先选择 original 与四种主流伪造类型。 c 用于指定压缩等级(compression quality)。常用选项为 raw、c23 和 c40。其中,raw 表示原始或无损压缩版本,数据体积最大,但保留了最完整的图像细节;c23 表示较高质量压缩版本,是目前较常见、也较平衡的一种设置,既能保留较好的视觉质量,又显著降低存储开销;c40 表示压缩更强、质量更低的数据版本,更适合做强压缩场景下的鲁棒性测试。实际使用中,如果只是复现主流实验或进行预处理,通常推荐优先下载 c23 视频版本。 t 用于指定文件类型(file type)。常见选项包括 videos、masks 和 models。其中,videos 表示下载视频文件,这是最常用的选项;masks 表示下载伪造区域的二值掩码,适用于伪造区域定位、分割或可解释性分析任务;models 主要与部分伪造方法相关,用于获取对应的生成模型文件。对于大多数 deepfake 分类或人脸抽帧任务,仅下载 videos 即可。 下载完成的数据集格式如下: (xq) qi.xiong@instance-ujccspas:/mnt/data2/qi.xiong/Dataset/FaceForensics$ tree -L 3 . ├── manipulated_sequences │ ├── DeepFakeDetection │ │ ├── c23 │ │ └── masks │ ├── Deepfakes │ │ ├── c23 │ │ └── masks │ ├── Face2Face │ │ ├── c23 │ │ └── masks │ ├── FaceShifter │ │ └── c23 │ ├── FaceSwap │ │ └── c23 │ └── NeuralTextures │ └── c23 └── original_sequences ├── actors │ └── c23 └── youtube └── c23 22 directories, 0 files 二、数据集预处理 我们前面下载得到的数据集仍然是视频格式,因此在正式用于 deepfake 检测之前,还需要先进行预处理。通常来说,这类任务不会直接将整段视频输入模型,而是先从视频中抽取若干具有代表性的帧,再从每一帧中提取对应的人脸区域。这样做一方面可以明显降低后续数据处理和模型训练的开销,另一方面也能让模型更聚焦于真正有用的面部伪造信息。FaceForensics++ 官方文档中也提到,通常更推荐先下载压缩后的视频,再自行完成帧提取。本文这里采用一种比较简化且实用的处理方式:从每个视频中均匀抽取固定数量的帧,然后使用 RetinaFace 对这些帧进行人脸检测,并将检测到的人脸区域裁剪保存。相比一些传统方法,RetinaFace 在检测精度和鲁棒性方面通常更有优势,尤其是在侧脸、光照变化较大或者人脸尺度变化明显的情况下,检测结果往往更加稳定。需要说明的是,本文这里的预处理目标比较明确,即只做人脸抽帧和人脸裁剪,不额外涉及关键点对齐、伪造区域掩码生成等更复杂的步骤,因此整个流程会更加清晰,也更适合作为 FaceForensics++ 数据预处理的基础版本。 git clone https://github.com/ternaus/retinaface.git cd retinaface pip install -v -e . 我们配置好了retinaface之后,即可使用如下脚本继续转换: from glob import glob import os import cv2 from tqdm import tqdm import numpy as np import argparse from retinaface.pre_trained_models import get_model import torch def facecrop(model, org_path, save_path, num_frames=10): cap_org = cv2.VideoCapture(org_path) frame_count_org = int(cap_org.get(cv2.CAP_PROP_FRAME_COUNT)) if frame_count_org <= 0: print(f"Invalid video: {org_path}") cap_org.release() return frame_idxs = np.linspace(0, frame_count_org - 1, num_frames, endpoint=True, dtype=int) frame_idxs = set(frame_idxs.tolist()) for cnt_frame in range(frame_count_org): ret_org, frame_org = cap_org.read() if not ret_org or frame_org is None: continue if cnt_frame not in frame_idxs: continue frame = cv2.cvtColor(frame_org, cv2.COLOR_BGR2RGB) faces = model.predict_jsons(frame) if len(faces) == 0: continue save_path_frames = os.path.join( save_path, 'frames_retina', os.path.basename(org_path).replace('.mp4', '') ) os.makedirs(save_path_frames, exist_ok=True) for face_idx, face in enumerate(faces): bbox = face.get('bbox', None) if bbox is None or len(bbox) < 4: continue x0, y0, x1, y1 = map(int, bbox[:4]) x0 = max(0, x0) y0 = max(0, y0) x1 = min(frame_org.shape[1], x1) y1 = min(frame_org.shape[0], y1) if x1 <= x0 or y1 <= y0: continue cropped_face = frame_org[y0:y1, x0:x1] face_image_path = os.path.join( save_path_frames, f'frame_{cnt_frame}_face_{face_idx}.png' ) cv2.imwrite(face_image_path, cropped_face) cap_org.release() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( '-d', dest='dataset', choices=[ 'Original', 'DeepFakeDetection_original', 'DeepFakeDetection', 'Deepfakes', 'Face2Face', 'FaceShifter', 'FaceSwap', 'NeuralTextures' ] ) parser.add_argument('-c', dest='comp', choices=['raw', 'c23', 'c40'], default='raw') parser.add_argument('-n', dest='num_frames', type=int, default=20) args = parser.parse_args() if args.dataset == 'Original': dataset_path = 'data/FaceForensics++/original_sequences/youtube/{}/'.format(args.comp) elif args.dataset == 'DeepFakeDetection_original': dataset_path = 'data/FaceForensics++/original_sequences/actors/{}/'.format(args.comp) elif args.dataset in ['DeepFakeDetection', 'FaceShifter', 'Face2Face', 'Deepfakes', 'FaceSwap', 'NeuralTextures']: dataset_path = 'data/FaceForensics++/manipulated_sequences/{}/{}/'.format(args.dataset, args.comp) else: raise NotImplementedError device = torch.device('cpu') model = get_model("resnet50_2020-07-20", max_size=2048, device=device) model.eval() movies_path = dataset_path + 'videos/' movies_path_list = sorted(glob(movies_path + '*.mp4')) print("{} : videos are exist in {}".format(len(movies_path_list), args.dataset)) for i in tqdm(range(len(movies_path_list))): facecrop(model, movies_path_list[i], save_path=dataset_path, num_frames=args.num_frames) 在具体使用时,我们主要关心三个参数:-d 用于指定要处理的数据子集,例如 Original 表示原始真实视频,Deepfakes、Face2Face、FaceSwap 和 NeuralTextures 表示不同伪造方法生成的视频;-c 用于指定压缩等级,常见取值包括 raw、c23 和 c40,其中 c23 是较为常用的一种设置;-n 表示每个视频需要抽取的帧数,例如 -n 20 表示从一个视频中均匀抽取 20 帧进行处理。 如果只以 FaceForensics++ 中的原始真实视频为例,并采用 c23 压缩版本,那么待处理的视频通常位于如下目录中: data/FaceForensics++/original_sequences/youtube/c23/videos/ 当脚本运行完成后,处理结果会保存在对应目录下新生成的 frames_retina 文件夹中。例如,如果处理的是 Original 的 c23 数据,那么输出目录通常为: data/FaceForensics++/original_sequences/youtube/c23/frames_retina/ 接下来我们需要划分train、val和test数据集,我们按照官方的比例来划分,使用如下代码即可: #!/usr/bin/env python # -*- coding: utf-8 -*- """ 将 frames_retina 组织为 fakefacecls 所需结构 用法: python setup_ffpp_dataset.py python setup_ffpp_dataset.py --data_root /path/to/data/FaceForensics++ 输出: data/FaceForensics++/ffpp/ ├── train.json, val.json, test.json ├── Origin/c23/larger_images/ -> symlinks to frames_retina ├── Deepfakes/c23/larger_images/ ├── Face2Face/c23/larger_images/ ├── FaceSwap/c23/larger_images/ └── NeuralTextures/c23/larger_images/ """ import argparse import json import os from pathlib import Path # FF++ 官方划分 (来自 https://github.com/ondyari/FaceForensics) TRAIN_JSON = [ ["071", "054"], ["087", "081"], ["881", "856"], ["187", "234"], ["645", "688"], ["754", "758"], ["811", "920"], ["710", "788"], ["628", "568"], ["312", "021"], ["950", "836"], ["059", "050"], ["524", "580"], ["751", "752"], ["918", "934"], ["604", "703"], ["296", "293"], ["518", "131"], ["536", "540"], ["969", "897"], ["372", "413"], ["357", "432"], ["809", "799"], ["092", "098"], ["302", "323"], ["981", "985"], ["512", "495"], ["088", "060"], ["795", "907"], ["535", "587"], ["297", "270"], ["838", "810"], ["850", "764"], ["476", "400"], ["268", "269"], ["033", "097"], ["226", "491"], ["784", "769"], ["195", "442"], ["678", "460"], ["320", "328"], ["451", "449"], ["409", "382"], ["556", "588"], ["027", "009"], ["196", "310"], ["241", "210"], ["295", "099"], ["043", "110"], ["753", "789"], ["716", "712"], ["508", "831"], ["005", "010"], ["276", "185"], ["498", "433"], ["294", "292"], ["105", "180"], ["984", "967"], ["318", "334"], ["356", "324"], ["344", "020"], ["289", "228"], ["022", "489"], ["137", "165"], ["095", "053"], ["999", "960"], ["481", "469"], ["534", "490"], ["543", "559"], ["150", "153"], ["598", "178"], ["475", "265"], ["671", "677"], ["204", "230"], ["863", "853"], ["561", "998"], ["163", "031"], ["655", "444"], ["038", "125"], ["735", "774"], ["184", "205"], ["499", "539"], ["717", "684"], ["878", "866"], ["127", "129"], ["286", "267"], ["032", "944"], ["681", "711"], ["236", "237"], ["989", "993"], ["537", "563"], ["814", "871"], ["509", "525"], ["221", "206"], ["808", "829"], ["696", "686"], ["431", "447"], ["737", "719"], ["609", "596"], ["408", "424"], ["976", "954"], ["156", "243"], ["434", "438"], ["627", "658"], ["025", "067"], ["635", "642"], ["523", "541"], ["572", "554"], ["215", "208"], ["651", "835"], ["975", "978"], ["792", "903"], ["931", "936"], ["846", "845"], ["899", "914"], ["209", "016"], ["398", "457"], ["797", "844"], ["360", "437"], ["738", "804"], ["694", "767"], ["790", "014"], ["657", "644"], ["374", "407"], ["728", "673"], ["193", "030"], ["876", "891"], ["553", "545"], ["331", "260"], ["873", "872"], ["109", "107"], ["121", "093"], ["143", "140"], ["778", "798"], ["983", "113"], ["504", "502"], ["709", "390"], ["940", "941"], ["894", "848"], ["311", "387"], ["562", "626"], ["330", "162"], ["112", "892"], ["765", "867"], ["124", "085"], ["665", "679"], ["414", "385"], ["555", "516"], ["072", "037"], ["086", "090"], ["202", "348"], ["341", "340"], ["333", "377"], ["082", "103"], ["569", "921"], ["750", "743"], ["211", "177"], ["770", "791"], ["329", "327"], ["613", "685"], ["007", "132"], ["304", "300"], ["860", "905"], ["986", "994"], ["378", "368"], ["761", "766"], ["232", "248"], ["136", "285"], ["601", "653"], ["693", "698"], ["359", "317"], ["246", "258"], ["500", "592"], ["776", "676"], ["262", "301"], ["307", "365"], ["600", "505"], ["833", "826"], ["361", "448"], ["473", "366"], ["885", "802"], ["277", "335"], ["667", "446"], ["522", "337"], ["018", "019"], ["430", "459"], ["886", "877"], ["456", "435"], ["239", "218"], ["771", "849"], ["065", "089"], ["654", "648"], ["151", "225"], ["152", "149"], ["229", "247"], ["624", "570"], ["290", "240"], ["011", "805"], ["461", "250"], ["251", "375"], ["639", "841"], ["602", "397"], ["028", "068"], ["338", "336"], ["964", "174"], ["782", "787"], ["478", "506"], ["313", "283"], ["659", "749"], ["690", "689"], ["893", "913"], ["197", "224"], ["253", "183"], ["373", "394"], ["803", "017"], ["305", "513"], ["051", "332"], ["238", "282"], ["621", "546"], ["401", "395"], ["510", "528"], ["410", "411"], ["049", "946"], ["663", "231"], ["477", "487"], ["252", "266"], ["952", "882"], ["315", "322"], ["216", "164"], ["061", "080"], ["603", "575"], ["828", "830"], ["723", "704"], ["870", "001"], ["201", "203"], ["652", "773"], ["108", "052"], ["272", "396"], ["040", "997"], ["988", "966"], ["281", "474"], ["077", "100"], ["146", "256"], ["972", "718"], ["303", "309"], ["582", "172"], ["222", "168"], ["884", "968"], ["217", "117"], ["118", "120"], ["242", "182"], ["858", "861"], ["101", "096"], ["697", "581"], ["763", "930"], ["839", "864"], ["542", "520"], ["122", "144"], ["687", "615"], ["544", "532"], ["721", "715"], ["179", "212"], ["591", "605"], ["275", "887"], ["996", "056"], ["825", "074"], ["530", "594"], ["757", "573"], ["611", "760"], ["189", "200"], ["392", "339"], ["734", "699"], ["977", "075"], ["879", "963"], ["910", "911"], ["889", "045"], ["962", "929"], ["515", "519"], ["062", "066"], ["937", "888"], ["199", "181"], ["785", "736"], ["079", "076"], ["155", "576"], ["748", "355"], ["819", "786"], ["577", "593"], ["464", "463"], ["439", "441"], ["574", "547"], ["747", "854"], ["403", "497"], ["965", "948"], ["726", "713"], ["943", "942"], ["160", "928"], ["496", "417"], ["700", "813"], ["756", "503"], ["213", "083"], ["039", "058"], ["781", "806"], ["620", "619"], ["351", "346"], ["959", "957"], ["264", "271"], ["006", "002"], ["391", "406"], ["631", "551"], ["501", "326"], ["412", "274"], ["641", "662"], ["111", "094"], ["166", "167"], ["130", "139"], ["938", "987"], ["055", "147"], ["990", "008"], ["013", "883"], ["614", "616"], ["772", "708"], ["840", "800"], ["415", "484"], ["287", "426"], ["680", "486"], ["057", "070"], ["590", "034"], ["194", "235"], ["291", "874"], ["902", "901"], ["343", "363"], ["279", "298"], ["393", "405"], ["674", "744"], ["244", "822"], ["133", "148"], ["636", "578"], ["637", "427"], ["041", "063"], ["869", "780"], ["733", "935"], ["259", "345"], ["069", "961"], ["783", "916"], ["191", "188"], ["526", "436"], ["123", "119"], ["207", "908"], ["796", "740"], ["815", "730"], ["173", "171"], ["383", "353"], ["458", "722"], ["533", "450"], ["618", "629"], ["646", "643"], ["531", "549"], ["428", "466"], ["859", "843"], ["692", "610"], ] VAL_JSON = [ ["720", "672"], ["939", "115"], ["284", "263"], ["402", "453"], ["820", "818"], ["762", "832"], ["834", "852"], ["922", "898"], ["104", "126"], ["106", "198"], ["159", "175"], ["416", "342"], ["857", "909"], ["599", "585"], ["443", "514"], ["566", "617"], ["472", "511"], ["325", "492"], ["816", "649"], ["583", "558"], ["933", "925"], ["419", "824"], ["465", "482"], ["565", "589"], ["261", "254"], ["992", "980"], ["157", "245"], ["571", "746"], ["947", "951"], ["926", "900"], ["493", "538"], ["468", "470"], ["915", "895"], ["362", "354"], ["440", "364"], ["640", "638"], ["827", "817"], ["793", "768"], ["837", "890"], ["004", "982"], ["192", "134"], ["745", "777"], ["299", "145"], ["742", "775"], ["586", "223"], ["483", "370"], ["779", "794"], ["971", "564"], ["273", "807"], ["991", "064"], ["664", "668"], ["823", "584"], ["656", "666"], ["557", "560"], ["471", "455"], ["042", "084"], ["979", "875"], ["316", "369"], ["091", "116"], ["023", "923"], ["702", "612"], ["904", "046"], ["647", "622"], ["958", "956"], ["606", "567"], ["632", "548"], ["927", "912"], ["350", "349"], ["595", "597"], ["727", "729"], ] TEST_JSON = [ ["953", "974"], ["012", "026"], ["078", "955"], ["623", "630"], ["919", "015"], ["367", "371"], ["847", "906"], ["529", "633"], ["418", "507"], ["227", "169"], ["389", "480"], ["821", "812"], ["670", "661"], ["158", "379"], ["423", "421"], ["352", "319"], ["579", "701"], ["488", "399"], ["695", "422"], ["288", "321"], ["705", "707"], ["306", "278"], ["865", "739"], ["995", "233"], ["755", "759"], ["467", "462"], ["314", "347"], ["741", "731"], ["970", "973"], ["634", "660"], ["494", "445"], ["706", "479"], ["186", "170"], ["176", "190"], ["380", "358"], ["214", "255"], ["454", "527"], ["425", "485"], ["388", "308"], ["384", "932"], ["035", "036"], ["257", "420"], ["924", "917"], ["114", "102"], ["732", "691"], ["550", "452"], ["280", "249"], ["842", "714"], ["625", "650"], ["024", "073"], ["044", "945"], ["896", "128"], ["862", "047"], ["607", "683"], ["517", "521"], ["682", "669"], ["138", "142"], ["552", "851"], ["376", "381"], ["000", "003"], ["048", "029"], ["724", "725"], ["608", "675"], ["386", "154"], ["220", "219"], ["801", "855"], ["161", "141"], ["949", "868"], ["880", "135"], ["429", "404"], ] # 路径映射: (method, codec) -> (frames_retina 相对路径) ORIGIN_FRAMES = "original_sequences/youtube/{codec}/frames_retina" MANIPULATED_FRAMES = "manipulated_sequences/{method}/{codec}/frames_retina" METHODS = ["Deepfakes", "Face2Face", "FaceSwap", "NeuralTextures", "FaceShifter"] # 可选 DeepFakeDetection def main(): root = Path(__file__).resolve().parent / "FaceForensics++" parser = argparse.ArgumentParser() parser.add_argument("--data_root", default=str(root), help="FaceForensics++ 根目录") parser.add_argument("--codec", default="c23") parser.add_argument("--methods", nargs="+", default=METHODS) args = parser.parse_args() data_root = Path(args.data_root) codec = args.codec ffpp = data_root / "ffpp" ffpp.mkdir(parents=True, exist_ok=True) # 1. 保存 JSON for name, pairs in [("train", TRAIN_JSON), ("val", VAL_JSON), ("test", TEST_JSON)]: f = ffpp / f"{name}.json" with open(f, "w") as fp: json.dump(pairs, fp, indent=2) print(f" {f}") # 2. Origin: larger_images/{id} -> symlink to frames_retina/xxx origin_frames = data_root / ORIGIN_FRAMES.format(codec=codec) origin_larger = ffpp / "Origin" / codec / "larger_images" origin_larger.mkdir(parents=True, exist_ok=True) if origin_frames.exists(): for vid in sorted(origin_frames.iterdir()): if vid.is_dir(): dst = origin_larger / vid.name if not dst.exists(): dst.symlink_to(vid.resolve()) print(f" Origin: {origin_larger} ({len(list(origin_larger.iterdir()))} videos)") else: print(f" [skip] Origin {origin_frames} not found") # 3. Manipulated: larger_images/{id1_id2} -> symlink to frames_retina/xxx for method in args.methods: man_frames = data_root / MANIPULATED_FRAMES.format(method=method, codec=codec) man_larger = ffpp / method / codec / "larger_images" man_larger.mkdir(parents=True, exist_ok=True) if man_frames.exists(): n = 0 for vid in sorted(man_frames.iterdir()): if vid.is_dir(): dst = man_larger / vid.name if not dst.exists(): dst.symlink_to(vid.resolve()) n += 1 print(f" {method}: {man_larger} ({n} videos)") else: print(f" [skip] {method} {man_frames} not found") print(f"\n完成: ffpp 目录 -> {ffpp}") print("\n使用方式:") print(" 1. fakefacecls: export FFPP_ROOT=" + str(ffpp.resolve())) print(" 2. multiple-attention: 在 datasets/data.py 中设置 ffpproot = '" + str(ffpp.resolve()) + "/'") if __name__ == "__main__": main() 三、人脸分类网络 我们接下来直接使用Timm库来验证CNN和Transformer作为Backbone对人脸伪造分类的识别性能,我们将支持两种分类方式,分别是二分类和五分类,二分类即单纯的True/False,五分类则在正确区分的基础上额外实现分类人脸伪造的方式 所有代码已上传至GitHub:https://github.com/xiongqi123123/fakefaceclsnet 数据集加载及数据增强代码如下: import os import random import torch import cv2 from torch.utils.data import Dataset import albumentations as A from albumentations import Compose from .augmentations import augmentations from . import data class DeepfakeDataset(Dataset): def __init__( self, phase='train', datalabel='', resize=(224, 224), imgs_per_video=30, min_frames=0, normalize=None, frame_interval=10, max_frames=300, augment='augment0', ): assert phase in ['train', 'val', 'test'] normalize = normalize or dict(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) self.datalabel = datalabel self.phase = phase self.imgs_per_video = imgs_per_video self.frame_interval = frame_interval self.epoch = 0 self.max_frames = max_frames self.min_frames = min_frames if min_frames else max_frames * 0.3 self.aug = augmentations.get(augment, augmentations['augment0']) self.resize = resize self.trans = Compose([ A.Resize(resize[0], resize[1]), # 小图(如19x14)需先 resize,CenterCrop 会报错 A.Normalize(mean=normalize['mean'], std=normalize['std']), A.ToTensorV2(), ]) self.dataset = self._build_dataset() self._frame_cache = {} # 缓存 os.listdir,避免每帧重复读目录 def _build_dataset(self): if isinstance(self.datalabel, (list, tuple)): return self.datalabel if 'ff-5' in self.datalabel: codec = self.datalabel.split('-')[2] out = [] for idx, tag in enumerate(['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face']): for item in data.FF_dataset(tag, codec, self.phase): out.append([item[0], idx]) return out if 'ff-all' in self.datalabel: codec = self.datalabel.split('-')[2] out = [] for tag in ['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face']: out.extend(data.FF_dataset(tag, codec, self.phase)) if self.phase != 'test': out = data.make_balance(out) return out if 'ff' in self.datalabel: parts = self.datalabel.split('-') codec = parts[2] tag = parts[1] return data.FF_dataset(tag, codec, self.phase) + data.FF_dataset('Origin', codec, self.phase) if 'celeb' in self.datalabel: return data.Celeb_test if 'deeper' in self.datalabel: codec = self.datalabel.split('-')[1] return data.deeperforensics_dataset(self.phase) + data.FF_dataset('Origin', codec, self.phase) if 'dfdc' in self.datalabel: return data.dfdc_dataset(self.phase) raise ValueError(f'Unknown datalabel: {self.datalabel}') def next_epoch(self): self.epoch += 1 def __getitem__(self, item): for _ in range(len(self.dataset)): # 避免无限递归 try: vid = self.dataset[item // self.imgs_per_video] vid_path = vid[0] if vid_path not in self._frame_cache: self._frame_cache[vid_path] = sorted(os.listdir(vid_path)) vd = self._frame_cache[vid_path] if len(vd) < self.min_frames: raise ValueError(f"frames {len(vd)} < min_frames {self.min_frames}") idx = (item % self.imgs_per_video * self.frame_interval + self.epoch) % min(len(vd), self.max_frames) fname = vd[idx] img = cv2.imread(os.path.join(vid[0], fname)) if img is None: raise ValueError(f"cv2.imread failed: {fname}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.phase == 'train': img = self.aug(image=img)['image'] return self.trans(image=img)['image'], vid[1] except Exception as e: if os.environ.get('DEBUG_DATASET') == '1' and not getattr(self, '_debug_printed', False): import traceback vp = self.dataset[item // self.imgs_per_video][0] if item < len(self) else '?' print(f'[DEBUG] item={item} path={vp} err={e}') traceback.print_exc() self._debug_printed = True # 只打印第一次 if self.phase == 'test': return torch.zeros(3, self.resize[0], self.resize[1]), -1 item = (item + self.imgs_per_video) % len(self) return torch.zeros(3, self.resize[0], self.resize[1]), -1 # 全部失败时返回占位 def __len__(self): return len(self.dataset) * self.imgs_per_video import os import json import random # 数据根目录:FFPP_ROOT 或默认 FFDeepFake/data/FaceForensics++/ffpp _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) _FFDEEPFAKE_ROOT = os.path.dirname(os.path.dirname(_SCRIPT_DIR)) # fakefacecls/ -> FFDeepFake _FFDEEPFAKE_ROOT = os.path.dirname(_FFDEEPFAKE_ROOT) # FFDeepFake _data_root = os.path.join(_FFDEEPFAKE_ROOT, 'data') _DEFAULT_FFPP = os.path.join(_data_root, 'FaceForensics++', 'ffpp') ffpproot = os.environ.get('FFPP_ROOT', _DEFAULT_FFPP) if ffpproot and not ffpproot.endswith(os.sep): ffpproot += os.sep dfdcroot = os.path.join(_data_root, 'dfdc') celebroot = os.path.join(_data_root, 'celebDF') deeperforensics_root = os.path.join(_data_root, 'deeper') def load_json(name): with open(name) as f: return json.load(f) def FF_dataset(tag='Origin', codec='c0', part='train'): assert tag in ['Origin', 'Deepfakes', 'NeuralTextures', 'FaceSwap', 'Face2Face', 'FaceShifter'] assert codec in ['c0', 'c23', 'c40', 'all'] assert part in ['train', 'val', 'test', 'all'] if part == 'all': return FF_dataset(tag, codec, 'train') + FF_dataset(tag, codec, 'val') + FF_dataset(tag, codec, 'test') if codec == 'all': return FF_dataset(tag, 'c0', part) + FF_dataset(tag, 'c23', part) + FF_dataset(tag, 'c40', part) path = os.path.join(ffpproot, tag, codec, 'larger_images') metafile = load_json(os.path.join(ffpproot, part + '.json')) files = [] if tag == 'Origin': for i in metafile: files.append([os.path.join(path, i[0]), 0]) files.append([os.path.join(path, i[1]), 0]) else: for i in metafile: files.append([os.path.join(path, i[0] + '_' + i[1]), 1]) files.append([os.path.join(path, i[1] + '_' + i[0]), 1]) return files def make_balance(data): tr = [x for x in data if x[1] == 0] tf = [x for x in data if x[1] == 1] if len(tr) > len(tf): tr, tf = tf, tr rate = len(tf) // len(tr) res = len(tf) - rate * len(tr) tr = tr * rate + random.sample(tr, res) return tr + tf def dfdc_dataset(part='train'): assert part in ['train', 'val', 'test'] lf = load_json(os.path.join(dfdcroot, 'DFDC.json')) if part == 'train': path = os.path.join(dfdcroot, 'dfdc') files = make_balance(lf['train']) elif part == 'test': path = os.path.join(dfdcroot, 'dfdc-test') files = lf['test'] else: path = os.path.join(dfdcroot, 'dfdc-val') files = lf['val'] return [[os.path.join(path, i[0]), i[1]] for i in files] def deeperforensics_dataset(part='train'): a = os.listdir(deeperforensics_root) d = {i.split('_')[0]: i for i in a} metafile = load_json(os.path.join(ffpproot, part + '.json')) files = [] for i in metafile: p = os.path.join(deeperforensics_root, d[i[0]]) files.append([p, 1]) p = os.path.join(deeperforensics_root, d[i[1]]) files.append([p, 1]) return files try: Celeb_test = list(map(lambda x: [os.path.join(celebroot, x[0]), 1 - x[1]], load_json(os.path.join(celebroot, 'celeb.json')))) except Exception: Celeb_test = [] import albumentations as A augment0 = A.Compose([A.HorizontalFlip()], p=1) augment1 = A.Compose([ A.HorizontalFlip(), A.HueSaturationValue(p=0.5), A.RandomBrightnessContrast(p=0.5), ], p=1) augment2 = A.Compose([ A.HorizontalFlip(), A.HueSaturationValue(p=0.5), A.RandomBrightnessContrast(p=0.5), A.OneOf([A.GaussNoise()], p=0.3), A.OneOf([ A.MotionBlur(), A.GaussianBlur(), A.ImageCompression(quality_range=(65, 80)), ], p=0.3), A.ToGray(p=0.1), ], p=1) augmentations = {'augment0': augment0, 'augment1': augment1, 'augment2': augment2} 网络直接使用Timm的预置模型: ####### __init__.py from .cnn import build_cnn_model def build_model(backbone='resnet50', num_classes=2, pretrained=True, dropout=0.3, **kwargs): """任意 timm 模型名均可,如 resnet50, vit_base_patch16_224, deit_small_patch16_224""" return build_cnn_model(backbone, num_classes, pretrained, dropout, **kwargs) ####### backbone.py """ 基于 timm 的 backbone + 分类头 支持 CNN (resnet50, efficientnet_b0, ...) 和 ViT (vit_base_patch16_224, deit_base_patch16_224, ...) """ import torch.nn as nn import timm def build_cnn_model(backbone='resnet50', num_classes=2, pretrained=True, dropout=0.3, **kwargs): """ Args: backbone: timm 模型名,如 resnet50, efficientnet_b0, convnext_tiny num_classes: 2 (真/假) 或 5 (Origin/Deepfakes/NeuralTextures/FaceSwap/Face2Face) pretrained: 是否加载 ImageNet 预训练 dropout: 分类头 dropout """ return CNNClassifier(backbone, num_classes, pretrained, dropout, **kwargs) class CNNClassifier(nn.Module): """timm backbone + 可替换分类头""" def __init__(self, backbone='resnet50', num_classes=2, pretrained=True, dropout=0.3, in_chans=3, **kwargs): super().__init__() self.num_classes = num_classes weights = 'imagenet' if pretrained else None self.backbone = timm.create_model( backbone, pretrained=weights, num_classes=0, # 移除原分类头 global_pool='avg', in_chans=in_chans, **kwargs ) feat_dim = self.backbone.num_features self.head = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(feat_dim, num_classes), ) def forward(self, x): feat = self.backbone(x) return self.head(feat) 然后就是训练的代码 """ 训练与验证脚本 用法: python train.py python train.py --backbone resnet50 --num_classes 5 --datalabel ff-5-c23 """ import argparse import logging import os import torch import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm import tqdm from config import TrainConfig from datasets import DeepfakeDataset from models import build_model def get_args(): p = argparse.ArgumentParser() p.add_argument('--backbone', default='resnet50', help='timm backbone') p.add_argument('--num_classes', type=int, default=2, choices=[2, 5]) p.add_argument('--datalabel', default='ff-all-c23', help='ff-all-c23 | ff-5-c23') p.add_argument('--epochs', type=int, default=20) p.add_argument('--batch_size', type=int, default=64) p.add_argument('--lr', type=float, default=1e-5) p.add_argument('--name', default='') p.add_argument('--no_pretrained', action='store_true') p.add_argument('--resume', default='') p.add_argument('--save_every', type=int, default=5, help='每隔多少轮保存一次 ckpt') p.add_argument('--log_interval', type=int, default=100, help='每隔多少 batch 打印一次 log') p.add_argument('--val_every_steps', type=int, default=500, help='每隔多少 step 在 val 上验证一次,0=仅每 epoch 结束') p.add_argument('--test_every_epoch', action='store_true', help='每个 epoch 结束在 test 上评估(仅监控,不参与选 best)') p.add_argument('--workers', type=int, default=0, help='DataLoader workers,0=用 config 默认') p.add_argument('--no_amp', action='store_true', help='禁用混合精度') return p.parse_args() def main(): args = get_args() cfg = TrainConfig( backbone=args.backbone, num_classes=args.num_classes, pretrained=not args.no_pretrained, datalabel=args.datalabel, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, name=args.name if args.name else None, ) run_dir = os.path.join('runs', cfg.name) os.makedirs(run_dir, exist_ok=True) log_path = os.path.join(run_dir, 'train.log') logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', handlers=[logging.FileHandler(log_path), logging.StreamHandler()], force=True, ) logging.info(f'Run dir: {run_dir}') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = build_model( backbone=cfg.backbone, num_classes=cfg.num_classes, pretrained=cfg.pretrained, dropout=cfg.dropout, ).to(device) workers = args.workers if args.workers > 0 else cfg.workers use_amp = not args.no_amp and torch.cuda.is_available() scaler = torch.amp.GradScaler('cuda') if use_amp else None if use_amp: logging.info('Using AMP (mixed precision)') train_ds = DeepfakeDataset(**cfg.dataset_kwargs('train')) val_ds = DeepfakeDataset(**cfg.dataset_kwargs('val')) test_ds = DeepfakeDataset(**cfg.dataset_kwargs('test')) dl_kw = dict(batch_size=cfg.batch_size, pin_memory=torch.cuda.is_available()) if workers > 0: dl_kw.update(num_workers=workers, persistent_workers=True, prefetch_factor=4) else: dl_kw['num_workers'] = 0 train_loader = DataLoader(train_ds, shuffle=True, **dl_kw) val_loader = DataLoader(val_ds, shuffle=False, **dl_kw) test_loader = DataLoader(test_ds, shuffle=False, **dl_kw) optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs) start_epoch = 0 best_acc = 0.0 if args.resume: ckpt = torch.load(args.resume, map_location=device) model.load_state_dict(ckpt.get('model', ckpt), strict=False) start_epoch = ckpt.get('epoch', 0) + 1 best_acc = ckpt.get('best_acc', 0.0) if 'optimizer' in ckpt: optimizer.load_state_dict(ckpt['optimizer']) if 'scheduler' in ckpt: scheduler.load_state_dict(ckpt['scheduler']) logging.info(f'Resumed from epoch {start_epoch}, best_acc={best_acc:.4f}') global_step = start_epoch * len(train_loader) for epoch in range(start_epoch, cfg.epochs): train_ds.next_epoch() train_loss, train_acc = train_epoch( model, train_loader, optimizer, device, log_interval=args.log_interval, scaler=scaler, val_loader=val_loader if args.val_every_steps > 0 else None, val_every_steps=args.val_every_steps, global_step=global_step, ) global_step += len(train_loader) val_loss, val_acc = validate(model, val_loader, device) scheduler.step() logging.info(f'E{epoch} train loss={train_loss:.4f} acc={train_acc:.4f} | val loss={val_loss:.4f} acc={val_acc:.4f}') ckpt = { 'model': model.state_dict(), 'epoch': epoch, 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), 'scheduler': scheduler.state_dict(), } torch.save(ckpt, os.path.join(run_dir, 'latest.pth')) if val_acc > best_acc: best_acc = val_acc torch.save({'model': model.state_dict(), 'epoch': epoch, 'best_acc': best_acc}, os.path.join(run_dir, 'best.pth')) logging.info(f' -> best acc={best_acc:.4f} saved') if args.save_every > 0 and (epoch + 1) % args.save_every == 0: torch.save({'model': model.state_dict(), 'epoch': epoch}, os.path.join(run_dir, f'ep{epoch}.pth')) if args.test_every_epoch: test_loss, test_acc = validate(model, test_loader, device, desc='test') logging.info(f'E{epoch} test loss={test_loss:.4f} acc={test_acc:.4f}') # 训练结束:用 best 在 test 上做最终评估(无 best 则用 latest) for ckpt_name in ('best.pth', 'latest.pth'): ckpt_path = os.path.join(run_dir, ckpt_name) if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path, map_location=device) model.load_state_dict(ckpt.get('model', ckpt), strict=False) test_loss, test_acc = validate(model, test_loader, device, desc='test') logging.info(f'[FINAL] {ckpt_name} on test: loss={test_loss:.4f} acc={test_acc:.4f}') break def train_epoch(model, loader, optimizer, device, log_interval=50, scaler=None, val_loader=None, val_every_steps=0, global_step=0): model.train() total_loss, total_acc, n = 0.0, 0.0, 0 pbar = tqdm(loader, desc='train', leave=False) for i, (x, y) in enumerate(pbar): step = global_step + i + 1 x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) optimizer.zero_grad() if scaler is not None: with torch.amp.autocast('cuda'): logits = model(x) loss = F.cross_entropy(logits, y) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() else: logits = model(x) loss = F.cross_entropy(logits, y) loss.backward() optimizer.step() acc = (logits.argmax(1) == y).float().mean().item() total_loss += loss.item() * x.size(0) total_acc += acc * x.size(0) n += x.size(0) pbar.set_postfix(loss=f'{loss.item():.4f}', acc=f'{acc:.4f}') if log_interval > 0 and (i + 1) % log_interval == 0: logging.info(f' batch {i+1}/{len(loader)} loss={loss.item():.4f} acc={acc:.4f}') if val_loader is not None and val_every_steps > 0 and step % val_every_steps == 0: val_loss, val_acc = validate(model, val_loader, device) logging.info(f' [step {step}] val loss={val_loss:.4f} acc={val_acc:.4f}') model.train() return total_loss / n, total_acc / n def validate(model, loader, device, desc='val'): model.eval() total_loss, total_acc, n = 0.0, 0.0, 0 with torch.no_grad(): pbar = tqdm(loader, desc=desc, leave=False) for x, y in pbar: x, y = x.to(device), y.to(device) # 过滤无效标签(dataset 加载失败时返回 -1) valid = y >= 0 if valid.sum() == 0: continue x, y = x[valid], y[valid] logits = model(x) loss = F.cross_entropy(logits, y) acc = (logits.argmax(1) == y).float().mean().item() total_loss += loss.item() * x.size(0) total_acc += acc * x.size(0) n += x.size(0) pbar.set_postfix(loss=f'{loss.item():.4f}', acc=f'{acc:.4f}') return total_loss / n, total_acc / n if n > 0 else (0.0, 0.0) if __name__ == '__main__': main() 可以看到训练的效果非常的好,基本一个Epoch就可以在Test验证集上达到0.8以上的正确率,且可以观察发现Transformer作为Backbone的效果远比CNN的效果好