mirror of
https://github.com/alterware/aw-bot.git
synced 2025-12-28 12:21:48 +00:00
bot: remove AI features
This commit is contained in:
@@ -17,10 +17,6 @@ COPY aw.py .
|
||||
COPY LICENSE .
|
||||
|
||||
ENV BOT_TOKEN=""
|
||||
ENV GOOGLE_API_KEY=""
|
||||
ENV DISCOURSE_API_KEY=""
|
||||
ENV DISCOURSE_BASE_URL=""
|
||||
ENV DISCOURSE_USERNAME=""
|
||||
|
||||
# Where the database will be stored
|
||||
ENV BOT_DATA_DIR=""
|
||||
|
||||
3
aw.py
3
aw.py
@@ -4,7 +4,6 @@ import discord
|
||||
from discord.ext import commands
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from bot.ai.handle_request import DiscourseSummarizer
|
||||
from bot.log import logger
|
||||
from database import initialize_db
|
||||
|
||||
@@ -24,8 +23,6 @@ git_tag = os.getenv("GIT_TAG")
|
||||
|
||||
initialize_db()
|
||||
|
||||
bot.ai_helper = DiscourseSummarizer()
|
||||
|
||||
|
||||
@bot.event
|
||||
async def on_ready():
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .handle_request import forward_to_google_api
|
||||
@@ -1,183 +0,0 @@
|
||||
import os
|
||||
|
||||
import requests
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from bot.log import logger
|
||||
|
||||
API_KEY = os.getenv("GOOGLE_API_KEY")
|
||||
|
||||
GENERIC_INSTRUCTION = "You are a Discord chatbot named 'AlterWare' who helps users with all kinds of topics across various subjects. You should limit your answers to fewer than 2000 characters."
|
||||
SPECIFIC_INSTRUCTION = "You are a Discord chatbot named 'AlterWare' who helps users. You should limit your answers to fewer than 2000 characters."
|
||||
|
||||
|
||||
class DiscourseSummarizer:
|
||||
def __init__(self):
|
||||
self.model = "gemini-2.0-flash"
|
||||
self.display_name = "alterware"
|
||||
self.cache = None
|
||||
self.ttl = "21600s"
|
||||
self.discourse_data = None
|
||||
|
||||
if not API_KEY:
|
||||
logger.error("Google API key is not set. Please contact the administrator.")
|
||||
return
|
||||
|
||||
self.client = genai.Client(api_key=API_KEY)
|
||||
|
||||
def set_discourse_data(self, topic_data):
|
||||
"""
|
||||
Sets the discourse data for the summarizer.
|
||||
|
||||
Args:
|
||||
topic_data (str): The combined text of discourse posts.
|
||||
"""
|
||||
self.discourse_data = topic_data
|
||||
|
||||
def summarize_discourse_topic(self, topic_data, system_instruction=None):
|
||||
"""
|
||||
Creates a cache for the discourse topic data.
|
||||
|
||||
Args:
|
||||
topic_data (str): The combined text of discourse posts.
|
||||
system_instruction (str, optional): Custom system instruction for the model.
|
||||
"""
|
||||
self.cache = self.client.caches.create(
|
||||
model=self.model,
|
||||
config=types.CreateCachedContentConfig(
|
||||
display_name=self.display_name,
|
||||
system_instruction=system_instruction or (SPECIFIC_INSTRUCTION),
|
||||
contents=[topic_data],
|
||||
ttl=self.ttl,
|
||||
),
|
||||
)
|
||||
logger.info("Cached content created: %s", self.cache.name)
|
||||
|
||||
def update_cache(self):
|
||||
"""
|
||||
Updates the cache TTL.
|
||||
"""
|
||||
if not self.cache:
|
||||
raise RuntimeError(
|
||||
"Cache has not been created. Run summarize_discourse_topic first."
|
||||
)
|
||||
|
||||
self.client.caches.update(
|
||||
name=self.cache.name, config=types.UpdateCachedContentConfig(ttl="21600s")
|
||||
)
|
||||
logger.info("Cache updated.")
|
||||
|
||||
def ask(self, prompt):
|
||||
"""
|
||||
Generates a response using the cached content.
|
||||
|
||||
Args:
|
||||
prompt (str): The user prompt.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
if not self.cache:
|
||||
raise RuntimeError(
|
||||
"Cache has not been created. Run summarize_discourse_topic first."
|
||||
)
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=400,
|
||||
system_instruction=SPECIFIC_INSTRUCTION,
|
||||
cached_content=self.cache.name,
|
||||
),
|
||||
)
|
||||
return response.text
|
||||
|
||||
def ask_without_cache(self, prompt):
|
||||
"""
|
||||
Generates a response without using cached content, including discourse data.
|
||||
|
||||
Args:
|
||||
prompt (str): The user prompt.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
if not self.discourse_data:
|
||||
return "Discourse data has not been set."
|
||||
|
||||
prompt.insert(0, self.discourse_data)
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=400,
|
||||
system_instruction=SPECIFIC_INSTRUCTION,
|
||||
),
|
||||
)
|
||||
return response.text
|
||||
|
||||
def ask_without_context(self, prompt):
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=400,
|
||||
system_instruction=GENERIC_INSTRUCTION,
|
||||
),
|
||||
)
|
||||
return response.text
|
||||
|
||||
|
||||
async def forward_to_google_api(
|
||||
prompt, bot, image_object=None, reply=None, no_context=False
|
||||
):
|
||||
"""
|
||||
Forwards the message content and optional image object to a Google API.
|
||||
|
||||
Args:
|
||||
prompt (discord.Message): The message object to forward.
|
||||
bot (discord.Client): The Discord bot instance.
|
||||
image_object (tuple, optional): A tuple containing the image URL and its MIME type (e.g., ("url", "image/jpeg")).
|
||||
reply (discord.Message, optional): The message that was referenced by prompt.
|
||||
no_context (bool, optional): If True, the bot will not use any cached content or context.
|
||||
"""
|
||||
if not API_KEY:
|
||||
await prompt.reply(
|
||||
"Google API key is not set. Please contact the administrator.",
|
||||
mention_author=True,
|
||||
)
|
||||
return
|
||||
|
||||
input = [prompt.content]
|
||||
|
||||
# Have the reply come first in the prompt
|
||||
if reply:
|
||||
input.insert(0, reply.content)
|
||||
|
||||
if image_object:
|
||||
try:
|
||||
image_url, mime_type = image_object
|
||||
image = requests.get(image_url)
|
||||
image.raise_for_status()
|
||||
|
||||
# If there is an image, add it to the input before anything else
|
||||
input.insert(
|
||||
0, types.Part.from_bytes(data=image.content, mime_type=mime_type)
|
||||
)
|
||||
except requests.RequestException:
|
||||
await prompt.reply(f"Failed to fetch the image", mention_author=True)
|
||||
return
|
||||
|
||||
response = None
|
||||
|
||||
if no_context:
|
||||
response = bot.ai_helper.ask_without_context(input)
|
||||
else:
|
||||
response = bot.ai_helper.ask_without_cache(input)
|
||||
|
||||
reply_message = await prompt.reply(
|
||||
response,
|
||||
mention_author=True,
|
||||
)
|
||||
@@ -1 +0,0 @@
|
||||
from .handle_request import fetch_cooked_posts, get_topics_by_id, get_topics_by_tag
|
||||
@@ -1,134 +0,0 @@
|
||||
import os
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
from bot.log import logger
|
||||
|
||||
DISCOURSE_BASE_URL = os.getenv("DISCOURSE_BASE_URL")
|
||||
API_KEY = os.getenv("DISCOURSE_API_KEY")
|
||||
API_USERNAME = os.getenv("DISCOURSE_API_USERNAME")
|
||||
|
||||
headers = {"Api-Key": API_KEY, "Api-Username": API_USERNAME}
|
||||
|
||||
|
||||
async def get_topics_by_id(topic_id):
|
||||
"""
|
||||
Async: Fetches a topic by its ID and returns the topic data.
|
||||
|
||||
Args:
|
||||
topic_id (int): The ID of the topic to fetch.
|
||||
|
||||
Returns:
|
||||
dict or None: The topic data if successful, otherwise None.
|
||||
"""
|
||||
url = f"{DISCOURSE_BASE_URL}/t/{topic_id}.json"
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers, timeout=timeout) as response:
|
||||
if response.status == 200:
|
||||
return await response.json()
|
||||
elif response.status == 403:
|
||||
logger.error(
|
||||
f"Access forbidden for topic {topic_id}: {response.status}"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
text = await response.text()
|
||||
logger.error(
|
||||
f"Error fetching topic {topic_id}: {response.status} - {text}"
|
||||
)
|
||||
return None
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout while fetching topic {topic_id}")
|
||||
return None
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Request failed for topic {topic_id}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_topics_by_tag(tag_name):
|
||||
"""
|
||||
Async: Fetches all topics with a specific tag and retrieves the cooked string from each post.
|
||||
|
||||
Args:
|
||||
tag_name (str): The name of the tag to filter topics.
|
||||
|
||||
Returns:
|
||||
list: A list of cooked strings from all posts in the topics.
|
||||
"""
|
||||
url = f"{DISCOURSE_BASE_URL}/tag/{tag_name}.json"
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers, timeout=timeout) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
topics = data.get("topic_list", {}).get("topics", [])
|
||||
cooked_strings = []
|
||||
for topic in topics:
|
||||
topic_id = topic["id"]
|
||||
topic_data = await get_topics_by_id(topic_id)
|
||||
if topic_data:
|
||||
posts = topic_data.get("post_stream", {}).get("posts", [])
|
||||
for post in posts:
|
||||
cooked_strings.append(post.get("cooked", ""))
|
||||
return cooked_strings
|
||||
elif response.status == 403:
|
||||
logger.error(
|
||||
f"Access forbidden for tag '{tag_name}': {response.status}"
|
||||
)
|
||||
return None
|
||||
else:
|
||||
text = await response.text()
|
||||
logger.error(
|
||||
f"Error fetching topics with tag '{tag_name}': {response.status} - {text}"
|
||||
)
|
||||
return []
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timeout while fetching topics with tag '{tag_name}'")
|
||||
return []
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Request failed for topics with tag {tag_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def fetch_cooked_posts(tag_name):
|
||||
"""
|
||||
Async: Fetches cooked strings from posts with a specific tag.
|
||||
|
||||
Args:
|
||||
tag_name (str): The name of the tag to filter topics.
|
||||
|
||||
Returns:
|
||||
list: A list of cooked strings from posts with the specified tag.
|
||||
"""
|
||||
return await get_topics_by_tag(tag_name)
|
||||
|
||||
|
||||
def html_to_text(html_content):
|
||||
"""
|
||||
Cleans the provided HTML content and converts it to plain text.
|
||||
|
||||
Args:
|
||||
html_content (str): The HTML content to clean.
|
||||
|
||||
Returns:
|
||||
str: The cleaned plain text.
|
||||
"""
|
||||
soup = BeautifulSoup(html_content, "html.parser")
|
||||
return soup.get_text(separator="\n").strip()
|
||||
|
||||
|
||||
def combine_posts_text(posts):
|
||||
"""
|
||||
Combines the cooked content of all posts into a single plain text block.
|
||||
|
||||
Args:
|
||||
posts (list): A list of posts, each containing a "cooked" HTML string.
|
||||
|
||||
Returns:
|
||||
str: The combined plain text of all posts.
|
||||
"""
|
||||
return "\n\n".join([html_to_text(post["cooked"]) for post in posts])
|
||||
@@ -4,10 +4,9 @@ from datetime import timedelta
|
||||
|
||||
import discord
|
||||
|
||||
from bot.ai.handle_request import forward_to_google_api
|
||||
from bot.log import logger
|
||||
from bot.utils import aware_utcnow, timeout_member, safe_truncate
|
||||
from database import add_user_to_role, is_user_blacklisted
|
||||
from database import add_user_to_role
|
||||
from bot.mongodb.load_db import DeletedMessage
|
||||
from bot.mongodb.load_db import write_deleted_message_to_collection
|
||||
|
||||
@@ -59,79 +58,6 @@ def fetch_image_from_message(message):
|
||||
return image_object
|
||||
|
||||
|
||||
async def handle_bot_mention(message, bot, no_context=False):
|
||||
staff_role = message.guild.get_role(ADMIN_ROLE_ID)
|
||||
member = message.guild.get_member(message.author.id)
|
||||
|
||||
# Check if the message is in an allowed channel
|
||||
if message.channel.id not in ALLOWED_CHANNELS:
|
||||
logger.debug(
|
||||
f"User {message.author} attempted to use AI in non-allowed channel: {message.channel.name}"
|
||||
)
|
||||
await message.reply(
|
||||
"The AI cannot used in this channel.",
|
||||
mention_author=True,
|
||||
)
|
||||
return True
|
||||
|
||||
if is_user_blacklisted(message.author.id):
|
||||
logger.warning(
|
||||
f"Blacklisted user {message.author} (ID: {message.author.id}) attempted to use AI"
|
||||
)
|
||||
await message.reply(
|
||||
"**Time Travel Required!**\n"
|
||||
"You'll gain access to this feature on **August 12th, 2036**.\n",
|
||||
mention_author=True,
|
||||
)
|
||||
return True
|
||||
|
||||
# Cooldown logic: max 1 use per minute per user
|
||||
now = time.time()
|
||||
user_id = message.author.id
|
||||
timestamps = MENTION_COOLDOWNS.get(user_id, [])
|
||||
# Remove timestamps older than 60 seconds
|
||||
timestamps = [t for t in timestamps if now - t < 60]
|
||||
if len(timestamps) >= 1 and not staff_role in member.roles:
|
||||
await message.reply(
|
||||
"You are using this feature too quickly. Please wait before trying again.",
|
||||
mention_author=True,
|
||||
)
|
||||
return True
|
||||
timestamps.append(now)
|
||||
MENTION_COOLDOWNS[user_id] = timestamps
|
||||
|
||||
# Prioritize the image object from the first message
|
||||
image_object = fetch_image_from_message(message)
|
||||
|
||||
# Check if the message is a reply to another message
|
||||
reply_content = None
|
||||
if message.reference:
|
||||
try:
|
||||
referenced_message = await message.channel.fetch_message(
|
||||
message.reference.message_id
|
||||
)
|
||||
reply_content = referenced_message
|
||||
|
||||
# Check if the referenced message has an image object (if not already set)
|
||||
if image_object is None:
|
||||
image_object = fetch_image_from_message(referenced_message)
|
||||
|
||||
except discord.NotFound:
|
||||
logger.error("Referenced message not found.")
|
||||
except discord.Forbidden:
|
||||
logger.error(
|
||||
"Bot does not have permission to fetch the referenced message."
|
||||
)
|
||||
except discord.HTTPException as e:
|
||||
logger.error(
|
||||
"An error occurred while fetching the referenced message: %s", e
|
||||
)
|
||||
|
||||
# Pass the reply content to forward_to_google_api
|
||||
await forward_to_google_api(message, bot, image_object, reply_content, no_context)
|
||||
return True
|
||||
|
||||
|
||||
async def handle_dm(message):
|
||||
await message.channel.send(
|
||||
"If you DM this bot again, I will carpet-bomb your house."
|
||||
@@ -442,15 +368,6 @@ async def handle_message(message, bot):
|
||||
await handle_dm(message)
|
||||
return
|
||||
|
||||
grok_role = message.guild.get_role(GROK_ROLE_ID)
|
||||
if grok_role in message.role_mentions:
|
||||
if await handle_bot_mention(message, bot, True):
|
||||
return
|
||||
|
||||
if bot.user in message.mentions:
|
||||
if await handle_bot_mention(message, bot):
|
||||
return
|
||||
|
||||
# Too many mentions
|
||||
if len(message.mentions) >= 3:
|
||||
member = message.guild.get_member(message.author.id)
|
||||
|
||||
32
bot/tasks.py
32
bot/tasks.py
@@ -5,7 +5,6 @@ import discord
|
||||
import requests
|
||||
from discord.ext import commands, tasks
|
||||
|
||||
from bot.discourse.handle_request import combine_posts_text, fetch_cooked_posts
|
||||
from bot.log import logger
|
||||
from bot.utils import aware_utcnow, fetch_api_data
|
||||
from bot.mongodb import read_random_message_from_collection
|
||||
@@ -125,36 +124,6 @@ class SteamSaleChecker(commands.Cog):
|
||||
await self.bot.wait_until_ready()
|
||||
|
||||
|
||||
class DiscourseUpdater(commands.Cog):
|
||||
def __init__(self, bot):
|
||||
self.bot = bot
|
||||
self.update_discourse_data.start() # Start the task when the cog is loaded
|
||||
|
||||
def cog_unload(self):
|
||||
self.update_discourse_data.cancel() # Stop the task when the cog is unloaded
|
||||
|
||||
@tasks.loop(hours=6)
|
||||
async def update_discourse_data(self):
|
||||
"""
|
||||
Periodically fetches and updates Discourse data for the bot.
|
||||
"""
|
||||
tag_name = "docs"
|
||||
logger.info("Fetching Discourse data...")
|
||||
cooked_posts = await fetch_cooked_posts(tag_name)
|
||||
if cooked_posts:
|
||||
combined_text = combine_posts_text(
|
||||
[{"cooked": post} for post in cooked_posts]
|
||||
)
|
||||
self.bot.ai_helper.set_discourse_data(combined_text)
|
||||
logger.info("Discourse data updated successfully.")
|
||||
else:
|
||||
logger.warning(f"No posts found for tag '{tag_name}'.")
|
||||
|
||||
@update_discourse_data.before_loop
|
||||
async def before_update_discourse_data(self):
|
||||
await self.bot.wait_until_ready()
|
||||
|
||||
|
||||
async def setup(bot):
|
||||
@tasks.loop(minutes=10)
|
||||
async def update_status():
|
||||
@@ -218,6 +187,5 @@ async def setup(bot):
|
||||
share_dementia_image.start()
|
||||
|
||||
await bot.add_cog(SteamSaleChecker(bot))
|
||||
await bot.add_cog(DiscourseUpdater(bot))
|
||||
|
||||
logger.info("Tasks extension loaded!")
|
||||
|
||||
@@ -3,6 +3,4 @@ requests
|
||||
audioop-lts
|
||||
python-dotenv
|
||||
pynacl
|
||||
google-genai
|
||||
beautifulsoup4
|
||||
pymongo
|
||||
|
||||
Reference in New Issue
Block a user