339 lines
11 KiB
Plaintext
339 lines
11 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T09:34:18.359212Z",
|
||
"start_time": "2025-04-14T09:34:18.356046Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import sqlite3\n",
|
||
"import numpy as np\n",
|
||
"import io\n",
|
||
"from numpy.linalg import norm\n",
|
||
"from collections import defaultdict\n",
|
||
"import os"
|
||
],
|
||
"id": "6917288db44004ea",
|
||
"outputs": [],
|
||
"execution_count": 59
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T09:34:18.371739Z",
|
||
"start_time": "2025-04-14T09:34:18.367219Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"\n",
|
||
"# 注册适配器与转换器:numpy数组 <-> BLOB\n",
|
||
"def adapt_array(arr):\n",
|
||
" out = io.BytesIO()\n",
|
||
" np.save(out, arr)\n",
|
||
" out.seek(0)\n",
|
||
" return sqlite3.Binary(out.read())\n",
|
||
"\n",
|
||
"\n",
|
||
"def convert_array(text):\n",
|
||
" out = io.BytesIO(text)\n",
|
||
" out.seek(0)\n",
|
||
" return np.load(out)\n",
|
||
"\n",
|
||
"\n",
|
||
"# 相似度函数 (余弦相似度)\n",
|
||
"def cosine_similarity(a, b):\n",
|
||
" return np.dot(a, b) / (norm(a) * norm(b))\n",
|
||
"\n"
|
||
],
|
||
"id": "dd9b84fa98c77e5a",
|
||
"outputs": [],
|
||
"execution_count": 60
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T09:34:18.388488Z",
|
||
"start_time": "2025-04-14T09:34:18.385501Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"def get_video_id(path):\n",
|
||
" filename = os.path.basename(path)\n",
|
||
" name = os.path.splitext(filename)[0] # 去掉扩展名\n",
|
||
" return name"
|
||
],
|
||
"id": "884f8c27d67a93d5",
|
||
"outputs": [],
|
||
"execution_count": 61
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T09:34:18.451821Z",
|
||
"start_time": "2025-04-14T09:34:18.443936Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"\n",
|
||
"# 计算指标函数\n",
|
||
"def compute_metrics(org_feats, query_feats, sim_feats, K=5):\n",
|
||
" precision_top1 = []\n",
|
||
" recall_at_k = []\n",
|
||
" AP_list = []\n",
|
||
"\n",
|
||
" # 构建搜索库 (原始视频 + 相似视频)\n",
|
||
" search_db = org_feats + sim_feats\n",
|
||
"\n",
|
||
" for query in query_feats:\n",
|
||
" similarities = []\n",
|
||
" query_feature = query['feature']\n",
|
||
"\n",
|
||
" # 计算query与所有库视频的相似度\n",
|
||
" for target in search_db:\n",
|
||
" sim = cosine_similarity(query_feature, target['feature'])\n",
|
||
" label = 1 if target in org_feats else 0 # 1代表匹配,0代表不匹配\n",
|
||
" if label == 1:\n",
|
||
" query_id = get_video_id(query['path'])\n",
|
||
" target_id = get_video_id(target['path'])\n",
|
||
" label = 1 if target_id == query_id else 0\n",
|
||
"\n",
|
||
" similarities.append((sim, label))\n",
|
||
"\n",
|
||
" # 相似度降序排序\n",
|
||
" similarities.sort(key=lambda x: x[0], reverse=True)\n",
|
||
" labels_sorted = [label for _, label in similarities]\n",
|
||
"\n",
|
||
" # Top-1 precision\n",
|
||
" precision_top1.append(labels_sorted[0])\n",
|
||
"\n",
|
||
" # Recall@K\n",
|
||
" recall = int(1 in labels_sorted[:K])\n",
|
||
" recall_at_k.append(recall)\n",
|
||
"\n",
|
||
" # Average Precision (AP)\n",
|
||
" hits, sum_precisions = 0, 0\n",
|
||
" for idx, label in enumerate(labels_sorted, start=1):\n",
|
||
" if label == 1:\n",
|
||
" hits += 1\n",
|
||
" sum_precisions += hits / idx\n",
|
||
" AP = sum_precisions / len(org_feats) if len(org_feats) else 0\n",
|
||
" AP_list.append(AP)\n",
|
||
"\n",
|
||
" # 返回指标均值\n",
|
||
" return {\n",
|
||
" 'Top-1 Precision': np.mean(precision_top1),\n",
|
||
" f'Recall@{K}': np.mean(recall_at_k),\n",
|
||
" 'MAP': np.mean(AP_list)\n",
|
||
" }\n",
|
||
"\n",
|
||
"\n",
|
||
"# 主函数,计算所有变形组的指标\n",
|
||
"def evaluate_all_variants(K=5):\n",
|
||
" org_feats = data_by_group['org']\n",
|
||
" sim_feats = data_by_group['sim']\n",
|
||
"\n",
|
||
" results = {}\n",
|
||
"\n",
|
||
" # 排除 'org' 和 'sim',其余都是变形组\n",
|
||
" variant_groups = [group for group in data_by_group if group not in ['org', 'sim']]\n",
|
||
"\n",
|
||
" for variant in variant_groups:\n",
|
||
" variant_feats = data_by_group[variant]\n",
|
||
" metrics = compute_metrics(org_feats, variant_feats, sim_feats, K)\n",
|
||
" results[variant] = metrics\n",
|
||
"\n",
|
||
" return results\n",
|
||
"\n",
|
||
"\n",
|
||
"def results_to_latex(currdb, results):\n",
|
||
" metric_names = list(next(iter(results.values())).keys()) # 获取所有指标名\n",
|
||
" header = \"Deformation & \" + \" & \".join(metric_names) + \" \\\\\\\\ \\\\hline\\n\"\n",
|
||
"\n",
|
||
" rows = \"\"\n",
|
||
" metric_sums = {metric: 0.0 for metric in metric_names}\n",
|
||
" count = 0\n",
|
||
"\n",
|
||
" for variant_id, variant_display in variant_order:\n",
|
||
" if variant_id not in results:\n",
|
||
" continue # 忽略不存在的组\n",
|
||
" metrics = results[variant_id]\n",
|
||
" count += 1\n",
|
||
" row = variant_display\n",
|
||
" for metric in metric_names:\n",
|
||
" val = metrics[metric]\n",
|
||
" metric_sums[metric] += val\n",
|
||
" row += f\" & {val:.4f}\"\n",
|
||
" rows += row + \" \\\\\\\\\\n\"\n",
|
||
"\n",
|
||
" # 添加平均值行\n",
|
||
" avg_row = \"avg\"\n",
|
||
" for metric in metric_names:\n",
|
||
" avg_val = metric_sums[metric] / count if count else 0.0\n",
|
||
" avg_row += f\" & {avg_val:.4f}\"\n",
|
||
" rows += avg_row + \" \\\\\\\\\\n\"\n",
|
||
"\n",
|
||
" latex_table = (\n",
|
||
" \"\\\\begin{table}[htbp]\\n\"\n",
|
||
" \"\\\\centering\\n\"\n",
|
||
" f\"\\\\caption{{Top-1 Precision, Recall@K and MAP for {currdb}}}\\n\"\n",
|
||
" \"\\\\begin{tabular}{l\" + \"c\" * len(metric_names) + \"}\\n\"\n",
|
||
" \"\\\\hline\\n\"\n",
|
||
" + header +\n",
|
||
" rows +\n",
|
||
" \"\\\\hline\\n\"\n",
|
||
" \"\\\\end{tabular}\\n\"\n",
|
||
" \"\\\\end{table}\\n\"\n",
|
||
" )\n",
|
||
" return latex_table"
|
||
],
|
||
"id": "282bfd74f5982041",
|
||
"outputs": [],
|
||
"execution_count": 62
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T09:34:18.460011Z",
|
||
"start_time": "2025-04-14T09:34:18.454827Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"variant_order = [\n",
|
||
" (r\"静态水印1\", r\"watermark1\"),\n",
|
||
" (r\"动态水印2\", r\"watermark2\"),\n",
|
||
" (r\"水印1+2\", r\"watermark1\\&2\"),\n",
|
||
" (r\"滤镜\", r\"filters\"),\n",
|
||
" (r\"水印1+2+滤镜\", r\"filters\\&watermark1\\&2\"),\n",
|
||
" (r\"watermark1\", r\"watermark1\"),\n",
|
||
" (r\"watermark2\", r\"watermark2\"),\n",
|
||
" (r\"watermark1+2\", r\"watermark11\\&2\"),\n",
|
||
" (r\"filters\", r\"filters\"),\n",
|
||
" (r\"filters+watermark1+2\", r\"filters\\&watermark1\\&2\"),\n",
|
||
"]"
|
||
],
|
||
"id": "12e194a6bf64b8ac",
|
||
"outputs": [],
|
||
"execution_count": 63
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T09:34:18.512805Z",
|
||
"start_time": "2025-04-14T09:34:18.473552Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"\n",
|
||
"for currdb in [\"LBPTOP\", \"i3d\", \"slowfast\"]:\n",
|
||
" sqlite3.register_adapter(np.ndarray, adapt_array)\n",
|
||
" sqlite3.register_converter(\"BLOB\", convert_array)\n",
|
||
"\n",
|
||
" # 连接SQLite数据库\n",
|
||
" conn = sqlite3.connect(currdb + \".db\", detect_types=sqlite3.PARSE_DECLTYPES)\n",
|
||
" cursor = conn.cursor()\n",
|
||
" # 读取数据库所有数据\n",
|
||
" cursor.execute(\"SELECT id, array, group_path, full_path FROM data\")\n",
|
||
" all_data = cursor.fetchall()\n",
|
||
" # 整理特征\n",
|
||
" data_by_group = defaultdict(list)\n",
|
||
" for vid, arr, group, path in all_data:\n",
|
||
" data_by_group[group].append({'id': vid, 'feature': arr, 'path': path})\n",
|
||
"\n",
|
||
" # 执行评估\n",
|
||
" results = evaluate_all_variants(K=3)\n",
|
||
" # 关闭数据库连接\n",
|
||
" conn.close()\n",
|
||
" # 打印 LaTeX 表格\n",
|
||
" print(results_to_latex(currdb, results))\n"
|
||
],
|
||
"id": "97f972940a003aa4",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\\begin{table}[htbp]\n",
|
||
"\\centering\n",
|
||
"\\caption{Top-1 Precision, Recall@K and MAP for LBPTOP}\n",
|
||
"\\begin{tabular}{lccc}\n",
|
||
"\\hline\n",
|
||
"Deformation & Top-1 Precision & Recall@3 & MAP \\\\ \\hline\n",
|
||
"watermark1 & 0.5000 & 0.6000 & 0.0608 \\\\\n",
|
||
"watermark2 & 0.1000 & 0.5000 & 0.0322 \\\\\n",
|
||
"watermark1\\&2 & 0.0000 & 0.3000 & 0.0259 \\\\\n",
|
||
"filters & 0.3000 & 0.4000 & 0.0461 \\\\\n",
|
||
"filters\\&watermark1\\&2 & 0.1000 & 0.3000 & 0.0310 \\\\\n",
|
||
"avg & 0.2000 & 0.4200 & 0.0392 \\\\\n",
|
||
"\\hline\n",
|
||
"\\end{tabular}\n",
|
||
"\\end{table}\n",
|
||
"\n",
|
||
"\\begin{table}[htbp]\n",
|
||
"\\centering\n",
|
||
"\\caption{Top-1 Precision, Recall@K and MAP for i3d}\n",
|
||
"\\begin{tabular}{lccc}\n",
|
||
"\\hline\n",
|
||
"Deformation & Top-1 Precision & Recall@3 & MAP \\\\ \\hline\n",
|
||
"watermark1 & 0.7000 & 0.9000 & 0.0798 \\\\\n",
|
||
"watermark2 & 0.6000 & 0.6000 & 0.0662 \\\\\n",
|
||
"watermark1\\&2 & 0.5000 & 0.8000 & 0.0634 \\\\\n",
|
||
"filters & 0.8000 & 0.8000 & 0.0850 \\\\\n",
|
||
"filters\\&watermark1\\&2 & 0.5000 & 0.8000 & 0.0634 \\\\\n",
|
||
"avg & 0.6200 & 0.7800 & 0.0715 \\\\\n",
|
||
"\\hline\n",
|
||
"\\end{tabular}\n",
|
||
"\\end{table}\n",
|
||
"\n",
|
||
"\\begin{table}[htbp]\n",
|
||
"\\centering\n",
|
||
"\\caption{Top-1 Precision, Recall@K and MAP for slowfast}\n",
|
||
"\\begin{tabular}{lccc}\n",
|
||
"\\hline\n",
|
||
"Deformation & Top-1 Precision & Recall@3 & MAP \\\\ \\hline\n",
|
||
"watermark1 & 0.7000 & 0.9000 & 0.0817 \\\\\n",
|
||
"watermark2 & 0.7000 & 1.0000 & 0.0800 \\\\\n",
|
||
"watermark1\\&2 & 0.3000 & 0.9000 & 0.0583 \\\\\n",
|
||
"filters & 0.6000 & 0.9000 & 0.0758 \\\\\n",
|
||
"filters\\&watermark1\\&2 & 0.3000 & 0.8000 & 0.0558 \\\\\n",
|
||
"avg & 0.5200 & 0.9000 & 0.0703 \\\\\n",
|
||
"\\hline\n",
|
||
"\\end{tabular}\n",
|
||
"\\end{table}\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 64
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 2
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython2",
|
||
"version": "2.7.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|