mirror of
				https://github.com/alterware/aw-bot.git
				synced 2025-10-26 14:15:54 +00:00 
			
		
		
		
	feat: use the docs as input for the AI
This commit is contained in:
		| @@ -6,12 +6,125 @@ from google import genai | ||||
| API_KEY = os.getenv("GOOGLE_API_KEY") | ||||
|  | ||||
|  | ||||
| async def forward_to_google_api(prompt, image_object=None): | ||||
| class DiscourseSummarizer: | ||||
|     def __init__(self): | ||||
|         self.model = "gemini-2.0-flash" | ||||
|         self.display_name = "alterware" | ||||
|         self.cache = None | ||||
|         self.ttl = "21600s" | ||||
|         self.discourse_data = None | ||||
|  | ||||
|         if not API_KEY: | ||||
|             print("Google API key is not set. Please contact the administrator.") | ||||
|             return | ||||
|  | ||||
|         self.client = genai.Client(api_key=API_KEY) | ||||
|  | ||||
|     def set_discourse_data(self, topic_data): | ||||
|         """ | ||||
|         Sets the discourse data for the summarizer. | ||||
|  | ||||
|         Args: | ||||
|             topic_data (str): The combined text of discourse posts. | ||||
|         """ | ||||
|         self.discourse_data = topic_data | ||||
|  | ||||
|     def summarize_discourse_topic(self, topic_data, system_instruction=None): | ||||
|         """ | ||||
|         Creates a cache for the discourse topic data. | ||||
|  | ||||
|         Args: | ||||
|             topic_data (str): The combined text of discourse posts. | ||||
|             system_instruction (str, optional): Custom system instruction for the model. | ||||
|         """ | ||||
|         self.cache = self.client.caches.create( | ||||
|             model=self.model, | ||||
|             config=types.CreateCachedContentConfig( | ||||
|                 display_name=self.display_name, | ||||
|                 system_instruction=system_instruction | ||||
|                 or ( | ||||
|                     "You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters." | ||||
|                 ), | ||||
|                 contents=[topic_data], | ||||
|                 ttl=self.ttl, | ||||
|             ), | ||||
|         ) | ||||
|         print(f"Cached content created: {self.cache.name}") | ||||
|  | ||||
|     def update_cache(self): | ||||
|         """ | ||||
|         Updates the cache TTL. | ||||
|         """ | ||||
|         if not self.cache: | ||||
|             raise RuntimeError( | ||||
|                 "Cache has not been created. Run summarize_discourse_topic first." | ||||
|             ) | ||||
|  | ||||
|         self.client.caches.update( | ||||
|             name=self.cache.name, config=types.UpdateCachedContentConfig(ttl="21600s") | ||||
|         ) | ||||
|         print("Cache updated.") | ||||
|  | ||||
|     def ask(self, prompt): | ||||
|         """ | ||||
|         Generates a response using the cached content. | ||||
|  | ||||
|         Args: | ||||
|             prompt (str): The user prompt. | ||||
|  | ||||
|         Returns: | ||||
|             str: The generated response. | ||||
|         """ | ||||
|         if not self.cache: | ||||
|             raise RuntimeError( | ||||
|                 "Cache has not been created. Run summarize_discourse_topic first." | ||||
|             ) | ||||
|  | ||||
|         response = self.client.models.generate_content( | ||||
|             model=self.model, | ||||
|             contents=prompt, | ||||
|             config=types.GenerateContentConfig( | ||||
|                 max_output_tokens=400, | ||||
|                 system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.", | ||||
|                 cached_content=self.cache.name, | ||||
|             ), | ||||
|         ) | ||||
|         return response.text | ||||
|  | ||||
|     def ask_without_cache(self, prompt): | ||||
|         """ | ||||
|         Generates a response without using cached content, including discourse data. | ||||
|  | ||||
|         Args: | ||||
|             prompt (str): The user prompt. | ||||
|  | ||||
|         Returns: | ||||
|             str: The generated response. | ||||
|         """ | ||||
|         if not self.discourse_data: | ||||
|             raise RuntimeError( | ||||
|                 "Discourse data has not been set. Use set_discourse_data first." | ||||
|             ) | ||||
|  | ||||
|         prompt.insert(0, self.discourse_data) | ||||
|         response = self.client.models.generate_content( | ||||
|             model=self.model, | ||||
|             contents=prompt, | ||||
|             config=types.GenerateContentConfig( | ||||
|                 max_output_tokens=400, | ||||
|                 system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.", | ||||
|             ), | ||||
|         ) | ||||
|         return response.text | ||||
|  | ||||
|  | ||||
| async def forward_to_google_api(prompt, bot, image_object=None): | ||||
|     """ | ||||
|     Forwards the message content and optional image object to a Google API. | ||||
|  | ||||
|     Args: | ||||
|         prompt (discord.Message): The message object to forward. | ||||
|         bot (discord.Client): The Discord bot instance. | ||||
|         image_object (tuple, optional): A tuple containing the image URL and its MIME type (e.g., ("url", "image/jpeg")). | ||||
|     """ | ||||
|     if not API_KEY: | ||||
| @@ -21,8 +134,6 @@ async def forward_to_google_api(prompt, image_object=None): | ||||
|         ) | ||||
|         return | ||||
|  | ||||
|     client = genai.Client(api_key=API_KEY) | ||||
|  | ||||
|     input = [prompt.content] | ||||
|     if image_object: | ||||
|         try: | ||||
| @@ -34,16 +145,9 @@ async def forward_to_google_api(prompt, image_object=None): | ||||
|             await prompt.reply(f"Failed to fetch the image", mention_author=True) | ||||
|             return | ||||
|  | ||||
|     response = client.models.generate_content( | ||||
|         model="gemini-2.0-flash", | ||||
|         contents=input, | ||||
|         config=types.GenerateContentConfig( | ||||
|             max_output_tokens=400, | ||||
|             system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.", | ||||
|         ), | ||||
|     ) | ||||
|     response = bot.ai_helper.ask_without_cache(input) | ||||
|  | ||||
|     await prompt.reply( | ||||
|         response.text, | ||||
|         response, | ||||
|         mention_author=True, | ||||
|     ) | ||||
|   | ||||
							
								
								
									
										1
									
								
								bot/discourse/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								bot/discourse/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| from .handle_request import fetch_cooked_posts, get_topics_by_tag, get_topics_by_id | ||||
							
								
								
									
										103
									
								
								bot/discourse/handle_request.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								bot/discourse/handle_request.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,103 @@ | ||||
