On Device LLM Processing In Android

Posted on Jun 24, 2024 • 15 min read

Meet Prajapati

Sr. Mobile Software Engineer

On Device LLM Processing In Android

Summary

As large language models (LLMs) continue to advance, integrating them into mobile applications has become increasingly feasible and beneficial. On-device processing of LLMs offers several advantages, such as reduced latency, enhanced privacy, and offline capabilities. 

By running LLMs directly on the device, applications can provide real-time responses without relying on a constant internet connection or exposing sensitive data to external servers. 

This blog explores the concept of on-device LLM processing in Android, demonstrating how to implement such a feature using Kotlin. We'll walk through the essential components of an Android application that leverages LLMs for real-time text generation and processing, providing an efficient and secure way to handle language models directly on the device.

Get Started

We are using Gemma 2B, Gemma is a family of lightweight, open models built from the research and technology that Google used to create the Gemini models. You can download the model from the provided link, extract it and it will be ready. 

To get started, make a new Android Project, We will be using the compose, We will use Google’s Mediapipe to interact with the model. MediaPipe Solutions offers a collection of libraries and tools designed to help you swiftly integrate artificial intelligence (AI) and machine learning (ML) capabilities into your applications.

Copy Model To Device

  1. Extract the downloaded Model in the Download Folder on your computer and connect your mobile device.

Run the below commands using the adb in the terminal,

adb shell rm -r /data/local/tmp/llm/
adb shell mkdir -p /data/local/tmp/llm/   
adb push gemma2b.bin /data/local/tmp/llm/gemma2b.bin

These commands will copy the gemma2b.bin model file to the temporary directory. Now that our model is in the right place we can start the coding part.

Let’s Code

Add this to the AndroidManifest file to support native libraries:

<uses-native-library
            android:name="libOpenCL.so"
            android:required="false" />
<uses-native-library
            android:name="libOpenCL-car.so"
            android:required="false" />
<uses-native-library
            android:name="libOpenCL-pixel.so"
            android:required="false" />

 Add this dependency to the build.gradle file of your Android app:

dependencies {
    implementation 'com.google.mediapipe:tasks-genai:0.10.14'
}

The LLMTask Class

The core component of our implementation is the LLMTask class. This class handles the initialization and execution of the LLM inference, ensuring the model runs efficiently on the device. Let's break down the key elements of this class:

class LLMTask(context: Context) {
    private val _partialResults = MutableSharedFlow<Pair<String, Boolean>>(
        extraBufferCapacity = 1,
        onBufferOverflow = BufferOverflow.DROP_OLDEST
    )
    val partialResults: SharedFlow<Pair<String, Boolean>> = _partialResults.asSharedFlow()
    private var llmInference: LlmInference

       init {
        val options = LlmInference.LlmInferenceOptions.builder()
            .setModelPath(MODEL_PATH)
            .setMaxTokens(2048)
            .setTopK(50)
            .setTemperature(0.7f)
            .setRandomSeed(1)
            .setResultListener { partialResult, done ->
                _partialResults.tryEmit(partialResult to done)
            }
            .build()

        llmInference = LlmInference.createFromOptions(
            context,
            options
        )
    }

    fun generateResponse(prompt: String) {
        llmInference.generateResponseAsync(prompt)
    }

    companion object {
        private const val MODEL_PATH = "/data/local/tmp/llm/gemma2b.bin"
        private var instance: LLMTask? = null
        fun getInstance(context: Context): LLMTask {
            return if (instance != null) {
                instance!!
            } else {
                LLMTask(context).also { instance = it }
            }
        }
    }
}

Key Components

  1. MutableSharedFlow and SharedFlow: These are used to manage the flow of partial results from the LLM inference. The MutableSharedFlow allows us to emit new results, while the SharedFlow exposes these results to other parts of the application.

  2. LlmInference Initialization: The LlmInference instance is initialized with options, including the model path, maximum tokens, and a result listener that handles partial results.

