From 5be09cd891bd28d5e10eb3aee67379f0c103c766 Mon Sep 17 00:00:00 2001 From: diamante0018 Date: Tue, 9 Dec 2025 11:17:03 +0100 Subject: [PATCH] chore: do not store text in memory --- bot/commands.py | 13 ++++++----- bot/config.py | 16 +------------- bot/mongodb/__init__.py | 2 +- bot/mongodb/load_db.py | 49 +++++++++++++++++++++++++++++++++++++++-- bot/tasks.py | 13 ++++++----- database/__init__.py | 6 ++--- 6 files changed, 67 insertions(+), 32 deletions(-) diff --git a/bot/commands.py b/bot/commands.py index ae18518..d8b2b15 100644 --- a/bot/commands.py +++ b/bot/commands.py @@ -4,13 +4,13 @@ from typing import Literal import discord from discord import app_commands -from bot.config import message_patterns, update_patterns from bot.log import logger from bot.utils import compile_stats, fetch_game_stats, perform_search from database import ( + get_meme_patterns, add_aka_response, search_aka, - add_pattern, + add_meme_pattern, add_user_to_blacklist, is_user_blacklisted, ) @@ -54,17 +54,17 @@ async def setup(bot): ) @bot.tree.command( - name="add_pattern", + name="add_meme_pattern", description="Add a new message pattern to the database.", guild=discord.Object(id=GUILD_ID), ) @app_commands.checks.has_permissions(administrator=True) - async def add_pattern_cmd( + async def add_meme_pattern_cmd( interaction: discord.Interaction, regex: str, response: str ): """Slash command to add a new message pattern to the database.""" - add_pattern(regex, response) - update_patterns(regex, response) + add_meme_pattern(regex, response) + logger.info(f"Pattern added in memory: {regex}") await interaction.response.send_message( f"Pattern added!\n**Regex:** `{regex}`\n**Response:** `{response}`" ) @@ -159,6 +159,7 @@ async def setup(bot): ) return + message_patterns = get_meme_patterns() # Check if any of the patterns match the input for pattern in message_patterns: if re.search(pattern["regex"], input, re.IGNORECASE): diff --git a/bot/config.py b/bot/config.py index 28c40d2..6911e2e 100644 --- a/bot/config.py +++ b/bot/config.py @@ -1,21 +1,7 @@ import os -from bot.log import logger -from bot.mongodb.load_db import load_chat_messages_from_db - -from database import get_patterns - MONGO_URI = os.getenv("MONGO_URI") - -def update_patterns(regex: str, response: str): - """update patterns in memory.""" - message_patterns.append({"regex": regex, "response": response}) - logger.info(f"Pattern added in memory: {regex}") - - # load global variables -message_patterns = get_patterns() - -schizo_messages = load_chat_messages_from_db() +# There are none ! diff --git a/bot/mongodb/__init__.py b/bot/mongodb/__init__.py index ce338ff..243e5e9 100644 --- a/bot/mongodb/__init__.py +++ b/bot/mongodb/__init__.py @@ -1 +1 @@ -from .load_db import load_chat_messages_from_db +from .load_db import load_chat_messages_from_db, read_random_message_from_collection diff --git a/bot/mongodb/load_db.py b/bot/mongodb/load_db.py index 9443e1c..8dc3f57 100644 --- a/bot/mongodb/load_db.py +++ b/bot/mongodb/load_db.py @@ -54,7 +54,7 @@ def write_deleted_message_to_collection( return [] -def read_message_from_collection( +def read_messages_from_collection( database="discord_bot", collection="messages", ): @@ -91,10 +91,55 @@ def read_message_from_collection( return [] +def read_random_message_from_collection( + database="discord_bot", + collection="messages", +): + """ + Loads a random chat message from MongoDB. + + Args: + database (str): Name of the MongoDB database + collection (str): Name of the collection + + Returns: + str or None: random message string, or None if collection is empty + """ + mongo_uri = get_mongodb_uri() + + try: + with MongoClient(mongo_uri) as client: + db = client[database] + col = db[collection] + + logger.debug( + f"Connecting to MongoDB at {mongo_uri}, DB='{database}', Collection='{collection}'" + ) + + # Use aggregation with $sample to get a random document + pipeline = [{"$sample": {"size": 1}}] + + cursor = col.aggregate(pipeline) + # almost random + random_docs = list(cursor) + + if random_docs and "message" in random_docs[0]: + message = random_docs[0]["message"] + logger.info(f"Loaded random message from MongoDB: {message[:100]}...") + return message + + logger.warning("No messages found in collection") + return None + + except Exception as e: + logger.error(f"Failed to load random message from MongoDB: {e}") + return None + + def load_chat_messages_from_db(): messages = [] - messages = read_message_from_collection() + messages = read_messages_from_collection() if not messages: logger.warning("messages collection is empty after loading from MongoDB!") diff --git a/bot/tasks.py b/bot/tasks.py index 2358953..df8669b 100644 --- a/bot/tasks.py +++ b/bot/tasks.py @@ -5,10 +5,10 @@ import discord import requests from discord.ext import commands, tasks -from bot.config import schizo_messages from bot.discourse.handle_request import combine_posts_text, fetch_cooked_posts from bot.log import logger from bot.utils import aware_utcnow, fetch_api_data +from bot.mongodb import read_random_message_from_collection from database import migrate_users_with_role TARGET_DATE = datetime(2036, 8, 12, tzinfo=timezone.utc) @@ -192,11 +192,14 @@ async def setup(bot): @tasks.loop(hours=5) async def shizo_message(): channel = bot.get_channel(OFFTOPIC_CHANNEL) - if channel and schizo_messages: - message = random.choice(schizo_messages) - await channel.send(message) + if channel: + message = read_random_message_from_collection() + if message: + await channel.send(message) + else: + logger.error("No funny messages were found.") else: - logger.error("Channel not found or schizo_messages is empty.") + logger.error("Channel not found. Check the OFFTOPIC_CHANNEL variable.") @tasks.loop(hours=24) async def share_dementia_image(): diff --git a/database/__init__.py b/database/__init__.py index f37952a..8924496 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -25,7 +25,7 @@ def initialize_db(): logger.info("Done loading database: %s", DB_PATH) -def add_pattern(regex: str, response: str): +def add_meme_pattern(regex: str, response: str): """Adds a new pattern to the database.""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() @@ -38,7 +38,7 @@ def add_pattern(regex: str, response: str): conn.close() -def get_patterns(): +def get_meme_patterns(): """Fetches all regex-response pairs from the database.""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() @@ -50,7 +50,7 @@ def get_patterns(): return [{"regex": row[0], "response": row[1]} for row in patterns] -def remove_pattern(pattern_id: int): +def remove_meme_pattern(pattern_id: int): """Removes a pattern by ID.""" conn = sqlite3.connect(DB_PATH) cursor = conn.cursor()