Skip to content

vllm.distributed.elastic_ep.elastic_state

ElasticEPScalingState

Source code in vllm/distributed/elastic_ep/elastic_state.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
class ElasticEPScalingState:
    def __init__(
        self,
        model_executor: "Executor",
        engine_core: "DPEngineCoreProc",
        vllm_config: "VllmConfig",
        new_parallel_config: ParallelConfig,
        worker_type: WorkerType,
        scale_type: Literal["scale_up", "scale_down"],
        reconfig_request: ReconfigureDistributedRequest | None = None,
    ):
        self.model_executor_ref = weakref.ref(model_executor)
        self.engine_core_ref = weakref.ref(engine_core)
        self.vllm_config = vllm_config
        self.old_dp_group = self.engine_core.dp_group if worker_type != "new" else None
        self.old_dp_store = self.engine_core.dp_store if worker_type != "new" else None
        self.new_parallel_config: ParallelConfig = new_parallel_config
        self.new_dp_group: torch.distributed.ProcessGroup | None = (
            self.engine_core.dp_group if worker_type == "new" else None
        )
        self.new_dp_store = self.engine_core.dp_store if worker_type == "new" else None
        self.worker_type = worker_type
        self.scale_type = scale_type
        self.reconfig_request = reconfig_request

        if scale_type == "scale_up":
            self.state = (
                ScaleUpNewEngineState.PREPARE
                if worker_type == "new"
                else ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
            )
        else:
            self.state = (
                ScaleDownRemovingEngineState.PREPARE
                if worker_type == "removing"
                else ScaleDownRemainingEngineState.PREPARE
            )

    @property
    def model_executor(self) -> "Executor":
        model_executor = self.model_executor_ref()
        if model_executor is None:
            raise RuntimeError("Model executor has been garbage collected")
        return model_executor

    @property
    def engine_core(self) -> "DPEngineCoreProc":
        engine_core = self.engine_core_ref()
        if engine_core is None:
            raise RuntimeError("Engine core has been garbage collected")
        return engine_core

    def progress(self) -> bool:
        if self.scale_type == "scale_up":
            return (
                self._progress_new_engine()
                if self.worker_type == "new"
                else self._progress_existing_engine()
            )
        return (
            self._progress_removing_engine()
            if self.worker_type == "removing"
            else self._progress_remaining_engine()
        )

    def _execute_tcp_store_barrier(
        self, dp_store, group_rank, group_size, barrier_id, timeout=None
    ):
        arrival_key = f"arrival_{barrier_id}_{group_rank}"
        dp_store.set(arrival_key, b"1")

        start_time = time.time()
        processes_arrived: set[int] = set()

        while len(processes_arrived) < group_size:
            if (
                timeout is not None
                and time.time() - start_time > timeout.total_seconds()
            ):
                raise _BarrierTimeoutError(
                    f"Barrier timed out after {timeout.total_seconds()} seconds"
                )

            for i in range(group_size):
                if i in processes_arrived:
                    continue

                key = f"arrival_{barrier_id}_{i}"
                present = dp_store.check([key])
                if present:
                    processes_arrived.add(i)

            if len(processes_arrived) < group_size:
                sched_yield()

    def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
        """
        Execute a two-staged barrier to synchronize all engines in the DP group.

        Some DP EngineCores may receive the reconfiguration notifications
        later than others, and already proceed to engine step (model forward)
        in the busy loop.
        In this case, EngineCores that already proceed to reconfiguration
        should skip reconfiguration and execute model forward for one more
        step, so in the next step, all EngineCores will be synchronized.
        We use a two-staged barrier to achieve this. The first time each
        EngineCore executes the barrier, if a timeout is reached before the
        barrier completes, that means some EngineCores have already entered
        engine step. The EngineCores that timed out will then proceed to
        engine step, and will synchronize with the other EngineCores in the
        next step with a barrier without timeout.
        """
        dp_store = self.new_dp_store if use_new_group else self.old_dp_store
        dp_group = self.new_dp_group if use_new_group else self.old_dp_group
        assert dp_group is not None

        group_rank = dp_group.rank()
        group_size = dp_group.size()
        barrier_id = f"eep_barrier_{barrier_name}"
        sync_key = f"{barrier_id}_sync"

        # TODO(yongji): figure out appropriate timeout for the barrier
        timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)

        try:
            self._execute_tcp_store_barrier(
                dp_store, group_rank, group_size, barrier_id, timeout=timeout
            )
            torch.distributed.barrier(dp_group)
            if group_rank == 0:
                dp_store.delete_key(sync_key)
                for i in range(group_size):
                    dp_store.delete_key(f"arrival_{barrier_id}_{i}")
            return True
        except _BarrierTimeoutError as e:
            if timeout is None:
                raise RuntimeError("Unexpected timeout encountered") from e
            dp_store.compare_set(sync_key, "", b"1")
            return False

    def _progress_existing_engine(self) -> bool:
        state = self.state

        if state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT:
            return False

        elif state == ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS:
            # NOTE(yongji): wait for all existing workers to receive the request
            if (
                int(self.old_dp_store.get("eep_barrier_engine_count"))
                < self.old_dp_group.size()
            ):
                return False
            if not self._staged_barrier(
                use_new_group=False, barrier_name="create_standby_groups"
            ):
                return False
            if self.old_dp_group.rank() == 0:
                self.old_dp_store.delete_key("eep_barrier_engine_count")
            self._create_standby_groups()
            self.state = ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING
            return True

        elif state == ScaleUpExistingEngineState.TRANSFER_EXPERT_MAPPING:
            self._transfer_expert_mapping()
            self.state = ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
            return True

        elif state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT:
            return False

        elif state == ScaleUpExistingEngineState.TRANSFER_WEIGHTS:
            if (
                int(self.old_dp_store.get("eep_barrier_engine_count"))
                < self.old_dp_group.size()
            ):
                return False
            if not self._staged_barrier(
                use_new_group=False, barrier_name="transfer_weights"
            ):
                return False
            if self.old_dp_group.rank() == 0:
                self.old_dp_store.delete_key("eep_barrier_engine_count")
            self._transfer_weights()
            self.state = ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE
            return True

        elif state == ScaleUpExistingEngineState.SYNC_KV_CACHE_MEMORY_SIZE:
            self._sync_kv_cache_memory_size()
            self.state = ScaleUpExistingEngineState.SWITCH_AND_PREPARE
            return True

        elif state == ScaleUpExistingEngineState.SWITCH_AND_PREPARE:
            self._switch_and_prepare()
            self.state = ScaleUpExistingEngineState.EPLB_RESHUFFLE
            self.new_dp_store.add("eep_barrier_engine_count", 1)
            return True

        elif state == ScaleUpExistingEngineState.EPLB_RESHUFFLE:
            assert self.new_dp_group is not None
            if (
                int(self.new_dp_store.get("eep_barrier_engine_count"))
                < self.new_dp_group.size()
            ):
                return False
            if not self._staged_barrier(
                use_new_group=True, barrier_name="eplb_reshuffle"
            ):
                return False
            if self.new_dp_group.rank() == 0:
                self.new_dp_store.delete_key("eep_barrier_engine_count")
            self._eplb_reshuffle()
            self.state = ScaleUpExistingEngineState.COMPLETE
            self._update_parallel_config()
            return True

        else:
            assert self.state == ScaleUpExistingEngineState.COMPLETE
            return True

    def _progress_new_engine(self) -> bool:
        state = self.state
        assert self.new_dp_group is not None

        if state == ScaleUpNewEngineState.PREPARE:
            tensor = torch.tensor([0, 0, 0], dtype=torch.int32, device="cpu")
            torch.distributed.all_reduce(
                tensor,
                op=torch.distributed.ReduceOp.MAX,
                group=self.new_dp_group,
            )
            data = tensor.tolist()
            self.engine_core.engines_running = bool(data[0])
            self.engine_core.current_wave = int(data[1])
            self.engine_core.step_counter = int(data[2])
            self.state = ScaleUpNewEngineState.EPLB_RESHUFFLE
            self.new_dp_store.add("eep_barrier_engine_count", 1)
            return True

        elif state == ScaleUpNewEngineState.EPLB_RESHUFFLE:
            if (
                int(self.new_dp_store.get("eep_barrier_engine_count"))
                < self.new_dp_group.size()
            ):
                return False
            if not self._staged_barrier(
                use_new_group=True, barrier_name="eplb_reshuffle"
            ):
                return False
            assert self.new_dp_group.rank() > 0
            self._eplb_reshuffle()
            self.state = ScaleUpNewEngineState.COMPLETE
            return True

        else:
            assert self.state == ScaleUpNewEngineState.COMPLETE
            return True

    def _progress_remaining_engine(self) -> bool:
        state = self.state

        if state == ScaleDownRemainingEngineState.PREPARE:
            self.state = ScaleDownRemainingEngineState.EPLB_RESHUFFLE
            self.old_dp_store.add("eep_barrier_engine_count", 1)
            return True

        elif state == ScaleDownRemainingEngineState.EPLB_RESHUFFLE:
            if (
                int(self.old_dp_store.get("eep_barrier_engine_count"))
                < self.old_dp_group.size()
            ):
                return False
            if not self._staged_barrier(
                use_new_group=False, barrier_name="eplb_reshuffle"
            ):
                return False
            if self.old_dp_group.rank() == 0:
                self.old_dp_store.delete_key("eep_barrier_engine_count")
            self._eplb_reshuffle_before_scale_down()
            self.state = ScaleDownRemainingEngineState.SWITCH_AND_PREPARE
            # NOTE(yongji): currently, after EPLB reshuffle
            # that redistributes experts to remaining workers, workers
            # to be removed will immediately initiate shutdown;
            # existing workers can no longer execute forward steps using
            # the old setup. In the future, we may keep
            # the removing workers alive a bit longer,
            # e.g., to drain in-batch requests.
            self._create_standby_groups()
            self._switch_and_prepare()
            self._update_parallel_config()
            self.state = ScaleDownRemainingEngineState.COMPLETE
            return True

        else:
            assert self.state == ScaleDownRemainingEngineState.COMPLETE
            return True

    def _progress_removing_engine(self) -> bool:
        state = self.state

        if state == ScaleDownRemovingEngineState.PREPARE:
            self.state = ScaleDownRemovingEngineState.EPLB_RESHUFFLE
            self.old_dp_store.add("eep_barrier_engine_count", 1)
            return True

        if state == ScaleDownRemovingEngineState.EPLB_RESHUFFLE:
            if (
                int(self.old_dp_store.get("eep_barrier_engine_count"))
                < self.old_dp_group.size()
            ):
                return False
            if not self._staged_barrier(
                use_new_group=False, barrier_name="eplb_reshuffle"
            ):
                return False
            assert self.old_dp_group.rank() > 0
            self._eplb_reshuffle_before_scale_down()
            self._switch_and_remove()
            self.state = ScaleDownRemovingEngineState.COMPLETE
            self.engine_core._eep_send_engine_core_notification(
                EEPNotificationType.SHUTDOWN_COMPLETE
            )
            self.engine_core.shutdown()
            return True

        else:
            assert self.state == ScaleDownRemovingEngineState.COMPLETE
            return True

    def handle_notification(self, notification_type: EEPNotificationType):
        assert self.worker_type != "new"
        if (
            notification_type == EEPNotificationType.NEW_CORE_ENGINES_INIT_READY
            and self.state == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_INIT
        ):
            self.old_dp_store.add("eep_barrier_engine_count", 1)
            self.state = ScaleUpExistingEngineState.CREATE_STANDBY_GROUPS
        elif (
            notification_type == EEPNotificationType.NEW_CORE_ENGINES_WEIGHTS_INIT_READY
            and self.state
            == ScaleUpExistingEngineState.WAIT_NEW_CORE_ENGINES_WEIGHTS_INIT
        ):
            self.old_dp_store.add("eep_barrier_engine_count", 1)
            self.state = ScaleUpExistingEngineState.TRANSFER_WEIGHTS

    def is_complete(self) -> bool:
        if self.scale_type == "scale_up":
            return (
                self.state == ScaleUpNewEngineState.COMPLETE
                if self.worker_type == "new"
                else self.state == ScaleUpExistingEngineState.COMPLETE
            )
        return (
            self.state == ScaleDownRemovingEngineState.COMPLETE
            if self.worker_type == "removing"
            else self.state == ScaleDownRemainingEngineState.COMPLETE
        )

    def _create_standby_groups(self):
        self.new_dp_group, self.new_dp_store = (
            self.new_parallel_config.stateless_init_dp_group(return_store=True)
        )
        self.model_executor.collective_rpc(
            "elastic_ep_execute", args=("create_standby_groups", self.reconfig_request)
        )
        if self.old_dp_group.rank() == 0:
            logger.info("[Elastic EP] Created standby communication groups")

    def _transfer_weights(self):
        assert self.reconfig_request is not None
        old_dp_size = self.old_dp_group.size()
        new_dp_size = self.reconfig_request.new_data_parallel_size

        self.model_executor.collective_rpc(
            "elastic_ep_execute", args=("transfer_weights", old_dp_size, new_dp_size)
        )
        if self.old_dp_group.rank() == 0:
            logger.info("[Elastic EP] Transferred weights to new workers")

    def _transfer_expert_mapping(self):
        self.model_executor.collective_rpc(
            "elastic_ep_execute", args=("broadcast_expert_mapping",)
        )
        if self.old_dp_group.rank() == 0:
            logger.info("[Elastic EP] Broadcasted expert mapping to new workers")

    def _sync_kv_cache_memory_size(self):
        assert self.engine_core.available_gpu_memory_for_kv_cache > 0
        assert self.new_dp_group is not None
        ParallelConfig.sync_kv_cache_memory_size(
            self.new_dp_group,
            self.engine_core.available_gpu_memory_for_kv_cache,
        )
        if self.old_dp_group.rank() == 0:
            logger.info("[Elastic EP] Synced KV cache memory size to new workers")

    def _switch_and_prepare(self):
        self.model_executor.collective_rpc(
            "elastic_ep_execute", args=("switch_and_prepare",)
        )
        old_dp_group = self.old_dp_group
        stateless_destroy_torch_distributed_process_group(old_dp_group)
        assert self.new_dp_group is not None
        new_dp_group = self.new_dp_group
        self.engine_core.dp_group = new_dp_group
        self.engine_core.dp_rank = new_dp_group.rank()
        self.engine_core.dp_store = self.new_dp_store
        engines_running = int(self.engine_core.engines_running)
        current_wave = self.engine_core.current_wave
        step_counter = self.engine_core.step_counter
        tensor = torch.tensor(
            [engines_running, current_wave, step_counter],
            dtype=torch.int32,
            device="cpu",
        )
        torch.distributed.all_reduce(
            tensor, op=torch.distributed.ReduceOp.MAX, group=new_dp_group
        )
        data = tensor.tolist()
        self.engine_core.engines_running = bool(data[0])
        self.engine_core.current_wave = int(data[1])
        self.engine_core.step_counter = int(data[2])
        if new_dp_group.rank() == 0:
            self.engine_core._eep_send_engine_core_notification(
                EEPNotificationType.RECONFIGURE_FINISHED
            )
            logger.info("[Elastic EP] Switched to new setup")

    def _eplb_reshuffle(self):
        self.model_executor.collective_rpc(
            "elastic_ep_execute", args=("perform_eplb_reshuffle",)
        )
        assert self.new_dp_group is not None
        if self.new_dp_group.rank() == 0:
            logger.info("[Elastic EP] EPLB reshuffle completed")

    def _eplb_reshuffle_before_scale_down(self):
        assert self.reconfig_request is not None
        self.model_executor.collective_rpc(
            "elastic_ep_execute",
            args=(
                "perform_eplb_reshuffle",
                self.reconfig_request.new_data_parallel_size,
            ),
        )
        if self.old_dp_group.rank() == 0:
            logger.info("[Elastic EP] EPLB reshuffle completed")

    def _switch_and_remove(self):
        self.model_executor.collective_rpc(
            "elastic_ep_execute", args=("switch_and_remove",)
        )

    def _update_parallel_config(self):
        assert self.reconfig_request is not None
        reconfig_request = self.reconfig_request
        parallel_config = self.vllm_config.parallel_config
        parallel_config.data_parallel_size = reconfig_request.new_data_parallel_size
        if (
            reconfig_request.new_data_parallel_rank
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank = reconfig_request.new_data_parallel_rank
        if (
            reconfig_request.new_data_parallel_rank_local
            != ReconfigureRankType.KEEP_CURRENT_RANK
        ):
            parallel_config.data_parallel_rank_local = (
                reconfig_request.new_data_parallel_rank_local
            )
        parallel_config.data_parallel_master_ip = (
            reconfig_request.new_data_parallel_master_ip
        )
        parallel_config.data_parallel_master_port = (
            reconfig_request.new_data_parallel_master_port
        )
        parallel_config._data_parallel_master_port_list = (
            reconfig_request.new_data_parallel_master_port_list
        )
        parallel_config._stateless_world_group_port_list = (
            reconfig_request.new_stateless_world_group_port_list
        )
        parallel_config._stateless_dp_group_port_list = (
            reconfig_request.new_stateless_dp_group_port_list
        )
        parallel_config._stateless_ep_group_port_list = (
            reconfig_request.new_stateless_ep_group_port_list
        )
        parallel_config._stateless_eplb_group_port_list = (
            reconfig_request.new_stateless_eplb_group_port_list
        )

_staged_barrier

_staged_barrier(
    use_new_group: bool, barrier_name: str
) -> bool

Execute a two-staged barrier to synchronize all engines in the DP group.

Some DP EngineCores may receive the reconfiguration notifications later than others, and already proceed to engine step (model forward) in the busy loop. In this case, EngineCores that already proceed to reconfiguration should skip reconfiguration and execute model forward for one more step, so in the next step, all EngineCores will be synchronized. We use a two-staged barrier to achieve this. The first time each EngineCore executes the barrier, if a timeout is reached before the barrier completes, that means some EngineCores have already entered engine step. The EngineCores that timed out will then proceed to engine step, and will synchronize with the other EngineCores in the next step with a barrier without timeout.

Source code in vllm/distributed/elastic_ep/elastic_state.py
def _staged_barrier(self, use_new_group: bool, barrier_name: str) -> bool:
    """
    Execute a two-staged barrier to synchronize all engines in the DP group.

    Some DP EngineCores may receive the reconfiguration notifications
    later than others, and already proceed to engine step (model forward)
    in the busy loop.
    In this case, EngineCores that already proceed to reconfiguration
    should skip reconfiguration and execute model forward for one more
    step, so in the next step, all EngineCores will be synchronized.
    We use a two-staged barrier to achieve this. The first time each
    EngineCore executes the barrier, if a timeout is reached before the
    barrier completes, that means some EngineCores have already entered
    engine step. The EngineCores that timed out will then proceed to
    engine step, and will synchronize with the other EngineCores in the
    next step with a barrier without timeout.
    """
    dp_store = self.new_dp_store if use_new_group else self.old_dp_store
    dp_group = self.new_dp_group if use_new_group else self.old_dp_group
    assert dp_group is not None

    group_rank = dp_group.rank()
    group_size = dp_group.size()
    barrier_id = f"eep_barrier_{barrier_name}"
    sync_key = f"{barrier_id}_sync"

    # TODO(yongji): figure out appropriate timeout for the barrier
    timeout = None if dp_store.check([sync_key]) else timedelta(seconds=5)

    try:
        self._execute_tcp_store_barrier(
            dp_store, group_rank, group_size, barrier_id, timeout=timeout
        )
        torch.distributed.barrier(dp_group)
        if group_rank == 0:
            dp_store.delete_key(sync_key)
            for i in range(group_size):
                dp_store.delete_key(f"arrival_{barrier_id}_{i}")
        return True
    except _BarrierTimeoutError as e:
        if timeout is None:
            raise RuntimeError("Unexpected timeout encountered") from e
        dp_store.compare_set(sync_key, "", b"1")
        return False

_BarrierTimeoutError

Bases: RuntimeError

Exception raised for timeout in the first stage of our two-staged TCPStore based barrier to synchronize the execution of all engines in the DP group.

Source code in vllm/distributed/elastic_ep/elastic_state.py
class _BarrierTimeoutError(RuntimeError):
    """
    Exception raised for timeout
    in the first stage of our two-staged
    TCPStore based barrier to synchronize the
    execution of all engines in the DP group.
    """