We can use the following configuration options to initialize the LlmInference,

  1. modelPath: The path to where the model is stored within the project directory.

  2. maxTokens: The model handles the maximum number of tokens (input tokens + output tokens). The default value is 512.

  3. topK: The number of tokens the model considers at each step of generation. Limits predictions to the top k most-probable tokens. The default value is 40.

  4. temperature: The amount of randomness introduced during generation. A higher temperature results in more creativity in the generated text, while a lower temperature produces more predictable generation. The default value is 0.8.

  5. randomSeed: The random seed used during text generation. The default value is 0.

  6. loraPath: The absolute path to the LoRA model locally on the device. Note: this is only compatible with GPU models.

  7. resultListener: Sets the result listener to receive the results asynchronously. Only applicable when using the async generation method.

  8. errorListener: Sets an optional error listener.

Managing State with LLMState

sealed class LLMState {
    data object LLMModelLoading : LLMState()
    data object LLMModelLoaded : LLMState()
    data object LLMResponseLoading : LLMState()
    data object LLMResponseLoaded : LLMState()

    val isLLMModelLoading get() = this is LLMModelLoading
    val isLLMResponseLoading get() = this is LLMResponseLoading
}

This sealed class helps manage and react to different states, such as when the model is loading, when it's loaded, and when a response is generated.

ChatState Class

The ChatState class is responsible for maintaining the state of the chat, including user messages and LLM responses.

class ChatState(
    messages: List<ChatDataModel> = emptyList()
) {
    private val _chatMessages: MutableList<ChatDataModel> = messages.toMutableStateList()
    val chatMessages: List<ChatDataModel>
        get() = _chatMessages.map { model ->
            val isUser = model.isUser
            val prefixToRemove =
                if (isUser) USER_PREFIX else MODEL_PREFIX
            model.copy(
                chatMessage = model.chatMessage
                    .replace(
                        START_TURN + prefixToRemove + "\n",
                        ""
                    )
                    .replace(
                        END_TURN,
                        ""
                    )
            )
        }.reversed()

    val fullPrompt
        get() =
            _chatMessages.takeLast(5).joinToString("\n") { it.chatMessage }

    fun createLLMLoadingMessage(): String {
        val chatMessage = ChatDataModel(
            chatMessage = "",
            isUser = false
        )
        _chatMessages.add(chatMessage)
        return chatMessage.id
    }

    fun appendFirstLLMResponse(
        id: String,
        message: String,
    ) {
        appendLLMResponse(
            id,
            "$START_TURN$MODEL_PREFIX\n$message",
            false
        )
    }

    fun appendLLMResponse(
        id: String,
        message: String,
        done: Boolean
    ) {
        val index = _chatMessages.indexOfFirst { it.id == id }
        if (index != -1) {
            val newText = if (done) {
                _chatMessages[index].chatMessage + message + END_TURN
            } else {
                _chatMessages[index].chatMessage + message
            }
            _chatMessages[index] = _chatMessages[index].copy(chatMessage = newText)
        }
    }

    fun appendUserMessage(
        message: String,
    ) {
        val chatMessage = ChatDataModel(
            chatMessage = "$START_TURN$USER_PREFIX\n$message$END_TURN",
            isUser = true
        )
        _chatMessages.add(chatMessage)
    }

    fun addErrorLLMResponse(e: Exception) {
        _chatMessages.add(
            ChatDataModel(
                chatMessage = e.localizedMessage ?: "Error generating message",
                isUser = false
            )
        )
    }

    companion object {
        private const val MODEL_PREFIX = "model"
        private const val USER_PREFIX = "user"
        private const val START_TURN = "<start_of_turn>"
        private const val END_TURN = "<end_of_turn>"
    }
}

