Skip to content

Commit

Permalink
Comprehend Moderation 0.2 (langchain-ai#11730)
Browse files Browse the repository at this point in the history
This PR replaces the previous `Intent` check with the new `Prompt
Safety` check. The logic and steps to enable chain moderation via the
Amazon Comprehend service, allowing you to detect and redact PII, Toxic,
and Prompt Safety information in the LLM prompt or answer remains
unchanged.
This implementation updates the code and configuration types with
respect to `Prompt Safety`.


### Usage sample

```python
from langchain_experimental.comprehend_moderation import (BaseModerationConfig, 
                                 ModerationPromptSafetyConfig, 
                                 ModerationPiiConfig, 
                                 ModerationToxicityConfig
)

pii_config = ModerationPiiConfig(
    labels=["SSN"],
    redact=True,
    mask_character="X"
)

toxicity_config = ModerationToxicityConfig(
    threshold=0.5
)

prompt_safety_config = ModerationPromptSafetyConfig(
    threshold=0.5
)

moderation_config = BaseModerationConfig(
    filters=[pii_config, toxicity_config, prompt_safety_config]
)

comp_moderation_with_config = AmazonComprehendModerationChain(
    moderation_config=moderation_config, #specify the configuration
    client=comprehend_client,            #optionally pass the Boto3 Client
    verbose=True
)

template = """Question: {question}

Answer:"""

prompt = PromptTemplate(template=template, input_variables=["question"])

responses = [
    "Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like 323-22-9980. John Doe's phone number is (999)253-9876.", 
    "Final Answer: This is a really shitty way of constructing a birdhouse. This is fucking insane to think that any birds would actually create their motherfucking nests here."
]
llm = FakeListLLM(responses=responses)

llm_chain = LLMChain(prompt=prompt, llm=llm)

chain = ( 
    prompt 
    | comp_moderation_with_config 
    | {llm_chain.input_keys[0]: lambda x: x['output'] }  
    | llm_chain 
    | { "input": lambda x: x['text'] } 
    | comp_moderation_with_config 
)

try:
    response = chain.invoke({"question": "A sample SSN number looks like this 123-456-7890. Can you give me some more samples?"})
except Exception as e:
    print(str(e))
else:
    print(response['output'])

```

### Output

```python
> Entering new AmazonComprehendModerationChain chain...
Running AmazonComprehendModerationChain...
Running pii Validation...
Running toxicity Validation...
Running prompt safety Validation...

> Finished chain.


> Entering new AmazonComprehendModerationChain chain...
Running AmazonComprehendModerationChain...
Running pii Validation...
Running toxicity Validation...
Running prompt safety Validation...

> Finished chain.
Final Answer: A credit card number looks like 1289-2321-1123-2387. A fake SSN number looks like XXXXXXXXXXXX John Doe's phone number is (999)253-9876.
```

---------

Co-authored-by: Jha <nikjha@amazon.com>
Co-authored-by: Anjan Biswas <anjanavb@amazon.com>
Co-authored-by: Anjan Biswas <84933469+anjanvb@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 26, 2023
1 parent b9410f2 commit dff2428
Show file tree
Hide file tree
Showing 7 changed files with 282 additions and 205 deletions.
388 changes: 230 additions & 158 deletions docs/docs/guides/safety/amazon_comprehend_chain.ipynb

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,25 @@
)
from langchain_experimental.comprehend_moderation.base_moderation_config import (
BaseModerationConfig,
ModerationIntentConfig,
ModerationPiiConfig,
ModerationPromptSafetyConfig,
ModerationToxicityConfig,
)
from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
from langchain_experimental.comprehend_moderation.prompt_safety import (
ComprehendPromptSafety,
)
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity

__all__ = [
"BaseModeration",
"ComprehendPII",
"ComprehendIntent",
"ComprehendPromptSafety",
"ComprehendToxicity",
"BaseModerationConfig",
"ModerationPiiConfig",
"ModerationToxicityConfig",
"ModerationIntentConfig",
"ModerationPromptSafetyConfig",
"BaseModerationCallbackHandler",
"AmazonComprehendModerationChain",
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import AIMessage, HumanMessage

from langchain_experimental.comprehend_moderation.intent import ComprehendIntent
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
from langchain_experimental.comprehend_moderation.prompt_safety import (
ComprehendPromptSafety,
)
from langchain_experimental.comprehend_moderation.toxicity import ComprehendToxicity


Expand Down Expand Up @@ -109,13 +111,13 @@ def _log_message_for_verbose(self, message: str) -> None:

def moderate(self, prompt: Any) -> str:
from langchain_experimental.comprehend_moderation.base_moderation_config import ( # noqa: E501
ModerationIntentConfig,
ModerationPiiConfig,
ModerationPromptSafetyConfig,
ModerationToxicityConfig,
)
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import ( # noqa: E501
ModerationIntentionError,
ModerationPiiError,
ModerationPromptSafetyError,
ModerationToxicityError,
)

Expand All @@ -128,7 +130,7 @@ def moderate(self, prompt: Any) -> str:
filter_functions = {
"pii": ComprehendPII,
"toxicity": ComprehendToxicity,
"intent": ComprehendIntent,
"prompt_safety": ComprehendPromptSafety,
}

filters = self.config.filters # type: ignore
Expand All @@ -141,8 +143,8 @@ def moderate(self, prompt: Any) -> str:
"toxicity"
if isinstance(_filter, ModerationToxicityConfig)
else (
"intent"
if isinstance(_filter, ModerationIntentConfig)
"prompt_safety"
if isinstance(_filter, ModerationPromptSafetyConfig)
else None
)
)
Expand Down Expand Up @@ -171,7 +173,7 @@ def moderate(self, prompt: Any) -> str:
f"Found Toxic content..stopping..\n{str(e)}\n"
)
raise e
except ModerationIntentionError as e:
except ModerationPromptSafetyError as e:
self._log_message_for_verbose(
f"Found Harmful intention..stopping..\n{str(e)}\n"
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ def __init__(self) -> None:
BaseModerationCallbackHandler.on_after_toxicity, self.on_after_toxicity
)
and self._is_method_unchanged(
BaseModerationCallbackHandler.on_after_intent, self.on_after_intent
BaseModerationCallbackHandler.on_after_prompt_safety,
self.on_after_prompt_safety,
)
):
raise NotImplementedError(
"Subclasses must override at least one of on_after_pii(), "
"on_after_toxicity(), or on_after_intent() functions."
"on_after_toxicity(), or on_after_prompt_safety() functions."
)

