MPC-Minimized Secure LLM Inference
This addresses privacy concerns in LLM inference services for users and providers by making secure inference more practical, though it is incremental as it builds on existing MPC and fine-tuning methods.
The paper tackles the problem of high overhead in secure multi-party computation (MPC) for large language model (LLM) inference by proposing Marill, a fine-tuning framework that minimizes MPC usage, resulting in 3.6-11.3x better runtime and 2.4-6.9x better communication while preserving over 90% performance on downstream tasks.
Many inference services based on large language models (LLMs) pose a privacy concern, either revealing user prompts to the service or the proprietary weights to the user. Secure inference offers a solution to this problem through secure multi-party computation (MPC), however, it is still impractical for modern LLM workload due to the large overhead imposed by MPC. To address this overhead, we propose Marill, a framework that adapts LLM fine-tuning to minimize MPC usage during secure inference. Marill introduces high-level architectural changes during fine-tuning that significantly reduce the number of expensive operations needed within MPC during inference, by removing some and relocating others outside MPC without compromising security. As a result, Marill-generated models are more efficient across all secure inference protocols and our approach complements MPC-friendly approximations for such operations. Compared to standard fine-tuning, Marill results in 3.6-11.3x better runtime and 2.4-6.9x better communication during secure inference across various MPC settings, while typically preserving over 90% performance across downstream tasks.