# SPDX-License-Identifier: Apache-2.0# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""Validates the loading of a model saved with the sharded_state format.This script demonstrates how to load a model that was previously savedusing save_sharded_state.py and validates it by running inference.Example usage:(First need to save a sharded_state mode)python save_sharded_state.py \ --model /path/to/load \ --quantization deepspeedfp \ --tensor-parallel-size 8 \ --output /path/to/save/sharded/modelpython load_sharded_state.py \ --model /path/to/saved/sharded/model \ --load-format sharded_state \ --quantization deepspeedfp \ --tensor-parallel-size 8 \ --prompt "Hello, my name is" \ --max-tokens 50"""importdataclassesfromvllmimportLLM,EngineArgs,SamplingParamsfromvllm.utils.argparse_utilsimportFlexibleArgumentParserdefparse_args():parser=FlexibleArgumentParser()# Add engine argumentsEngineArgs.add_cli_args(parser)# Override default load_format for clarityparser.set_defaults(load_format="sharded_state")# Add validation argumentsparser.add_argument("--prompt",type=str,default="Hello, world!",help="Prompt for validation")parser.add_argument("--max-tokens",type=int,default=100,help="Maximum number of tokens to generate",)parser.add_argument("--temperature",type=float,default=0.7,help="Sampling temperature")parser.add_argument("--top-p",type=float,default=1.0,help="Top-p sampling parameter")returnparser.parse_args()defmain():args=parse_args()engine_args=EngineArgs.from_cli_args(args)print(f"Loading model from {engine_args.model} using format {engine_args.load_format}")print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")# Load the model using engine argsllm=LLM(**dataclasses.asdict(engine_args))# Prepare sampling parameterssampling_params=SamplingParams(temperature=args.temperature,top_p=args.top_p,max_tokens=args.max_tokens,)print("\nRunning inference:")print(f"Prompt: {args.prompt}")# Generate completionoutputs=llm.generate(args.prompt,sampling_params)# Display generated textprint("\nGenerated outputs:")foroutputinoutputs:generated_text=output.outputs[0].textprint("-"*50)print(f"Full output: {args.prompt}{generated_text}")print("-"*50)if__name__=="__main__":main()