feat: use the docs as input for the AI

This commit is contained in:
2025-04-15 21:24:13 +02:00
parent 727d717b60
commit 83ed3f0d86
8 changed files with 260 additions and 13 deletions

View File

@@ -6,12 +6,125 @@ from google import genai
API_KEY = os.getenv("GOOGLE_API_KEY")
async def forward_to_google_api(prompt, image_object=None):
class DiscourseSummarizer:
def __init__(self):
self.model = "gemini-2.0-flash"
self.display_name = "alterware"
self.cache = None
self.ttl = "21600s"
self.discourse_data = None
if not API_KEY:
print("Google API key is not set. Please contact the administrator.")
return
self.client = genai.Client(api_key=API_KEY)
def set_discourse_data(self, topic_data):
"""
Sets the discourse data for the summarizer.
Args:
topic_data (str): The combined text of discourse posts.
"""
self.discourse_data = topic_data
def summarize_discourse_topic(self, topic_data, system_instruction=None):
"""
Creates a cache for the discourse topic data.
Args:
topic_data (str): The combined text of discourse posts.
system_instruction (str, optional): Custom system instruction for the model.
"""
self.cache = self.client.caches.create(
model=self.model,
config=types.CreateCachedContentConfig(
display_name=self.display_name,
system_instruction=system_instruction
or (
"You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters."
),
contents=[topic_data],
ttl=self.ttl,
),
)
print(f"Cached content created: {self.cache.name}")
def update_cache(self):
"""
Updates the cache TTL.
"""
if not self.cache:
raise RuntimeError(
"Cache has not been created. Run summarize_discourse_topic first."
)
self.client.caches.update(
name=self.cache.name, config=types.UpdateCachedContentConfig(ttl="21600s")
)
print("Cache updated.")
def ask(self, prompt):
"""
Generates a response using the cached content.
Args:
prompt (str): The user prompt.
Returns:
str: The generated response.
"""
if not self.cache:
raise RuntimeError(
"Cache has not been created. Run summarize_discourse_topic first."
)
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config=types.GenerateContentConfig(
max_output_tokens=400,
system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.",
cached_content=self.cache.name,
),
)
return response.text
def ask_without_cache(self, prompt):
"""
Generates a response without using cached content, including discourse data.
Args:
prompt (str): The user prompt.
Returns:
str: The generated response.
"""
if not self.discourse_data:
raise RuntimeError(
"Discourse data has not been set. Use set_discourse_data first."
)
prompt.insert(0, self.discourse_data)
response = self.client.models.generate_content(
model=self.model,
contents=prompt,
config=types.GenerateContentConfig(
max_output_tokens=400,
system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.",
),
)
return response.text
async def forward_to_google_api(prompt, bot, image_object=None):
"""
Forwards the message content and optional image object to a Google API.
Args:
prompt (discord.Message): The message object to forward.
bot (discord.Client): The Discord bot instance.
image_object (tuple, optional): A tuple containing the image URL and its MIME type (e.g., ("url", "image/jpeg")).
"""
if not API_KEY:
@@ -21,8 +134,6 @@ async def forward_to_google_api(prompt, image_object=None):
)
return
client = genai.Client(api_key=API_KEY)
input = [prompt.content]
if image_object:
try:
@@ -34,16 +145,9 @@ async def forward_to_google_api(prompt, image_object=None):
await prompt.reply(f"Failed to fetch the image", mention_author=True)
return
response = client.models.generate_content(
model="gemini-2.0-flash",
contents=input,
config=types.GenerateContentConfig(
max_output_tokens=400,
system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.",
),
)
response = bot.ai_helper.ask_without_cache(input)
await prompt.reply(
response.text,
response,
mention_author=True,
)

View File

@@ -0,0 +1 @@
from .handle_request import fetch_cooked_posts, get_topics_by_tag, get_topics_by_id

View File

