Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
lint and run black
  • Loading branch information
needuv committed Oct 25, 2024
commit ab5a3d80fef79eed44e449b3e4e5738a4f7080da
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@


@overload
def experimental(wrapped: Type[T]) -> Type[T]: ...
def experimental(wrapped: Type[T]) -> Type[T]:
...


@overload
def experimental(wrapped: Callable[P, T]) -> Callable[P, T]: ...
def experimental(wrapped: Callable[P, T]) -> Callable[P, T]:
...


def experimental(wrapped: Union[Type[T], Callable[P, T]]) -> Union[Type[T], Callable[P, T]]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __call__(
response: str,
) -> Dict[str, float]:
"""Evaluate coherence for given input of query, response

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -76,7 +76,7 @@ def __call__(
**kwargs,
) -> Dict[str, Union[float, Dict[str, List[float]]]]:
"""Evaluate coherence for a conversation

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __call__(
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate a collection of content safety metrics for the given query/response pair

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -105,7 +105,7 @@ def __call__(
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a collection of content safety metrics for a conversation

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
credential=credential,
eval_last_turn=eval_last_turn,
)

@overload
def __call__(
self,
Expand All @@ -67,7 +67,7 @@ def __call__(
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate the given query/response pair for hateful content

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -85,7 +85,7 @@ def __call__(
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a conversation for hateful content

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate a given query/response pair for self-harm content

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -85,7 +85,7 @@ def __call__(
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a conversation for self-harm content

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate a given query/response pair for sexual content

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -83,9 +83,9 @@ def __call__(
*,
conversation: Conversation,
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a conversation for sexual content

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(
response: str,
) -> Dict[str, Union[str, float]]:
"""Evaluate a given query/response pair for violent content

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -85,7 +85,7 @@ def __call__(
**kwargs,
) -> Dict[str, Union[str, float, Dict[str, List[Union[str, float]]]]]:
"""Evaluate a conversation for violent content

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(
response: str,
) -> Dict[str, float]:
"""Evaluate fluency in given query/response

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -76,8 +76,8 @@ def __call__(
conversation: Conversation,
**kwargs,
) -> Dict[str, Union[float, Dict[str, List[float]]]]:
"""Evaluate fluency for a conversation
"""Evaluate fluency for a conversation

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(
context: str,
) -> Dict[str, float]:
"""Evaluate groundedness for given input of response, context

:keyword response: The response to be evaluated.
:paramtype response: str
:keyword context: The context to be evaluated.
Expand All @@ -77,7 +77,7 @@ def __call__(
**kwargs,
) -> Dict[str, Union[float, Dict[str, List[float]]]]:
"""Evaluate groundedness for a conversation

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand All @@ -86,7 +86,6 @@ def __call__(
:rtype: Dict[str, Union[float, Dict[str, List[float]]]]
"""
...


@override
def __call__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __call__(
response: str,
) -> Dict[str, Union[str, bool]]:
"""Evaluate a given query/response pair for protected material

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -86,7 +86,7 @@ def __call__(
**kwargs,
) -> Dict[str, Union[str, bool, Dict[str, List[Union[str, bool]]]]]:
"""Evaluate a conversation for protected material

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __call__(
context: str,
) -> Dict[str, float]:
"""Evaluate groundedness for given input of query, response, context

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -73,7 +73,7 @@ def __call__(
:return: The relevance score.
:rtype: Dict[str, float]
"""

...

@overload
Expand All @@ -84,7 +84,7 @@ def __call__(
**kwargs,
) -> Dict[str, Union[float, Dict[str, List[float]]]]:
"""Evaluate relevance for a conversation

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand All @@ -93,7 +93,6 @@ def __call__(
:rtype: Dict[str, Union[float, Dict[str, List[float]]]]
"""
...


@override
def __call__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(
eval_last_turn=eval_last_turn,
)


@overload
def __call__(
self,
Expand All @@ -75,7 +74,7 @@ def __call__(
response: str,
) -> Dict[str, Union[str, bool]]:
"""Evaluate whether cross domain injected attacks are present in given query/response

:keyword query: The query to be evaluated.
:paramtype query: str
:keyword response: The response to be evaluated.
Expand All @@ -92,8 +91,8 @@ def __call__(
conversation: Conversation,
**kwargs,
) -> Dict[str, Union[str, bool, Dict[str, List[Union[str, bool]]]]]:
"""Evaluate cross domain injected attacks are present in a conversation
"""Evaluate cross domain injected attacks are present in a conversation

:keyword conversation: The conversation to evaluate. Expected to contain a list of conversation turns under the
key "messages", and potentially a global context under the key "context". Conversation turns are expected
to be dictionaries with keys "content", "role", and possibly "context".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,6 @@ async def _simulate_with_predefined_turns(
semaphore = asyncio.Semaphore(concurrent_async_tasks)
progress_bar_lock = asyncio.Lock()


async def run_simulation(simulation: List[Union[str, Dict[str, Any]]]) -> JsonLineChatProtocol:
async with semaphore:
current_simulation = ConversationHistory()
Expand Down