Prateek Jain

LG
h-index117
120papers
14,942citations
Novelty59%
AI Score62

120 Papers

LGMay 26, 2022Code
Matryoshka Representation Learning

Aditya Kusupati, Gantavya Bhatt, Aniket Rege et al. · uw

Learned representations are a central component in modern ML systems, serving a multitude of downstream tasks. When training such representations, it is often the case that computational and statistical constraints for each downstream task are unknown. In this context rigid, fixed capacity representations can be either over or under-accommodating to the task at hand. This leads us to ask: can we design a flexible representation that can adapt to multiple downstream tasks with varying computational resources? Our main contribution is Matryoshka Representation Learning (MRL) which encodes information at different granularities and allows a single embedding to adapt to the computational constraints of downstream tasks. MRL minimally modifies existing representation learning pipelines and imposes no additional cost during inference and deployment. MRL learns coarse-to-fine representations that are at least as accurate and rich as independently trained low-dimensional representations. The flexibility within the learned Matryoshka Representations offer: (a) up to 14x smaller embedding size for ImageNet-1K classification at the same level of accuracy; (b) up to 14x real-world speed-ups for large-scale retrieval on ImageNet-1K and 4K; and (c) up to 2% accuracy improvements for long-tail few-shot classification, all while being as robust as the original representations. Finally, we show that MRL extends seamlessly to web-scale datasets (ImageNet, JFT) across various modalities -- vision (ViT, ResNet), vision + language (ALIGN) and language (BERT). MRL code and pretrained models are open-sourced at https://github.com/RAIVNLab/MRL.

LGOct 11, 2023
MatFormer: Nested Transformer for Elastic Inference

Devvrit, Sneha Kudugunta, Aditya Kusupati et al. · uw

Foundation models are applied in a broad spectrum of settings with different inference constraints, from massive multi-accelerator clusters to resource-constrained standalone mobile devices. However, the substantial costs associated with training these models often limit the number of unique model sizes that can be offered. Consequently, practitioners are compelled to select a model that may not be optimally aligned with their specific latency and cost requirements. We present MatFormer, a novel Transformer architecture designed to provide elastic inference across diverse deployment constraints. MatFormer achieves this by incorporating a nested Feed Forward Network (FFN) block structure within a standard Transformer model. During training, we optimize the parameters of multiple nested FFN blocks with varying sizes, enabling the extraction of hundreds of accurate smaller models without incurring additional computational costs. We empirically validate the efficacy of MatFormer across different model classes (decoders and encoders) and modalities (language and vision), demonstrating its potential for real-world deployment. We show that a 850M decoder-only MatFormer language model (MatLM) allows us to extract multiple smaller models spanning from 582M to 850M parameters, each exhibiting better validation loss and one-shot downstream evaluations than independently trained counterparts. Furthermore, we observe that smaller encoders extracted from a universal MatFormer-based ViT (MatViT) encoder preserve the metric-space structure for adaptive large-scale retrieval. Finally, we showcase that speculative decoding with the accurate and consistent submodels extracted from MatFormer can lead to significant reduction in inference latency. Project website: https://devvrit.github.io/matformer/

LGOct 11, 2022
Multi-User Reinforcement Learning with Low Rank Rewards

Naman Agarwal, Prateek Jain, Suhas Kowshik et al. · deepmind

In this work, we consider the problem of collaborative multi-user reinforcement learning. In this setting there are multiple users with the same state-action space and transition probabilities but with different rewards. Under the assumption that the reward matrix of the $N$ users has a low-rank structure -- a standard and practically successful assumption in the offline collaborative filtering setting -- the question is can we design algorithms with significantly lower sample complexity compared to the ones that learn the MDP individually for each user. Our main contribution is an algorithm which explores rewards collaboratively with $N$ user-specific MDPs and can learn rewards efficiently in two key settings: tabular MDPs and linear MDPs. When $N$ is large and the rank is constant, the sample complexity per MDP depends logarithmically over the size of the state-space, which represents an exponential reduction (in the state-space size) when compared to the standard ``non-collaborative'' algorithms.

LGOct 16, 2023Code
Dual-Encoders for Extreme Multi-Label Classification

Nilesh Gupta, Devvrit Khatri, Ankit S Rawat et al.

Dual-encoder (DE) models are widely used in retrieval tasks, most commonly studied on open QA benchmarks that are often characterized by multi-class and limited training data. In contrast, their performance in multi-label and data-rich retrieval settings like extreme multi-label classification (XMC), remains under-explored. Current empirical evidence indicates that DE models fall significantly short on XMC benchmarks, where SOTA methods linearly scale the number of learnable parameters with the total number of classes (documents in the corpus) by employing per-class classification head. To this end, we first study and highlight that existing multi-label contrastive training losses are not appropriate for training DE models on XMC tasks. We propose decoupled softmax loss - a simple modification to the InfoNCE loss - that overcomes the limitations of existing contrastive losses. We further extend our loss design to a soft top-k operator-based loss which is tailored to optimize top-k prediction performance. When trained with our proposed loss functions, standard DE models alone can match or outperform SOTA methods by up to 2% at Precision@1 even on the largest XMC datasets while being 20x smaller in terms of the number of trainable parameters. This leads to more parameter-efficient and universally applicable solutions for retrieval tasks. Our code and models are publicly available at https://github.com/nilesh2797/dexml.

CVJul 29, 2024
Mixture of Nested Experts: Adaptive Processing of Visual Tokens

Gagan Jain, Nidhi Hegde, Aditya Kusupati et al. · uw

The visual medium (images and videos) naturally contains a large amount of information redundancy, thereby providing a great opportunity for leveraging efficiency in processing. While Vision Transformer (ViT) based models scale effectively to large data regimes, they fail to capitalize on this inherent redundancy, leading to higher computational costs. Mixture of Experts (MoE) networks demonstrate scalability while maintaining same inference-time costs, but they come with a larger parameter footprint. We present Mixture of Nested Experts (MoNE), which utilizes a nested structure for experts, wherein individual experts fall on an increasing compute-accuracy curve. Given a compute budget, MoNE learns to dynamically choose tokens in a priority order, and thus redundant tokens are processed through cheaper nested experts. Using this framework, we achieve equivalent performance as the baseline models, while reducing inference time compute by over two-fold. We validate our approach on standard image and video datasets - ImageNet-21K, Kinetics400, and Something-Something-v2. We further highlight MoNE$'$s adaptability by showcasing its ability to maintain strong performance across different inference-time compute budgets on videos, using only a single trained model.

LGMay 27, 2022
DP-PCA: Statistically Optimal and Differentially Private PCA

Xiyang Liu, Weihao Kong, Prateek Jain et al.

We study the canonical statistical task of computing the principal component from $n$ i.i.d.~data in $d$ dimensions under $(\varepsilon,δ)$-differential privacy. Although extensively studied in literature, existing solutions fall short on two key aspects: ($i$) even for Gaussian data, existing private algorithms require the number of samples $n$ to scale super-linearly with $d$, i.e., $n=Ω(d^{3/2})$, to obtain non-trivial results while non-private PCA requires only $n=O(d)$, and ($ii$) existing techniques suffer from a non-vanishing error even when the randomness in each data point is arbitrarily small. We propose DP-PCA, which is a single-pass algorithm that overcomes both limitations. It is based on a private minibatch gradient ascent method that relies on {\em private mean estimation}, which adds minimal noise required to ensure privacy by adapting to the variance of a given minibatch of gradients. For sub-Gaussian data, we provide nearly optimal statistical error rates even for $n=\tilde O(d)$. Furthermore, we provide a lower bound showing that sub-Gaussian style assumption is necessary in obtaining the optimal error rate.

LGFeb 1, 2023
Simplicity Bias in 1-Hidden Layer Neural Networks

Depen Morwani, Jatin Batra, Prateek Jain et al.

