Files
COMP90044_a1/analy.ipynb
2025-04-14 17:15:38 +10:00

293 lines
8.0 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T06:58:32.199713Z",
"start_time": "2025-04-14T06:58:32.134049Z"
}
},
"cell_type": "code",
"source": [
"import sqlite3\n",
"import numpy as np\n",
"import io\n",
"from numpy.linalg import norm\n",
"from collections import defaultdict"
],
"id": "6917288db44004ea",
"outputs": [],
"execution_count": 1
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T06:58:34.129528Z",
"start_time": "2025-04-14T06:58:34.126022Z"
}
},
"cell_type": "code",
"source": [
"currdb = \"i3d.db\"\n",
"#currdb = \"LBPTOP.db\"\n",
"currdb = \"slowfast.db\""
],
"id": "dd9b84fa98c77e5a",
"outputs": [],
"execution_count": 2
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T06:58:36.801611Z",
"start_time": "2025-04-14T06:58:36.798114Z"
}
},
"cell_type": "code",
"source": [
"# 注册适配器与转换器numpy数组 <-> BLOB\n",
"def adapt_array(arr):\n",
" out = io.BytesIO()\n",
" np.save(out, arr)\n",
" out.seek(0)\n",
" return sqlite3.Binary(out.read())\n",
"\n",
"def convert_array(text):\n",
" out = io.BytesIO(text)\n",
" out.seek(0)\n",
" return np.load(out)"
],
"id": "fa2ab158cbe0aa98",
"outputs": [],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T07:01:06.063185Z",
"start_time": "2025-04-14T07:01:06.060008Z"
}
},
"cell_type": "code",
"source": [
"\n",
"sqlite3.register_adapter(np.ndarray, adapt_array)\n",
"sqlite3.register_converter(\"BLOB\", convert_array)\n",
"\n",
"# 连接SQLite数据库\n",
"conn = sqlite3.connect(currdb, detect_types=sqlite3.PARSE_DECLTYPES)\n",
"cursor = conn.cursor()"
],
"id": "629eb156f332c0bc",
"outputs": [],
"execution_count": 12
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T07:01:06.372809Z",
"start_time": "2025-04-14T07:01:06.362911Z"
}
},
"cell_type": "code",
"source": [
"# 读取数据库所有数据\n",
"cursor.execute(\"SELECT id, array, group_path, full_path FROM data\")\n",
"all_data = cursor.fetchall()"
],
"id": "ba7902156f1b6117",
"outputs": [],
"execution_count": 13
},
{
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2025-04-14T07:01:06.786016Z",
"start_time": "2025-04-14T07:01:06.779812Z"
}
},
"cell_type": "code",
"source": [
"# 整理特征\n",
"data_by_group = defaultdict(list)\n",
"for vid, arr, group, path in all_data:\n",
" data_by_group[group].append({'id': vid, 'feature': arr, 'path': path})\n",
"\n",
"# 相似度函数 (余弦相似度)\n",
"def cosine_similarity(a, b):\n",
" return np.dot(a, b) / (norm(a) * norm(b))\n",
"\n",
"# 计算指标函数\n",
"def compute_metrics(org_feats, query_feats, sim_feats, K=5):\n",
" precision_top1 = []\n",
" recall_at_k = []\n",
" AP_list = []\n",
"\n",
" # 构建搜索库 (原始视频 + 相似视频)\n",
" search_db = org_feats + sim_feats\n",
"\n",
" for query in query_feats:\n",
" similarities = []\n",
" query_feature = query['feature']\n",
"\n",
" # 计算query与所有库视频的相似度\n",
" for target in search_db:\n",
" sim = cosine_similarity(query_feature, target['feature'])\n",
" label = 1 if target in org_feats else 0 # 1代表匹配0代表不匹配\n",
" similarities.append((sim, label))\n",
"\n",
" # 相似度降序排序\n",
" similarities.sort(key=lambda x: x[0], reverse=True)\n",
" labels_sorted = [label for _, label in similarities]\n",
"\n",
" # Top-1 precision\n",
" precision_top1.append(labels_sorted[0])\n",
"\n",
" # Recall@K\n",
" recall = sum(labels_sorted[:K]) / len(org_feats)\n",
" recall_at_k.append(recall)\n",
"\n",
" # Average Precision (AP)\n",
" hits, sum_precisions = 0, 0\n",
" for idx, label in enumerate(labels_sorted, start=1):\n",
" if label == 1:\n",
" hits += 1\n",
" sum_precisions += hits / idx\n",
" AP = sum_precisions / len(org_feats) if len(org_feats) else 0\n",
" AP_list.append(AP)\n",
"\n",
" # 返回指标均值\n",
" return {\n",
" 'Top-1 Precision': np.mean(precision_top1),\n",
" f'Recall@{K}': np.mean(recall_at_k),\n",
" 'MAP': np.mean(AP_list)\n",
" }"
],
"id": "initial_id",
"outputs": [],
"execution_count": 14
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T07:01:07.745386Z",
"start_time": "2025-04-14T07:01:07.740942Z"
}
},
"cell_type": "code",
"source": "data_by_group['org'][0]['feature'].shape",
"id": "55e57dc7ec56b732",
"outputs": [
{
"data": {
"text/plain": [
"(2304,)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 15
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2025-04-14T06:00:12.392784Z",
"start_time": "2025-04-14T06:00:12.382725Z"
}
},
"cell_type": "code",
"source": [
"# 主函数,计算所有变形组的指标\n",
"def evaluate_all_variants(K=5):\n",
" org_feats = data_by_group['org']\n",
" sim_feats = data_by_group['sim']\n",
"\n",
" results = {}\n",
"\n",
" # 排除 'org' 和 'sim',其余都是变形组\n",
" variant_groups = [group for group in data_by_group if group not in ['org', 'sim']]\n",
"\n",
" for variant in variant_groups:\n",
" variant_feats = data_by_group[variant]\n",
" metrics = compute_metrics(org_feats, variant_feats, sim_feats, K)\n",
" results[variant] = metrics\n",
"\n",
" return results\n",
"\n",
"# 执行评估\n",
"results = evaluate_all_variants(K=3)\n",
"\n",
"# 输出结果\n",
"for variant, metrics in results.items():\n",
" print(f\"变形组: {variant}\")\n",
" for metric_name, value in metrics.items():\n",
" print(f\" {metric_name}: {value:.4f}\")\n",
" print(\"-\" * 40)\n",
"\n",
"# 关闭数据库连接\n",
"conn.close()"
],
"id": "d9a0103dbc3f1bef",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"变形组: 动态水印2\n",
" Top-1 Precision: 1.0000\n",
" Recall@3: 0.1500\n",
" MAP: 0.5815\n",
"----------------------------------------\n",
"变形组: 水印1+2\n",
" Top-1 Precision: 0.8000\n",
" Recall@3: 0.1600\n",
" MAP: 0.5751\n",
"----------------------------------------\n",
"变形组: 水印1+2+滤镜\n",
" Top-1 Precision: 0.8000\n",
" Recall@3: 0.1500\n",
" MAP: 0.5685\n",
"----------------------------------------\n",
"变形组: 滤镜\n",
" Top-1 Precision: 0.9000\n",
" Recall@3: 0.1800\n",
" MAP: 0.6100\n",
"----------------------------------------\n",
"变形组: 静态水印1\n",
" Top-1 Precision: 0.9000\n",
" Recall@3: 0.1700\n",
" MAP: 0.6058\n",
"----------------------------------------\n"
]
}
],
"execution_count": 25
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}