diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/_coherence.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/_coherence.py index b6cb803021e8..a07754b69d56 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/_coherence.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_coherence/_coherence.py @@ -2,14 +2,15 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- import os -from typing import Optional +from typing import Dict, Union, List -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation -class CoherenceEvaluator(PromptyEvaluatorBase): +class CoherenceEvaluator(PromptyEvaluatorBase[Union[str, float]]): """ Initialize a coherence evaluator configured for a specific Azure OpenAI model. @@ -49,13 +50,43 @@ def __init__(self, model_config): prompty_path = os.path.join(current_dir, self._PROMPTY_FILE) super().__init__(model_config=model_config, prompty_file=prompty_path, result_key=self._RESULT_KEY) - @override + @overload + def __call__( + self, + *, + query: str, + response: str, + ) -> Dict[str, Union[str, float]]: + """Evaluate coherence for given input of query, response + + :keyword query: The query to be evaluated. + :paramtype query: str + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The coherence score. + :rtype: Dict[str, float] + """ + + @overload def __call__( self, *, - query: Optional[str] = None, - response: Optional[str] = None, - conversation=None, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]: + """Evaluate coherence for a conversation + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The coherence score. + :rtype: Dict[str, Union[float, Dict[str, List[float]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """Evaluate coherence. Accepts either a query and response for a single evaluation, @@ -73,4 +104,4 @@ def __call__( :return: The relevance score. :rtype: Union[Dict[str, float], Dict[str, Union[float, Dict[str, List[float]]]]] """ - return super().__call__(query=query, response=response, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_eval.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_eval.py index e46b2dea11b5..8969d4dae9a9 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_eval.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_eval.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, Generic, List, TypedDict, TypeVar, Union, cast, final from promptflow._utils.async_utils import async_run_allowing_running_loop -from typing_extensions import ParamSpec, TypeAlias +from typing_extensions import ParamSpec, TypeAlias, get_overloads from azure.ai.evaluation._common.math import list_mean from azure.ai.evaluation._exceptions import ErrorBlame, ErrorCategory, ErrorTarget, EvaluationException @@ -88,7 +88,11 @@ def __init__( # This needs to be overridden just to change the function header into something more informative, # and to be able to add a more specific docstring. The actual function contents should just be # super().__call__() - def __call__(self, **kwargs) -> Union[DoEvalResult[T_EvalValue], AggregateResult[T_EvalValue]]: + def __call__( # pylint: disable=docstring-missing-param + self, + *args, + **kwargs, + ) -> Union[DoEvalResult[T_EvalValue], AggregateResult[T_EvalValue]]: """Evaluate a given input. This method serves as a wrapper and is meant to be overridden by child classes for one main reason - to overwrite the method headers and docstring to include additional inputs as needed. The actual behavior of this function shouldn't change beyond adding more inputs to the @@ -127,11 +131,19 @@ def _derive_singleton_inputs(self) -> List[str]: :rtype: List[str] """ + overloads = get_overloads(self.__call__) + if not overloads: + call_signatures = [inspect.signature(self.__call__)] + else: + call_signatures = [inspect.signature(overload) for overload in overloads] call_signature = inspect.signature(self.__call__) singletons = [] - for param in call_signature.parameters: - if param not in self._not_singleton_inputs: - singletons.append(param) + for call_signature in call_signatures: + params = call_signature.parameters + if any(not_singleton_input in params for not_singleton_input in self._not_singleton_inputs): + continue + # exclude self since it is not a singleton input + singletons.extend([p for p in params if p != "self"]) return singletons def _derive_conversation_converter(self) -> Callable[[Dict], List[DerivedEvalInput]]: diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py index e02f29ad0def..e851e8499260 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_prompty_eval.py @@ -4,7 +4,7 @@ import math import re -from typing import Dict, Union +from typing import Dict, TypeVar, Union from promptflow.core import AsyncPrompty from typing_extensions import override @@ -18,8 +18,10 @@ except ImportError: USER_AGENT = "None" +T = TypeVar("T") -class PromptyEvaluatorBase(EvaluatorBase[float]): + +class PromptyEvaluatorBase(EvaluatorBase[T]): """Base class for all evaluators that make use of context as an input. It's also assumed that such evaluators make use of a prompty file, and return their results as a dictionary, with a single key-value pair linking the result name to a float value (unless multi-turn evaluation occurs, in which case the diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py index 28d00f7977b6..cb687d23e695 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_common/_base_rai_svc_eval.py @@ -1,7 +1,7 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Dict, Optional, Union +from typing import Dict, TypeVar, Union from typing_extensions import override @@ -18,7 +18,7 @@ from . import EvaluatorBase -T = Union[str, float] +T = TypeVar("T") class RaiServiceEvaluatorBase(EvaluatorBase[T]): @@ -50,12 +50,9 @@ def __init__( self._credential = credential @override - def __call__( + def __call__( # pylint: disable=docstring-missing-param self, - *, - query: Optional[str] = None, - response: Optional[str] = None, - conversation=None, + *args, **kwargs, ): """Evaluate either a query and response or a conversation. Must supply either a query AND response, @@ -71,7 +68,7 @@ def __call__( :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] :rtype: Union[Dict[str, T], Dict[str, Union[float, Dict[str, List[T]]]]] """ - return super().__call__(query=query, response=response, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) @override async def _do_eval(self, eval_input: Dict) -> Dict[str, T]: @@ -108,7 +105,7 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, T]: ) input_data["context"] = context - return await evaluate_with_rai_service( + return await evaluate_with_rai_service( # type: ignore metric_name=self._eval_metric, data=input_data, project_scope=self._azure_ai_project, diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py index 8384024fc813..05ad3c7539d8 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_content_safety.py @@ -2,13 +2,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- from concurrent.futures import as_completed -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, List, Union from promptflow.tracing import ThreadPoolExecutorWithContext as ThreadPoolExecutor -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._common._experimental import experimental from azure.ai.evaluation._evaluators._common import EvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation from ._hate_unfairness import HateUnfairnessEvaluator from ._self_harm import SelfHarmEvaluator @@ -17,7 +18,7 @@ @experimental -class ContentSafetyEvaluator(EvaluatorBase): +class ContentSafetyEvaluator(EvaluatorBase[Union[str, float]]): """ Initialize a content safety evaluator configured to evaluate content safetry metrics for QA scenario. @@ -31,7 +32,6 @@ class ContentSafetyEvaluator(EvaluatorBase): :param kwargs: Additional arguments to pass to the evaluator. :type kwargs: Any :return: A function that evaluates content-safety metrics for "question-answering" scenario. - :rtype: Callable **Usage** @@ -79,13 +79,43 @@ def __init__(self, credential, azure_ai_project, eval_last_turn: bool = False, * HateUnfairnessEvaluator(credential, azure_ai_project), ] - @override + @overload + def __call__( + self, + *, + query: str, + response: str, + ) -> Dict[str, Union[str, float]]: + """Evaluate a collection of content safety metrics for the given query/response pair + + :keyword query: The query to be evaluated. + :paramtype query: str + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The content safety scores. + :rtype: Dict[str, Union[str, float]] + """ + + @overload def __call__( self, *, - query: Optional[str] = None, - response: Optional[str] = None, - conversation=None, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]: + """Evaluate a collection of content safety metrics for a conversation + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The content safety scores. + :rtype: Dict[str, Union[float, Dict[str, List[Union[str, float]]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """Evaluate a collection of content safety metrics for the given query/response pair or conversation. @@ -100,9 +130,9 @@ def __call__( to be dictionaries with keys "content", "role", and possibly "context". :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] :return: The evaluation result. - :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]] + :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]] """ - return super().__call__(query=query, response=response, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) @override async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[str, float]]: diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py index 551624101661..afb645d45768 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_hate_unfairness.py @@ -1,17 +1,18 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Optional +from typing import Dict, List, Union -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._common._experimental import experimental from azure.ai.evaluation._common.constants import EvaluationMetrics from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation @experimental -class HateUnfairnessEvaluator(RaiServiceEvaluatorBase): +class HateUnfairnessEvaluator(RaiServiceEvaluatorBase[Union[str, float]]): """ Initialize a hate-unfairness evaluator for hate unfairness score. @@ -58,27 +59,57 @@ def __init__( eval_last_turn=eval_last_turn, ) - @override + @overload + def __call__( + self, + *, + query: str, + response: str, + ) -> Dict[str, Union[str, float]]: + """Evaluate the given query/response pair for hateful content + + :keyword query: The query to be evaluated. + :paramtype query: str + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The hate score + :rtype: Dict[str, Union[str, float]] + """ + + @overload def __call__( self, *, - query: Optional[str] = None, - response: Optional[str] = None, - conversation=None, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]: + """Evaluate a conversation for hateful content + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The hate score + :rtype: Dict[str, Union[float, Dict[str, List[Union[str, float]]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """ Evaluate whether hateful content is present in your AI system's response. :keyword query: The query to be evaluated. - :paramtype query: str + :paramtype query: Optional[str] :keyword response: The response to be evaluated. - :paramtype response: str + :paramtype response: Optional[str] :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the key "messages". Conversation turns are expected to be dictionaries with keys "content" and "role". :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] :return: The fluency score. - :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]] + :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]] """ - return super().__call__(query=query, response=response, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py index 39e1d311063c..66cc70280737 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_self_harm.py @@ -1,17 +1,18 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Optional +from typing import Dict, List, Union -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._common._experimental import experimental from azure.ai.evaluation._common.constants import EvaluationMetrics from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation @experimental -class SelfHarmEvaluator(RaiServiceEvaluatorBase): +class SelfHarmEvaluator(RaiServiceEvaluatorBase[Union[str, float]]): """ Initialize a self harm evaluator for self harm score. @@ -58,27 +59,57 @@ def __init__( eval_last_turn=eval_last_turn, ) - @override + @overload + def __call__( + self, + *, + query: str, + response: str, + ) -> Dict[str, Union[str, float]]: + """Evaluate a given query/response pair for self-harm content + + :keyword query: The query to be evaluated. + :paramtype query: str + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The self-harm score + :rtype: Dict[str, Union[str, float]] + """ + + @overload def __call__( self, *, - query: Optional[str] = None, - response: Optional[str] = None, - conversation=None, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]: + """Evaluate a conversation for self-harm content + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The self-harm score + :rtype: Dict[str, Union[float, Dict[str, List[Union[str, float]]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """ Evaluate whether self-harm content is present in your AI system's response. :keyword query: The query to be evaluated. - :paramtype query: str + :paramtype query: Optional[str] :keyword response: The response to be evaluated. - :paramtype response: str + :paramtype response: Optional[str] :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the key "messages". Conversation turns are expected to be dictionaries with keys "content" and "role". :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] :return: The fluency score. - :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]] + :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]] """ - return super().__call__(query=query, response=response, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_sexual.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_sexual.py index ad7e875b3910..dbf7a2a0ae12 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_sexual.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_sexual.py @@ -1,17 +1,18 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Optional +from typing import Dict, List, Union -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._common._experimental import experimental from azure.ai.evaluation._common.constants import EvaluationMetrics from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation @experimental -class SexualEvaluator(RaiServiceEvaluatorBase): +class SexualEvaluator(RaiServiceEvaluatorBase[Union[str, float]]): """ Initialize a sexual evaluator for sexual score. @@ -58,22 +59,52 @@ def __init__( eval_last_turn=eval_last_turn, ) - @override + @overload + def __call__( + self, + *, + query: str, + response: str, + ) -> Dict[str, Union[str, float]]: + """Evaluate a given query/response pair for sexual content + + :keyword query: The query to be evaluated. + :paramtype query: str + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The sexual score + :rtype: Dict[str, Union[str, float]] + """ + + @overload def __call__( self, *, - query: Optional[str] = None, - response: Optional[str] = None, - conversation=None, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]: + """Evaluate a conversation for sexual content + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The sexual score + :rtype: Dict[str, Union[float, Dict[str, List[Union[str, float]]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """ Evaluate whether sexual content is present in your AI system's response. :keyword query: The query to be evaluated. - :paramtype query: str + :paramtype query: Optional[str] :keyword response: The response to be evaluated. - :paramtype response: str + :paramtype response: Optional[str] :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the key "messages". Conversation turns are expected to be dictionaries with keys "content" and "role". @@ -81,4 +112,4 @@ def __call__( :return: The fluency score. :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]] """ - return super().__call__(query=query, response=response, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_violence.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_violence.py index ab03d551d214..f43c08726dcd 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_violence.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_content_safety/_violence.py @@ -1,17 +1,18 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Optional +from typing import Dict, List, Union -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._common._experimental import experimental from azure.ai.evaluation._common.constants import EvaluationMetrics from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation @experimental -class ViolenceEvaluator(RaiServiceEvaluatorBase): +class ViolenceEvaluator(RaiServiceEvaluatorBase[Union[str, float]]): """ Initialize a violence evaluator for violence score. @@ -58,27 +59,58 @@ def __init__( eval_last_turn=eval_last_turn, ) - @override + @overload + def __call__( + self, + *, + query: str, + response: str, + ) -> Dict[str, Union[str, float]]: + """Evaluate a given query/response pair for violent content + + :keyword query: The query to be evaluated. + :paramtype query: str + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The content safety score. + :rtype: Dict[str, Union[str, float]] + """ + + @overload def __call__( self, *, - query: Optional[str] = None, - response: Optional[str] = None, - conversation=None, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]: + """Evaluate a conversation for violent content + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The violence score. + :rtype: Dict[str, Union[float, Dict[str, List[Union[str, float]]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """ Evaluate whether violent content is present in your AI system's response. :keyword query: The query to be evaluated. - :paramtype query: str + :paramtype query: Optional[str] :keyword response: The response to be evaluated. - :paramtype response: str + :paramtype response: Optional[str] :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the key "messages". Conversation turns are expected to be dictionaries with keys "content" and "role". :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] :return: The fluency score. - :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]] + :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]] """ - return super().__call__(query=query, response=response, conversation=conversation, **kwargs) + + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_eci/_eci.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_eci/_eci.py index 3e6e420e9305..c89df72fb13a 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_eci/_eci.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_eci/_eci.py @@ -1,11 +1,12 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._common._experimental import experimental from azure.ai.evaluation._common.constants import _InternalEvaluationMetrics from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation @experimental @@ -62,3 +63,26 @@ def __init__( credential=credential, eval_last_turn=eval_last_turn, ) + + @overload + def __call__( + self, + *, + query: str, + response: str, + ): ... + + @overload + def __call__( + self, + *, + conversation: Conversation, + ): ... + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, + **kwargs, + ): + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/_fluency.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/_fluency.py index 45d04ab58616..66c162a03993 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/_fluency.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_fluency/_fluency.py @@ -3,14 +3,15 @@ # --------------------------------------------------------- import os -from typing import Optional +from typing import Dict, List, Union -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation -class FluencyEvaluator(PromptyEvaluatorBase): +class FluencyEvaluator(PromptyEvaluatorBase[Union[str, float]]): """ Initialize a fluency evaluator configured for a specific Azure OpenAI model. @@ -48,12 +49,40 @@ def __init__(self, model_config): prompty_path = os.path.join(current_dir, self._PROMPTY_FILE) super().__init__(model_config=model_config, prompty_file=prompty_path, result_key=self._RESULT_KEY) - @override + @overload def __call__( self, *, - response: Optional[str] = None, - conversation=None, + response: str, + ) -> Dict[str, Union[str, float]]: + """Evaluate fluency in given query/response + + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The fluency score + :rtype: Dict[str, float] + """ + + @overload + def __call__( + self, + *, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]: + """Evaluate fluency for a conversation + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The fluency score + :rtype: Dict[str, Union[float, Dict[str, List[float]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """ @@ -62,12 +91,11 @@ def __call__( the evaluator will aggregate the results of each turn. :keyword response: The response to be evaluated. Mutually exclusive with the "conversation" parameter. - :paramtype response: str + :paramtype response: Optional[str] :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the key "messages". Conversation turns are expected to be dictionaries with keys "content" and "role". :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] :return: The fluency score. :rtype: Union[Dict[str, float], Dict[str, Union[float, Dict[str, List[float]]]]] """ - - return super().__call__(response=response, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py index 21a1a96f9f9c..5215396aa9d2 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_groundedness/_groundedness.py @@ -2,12 +2,13 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- import os -from typing import Optional +from typing import Dict, List, Optional, Union -from typing_extensions import override +from typing_extensions import overload, override from promptflow.core import AsyncPrompty from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation from ..._common.utils import construct_prompty_model_config, validate_model_config try: @@ -16,7 +17,7 @@ USER_AGENT = "None" -class GroundednessEvaluator(PromptyEvaluatorBase): +class GroundednessEvaluator(PromptyEvaluatorBase[Union[str, float]]): """ Initialize a groundedness evaluator configured for a specific Azure OpenAI model. @@ -62,14 +63,47 @@ def __init__(self, model_config): self._model_config = model_config # Needs to be set because it's used in call method to re-validate prompt if `query` is provided - @override + @overload def __call__( self, *, + response: str, + context: str, query: Optional[str] = None, - response: Optional[str] = None, - context: Optional[str] = None, - conversation=None, + ) -> Dict[str, Union[str, float]]: + """Evaluate groundedness for given input of response, context + + :keyword response: The response to be evaluated. + :paramtype response: str + :keyword context: The context to be evaluated. + :paramtype context: str + :keyword query: The query to be evaluated. Optional parameter for use with the `response` + and `context` parameters. If provided, a different prompt template will be used for evaluation. + :paramtype query: Optional[str] + :return: The groundedness score. + :rtype: Dict[str, float] + """ + + @overload + def __call__( + self, + *, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]: + """Evaluate groundedness for a conversation + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The groundedness score. + :rtype: Dict[str, Union[float, Dict[str, List[float]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """Evaluate groundedness. Accepts either a query, response, and context for a single evaluation, @@ -89,10 +123,10 @@ def __call__( to be dictionaries with keys "content", "role", and possibly "context". :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] :return: The relevance score. - :rtype: Union[Dict[str, float], Dict[str, Union[float, Dict[str, List[float]]]]] + :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]] """ - if query: + if kwargs.get("query", None): current_dir = os.path.dirname(__file__) prompty_path = os.path.join(current_dir, self._PROMPTY_FILE_WITH_QUERY) self._prompty_file = prompty_path @@ -103,4 +137,4 @@ def __call__( ) self._flow = AsyncPrompty.load(source=self._prompty_file, model=prompty_model_config) - return super().__call__(query=query, response=response, context=context, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py index 9c9351037da0..fb7dc8aefcb3 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_protected_material/_protected_material.py @@ -2,17 +2,18 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Optional +from typing import Dict, List, Optional, Union -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._common._experimental import experimental from azure.ai.evaluation._common.constants import EvaluationMetrics from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation @experimental -class ProtectedMaterialEvaluator(RaiServiceEvaluatorBase): +class ProtectedMaterialEvaluator(RaiServiceEvaluatorBase[Union[str, bool]]): """ Initialize a protected material evaluator to detect whether protected material is present in the AI system's response. The evaluator outputs a Boolean label (`True` or `False`) @@ -64,6 +65,39 @@ def __init__( eval_last_turn=eval_last_turn, ) + @overload + def __call__( + self, + *, + query: str, + response: str, + ) -> Dict[str, Union[str, bool]]: + """Evaluate a given query/response pair for protected material + + :keyword query: The query to be evaluated. + :paramtype query: str + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The protected material score. + :rtype: Dict[str, Union[str, bool]] + """ + + @overload + def __call__( + self, + *, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, bool]]]]]: + """Evaluate a conversation for protected material + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The protected material score. + :rtype: Dict[str, Union[str, bool, Dict[str, List[Union[str, bool]]]]] + """ + @override def __call__( self, @@ -77,14 +111,14 @@ def __call__( Evaluate if protected material is present in your AI system's response. :keyword query: The query to be evaluated. - :paramtype query: str + :paramtype query: Optional[str] :keyword response: The response to be evaluated. - :paramtype response: str + :paramtype response: Optional[str] :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the key "messages". Conversation turns are expected to be dictionaries with keys "content" and "role". :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] :return: The fluency score. - :rtype: Union[Dict[str, Union[str, bool]], Dict[str, Union[str, bool, Dict[str, List[Union[str, bool]]]]]] + :rtype: Union[Dict[str, Union[str, bool]], Dict[str, Union[float, Dict[str, List[Union[str, bool]]]]]] """ return super().__call__(query=query, response=response, conversation=conversation, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_qa/_qa.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_qa/_qa.py index e8198ff85e89..b5f3ac810eff 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_qa/_qa.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_qa/_qa.py @@ -3,7 +3,7 @@ # --------------------------------------------------------- from concurrent.futures import as_completed -from typing import Callable, Dict, List +from typing import Callable, Dict, List, Union from promptflow.tracing import ThreadPoolExecutorWithContext as ThreadPoolExecutor @@ -58,7 +58,7 @@ class QAEvaluator: def __init__(self, model_config, parallel: bool = True): self._parallel = parallel - self._evaluators: List[Callable[..., Dict[str, float]]] = [ + self._evaluators: List[Union[Callable[..., Dict[str, Union[str, float]]], Callable[..., Dict[str, float]]]] = [ GroundednessEvaluator(model_config), RelevanceEvaluator(model_config), CoherenceEvaluator(model_config), @@ -82,9 +82,9 @@ def __call__(self, *, query: str, response: str, context: str, ground_truth: str :keyword parallel: Whether to evaluate in parallel. Defaults to True. :paramtype parallel: bool :return: The scores for QA scenario. - :rtype: Dict[str, float] + :rtype: Dict[str, Union[str, float]] """ - results: Dict[str, float] = {} + results: Dict[str, Union[str, float]] = {} if self._parallel: with ThreadPoolExecutor() as executor: futures = { diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/_relevance.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/_relevance.py index abfa485b5c6d..f5fb2d96360b 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/_relevance.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_relevance/_relevance.py @@ -3,10 +3,11 @@ # --------------------------------------------------------- import os -from typing import Optional +from typing import Dict, Union, List -from typing_extensions import override +from typing_extensions import overload, override +from azure.ai.evaluation._model_configurations import Conversation from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase @@ -52,13 +53,43 @@ def __init__(self, model_config): prompty_path = os.path.join(current_dir, self._PROMPTY_FILE) super().__init__(model_config=model_config, prompty_file=prompty_path, result_key=self._RESULT_KEY) - @override + @overload def __call__( self, *, - query: Optional[str] = None, - response: Optional[str] = None, - conversation=None, + query: str, + response: str, + ) -> Dict[str, Union[str, float]]: + """Evaluate groundedness for given input of query, response, context + + :keyword query: The query to be evaluated. + :paramtype query: str + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The relevance score. + :rtype: Dict[str, float] + """ + + @overload + def __call__( + self, + *, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]: + """Evaluate relevance for a conversation + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The relevance score. + :rtype: Dict[str, Union[float, Dict[str, List[float]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """Evaluate relevance. Accepts either a query and response for a single evaluation, @@ -74,7 +105,6 @@ def __call__( to be dictionaries with keys "content", "role", and possibly "context". :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] :return: The relevance score. - :rtype: Union[Dict[str, float], Dict[str, Union[float, Dict[str, List[float]]]]] + :rtype: Union[Dict[str, Union[str, float]], Dict[str, Union[float, Dict[str, List[Union[str, float]]]]]] """ - - return super().__call__(query=query, response=response, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py index 748c4e4904b0..b23cf62b10be 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_retrieval/_retrieval.py @@ -6,12 +6,15 @@ import logging import math import os -from typing import Optional +from typing import Dict, List, Union +from typing_extensions import overload from promptflow._utils.async_utils import async_run_allowing_running_loop from promptflow.core import AsyncPrompty +from azure.ai.evaluation._evaluators._common._base_prompty_eval import PromptyEvaluatorBase from azure.ai.evaluation._exceptions import EvaluationException, ErrorBlame, ErrorCategory, ErrorTarget +from azure.ai.evaluation._model_configurations import Conversation from ..._common.math import list_mean_nan_safe from ..._common.utils import construct_prompty_model_config, validate_model_config, parse_quality_evaluator_reason_score @@ -107,7 +110,7 @@ async def __call__(self, *, query, context, conversation, **kwargs): } -class RetrievalEvaluator: +class RetrievalEvaluator(PromptyEvaluatorBase[Union[str, float]]): """ Initialize an evaluator configured for a specific Azure OpenAI model. @@ -152,10 +155,42 @@ class RetrievalEvaluator: however, it is recommended to use the new key moving forward as the old key will be deprecated in the future. """ - def __init__(self, model_config): + def __init__(self, model_config): # pylint: disable=super-init-not-called self._async_evaluator = _AsyncRetrievalScoreEvaluator(model_config) - def __call__(self, *, query: Optional[str] = None, context: Optional[str] = None, conversation=None, **kwargs): + @overload + def __call__( + self, + *, + query: str, + context: str, + ) -> Dict[str, Union[str, float]]: + """Evaluates retrieval for a given a query and context + + :keyword query: The query to be evaluated. Mutually exclusive with `conversation` parameter. + :paramtype query: Optional[str] + :keyword context: The context to be evaluated. Mutually exclusive with `conversation` parameter. + :paramtype context: Optional[str] + :return: The scores for Chat scenario. + :rtype: Dict[str, Union[str, float]] + """ + + @overload + def __call__( + self, + *, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[float]]]]: + """Evaluates retrieval for a for a multi-turn evaluation. If the conversation has more than one turn, + the evaluator will aggregate the results of each turn. + + :keyword conversation: The conversation to be evaluated. + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The scores for Chat scenario. + :rtype: :rtype: Dict[str, Union[float, Dict[str, List[float]]]] + """ + + def __call__(self, *args, **kwargs): # pylint: disable=docstring-missing-param """Evaluates retrieval score chat scenario. Accepts either a query and context for a single evaluation, or a conversation for a multi-turn evaluation. If the conversation has more than one turn, the evaluator will aggregate the results of each turn. @@ -169,6 +204,10 @@ def __call__(self, *, query: Optional[str] = None, context: Optional[str] = None :return: The scores for Chat scenario. :rtype: :rtype: Dict[str, Union[float, Dict[str, List[float]]]] """ + query = kwargs.pop("query", None) + context = kwargs.pop("context", None) + conversation = kwargs.pop("conversation", None) + if (query is None or context is None) and conversation is None: msg = "Either a pair of 'query'/'context' or 'conversation' must be provided." raise EvaluationException( @@ -192,6 +231,3 @@ def __call__(self, *, query: Optional[str] = None, context: Optional[str] = None return async_run_allowing_running_loop( self._async_evaluator, query=query, context=context, conversation=conversation, **kwargs ) - - def _to_async(self): - return self._async_evaluator diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py index 83780f6506ef..be0d249c99b3 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_service_groundedness/_service_groundedness.py @@ -1,16 +1,17 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from typing import Optional, Dict -from typing_extensions import override +from typing import List, Optional, Union, Dict +from typing_extensions import overload, override from azure.ai.evaluation._common._experimental import experimental from azure.ai.evaluation._common.constants import EvaluationMetrics from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation @experimental -class GroundednessProEvaluator(RaiServiceEvaluatorBase): +class GroundednessProEvaluator(RaiServiceEvaluatorBase[Union[str, bool]]): """ Initialize a Groundedness Pro evaluator for determine if the response is grounded in the query and context. @@ -100,14 +101,48 @@ def __init__( **kwargs, ) - @override + @overload def __call__( self, *, query: Optional[str] = None, response: Optional[str] = None, context: Optional[str] = None, - conversation=None, + ) -> Dict[str, Union[str, bool]]: + """Evaluate groundedness for a given query/response/context + + :keyword query: The query to be evaluated. + :paramtype query: Optional[str] + :keyword response: The response to be evaluated. + :paramtype response: Optional[str] + :keyword context: The context to be evaluated. + :paramtype context: Optional[str] + :return: The relevance score. + :rtype: Dict[str, Union[str, bool]] + """ + + @overload + def __call__( + self, + *, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, bool]]]]]: + """Evaluate groundedness for a conversation for a multi-turn evaluation. If the conversation has + more than one turn, the evaluator will aggregate the results of each turn, with the per-turn results + available in the output under the "evaluation_per_turn" key. + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The relevance score. + :rtype: Dict[str, Union[float, Dict[str, List[Union[str, bool]]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """Evaluate groundedness. Accepts either a query, response and context for a single-turn evaluation, or a @@ -128,7 +163,7 @@ def __call__( :return: The relevance score. :rtype: Union[Dict[str, Union[str, bool]], Dict[str, Union[float, Dict[str, List[Union[str, bool]]]]]] """ - return super().__call__(query=query, response=response, context=context, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs) @override async def _do_eval(self, eval_input: Dict): diff --git a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_xpia/xpia.py b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_xpia/xpia.py index 703fa31cf70b..9d591f8d75b7 100644 --- a/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_xpia/xpia.py +++ b/sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_evaluators/_xpia/xpia.py @@ -2,19 +2,20 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- import logging -from typing import Optional +from typing import Dict, List, Union -from typing_extensions import override +from typing_extensions import overload, override from azure.ai.evaluation._common._experimental import experimental from azure.ai.evaluation._common.constants import EvaluationMetrics from azure.ai.evaluation._evaluators._common import RaiServiceEvaluatorBase +from azure.ai.evaluation._model_configurations import Conversation logger = logging.getLogger(__name__) @experimental -class IndirectAttackEvaluator(RaiServiceEvaluatorBase): +class IndirectAttackEvaluator(RaiServiceEvaluatorBase[Union[str, bool]]): """A Cross-Domain Prompt Injection Attack (XPIA) jailbreak evaluator. Detect whether cross domain injected attacks are present in your AI system's response. @@ -65,27 +66,57 @@ def __init__( eval_last_turn=eval_last_turn, ) - @override + @overload + def __call__( + self, + *, + query: str, + response: str, + ) -> Dict[str, Union[str, bool]]: + """Evaluate whether cross domain injected attacks are present in given query/response + + :keyword query: The query to be evaluated. + :paramtype query: str + :keyword response: The response to be evaluated. + :paramtype response: str + :return: The cross domain injection attack score + :rtype: Dict[str, Union[str, bool]] + """ + + @overload def __call__( self, *, - query: Optional[str] = None, - response: Optional[str] = None, - conversation=None, + conversation: Conversation, + ) -> Dict[str, Union[float, Dict[str, List[Union[str, bool]]]]]: + """Evaluate cross domain injected attacks are present in a conversation + + :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the + key "messages", and potentially a global context under the key "context". Conversation turns are expected + to be dictionaries with keys "content", "role", and possibly "context". + :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] + :return: The cross domain injection attack score + :rtype: Dict[str, Union[str, bool, Dict[str, List[Union[str, bool]]]]] + """ + + @override + def __call__( # pylint: disable=docstring-missing-param + self, + *args, **kwargs, ): """ Evaluate whether cross domain injected attacks are present in your AI system's response. :keyword query: The query to be evaluated. - :paramtype query: str + :paramtype query: Optional[str] :keyword response: The response to be evaluated. - :paramtype response: str + :paramtype response: Optional[str] :keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the key "messages". Conversation turns are expected to be dictionaries with keys "content" and "role". :paramtype conversation: Optional[~azure.ai.evaluation.Conversation] - :return: The fluency score. - :rtype: Union[Dict[str, Union[str, bool]], Dict[str, Union[str, bool, Dict[str, List[Union[str, bool]]]]]] + :return: The cross domain injection attack score + :rtype: Union[Dict[str, Union[str, bool]], Dict[str, Union[float, Dict[str, List[Union[str, bool]]]]]] """ - return super().__call__(query=query, response=response, conversation=conversation, **kwargs) + return super().__call__(*args, **kwargs)