LGCLDCMar 17, 2024

JORA: JAX Tensor-Parallel LoRA Library for Retrieval Augmented Fine-Tuning

arXiv:2403.11366v227 citationsh-index: 9Has CodeACL
AI Analysis

This work addresses scalability issues for researchers and practitioners fine-tuning LLMs on retrieval-based tasks, particularly on systems with limited GPU resources, though it is incremental as it builds on existing PEFT and distributed training methods.

The paper tackles the memory constraints in fine-tuning large language models for retrieval augmented generation by introducing a JAX-based tensor-parallel LoRA library, achieving over 12x faster runtime and less than half the VRAM usage per GPU compared to existing implementations.

The scaling of Large Language Models (LLMs) for retrieval-based tasks, particularly in Retrieval Augmented Generation (RAG), faces significant memory constraints, especially when fine-tuning extensive prompt sequences. Current open-source libraries support full-model inference and fine-tuning across multiple GPUs but fall short of accommodating the efficient parameter distribution required for retrieved context. Addressing this gap, we introduce a novel framework for PEFT-compatible fine-tuning of Llama-2 models, leveraging distributed training. Our framework uniquely utilizes JAX's just-in-time (JIT) compilation and tensor-sharding for efficient resource management, thereby enabling accelerated fine-tuning with reduced memory requirements. This advancement significantly improves the scalability and feasibility of fine-tuning LLMs for complex RAG applications, even on systems with limited GPU resources. Our experiments show more than 12x improvement in runtime compared to Hugging Face/DeepSpeed implementation with four GPUs while consuming less than half the VRAM per GPU.

Code Implementations1 repo
Foundations

The foundational work for this paper's niche, ranked by how specifically the neighbourhood builds on it — not by global fame.

Your Notes