mirror of
https://github.com/alterware/aw-bot.git
synced 2025-12-11 12:07:50 +00:00
chore: do not store text in memory
This commit is contained in:
@@ -4,13 +4,13 @@ from typing import Literal
|
|||||||
import discord
|
import discord
|
||||||
from discord import app_commands
|
from discord import app_commands
|
||||||
|
|
||||||
from bot.config import message_patterns, update_patterns
|
|
||||||
from bot.log import logger
|
from bot.log import logger
|
||||||
from bot.utils import compile_stats, fetch_game_stats, perform_search
|
from bot.utils import compile_stats, fetch_game_stats, perform_search
|
||||||
from database import (
|
from database import (
|
||||||
|
get_meme_patterns,
|
||||||
add_aka_response,
|
add_aka_response,
|
||||||
search_aka,
|
search_aka,
|
||||||
add_pattern,
|
add_meme_pattern,
|
||||||
add_user_to_blacklist,
|
add_user_to_blacklist,
|
||||||
is_user_blacklisted,
|
is_user_blacklisted,
|
||||||
)
|
)
|
||||||
@@ -54,17 +54,17 @@ async def setup(bot):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@bot.tree.command(
|
@bot.tree.command(
|
||||||
name="add_pattern",
|
name="add_meme_pattern",
|
||||||
description="Add a new message pattern to the database.",
|
description="Add a new message pattern to the database.",
|
||||||
guild=discord.Object(id=GUILD_ID),
|
guild=discord.Object(id=GUILD_ID),
|
||||||
)
|
)
|
||||||
@app_commands.checks.has_permissions(administrator=True)
|
@app_commands.checks.has_permissions(administrator=True)
|
||||||
async def add_pattern_cmd(
|
async def add_meme_pattern_cmd(
|
||||||
interaction: discord.Interaction, regex: str, response: str
|
interaction: discord.Interaction, regex: str, response: str
|
||||||
):
|
):
|
||||||
"""Slash command to add a new message pattern to the database."""
|
"""Slash command to add a new message pattern to the database."""
|
||||||
add_pattern(regex, response)
|
add_meme_pattern(regex, response)
|
||||||
update_patterns(regex, response)
|
logger.info(f"Pattern added in memory: {regex}")
|
||||||
await interaction.response.send_message(
|
await interaction.response.send_message(
|
||||||
f"Pattern added!\n**Regex:** `{regex}`\n**Response:** `{response}`"
|
f"Pattern added!\n**Regex:** `{regex}`\n**Response:** `{response}`"
|
||||||
)
|
)
|
||||||
@@ -159,6 +159,7 @@ async def setup(bot):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
message_patterns = get_meme_patterns()
|
||||||
# Check if any of the patterns match the input
|
# Check if any of the patterns match the input
|
||||||
for pattern in message_patterns:
|
for pattern in message_patterns:
|
||||||
if re.search(pattern["regex"], input, re.IGNORECASE):
|
if re.search(pattern["regex"], input, re.IGNORECASE):
|
||||||
|
|||||||
@@ -1,21 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from bot.log import logger
|
|
||||||
from bot.mongodb.load_db import load_chat_messages_from_db
|
|
||||||
|
|
||||||
from database import get_patterns
|
|
||||||
|
|
||||||
MONGO_URI = os.getenv("MONGO_URI")
|
MONGO_URI = os.getenv("MONGO_URI")
|
||||||
|
|
||||||
|
|
||||||
def update_patterns(regex: str, response: str):
|
|
||||||
"""update patterns in memory."""
|
|
||||||
message_patterns.append({"regex": regex, "response": response})
|
|
||||||
logger.info(f"Pattern added in memory: {regex}")
|
|
||||||
|
|
||||||
|
|
||||||
# load global variables
|
# load global variables
|
||||||
|
|
||||||
message_patterns = get_patterns()
|
# There are none !
|
||||||
|
|
||||||
schizo_messages = load_chat_messages_from_db()
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
from .load_db import load_chat_messages_from_db
|
from .load_db import load_chat_messages_from_db, read_random_message_from_collection
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ def write_deleted_message_to_collection(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def read_message_from_collection(
|
def read_messages_from_collection(
|
||||||
database="discord_bot",
|
database="discord_bot",
|
||||||
collection="messages",
|
collection="messages",
|
||||||
):
|
):
|
||||||
@@ -91,10 +91,55 @@ def read_message_from_collection(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def read_random_message_from_collection(
|
||||||
|
database="discord_bot",
|
||||||
|
collection="messages",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Loads a random chat message from MongoDB.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database (str): Name of the MongoDB database
|
||||||
|
collection (str): Name of the collection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str or None: random message string, or None if collection is empty
|
||||||
|
"""
|
||||||
|
mongo_uri = get_mongodb_uri()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with MongoClient(mongo_uri) as client:
|
||||||
|
db = client[database]
|
||||||
|
col = db[collection]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Connecting to MongoDB at {mongo_uri}, DB='{database}', Collection='{collection}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use aggregation with $sample to get a random document
|
||||||
|
pipeline = [{"$sample": {"size": 1}}]
|
||||||
|
|
||||||
|
cursor = col.aggregate(pipeline)
|
||||||
|
# almost random
|
||||||
|
random_docs = list(cursor)
|
||||||
|
|
||||||
|
if random_docs and "message" in random_docs[0]:
|
||||||
|
message = random_docs[0]["message"]
|
||||||
|
logger.info(f"Loaded random message from MongoDB: {message[:100]}...")
|
||||||
|
return message
|
||||||
|
|
||||||
|
logger.warning("No messages found in collection")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load random message from MongoDB: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def load_chat_messages_from_db():
|
def load_chat_messages_from_db():
|
||||||
messages = []
|
messages = []
|
||||||
|
|
||||||
messages = read_message_from_collection()
|
messages = read_messages_from_collection()
|
||||||
if not messages:
|
if not messages:
|
||||||
logger.warning("messages collection is empty after loading from MongoDB!")
|
logger.warning("messages collection is empty after loading from MongoDB!")
|
||||||
|
|
||||||
|
|||||||
13
bot/tasks.py
13
bot/tasks.py
@@ -5,10 +5,10 @@ import discord
|
|||||||
import requests
|
import requests
|
||||||
from discord.ext import commands, tasks
|
from discord.ext import commands, tasks
|
||||||
|
|
||||||
from bot.config import schizo_messages
|
|
||||||
from bot.discourse.handle_request import combine_posts_text, fetch_cooked_posts
|
from bot.discourse.handle_request import combine_posts_text, fetch_cooked_posts
|
||||||
from bot.log import logger
|
from bot.log import logger
|
||||||
from bot.utils import aware_utcnow, fetch_api_data
|
from bot.utils import aware_utcnow, fetch_api_data
|
||||||
|
from bot.mongodb import read_random_message_from_collection
|
||||||
from database import migrate_users_with_role
|
from database import migrate_users_with_role
|
||||||
|
|
||||||
TARGET_DATE = datetime(2036, 8, 12, tzinfo=timezone.utc)
|
TARGET_DATE = datetime(2036, 8, 12, tzinfo=timezone.utc)
|
||||||
@@ -192,11 +192,14 @@ async def setup(bot):
|
|||||||
@tasks.loop(hours=5)
|
@tasks.loop(hours=5)
|
||||||
async def shizo_message():
|
async def shizo_message():
|
||||||
channel = bot.get_channel(OFFTOPIC_CHANNEL)
|
channel = bot.get_channel(OFFTOPIC_CHANNEL)
|
||||||
if channel and schizo_messages:
|
if channel:
|
||||||
message = random.choice(schizo_messages)
|
message = read_random_message_from_collection()
|
||||||
await channel.send(message)
|
if message:
|
||||||
|
await channel.send(message)
|
||||||
|
else:
|
||||||
|
logger.error("No funny messages were found.")
|
||||||
else:
|
else:
|
||||||
logger.error("Channel not found or schizo_messages is empty.")
|
logger.error("Channel not found. Check the OFFTOPIC_CHANNEL variable.")
|
||||||
|
|
||||||
@tasks.loop(hours=24)
|
@tasks.loop(hours=24)
|
||||||
async def share_dementia_image():
|
async def share_dementia_image():
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ def initialize_db():
|
|||||||
logger.info("Done loading database: %s", DB_PATH)
|
logger.info("Done loading database: %s", DB_PATH)
|
||||||
|
|
||||||
|
|
||||||
def add_pattern(regex: str, response: str):
|
def add_meme_pattern(regex: str, response: str):
|
||||||
"""Adds a new pattern to the database."""
|
"""Adds a new pattern to the database."""
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -38,7 +38,7 @@ def add_pattern(regex: str, response: str):
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def get_patterns():
|
def get_meme_patterns():
|
||||||
"""Fetches all regex-response pairs from the database."""
|
"""Fetches all regex-response pairs from the database."""
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
@@ -50,7 +50,7 @@ def get_patterns():
|
|||||||
return [{"regex": row[0], "response": row[1]} for row in patterns]
|
return [{"regex": row[0], "response": row[1]} for row in patterns]
|
||||||
|
|
||||||
|
|
||||||
def remove_pattern(pattern_id: int):
|
def remove_meme_pattern(pattern_id: int):
|
||||||
"""Removes a pattern by ID."""
|
"""Removes a pattern by ID."""
|
||||||
conn = sqlite3.connect(DB_PATH)
|
conn = sqlite3.connect(DB_PATH)
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|||||||
Reference in New Issue
Block a user