Recent works have demonstrated that neural networks exhibit extreme simplicity bias(SB). That is, they learn only the simplest features to solve a task at hand, even in the presence of other, more robust but more complex features. Due to the lack of a general and rigorous definition of features, these works showcase SB on semi-synthetic datasets such as Color-MNIST, MNIST-CIFAR where defining features is relatively easier. In this work, we rigorously define as well as thoroughly establish SB for one hidden layer neural networks. More concretely, (i) we define SB as the network essentially being a function of a low dimensional projection of the inputs (ii) theoretically, we show that when the data is linearly separable, the network primarily depends on only the linearly separable ($1$-dimensional) subspace even in the presence of an arbitrarily large number of other, more complex features which could have led to a significantly more robust classifier, (iii) empirically, we show that models trained on real datasets such as Imagenette and Waterbirds-Landbirds indeed depend on a low dimensional projection of the inputs, thereby demonstrating SB on these datasets, iv) finally, we present a natural ensemble approach that encourages diversity in models by training successive models on features not used by earlier models, and demonstrate that it yields models that are significantly more robust to Gaussian noise.

QUANT-PHNov 15, 2022
Variational Quantum Algorithms for Chemical Simulation and Drug Discovery

Hasan Mustafa, Sai Nandan Morapakula, Prateek Jain et al.

Quantum computing has gained a lot of attention recently, and scientists have seen potential applications in this field using quantum computing for Cryptography and Communication to Machine Learning and Healthcare. Protein folding has been one of the most interesting areas to study, and it is also one of the biggest problems of biochemistry. Each protein folds distinctively, and the difficulty of finding its stable shape rapidly increases with an increase in the number of amino acids in the chain. A moderate protein has about 100 amino acids, and the number of combinations one needs to verify to find the stable structure is enormous. At some point, the number of these combinations will be so vast that classical computers cannot even attempt to solve them. In this paper, we examine how this problem can be solved with the help of quantum computing using two different algorithms, Variational Quantum Eigensolver (VQE) and Quantum Approximate Optimization Algorithm (QAOA), using Qiskit Nature. We compare the results of different quantum hardware and simulators and check how error mitigation affects the performance. Further, we make comparisons with SoTA algorithms and evaluate the reliability of the method.

LGSep 8, 2022
Online Low Rank Matrix Completion

Prateek Jain, Soumyabrata Pal

We study the problem of {\em online} low-rank matrix completion with $\mathsf{M}$ users, $\mathsf{N}$ items and $\mathsf{T}$ rounds. In each round, the algorithm recommends one item per user, for which it gets a (noisy) reward sampled from a low-rank user-item preference matrix. The goal is to design a method with sub-linear regret (in $\mathsf{T}$) and nearly optimal dependence on $\mathsf{M}$ and $\mathsf{N}$. The problem can be easily mapped to the standard multi-armed bandit problem where each item is an {\em independent} arm, but that leads to poor regret as the correlation between arms and users is not exploited. On the other hand, exploiting the low-rank structure of reward matrix is challenging due to non-convexity of the low-rank manifold. We first demonstrate that the low-rank structure can be exploited using a simple explore-then-commit (ETC) approach that ensures a regret of $O(\mathsf{polylog} (\mathsf{M}+\mathsf{N}) \mathsf{T}^{2/3})$. That is, roughly only $\mathsf{polylog} (\mathsf{M}+\mathsf{N})$ item recommendations are required per user to get a non-trivial solution. We then improve our result for the rank-$1$ setting which in itself is quite challenging and encapsulates some of the key issues. Here, we propose \textsc{OCTAL} (Online Collaborative filTering using iterAtive user cLustering) that guarantees nearly optimal regret of $O(\mathsf{polylog} (\mathsf{M}+\mathsf{N}) \mathsf{T}^{1/2})$. OCTAL is based on a novel technique of clustering users that allows iterative elimination of items and leads to a nearly optimal minimax rate.

LGJan 30, 2023
Near Optimal Private and Robust Linear Regression

Xiyang Liu, Prateek Jain, Weihao Kong et al.

We study the canonical statistical estimation problem of linear regression from $n$ i.i.d.~examples under $(\varepsilon,δ)$-differential privacy when some response variables are adversarially corrupted. We propose a variant of the popular differentially private stochastic gradient descent (DP-SGD) algorithm with two innovations: a full-batch gradient descent to improve sample complexity and a novel adaptive clipping to guarantee robustness. When there is no adversarial corruption, this algorithm improves upon the existing state-of-the-art approach and achieves a near optimal sample complexity. Under label-corruption, this is the first efficient linear regression algorithm to guarantee both $(\varepsilon,δ)$-DP and robustness. Synthetic experiments confirm the superiority of our approach.

LGJul 11, 2022
(Nearly) Optimal Private Linear Regression via Adaptive Clipping

Prateek Varshney, Abhradeep Thakurta, Prateek Jain

We study the problem of differentially private linear regression where each data point is sampled from a fixed sub-Gaussian style distribution. We propose and analyze a one-pass mini-batch stochastic gradient descent method (DP-AMBSSGD) where points in each iteration are sampled without replacement. Noise is added for DP but the noise standard deviation is estimated online. Compared to existing $(ε, δ)$-DP techniques which have sub-optimal error bounds, DP-AMBSSGD is able to provide nearly optimal error bounds in terms of key parameters like dimensionality $d$, number of points $N$, and the standard deviation $σ$ of the noise in observations. For example, when the $d$-dimensional covariates are sampled i.i.d. from the normal distribution, then the excess error of DP-AMBSSGD due to privacy is $\frac{σ^2 d}{N}(1+\frac{d}{ε^2 N})$, i.e., the error is meaningful when number of samples $N= Ω(d \log d)$ which is the standard operative regime for linear regression. In contrast, error bounds for existing efficient methods in this setting are: $\mathcal{O}\big(\frac{d^3}{ε^2 N^2}\big)$, even for $σ=0$. That is, for constant $ε$, the existing techniques require $N=Ω(d\sqrt{d})$ to provide a non-trivial result.

CLAug 18, 2022
Treeformer: Dense Gradient Trees for Efficient Attention Computation

Lovish Madaan, Srinadh Bhojanapalli, Himanshu Jain et al.

Standard inference and training with transformer based architectures scale quadratically with input sequence length. This is prohibitively large for a variety of applications especially in web-page translation, query-answering etc. Consequently, several approaches have been developed recently to speedup attention computation by enforcing different attention structures such as sparsity, low-rank, approximating attention using kernels. In this work, we view attention computation as that of nearest neighbor retrieval, and use decision tree based hierarchical navigation to reduce the retrieval cost per query token from linear in sequence length to nearly logarithmic. Based on such hierarchical navigation, we design Treeformer which can use one of two efficient attention layers -- TF-Attention and TC-Attention. TF-Attention computes the attention in a fine-grained style, while TC-Attention is a coarse attention layer which also ensures that the gradients are "dense". To optimize such challenging discrete layers, we propose a two-level bootstrapped training method. Using extensive experiments on standard NLP benchmarks, especially for long-sequences, we demonstrate that our Treeformer architecture can be almost as accurate as baseline Transformer while using 30x lesser FLOPs in the attention layer. Compared to Linformer, the accuracy can be as much as 12% higher while using similar FLOPs in the attention layer.

LGFeb 15, 2023
Multi-Task Differential Privacy Under Distribution Skew

Walid Krichene, Prateek Jain, Shuang Song et al.

We study the problem of multi-task learning under user-level differential privacy, in which $n$ users contribute data to $m$ tasks, each involving a subset of users. One important aspect of the problem, that can significantly impact quality, is the distribution skew among tasks. Certain tasks may have much fewer data samples than others, making them more susceptible to the noise added for privacy. It is natural to ask whether algorithms can adapt to this skew to improve the overall utility. We give a systematic analysis of the problem, by studying how to optimally allocate a user's privacy budget among tasks. We propose a generic algorithm, based on an adaptive reweighting of the empirical loss, and show that when there is task distribution skew, this gives a quantifiable improvement of excess empirical risk. Experimental studies on recommendation problems that exhibit a long tail of small tasks, demonstrate that our methods significantly improve utility, achieving the state of the art on two standard benchmarks.

LGOct 4, 2022
Learning an Invertible Output Mapping Can Mitigate Simplicity Bias in Neural Networks

Sravanti Addepalli, Anshul Nasery, R. Venkatesh Babu et al.

Deep Neural Networks are known to be brittle to even minor distribution shifts compared to the training distribution. While one line of work has demonstrated that Simplicity Bias (SB) of DNNs - bias towards learning only the simplest features - is a key reason for this brittleness, another recent line of work has surprisingly found that diverse/ complex features are indeed learned by the backbone, and their brittleness is due to the linear classification head relying primarily on the simplest features. To bridge the gap between these two lines of work, we first hypothesize and verify that while SB may not altogether preclude learning complex features, it amplifies simpler features over complex ones. Namely, simple features are replicated several times in the learned representations while complex features might not be replicated. This phenomenon, we term Feature Replication Hypothesis, coupled with the Implicit Bias of SGD to converge to maximum margin solutions in the feature space, leads the models to rely mostly on the simple features for classification. To mitigate this bias, we propose Feature Reconstruction Regularizer (FRR) to ensure that the learned features can be reconstructed back from the logits. The use of {\em FRR} in linear layer training (FRR-L) encourages the use of more diverse features for classification. We further propose to finetune the full network by freezing the weights of the linear layer trained using FRR-L, to refine the learned features, making them more suitable for classification. Using this simple solution, we demonstrate up to 15% gains in OOD accuracy on the recently introduced semi-synthetic datasets with extreme distribution shifts. Moreover, we demonstrate noteworthy gains over existing SOTA methods on the standard OOD benchmark DomainBed as well.

AIFeb 1, 2023
iPAL: A Machine Learning Based Smart Healthcare Framework For Automatic Diagnosis Of Attention Deficit/Hyperactivity Disorder (ADHD)

Abhishek Sharma, Arpit Jain, Shubhangi Sharma et al.

ADHD is a prevalent disorder among the younger population. Standard evaluation techniques currently use evaluation forms, interviews with the patient, and more. However, its symptoms are similar to those of many other disorders like depression, conduct disorder, and oppositional defiant disorder, and these current diagnosis techniques are not very effective. Thus, a sophisticated computing model holds the potential to provide a promising diagnosis solution to this problem. This work attempts to explore methods to diagnose ADHD using combinations of multiple established machine learning techniques like neural networks and SVM models on the ADHD200 dataset and explore the field of neuroscience. In this work, multiclass classification is performed on phenotypic data using an SVM model. The better results have been analyzed on the phenotypic data compared to other supervised learning techniques like Logistic regression, KNN, AdaBoost, etc. In addition, neural networks have been implemented on functional connectivity from the MRI data of a sample of 40 subjects provided to achieve high accuracy without prior knowledge of neuroscience. It is combined with the phenotypic classifier using the ensemble technique to get a binary classifier. It is further trained and tested on 400 out of 824 subjects from the ADHD200 data set and achieved an accuracy of 92.5% for binary classification The training and testing accuracy has been achieved upto 99% using ensemble classifier.

LGAug 19, 2022
DAFT: Distilling Adversarially Fine-tuned Models for Better OOD Generalization

Anshul Nasery, Sravanti Addepalli, Praneeth Netrapalli et al.

We consider the problem of OOD generalization, where the goal is to train a model that performs well on test distributions that are different from the training distribution. Deep learning models are known to be fragile to such shifts and can suffer large accuracy drops even for slightly different test distributions. We propose a new method - DAFT - based on the intuition that adversarially robust combination of a large number of rich features should provide OOD robustness. Our method carefully distills the knowledge from a powerful teacher that learns several discriminative features using standard training while combining them using adversarial training. The standard adversarial training procedure is modified to produce teachers which can guide the student better. We evaluate DAFT on standard benchmarks in the DomainBed framework, and demonstrate that DAFT achieves significant improvements over the current state-of-the-art OOD generalization methods. DAFT consistently out-performs well-tuned ERM and distillation baselines by up to 6%, with more pronounced gains for smaller networks.

LGOct 13, 2023
EHI: End-to-end Learning of Hierarchical Index for Efficient Dense Retrieval

Ramnath Kumar, Anshul Mittal, Nilesh Gupta et al. · uw

Dense embedding-based retrieval is widely used for semantic search and ranking. However, conventional two-stage approaches, involving contrastive embedding learning followed by approximate nearest neighbor search (ANNS), can suffer from misalignment between these stages. This mismatch degrades retrieval performance. We propose End-to-end Hierarchical Indexing (EHI), a novel method that directly addresses this issue by jointly optimizing embedding generation and ANNS structure. EHI leverages a dual encoder for embedding queries and documents while simultaneously learning an inverted file index (IVF)-style tree structure. To facilitate the effective learning of this discrete structure, EHI introduces dense path embeddings that encodes the path traversed by queries and documents within the tree. Extensive evaluations on standard benchmarks, including MS MARCO (Dev set) and TREC DL19, demonstrate EHI's superiority over traditional ANNS index. Under the same computational constraints, EHI outperforms existing state-of-the-art methods by +1.45% in MRR@10 on MS MARCO (Dev) and +8.2% in nDCG@10 on TREC DL19, highlighting the benefits of our end-to-end approach.

LGOct 7, 2022
Sample-Efficient Personalization: Modeling User Parameters as Low Rank Plus Sparse Components

Soumyabrata Pal, Prateek Varshney, Prateek Jain et al.

Personalization of machine learning (ML) predictions for individual users/domains/enterprises is critical for practical recommendation systems. Standard personalization approaches involve learning a user/domain specific embedding that is fed into a fixed global model which can be limiting. On the other hand, personalizing/fine-tuning model itself for each user/domain -- a.k.a meta-learning -- has high storage/infrastructure cost. Moreover, rigorous theoretical studies of scalable personalization approaches have been very limited. To address the above issues, we propose a novel meta-learning style approach that models network weights as a sum of low-rank and sparse components. This captures common information from multiple individuals/users together in the low-rank part while sparse part captures user-specific idiosyncrasies. We then study the framework in the linear setting, where the problem reduces to that of estimating the sum of a rank-$r$ and a $k$-column sparse matrix using a small number of linear measurements. We propose a computationally efficient alternating minimization method with iterative hard thresholding -- AMHT-LRS -- to learn the low-rank and sparse part. Theoretically, for the realizable Gaussian data setting, we show that AMHT-LRS solves the problem efficiently with nearly optimal sample complexity. Finally, a significant challenge in personalization is ensuring privacy of each user's sensitive data. We alleviate this problem by proposing a differentially private variant of our method that also is equipped with strong generalization guarantees.

LGJun 17, 2022
MET: Masked Encoding for Tabular Data

Kushal Majmundar, Sachin Goyal, Praneeth Netrapalli et al.

We consider the task of self-supervised representation learning (SSL) for tabular data: tabular-SSL. Typical contrastive learning based SSL methods require instance-wise data augmentations which are difficult to design for unstructured tabular data. Existing tabular-SSL methods design such augmentations in a relatively ad-hoc fashion and can fail to capture the underlying data manifold. Instead of augmentations based approaches for tabular-SSL, we propose a new reconstruction based method, called Masked Encoding for Tabular Data (MET), that does not require augmentations. MET is based on the popular MAE approach for vision-SSL [He et al., 2021] and uses two key ideas: (i) since each coordinate in a tabular dataset has a distinct meaning, we need to use separate representations for all coordinates, and (ii) using an adversarial reconstruction loss in addition to the standard one. Empirical results on five diverse tabular datasets show that MET achieves a new state of the art (SOTA) on all of these datasets and improves up to 9% over current SOTA methods. We shed more light on the working of MET via experiments on carefully designed simple datasets.

IRJan 8
Succeeding at Scale: Automated Dataset Construction and Query-Side Adaptation for Multi-Tenant Search

Prateek Jain, Shabari S Nair, Ritesh Goru et al.

Large-scale multi-tenant retrieval systems generate extensive query logs but lack curated relevance labels for effective domain adaptation, resulting in substantial underutilized "dark data". This challenge is compounded by the high cost of model updates, as jointly fine-tuning query and document encoders requires full corpus re-indexing, which is impractical in multi-tenant settings with thousands of isolated indices. We introduce DevRev-Search, a passage retrieval benchmark for technical customer support built via a fully automated pipeline. Candidate generation uses fusion across diverse sparse and dense retrievers, followed by an LLM-as-a-Judge for consistency filtering and relevance labeling. We further propose an Index-Preserving Adaptation strategy that fine-tunes only the query encoder, achieving strong performance gains while keeping document indices fixed. Experiments on DevRev-Search, SciFact, and FiQA-2018 show that Parameter-Efficient Fine-Tuning (PEFT) of the query encoder delivers a remarkable quality-efficiency trade-off, enabling scalable and practical enterprise search adaptation.

LGJan 17, 2023
Optimal Algorithms for Latent Bandits with Cluster Structure

Soumyabrata Pal, Arun Sai Suggala, Karthikeyan Shanmugam et al.

We consider the problem of latent bandits with cluster structure where there are multiple users, each with an associated multi-armed bandit problem. These users are grouped into \emph{latent} clusters such that the mean reward vectors of users within the same cluster are identical. At each round, a user, selected uniformly at random, pulls an arm and observes a corresponding noisy reward. The goal of the users is to maximize their cumulative rewards. This problem is central to practical recommendation systems and has received wide attention of late \cite{gentile2014online, maillard2014latent}. Now, if each user acts independently, then they would have to explore each arm independently and a regret of $Ω(\sqrt{\mathsf{MNT}})$ is unavoidable, where $\mathsf{M}, \mathsf{N}$ are the number of arms and users, respectively. Instead, we propose LATTICE (Latent bAndiTs via maTrIx ComplEtion) which allows exploitation of the latent cluster structure to provide the minimax optimal regret of $\widetilde{O}(\sqrt{(\mathsf{M}+\mathsf{N})\mathsf{T}})$, when the number of clusters is $\widetilde{O}(1)$. This is the first algorithm to guarantee such strong regret bound. LATTICE is based on a careful exploitation of arm information within a cluster while simultaneously clustering users. Furthermore, it is computationally efficient and requires only $O(\log{\mathsf{T}})$ calls to an offline matrix completion oracle across all $\mathsf{T}$ rounds.

CVJul 17, 2024
LookupViT: Compressing visual information to a limited number of tokens

Rajat Koner, Gagan Jain, Prateek Jain et al.

Vision Transformers (ViT) have emerged as the de-facto choice for numerous industry grade vision solutions. But their inference cost can be prohibitive for many settings, as they compute self-attention in each layer which suffers from quadratic computational complexity in the number of tokens. On the other hand, spatial information in images and spatio-temporal information in videos is usually sparse and redundant. In this work, we introduce LookupViT, that aims to exploit this information sparsity to reduce ViT inference cost. LookupViT provides a novel general purpose vision transformer block that operates by compressing information from higher resolution tokens to a fixed number of tokens. These few compressed tokens undergo meticulous processing, while the higher-resolution tokens are passed through computationally cheaper layers. Information sharing between these two token sets is enabled through a bidirectional cross-attention mechanism. The approach offers multiple advantages - (a) easy to implement on standard ML accelerators (GPUs/TPUs) via standard high-level operators, (b) applicable to standard ViT and its variants, thus generalizes to various tasks, (c) can handle different tokenization and attention approaches. LookupViT also offers flexibility for the compressed tokens, enabling performance-computation trade-offs in a single trained model. We show LookupViT's effectiveness on multiple domains - (a) for image-classification (ImageNet-1K and ImageNet-21K), (b) video classification (Kinetics400 and Something-Something V2), (c) image captioning (COCO-Captions) with a frozen encoder. LookupViT provides $2\times$ reduction in FLOPs while upholding or improving accuracy across these domains. In addition, LookupViT also demonstrates out-of-the-box robustness and generalization on image classification (ImageNet-C,R,A,O), improving by up to $4\%$ over ViT.

IROct 31, 2023
Blocked Collaborative Bandits: Online Collaborative Filtering with Per-Item Budget Constraints

Soumyabrata Pal, Arun Sai Suggala, Karthikeyan Shanmugam et al.

We consider the problem of \emph{blocked} collaborative bandits where there are multiple users, each with an associated multi-armed bandit problem. These users are grouped into \emph{latent} clusters such that the mean reward vectors of users within the same cluster are identical. Our goal is to design algorithms that maximize the cumulative reward accrued by all the users over time, under the \emph{constraint} that no arm of a user is pulled more than $\mathsf{B}$ times. This problem has been originally considered by \cite{Bresler:2014}, and designing regret-optimal algorithms for it has since remained an open problem. In this work, we propose an algorithm called \texttt{B-LATTICE} (Blocked Latent bAndiTs via maTrIx ComplEtion) that collaborates across users, while simultaneously satisfying the budget constraints, to maximize their cumulative rewards. Theoretically, under certain reasonable assumptions on the latent structure, with $\mathsf{M}$ users, $\mathsf{N}$ arms, $\mathsf{T}$ rounds per user, and $\mathsf{C}=O(1)$ latent clusters, \texttt{B-LATTICE} achieves a per-user regret of $\widetilde{O}(\sqrt{\mathsf{T}(1 + \mathsf{N}\mathsf{M}^{-1})}$ under a budget constraint of $\mathsf{B}=Θ(\log \mathsf{T})$. These are the first sub-linear regret bounds for this problem, and match the minimax regret bounds when $\mathsf{B}=\mathsf{T}$. Empirically, we demonstrate that our algorithm has superior performance over baselines even when $\mathsf{B}=1$. \texttt{B-LATTICE} runs in phases where in each phase it clusters users into groups and collaborates across users within a group to quickly learn their reward models.

QUANT-PHDec 15, 2022
Hybrid Quantum Generative Adversarial Networks for Molecular Simulation and Drug Discovery

Prateek Jain, Srinjoy Ganguly

In molecular research, simulation \& design of molecules are key areas with significant implications for drug development, material science, and other fields. Current classical computational power falls inadequate to simulate any more than small molecules, let alone protein chains on hundreds of peptide. Therefore these experiment are done physically in wet-lab, but it takes a lot of time \& not possible to examine every molecule due to the size of the search area, tens of billions of dollars are spent every year in these research experiments. Molecule simulation \& design has lately advanced significantly by machine learning models, A fresh perspective on the issue of chemical synthesis is provided by deep generative models for graph-structured data. By optimising differentiable models that produce molecular graphs directly, it is feasible to avoid costly search techniques in the discrete and huge space of chemical structures. But these models also suffer from computational limitations when dimensions become huge and consume huge amount of resources. Quantum Generative machine learning in recent years have shown some empirical results promising significant advantages over classical counterparts.

LGJun 9, 2023
End-to-End Neural Network Compression via $\frac{\ell_1}{\ell_2}$ Regularized Latency Surrogates

Anshul Nasery, Hardik Shah, Arun Sai Suggala et al.

Neural network (NN) compression via techniques such as pruning, quantization requires setting compression hyperparameters (e.g., number of channels to be pruned, bitwidths for quantization) for each layer either manually or via neural architecture search (NAS) which can be computationally expensive. We address this problem by providing an end-to-end technique that optimizes for model's Floating Point Operations (FLOPs) or for on-device latency via a novel $\frac{\ell_1}{\ell_2}$ latency surrogate. Our algorithm is versatile and can be used with many popular compression methods including pruning, low-rank factorization, and quantization. Crucially, it is fast and runs in almost the same amount of time as single model training; which is a significant training speed-up over standard NAS methods. For BERT compression on GLUE fine-tuning tasks, we achieve $50\%$ reduction in FLOPs with only $1\%$ drop in performance. For compressing MobileNetV3 on ImageNet-1K, we achieve $15\%$ reduction in FLOPs, and $11\%$ reduction in on-device latency without drop in accuracy, while still requiring $3\times$ less training compute than SOTA compression techniques. Finally, for transfer learning on smaller datasets, our technique identifies $1.2\times$-$1.4\times$ cheaper architectures than standard MobileNetV3, EfficientNet suite of architectures at almost the same training cost and accuracy.

CLMar 8, 2024
Gemini 1.5: Unlocking multimodal understanding across millions of tokens of context

Gemini Team, Petko Georgiev, Ving Ian Lei et al. · deepmind, mila

In this report, we introduce the Gemini 1.5 family of models, representing the next generation of highly compute-efficient multimodal models capable of recalling and reasoning over fine-grained information from millions of tokens of context, including multiple long documents and hours of video and audio. The family includes two new models: (1) an updated Gemini 1.5 Pro, which exceeds the February version on the great majority of capabilities and benchmarks; (2) Gemini 1.5 Flash, a more lightweight variant designed for efficiency with minimal regression in quality. Gemini 1.5 models achieve near-perfect recall on long-context retrieval tasks across modalities, improve the state-of-the-art in long-document QA, long-video QA and long-context ASR, and match or surpass Gemini 1.0 Ultra's state-of-the-art performance across a broad set of benchmarks. Studying the limits of Gemini 1.5's long-context ability, we find continued improvement in next-token prediction and near-perfect retrieval (>99%) up to at least 10M tokens, a generational leap over existing models such as Claude 3.0 (200k) and GPT-4 Turbo (128k). Finally, we highlight real-world use cases, such as Gemini 1.5 collaborating with professionals on completing their tasks achieving 26 to 75% time savings across 10 different job categories, as well as surprising new capabilities of large language models at the frontier; when given a grammar manual for Kalamang, a language with fewer than 200 speakers worldwide, the model learns to translate English to Kalamang at a similar level to a person who learned from the same content.

CLJul 7, 2025
Gemini 2.5: Pushing the Frontier with Advanced Reasoning, Multimodality, Long Context, and Next Generation Agentic Capabilities

Gheorghe Comanici, Eric Bieber, Mike Schaekermann et al. · amazon-science, baidu

In this report, we introduce the Gemini 2.X model family: Gemini 2.5 Pro and Gemini 2.5 Flash, as well as our earlier Gemini 2.0 Flash and Flash-Lite models. Gemini 2.5 Pro is our most capable model yet, achieving SoTA performance on frontier coding and reasoning benchmarks. In addition to its incredible coding and reasoning skills, Gemini 2.5 Pro is a thinking model that excels at multimodal understanding and it is now able to process up to 3 hours of video content. Its unique combination of long context, multimodal and reasoning capabilities can be combined to unlock new agentic workflows. Gemini 2.5 Flash provides excellent reasoning abilities at a fraction of the compute and latency requirements and Gemini 2.0 Flash and Flash-Lite provide high performance at low latency and cost. Taken together, the Gemini 2.X model generation spans the full Pareto frontier of model capability vs cost, allowing users to explore the boundaries of what is possible with complex agentic problem solving.

CVSep 6, 2022
Surya Namaskar: real-time advanced yoga pose recognition and correction for smart healthcare

Abhishek Sharma, Pranjal Sharma, Darshan Pincha et al.

Nowadays, yoga has gained worldwide attention because of increasing levels of stress in the modern way of life, and there are many ways or resources to learn yoga. The word yoga means a deep connection between the mind and body. Today there is substantial Medical and scientific evidence to show that the very fundamentals of the activity of our brain, our chemistry even our genetic content can be changed by practicing different systems of yoga. Suryanamaskar, also known as salute to the sun, is a yoga practice that combines eight different forms and 12 asanas(4 asana get repeated) devoted to the Hindu Sun God, Surya. Suryanamaskar offers a number of health benefits such as strengthening muscles and helping to control blood sugar levels. Here the Mediapipe Library is used to analyze Surya namaskar situations. Standing is detected in real time with advanced software, as one performs Surya namaskar in front of the camera. The class divider identifies the form as one of the following: Pranamasana, Hasta Padasana, Hasta Uttanasana, Ashwa - Sanchalan asana, Ashtanga Namaskar, Dandasana, or Bhujangasana and Svanasana. Deep learning-based techniques(CNN) are used to develop this model with model accuracy of 98.68 percent and an accuracy score of 0.75 to detect correct yoga (Surya Namaskar ) posture. With this method, the users can practice the desired pose and can check if the pose that the person is doing is correct or not. It will help in doing all the different poses of surya namaskar correctly and increase the efficiency of the yoga practitioner. This paper describes the whole framework which is to be implemented in the model.

CVApr 10
ELT: Elastic Looped Transformers for Visual Generation

Sahil Goyal, Swayam Agrawal, Gautham Govind Anil et al.

We introduce Elastic Looped Transformers (ELT), a highly parameter-efficient class of visual generative models based on a recurrent transformer architecture. While conventional generative models rely on deep stacks of unique transformer layers, our approach employs iterative, weight-shared transformer blocks to drastically reduce parameter counts while maintaining high synthesis quality. To effectively train these models for image and video generation, we propose the idea of Intra-Loop Self Distillation (ILSD), where student configurations (intermediate loops) are distilled from the teacher configuration (maximum training loops) to ensure consistency across the model's depth in a single training step. Our framework yields a family of elastic models from a single training run, enabling Any-Time inference capability with dynamic trade-offs between computational cost and generation quality, with the same parameter count. ELT significantly shifts the efficiency frontier for visual synthesis. With $4\times$ reduction in parameter count under iso-inference-compute settings, ELT achieves a competitive FID of $2.0$ on class-conditional ImageNet $256 \times 256$ and FVD of $72.8$ on class-conditional UCF-101.

CLApr 25, 2025Code
One-Pass to Reason: Token Duplication and Block-Sparse Mask for Efficient Fine-Tuning on Multi-Turn Reasoning

Ritesh Goru, Shanay Mehta, Prateek Jain

Fine-tuning Large Language Models (LLMs) on multi-turn reasoning datasets requires N (number of turns) separate forward passes per conversation due to reasoning token visibility constraints, as reasoning tokens for a turn are discarded in subsequent turns. We propose duplicating response tokens along with a custom attention mask to enable single-pass processing of entire conversations. We prove our method produces identical losses to the N-pass approach while reducing time complexity from $O\bigl(N^{3}\bigl)$ to $O\bigl(N^{2}\bigl)$ and maintaining the same memory complexity for a transformer based model. Our approach achieves significant training speedup while preserving accuracy. Our implementation is available online (https://github.com/devrev/One-Pass-to-Reason).

LGMay 30, 2023Code
AdANNS: A Framework for Adaptive Semantic Search

Aniket Rege, Aditya Kusupati, Sharan Ranjit S et al.

Web-scale search systems learn an encoder to embed a given query which is then hooked into an approximate nearest neighbor search (ANNS) pipeline to retrieve similar data points. To accurately capture tail queries and data points, learned representations typically are rigid, high-dimensional vectors that are generally used as-is in the entire ANNS pipeline and can lead to computationally expensive retrieval. In this paper, we argue that instead of rigid representations, different stages of ANNS can leverage adaptive representations of varying capacities to achieve significantly better accuracy-compute trade-offs, i.e., stages of ANNS that can get away with more approximate computation should use a lower-capacity representation of the same data point. To this end, we introduce AdANNS, a novel ANNS design framework that explicitly leverages the flexibility of Matryoshka Representations. We demonstrate state-of-the-art accuracy-compute trade-offs using novel AdANNS-based key ANNS building blocks like search data structures (AdANNS-IVF) and quantization (AdANNS-OPQ). For example on ImageNet retrieval, AdANNS-IVF is up to 1.5% more accurate than the rigid representations-based IVF at the same compute budget; and matches accuracy while being up to 90x faster in wall-clock time. For Natural Questions, 32-byte AdANNS-OPQ matches the accuracy of the 64-byte OPQ baseline constructed using rigid representations -- same accuracy at half the cost! We further show that the gains from AdANNS translate to modern-day composite ANNS indices that combine search structures and quantization. Finally, we demonstrate that AdANNS can enable inference-time adaptivity for compute-aware search on ANNS indices built non-adaptively on matryoshka representations. Code is open-sourced at https://github.com/RAIVNLab/AdANNS.

LGJun 2, 2021Code
LLC: Accurate, Multi-purpose Learnt Low-dimensional Binary Codes

Aditya Kusupati, Matthew Wallingford, Vivek Ramanujan et al.

Learning binary representations of instances and classes is a classical problem with several high potential applications. In modern settings, the compression of high-dimensional neural representations to low-dimensional binary codes is a challenging task and often require large bit-codes to be accurate. In this work, we propose a novel method for Learning Low-dimensional binary Codes (LLC) for instances as well as classes. Our method does not require any side-information, like annotated attributes or label meta-data, and learns extremely low-dimensional binary codes (~20 bits for ImageNet-1K). The learnt codes are super-efficient while still ensuring nearly optimal classification accuracy for ResNet50 on ImageNet-1K. We demonstrate that the learnt codes capture intrinsically important features in the data, by discovering an intuitive taxonomy over classes. We further quantitatively measure the quality of our codes by applying it to the efficient image retrieval as well as out-of-distribution (OOD) detection problems. For ImageNet-100 retrieval problem, our learnt binary codes outperform 16 bit HashNet using only 10 bits and also are as accurate as 10 dimensional real representations. Finally, our learnt binary codes can perform OOD detection, out-of-the-box, as accurately as a baseline that needs ~3000 samples to tune its threshold, while we require none. Code is open-sourced at https://github.com/RAIVNLab/LLC.

LGFeb 25, 2021Code
Do Input Gradients Highlight Discriminative Features?

Harshay Shah, Prateek Jain, Praneeth Netrapalli

Post-hoc gradient-based interpretability methods [Simonyan et al., 2013, Smilkov et al., 2017] that provide instance-specific explanations of model predictions are often based on assumption (A): magnitude of input gradients -- gradients of logits with respect to input -- noisily highlight discriminative task-relevant features. In this work, we test the validity of assumption (A) using a three-pronged approach. First, we develop an evaluation framework, DiffROAR, to test assumption (A) on four image classification benchmarks. Our results suggest that (i) input gradients of standard models (i.e., trained on original data) may grossly violate (A), whereas (ii) input gradients of adversarially robust models satisfy (A). Second, we introduce BlockMNIST, an MNIST-based semi-real dataset, that by design encodes a priori knowledge of discriminative features. Our analysis on BlockMNIST leverages this information to validate as well as characterize differences between input gradient attributions of standard and robust models. Finally, we theoretically prove that our empirical findings hold on a simplified version of the BlockMNIST dataset. Specifically, we prove that input gradients of standard one-hidden-layer MLPs trained on this dataset do not highlight instance-specific signal coordinates, thus grossly violating assumption (A). Our findings motivate the need to formalize and test common assumptions in interpretability in a falsifiable manner [Leavitt and Morcos, 2020]. We believe that the DiffROAR evaluation framework and BlockMNIST-based datasets can serve as sanity checks to audit instance-specific interpretability methods; code and data available at https://github.com/harshays/inputgradients.

LGFeb 28, 2020Code
DROCC: Deep Robust One-Class Classification

Sachin Goyal, Aditi Raghunathan, Moksh Jain et al.

Classical approaches for one-class problems such as one-class SVM and isolation forest require careful feature engineering when applied to structured domains like images. State-of-the-art methods aim to leverage deep learning to learn appropriate features via two main approaches. The first approach based on predicting transformations (Golan & El-Yaniv, 2018; Hendrycks et al., 2019a) while successful in some domains, crucially depends on an appropriate domain-specific set of transformations that are hard to obtain in general. The second approach of minimizing a classical one-class loss on the learned final layer representations, e.g., DeepSVDD (Ruff et al., 2018) suffers from the fundamental drawback of representation collapse. In this work, we propose Deep Robust One-Class Classification (DROCC) that is both applicable to most standard domains without requiring any side-information and robust to representation collapse. DROCC is based on the assumption that the points from the class of interest lie on a well-sampled, locally linear low dimensional manifold. Empirical evaluation demonstrates that DROCC is highly effective in two different one-class problem settings and on a range of real-world datasets across different domains: tabular data, images (CIFAR and ImageNet), audio, and time-series, offering up to 20% increase in accuracy over the state-of-the-art in anomaly detection. Code is available at https://github.com/microsoft/EdgeML.

CVFeb 27, 2020Code
RNNPool: Efficient Non-linear Pooling for RAM Constrained Inference

Oindrila Saha, Aditya Kusupati, Harsha Vardhan Simhadri et al.

Standard Convolutional Neural Networks (CNNs) designed for computer vision tasks tend to have large intermediate activation maps. These require large working memory and are thus unsuitable for deployment on resource-constrained devices typically used for inference on the edge. Aggressively downsampling the images via pooling or strided convolutions can address the problem but leads to a significant decrease in accuracy due to gross aggregation of the feature map by standard pooling operators. In this paper, we introduce RNNPool, a novel pooling operator based on Recurrent Neural Networks (RNNs), that efficiently aggregates features over large patches of an image and rapidly downsamples activation maps. Empirical evaluation indicates that an RNNPool layer can effectively replace multiple blocks in a variety of architectures such as MobileNets, DenseNet when applied to standard vision tasks like image classification and face detection. That is, RNNPool can significantly decrease computational complexity and peak memory usage for inference while retaining comparable accuracy. We use RNNPool with the standard S3FD architecture to construct a face detection method that achieves state-of-the-art MAP for tiny ARM Cortex-M4 class microcontrollers with under 256 KB of RAM. Code is released at https://github.com/Microsoft/EdgeML.

LGFeb 8, 2020Code
Soft Threshold Weight Reparameterization for Learnable Sparsity

Aditya Kusupati, Vivek Ramanujan, Raghav Somani et al.

Sparsity in Deep Neural Networks (DNNs) is studied extensively with the focus of maximizing prediction accuracy given an overall parameter budget. Existing methods rely on uniform or heuristic non-uniform sparsity budgets which have sub-optimal layer-wise parameter allocation resulting in a) lower prediction accuracy or b) higher inference cost (FLOPs). This work proposes Soft Threshold Reparameterization (STR), a novel use of the soft-threshold operator on DNN weights. STR smoothly induces sparsity while learning pruning thresholds thereby obtaining a non-uniform sparsity budget. Our method achieves state-of-the-art accuracy for unstructured sparsity in CNNs (ResNet50 and MobileNetV1 on ImageNet-1K), and, additionally, learns non-uniform budgets that empirically reduce the FLOPs by up to 50%. Notably, STR boosts the accuracy over existing results by up to 10% in the ultra sparse (99%) regime and can also be used to induce low-rank (structured sparsity) in RNNs. In short, STR is a simple mechanism which learns effective sparsity budgets that contrast with popular heuristics. Code, pretrained models and sparsity budgets are at https://github.com/RAIVNLab/STR.

