This commit is contained in:
2025-04-14 21:47:52 +10:00
parent e90d776f53
commit 2b79a4869d

View File

@@ -3,8 +3,8 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T06:58:32.199713Z",
"start_time": "2025-04-14T06:58:32.134049Z"
"end_time": "2025-04-14T09:34:18.359212Z",
"start_time": "2025-04-14T09:34:18.356046Z"
}
},
"cell_type": "code",
@@ -13,38 +13,23 @@
"import numpy as np\n",
"import io\n",
"from numpy.linalg import norm\n",
"from collections import defaultdict"
"from collections import defaultdict\n",
"import os"
],
"id": "6917288db44004ea",
"outputs": [],
"execution_count": 1
"execution_count": 59
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T06:58:34.129528Z",
"start_time": "2025-04-14T06:58:34.126022Z"
}
},
"cell_type": "code",
"source": [
"currdb = \"i3d.db\"\n",
"#currdb = \"LBPTOP.db\"\n",
"currdb = \"slowfast.db\""
],
"id": "dd9b84fa98c77e5a",
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T06:58:36.801611Z",
"start_time": "2025-04-14T06:58:36.798114Z"
"end_time": "2025-04-14T09:34:18.371739Z",
"start_time": "2025-04-14T09:34:18.367219Z"
}
},
"cell_type": "code",
"source": [
"\n",
"# 注册适配器与转换器numpy数组 <-> BLOB\n",
"def adapt_array(arr):\n",
" out = io.BytesIO()\n",
@@ -52,71 +37,49 @@
" out.seek(0)\n",
" return sqlite3.Binary(out.read())\n",
"\n",
"\n",
"def convert_array(text):\n",
" out = io.BytesIO(text)\n",
" out.seek(0)\n",
" return np.load(out)"
],
"id": "fa2ab158cbe0aa98",
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T07:01:06.063185Z",
"start_time": "2025-04-14T07:01:06.060008Z"
}
},
"cell_type": "code",
"source": [
" return np.load(out)\n",
"\n",
"sqlite3.register_adapter(np.ndarray, adapt_array)\n",
"sqlite3.register_converter(\"BLOB\", convert_array)\n",
"\n",
"# 连接SQLite数据库\n",
"conn = sqlite3.connect(currdb, detect_types=sqlite3.PARSE_DECLTYPES)\n",
"cursor = conn.cursor()"
],
"id": "629eb156f332c0bc",
"outputs": [],
"execution_count": 12
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T07:01:06.372809Z",
"start_time": "2025-04-14T07:01:06.362911Z"
}
},
"cell_type": "code",
"source": [
"# 读取数据库所有数据\n",
"cursor.execute(\"SELECT id, array, group_path, full_path FROM data\")\n",
"all_data = cursor.fetchall()"
],
"id": "ba7902156f1b6117",
"outputs": [],
"execution_count": 13
},
{
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-04-14T07:01:06.786016Z",
"start_time": "2025-04-14T07:01:06.779812Z"
}
},
"cell_type": "code",
"source": [
"# 整理特征\n",
"data_by_group = defaultdict(list)\n",
"for vid, arr, group, path in all_data:\n",
" data_by_group[group].append({'id': vid, 'feature': arr, 'path': path})\n",
"\n",
"# 相似度函数 (余弦相似度)\n",
"def cosine_similarity(a, b):\n",
" return np.dot(a, b) / (norm(a) * norm(b))\n",
"\n"
],
"id": "dd9b84fa98c77e5a",
"outputs": [],
"execution_count": 60
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.388488Z",
"start_time": "2025-04-14T09:34:18.385501Z"
}
},
"cell_type": "code",
"source": [
"def get_video_id(path):\n",
" filename = os.path.basename(path)\n",
" name = os.path.splitext(filename)[0] # 去掉扩展名\n",
" return name"
],
"id": "884f8c27d67a93d5",
"outputs": [],
"execution_count": 61
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.451821Z",
"start_time": "2025-04-14T09:34:18.443936Z"
}
},
"cell_type": "code",
"source": [
"\n",
"# 计算指标函数\n",
"def compute_metrics(org_feats, query_feats, sim_feats, K=5):\n",
@@ -135,6 +98,11 @@
" for target in search_db:\n",
" sim = cosine_similarity(query_feature, target['feature'])\n",
" label = 1 if target in org_feats else 0 # 1代表匹配0代表不匹配\n",
" if label == 1:\n",
" query_id = get_video_id(query['path'])\n",
" target_id = get_video_id(target['path'])\n",
" label = 1 if target_id == query_id else 0\n",
"\n",
" similarities.append((sim, label))\n",
"\n",
" # 相似度降序排序\n",
@@ -145,7 +113,7 @@
" precision_top1.append(labels_sorted[0])\n",
"\n",
" # Recall@K\n",
" recall = sum(labels_sorted[:K]) / len(org_feats)\n",
" recall = int(1 in labels_sorted[:K])\n",
" recall_at_k.append(recall)\n",
"\n",
" # Average Precision (AP)\n",
@@ -162,45 +130,9 @@
" 'Top-1 Precision': np.mean(precision_top1),\n",
" f'Recall@{K}': np.mean(recall_at_k),\n",
" 'MAP': np.mean(AP_list)\n",
" }"
],
"id": "initial_id",
"outputs": [],
"execution_count": 14
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T07:01:07.745386Z",
"start_time": "2025-04-14T07:01:07.740942Z"
}
},
"cell_type": "code",
"source": "data_by_group['org'][0]['feature'].shape",
"id": "55e57dc7ec56b732",
"outputs": [
{
"data": {
"text/plain": [
"(2304,)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T06:00:12.392784Z",
"start_time": "2025-04-14T06:00:12.382725Z"
}
},
"cell_type": "code",
"source": [
" }\n",
"\n",
"\n",
"# 主函数,计算所有变形组的指标\n",
"def evaluate_all_variants(K=5):\n",
" org_feats = data_by_group['org']\n",
@@ -218,54 +150,168 @@
"\n",
" return results\n",
"\n",
"# 执行评估\n",
"results = evaluate_all_variants(K=3)\n",
"\n",
"# 输出结果\n",
"for variant, metrics in results.items():\n",
" print(f\"变形组: {variant}\")\n",
" for metric_name, value in metrics.items():\n",
" print(f\" {metric_name}: {value:.4f}\")\n",
" print(\"-\" * 40)\n",
"def results_to_latex(currdb, results):\n",
" metric_names = list(next(iter(results.values())).keys()) # 获取所有指标名\n",
" header = \"Deformation & \" + \" & \".join(metric_names) + \" \\\\\\\\ \\\\hline\\n\"\n",
"\n",
"# 关闭数据库连接\n",
"conn.close()"
" rows = \"\"\n",
" metric_sums = {metric: 0.0 for metric in metric_names}\n",
" count = 0\n",
"\n",
" for variant_id, variant_display in variant_order:\n",
" if variant_id not in results:\n",
" continue # 忽略不存在的组\n",
" metrics = results[variant_id]\n",
" count += 1\n",
" row = variant_display\n",
" for metric in metric_names:\n",
" val = metrics[metric]\n",
" metric_sums[metric] += val\n",
" row += f\" & {val:.4f}\"\n",
" rows += row + \" \\\\\\\\\\n\"\n",
"\n",
" # 添加平均值行\n",
" avg_row = \"avg\"\n",
" for metric in metric_names:\n",
" avg_val = metric_sums[metric] / count if count else 0.0\n",
" avg_row += f\" & {avg_val:.4f}\"\n",
" rows += avg_row + \" \\\\\\\\\\n\"\n",
"\n",
" latex_table = (\n",
" \"\\\\begin{table}[htbp]\\n\"\n",
" \"\\\\centering\\n\"\n",
" f\"\\\\caption{{Top-1 Precision, Recall@K and MAP for {currdb}}}\\n\"\n",
" \"\\\\begin{tabular}{l\" + \"c\" * len(metric_names) + \"}\\n\"\n",
" \"\\\\hline\\n\"\n",
" + header +\n",
" rows +\n",
" \"\\\\hline\\n\"\n",
" \"\\\\end{tabular}\\n\"\n",
" \"\\\\end{table}\\n\"\n",
" )\n",
" return latex_table"
],
"id": "d9a0103dbc3f1bef",
"id": "282bfd74f5982041",
"outputs": [],
"execution_count": 62
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.460011Z",
"start_time": "2025-04-14T09:34:18.454827Z"
}
},
"cell_type": "code",
"source": [
"variant_order = [\n",
" (r\"静态水印1\", r\"watermark1\"),\n",
" (r\"动态水印2\", r\"watermark2\"),\n",
" (r\"水印1+2\", r\"watermark1\\&2\"),\n",
" (r\"滤镜\", r\"filters\"),\n",
" (r\"水印1+2+滤镜\", r\"filters\\&watermark1\\&2\"),\n",
" (r\"watermark1\", r\"watermark1\"),\n",
" (r\"watermark2\", r\"watermark2\"),\n",
" (r\"watermark1+2\", r\"watermark11\\&2\"),\n",
" (r\"filters\", r\"filters\"),\n",
" (r\"filters+watermark1+2\", r\"filters\\&watermark1\\&2\"),\n",
"]"
],
"id": "12e194a6bf64b8ac",
"outputs": [],
"execution_count": 63
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.512805Z",
"start_time": "2025-04-14T09:34:18.473552Z"
}
},
"cell_type": "code",
"source": [
"\n",
"for currdb in [\"LBPTOP\", \"i3d\", \"slowfast\"]:\n",
" sqlite3.register_adapter(np.ndarray, adapt_array)\n",
" sqlite3.register_converter(\"BLOB\", convert_array)\n",
"\n",
" # 连接SQLite数据库\n",
" conn = sqlite3.connect(currdb + \".db\", detect_types=sqlite3.PARSE_DECLTYPES)\n",
" cursor = conn.cursor()\n",
" # 读取数据库所有数据\n",
" cursor.execute(\"SELECT id, array, group_path, full_path FROM data\")\n",
" all_data = cursor.fetchall()\n",
" # 整理特征\n",
" data_by_group = defaultdict(list)\n",
" for vid, arr, group, path in all_data:\n",
" data_by_group[group].append({'id': vid, 'feature': arr, 'path': path})\n",
"\n",
" # 执行评估\n",
" results = evaluate_all_variants(K=3)\n",
" # 关闭数据库连接\n",
" conn.close()\n",
" # 打印 LaTeX 表格\n",
" print(results_to_latex(currdb, results))\n"
],
"id": "97f972940a003aa4",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"变形组: 动态水印2\n",
" Top-1 Precision: 1.0000\n",
" Recall@3: 0.1500\n",
" MAP: 0.5815\n",
"----------------------------------------\n",
"变形组: 水印1+2\n",
" Top-1 Precision: 0.8000\n",
" Recall@3: 0.1600\n",
" MAP: 0.5751\n",
"----------------------------------------\n",
"变形组: 水印1+2+滤镜\n",
" Top-1 Precision: 0.8000\n",
" Recall@3: 0.1500\n",
" MAP: 0.5685\n",
"----------------------------------------\n",
"变形组: 滤镜\n",
" Top-1 Precision: 0.9000\n",
" Recall@3: 0.1800\n",
" MAP: 0.6100\n",
"----------------------------------------\n",
"变形组: 静态水印1\n",
" Top-1 Precision: 0.9000\n",
" Recall@3: 0.1700\n",
" MAP: 0.6058\n",
"----------------------------------------\n"
"\\begin{table}[htbp]\n",
"\\centering\n",
"\\caption{Top-1 Precision, Recall@K and MAP for LBPTOP}\n",
"\\begin{tabular}{lccc}\n",
"\\hline\n",
"Deformation & Top-1 Precision & Recall@3 & MAP \\\\ \\hline\n",
"watermark1 & 0.5000 & 0.6000 & 0.0608 \\\\\n",
"watermark2 & 0.1000 & 0.5000 & 0.0322 \\\\\n",
"watermark1\\&2 & 0.0000 & 0.3000 & 0.0259 \\\\\n",
"filters & 0.3000 & 0.4000 & 0.0461 \\\\\n",
"filters\\&watermark1\\&2 & 0.1000 & 0.3000 & 0.0310 \\\\\n",
"avg & 0.2000 & 0.4200 & 0.0392 \\\\\n",
"\\hline\n",
"\\end{tabular}\n",
"\\end{table}\n",
"\n",
"\\begin{table}[htbp]\n",
"\\centering\n",
"\\caption{Top-1 Precision, Recall@K and MAP for i3d}\n",
"\\begin{tabular}{lccc}\n",
"\\hline\n",
"Deformation & Top-1 Precision & Recall@3 & MAP \\\\ \\hline\n",
"watermark1 & 0.7000 & 0.9000 & 0.0798 \\\\\n",
"watermark2 & 0.6000 & 0.6000 & 0.0662 \\\\\n",
"watermark1\\&2 & 0.5000 & 0.8000 & 0.0634 \\\\\n",
"filters & 0.8000 & 0.8000 & 0.0850 \\\\\n",
"filters\\&watermark1\\&2 & 0.5000 & 0.8000 & 0.0634 \\\\\n",
"avg & 0.6200 & 0.7800 & 0.0715 \\\\\n",
"\\hline\n",
"\\end{tabular}\n",
"\\end{table}\n",
"\n",
"\\begin{table}[htbp]\n",
"\\centering\n",
"\\caption{Top-1 Precision, Recall@K and MAP for slowfast}\n",
"\\begin{tabular}{lccc}\n",
"\\hline\n",
"Deformation & Top-1 Precision & Recall@3 & MAP \\\\ \\hline\n",
"watermark1 & 0.7000 & 0.9000 & 0.0817 \\\\\n",
"watermark2 & 0.7000 & 1.0000 & 0.0800 \\\\\n",
"watermark1\\&2 & 0.3000 & 0.9000 & 0.0583 \\\\\n",
"filters & 0.6000 & 0.9000 & 0.0758 \\\\\n",
"filters\\&watermark1\\&2 & 0.3000 & 0.8000 & 0.0558 \\\\\n",
"avg & 0.5200 & 0.9000 & 0.0703 \\\\\n",
"\\hline\n",
"\\end{tabular}\n",
"\\end{table}\n",
"\n"
]
}
],
"execution_count": 25
"execution_count": 64
}
],
"metadata": {