205 lines
6.3 KiB
Python
205 lines
6.3 KiB
Python
import os
|
|
import random
|
|
import sqlite3
|
|
import time
|
|
import zipfile
|
|
|
|
|
|
class Backend:
|
|
|
|
def __init__(self, db="data.db"):
|
|
self.db_path = db
|
|
# 先解压
|
|
if not os.path.exists(self.db_path) and os.path.exists(self.db_path + ".zip"):
|
|
with zipfile.ZipFile(self.db_path + ".zip", 'r') as zip_ref:
|
|
zip_ref.extractall(".")
|
|
|
|
|
|
self.time = time.time()
|
|
self.title_filter = self.get_source_type()
|
|
self.global_filter = ""
|
|
|
|
|
|
|
|
def get_question(self):
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# 构造 IN 子句的占位符
|
|
placeholders = ','.join('?' for _ in self.title_filter)
|
|
|
|
# 总权重查询
|
|
total_query = f"""
|
|
SELECT SUM(q.count)
|
|
FROM questions q
|
|
WHERE q.count > 0 AND title IN ({placeholders}) {self.global_filter}
|
|
"""
|
|
|
|
# print(total_query)
|
|
# print(self.title_filter)
|
|
|
|
cursor.execute(total_query, self.title_filter)
|
|
total_weights = cursor.fetchone()[0]
|
|
if total_weights is None or total_weights == 0:
|
|
conn.close()
|
|
return None
|
|
|
|
# 生成随机数 [0, total_weights)
|
|
random_num = random.uniform(0, total_weights)
|
|
|
|
# 主查询:加权随机选择
|
|
query = f"""
|
|
SELECT *
|
|
FROM (
|
|
SELECT *,
|
|
SUM(q.count) OVER (ORDER BY id ROWS UNBOUNDED PRECEDING) AS cum_weight
|
|
FROM questions q
|
|
WHERE q.count > 0 AND title IN ({placeholders}) {self.global_filter}
|
|
)
|
|
WHERE ? < cum_weight
|
|
ORDER BY cum_weight
|
|
LIMIT 1;
|
|
"""
|
|
|
|
params = tuple(self.title_filter) + (random_num,)
|
|
cursor.execute(query, params)
|
|
|
|
row = cursor.fetchone()
|
|
print(total_weights, row)
|
|
|
|
conn.close()
|
|
self.time = time.time()
|
|
|
|
return list(row) if row else None
|
|
|
|
def update(self, id, state):
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# 作对-1 错了+2
|
|
if state != 0:
|
|
query = """
|
|
UPDATE questions
|
|
SET count = count + 2
|
|
WHERE id = ?
|
|
and count > 0"""
|
|
else:
|
|
query = """
|
|
UPDATE questions
|
|
SET count = count - 1
|
|
WHERE id = ?
|
|
"""
|
|
cursor.execute(query, (id,))
|
|
|
|
# 记录用时
|
|
query = "INSERT INTO answers_history (id, time_used, state) VALUES (?, ?, ?)"
|
|
cursor.execute(query, (id, time.time() - self.time, state))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def reset_time(self):
|
|
self.time = time.time()
|
|
|
|
def get_acc(self, top_n=0):
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
if top_n > 0:
|
|
query = """
|
|
SELECT CAST(SUM(CASE WHEN state = 0 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*) AS accuracy
|
|
FROM (SELECT state
|
|
FROM answers_history
|
|
ORDER BY time DESC
|
|
LIMIT ?)
|
|
"""
|
|
params = (top_n,)
|
|
else:
|
|
query = """
|
|
SELECT CAST(SUM(CASE WHEN state = 0 THEN 1 ELSE 0 END) AS FLOAT) / COUNT(*) AS accuracy FROM answers_history
|
|
"""
|
|
params = ()
|
|
|
|
cursor.execute(query, params)
|
|
result = cursor.fetchone()[0]
|
|
return result * 100 if result is not None else 0.0
|
|
|
|
except sqlite3.Error as e:
|
|
print(f"数据库错误:{e}")
|
|
return 0.0
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_avg_time(self, top_n=0):
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
if top_n > 0:
|
|
query = """
|
|
SELECT CAST(SUM(time_used) AS FLOAT) / COUNT(*) AS accuracy
|
|
FROM (SELECT time_used
|
|
FROM answers_history
|
|
where time_used is not null
|
|
ORDER BY time DESC
|
|
LIMIT ?)
|
|
"""
|
|
params = (top_n,)
|
|
else:
|
|
query = """
|
|
SELECT CAST(SUM(time_used) AS FLOAT) / COUNT(*) AS accuracy FROM answers_history where time_used is not null
|
|
"""
|
|
params = ()
|
|
|
|
cursor.execute(query, params)
|
|
result = cursor.fetchone()[0]
|
|
return result if result is not None else 0.0
|
|
|
|
except sqlite3.Error as e:
|
|
print(f"数据库错误:{e}")
|
|
return 0.0
|
|
finally:
|
|
conn.close()
|
|
|
|
def get_statistics(self):
|
|
avg_all = self.get_avg_time()
|
|
avg_50 = self.get_avg_time(50)
|
|
avg_100 = self.get_avg_time(100)
|
|
|
|
result = [
|
|
["正确率", f"{self.get_acc():.1f}/60%"],
|
|
[],
|
|
["最近50题正确率", f"{self.get_acc(50):.1f}/60%"],
|
|
["最近100题正确率", f"{self.get_acc(120):.1f}/60%"],
|
|
["平均耗时", f"{avg_all:.1f}/72s"],
|
|
["预计做完用时", f"{avg_all * 1.667:.1f}/120min"],
|
|
["50平均耗时", f"{avg_50:.1f}/72s"],
|
|
["预计做完用时", f"{avg_50 * 1.667:.1f}/120min"],
|
|
["120平均耗时", f"{avg_100:.1f}/72s"],
|
|
["预计做完用时", f"{avg_100 * 1.667:.1f}/120min"],
|
|
]
|
|
return result
|
|
|
|
def get_source_type(self):
|
|
conn = sqlite3.connect(self.db_path)
|
|
cursor = conn.cursor()
|
|
query = "SELECT DISTINCT title FROM questions"
|
|
cursor.execute(query)
|
|
result = cursor.fetchall()
|
|
new_list = [item[0] for item in result]
|
|
return new_list
|
|
|
|
def set_config(self,a,b):
|
|
self.title_filter = a
|
|
|
|
self.global_filter = ""
|
|
if len(b)>0:
|
|
self.global_filter = " and ( FALSE "
|
|
for item in b:
|
|
if item == "未做过的题":
|
|
self.global_filter +=" OR q.count = 3 "
|
|
elif item == "错过的题":
|
|
self.global_filter +=" OR q.count > 3 "
|
|
self.global_filter += ")"
|