HexGen: Generative inference of foundation model over heterogeneous decentralized environment

Meta Info

Presented in arxiv:2311.11514.

Understanding the paper

TL;DR

  • Formally define the scheduling of serving the inference of multiple copies of the same foundation model concurrently over a heterogeneous set of GPU devices as a constrained optimization problem

    • Each pipeline stage can consider a different tensor model parallel degree

    • Propose a heuristic-based evolutionary algorithm to search for the optimal layout

  • HexGen — a distributed inference engine

    • Support asymmetric tensor model parallelism and pipeline parallelism under the heterogeneous setting

    • Select a leader GPU node in a pipeline stage

      • Manage the peer-to-peer communication between pipeline stages

      • Manage the broadcast operation of the received activations within its tensor model parallel group

Formulation

  • D={d1,⋯ ,dN}\textbf{D} = \{ d_1, \cdots, d_N \} — a set of NN GPU devices

    • MdM_d — GPU memory limit

    • mdm_d — GPU memory bandwidth

    • cdc_d — Tensor core computation power

    • A∈R+N×N\textbf{A} \in \mathbb{R}_{+}^{N \times N} — The communication delay matrix between these devices

      • αd,d′\alpha_{d, d'} — The delay between device dd and d′d'

    • B∈R+N×N\textbf{B} \in \mathbb{R}_{+}^{N \times N} — The communication bandwidth matrix between these devices

      • βd,d′\beta_{d, d'} — The bandwidth between the device dd and d′d'

    • LL — The total number of layers in the model

  • σ\sigma — An assignment D→{di,j,li,j}\textbf{D} \rightarrow \{ \textbf{d}_{i,j}, l_{i,j} \}

    • di,j\textbf{d}_{i,j} — A subset of GPU devices

      • Serve the ii-th model replica as an independent pipeline

      • Serve the jj-th stage in the -th pipeline

      • li,jl_{i,j} — Transformer layers

      • ∥di,j∥>1\left \| \textbf{d}_{i,j} \right\| > 1 → Run tensor model parallelism

    • An optimal assignment σ∗=arg,maxσ∈∑ET∼φ[SLO(Ccomm(σ)+Ccomp(σ))]\sigma^{*} = \text{arg},\mathop{\text{max}}\limits_{\sigma \in \sum}\quad\mathbb{E}_{\textbf{T} \sim \varphi}[\text{SLO}(C_{\text{comm}}(\sigma) + C_{\text{comp}}(\sigma))]

      • s.t. Cmemd(σ)≤Md,∀d∈DC_{\text{mem}}^{d}(\sigma) \le M_d, \forall d \in \textbf{D}

      • Ccomm(σ)C_{\text{comm}}(\sigma) — The communication cost

      • Ccomp(σ)C_{\text{comp}}(\sigma) — The computation cost

      • Cmemd(σ)C_{\text{mem}}^{d}(\sigma) — Memory consumption for the device dd

      • Objective: Find an optimal assignment that partitions the device set to represent multiple independent inference pipeline groups that can maximize the inference service SLO considering the computation cost, communication cost, and memory consumption constraints

Implementation

  • Essential change: each pipeline parallel stage can be assigned with a different number of layers and tensor model parallel degree

  • Steps

    • Each stage selects a leader GPU to initialize an independent tensor model parallel group

    • Only the leader node in each stage (i.e., tensor model parallel group) sends the activation to the leader GPU in the next stage

    • After receiving the activation, the leader GPU broadcasts this activation among its tensor model parallel group to execute the tensor model parallel computation

Evaluation

  • Compared to Petals

    • Petals depends on dynamic adjustment of the collective learning paradigm to ensure elasticity → a dynamic design compromises the inference service performance

    • HexGen carefully designs static scheduling of the inference workflow

  • Metrics

    • SLO attainment

      • Generate some inference workload according to a Poisson process parameterized by request rate

      • For a target SLO goal (e.g., 99%)

        • The minimum latency deadline to achieve the desired attainment

        • The system’s resilience to peak request rate

      • Llama 2 70B model

      • Real-world prompts: https://huggingface.co/datasets/lmsys/chatbot_arena_conversations

      • Output sequence length: 32, 64, 128

      • Request rates vary between 0.125 - 10 requests per second

      • SLO Scale=T\text{SLO Scale}=T → the default SLO is set as tight as T×T \times inference latency

Last updated