From 83ed3f0d8644767c8055a61bea08d6b7fcfbd4ac Mon Sep 17 00:00:00 2001 From: diamante0018 Date: Tue, 15 Apr 2025 21:24:13 +0200 Subject: [PATCH] feat: use the docs as input for the AI --- Dockerfile | 3 + aw.py | 3 + bot/ai/handle_request.py | 128 +++++++++++++++++++++++--- bot/discourse/__init__.py | 1 + bot/discourse/handle_request.py | 103 +++++++++++++++++++++ bot/events_handlers/message_events.py | 2 +- bot/tasks.py | 32 +++++++ requirements.txt | 1 + 8 files changed, 260 insertions(+), 13 deletions(-) create mode 100644 bot/discourse/__init__.py create mode 100644 bot/discourse/handle_request.py diff --git a/Dockerfile b/Dockerfile index b65846c..3e4b05f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -18,6 +18,9 @@ COPY LICENSE . ENV BOT_TOKEN="" ENV GOOGLE_API_KEY="" +ENV DISCOURSE_API_KEY="" +ENV DISCOURSE_BASE_URL="" +ENV DISCOURSE_USERNAME="" # Where the database will be stored ENV BOT_DATA_DIR="" diff --git a/aw.py b/aw.py index 4774806..42219f1 100644 --- a/aw.py +++ b/aw.py @@ -5,6 +5,7 @@ import discord from discord.ext import commands from database import initialize_db +from bot.ai.handle_request import DiscourseSummarizer GUILD_ID = 1110531063161299074 BOT_LOG = 1112049391482703873 @@ -19,6 +20,8 @@ load_dotenv(override=True) initialize_db() +bot.ai_helper = DiscourseSummarizer() + @bot.event async def on_ready(): diff --git a/bot/ai/handle_request.py b/bot/ai/handle_request.py index acdfe32..351771b 100644 --- a/bot/ai/handle_request.py +++ b/bot/ai/handle_request.py @@ -6,12 +6,125 @@ from google import genai API_KEY = os.getenv("GOOGLE_API_KEY") -async def forward_to_google_api(prompt, image_object=None): +class DiscourseSummarizer: + def __init__(self): + self.model = "gemini-2.0-flash" + self.display_name = "alterware" + self.cache = None + self.ttl = "21600s" + self.discourse_data = None + + if not API_KEY: + print("Google API key is not set. Please contact the administrator.") + return + + self.client = genai.Client(api_key=API_KEY) + + def set_discourse_data(self, topic_data): + """ + Sets the discourse data for the summarizer. + + Args: + topic_data (str): The combined text of discourse posts. + """ + self.discourse_data = topic_data + + def summarize_discourse_topic(self, topic_data, system_instruction=None): + """ + Creates a cache for the discourse topic data. + + Args: + topic_data (str): The combined text of discourse posts. + system_instruction (str, optional): Custom system instruction for the model. + """ + self.cache = self.client.caches.create( + model=self.model, + config=types.CreateCachedContentConfig( + display_name=self.display_name, + system_instruction=system_instruction + or ( + "You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters." + ), + contents=[topic_data], + ttl=self.ttl, + ), + ) + print(f"Cached content created: {self.cache.name}") + + def update_cache(self): + """ + Updates the cache TTL. + """ + if not self.cache: + raise RuntimeError( + "Cache has not been created. Run summarize_discourse_topic first." + ) + + self.client.caches.update( + name=self.cache.name, config=types.UpdateCachedContentConfig(ttl="21600s") + ) + print("Cache updated.") + + def ask(self, prompt): + """ + Generates a response using the cached content. + + Args: + prompt (str): The user prompt. + + Returns: + str: The generated response. + """ + if not self.cache: + raise RuntimeError( + "Cache has not been created. Run summarize_discourse_topic first." + ) + + response = self.client.models.generate_content( + model=self.model, + contents=prompt, + config=types.GenerateContentConfig( + max_output_tokens=400, + system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.", + cached_content=self.cache.name, + ), + ) + return response.text + + def ask_without_cache(self, prompt): + """ + Generates a response without using cached content, including discourse data. + + Args: + prompt (str): The user prompt. + + Returns: + str: The generated response. + """ + if not self.discourse_data: + raise RuntimeError( + "Discourse data has not been set. Use set_discourse_data first." + ) + + prompt.insert(0, self.discourse_data) + response = self.client.models.generate_content( + model=self.model, + contents=prompt, + config=types.GenerateContentConfig( + max_output_tokens=400, + system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.", + ), + ) + return response.text + + +async def forward_to_google_api(prompt, bot, image_object=None): """ Forwards the message content and optional image object to a Google API. Args: prompt (discord.Message): The message object to forward. + bot (discord.Client): The Discord bot instance. image_object (tuple, optional): A tuple containing the image URL and its MIME type (e.g., ("url", "image/jpeg")). """ if not API_KEY: @@ -21,8 +134,6 @@ async def forward_to_google_api(prompt, image_object=None): ) return - client = genai.Client(api_key=API_KEY) - input = [prompt.content] if image_object: try: @@ -34,16 +145,9 @@ async def forward_to_google_api(prompt, image_object=None): await prompt.reply(f"Failed to fetch the image", mention_author=True) return - response = client.models.generate_content( - model="gemini-2.0-flash", - contents=input, - config=types.GenerateContentConfig( - max_output_tokens=400, - system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.", - ), - ) + response = bot.ai_helper.ask_without_cache(input) await prompt.reply( - response.text, + response, mention_author=True, ) diff --git a/bot/discourse/__init__.py b/bot/discourse/__init__.py new file mode 100644 index 0000000..f08ff3d --- /dev/null +++ b/bot/discourse/__init__.py @@ -0,0 +1 @@ +from .handle_request import fetch_cooked_posts, get_topics_by_tag, get_topics_by_id diff --git a/bot/discourse/handle_request.py b/bot/discourse/handle_request.py new file mode 100644 index 0000000..51fd166 --- /dev/null +++ b/bot/discourse/handle_request.py @@ -0,0 +1,103 @@ +import requests +import os + +from bs4 import BeautifulSoup + +DISCOURSE_BASE_URL = os.getenv("DISCOURSE_BASE_URL") +API_KEY = os.getenv("DISCOURSE_API_KEY") +API_USERNAME = os.getenv("DISCOURSE_API_USERNAME") + +headers = {"Api-Key": API_KEY, "Api-Username": API_USERNAME} + + +def get_topics_by_id(topic_id): + """ + Fetches a topic by its ID and returns the topic data. + + Args: + topic_id (int): The ID of the topic to fetch. + + Returns: + dict or None: The topic data if successful, otherwise None. + """ + response = requests.get(f"{DISCOURSE_BASE_URL}/t/{topic_id}.json", headers=headers) + if response.status_code == 200: + return response.json() + else: + print( + f"Error fetching topic {topic_id}: {response.status_code} - {response.text}" + ) + return None + + +def get_topics_by_tag(tag_name): + """ + Fetches all topics with a specific tag and retrieves the cooked string from each post. + + Args: + tag_name (str): The name of the tag to filter topics. + + Returns: + list: A list of cooked strings from all posts in the topics. + """ + response = requests.get( + f"{DISCOURSE_BASE_URL}/tag/{tag_name}.json", headers=headers + ) + if response.status_code == 200: + data = response.json() + topics = data.get("topic_list", {}).get("topics", []) + cooked_strings = [] + + for topic in topics: + topic_id = topic["id"] + topic_data = get_topics_by_id(topic_id) + if topic_data: + posts = topic_data.get("post_stream", {}).get("posts", []) + for post in posts: + cooked_strings.append(post.get("cooked", "")) + return cooked_strings + else: + print( + f"Error fetching topics with tag '{tag_name}': {response.status_code} - {response.text}" + ) + return [] + + +def fetch_cooked_posts(tag_name): + """ + Fetches cooked strings from posts with a specific tag. + + Args: + tag_name (str): The name of the tag to filter topics. + + Returns: + list: A list of cooked strings from posts with the specified tag. + """ + return get_topics_by_tag(tag_name) + + +def html_to_text(html_content): + """ + Cleans the provided HTML content and converts it to plain text. + + Args: + html_content (str): The HTML content to clean. + + Returns: + str: The cleaned plain text. + """ + soup = BeautifulSoup(html_content, "html.parser") + return soup.get_text(separator="\n").strip() + + +def combine_posts_text(posts): + """ + Combines the cooked content of all posts into a single plain text block. + + Args: + posts (list): A list of posts, each containing a "cooked" HTML string. + + Returns: + str: The combined plain text of all posts. + """ + return "\n\n".join([html_to_text(post["cooked"]) for post in posts]) diff --git a/bot/events_handlers/message_events.py b/bot/events_handlers/message_events.py index 6fbc13b..9e55c19 100644 --- a/bot/events_handlers/message_events.py +++ b/bot/events_handlers/message_events.py @@ -329,7 +329,7 @@ async def handle_message(message, bot): image_object = (attachment.url, "image/png") break - await forward_to_google_api(message, image_object) + await forward_to_google_api(message, bot, image_object) return # Too many mentions diff --git a/bot/tasks.py b/bot/tasks.py index 7b123d4..a1e14b1 100644 --- a/bot/tasks.py +++ b/bot/tasks.py @@ -5,6 +5,7 @@ import discord from discord.ext import tasks, commands from bot.utils import aware_utcnow, fetch_api_data +from bot.discourse.handle_request import fetch_cooked_posts, combine_posts_text from database import migrate_users_with_role @@ -116,6 +117,36 @@ class SteamSaleChecker(commands.Cog): await self.bot.wait_until_ready() +class DiscourseUpdater(commands.Cog): + def __init__(self, bot): + self.bot = bot + self.update_discourse_data.start() # Start the task when the cog is loaded + + def cog_unload(self): + self.update_discourse_data.cancel() # Stop the task when the cog is unloaded + + @tasks.loop(hours=6) + async def update_discourse_data(self): + """ + Periodically fetches and updates Discourse data for the bot. + """ + tag_name = "docs" + print("Fetching Discourse data...") + cooked_posts = fetch_cooked_posts(tag_name) + if cooked_posts: + combined_text = combine_posts_text( + [{"cooked": post} for post in cooked_posts] + ) + self.bot.ai_helper.set_discourse_data(combined_text) + print("Discourse data updated successfully.") + else: + print(f"No posts found for tag '{tag_name}'.") + + @update_discourse_data.before_loop + async def before_update_discourse_data(self): + await self.bot.wait_until_ready() + + async def setup(bot): @tasks.loop(minutes=10) async def update_status(): @@ -152,5 +183,6 @@ async def setup(bot): heat_death.start() await bot.add_cog(SteamSaleChecker(bot)) + await bot.add_cog(DiscourseUpdater(bot)) print("Tasks extension loaded!") diff --git a/requirements.txt b/requirements.txt index 43e325d..9c3c1b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ audioop-lts python-dotenv pynacl google-genai +beautifulsoup4