{ "cells": [ { "cell_type": "code", "id": "initial_id", "metadata": { "collapsed": true, "ExecuteTime": { "end_time": "2025-04-14T03:41:10.176744Z", "start_time": "2025-04-14T03:41:08.344955Z" } }, "source": [ "import torch\n", "import numpy as np\n", "import cv2\n", "from decord import VideoReader, cpu" ], "outputs": [], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:41:10.726870Z", "start_time": "2025-04-14T03:41:10.181751Z" } }, "cell_type": "code", "source": [ "print(torch.__version__)\n", "print(torch.cuda.is_available())" ], "id": "31d50e0f9e4ea204", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.6.0+cu124\n", "True\n" ] } ], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:41:10.875039Z", "start_time": "2025-04-14T03:41:10.871492Z" } }, "cell_type": "code", "source": [ "def load_first_480_frames(video_path, resize=(224, 224)):\n", " vr = VideoReader(video_path, ctx=cpu(0))\n", " total_frames = len(vr)\n", "\n", " if total_frames < 480:\n", " raise ValueError(f\"{video_path},视频帧数不足 480 帧\")\n", "\n", " indices = list(range(480))\n", " frames = vr.get_batch(indices).asnumpy()\n", "\n", " if resize:\n", " frames = np.array([cv2.resize(f, resize) for f in frames])\n", " return frames" ], "id": "e351cfd29d48331a", "outputs": [], "execution_count": 3 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:41:10.884466Z", "start_time": "2025-04-14T03:41:10.880516Z" } }, "cell_type": "code", "source": [ "def preprocess_for_slowfast(frames, num_frames=32, alpha=4):\n", " total = len(frames)\n", " indices = np.linspace(0, total - 1, num=num_frames, dtype=int)\n", " frames = frames[indices]\n", "\n", " frames = frames / 255.0\n", " frames = (frames - [0.45, 0.45, 0.45]) / [0.225, 0.225, 0.225]\n", " frames = frames.astype(np.float32)\n", "\n", " frames = torch.from_numpy(frames).permute(3, 0, 1, 2).unsqueeze(0)\n", "\n", " fast_pathway = frames\n", " slow_pathway = frames[:, :, ::alpha, :, :]\n", "\n", " return [slow_pathway, fast_pathway]" ], "id": "f8784f81a946176", "outputs": [], "execution_count": 4 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:41:10.893723Z", "start_time": "2025-04-14T03:41:10.890080Z" } }, "cell_type": "code", "source": [ "# 定义特征提取模型 (SlowFast Backbone)\n", "class SlowFastFeatureExtractor(torch.nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " model = torch.hub.load(\"facebookresearch/pytorchvideo\", \"slowfast_r50\", pretrained=True)\n", " # 移除分类头,仅保留backbone部分\n", " self.blocks = model.blocks[:-1] # 去掉最后的分类器 head\n", " self.pool = torch.nn.AdaptiveAvgPool3d(1) # 全局池化\n", "\n", " def forward(self, x):\n", " for block in self.blocks:\n", " x = block(x)\n", " x = self.pool(x)\n", " x = torch.flatten(x, 1) # [B, C]\n", " x = torch.nn.functional.normalize(x, dim=1) # 特征归一化\n", " return x" ], "id": "6c52b60b3399c5b7", "outputs": [], "execution_count": 5 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:41:13.527271Z", "start_time": "2025-04-14T03:41:10.899167Z" } }, "cell_type": "code", "source": [ "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(\"Loading SlowFast backbone for feature extraction...\")\n", "model = SlowFastFeatureExtractor().to(device).eval()\n", "\n", "\n", "def extract_video_feature(video_path):\n", "\n", "\n", " frames = load_first_480_frames(video_path)\n", " inputs = preprocess_for_slowfast(frames, num_frames=32, alpha=4)\n", "\n", " with torch.no_grad():\n", " inputs = [x.to(device) for x in inputs]\n", " features = model(inputs)\n", "\n", " features = features.cpu().numpy().squeeze()\n", " return features" ], "id": "d841190cdd5ee920", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading SlowFast backbone for feature extraction...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Using cache found in C:\\Users\\zikai/.cache\\torch\\hub\\facebookresearch_pytorchvideo_main\n" ] } ], "execution_count": 6 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:41:13.545015Z", "start_time": "2025-04-14T03:41:13.541939Z" } }, "cell_type": "code", "source": [ "def getFeature(video_path = r\"D:\\DESKTOP\\2025\\44\\a1\\dataset\\org\\1.mp4\"):\n", "\n", " feature = extract_video_feature(video_path)\n", " # print(\"Video feature shape (for copyright):\", feature.shape)\n", " return feature" ], "id": "dce706080dfba5b6", "outputs": [], "execution_count": 7 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:41:13.607891Z", "start_time": "2025-04-14T03:41:13.554129Z" } }, "cell_type": "code", "source": [ "import sqlite3\n", "import io\n", "import os\n", "import tqdm\n", "\n", "# 注册适配器与转换器:numpy数组 <-> BLOB\n", "def adapt_array(arr):\n", " out = io.BytesIO()\n", " np.save(out, arr)\n", " out.seek(0)\n", " return sqlite3.Binary(out.read())\n", "\n", "def convert_array(text):\n", " out = io.BytesIO(text)\n", " out.seek(0)\n", " return np.load(out)\n", "\n", "# 注册自定义类型处理\n", "sqlite3.register_adapter(np.ndarray, adapt_array)\n", "sqlite3.register_converter(\"array\", convert_array)\n", "\n", "# 创建数据库连接(带类型检测)\n", "conn = sqlite3.connect(\"slowfast.db\", detect_types=sqlite3.PARSE_DECLTYPES)\n", "cursor = conn.cursor()\n", "\n", "# 创建表\n", "cursor.execute(\"\"\"\n", "CREATE TABLE IF NOT EXISTS data (\n", " id INTEGER PRIMARY KEY,\n", " array BLOB,\n", " group_path TEXT,\n", " full_path TEXT\n", ")\n", "\"\"\")\n", "\n", "\n", "def add_to_db(array_to_store,group_path,full_path):\n", " cursor.execute(\"INSERT INTO data (array,group_path,full_path) VALUES (?,?,?)\", (array_to_store,group_path,full_path,))\n", " conn.commit()\n", "\n", "# # 读取数组\n", "# cursor.execute(\"SELECT array FROM data WHERE id=1\")\n", "# fetched_array = cursor.fetchone()[0]\n", "#\n", "# print(\"原始数组:\\n\", array_to_store)\n", "# print(\"读取的数组:\\n\", fetched_array)\n", "\n", "folder_path = r\"D:\\DESKTOP\\2025\\44\\a1\\dataset\"\n", "\n", "all_files = []\n", "names = [str(x)+\".mp4\" for x in range(10)]\n", "print(names)\n", "\n", "\n", "# 遍历文件夹\n", "for group_path in os.listdir(folder_path):\n", " full_path = os.path.join(folder_path, group_path)\n", " if os.path.isdir(full_path):\n", " # 遍历子文件夹\n", " for video_path in os.listdir(full_path):\n", " if os.path.basename(video_path) in names:\n", " video_full_path = os.path.join(full_path, video_path)\n", " if os.path.isfile(video_full_path):\n", " # 处理视频文件\n", " all_files.append((group_path,video_full_path))\n", "print(len(all_files))\n" ], "id": "dcf5af026d672b41", "outputs": [], "execution_count": 8 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:42:00.157934Z", "start_time": "2025-04-14T03:41:13.633220Z" } }, "cell_type": "code", "source": [ "for group_path, video_full_path in tqdm.tqdm(all_files):\n", " # 读取视频特征\n", " feature = getFeature(video_full_path)\n", " # 将特征存储到数据库\n", " add_to_db(feature,group_path,video_full_path)\n", " # print(f\"已处理并存储: {video_full_path}\")\n", "\n", "conn.close()\n" ], "id": "63eed21338e8f26c", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 70/70 [00:46<00:00, 1.50it/s]\n" ] } ], "execution_count": 10 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:42:00.170765Z", "start_time": "2025-04-14T03:42:00.167606Z" } }, "cell_type": "code", "source": "", "id": "d8a84ff72bc12b33", "outputs": [], "execution_count": 11 }, { "metadata": { "ExecuteTime": { "end_time": "2025-04-14T03:42:00.184252Z", "start_time": "2025-04-14T03:42:00.181880Z" } }, "cell_type": "code", "source": "", "id": "7343d8b3fcea327c", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }