Skip to content

[Draft] Support PIL Image in llm.chat #17919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/entrypoints/llm/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import weakref

import pytest
from PIL import Image

from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
Expand Down Expand Up @@ -118,6 +119,29 @@ def test_chat_multi_image(vision_llm, image_urls: list[str]):
assert len(outputs) >= 0


@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_pil_image(vision_llm, image_urls: list[str]):
images = [Image.open(image_url) for image_url in image_urls]

messages = [{
"role":
"user",
"content": [
*({
"type": "image",
"image": image
} for image in images),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = vision_llm.chat(messages)
assert len(outputs) >= 0


def test_llm_chat_tokenization_no_double_bos(text_llm):
"""
LLM.chat() should not add special tokens when using chat templates.
Expand Down
38 changes: 37 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
ChatCompletionToolMessageParam)
from openai.types.chat.chat_completion_content_part_input_audio_param import (
InputAudio)
from PIL import Image
from pydantic import TypeAdapter
# yapf: enable
from transformers import (PreTrainedTokenizer, PreTrainedTokenizerFast,
Expand Down Expand Up @@ -87,6 +88,20 @@ class ChatCompletionContentPartVideoParam(TypedDict, total=False):
"""The type of the content part."""


class PILImage(TypedDict, total=False):
image: Required[Image.Image]
"""
A PIL.Image.Image object.
"""


class ChatCompletionContentPartPILImageParam(TypedDict, total=False):
image: Required[PILImage]

type: Required[Literal["image"]]
"""The type of the content part."""


class CustomChatCompletionContentSimpleImageParam(TypedDict, total=False):
"""A simpler version of the param that only accepts a plain image_url.
This is supported by OpenAI API, although it is not documented.
Expand Down Expand Up @@ -124,6 +139,7 @@ class CustomChatCompletionContentSimpleVideoParam(TypedDict, total=False):
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
ChatCompletionContentPartInputAudioParam,
ChatCompletionContentPartPILImageParam,
ChatCompletionContentPartVideoParam, ChatCompletionContentPartRefusalParam,
CustomChatCompletionContentSimpleImageParam,
ChatCompletionContentPartImageEmbedsParam,
Expand Down Expand Up @@ -680,6 +696,10 @@ def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
raise NotImplementedError

@abstractmethod
def parse_pil_image(self, image: Image.Image) -> None:
raise NotImplementedError

@abstractmethod
def parse_audio(self, audio_url: str) -> None:
raise NotImplementedError
Expand Down Expand Up @@ -710,6 +730,10 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)

def parse_pil_image(self, image: Image.Image) -> None:
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)

def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
if isinstance(image_embeds, dict):
Expand Down Expand Up @@ -761,6 +785,10 @@ def parse_image(self, image_url: str) -> None:
placeholder = self._tracker.add("image", image_coro)
self._add_placeholder(placeholder)

def parse_pil_image(self, image: Image.Image) -> None:
placeholder = self._tracker.add("image", image)
self._add_placeholder(placeholder)

def parse_image_embeds(self,
image_embeds: Union[str, dict[str, str]]) -> None:
future: asyncio.Future[Union[str, dict[str, str]]] = asyncio.Future()
Expand Down Expand Up @@ -902,6 +930,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam).validate_python
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam).validate_python
_VideoParser = TypeAdapter(ChatCompletionContentPartVideoParam).validate_python
# Parser for supporting raw multimodal data format
_PILImageParser = TypeAdapter(ChatCompletionContentPartPILImageParam).validate_python # noqa: E501

_ContentPart: TypeAlias = Union[str, dict[str, str], InputAudio]

Expand All @@ -912,6 +942,8 @@ def _get_full_multimodal_text_prompt(placeholder_counts: dict[str, int],
] = {
"text":
lambda part: _TextParser(part).get("text", None),
"image":
lambda part: _PILImageParser(part).get("image", None),
"image_url":
lambda part: _ImageParser(part).get("image_url", {}).get("url", None),
"image_embeds":
Expand Down Expand Up @@ -985,7 +1017,7 @@ def _parse_chat_message_content_mm_part(


VALID_MESSAGE_CONTENT_MM_PART_TYPES = ("text", "refusal", "image_url",
"image_embeds",
"image_embeds", "image",
"audio_url", "input_audio", "video_url")


Expand Down Expand Up @@ -1056,6 +1088,10 @@ def _parse_chat_message_content_part(
else:
return str_content

if part_type == "image":
image = cast(Image.Image, content)
mm_parser.parse_pil_image(image)
return {'type': 'image'} if wrap_dicts else None
if part_type == "image_url":
str_content = cast(str, content)
mm_parser.parse_image(str_content)
Expand Down