293 lines
8.0 KiB
Plaintext
293 lines
8.0 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T06:58:32.199713Z",
|
||
"start_time": "2025-04-14T06:58:32.134049Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"import sqlite3\n",
|
||
"import numpy as np\n",
|
||
"import io\n",
|
||
"from numpy.linalg import norm\n",
|
||
"from collections import defaultdict"
|
||
],
|
||
"id": "6917288db44004ea",
|
||
"outputs": [],
|
||
"execution_count": 1
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T06:58:34.129528Z",
|
||
"start_time": "2025-04-14T06:58:34.126022Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"currdb = \"i3d.db\"\n",
|
||
"#currdb = \"LBPTOP.db\"\n",
|
||
"currdb = \"slowfast.db\""
|
||
],
|
||
"id": "dd9b84fa98c77e5a",
|
||
"outputs": [],
|
||
"execution_count": 2
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T06:58:36.801611Z",
|
||
"start_time": "2025-04-14T06:58:36.798114Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# 注册适配器与转换器:numpy数组 <-> BLOB\n",
|
||
"def adapt_array(arr):\n",
|
||
" out = io.BytesIO()\n",
|
||
" np.save(out, arr)\n",
|
||
" out.seek(0)\n",
|
||
" return sqlite3.Binary(out.read())\n",
|
||
"\n",
|
||
"def convert_array(text):\n",
|
||
" out = io.BytesIO(text)\n",
|
||
" out.seek(0)\n",
|
||
" return np.load(out)"
|
||
],
|
||
"id": "fa2ab158cbe0aa98",
|
||
"outputs": [],
|
||
"execution_count": 3
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T07:01:06.063185Z",
|
||
"start_time": "2025-04-14T07:01:06.060008Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"\n",
|
||
"sqlite3.register_adapter(np.ndarray, adapt_array)\n",
|
||
"sqlite3.register_converter(\"BLOB\", convert_array)\n",
|
||
"\n",
|
||
"# 连接SQLite数据库\n",
|
||
"conn = sqlite3.connect(currdb, detect_types=sqlite3.PARSE_DECLTYPES)\n",
|
||
"cursor = conn.cursor()"
|
||
],
|
||
"id": "629eb156f332c0bc",
|
||
"outputs": [],
|
||
"execution_count": 12
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T07:01:06.372809Z",
|
||
"start_time": "2025-04-14T07:01:06.362911Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# 读取数据库所有数据\n",
|
||
"cursor.execute(\"SELECT id, array, group_path, full_path FROM data\")\n",
|
||
"all_data = cursor.fetchall()"
|
||
],
|
||
"id": "ba7902156f1b6117",
|
||
"outputs": [],
|
||
"execution_count": 13
|
||
},
|
||
{
|
||
"metadata": {
|
||
"collapsed": true,
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T07:01:06.786016Z",
|
||
"start_time": "2025-04-14T07:01:06.779812Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# 整理特征\n",
|
||
"data_by_group = defaultdict(list)\n",
|
||
"for vid, arr, group, path in all_data:\n",
|
||
" data_by_group[group].append({'id': vid, 'feature': arr, 'path': path})\n",
|
||
"\n",
|
||
"# 相似度函数 (余弦相似度)\n",
|
||
"def cosine_similarity(a, b):\n",
|
||
" return np.dot(a, b) / (norm(a) * norm(b))\n",
|
||
"\n",
|
||
"# 计算指标函数\n",
|
||
"def compute_metrics(org_feats, query_feats, sim_feats, K=5):\n",
|
||
" precision_top1 = []\n",
|
||
" recall_at_k = []\n",
|
||
" AP_list = []\n",
|
||
"\n",
|
||
" # 构建搜索库 (原始视频 + 相似视频)\n",
|
||
" search_db = org_feats + sim_feats\n",
|
||
"\n",
|
||
" for query in query_feats:\n",
|
||
" similarities = []\n",
|
||
" query_feature = query['feature']\n",
|
||
"\n",
|
||
" # 计算query与所有库视频的相似度\n",
|
||
" for target in search_db:\n",
|
||
" sim = cosine_similarity(query_feature, target['feature'])\n",
|
||
" label = 1 if target in org_feats else 0 # 1代表匹配,0代表不匹配\n",
|
||
" similarities.append((sim, label))\n",
|
||
"\n",
|
||
" # 相似度降序排序\n",
|
||
" similarities.sort(key=lambda x: x[0], reverse=True)\n",
|
||
" labels_sorted = [label for _, label in similarities]\n",
|
||
"\n",
|
||
" # Top-1 precision\n",
|
||
" precision_top1.append(labels_sorted[0])\n",
|
||
"\n",
|
||
" # Recall@K\n",
|
||
" recall = sum(labels_sorted[:K]) / len(org_feats)\n",
|
||
" recall_at_k.append(recall)\n",
|
||
"\n",
|
||
" # Average Precision (AP)\n",
|
||
" hits, sum_precisions = 0, 0\n",
|
||
" for idx, label in enumerate(labels_sorted, start=1):\n",
|
||
" if label == 1:\n",
|
||
" hits += 1\n",
|
||
" sum_precisions += hits / idx\n",
|
||
" AP = sum_precisions / len(org_feats) if len(org_feats) else 0\n",
|
||
" AP_list.append(AP)\n",
|
||
"\n",
|
||
" # 返回指标均值\n",
|
||
" return {\n",
|
||
" 'Top-1 Precision': np.mean(precision_top1),\n",
|
||
" f'Recall@{K}': np.mean(recall_at_k),\n",
|
||
" 'MAP': np.mean(AP_list)\n",
|
||
" }"
|
||
],
|
||
"id": "initial_id",
|
||
"outputs": [],
|
||
"execution_count": 14
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T07:01:07.745386Z",
|
||
"start_time": "2025-04-14T07:01:07.740942Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": "data_by_group['org'][0]['feature'].shape",
|
||
"id": "55e57dc7ec56b732",
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"(2304,)"
|
||
]
|
||
},
|
||
"execution_count": 15,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"execution_count": 15
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-04-14T06:00:12.392784Z",
|
||
"start_time": "2025-04-14T06:00:12.382725Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"# 主函数,计算所有变形组的指标\n",
|
||
"def evaluate_all_variants(K=5):\n",
|
||
" org_feats = data_by_group['org']\n",
|
||
" sim_feats = data_by_group['sim']\n",
|
||
"\n",
|
||
" results = {}\n",
|
||
"\n",
|
||
" # 排除 'org' 和 'sim',其余都是变形组\n",
|
||
" variant_groups = [group for group in data_by_group if group not in ['org', 'sim']]\n",
|
||
"\n",
|
||
" for variant in variant_groups:\n",
|
||
" variant_feats = data_by_group[variant]\n",
|
||
" metrics = compute_metrics(org_feats, variant_feats, sim_feats, K)\n",
|
||
" results[variant] = metrics\n",
|
||
"\n",
|
||
" return results\n",
|
||
"\n",
|
||
"# 执行评估\n",
|
||
"results = evaluate_all_variants(K=3)\n",
|
||
"\n",
|
||
"# 输出结果\n",
|
||
"for variant, metrics in results.items():\n",
|
||
" print(f\"变形组: {variant}\")\n",
|
||
" for metric_name, value in metrics.items():\n",
|
||
" print(f\" {metric_name}: {value:.4f}\")\n",
|
||
" print(\"-\" * 40)\n",
|
||
"\n",
|
||
"# 关闭数据库连接\n",
|
||
"conn.close()"
|
||
],
|
||
"id": "d9a0103dbc3f1bef",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"变形组: 动态水印2\n",
|
||
" Top-1 Precision: 1.0000\n",
|
||
" Recall@3: 0.1500\n",
|
||
" MAP: 0.5815\n",
|
||
"----------------------------------------\n",
|
||
"变形组: 水印1+2\n",
|
||
" Top-1 Precision: 0.8000\n",
|
||
" Recall@3: 0.1600\n",
|
||
" MAP: 0.5751\n",
|
||
"----------------------------------------\n",
|
||
"变形组: 水印1+2+滤镜\n",
|
||
" Top-1 Precision: 0.8000\n",
|
||
" Recall@3: 0.1500\n",
|
||
" MAP: 0.5685\n",
|
||
"----------------------------------------\n",
|
||
"变形组: 滤镜\n",
|
||
" Top-1 Precision: 0.9000\n",
|
||
" Recall@3: 0.1800\n",
|
||
" MAP: 0.6100\n",
|
||
"----------------------------------------\n",
|
||
"变形组: 静态水印1\n",
|
||
" Top-1 Precision: 0.9000\n",
|
||
" Recall@3: 0.1700\n",
|
||
" MAP: 0.6058\n",
|
||
"----------------------------------------\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 25
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 2
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython2",
|
||
"version": "2.7.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|