# SPDX-License-Identifier: Apache-2.0# SPDX-FileCopyrightText: Copyright contributors to the vLLM project"""Saves each worker's model state dict directly to a checkpoint, which enables afast load path for large tensor-parallel models where each worker only needs toread its own shard rather than the entire checkpoint.Example usage:python save_sharded_state.py \ --model /path/to/load \ --quantization deepspeedfp \ --tensor-parallel-size 8 \ --output /path/to/saveThen, the model can be loaded withllm = LLM( model="/path/to/save", load_format="sharded_state", quantization="deepspeedfp", tensor_parallel_size=8,)"""importdataclassesimportosimportshutilfrompathlibimportPathfromvllmimportLLM,EngineArgsfromvllm.model_executor.model_loaderimportShardedStateLoaderfromvllm.utils.argparse_utilsimportFlexibleArgumentParserdefparse_args():parser=FlexibleArgumentParser()EngineArgs.add_cli_args(parser)parser.add_argument("--output","-o",required=True,type=str,help="path to output checkpoint")parser.add_argument("--file-pattern",type=str,default=ShardedStateLoader.DEFAULT_PATTERN,help="string pattern of saved filenames",)parser.add_argument("--max-file-size",type=int,default=5*1024**3,help="max size (in bytes) of each safetensors file",)returnparser.parse_args()defmain(args):engine_args=EngineArgs.from_cli_args(args)ifengine_args.enable_lora:raiseValueError("Saving with enable_lora=True is not supported!")model_path=engine_args.modelifnotPath(model_path).is_dir():raiseValueError("model path must be a local directory")# Create LLM instance from argumentsllm=LLM(**dataclasses.asdict(engine_args))# Prepare output directoryPath(args.output).mkdir(exist_ok=True)# Dump worker states to output directoryllm.llm_engine.engine_core.save_sharded_state(path=args.output,pattern=args.file_pattern,max_size=args.max_file_size)# Copy metadata files to output directoryforfileinos.listdir(model_path):ifos.path.splitext(file)[1]notin(".bin",".pt",".safetensors"):ifos.path.isdir(os.path.join(model_path,file)):shutil.copytree(os.path.join(model_path,file),os.path.join(args.output,file))else:shutil.copy(os.path.join(model_path,file),args.output)if__name__=="__main__":args=parse_args()main(args)