LGJan 8, 2019Code
FastGRNN: A Fast, Accurate, Stable and Tiny Kilobyte Sized Gated Recurrent Neural Network

Aditya Kusupati, Manish Singh, Kush Bhatia et al.

This paper develops the FastRNN and FastGRNN algorithms to address the twin RNN limitations of inaccurate training and inefficient prediction. Previous approaches have improved accuracy at the expense of prediction costs making them infeasible for resource-constrained and real-time applications. Unitary RNNs have increased accuracy somewhat by restricting the range of the state transition matrix's singular values but have also increased the model size as they require a larger number of hidden units to make up for the loss in expressive power. Gated RNNs have obtained state-of-the-art accuracies by adding extra parameters thereby resulting in even larger models. FastRNN addresses these limitations by adding a residual connection that does not constrain the range of the singular values explicitly and has only two extra scalar parameters. FastGRNN then extends the residual connection to a gate by reusing the RNN matrices to match state-of-the-art gated RNN accuracies but with a 2-4x smaller model. Enforcing FastGRNN's matrices to be low-rank, sparse and quantized resulted in accurate models that could be up to 35x smaller than leading gated and unitary RNNs. This allowed FastGRNN to accurately recognize the "Hey Cortana" wakeword with a 1 KB model and to be deployed on severely resource-constrained IoT microcontrollers too tiny to store other RNN models. FastGRNN's code is available at https://github.com/Microsoft/EdgeML/.

