Files
COMP90044_a1/analy.ipynb
2025-04-14 21:47:52 +10:00

339 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.359212Z",
"start_time": "2025-04-14T09:34:18.356046Z"
}
},
"cell_type": "code",
"source": [
"import sqlite3\n",
"import numpy as np\n",
"import io\n",
"from numpy.linalg import norm\n",
"from collections import defaultdict\n",
"import os"
],
"id": "6917288db44004ea",
"outputs": [],
"execution_count": 59
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.371739Z",
"start_time": "2025-04-14T09:34:18.367219Z"
}
},
"cell_type": "code",
"source": [
"\n",
"# 注册适配器与转换器numpy数组 <-> BLOB\n",
"def adapt_array(arr):\n",
" out = io.BytesIO()\n",
" np.save(out, arr)\n",
" out.seek(0)\n",
" return sqlite3.Binary(out.read())\n",
"\n",
"\n",
"def convert_array(text):\n",
" out = io.BytesIO(text)\n",
" out.seek(0)\n",
" return np.load(out)\n",
"\n",
"\n",
"# 相似度函数 (余弦相似度)\n",
"def cosine_similarity(a, b):\n",
" return np.dot(a, b) / (norm(a) * norm(b))\n",
"\n"
],
"id": "dd9b84fa98c77e5a",
"outputs": [],
"execution_count": 60
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.388488Z",
"start_time": "2025-04-14T09:34:18.385501Z"
}
},
"cell_type": "code",
"source": [
"def get_video_id(path):\n",
" filename = os.path.basename(path)\n",
" name = os.path.splitext(filename)[0] # 去掉扩展名\n",
" return name"
],
"id": "884f8c27d67a93d5",
"outputs": [],
"execution_count": 61
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.451821Z",
"start_time": "2025-04-14T09:34:18.443936Z"
}
},
"cell_type": "code",
"source": [
"\n",
"# 计算指标函数\n",
"def compute_metrics(org_feats, query_feats, sim_feats, K=5):\n",
" precision_top1 = []\n",
" recall_at_k = []\n",
" AP_list = []\n",
"\n",
" # 构建搜索库 (原始视频 + 相似视频)\n",
" search_db = org_feats + sim_feats\n",
"\n",
" for query in query_feats:\n",
" similarities = []\n",
" query_feature = query['feature']\n",
"\n",
" # 计算query与所有库视频的相似度\n",
" for target in search_db:\n",
" sim = cosine_similarity(query_feature, target['feature'])\n",
" label = 1 if target in org_feats else 0 # 1代表匹配0代表不匹配\n",
" if label == 1:\n",
" query_id = get_video_id(query['path'])\n",
" target_id = get_video_id(target['path'])\n",
" label = 1 if target_id == query_id else 0\n",
"\n",
" similarities.append((sim, label))\n",
"\n",
" # 相似度降序排序\n",
" similarities.sort(key=lambda x: x[0], reverse=True)\n",
" labels_sorted = [label for _, label in similarities]\n",
"\n",
" # Top-1 precision\n",
" precision_top1.append(labels_sorted[0])\n",
"\n",
" # Recall@K\n",
" recall = int(1 in labels_sorted[:K])\n",
" recall_at_k.append(recall)\n",
"\n",
" # Average Precision (AP)\n",
" hits, sum_precisions = 0, 0\n",
" for idx, label in enumerate(labels_sorted, start=1):\n",
" if label == 1:\n",
" hits += 1\n",
" sum_precisions += hits / idx\n",
" AP = sum_precisions / len(org_feats) if len(org_feats) else 0\n",
" AP_list.append(AP)\n",
"\n",
" # 返回指标均值\n",
" return {\n",
" 'Top-1 Precision': np.mean(precision_top1),\n",
" f'Recall@{K}': np.mean(recall_at_k),\n",
" 'MAP': np.mean(AP_list)\n",
" }\n",
"\n",
"\n",
"# 主函数,计算所有变形组的指标\n",
"def evaluate_all_variants(K=5):\n",
" org_feats = data_by_group['org']\n",
" sim_feats = data_by_group['sim']\n",
"\n",
" results = {}\n",
"\n",
" # 排除 'org' 和 'sim',其余都是变形组\n",
" variant_groups = [group for group in data_by_group if group not in ['org', 'sim']]\n",
"\n",
" for variant in variant_groups:\n",
" variant_feats = data_by_group[variant]\n",
" metrics = compute_metrics(org_feats, variant_feats, sim_feats, K)\n",
" results[variant] = metrics\n",
"\n",
" return results\n",
"\n",
"\n",
"def results_to_latex(currdb, results):\n",
" metric_names = list(next(iter(results.values())).keys()) # 获取所有指标名\n",
" header = \"Deformation & \" + \" & \".join(metric_names) + \" \\\\\\\\ \\\\hline\\n\"\n",
"\n",
" rows = \"\"\n",
" metric_sums = {metric: 0.0 for metric in metric_names}\n",
" count = 0\n",
"\n",
" for variant_id, variant_display in variant_order:\n",
" if variant_id not in results:\n",
" continue # 忽略不存在的组\n",
" metrics = results[variant_id]\n",
" count += 1\n",
" row = variant_display\n",
" for metric in metric_names:\n",
" val = metrics[metric]\n",
" metric_sums[metric] += val\n",
" row += f\" & {val:.4f}\"\n",
" rows += row + \" \\\\\\\\\\n\"\n",
"\n",
" # 添加平均值行\n",
" avg_row = \"avg\"\n",
" for metric in metric_names:\n",
" avg_val = metric_sums[metric] / count if count else 0.0\n",
" avg_row += f\" & {avg_val:.4f}\"\n",
" rows += avg_row + \" \\\\\\\\\\n\"\n",
"\n",
" latex_table = (\n",
" \"\\\\begin{table}[htbp]\\n\"\n",
" \"\\\\centering\\n\"\n",
" f\"\\\\caption{{Top-1 Precision, Recall@K and MAP for {currdb}}}\\n\"\n",
" \"\\\\begin{tabular}{l\" + \"c\" * len(metric_names) + \"}\\n\"\n",
" \"\\\\hline\\n\"\n",
" + header +\n",
" rows +\n",
" \"\\\\hline\\n\"\n",
" \"\\\\end{tabular}\\n\"\n",
" \"\\\\end{table}\\n\"\n",
" )\n",
" return latex_table"
],
"id": "282bfd74f5982041",
"outputs": [],
"execution_count": 62
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.460011Z",
"start_time": "2025-04-14T09:34:18.454827Z"
}
},
"cell_type": "code",
"source": [
"variant_order = [\n",
" (r\"静态水印1\", r\"watermark1\"),\n",
" (r\"动态水印2\", r\"watermark2\"),\n",
" (r\"水印1+2\", r\"watermark1\\&2\"),\n",
" (r\"滤镜\", r\"filters\"),\n",
" (r\"水印1+2+滤镜\", r\"filters\\&watermark1\\&2\"),\n",
" (r\"watermark1\", r\"watermark1\"),\n",
" (r\"watermark2\", r\"watermark2\"),\n",
" (r\"watermark1+2\", r\"watermark11\\&2\"),\n",
" (r\"filters\", r\"filters\"),\n",
" (r\"filters+watermark1+2\", r\"filters\\&watermark1\\&2\"),\n",
"]"
],
"id": "12e194a6bf64b8ac",
"outputs": [],
"execution_count": 63
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T09:34:18.512805Z",
"start_time": "2025-04-14T09:34:18.473552Z"
}
},
"cell_type": "code",
"source": [
"\n",
"for currdb in [\"LBPTOP\", \"i3d\", \"slowfast\"]:\n",
" sqlite3.register_adapter(np.ndarray, adapt_array)\n",
" sqlite3.register_converter(\"BLOB\", convert_array)\n",
"\n",
" # 连接SQLite数据库\n",
" conn = sqlite3.connect(currdb + \".db\", detect_types=sqlite3.PARSE_DECLTYPES)\n",
" cursor = conn.cursor()\n",
" # 读取数据库所有数据\n",
" cursor.execute(\"SELECT id, array, group_path, full_path FROM data\")\n",
" all_data = cursor.fetchall()\n",
" # 整理特征\n",
" data_by_group = defaultdict(list)\n",
" for vid, arr, group, path in all_data:\n",
" data_by_group[group].append({'id': vid, 'feature': arr, 'path': path})\n",
"\n",
" # 执行评估\n",
" results = evaluate_all_variants(K=3)\n",
" # 关闭数据库连接\n",
" conn.close()\n",
" # 打印 LaTeX 表格\n",
" print(results_to_latex(currdb, results))\n"
],
"id": "97f972940a003aa4",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\\begin{table}[htbp]\n",
"\\centering\n",
"\\caption{Top-1 Precision, Recall@K and MAP for LBPTOP}\n",
"\\begin{tabular}{lccc}\n",
"\\hline\n",
"Deformation & Top-1 Precision & Recall@3 & MAP \\\\ \\hline\n",
"watermark1 & 0.5000 & 0.6000 & 0.0608 \\\\\n",
"watermark2 & 0.1000 & 0.5000 & 0.0322 \\\\\n",
"watermark1\\&2 & 0.0000 & 0.3000 & 0.0259 \\\\\n",
"filters & 0.3000 & 0.4000 & 0.0461 \\\\\n",
"filters\\&watermark1\\&2 & 0.1000 & 0.3000 & 0.0310 \\\\\n",
"avg & 0.2000 & 0.4200 & 0.0392 \\\\\n",
"\\hline\n",
"\\end{tabular}\n",
"\\end{table}\n",
"\n",
"\\begin{table}[htbp]\n",
"\\centering\n",
"\\caption{Top-1 Precision, Recall@K and MAP for i3d}\n",
"\\begin{tabular}{lccc}\n",
"\\hline\n",
"Deformation & Top-1 Precision & Recall@3 & MAP \\\\ \\hline\n",
"watermark1 & 0.7000 & 0.9000 & 0.0798 \\\\\n",
"watermark2 & 0.6000 & 0.6000 & 0.0662 \\\\\n",
"watermark1\\&2 & 0.5000 & 0.8000 & 0.0634 \\\\\n",
"filters & 0.8000 & 0.8000 & 0.0850 \\\\\n",
"filters\\&watermark1\\&2 & 0.5000 & 0.8000 & 0.0634 \\\\\n",
"avg & 0.6200 & 0.7800 & 0.0715 \\\\\n",
"\\hline\n",
"\\end{tabular}\n",
"\\end{table}\n",
"\n",
"\\begin{table}[htbp]\n",
"\\centering\n",
"\\caption{Top-1 Precision, Recall@K and MAP for slowfast}\n",
"\\begin{tabular}{lccc}\n",
"\\hline\n",
"Deformation & Top-1 Precision & Recall@3 & MAP \\\\ \\hline\n",
"watermark1 & 0.7000 & 0.9000 & 0.0817 \\\\\n",
"watermark2 & 0.7000 & 1.0000 & 0.0800 \\\\\n",
"watermark1\\&2 & 0.3000 & 0.9000 & 0.0583 \\\\\n",
"filters & 0.6000 & 0.9000 & 0.0758 \\\\\n",
"filters\\&watermark1\\&2 & 0.3000 & 0.8000 & 0.0558 \\\\\n",
"avg & 0.5200 & 0.9000 & 0.0703 \\\\\n",
"\\hline\n",
"\\end{tabular}\n",
"\\end{table}\n",
"\n"
]
}
],
"execution_count": 64
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}