| import requests | ||||
| import os | ||||
|  | ||||
| from bs4 import BeautifulSoup | ||||
|  | ||||
| DISCOURSE_BASE_URL = os.getenv("DISCOURSE_BASE_URL") | ||||
| API_KEY = os.getenv("DISCOURSE_API_KEY") | ||||
| API_USERNAME = os.getenv("DISCOURSE_API_USERNAME") | ||||
|  | ||||
| headers = {"Api-Key": API_KEY, "Api-Username": API_USERNAME} | ||||
|  | ||||
|  | ||||
| def get_topics_by_id(topic_id): | ||||
|     """ | ||||
|     Fetches a topic by its ID and returns the topic data. | ||||
|  | ||||
|     Args: | ||||
|         topic_id (int): The ID of the topic to fetch. | ||||
|  | ||||
|     Returns: | ||||
|         dict or None: The topic data if successful, otherwise None. | ||||
|     """ | ||||
|     response = requests.get(f"{DISCOURSE_BASE_URL}/t/{topic_id}.json", headers=headers) | ||||
|     if response.status_code == 200: | ||||
|         return response.json() | ||||
|     else: | ||||
|         print( | ||||
|             f"Error fetching topic {topic_id}: {response.status_code} - {response.text}" | ||||
|         ) | ||||
|         return None | ||||
|  | ||||
|  | ||||
| def get_topics_by_tag(tag_name): | ||||
|     """ | ||||
|     Fetches all topics with a specific tag and retrieves the cooked string from each post. | ||||
|  | ||||
|     Args: | ||||
|         tag_name (str): The name of the tag to filter topics. | ||||
|  | ||||
|     Returns: | ||||
|         list: A list of cooked strings from all posts in the topics. | ||||
|     """ | ||||
|     response = requests.get( | ||||
|         f"{DISCOURSE_BASE_URL}/tag/{tag_name}.json", headers=headers | ||||
|     ) | ||||
|     if response.status_code == 200: | ||||
|         data = response.json() | ||||
|         topics = data.get("topic_list", {}).get("topics", []) | ||||
|         cooked_strings = [] | ||||
|  | ||||
|         for topic in topics: | ||||
|             topic_id = topic["id"] | ||||
|             topic_data = get_topics_by_id(topic_id) | ||||
|             if topic_data: | ||||
|                 posts = topic_data.get("post_stream", {}).get("posts", []) | ||||
|                 for post in posts: | ||||
|                     cooked_strings.append(post.get("cooked", "")) | ||||
|         return cooked_strings | ||||
|     else: | ||||
|         print( | ||||
|             f"Error fetching topics with tag '{tag_name}': {response.status_code} - {response.text}" | ||||
|         ) | ||||
|         return [] | ||||
|  | ||||
|  | ||||
| def fetch_cooked_posts(tag_name): | ||||
|     """ | ||||
|     Fetches cooked strings from posts with a specific tag. | ||||
|  | ||||
|     Args: | ||||
|         tag_name (str): The name of the tag to filter topics. | ||||
|  | ||||
|     Returns: | ||||
|         list: A list of cooked strings from posts with the specified tag. | ||||
|     """ | ||||
|     return get_topics_by_tag(tag_name) | ||||
|  | ||||
|  | ||||
| def html_to_text(html_content): | ||||
|     """ | ||||
|     Cleans the provided HTML content and converts it to plain text. | ||||
|  | ||||
|     Args: | ||||
|         html_content (str): The HTML content to clean. | ||||
|  | ||||
|     Returns: | ||||
|         str: The cleaned plain text. | ||||
|     """ | ||||
|     soup = BeautifulSoup(html_content, "html.parser") | ||||
|     return soup.get_text(separator="\n").strip() | ||||
|  | ||||
|  | ||||
| def combine_posts_text(posts): | ||||
|     """ | ||||
|     Combines the cooked content of all posts into a single plain text block. | ||||
|  | ||||
|     Args: | ||||
|         posts (list): A list of posts, each containing a "cooked" HTML string. | ||||
|  | ||||
|     Returns: | ||||
|         str: The combined plain text of all posts. | ||||
|     """ | ||||
|     return "\n\n".join([html_to_text(post["cooked"]) for post in posts]) | ||||
| @@ -329,7 +329,7 @@ async def handle_message(message, bot): | ||||
|                     image_object = (attachment.url, "image/png") | ||||
|                     break | ||||
|  | ||||
|             await forward_to_google_api(message, image_object) | ||||
|             await forward_to_google_api(message, bot, image_object) | ||||
|             return | ||||
|  | ||||
|     # Too many mentions | ||||
|   | ||||
							
								
								
									
										32
									
								
								bot/tasks.py
									
									
									
									
									
								
							
							
						
						
									
										32
									
								
								bot/tasks.py
									
									
									
									
									
								
							| @@ -5,6 +5,7 @@ import discord | ||||
