def init_pooling_state(
engine_client: "EngineClient",
state: "State",
args: "Namespace",
request_logger: RequestLogger | None,
supported_tasks: tuple["SupportedTask", ...],
):
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
from vllm.tasks import POOLING_TASKS
resolved_chat_template = load_chat_template(args.chat_template)
state.openai_serving_pooling = (
(
OpenAIServingPooling(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
)
if any(t in supported_tasks for t in POOLING_TASKS)
else None
)
state.openai_serving_embedding = (
OpenAIServingEmbedding(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "embed" in supported_tasks
else None
)
state.openai_serving_classification = (
ServingClassification(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
trust_request_chat_template=args.trust_request_chat_template,
log_error_stack=args.log_error_stack,
)
if "classify" in supported_tasks
else None
)
# ServingScores handles score/rerank for:
# - "score" task (cross-encoder models)
# - "embed" task (bi-encoder models)
# - "token_embed" task (late interaction models like ColBERT)
state.openai_serving_scores = (
ServingScores(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
score_template=resolved_chat_template,
log_error_stack=args.log_error_stack,
)
if any(t in supported_tasks for t in ("embed", "score", "token_embed"))
else None
)