-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
apps/ai_reports: connect with external API
- add a celery task which sends a comment to the XAI server and saves the response as AiReport - connect comment post_save signal with celery task - rename category to label to be in line with the XAI response - change AiReport fields to JSONField for now - add tests
- Loading branch information
Showing
15 changed files
with
205 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
33 changes: 33 additions & 0 deletions
33
apps/ai_reports/migrations/0003_remove_aireport_category_aireport_label_and_more.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Generated by Django 4.2.13 on 2024-05-30 09:19 | ||
|
||
from django.db import migrations, models | ||
|
||
|
||
class Migration(migrations.Migration): | ||
|
||
dependencies = [ | ||
("ai_reports", "0002_aireport_show_in_discussion"), | ||
] | ||
|
||
operations = [ | ||
migrations.RemoveField( | ||
model_name="aireport", | ||
name="category", | ||
), | ||
migrations.AddField( | ||
model_name="aireport", | ||
name="label", | ||
field=models.JSONField(default=""), | ||
preserve_default=False, | ||
), | ||
migrations.AlterField( | ||
model_name="aireport", | ||
name="confidence", | ||
field=models.JSONField(), | ||
), | ||
migrations.AlterField( | ||
model_name="aireport", | ||
name="explanation", | ||
field=models.JSONField(), | ||
), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,24 @@ | ||
import json | ||
|
||
from rest_framework import serializers | ||
|
||
from apps.ai_reports.models import AiReport | ||
|
||
|
||
class AiReportSerializer(serializers.ModelSerializer): | ||
explanation = serializers.SerializerMethodField() | ||
|
||
class Meta: | ||
model = AiReport | ||
fields = ( | ||
"category", | ||
"label", | ||
"confidence", | ||
"explanation", | ||
"is_pending", | ||
"comment", | ||
"show_in_discussion", | ||
) | ||
|
||
# FIXME: remove once frontend knows what to do with this | ||
def get_explanation(self, ai_report: AiReport) -> str: | ||
return json.dumps(ai_report.explanation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import logging | ||
|
||
import httpx | ||
from django.conf import settings | ||
from django.db.models import signals | ||
from django.dispatch import receiver | ||
|
||
from adhocracy4.comments.models import Comment | ||
from apps.ai_reports import tasks | ||
|
||
client = httpx.Client() | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@receiver(signals.post_save, sender=Comment) | ||
def get_ai_classification(sender, instance, created, update_fields, **kwargs): | ||
if getattr(settings, "XAI_ENABLE", False): | ||
if getattr(settings, "XAI_API_URL", None): | ||
comment_text_changed = getattr(instance, "_former_comment") != getattr( | ||
instance, "comment" | ||
) | ||
if created or comment_text_changed: | ||
# FIXME: use delay_on_commit() once updated to celery 5.x | ||
tasks.get_classification_for_comment.delay(instance.pk) | ||
else: | ||
logger.error("no xai api url provided") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import ast | ||
import logging | ||
|
||
import backoff | ||
import httpx | ||
from celery import shared_task | ||
from django.conf import settings | ||
|
||
from adhocracy4.comments.models import Comment | ||
from apps.ai_reports.models import AiReport | ||
|
||
client = httpx.Client() | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@shared_task | ||
def get_classification_for_comment(comment_pk: int) -> None: | ||
try: | ||
comment = Comment.objects.get(pk=comment_pk) | ||
response = call_ai_api(comment=comment) | ||
if response.status_code == 200: | ||
extract_and_save_ai_classifications(comment=comment, report=response.json()) | ||
except httpx.HTTPError as e: | ||
logger.error("Error connecting to %s: %s", settings.AI_API_URL, str(e)) | ||
|
||
|
||
def skip_retry(e: Exception) -> bool: | ||
if isinstance(e, httpx.HTTPStatusError): | ||
return 400 <= e.response.status_code < 500 | ||
return False | ||
|
||
|
||
@backoff.on_exception( | ||
backoff.expo, httpx.HTTPError, max_tries=4, factor=2, giveup=skip_retry | ||
) | ||
def call_ai_api(comment: Comment) -> httpx.Response: | ||
response = client.post( | ||
settings.XAI_API_URL, | ||
json={"comment": comment.comment}, | ||
headers={"Accept": "application/json", "Content-Type": "application/json"}, | ||
timeout=25.0, | ||
) | ||
response.raise_for_status() | ||
return response | ||
|
||
|
||
def extract_and_save_ai_classifications(comment: Comment, report: dict) -> None: | ||
# FIXME: the data returned from the api is not actually valid json, so we need | ||
# to use ast to explicitly convert it. This should be fixed on their side. | ||
confidence = ast.literal_eval((report["confidence"])) | ||
label = ast.literal_eval(report["label"]) | ||
explanation = ast.literal_eval(report["explanation"]) | ||
|
||
ai_report = AiReport( | ||
comment=comment, confidence=confidence, label=label, explanation=explanation | ||
) | ||
ai_report.save() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ pytest==7.3.2 | |
pytest-cov==4.1.0 | ||
pytest-django==4.5.2 | ||
pytest-factoryboy==2.5.1 | ||
pytest-mock==3.14.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
# requirements needed in this fork, but not a+ | ||
backoff==2.2.1 | ||
httpx==0.27.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
from pytest_factoryboy import register | ||
|
||
from tests.ideas import factories as idea_factories | ||
|
||
from .factories import AiReportFactory | ||
|
||
register(AiReportFactory) | ||
register(idea_factories.IdeaFactory) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import pytest | ||
from django.db.models import signals | ||
from django.test import override_settings | ||
|
||
from adhocracy4.comments.models import Comment | ||
from apps.ai_reports.signals import get_ai_classification | ||
|
||
|
||
@pytest.mark.django_db | ||
def test_comment_not_sent_to_xai_no_url(mocker, idea, comment_factory, caplog): | ||
task_mock = mocker.patch( | ||
"apps.ai_reports.tasks.get_classification_for_comment" ".delay" | ||
) | ||
assert get_ai_classification in signals.post_save._live_receivers(Comment) | ||
comment_factory(content_object=idea, comment="lala") | ||
task_mock.assert_not_called() | ||
assert len(caplog.records) == 1 | ||
assert "no xai api url provided" in str(caplog.records[-1]) | ||
|
||
|
||
@override_settings(XAI_API_URL="https://liqd.net") | ||
@pytest.mark.django_db | ||
def test_comment_sent_to_xai_on_comment_text_change(mocker, idea, comment_factory): | ||
task_mock = mocker.patch( | ||
"apps.ai_reports.tasks.get_classification_for_comment" ".delay" | ||
) | ||
|
||
assert get_ai_classification in signals.post_save._live_receivers(Comment) | ||
comment = comment_factory(content_object=idea, comment="lala") | ||
task_mock.assert_called_once_with(comment.pk) | ||
task_mock.reset_mock() | ||
|
||
comment.comment = "modified comment" | ||
comment.save() | ||
task_mock.assert_called_once_with(comment.pk) | ||
|
||
|
||
@override_settings(XAI_API_URL="https://liqd.net") | ||
@pytest.mark.django_db | ||
def test_comment_not_sent_to_xai_without_comment_text_change( | ||
mocker, idea, comment_factory, caplog | ||
): | ||
task_mock = mocker.patch( | ||
"apps.ai_reports.tasks.get_classification_for_comment" ".delay" | ||
) | ||
assert get_ai_classification in signals.post_save._live_receivers(Comment) | ||
comment = comment_factory(content_object=idea, comment="lala") | ||
task_mock.assert_called_once_with(comment.pk) | ||
task_mock.reset_mock() | ||
|
||
comment.is_blocked = True | ||
comment.save() | ||
task_mock.assert_not_called() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters