This commit is contained in:
2025-04-14 17:15:38 +10:00
parent 91dbc67c71
commit e90d776f53
20 changed files with 1590 additions and 0 deletions

357
slowFast.ipynb Normal file
View File

@@ -0,0 +1,357 @@
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-04-14T03:41:10.176744Z",
"start_time": "2025-04-14T03:41:08.344955Z"
}
},
"source": [
"import torch\n",
"import numpy as np\n",
"import cv2\n",
"from decord import VideoReader, cpu"
],
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:41:10.726870Z",
"start_time": "2025-04-14T03:41:10.181751Z"
}
},
"cell_type": "code",
"source": [
"print(torch.__version__)\n",
"print(torch.cuda.is_available())"
],
"id": "31d50e0f9e4ea204",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.6.0+cu124\n",
"True\n"
]
}
],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:41:10.875039Z",
"start_time": "2025-04-14T03:41:10.871492Z"
}
},
"cell_type": "code",
"source": [
"def load_first_480_frames(video_path, resize=(224, 224)):\n",
" vr = VideoReader(video_path, ctx=cpu(0))\n",
" total_frames = len(vr)\n",
"\n",
" if total_frames < 480:\n",
" raise ValueError(f\"{video_path},视频帧数不足 480 帧\")\n",
"\n",
" indices = list(range(480))\n",
" frames = vr.get_batch(indices).asnumpy()\n",
"\n",
" if resize:\n",
" frames = np.array([cv2.resize(f, resize) for f in frames])\n",
" return frames"
],
"id": "e351cfd29d48331a",
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:41:10.884466Z",
"start_time": "2025-04-14T03:41:10.880516Z"
}
},
"cell_type": "code",
"source": [
"def preprocess_for_slowfast(frames, num_frames=32, alpha=4):\n",
" total = len(frames)\n",
" indices = np.linspace(0, total - 1, num=num_frames, dtype=int)\n",
" frames = frames[indices]\n",
"\n",
" frames = frames / 255.0\n",
" frames = (frames - [0.45, 0.45, 0.45]) / [0.225, 0.225, 0.225]\n",
" frames = frames.astype(np.float32)\n",
"\n",
" frames = torch.from_numpy(frames).permute(3, 0, 1, 2).unsqueeze(0)\n",
"\n",
" fast_pathway = frames\n",
" slow_pathway = frames[:, :, ::alpha, :, :]\n",
"\n",
" return [slow_pathway, fast_pathway]"
],
"id": "f8784f81a946176",
"outputs": [],
"execution_count": 4
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:41:10.893723Z",
"start_time": "2025-04-14T03:41:10.890080Z"
}
},
"cell_type": "code",
"source": [
"# 定义特征提取模型 (SlowFast Backbone)\n",
"class SlowFastFeatureExtractor(torch.nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" model = torch.hub.load(\"facebookresearch/pytorchvideo\", \"slowfast_r50\", pretrained=True)\n",
" # 移除分类头仅保留backbone部分\n",
" self.blocks = model.blocks[:-1] # 去掉最后的分类器 head\n",
" self.pool = torch.nn.AdaptiveAvgPool3d(1) # 全局池化\n",
"\n",
" def forward(self, x):\n",
" for block in self.blocks:\n",
" x = block(x)\n",
" x = self.pool(x)\n",
" x = torch.flatten(x, 1) # [B, C]\n",
" x = torch.nn.functional.normalize(x, dim=1) # 特征归一化\n",
" return x"
],
"id": "6c52b60b3399c5b7",
"outputs": [],
"execution_count": 5
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:41:13.527271Z",
"start_time": "2025-04-14T03:41:10.899167Z"
}
},
"cell_type": "code",
"source": [
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(\"Loading SlowFast backbone for feature extraction...\")\n",
"model = SlowFastFeatureExtractor().to(device).eval()\n",
"\n",
"\n",
"def extract_video_feature(video_path):\n",
"\n",
"\n",
" frames = load_first_480_frames(video_path)\n",
" inputs = preprocess_for_slowfast(frames, num_frames=32, alpha=4)\n",
"\n",
" with torch.no_grad():\n",
" inputs = [x.to(device) for x in inputs]\n",
" features = model(inputs)\n",
"\n",
" features = features.cpu().numpy().squeeze()\n",
" return features"
],
"id": "d841190cdd5ee920",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading SlowFast backbone for feature extraction...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using cache found in C:\\Users\\zikai/.cache\\torch\\hub\\facebookresearch_pytorchvideo_main\n"
]
}
],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:41:13.545015Z",
"start_time": "2025-04-14T03:41:13.541939Z"
}
},
"cell_type": "code",
"source": [
"def getFeature(video_path = r\"D:\\DESKTOP\\2025\\44\\a1\\dataset\\org\\1.mp4\"):\n",
"\n",
" feature = extract_video_feature(video_path)\n",
" # print(\"Video feature shape (for copyright):\", feature.shape)\n",
" return feature"
],
"id": "dce706080dfba5b6",
"outputs": [],
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:41:13.607891Z",
"start_time": "2025-04-14T03:41:13.554129Z"
}
},
"cell_type": "code",
"source": [
"import sqlite3\n",
"import io\n",
"import os\n",
"import tqdm\n",
"\n",
"# 注册适配器与转换器numpy数组 <-> BLOB\n",
"def adapt_array(arr):\n",
" out = io.BytesIO()\n",
" np.save(out, arr)\n",
" out.seek(0)\n",
" return sqlite3.Binary(out.read())\n",
"\n",
"def convert_array(text):\n",
" out = io.BytesIO(text)\n",
" out.seek(0)\n",
" return np.load(out)\n",
"\n",
"# 注册自定义类型处理\n",
"sqlite3.register_adapter(np.ndarray, adapt_array)\n",
"sqlite3.register_converter(\"array\", convert_array)\n",
"\n",
"# 创建数据库连接(带类型检测)\n",
"conn = sqlite3.connect(\"slowfast.db\", detect_types=sqlite3.PARSE_DECLTYPES)\n",
"cursor = conn.cursor()\n",
"\n",
"# 创建表\n",
"cursor.execute(\"\"\"\n",
"CREATE TABLE IF NOT EXISTS data (\n",
" id INTEGER PRIMARY KEY,\n",
" array BLOB,\n",
" group_path TEXT,\n",
" full_path TEXT\n",
")\n",
"\"\"\")\n",
"\n",
"\n",
"def add_to_db(array_to_store,group_path,full_path):\n",
" cursor.execute(\"INSERT INTO data (array,group_path,full_path) VALUES (?,?,?)\", (array_to_store,group_path,full_path,))\n",
" conn.commit()\n",
"\n",
"# # 读取数组\n",
"# cursor.execute(\"SELECT array FROM data WHERE id=1\")\n",
"# fetched_array = cursor.fetchone()[0]\n",
"#\n",
"# print(\"原始数组:\\n\", array_to_store)\n",
"# print(\"读取的数组:\\n\", fetched_array)\n",
"\n",
"folder_path = r\"D:\\DESKTOP\\2025\\44\\a1\\dataset\"\n",
"\n",
"all_files = []\n",
"names = [str(x)+\".mp4\" for x in range(10)]\n",
"print(names)\n",
"\n",
"\n",
"# 遍历文件夹\n",
"for group_path in os.listdir(folder_path):\n",
" full_path = os.path.join(folder_path, group_path)\n",
" if os.path.isdir(full_path):\n",
" # 遍历子文件夹\n",
" for video_path in os.listdir(full_path):\n",
" if os.path.basename(video_path) in names:\n",
" video_full_path = os.path.join(full_path, video_path)\n",
" if os.path.isfile(video_full_path):\n",
" # 处理视频文件\n",
" all_files.append((group_path,video_full_path))\n",
"print(len(all_files))\n"
],
"id": "dcf5af026d672b41",
"outputs": [],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:42:00.157934Z",
"start_time": "2025-04-14T03:41:13.633220Z"
}
},
"cell_type": "code",
"source": [
"for group_path, video_full_path in tqdm.tqdm(all_files):\n",
" # 读取视频特征\n",
" feature = getFeature(video_full_path)\n",
" # 将特征存储到数据库\n",
" add_to_db(feature,group_path,video_full_path)\n",
" # print(f\"已处理并存储: {video_full_path}\")\n",
"\n",
"conn.close()\n"
],
"id": "63eed21338e8f26c",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 70/70 [00:46<00:00, 1.50it/s]\n"
]
}
],
"execution_count": 10
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:42:00.170765Z",
"start_time": "2025-04-14T03:42:00.167606Z"
}
},
"cell_type": "code",
"source": "",
"id": "d8a84ff72bc12b33",
"outputs": [],
"execution_count": 11
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:42:00.184252Z",
"start_time": "2025-04-14T03:42:00.181880Z"
}
},
"cell_type": "code",
"source": "",
"id": "7343d8b3fcea327c",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}