CLMar 29, 2024
Gecko: Versatile Text Embeddings Distilled from Large Language Models

Jinhyuk Lee, Zhuyun Dai, Xiaoqi Ren et al. · uw

We present Gecko, a compact and versatile text embedding model. Gecko achieves strong retrieval performance by leveraging a key idea: distilling knowledge from large language models (LLMs) into a retriever. Our two-step distillation process begins with generating diverse, synthetic paired data using an LLM. Next, we further refine the data quality by retrieving a set of candidate passages for each query, and relabeling the positive and hard negative passages using the same LLM. The effectiveness of our approach is demonstrated by the compactness of the Gecko. On the Massive Text Embedding Benchmark (MTEB), Gecko with 256 embedding dimensions outperforms all existing entries with 768 embedding size. Gecko with 768 embedding dimensions achieves an average score of 66.31, competing with 7x larger models and 5x higher dimensional embeddings.

LGJan 4, 2024
LLM Augmented LLMs: Expanding Capabilities through Composition

Rachit Bansal, Bidisha Samanta, Siddharth Dalmia et al. · cmu, deepmind

Foundational models with billions of parameters which have been trained on large corpora of data have demonstrated non-trivial skills in a variety of domains. However, due to their monolithic structure, it is challenging and expensive to augment them or impart new skills. On the other hand, due to their adaptation abilities, several new instances of these models are being trained towards new domains and tasks. In this work, we study the problem of efficient and practical composition of existing foundation models with more specific models to enable newer capabilities. To this end, we propose CALM -- Composition to Augment Language Models -- which introduces cross-attention between models to compose their representations and enable new capabilities. Salient features of CALM are: (i) Scales up LLMs on new tasks by 're-using' existing LLMs along with a few additional parameters and data, (ii) Existing model weights are kept intact, and hence preserves existing capabilities, and (iii) Applies to diverse domains and settings. We illustrate that augmenting PaLM2-S with a smaller model trained on low-resource languages results in an absolute improvement of up to 13\% on tasks like translation into English and arithmetic reasoning for low-resource languages. Similarly, when PaLM2-S is augmented with a code-specific model, we see a relative improvement of 40\% over the base model for code generation and explanation tasks -- on-par with fully fine-tuned counterparts.

LGJan 8
Estimating Causal Effects in Gaussian Linear SCMs with Finite Data

Aurghya Maiti, Prateek Jain

Estimating causal effects from observational data remains a fundamental challenge in causal inference, especially in the presence of latent confounders. This paper focuses on estimating causal effects in Gaussian Linear Structural Causal Models (GL-SCMs), which are widely used due to their analytical tractability. However, parameter estimation in GL-SCMs is often infeasible with finite data, primarily due to overparameterization. To address this, we introduce the class of Centralized Gaussian Linear SCMs (CGL-SCMs), a simplified yet expressive subclass where exogenous variables follow standardized distributions. We show that CGL-SCMs are equally expressive in terms of causal effect identifiability from observational distributions and present a novel EM-based estimation algorithm that can learn CGL-SCM parameters and estimate identifiable causal effects from finite observational samples. Our theoretical analysis is validated through experiments on synthetic data and benchmark causal graphs, demonstrating that the learned models accurately recover causal distributions.

CVNov 15, 2025
AGGRNet: Selective Feature Extraction and Aggregation for Enhanced Medical Image Classification

Ansh Makwe, Akansh Agrawal, Prateek Jain et al.

Medical image analysis for complex tasks such as severity grading and disease subtype classification poses significant challenges due to intricate and similar visual patterns among classes, scarcity of labeled data, and variability in expert interpretations. Despite the usefulness of existing attention-based models in capturing complex visual patterns for medical image classification, underlying architectures often face challenges in effectively distinguishing subtle classes since they struggle to capture inter-class similarity and intra-class variability, resulting in incorrect diagnosis. To address this, we propose AGGRNet framework to extract informative and non-informative features to effectively understand fine-grained visual patterns and improve classification for complex medical image analysis tasks. Experimental results show that our model achieves state-of-the-art performance on various medical imaging datasets, with the best improvement up to 5% over SOTA models on the Kvasir dataset.

AIFeb 13, 2024
Tandem Transformers for Inference Efficient LLMs

Aishwarya P S, Pranav Ajit Nair, Yashas Samaga et al.

The autoregressive nature of conventional large language models (LLMs) inherently limits inference speed, as tokens are generated sequentially. While speculative and parallel decoding techniques attempt to mitigate this, they face limitations: either relying on less accurate smaller models for generation or failing to fully leverage the base LLM's representations. We introduce a novel architecture, Tandem transformers, to address these issues. This architecture uniquely combines (1) a small autoregressive model and (2) a large model operating in block mode (processing multiple tokens simultaneously). The small model's predictive accuracy is substantially enhanced by granting it attention to the large model's richer representations. On the PaLM2 pretraining dataset, a tandem of PaLM2-Bison and PaLM2-Gecko demonstrates a 3.3% improvement in next-token prediction accuracy over a standalone PaLM2-Gecko, offering a 1.16x speedup compared to a PaLM2-Otter model with comparable downstream performance. We further incorporate the tandem model within the speculative decoding (SPEED) framework where the large model validates tokens from the small model. This ensures that the Tandem of PaLM2-Bison and PaLM2-Gecko achieves substantial speedup (around 1.14x faster than using vanilla PaLM2-Gecko in SPEED) while maintaining identical downstream task accuracy.

CLDec 4, 2024
Does Safety Training of LLMs Generalize to Semantically Related Natural Prompts?

Sravanti Addepalli, Yerram Varun, Arun Suggala et al.

Large Language Models (LLMs) are known to be susceptible to crafted adversarial attacks or jailbreaks that lead to the generation of objectionable content despite being aligned to human preferences using safety fine-tuning methods. While the large dimensionality of input token space makes it inevitable to find adversarial prompts that can jailbreak these models, we aim to evaluate whether safety fine-tuned LLMs are safe against natural prompts which are semantically related to toxic seed prompts that elicit safe responses after alignment. We surprisingly find that popular aligned LLMs such as GPT-4 can be compromised using naive prompts that are NOT even crafted with an objective of jailbreaking the model. Furthermore, we empirically show that given a seed prompt that elicits a toxic response from an unaligned model, one can systematically generate several semantically related natural prompts that can jailbreak aligned LLMs. Towards this, we propose a method of Response Guided Question Augmentation (ReG-QA) to evaluate the generalization of safety aligned LLMs to natural prompts, that first generates several toxic answers given a seed question using an unaligned LLM (Q to A), and further leverages an LLM to generate questions that are likely to produce these answers (A to Q). We interestingly find that safety fine-tuned LLMs such as GPT-4o are vulnerable to producing natural jailbreak questions from unsafe content (without denial) and can thus be used for the latter (A to Q) step. We obtain attack success rates that are comparable to/ better than leading adversarial attack methods on the JailbreakBench leaderboard, while being significantly more stable against defenses such as Smooth-LLM and Synonym Substitution, which are effective against existing all attacks on the leaderboard.

LGFeb 14, 2024
HiRE: High Recall Approximate Top-$k$ Estimation for Efficient LLM Inference

Yashas Samaga B L, Varun Yerram, Chong You et al.

Autoregressive decoding with generative Large Language Models (LLMs) on accelerators (GPUs/TPUs) is often memory-bound where most of the time is spent on transferring model parameters from high bandwidth memory (HBM) to cache. On the other hand, recent works show that LLMs can maintain quality with significant sparsity/redundancy in the feedforward (FFN) layers by appropriately training the model to operate on a top-$k$ fraction of rows/columns (where $k \approx 0.05$), there by suggesting a way to reduce the transfer of model parameters, and hence latency. However, exploiting this sparsity for improving latency is hindered by the fact that identifying top rows/columns is data-dependent and is usually performed using full matrix operations, severely limiting potential gains. To address these issues, we introduce HiRE (High Recall Approximate Top-k Estimation). HiRE comprises of two novel components: (i) a compression scheme to cheaply predict top-$k$ rows/columns with high recall, followed by full computation restricted to the predicted subset, and (ii) DA-TOP-$k$: an efficient multi-device approximate top-$k$ operator. We demonstrate that on a one billion parameter model, HiRE applied to both the softmax as well as feedforward layers, achieves almost matching pretraining and downstream accuracy, and speeds up inference latency by $1.47\times$ on a single TPUv5e device.

