Files
COMP90044_a1/i3d.ipynb
2025-04-14 17:15:38 +10:00

280 lines
7.8 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:58:08.619189Z",
"start_time": "2025-04-14T03:58:08.616451Z"
}
},
"cell_type": "code",
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"import torchvision.transforms as transforms\n",
"from pytorchvideo.models.hub import i3d_r50\n",
"import cv2\n",
"import numpy as np\n",
"from PIL import Image"
],
"id": "79af3e0bb61c3290",
"outputs": [],
"execution_count": 5
},
{
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-04-14T03:58:09.075199Z",
"start_time": "2025-04-14T03:58:08.635303Z"
}
},
"cell_type": "code",
"source": [
"# 初始化预训练I3D模型\n",
"model = i3d_r50(pretrained=True).eval()\n",
"feature_extractor = torch.nn.Sequential(*model.blocks[:-1])"
],
"id": "initial_id",
"outputs": [],
"execution_count": 6
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:58:09.118627Z",
"start_time": "2025-04-14T03:58:09.113363Z"
}
},
"cell_type": "code",
"source": [
"def preprocess_video(video_path, num_frames=32, size=224):\n",
" cap = cv2.VideoCapture(video_path)\n",
" frames = []\n",
" transform = transforms.Compose([\n",
" transforms.Resize((size, size)),\n",
" transforms.ToTensor(),\n",
" transforms.Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])\n",
" ])\n",
"\n",
" total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
" frame_indices = np.linspace(0, total_frames - 1, num_frames).astype(int)\n",
"\n",
" for i in frame_indices:\n",
" cap.set(cv2.CAP_PROP_POS_FRAMES, i)\n",
" ret, frame = cap.read()\n",
" if not ret:\n",
" break\n",
" frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
" frame = Image.fromarray(frame)\n",
" frame = transform(frame)\n",
" frames.append(frame)\n",
"\n",
" cap.release()\n",
" video_tensor = torch.stack(frames).permute(1, 0, 2, 3).unsqueeze(0)\n",
" return video_tensor\n"
],
"id": "5fd2f80d5947967e",
"outputs": [],
"execution_count": 7
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:58:09.128752Z",
"start_time": "2025-04-14T03:58:09.125636Z"
}
},
"cell_type": "code",
"source": [
"def extract_features(video_tensor):\n",
" with torch.no_grad():\n",
" features = feature_extractor(video_tensor)\n",
" features = F.adaptive_avg_pool3d(features, 1)\n",
" features = features.flatten()\n",
" return features.numpy()"
],
"id": "728a0b9ece5bdc06",
"outputs": [],
"execution_count": 8
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:58:09.142606Z",
"start_time": "2025-04-14T03:58:09.138924Z"
}
},
"cell_type": "code",
"source": [
"def video_features(video_path):\n",
" video_tensor = preprocess_video(video_path)\n",
" features = extract_features(video_tensor)\n",
" return features"
],
"id": "60ca6ade121d00af",
"outputs": [],
"execution_count": 9
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:58:09.156458Z",
"start_time": "2025-04-14T03:58:09.152546Z"
}
},
"cell_type": "code",
"source": [
"def getFeature(video_path = r\"D:\\DESKTOP\\2025\\44\\a1\\dataset\\org\\1.mp4\"):\n",
"\n",
" feature = video_features(video_path)\n",
" # print(\"Video feature shape (for copyright):\", feature.shape)\n",
" return feature"
],
"id": "5c3f6ede68d0f22f",
"outputs": [],
"execution_count": 10
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:58:09.206798Z",
"start_time": "2025-04-14T03:58:09.166673Z"
}
},
"cell_type": "code",
"source": [
"import sqlite3\n",
"import io\n",
"import os\n",
"import tqdm\n",
"\n",
"# 注册适配器与转换器numpy数组 <-> BLOB\n",
"def adapt_array(arr):\n",
" out = io.BytesIO()\n",
" np.save(out, arr)\n",
" out.seek(0)\n",
" return sqlite3.Binary(out.read())\n",
"\n",
"def convert_array(text):\n",
" out = io.BytesIO(text)\n",
" out.seek(0)\n",
" return np.load(out)\n",
"\n",
"# 注册自定义类型处理\n",
"sqlite3.register_adapter(np.ndarray, adapt_array)\n",
"sqlite3.register_converter(\"array\", convert_array)\n",
"\n",
"# 创建数据库连接(带类型检测)\n",
"conn = sqlite3.connect(\"i3d.db\", detect_types=sqlite3.PARSE_DECLTYPES)\n",
"cursor = conn.cursor()\n",
"\n",
"# 创建表\n",
"cursor.execute(\"\"\"\n",
"CREATE TABLE IF NOT EXISTS data (\n",
" id INTEGER PRIMARY KEY,\n",
" array BLOB,\n",
" group_path TEXT,\n",
" full_path TEXT\n",
")\n",
"\"\"\")\n",
"\n",
"\n",
"def add_to_db(array_to_store,group_path,full_path):\n",
" cursor.execute(\"INSERT INTO data (array,group_path,full_path) VALUES (?,?,?)\", (array_to_store,group_path,full_path,))\n",
" conn.commit()\n",
"\n",
"# # 读取数组\n",
"# cursor.execute(\"SELECT array FROM data WHERE id=1\")\n",
"# fetched_array = cursor.fetchone()[0]\n",
"#\n",
"# print(\"原始数组:\\n\", array_to_store)\n",
"# print(\"读取的数组:\\n\", fetched_array)\n",
"\n",
"folder_path = r\"D:\\DESKTOP\\2025\\44\\a1\\dataset\"\n",
"\n",
"all_files = []\n",
"names = [str(x)+\".mp4\" for x in range(10)]\n",
"print(names)\n",
"\n",
"\n",
"# 遍历文件夹\n",
"for group_path in os.listdir(folder_path):\n",
" full_path = os.path.join(folder_path, group_path)\n",
" if os.path.isdir(full_path):\n",
" # 遍历子文件夹\n",
" for video_path in os.listdir(full_path):\n",
" if os.path.basename(video_path) in names:\n",
" video_full_path = os.path.join(full_path, video_path)\n",
" if os.path.isfile(video_full_path):\n",
" # 处理视频文件\n",
" all_files.append((group_path,video_full_path))\n",
"print(len(all_files))\n"
],
"id": "517f4d3e7e8d4402",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['0.mp4', '1.mp4', '2.mp4', '3.mp4', '4.mp4', '5.mp4', '6.mp4', '7.mp4', '8.mp4', '9.mp4']\n",
"70\n"
]
}
],
"execution_count": 11
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T03:59:48.339352Z",
"start_time": "2025-04-14T03:58:09.217915Z"
}
},
"cell_type": "code",
"source": [
"for group_path, video_full_path in tqdm.tqdm(all_files):\n",
" # 读取视频特征\n",
" feature = getFeature(video_full_path)\n",
" # 将特征存储到数据库\n",
" add_to_db(feature,group_path,video_full_path)\n",
" # print(f\"已处理并存储: {video_full_path}\")\n",
"\n",
"conn.close()\n"
],
"id": "aa7bdb735dbc1e1e",
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 70/70 [01:39<00:00, 1.42s/it]\n"
]
}
],
"execution_count": 12
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}