mirror of
https://github.com/alterware/aw-bot.git
synced 2025-10-26 22:16:58 +00:00
feat: use the docs as input for the AI
This commit is contained in:
@@ -6,12 +6,125 @@ from google import genai
|
||||
API_KEY = os.getenv("GOOGLE_API_KEY")
|
||||
|
||||
|
||||
async def forward_to_google_api(prompt, image_object=None):
|
||||
class DiscourseSummarizer:
|
||||
def __init__(self):
|
||||
self.model = "gemini-2.0-flash"
|
||||
self.display_name = "alterware"
|
||||
self.cache = None
|
||||
self.ttl = "21600s"
|
||||
self.discourse_data = None
|
||||
|
||||
if not API_KEY:
|
||||
print("Google API key is not set. Please contact the administrator.")
|
||||
return
|
||||
|
||||
self.client = genai.Client(api_key=API_KEY)
|
||||
|
||||
def set_discourse_data(self, topic_data):
|
||||
"""
|
||||
Sets the discourse data for the summarizer.
|
||||
|
||||
Args:
|
||||
topic_data (str): The combined text of discourse posts.
|
||||
"""
|
||||
self.discourse_data = topic_data
|
||||
|
||||
def summarize_discourse_topic(self, topic_data, system_instruction=None):
|
||||
"""
|
||||
Creates a cache for the discourse topic data.
|
||||
|
||||
Args:
|
||||
topic_data (str): The combined text of discourse posts.
|
||||
system_instruction (str, optional): Custom system instruction for the model.
|
||||
"""
|
||||
self.cache = self.client.caches.create(
|
||||
model=self.model,
|
||||
config=types.CreateCachedContentConfig(
|
||||
display_name=self.display_name,
|
||||
system_instruction=system_instruction
|
||||
or (
|
||||
"You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters."
|
||||
),
|
||||
contents=[topic_data],
|
||||
ttl=self.ttl,
|
||||
),
|
||||
)
|
||||
print(f"Cached content created: {self.cache.name}")
|
||||
|
||||
def update_cache(self):
|
||||
"""
|
||||
Updates the cache TTL.
|
||||
"""
|
||||
if not self.cache:
|
||||
raise RuntimeError(
|
||||
"Cache has not been created. Run summarize_discourse_topic first."
|
||||
)
|
||||
|
||||
self.client.caches.update(
|
||||
name=self.cache.name, config=types.UpdateCachedContentConfig(ttl="21600s")
|
||||
)
|
||||
print("Cache updated.")
|
||||
|
||||
def ask(self, prompt):
|
||||
"""
|
||||
Generates a response using the cached content.
|
||||
|
||||
Args:
|
||||
prompt (str): The user prompt.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
if not self.cache:
|
||||
raise RuntimeError(
|
||||
"Cache has not been created. Run summarize_discourse_topic first."
|
||||
)
|
||||
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=400,
|
||||
system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.",
|
||||
cached_content=self.cache.name,
|
||||
),
|
||||
)
|
||||
return response.text
|
||||
|
||||
def ask_without_cache(self, prompt):
|
||||
"""
|
||||
Generates a response without using cached content, including discourse data.
|
||||
|
||||
Args:
|
||||
prompt (str): The user prompt.
|
||||
|
||||
Returns:
|
||||
str: The generated response.
|
||||
"""
|
||||
if not self.discourse_data:
|
||||
raise RuntimeError(
|
||||
"Discourse data has not been set. Use set_discourse_data first."
|
||||
)
|
||||
|
||||
prompt.insert(0, self.discourse_data)
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=prompt,
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=400,
|
||||
system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.",
|
||||
),
|
||||
)
|
||||
return response.text
|
||||
|
||||
|
||||
async def forward_to_google_api(prompt, bot, image_object=None):
|
||||
"""
|
||||
Forwards the message content and optional image object to a Google API.
|
||||
|
||||
Args:
|
||||
prompt (discord.Message): The message object to forward.
|
||||
bot (discord.Client): The Discord bot instance.
|
||||
image_object (tuple, optional): A tuple containing the image URL and its MIME type (e.g., ("url", "image/jpeg")).
|
||||
"""
|
||||
if not API_KEY:
|
||||
@@ -21,8 +134,6 @@ async def forward_to_google_api(prompt, image_object=None):
|
||||
)
|
||||
return
|
||||
|
||||
client = genai.Client(api_key=API_KEY)
|
||||
|
||||
input = [prompt.content]
|
||||
if image_object:
|
||||
try:
|
||||
@@ -34,16 +145,9 @@ async def forward_to_google_api(prompt, image_object=None):
|
||||
await prompt.reply(f"Failed to fetch the image", mention_author=True)
|
||||
return
|
||||
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash",
|
||||
contents=input,
|
||||
config=types.GenerateContentConfig(
|
||||
max_output_tokens=400,
|
||||
system_instruction="You are a Discord chat bot named 'AlterWare' who helps users. You should limit your answers to be less than 2000 characters.",
|
||||
),
|
||||
)
|
||||
response = bot.ai_helper.ask_without_cache(input)
|
||||
|
||||
await prompt.reply(
|
||||
response.text,
|
||||
response,
|
||||
mention_author=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user