Files
COMP90044_a1/slowFast.py
2025-04-14 17:15:38 +10:00

160 lines
4.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#%%
import torch
import numpy as np
import cv2
from decord import VideoReader, cpu
#%%
print(torch.__version__)
print(torch.cuda.is_available())
#%%
def load_first_480_frames(video_path, resize=(224, 224)):
vr = VideoReader(video_path, ctx=cpu(0))
total_frames = len(vr)
if total_frames < 480:
raise ValueError(f"{video_path},视频帧数不足 480 帧")
indices = list(range(480))
frames = vr.get_batch(indices).asnumpy()
if resize:
frames = np.array([cv2.resize(f, resize) for f in frames])
return frames
#%%
def preprocess_for_slowfast(frames, num_frames=32, alpha=4):
total = len(frames)
indices = np.linspace(0, total - 1, num=num_frames, dtype=int)
frames = frames[indices]
frames = frames / 255.0
frames = (frames - [0.45, 0.45, 0.45]) / [0.225, 0.225, 0.225]
frames = frames.astype(np.float32)
frames = torch.from_numpy(frames).permute(3, 0, 1, 2).unsqueeze(0)
fast_pathway = frames
slow_pathway = frames[:, :, ::alpha, :, :]
return [slow_pathway, fast_pathway]
#%%
# 定义特征提取模型 (SlowFast Backbone)
class SlowFastFeatureExtractor(torch.nn.Module):
def __init__(self):
super().__init__()
model = torch.hub.load("facebookresearch/pytorchvideo", "slowfast_r50", pretrained=True)
# 移除分类头仅保留backbone部分
self.blocks = model.blocks[:-1] # 去掉最后的分类器 head
self.pool = torch.nn.AdaptiveAvgPool3d(1) # 全局池化
def forward(self, x):
for block in self.blocks:
x = block(x)
x = self.pool(x)
x = torch.flatten(x, 1) # [B, C]
x = torch.nn.functional.normalize(x, dim=1) # 特征归一化
return x
#%%
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Loading SlowFast backbone for feature extraction...")
model = SlowFastFeatureExtractor().to(device).eval()
def extract_video_feature(video_path):
frames = load_first_480_frames(video_path)
inputs = preprocess_for_slowfast(frames, num_frames=32, alpha=4)
with torch.no_grad():
inputs = [x.to(device) for x in inputs]
features = model(inputs)
features = features.cpu().numpy().squeeze()
return features
#%%
def getFeature(video_path = r"D:\DESKTOP\2025\44\a1\dataset\org\1.mp4"):
feature = extract_video_feature(video_path)
# print("Video feature shape (for copyright):", feature.shape)
return feature
#%%
import sqlite3
import io
import os
import tqdm
# 注册适配器与转换器numpy数组 <-> BLOB
def adapt_array(arr):
out = io.BytesIO()
np.save(out, arr)
out.seek(0)
return sqlite3.Binary(out.read())
def convert_array(text):
out = io.BytesIO(text)
out.seek(0)
return np.load(out)
# 注册自定义类型处理
sqlite3.register_adapter(np.ndarray, adapt_array)
sqlite3.register_converter("array", convert_array)
# 创建数据库连接(带类型检测)
conn = sqlite3.connect("slowfast.db", detect_types=sqlite3.PARSE_DECLTYPES)
cursor = conn.cursor()
# 创建表
cursor.execute("""
CREATE TABLE IF NOT EXISTS data (
id INTEGER PRIMARY KEY,
array BLOB,
group_path TEXT,
full_path TEXT
)
""")
def add_to_db(array_to_store,group_path,full_path):
cursor.execute("INSERT INTO data (array,group_path,full_path) VALUES (?,?,?)", (array_to_store,group_path,full_path,))
conn.commit()
# # 读取数组
# cursor.execute("SELECT array FROM data WHERE id=1")
# fetched_array = cursor.fetchone()[0]
#
# print("原始数组:\n", array_to_store)
# print("读取的数组:\n", fetched_array)
#%%
folder_path = r"D:\DESKTOP\2025\44\a1\dataset"
all_files = []
names = [str(x)+".mp4" for x in range(10)]
print(names)
# 遍历文件夹
for group_path in os.listdir(folder_path):
full_path = os.path.join(folder_path, group_path)
if os.path.isdir(full_path):
# 遍历子文件夹
for video_path in os.listdir(full_path):
if os.path.basename(video_path) in names:
video_full_path = os.path.join(full_path, video_path)
if os.path.isfile(video_full_path):
# 处理视频文件
all_files.append((group_path,video_full_path))
print(len(all_files))
#%%
for group_path, video_full_path in tqdm.tqdm(all_files):
# 读取视频特征
feature = getFeature(video_full_path)
# 将特征存储到数据库
add_to_db(feature,group_path,video_full_path)
# print(f"已处理并存储: {video_full_path}")
#%%
conn.close()
#%%