| from discord.ext import tasks, commands | ||||
|  | ||||
| from bot.utils import aware_utcnow, fetch_api_data | ||||
| from bot.discourse.handle_request import fetch_cooked_posts, combine_posts_text | ||||
|  | ||||
| from database import migrate_users_with_role | ||||
|  | ||||
| @@ -116,6 +117,36 @@ class SteamSaleChecker(commands.Cog): | ||||
|         await self.bot.wait_until_ready() | ||||
|  | ||||
|  | ||||
| class DiscourseUpdater(commands.Cog): | ||||
|     def __init__(self, bot): | ||||
|         self.bot = bot | ||||
|         self.update_discourse_data.start()  # Start the task when the cog is loaded | ||||
|  | ||||
|     def cog_unload(self): | ||||
|         self.update_discourse_data.cancel()  # Stop the task when the cog is unloaded | ||||
|  | ||||
|     @tasks.loop(hours=6) | ||||
|     async def update_discourse_data(self): | ||||
|         """ | ||||
|         Periodically fetches and updates Discourse data for the bot. | ||||
|         """ | ||||
|         tag_name = "docs" | ||||
|         print("Fetching Discourse data...") | ||||
|         cooked_posts = fetch_cooked_posts(tag_name) | ||||
|         if cooked_posts: | ||||
|             combined_text = combine_posts_text( | ||||
|                 [{"cooked": post} for post in cooked_posts] | ||||
|             ) | ||||
|             self.bot.ai_helper.set_discourse_data(combined_text) | ||||
|             print("Discourse data updated successfully.") | ||||
|         else: | ||||
|             print(f"No posts found for tag '{tag_name}'.") | ||||
|  | ||||
|     @update_discourse_data.before_loop | ||||
|     async def before_update_discourse_data(self): | ||||
|         await self.bot.wait_until_ready() | ||||
|  | ||||
|  | ||||
| async def setup(bot): | ||||
|     @tasks.loop(minutes=10) | ||||
|     async def update_status(): | ||||
| @@ -152,5 +183,6 @@ async def setup(bot): | ||||
|     heat_death.start() | ||||
|  | ||||
|     await bot.add_cog(SteamSaleChecker(bot)) | ||||
|     await bot.add_cog(DiscourseUpdater(bot)) | ||||
|  | ||||
|     print("Tasks extension loaded!") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user