def _is_method_unchanged(
Expand All @@ -36,10 +37,10 @@ async def on_after_toxicity(
"""Run after Toxicity validation is complete."""
pass

async def on_after_intent(
async def on_after_prompt_safety(
self, moderation_beacon: Dict[str, Any], unique_id: str, **kwargs: Any
) -> None:
"""Run after Toxicity validation is complete."""
"""Run after Prompt Safety validation is complete."""
pass

@property
Expand All @@ -57,8 +58,8 @@ def toxicity_callback(self) -> bool:
)

@property
def intent_callback(self) -> bool:
def prompt_safety_callback(self) -> bool:
return (
self.on_after_intent.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_intent
self.on_after_prompt_safety.__func__ # type: ignore
is not BaseModerationCallbackHandler.on_after_prompt_safety
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,26 @@ class ModerationToxicityConfig(BaseModel):
"""List of toxic labels, defaults to `list[]`"""


class ModerationIntentConfig(BaseModel):
class ModerationPromptSafetyConfig(BaseModel):
threshold: float = 0.5
"""
Threshold for Intent classification
Threshold for Prompt Safety classification
confidence score, defaults to 0.5 i.e. 50%
"""


class BaseModerationConfig(BaseModel):
filters: List[
Union[ModerationPiiConfig, ModerationToxicityConfig, ModerationIntentConfig]
Union[
ModerationPiiConfig, ModerationToxicityConfig, ModerationPromptSafetyConfig
]
] = [
ModerationPiiConfig(),
ModerationToxicityConfig(),
ModerationIntentConfig(),
ModerationPromptSafetyConfig(),
]
"""
Filters applied to the moderation chain, defaults to
`[ModerationPiiConfig(), ModerationToxicityConfig(),
ModerationIntentConfig()]`
ModerationPromptSafetyConfig()]`
"""
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
super().__init__(self.message)


class ModerationIntentionError(Exception):
class ModerationPromptSafetyError(Exception):
"""Exception raised if Intention entities are detected.
Attributes:
Expand All @@ -35,9 +35,7 @@ class ModerationIntentionError(Exception):

def __init__(
self,
message: str = (
"The prompt indicates an un-desired intent and " "cannot be processed"
),
message: str = ("The prompt is unsafe and cannot be processed"),
):
self.message = message
super().__init__(self.message)
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from typing import Any, Optional

from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationIntentionError,
ModerationPromptSafetyError,
)


class ComprehendIntent:
class ComprehendPromptSafety:
def __init__(
self,
client: Any,
Expand All @@ -17,7 +17,7 @@ def __init__(
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "Intent",
"moderation_type": "PromptSafety",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
Expand All @@ -26,62 +26,62 @@ def __init__(
def _get_arn(self) -> str:
region_name = self.client.meta.region_name
service = "comprehend"
intent_endpoint = "document-classifier-endpoint/prompt-intent"
return f"arn:aws:{service}:{region_name}:aws:{intent_endpoint}"
prompt_safety_endpoint = "document-classifier-endpoint/prompt-safety"
return f"arn:aws:{service}:{region_name}:aws:{prompt_safety_endpoint}"

def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check and validate the intent of the given prompt text.
Check and validate the safety of the given prompt text.
Args:
prompt_value (str): The input text to be checked for unintended intent.
config (Dict[str, Any]): Configuration settings for intent checks.
prompt_value (str): The input text to be checked for unsafe text.
config (Dict[str, Any]): Configuration settings for prompt safety checks.
Raises:
ValueError: If unintended intent is found in the prompt text based
ValueError: If unsafe prompt is found in the prompt text based
on the specified threshold.
Returns:
str: The input prompt_value.
Note:
This function checks the intent of the provided prompt text using
Comprehend's classify_document API and raises an error if unintended
intent is detected with a score above the specified threshold.
This function checks the safety of the provided prompt text using
Comprehend's classify_document API and raises an error if unsafe
text is detected with a score above the specified threshold.
Example:
comprehend_client = boto3.client('comprehend')
prompt_text = "Please tell me your credit card information."
config = {"threshold": 0.7}
checked_prompt = check_intent(comprehend_client, prompt_text, config)
checked_prompt = check_prompt_safety(comprehend_client, prompt_text, config)
"""

threshold = config.get("threshold")
intent_found = False
unsafe_prompt = False

endpoint_arn = self._get_arn()
response = self.client.classify_document(
Text=prompt_value, EndpointArn=endpoint_arn
)

if self.callback and self.callback.intent_callback:
if self.callback and self.callback.prompt_safety_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = response

for class_result in response["Classes"]:
if (
class_result["Score"] >= threshold
and class_result["Name"] == "UNDESIRED_PROMPT"
and class_result["Name"] == "UNSAFE_PROMPT"
):
intent_found = True
unsafe_prompt = True
break

if self.callback and self.callback.intent_callback:
if intent_found:
if unsafe_prompt:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_intent(self.moderation_beacon, self.unique_id)
)
if intent_found:
raise ModerationIntentionError
if unsafe_prompt:
raise ModerationPromptSafetyError
return prompt_value

0 comments on commit dff2428

Please sign in to comment.