LGFeb 10, 2025
Matryoshka Quantization

Pranav Nair, Puranjay Datta, Jeff Dean et al. · uw

Quantizing model weights is critical for reducing the communication and inference costs of large models. However, quantizing models -- especially to low precisions like int4 or int2 -- requires a trade-off in model quality; int2, in particular, is known to severely degrade model quality. Consequently, practitioners are often forced to maintain multiple models with different quantization levels or serve a single model that best satisfies the quality-latency trade-off. On the other hand, integer data types, such as int8, inherently possess a nested (Matryoshka) structure where smaller bit-width integers, like int4 or int2, are nested within the most significant bits. Leveraging this insight, in this paper, we propose Matryoshka Quantization (MatQuant), a novel multi-scale quantization technique that alleviates the aforementioned challenge. This technique allows us to train and maintain a single quantized model but serve it with the precision demanded by the deployment. Furthermore, leveraging MatQuant's co-training and co-distillation regularization, int2 precision models extracted by MatQuant outperform standard int2 quantization by up to to 4% and 7% with OmniQuant and QAT as base algorithms respectively. Finally, we demonstrate that by using an extra bit to represent outliers, a model with an effective precision of 2.05-bit gives an additional 6% improvement with OmniQuant as the base algorithm.

CVFeb 1, 2025
Masked Generative Nested Transformers with Decode Time Scaling

Sahil Goyal, Debapriya Tula, Gagan Jain et al.

Recent advances in visual generation have made significant strides in producing content of exceptional quality. However, most methods suffer from a fundamental problem - a bottleneck of inference computational efficiency. Most of these algorithms involve multiple passes over a transformer model to generate tokens or denoise inputs. However, the model size is kept consistent throughout all iterations, which makes it computationally expensive. In this work, we aim to address this issue primarily through two key ideas - (a) not all parts of the generation process need equal compute, and we design a decode time model scaling schedule to utilize compute effectively, and (b) we can cache and reuse some of the computation. Combining these two ideas leads to using smaller models to process more tokens while large models process fewer tokens. These different-sized models do not increase the parameter size, as they share parameters. We rigorously experiment with ImageNet256$\times$256 , UCF101, and Kinetics600 to showcase the efficacy of the proposed method for image/video generation and frame prediction. Our experiments show that with almost $3\times$ less compute than baseline, our model obtains competitive performance.

CLDec 3, 2024
Time-Reversal Provides Unsupervised Feedback to LLMs

Yerram Varun, Rahul Madhavan, Sravanti Addepalli et al.

Large Language Models (LLMs) are typically trained to predict in the forward direction of time. However, recent works have shown that prompting these models to look back and critique their own generations can produce useful feedback. Motivated by this, we explore the question of whether LLMs can be empowered to think (predict and score) backwards to provide unsupervised feedback that complements forward LLMs. Towards this, we introduce Time Reversed Language Models (TRLMs), which can score and generate queries when conditioned on responses, effectively functioning in the reverse direction of time. Further, to effectively infer in the response to query direction, we pre-train and fine-tune a language model (TRLM-Ba) in the reverse token order from scratch. We show empirically (and theoretically in a stylized setting) that time-reversed models can indeed complement forward model predictions when used to score the query given response for re-ranking multiple forward generations. We obtain up to 5\% improvement on the widely used AlpacaEval Leaderboard over the competent baseline of best-of-N re-ranking using self log-perplexity scores. We further show that TRLM scoring outperforms conventional forward scoring of response given query, resulting in significant gains in applications such as citation generation and passage retrieval. We next leverage the generative ability of TRLM to augment or provide unsupervised feedback to input safety filters of LLMs, demonstrating a drastic reduction in false negative rate with negligible impact on false positive rates against several attacks published on the popular JailbreakBench leaderboard.

LGOct 17, 2025
Compressing Many-Shots in In-Context Learning

Devvrit Khatri, Pranamya Kulkarni, Nilesh Gupta et al.

Large Language Models (LLMs) have been shown to be able to learn different tasks without explicit finetuning when given many input-output examples / demonstrations through In-Context Learning (ICL). Increasing the number of examples, called ``shots'', improves downstream task performance but incurs higher memory and computational costs. In this work, we study an approach to improve the memory and computational efficiency of ICL inference by compressing the many-shot prompts. Given many shots comprising t tokens, our goal is to generate a m soft-token summary, where m < t. We first show that existing prompt compression methods are ineffective for many-shot compression, and simply using fewer shots as a baseline is surprisingly strong. To achieve effective compression, we find that: (a) a stronger compressor model with more trainable parameters is necessary, and (b) compressing many-shot representations at each transformer layer enables more fine-grained compression by providing each layer with its own compressed representation. Based on these insights, we propose MemCom, a layer-wise compression method. We systematically evaluate various compressor models and training approaches across different model sizes (2B and 7B), architectures (Gemma and Mistral), many-shot sequence lengths (3k-6k tokens), and compression ratios (3x to 8x). MemCom outperforms strong baselines across all compression ratios on multiple classification tasks with large label sets. Notably, while baseline performance degrades sharply at higher compression ratios, often by over 20-30%, MemCom maintains high accuracy with minimal degradation, typically dropping by less than 10%.

LGJun 4, 2025
Faster Approx. Top-K: Harnessing the Full Power of Two Stages

Yashas Samaga, Varun Yerram, Spandana Raj Babbula et al.

We consider the Top-$K$ selection problem, which aims to identify the largest-$K$ elements from an array. Top-$K$ selection arises in many machine learning algorithms and often becomes a bottleneck on accelerators, which are optimized for dense matrix multiplications. To address this problem, \citet{chern2022tpuknnknearestneighbor} proposed a fast two-stage \textit{approximate} Top-$K$ algorithm: (i) partition the input array and select the top-$1$ element from each partition, (ii) sort this \textit{smaller subset} and return the top $K$ elements. In this paper, we consider a generalized version of this algorithm, where the first stage selects top-$K'$ elements, for some $1 \leq K' \leq K$, from each partition. Our contributions are as follows: (i) we derive an expression for the expected recall of this generalized algorithm and show that choosing $K' > 1$ with fewer partitions in the first stage reduces the input size to the second stage more effectively while maintaining the same expected recall as the original algorithm, (ii) we derive a bound on the expected recall for the original algorithm in \citet{chern2022tpuknnknearestneighbor} that is provably tighter by a factor of $2$ than the one in that paper, and (iii) we implement our algorithm on Cloud TPUv5e and achieve around an order of magnitude speedups over the original algorithm without sacrificing recall on real-world tasks.

LGMay 29, 2025
Matryoshka Model Learning for Improved Elastic Student Models

Chetan Verma, Aditya Srinivas Timmaraju, Cho-Jui Hsieh et al.

Industry-grade ML models are carefully designed to meet rapidly evolving serving constraints, which requires significant resources for model development. In this paper, we propose MatTA, a framework for training multiple accurate Student models using a novel Teacher-TA-Student recipe. TA models are larger versions of the Student models with higher capacity, and thus allow Student models to better relate to the Teacher model and also bring in more domain-specific expertise. Furthermore, multiple accurate Student models can be extracted from the TA model. Therefore, despite only one training run, our methodology provides multiple servable options to trade off accuracy for lower serving cost. We demonstrate the proposed method, MatTA, on proprietary datasets and models. Its practical efficacy is underscored by live A/B tests within a production ML system, demonstrating 20% improvement on a key metric. We also demonstrate our method on GPT-2 Medium, a public model, and achieve relative improvements of over 24% on SAT Math and over 10% on the LAMBADA benchmark.