perefouras/rhymes.py

298 lines
11 KiB
Python

# rhymes.py
import random
import json
import sqlite3
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, Tuple
RHYMES_FILE = "resources/rhymes.json"
DB_FILE = "data/poilau_state.db"
RHYME_LOG_FILE = "data/rhyme_log.csv"
loaded_rhymes = {}
def _ensure_db() -> None:
"""Initialize SQLite database and create tables if needed."""
Path(DB_FILE).parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(DB_FILE) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS guild_state (
guild_id TEXT PRIMARY KEY,
guild_name TEXT NOT NULL DEFAULT '',
cooldown_until TEXT NOT NULL DEFAULT '1970-01-01T00:00:00',
self_control REAL NOT NULL DEFAULT 1.0,
last_updated TEXT NOT NULL DEFAULT (datetime('now'))
)
""")
conn.commit()
def _ensure_log_file() -> None:
"""Create CSV log file if it doesn't exist."""
Path(RHYME_LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
if not Path(RHYME_LOG_FILE).exists():
with open(RHYME_LOG_FILE, "w", encoding="utf-8") as f:
f.write("timestamp,last_word,rhyme_triggered\n")
def _get_connection() -> sqlite3.Connection:
"""Return SQLite connection with row factory for named column access."""
conn = sqlite3.connect(DB_FILE)
conn.row_factory = sqlite3.Row
return conn
def load_rhymes() -> Tuple[bool, str]:
global loaded_rhymes
"""Load rhymes from JSON file. Returns (success, message)."""
try:
with open(RHYMES_FILE, "r", encoding="utf-8") as f:
loaded_rhymes = json.load(f)
return True, f"Loaded rhymes file \"{RHYMES_FILE}\""
except FileNotFoundError:
return False, f"No rhymes file found at \"{RHYMES_FILE}\""
except json.JSONDecodeError as e:
return False, f"Invalid JSON in \"{RHYMES_FILE}\": {e}"
def get_guild_state(guild_id: str) -> Dict[str, Any]:
"""Retrieve guild state from database."""
with _get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT guild_id, guild_name, cooldown_until, self_control, last_updated
FROM guild_state WHERE guild_id = ?
""", (guild_id,))
row = cursor.fetchone()
if row:
return {
"guild_id": row["guild_id"],
"guild_name": row["guild_name"],
"cooldown_until": row["cooldown_until"],
"self_control": row["self_control"],
"last_updated": row["last_updated"]
}
else:
return {
"guild_id": guild_id,
"guild_name": "",
"cooldown_until": "1970-01-01T00:00:00",
"self_control": 1.0,
"last_updated": datetime.now().isoformat()
}
def update_guild_state(
guild_id: str,
guild_name: str,
cooldown_until: str,
self_control: float
) -> None:
"""Update guild state in database."""
with _get_connection() as conn:
cursor = conn.cursor()
cursor.execute("""
INSERT OR REPLACE INTO guild_state (guild_id, guild_name, cooldown_until, self_control, last_updated)
VALUES (?, ?, ?, ?, ?)
""", (guild_id, guild_name, cooldown_until, self_control, datetime.now().isoformat()))
conn.commit()
def delete_guild_state(guild_id: str) -> bool:
"""Delete guild state from database."""
with _get_connection() as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM guild_state WHERE guild_id = ?", (guild_id,))
conn.commit()
return cursor.rowcount > 0
def get_all_guild_states() -> Dict[str, Dict[str, Any]]:
"""Retrieve all guild states (for debug purposes)."""
with _get_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT guild_id, guild_name, cooldown_until, self_control, last_updated FROM guild_state")
return {
row["guild_id"]: {
"guild_name": row["guild_name"],
"cooldown_until": row["cooldown_until"],
"self_control": row["self_control"],
"last_updated": row["last_updated"]
}
for row in cursor.fetchall()
}
def log_rhyme(last_word: str, rhyme_triggered: str) -> None:
"""Log rhyme trigger to CSV file."""
timestamp = datetime.now().isoformat()
with open(RHYME_LOG_FILE, "a", encoding="utf-8") as f:
safe_rhyme = rhyme_triggered.replace(",", ";")
f.write(f"{timestamp},{last_word},{safe_rhyme}\n")
def get_last_word(text: str) -> str:
"""Extract last alphabetic word from text."""
truncated = text
while True:
if len(truncated) < 2 or truncated[-1].isnumeric():
return ""
if truncated[-1].isalpha() and truncated[-2].isalpha():
break
truncated = truncated[:-1]
truncated = truncated.split(" ")[-1]
return truncated if truncated.isalpha() else ""
def find_rhyme(word: str) -> str:
global loaded_rhymes
"""Find matching rhyme for given word."""
for rhyme in loaded_rhymes:
if word in rhyme["blacklist"]:
return ""
if word.endswith(tuple(rhyme["keys"])):
log_rhyme(word, rhyme["sound"])
return random.choice(rhyme["rhymes"])
return ""
async def get_guild_name(guildId, client) -> str:
guild = await client.fetch_guild(guildId)
return "[Server={0}]".format(guild.name)
async def handle_debug_commands(message, client) -> bool:
"""Handle debug commands (debug, save, load). Returns True if handled."""
message_content = message.content.lower()
if message_content == "debug poilau":
if message.author.id == 151626081458192384:
all_states = get_all_guild_states()
dump = {}
for guild_id, state in all_states.items():
channel_name = await get_guild_name(guild_id, client)
cooldown_dt = datetime.fromisoformat(state["cooldown_until"])
time_remaining = max(0, (cooldown_dt - datetime.now()).total_seconds())
sleeping_time = "{:.2f} s".format(time_remaining)
dump[channel_name] = {
"cooldown_until": state["cooldown_until"],
"cooldown_remaining": sleeping_time,
"self-control": state["self_control"],
"last_updated": state["last_updated"]
}
await message.author.send(
"```json\n{0}```".format(json.dumps(dump, ensure_ascii=False, indent=2))
)
return True
if message_content == "save poilau":
if message.author.id == 151626081458192384:
all_states = get_all_guild_states()
json_str = "```json\n{0}```".format(json.dumps(all_states, ensure_ascii=False, indent=2))
await message.author.send("State persisted in SQLite database")
await message.author.send(json_str)
return True
if message_content == "load poilau":
if message.author.id == 151626081458192384:
success, msg = load_rhymes()
all_states = get_all_guild_states()
json_str = "```json\n{0}```".format(json.dumps(all_states, ensure_ascii=False, indent=2))
await message.author.send(msg)
await message.author.send(json_str)
return True
if message_content == "tg fouras" and message.guild:
# Disable cooldown for this server (set to far future)
cooldown_date = datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)
cooldown_date = cooldown_date.replace(day=cooldown_date.day + 10000)
update_guild_state(
str(message.guild.id),
guild_name=message.guild.name,
cooldown_until=cooldown_date.isoformat(),
self_control=2.0
)
await message.channel.send("ok :'(")
return True
return False
async def handle_rhyme_logic(message, client) -> bool:
"""Main rhyme detection logic. Returns True if rhyme was triggered."""
message_content = message.content.lower()
last_word = get_last_word(message_content)
if message.author != client.user and message.guild and last_word:
rhyme = find_rhyme(last_word)
guild_id = str(message.guild.id)
guild_name = message.guild.name
if rhyme:
guild_state = get_guild_state(guild_id)
# Update guild name if changed
if guild_state["guild_name"] != guild_name:
update_guild_state(
guild_id,
guild_name=guild_name,
cooldown_until=guild_state["cooldown_until"],
self_control=guild_state["self_control"]
)
# Check cooldown
cooldown_dt = datetime.fromisoformat(guild_state["cooldown_until"])
now_dt = datetime.now()
if now_dt >= cooldown_dt:
self_control = guild_state["self_control"]
# Probability check
if random.random() < self_control:
new_self_control = self_control * 0.9
update_guild_state(
guild_id,
guild_name=guild_name,
cooldown_until=now_dt.isoformat(),
self_control=new_self_control
)
return False
# Calculate new cooldown duration
wait_time = random.randint(0, 900)
if bool(random.getrandbits(1)):
wait_time = random.randint(900, 10800)
new_cooldown_dt = now_dt.replace(second=0, microsecond=0)
new_cooldown_dt = new_cooldown_dt.replace(minute=new_cooldown_dt.minute + wait_time // 60)
new_cooldown_dt = new_cooldown_dt.replace(hour=new_cooldown_dt.hour + wait_time // 3600)
update_guild_state(
guild_id,
guild_name=guild_name,
cooldown_until=new_cooldown_dt.isoformat(),
self_control=self_control + 1.0
)
await message.channel.send(rhyme)
return True
return False
async def handle_message(message, client) -> bool:
"""Main entry point for message handling."""
# Initialize database and log file on first run
_ensure_db()
_ensure_log_file()
# Handle debug commands first
if await handle_debug_commands(message, client):
return True
# Process rhyme logic
return await handle_rhyme_logic(message, client)