@@ -0,0 +1,103 @@
import requests
import os
from bs4 import BeautifulSoup
DISCOURSE_BASE_URL = os.getenv("DISCOURSE_BASE_URL")
API_KEY = os.getenv("DISCOURSE_API_KEY")
API_USERNAME = os.getenv("DISCOURSE_API_USERNAME")
headers = {"Api-Key": API_KEY, "Api-Username": API_USERNAME}
def get_topics_by_id(topic_id):
"""
Fetches a topic by its ID and returns the topic data.
Args:
topic_id (int): The ID of the topic to fetch.
Returns:
dict or None: The topic data if successful, otherwise None.
"""
response = requests.get(f"{DISCOURSE_BASE_URL}/t/{topic_id}.json", headers=headers)
if response.status_code == 200:
return response.json()
else:
print(
f"Error fetching topic {topic_id}: {response.status_code} - {response.text}"
)
return None
def get_topics_by_tag(tag_name):
"""
Fetches all topics with a specific tag and retrieves the cooked string from each post.
Args:
tag_name (str): The name of the tag to filter topics.
Returns:
list: A list of cooked strings from all posts in the topics.
"""
response = requests.get(
f"{DISCOURSE_BASE_URL}/tag/{tag_name}.json", headers=headers
)
if response.status_code == 200:
data = response.json()
topics = data.get("topic_list", {}).get("topics", [])
cooked_strings = []
for topic in topics:
topic_id = topic["id"]
topic_data = get_topics_by_id(topic_id)
if topic_data:
posts = topic_data.get("post_stream", {}).get("posts", [])
for post in posts:
cooked_strings.append(post.get("cooked", ""))
return cooked_strings
else:
print(
f"Error fetching topics with tag '{tag_name}': {response.status_code} - {response.text}"
)
return []
def fetch_cooked_posts(tag_name):
"""
Fetches cooked strings from posts with a specific tag.
Args:
tag_name (str): The name of the tag to filter topics.
Returns:
list: A list of cooked strings from posts with the specified tag.
"""
return get_topics_by_tag(tag_name)
def html_to_text(html_content):
"""
Cleans the provided HTML content and converts it to plain text.
Args:
html_content (str): The HTML content to clean.
Returns:
str: The cleaned plain text.
"""
soup = BeautifulSoup(html_content, "html.parser")
return soup.get_text(separator="\n").strip()
def combine_posts_text(posts):
"""
Combines the cooked content of all posts into a single plain text block.
Args:
posts (list): A list of posts, each containing a "cooked" HTML string.
Returns:
str: The combined plain text of all posts.
"""
return "\n\n".join([html_to_text(post["cooked"]) for post in posts])

View File

@@ -329,7 +329,7 @@ async def handle_message(message, bot):
image_object = (attachment.url, "image/png")
break
await forward_to_google_api(message, image_object)
await forward_to_google_api(message, bot, image_object)
return
# Too many mentions

View File

@@ -5,6 +5,7 @@ import discord
from discord.ext import tasks, commands
from bot.utils import aware_utcnow, fetch_api_data
from bot.discourse.handle_request import fetch_cooked_posts, combine_posts_text
from database import migrate_users_with_role
@@ -116,6 +117,36 @@ class SteamSaleChecker(commands.Cog):
await self.bot.wait_until_ready()
class DiscourseUpdater(commands.Cog):
def __init__(self, bot):
self.bot = bot
self.update_discourse_data.start() # Start the task when the cog is loaded
def cog_unload(self):
self.update_discourse_data.cancel() # Stop the task when the cog is unloaded
@tasks.loop(hours=6)
async def update_discourse_data(self):
"""
Periodically fetches and updates Discourse data for the bot.
"""
tag_name = "docs"
print("Fetching Discourse data...")
cooked_posts = fetch_cooked_posts(tag_name)
if cooked_posts:
combined_text = combine_posts_text(
[{"cooked": post} for post in cooked_posts]
)
self.bot.ai_helper.set_discourse_data(combined_text)
print("Discourse data updated successfully.")
else:
print(f"No posts found for tag '{tag_name}'.")
@update_discourse_data.before_loop
async def before_update_discourse_data(self):
await self.bot.wait_until_ready()
async def setup(bot):
@tasks.loop(minutes=10)
async def update_status():
@@ -152,5 +183,6 @@ async def setup(bot):
heat_death.start()
await bot.add_cog(SteamSaleChecker(bot))
await bot.add_cog(DiscourseUpdater(bot))
print("Tasks extension loaded!")