Key Methods:

  1. createLLMLoadingMessage: Adds a new loading message to the chat state and returns its ID.

  2. appendFirstLLMResponse and appendLLMResponse: These methods handle appending partial and complete LLM responses to the chat messages.

  3. appendUserMessage: Adds user messages to the chat state.

  4. addErrorLLMResponse: Adds an error message if there's an issue during LLM processing.

  5. fullPrompt: Joins the last 5 messages to provide better context to LLM

ChatViewModel Class

The ChatViewModel class manages the interaction between the UI and the LLM processing logic. It uses Kotlin coroutines to manage asynchronous tasks and updates the UI state accordingly.

@HiltViewModel
class ChatViewModel @Inject constructor(@ApplicationContext private val context: Context) :
    ViewModel() {
    private val _llmState = MutableStateFlow<LLMState>(LLMState.LLMModelLoading)
    val llmState = _llmState.asStateFlow()
    private val _chatState: MutableStateFlow<ChatState> = MutableStateFlow(ChatState())
    val chatState: StateFlow<ChatState> = _chatState.asStateFlow()

    fun initLLMModel() {
        viewModelScope.launch(Dispatchers.IO) {
            _llmState.emit(LLMState.LLMModelLoading)
            LLMTask.getInstance(context)
        }.invokeOnCompletion {
            _llmState.value = LLMState.LLMModelLoaded
        }
    }

    fun sendMessage(message: String) {
        viewModelScope.launch(Dispatchers.IO) {
            _chatState.value.appendUserMessage(message)
            try {
                _llmState.emit(LLMState.LLMResponseLoading)
                var currentLLMResponseId: String? = _chatState.value.createLLMLoadingMessage()
                LLMTask.getInstance(context).generateResponse(_chatState.value.fullPrompt)
                LLMTask.getInstance(context).partialResults
                    .collectIndexed { index, (partialResult, done) ->
                        currentLLMResponseId?.let { id ->
                            if (index == 0) {
                                _chatState.value.appendFirstLLMResponse(id, partialResult)
                            } else {
                                _chatState.value.appendLLMResponse(id, partialResult, done)
                            }
                            if (done) {
                                _llmState.emit(LLMState.LLMResponseLoaded)
                                currentLLMResponseId = null
                            }
                        }
                    }
            } catch (e: Exception) {
                _chatState.value.addErrorLLMResponse(e)
            }
        }
    }
}

Key Functions:

  1. initLLMModel: Initializes the LLM model and updates the state accordingly.

  2. sendMessage: Handles user messages, generates LLM responses, and updates the chat state with partial and final results.

Bringing it Together In a Chat Interface

test

Pros

Privacy: By processing data locally on the device, on-device LLMs reduce the need to send sensitive information over the internet, enhancing user privacy.

Offline Capabilities: On-device LLMs can function without an internet connection, enabling users to access language processing features even in offline environments.

Low Latency: Processing data locally reduces the latency associated with sending data to remote servers for processing, resulting in faster response times.

Reduced Data Costs: Users can avoid data charges associated with sending data to remote servers for processing.

Customization: On-device LLMs can be customized and optimized for specific use cases or devices, allowing for greater flexibility and performance optimization.

Cost-Effectiveness: On-device processing can be more cost-effective in the long run, as it reduces the need for expensive server infrastructure and data transfer costs.

Cons

Model Size and Complexity: On-device processing requires the LLM model to be stored locally on the device, which can be challenging due to the size and complexity of modern LLMs. Larger models require more storage space and computational resources, which can strain lower-end devices.

Resource Intensive: Running LLMs on-device can be resource-intensive, especially for complex models or long sequences. This can lead to increased battery consumption and slower performance, particularly on older or less powerful devices.

Model Updates: Keeping the LLM model up to date with the latest advancements and improvements can be challenging. Updating the model requires updating the application, which may only sometimes be feasible or practical for users.

Github Project Link - Here

AWS Certified Team

Tech Holding Team is a AWS Certified & validates cloud expertise to help professionals highlight in-demand skills and organizations build effective, innovative teams for cloud initiatives using AWS.

By using this site, you agree to thePrivacy Policy.