298 lines
11 KiB
Python
298 lines
11 KiB
Python
# rhymes.py
|
|
import random
|
|
import json
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Dict, Any, Tuple
|
|
|
|
RHYMES_FILE = "resources/rhymes.json"
|
|
DB_FILE = "data/poilau_state.db"
|
|
RHYME_LOG_FILE = "data/rhyme_log.csv"
|
|
|
|
loaded_rhymes = {}
|
|
|
|
def _ensure_db() -> None:
|
|
"""Initialize SQLite database and create tables if needed."""
|
|
Path(DB_FILE).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with sqlite3.connect(DB_FILE) as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS guild_state (
|
|
guild_id TEXT PRIMARY KEY,
|
|
guild_name TEXT NOT NULL DEFAULT '',
|
|
cooldown_until TEXT NOT NULL DEFAULT '1970-01-01T00:00:00',
|
|
self_control REAL NOT NULL DEFAULT 1.0,
|
|
last_updated TEXT NOT NULL DEFAULT (datetime('now'))
|
|
)
|
|
""")
|
|
conn.commit()
|
|
|
|
|
|
def _ensure_log_file() -> None:
|
|
"""Create CSV log file if it doesn't exist."""
|
|
Path(RHYME_LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not Path(RHYME_LOG_FILE).exists():
|
|
with open(RHYME_LOG_FILE, "w", encoding="utf-8") as f:
|
|
f.write("timestamp,last_word,rhyme_triggered\n")
|
|
|
|
|
|
def _get_connection() -> sqlite3.Connection:
|
|
"""Return SQLite connection with row factory for named column access."""
|
|
conn = sqlite3.connect(DB_FILE)
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
|
|
def load_rhymes() -> Tuple[bool, str]:
|
|
global loaded_rhymes
|
|
"""Load rhymes from JSON file. Returns (success, message)."""
|
|
try:
|
|
with open(RHYMES_FILE, "r", encoding="utf-8") as f:
|
|
loaded_rhymes = json.load(f)
|
|
return True, f"Loaded rhymes file \"{RHYMES_FILE}\""
|
|
except FileNotFoundError:
|
|
return False, f"No rhymes file found at \"{RHYMES_FILE}\""
|
|
except json.JSONDecodeError as e:
|
|
return False, f"Invalid JSON in \"{RHYMES_FILE}\": {e}"
|
|
|
|
|
|
def get_guild_state(guild_id: str) -> Dict[str, Any]:
|
|
"""Retrieve guild state from database."""
|
|
with _get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
SELECT guild_id, guild_name, cooldown_until, self_control, last_updated
|
|
FROM guild_state WHERE guild_id = ?
|
|
""", (guild_id,))
|
|
row = cursor.fetchone()
|
|
|
|
if row:
|
|
return {
|
|
"guild_id": row["guild_id"],
|
|
"guild_name": row["guild_name"],
|
|
"cooldown_until": row["cooldown_until"],
|
|
"self_control": row["self_control"],
|
|
"last_updated": row["last_updated"]
|
|
}
|
|
else:
|
|
return {
|
|
"guild_id": guild_id,
|
|
"guild_name": "",
|
|
"cooldown_until": "1970-01-01T00:00:00",
|
|
"self_control": 1.0,
|
|
"last_updated": datetime.now().isoformat()
|
|
}
|
|
|
|
|
|
def update_guild_state(
|
|
guild_id: str,
|
|
guild_name: str,
|
|
cooldown_until: str,
|
|
self_control: float
|
|
) -> None:
|
|
"""Update guild state in database."""
|
|
with _get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("""
|
|
INSERT OR REPLACE INTO guild_state (guild_id, guild_name, cooldown_until, self_control, last_updated)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
""", (guild_id, guild_name, cooldown_until, self_control, datetime.now().isoformat()))
|
|
conn.commit()
|
|
|
|
|
|
def delete_guild_state(guild_id: str) -> bool:
|
|
"""Delete guild state from database."""
|
|
with _get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("DELETE FROM guild_state WHERE guild_id = ?", (guild_id,))
|
|
conn.commit()
|
|
return cursor.rowcount > 0
|
|
|
|
|
|
def get_all_guild_states() -> Dict[str, Dict[str, Any]]:
|
|
"""Retrieve all guild states (for debug purposes)."""
|
|
with _get_connection() as conn:
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT guild_id, guild_name, cooldown_until, self_control, last_updated FROM guild_state")
|
|
return {
|
|
row["guild_id"]: {
|
|
"guild_name": row["guild_name"],
|
|
"cooldown_until": row["cooldown_until"],
|
|
"self_control": row["self_control"],
|
|
"last_updated": row["last_updated"]
|
|
}
|
|
for row in cursor.fetchall()
|
|
}
|
|
|
|
|
|
def log_rhyme(last_word: str, rhyme_triggered: str) -> None:
|
|
"""Log rhyme trigger to CSV file."""
|
|
timestamp = datetime.now().isoformat()
|
|
|
|
with open(RHYME_LOG_FILE, "a", encoding="utf-8") as f:
|
|
safe_rhyme = rhyme_triggered.replace(",", ";")
|
|
f.write(f"{timestamp},{last_word},{safe_rhyme}\n")
|
|
|
|
|
|
def get_last_word(text: str) -> str:
|
|
"""Extract last alphabetic word from text."""
|
|
truncated = text
|
|
while True:
|
|
if len(truncated) < 2 or truncated[-1].isnumeric():
|
|
return ""
|
|
if truncated[-1].isalpha() and truncated[-2].isalpha():
|
|
break
|
|
truncated = truncated[:-1]
|
|
truncated = truncated.split(" ")[-1]
|
|
return truncated if truncated.isalpha() else ""
|
|
|
|
|
|
def find_rhyme(word: str) -> str:
|
|
global loaded_rhymes
|
|
"""Find matching rhyme for given word."""
|
|
for rhyme in loaded_rhymes:
|
|
if word in rhyme["blacklist"]:
|
|
return ""
|
|
if word.endswith(tuple(rhyme["keys"])):
|
|
log_rhyme(word, rhyme["sound"])
|
|
return random.choice(rhyme["rhymes"])
|
|
return ""
|
|
|
|
async def get_guild_name(guildId, client) -> str:
|
|
guild = await client.fetch_guild(guildId)
|
|
return "[Server={0}]".format(guild.name)
|
|
|
|
async def handle_debug_commands(message, client) -> bool:
|
|
"""Handle debug commands (debug, save, load). Returns True if handled."""
|
|
message_content = message.content.lower()
|
|
|
|
if message_content == "debug poilau":
|
|
if message.author.id == 151626081458192384:
|
|
all_states = get_all_guild_states()
|
|
dump = {}
|
|
for guild_id, state in all_states.items():
|
|
channel_name = await get_guild_name(guild_id, client)
|
|
cooldown_dt = datetime.fromisoformat(state["cooldown_until"])
|
|
time_remaining = max(0, (cooldown_dt - datetime.now()).total_seconds())
|
|
sleeping_time = "{:.2f} s".format(time_remaining)
|
|
dump[channel_name] = {
|
|
"cooldown_until": state["cooldown_until"],
|
|
"cooldown_remaining": sleeping_time,
|
|
"self-control": state["self_control"],
|
|
"last_updated": state["last_updated"]
|
|
}
|
|
await message.author.send(
|
|
"```json\n{0}```".format(json.dumps(dump, ensure_ascii=False, indent=2))
|
|
)
|
|
return True
|
|
|
|
if message_content == "save poilau":
|
|
if message.author.id == 151626081458192384:
|
|
all_states = get_all_guild_states()
|
|
json_str = "```json\n{0}```".format(json.dumps(all_states, ensure_ascii=False, indent=2))
|
|
await message.author.send("State persisted in SQLite database")
|
|
await message.author.send(json_str)
|
|
return True
|
|
|
|
if message_content == "load poilau":
|
|
if message.author.id == 151626081458192384:
|
|
success, msg = load_rhymes()
|
|
all_states = get_all_guild_states()
|
|
json_str = "```json\n{0}```".format(json.dumps(all_states, ensure_ascii=False, indent=2))
|
|
await message.author.send(msg)
|
|
await message.author.send(json_str)
|
|
return True
|
|
|
|
if message_content == "tg fouras" and message.guild:
|
|
# Disable cooldown for this server (set to far future)
|
|
cooldown_date = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
|
|
cooldown_date = cooldown_date.replace(day=cooldown_date.day + 10000)
|
|
update_guild_state(
|
|
str(message.guild.id),
|
|
guild_name=message.guild.name,
|
|
cooldown_until=cooldown_date.isoformat(),
|
|
self_control=2.0
|
|
)
|
|
await message.channel.send("ok :'(")
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
async def handle_rhyme_logic(message, client) -> bool:
|
|
"""Main rhyme detection logic. Returns True if rhyme was triggered."""
|
|
message_content = message.content.lower()
|
|
last_word = get_last_word(message_content)
|
|
|
|
if message.author != client.user and message.guild and last_word:
|
|
rhyme = find_rhyme(last_word)
|
|
guild_id = str(message.guild.id)
|
|
guild_name = message.guild.name
|
|
|
|
if rhyme:
|
|
guild_state = get_guild_state(guild_id)
|
|
|
|
# Update guild name if changed
|
|
if guild_state["guild_name"] != guild_name:
|
|
update_guild_state(
|
|
guild_id,
|
|
guild_name=guild_name,
|
|
cooldown_until=guild_state["cooldown_until"],
|
|
self_control=guild_state["self_control"]
|
|
)
|
|
|
|
# Check cooldown
|
|
cooldown_dt = datetime.fromisoformat(guild_state["cooldown_until"])
|
|
now_dt = datetime.now()
|
|
|
|
if now_dt >= cooldown_dt:
|
|
self_control = guild_state["self_control"]
|
|
|
|
# Probability check
|
|
if random.random() < self_control:
|
|
new_self_control = self_control * 0.9
|
|
update_guild_state(
|
|
guild_id,
|
|
guild_name=guild_name,
|
|
cooldown_until=now_dt.isoformat(),
|
|
self_control=new_self_control
|
|
)
|
|
return False
|
|
|
|
# Calculate new cooldown duration
|
|
wait_time = random.randint(0, 900)
|
|
if bool(random.getrandbits(1)):
|
|
wait_time = random.randint(900, 10800)
|
|
|
|
new_cooldown_dt = now_dt.replace(second=0, microsecond=0)
|
|
new_cooldown_dt = new_cooldown_dt.replace(minute=new_cooldown_dt.minute + wait_time // 60)
|
|
new_cooldown_dt = new_cooldown_dt.replace(hour=new_cooldown_dt.hour + wait_time // 3600)
|
|
|
|
update_guild_state(
|
|
guild_id,
|
|
guild_name=guild_name,
|
|
cooldown_until=new_cooldown_dt.isoformat(),
|
|
self_control=self_control + 1.0
|
|
)
|
|
|
|
await message.channel.send(rhyme)
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
async def handle_message(message, client) -> bool:
|
|
"""Main entry point for message handling."""
|
|
# Initialize database and log file on first run
|
|
_ensure_db()
|
|
_ensure_log_file()
|
|
|
|
# Handle debug commands first
|
|
if await handle_debug_commands(message, client):
|
|
return True
|
|
|
|
# Process rhyme logic
|
|
return await handle_rhyme_logic(message, client) |