160 lines
4.5 KiB
Python
160 lines
4.5 KiB
Python
#%%
|
||
import torch
|
||
import numpy as np
|
||
import cv2
|
||
from decord import VideoReader, cpu
|
||
#%%
|
||
print(torch.__version__)
|
||
print(torch.cuda.is_available())
|
||
#%%
|
||
def load_first_480_frames(video_path, resize=(224, 224)):
|
||
vr = VideoReader(video_path, ctx=cpu(0))
|
||
total_frames = len(vr)
|
||
|
||
if total_frames < 480:
|
||
raise ValueError(f"{video_path},视频帧数不足 480 帧")
|
||
|
||
indices = list(range(480))
|
||
frames = vr.get_batch(indices).asnumpy()
|
||
|
||
if resize:
|
||
frames = np.array([cv2.resize(f, resize) for f in frames])
|
||
return frames
|
||
#%%
|
||
def preprocess_for_slowfast(frames, num_frames=32, alpha=4):
|
||
total = len(frames)
|
||
indices = np.linspace(0, total - 1, num=num_frames, dtype=int)
|
||
frames = frames[indices]
|
||
|
||
frames = frames / 255.0
|
||
frames = (frames - [0.45, 0.45, 0.45]) / [0.225, 0.225, 0.225]
|
||
frames = frames.astype(np.float32)
|
||
|
||
frames = torch.from_numpy(frames).permute(3, 0, 1, 2).unsqueeze(0)
|
||
|
||
fast_pathway = frames
|
||
slow_pathway = frames[:, :, ::alpha, :, :]
|
||
|
||
return [slow_pathway, fast_pathway]
|
||
#%%
|
||
# 定义特征提取模型 (SlowFast Backbone)
|
||
class SlowFastFeatureExtractor(torch.nn.Module):
|
||
def __init__(self):
|
||
super().__init__()
|
||
model = torch.hub.load("facebookresearch/pytorchvideo", "slowfast_r50", pretrained=True)
|
||
# 移除分类头,仅保留backbone部分
|
||
self.blocks = model.blocks[:-1] # 去掉最后的分类器 head
|
||
self.pool = torch.nn.AdaptiveAvgPool3d(1) # 全局池化
|
||
|
||
def forward(self, x):
|
||
for block in self.blocks:
|
||
x = block(x)
|
||
x = self.pool(x)
|
||
x = torch.flatten(x, 1) # [B, C]
|
||
x = torch.nn.functional.normalize(x, dim=1) # 特征归一化
|
||
return x
|
||
#%%
|
||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
print("Loading SlowFast backbone for feature extraction...")
|
||
model = SlowFastFeatureExtractor().to(device).eval()
|
||
|
||
|
||
def extract_video_feature(video_path):
|
||
|
||
|
||
frames = load_first_480_frames(video_path)
|
||
inputs = preprocess_for_slowfast(frames, num_frames=32, alpha=4)
|
||
|
||
with torch.no_grad():
|
||
inputs = [x.to(device) for x in inputs]
|
||
features = model(inputs)
|
||
|
||
features = features.cpu().numpy().squeeze()
|
||
return features
|
||
#%%
|
||
def getFeature(video_path = r"D:\DESKTOP\2025\44\a1\dataset\org\1.mp4"):
|
||
|
||
feature = extract_video_feature(video_path)
|
||
# print("Video feature shape (for copyright):", feature.shape)
|
||
return feature
|
||
#%%
|
||
import sqlite3
|
||
import io
|
||
import os
|
||
import tqdm
|
||
|
||
# 注册适配器与转换器:numpy数组 <-> BLOB
|
||
def adapt_array(arr):
|
||
out = io.BytesIO()
|
||
np.save(out, arr)
|
||
out.seek(0)
|
||
return sqlite3.Binary(out.read())
|
||
|
||
def convert_array(text):
|
||
out = io.BytesIO(text)
|
||
out.seek(0)
|
||
return np.load(out)
|
||
|
||
# 注册自定义类型处理
|
||
sqlite3.register_adapter(np.ndarray, adapt_array)
|
||
sqlite3.register_converter("array", convert_array)
|
||
|
||
# 创建数据库连接(带类型检测)
|
||
conn = sqlite3.connect("slowfast.db", detect_types=sqlite3.PARSE_DECLTYPES)
|
||
cursor = conn.cursor()
|
||
|
||
# 创建表
|
||
cursor.execute("""
|
||
CREATE TABLE IF NOT EXISTS data (
|
||
id INTEGER PRIMARY KEY,
|
||
array BLOB,
|
||
group_path TEXT,
|
||
full_path TEXT
|
||
)
|
||
""")
|
||
|
||
|
||
def add_to_db(array_to_store,group_path,full_path):
|
||
cursor.execute("INSERT INTO data (array,group_path,full_path) VALUES (?,?,?)", (array_to_store,group_path,full_path,))
|
||
conn.commit()
|
||
|
||
# # 读取数组
|
||
# cursor.execute("SELECT array FROM data WHERE id=1")
|
||
# fetched_array = cursor.fetchone()[0]
|
||
#
|
||
# print("原始数组:\n", array_to_store)
|
||
# print("读取的数组:\n", fetched_array)
|
||
|
||
|
||
#%%
|
||
folder_path = r"D:\DESKTOP\2025\44\a1\dataset"
|
||
|
||
all_files = []
|
||
names = [str(x)+".mp4" for x in range(10)]
|
||
print(names)
|
||
|
||
|
||
# 遍历文件夹
|
||
for group_path in os.listdir(folder_path):
|
||
full_path = os.path.join(folder_path, group_path)
|
||
if os.path.isdir(full_path):
|
||
# 遍历子文件夹
|
||
for video_path in os.listdir(full_path):
|
||
if os.path.basename(video_path) in names:
|
||
video_full_path = os.path.join(full_path, video_path)
|
||
if os.path.isfile(video_full_path):
|
||
# 处理视频文件
|
||
all_files.append((group_path,video_full_path))
|
||
print(len(all_files))
|
||
#%%
|
||
for group_path, video_full_path in tqdm.tqdm(all_files):
|
||
# 读取视频特征
|
||
feature = getFeature(video_full_path)
|
||
# 将特征存储到数据库
|
||
add_to_db(feature,group_path,video_full_path)
|
||
# print(f"已处理并存储: {video_full_path}")
|
||
#%%
|
||
conn.close()
|
||
#%%
|