Files
cyzg/backend.py
2025-08-20 16:25:51 +08:00

196 lines
5.9 KiB
Python

import os
import random
import sqlite3
import time
import zipfile
class Backend:
def __init__(self, db="data.db"):
self.db_path = db
# 先解压
if not os.path.exists(self.db_path) and os.path.exists(self.db_path + ".zip"):
with zipfile.ZipFile(self.db_path + ".zip", 'r') as zip_ref:
zip_ref.extractall(".")
self.time = time.time()
self.title_filter = self.get_source_type()
self.global_filter = []
def get_question(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 构造 IN 子句的占位符
placeholders = ','.join('?' for _ in self.title_filter)
# 总权重查询
total_query = f"""
SELECT SUM(count)
FROM questions
WHERE count > 0 AND title IN ({placeholders})
"""
#print(total_query)
#print(self.title_filter)
cursor.execute(total_query, self.title_filter)
total_weights = cursor.fetchone()[0]
if total_weights is None or total_weights == 0:
conn.close()
return None
# 生成随机数 [0, total_weights)
random_num = random.uniform(0, total_weights)
# 主查询:加权随机选择
query = f"""
SELECT *
FROM (
SELECT *,
SUM(count) OVER (ORDER BY id ROWS UNBOUNDED PRECEDING) AS cum_weight
FROM questions
WHERE count > 0 AND title IN ({placeholders})
)
WHERE ? < cum_weight
ORDER BY cum_weight
LIMIT 1;
"""
params = tuple(self.title_filter) + (random_num,)
cursor.execute(query, params)
row = cursor.fetchone()
print(total_weights, row)
conn.close()
self.time = time.time()
return list(row) if row else None
def update(self, id, state):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
# 作对-1 错了+2
if state != 0:
query = """
UPDATE questions
SET count = count + 2
WHERE id = ?
and count > 0"""
else:
query = """
UPDATE questions
SET count = count - 1
WHERE id = ?
"""
cursor.execute(query, (id,))
# 记录用时
query = "INSERT INTO answers_history (id, time_used, state) VALUES (?, ?, ?)"
cursor.execute(query, (id, time.time() - self.time, state))
conn.commit()
conn.close()
def reset_time(self):
self.time = time.time()
def get_acc(self, top_n=0):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
if top_n > 0:
query = """
SELECT CAST(SUM(CASE WHEN state = 0 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*) AS accuracy
FROM (SELECT state
FROM answers_history
ORDER BY time DESC
LIMIT ?)
"""
params = (top_n,)
else:
query = """
SELECT CAST(SUM(CASE WHEN state = 0 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*) AS accuracy FROM answers_history
"""
params = ()
cursor.execute(query, params)
result = cursor.fetchone()[0]
return result * 100 if result is not None else 0.0
except sqlite3.Error as e:
print(f"数据库错误:{e}")
return 0.0
finally:
conn.close()
def get_avg_time(self, top_n=0):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
try:
if top_n > 0:
query = """
SELECT CAST(SUM(time_used) AS FLOAT) / COUNT(*) AS accuracy
FROM (SELECT time_used
FROM answers_history
where time_used is not null
ORDER BY time DESC
LIMIT ?)
"""
params = (top_n,)
else:
query = """
SELECT CAST(SUM(time_used) AS FLOAT) / COUNT(*) AS accuracy FROM answers_history where time_used is not null
"""
params = ()
cursor.execute(query, params)
result = cursor.fetchone()[0]
return result if result is not None else 0.0
except sqlite3.Error as e:
print(f"数据库错误:{e}")
return 0.0
finally:
conn.close()
def get_statistics(self):
avg_all = self.get_avg_time()
avg_50 = self.get_avg_time(50)
avg_100 = self.get_avg_time(100)
result = [
["正确率", f"{self.get_acc():.1f}/60%"],
[],
["最近50题正确率", f"{self.get_acc(50):.1f}/60%"],
["最近100题正确率", f"{self.get_acc(120):.1f}/60%"],
["平均耗时", f"{avg_all:.1f}/72s"],
["预计做完用时", f"{avg_all * 1.667:.1f}/120min"],
["50平均耗时", f"{avg_50:.1f}/72s"],
["预计做完用时", f"{avg_50 * 1.667:.1f}/120min"],
["120平均耗时", f"{avg_100:.1f}/72s"],
["预计做完用时", f"{avg_100 * 1.667:.1f}/120min"],
]
return result
def get_source_type(self):
conn = sqlite3.connect(self.db_path)
cursor = conn.cursor()
query = "SELECT DISTINCT title FROM questions"
cursor.execute(query)
result = cursor.fetchall()
new_list = [item[0] for item in result]
return new_list
def set_config(self,a,b):
self.title_filter = a
self.global_filter = b