Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion fastchat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
If you have any changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates.
"""

import base64
import dataclasses
from enum import auto, IntEnum
from io import BytesIO
from typing import List, Any, Dict, Union, Tuple


Expand Down Expand Up @@ -34,6 +36,9 @@ class SeparatorStyle(IntEnum):
YUAN2 = auto()


IMAGE_PLACEHOLDER_STR = "$$<image>$$"


@dataclasses.dataclass
class Conversation:
"""A class that manages prompt templates and keeps all conversation history."""
Expand All @@ -47,6 +52,7 @@ class Conversation:
# The names of two roles
roles: Tuple[str] = ("USER", "ASSISTANT")
# All messages. Each item is (role, message).
# Each message is either a string or a tuple of (string, List[image_url]).
messages: List[List[str]] = ()
# The number of few shot examples
offset: int = 0
Expand Down Expand Up @@ -77,6 +83,7 @@ def get_prompt(self) -> str:
if message:
if type(message) is tuple:
message, images = message
message = IMAGE_PLACEHOLDER_STR * len(images) + message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
Expand Down Expand Up @@ -289,11 +296,52 @@ def update_last_message(self, message: str):
"""
self.messages[-1][1] = message

def convert_image_to_base64(self, image):
"""Given an image, return the base64 encoded image string."""
from PIL import Image
import requests

# Load image if it has not been loaded in yet
if type(image) == str:
if image.startswith("http://") or image.startswith("https://"):
response = requests.get(image)
image = Image.open(BytesIO(response.content)).convert("RGB")
elif "base64" in image:
# OpenAI format is: data:image/jpeg;base64,{base64_encoded_image_str}
return image.split(",")[1]
else:
image = Image.open(image).convert("RGB")

max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 2048, 2048
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))

buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()

return img_b64_str

def to_gradio_chatbot(self):
"""Convert the conversation to gradio chatbot format."""
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
if type(msg) is tuple:
msg, image = msg
img_b64_str = image[0] # Only one image on gradio at one time
img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace("<image>\n", "").strip()

ret.append([msg, None])
else:
ret[-1][-1] = msg
Expand All @@ -314,6 +362,12 @@ def to_openai_api_messages(self):
ret.append({"role": "assistant", "content": msg})
return ret

def extract_text_from_messages(self):
return [
(role, message[0]) if type(message) is tuple else (role, message)
for role, message in self.messages
]

def copy(self):
return Conversation(
name=self.name,
Expand All @@ -334,7 +388,7 @@ def dict(self):
"template_name": self.name,
"system_message": self.system_message,
"roles": self.roles,
"messages": self.messages,
"messages": self.extract_text_from_messages(),
"offset": self.offset,
}

Expand Down
Loading