Leveraging fold operator in Kotlin flow to optimize Room DB upserts

Hey everyone, today we will look into an interesting topic to optimize room DB upserts using fold in Kotlin flow.

The case?

We are working on a chat application where we fetch a list of chats in pages from the network and upsert it into DB as it is an offline-first application. We use Room here as an ORM. we fetch 45 chats in a single page. Considering we have 900 total chats, 20 API calls will be made to fetch all chats, and 20 upsert operations will be performed in DB. You may think we can do it by using pagination, but the problem with pagination is any chat from any page may come to the top and we may end us missing moving chats to the top.

The Problem?

Since there are so frequent upserts into the DB, the DB puts the upsert into the queue until the previous upsert gets completed. This gradually increases the time to show chats on the screen to the user.

Let's calculate numbers by writing some code

Old Implementation

Here we are simulating the behavior by mocking some data and upserting it into the DB. Let's create a mock JSON Array of 1000 chats. Based on our above case we will be upserting data in batches of 45. Let's see some code. You will find the below snippet in the code sample link.

lifecycleScope.launch(Dispatchers.IO) {
            val jsonArray = loadJsonFromAsset(this@MainActivity, "chats.json") ?: return@launch
            val chats = mutableListOf<Chat>()
            for (i in 0 until jsonArray.length()) {
                val channelJson = jsonArray.optJSONObject(i)
                val id = channelJson.optString("id")
                val name = channelJson.optString("name")
                val lastMsg = channelJson.optString("last_message")
                val lastSeen = channelJson.optString("last_seen")
                val type = channelJson.optString("type")
                chats.add(createChat(id, name, lastMsg, lastSeen, type))
            }
            val db = AppDatabase.getInstance(this@MainActivity)
            val time = measureTimeMillis {
                chats.chunked(45).asFlow().collect {
                    db.chatsDao().insertChats(it)
                }
            }
            withContext(Dispatchers.Main) {
                tvResult.text = "time taken: $time ms"
            }
        }

Let's dry-run this code. From the above snippet, you can see that we are parsing a JSON Array of 1000 chats. we are then iterating it by parsing each chat and adding it to a mutable list to create our test data. After creating our sample list of chats we are simulating the behaviour of upserting it in a batch of 45.

We will take numbers on a cold start with no data in the table. Let's calculate the number by running the above code. For Nothing(1) the numbers are as below. We have taken 10 iterations and we will average out the number.

1. time taken: 284 ms
2. time taken: 234 ms
3. time taken: 280 ms
4. time taken: 226 ms
5. time taken: 268 ms
6. time taken: 256 ms
7. time taken: 303 ms
8. time taken: 235 ms
9. time taken: 252 ms
10. time taken: 259 ms

The average of the above for old implementation is around 259.7 ms.

Optimized Implementation using Fold Operator.

Let's understand how fold works in Kotlin flow.

suspend fun sumNumbers(flow: Flow<Int>): Int {
    return flow.fold(0) { accumulator, value ->
        accumulator + value
    }
}

val numbersFlow = flowOf(1, 2, 3, 4, 5)
val sum = sumNumbers(numbersFlow)
println("Sum: $sum") // Output: Sum: 15

//How it works. Let's dry run?
//1. fold needs an initial value, we have provided 0 as default.
//2. The lambda has 2 parameters accumulator & value.
//3. Accumulator -> initial value.
//4. value -> value from flow.
//5. Let's run the code.
//6. For the first time accumulator -> 0, value -> 1. 
//   Addition of both is 1 which is the return statement.
//7. Now, Accumulator -> 1, Value -> 2 = 3
//8. Accumulator -> 3, Value -> 3 = 6
//9. Accumulator -> 6, Value -> 4 = 10
//10. Accumulator -> 10, Value -> 5 = 15
//11. Hence, the output is 15 here. Hope you got to know how it works.

Here we are using @Transaction on a suspend function to upsert the chats data. As per our findings, we found out that using batch upserts so frequently may increase the time to insert data in our DB.

Why so? As per the definition of @Transaction A transaction is a unit of work that ensures the atomicity, consistency, isolation, and durability (ACID) properties of database operations. So basically doing frequent batch updates into the DB may cause an increase in the number to upsert the data since @Transation promises ACID properties.

Let's utilize the time and instead of doing frequent updates we will accumulate the chats data in memory for a while and do a bulk update instead. So upserting 45 chats in a single time we will do 150 at a time.

Want to see some code on how it's done?

lifecycleScope.launch(Dispatchers.IO) {
            val jsonArray = loadJsonFromAsset(this@MainActivity, "chats.json") ?: return@launch
            val chats = mutableListOf<Chat>()
            for (i in 0 until jsonArray.length()) {
                val channelJson = jsonArray.optJSONObject(i)
                val id = channelJson.optString("id")
                val name = channelJson.optString("name")
                val lastMsg = channelJson.optString("last_message")
                val lastSeen = channelJson.optString("last_seen")
                val type = channelJson.optString("type")
                chats.add(createChat(id, name, lastMsg, lastSeen, type))
            }
            val db = AppDatabase.getInstance(this@MainActivity)
            var isLastItem = false
            var emittedItems = 0
            val time = measureTimeMillis {
                chats.chunked(45).asFlow().onEach {
                    emittedItems += it.size
                    if (emittedItems >= chats.size) {
                        isLastItem = true
                    }
                }.fold(ChatAcc.Initial as ChatAcc) { acc, value ->
                    val result = when (acc) {
                        ChatAcc.Initial -> {
                            db.chatDao().insertChats(value)
                            ChatAcc.NextChats(emptyList())
                        }
                        is ChatAcc.NextChats -> {
                            if (acc.chats.size >= 150 || isLastItem) {
                                val chats = acc.chats + value
                                db.chatsDao().insertChats(chats)
                                ChatAcc.NextChats(emptyList())
                            } else {
                                ChatAcc.NextChats(acc.chats + value)
                            }
                        }
                    }
                    result
                }
            }
            withContext(Dispatchers.Main) {
                tvResult.text = "time taken: $time ms"
            }
        }

sealed class ChatAcc {
        data object Initial : ChatAcc()
        class NextChats(val chats: List<Chat>) : ChatAcc()
    }

So we have created a sealed class that has Initial and NextChats. Initial takes care of the first page to be loaded instantly and show chats to the user meanwhile we will start accumulating chats and update it in DB in bulk.

As per the fold example we have seen above, nothing much fancy here. We accumulated chats until they were 150 or we have fetched the last page and updated it in DB.

Similarly, let's take out the number for the optimized flow.

1. time taken: 187 ms
2. time taken: 181 ms
3. time taken: 182 ms
4. time taken: 185 ms
5. time taken: 180 ms
6. time taken: 180 ms
7. time taken: 196 ms
8. time taken: 192 ms
9. time taken: 182 ms
10. time taken: 178 ms

The average of the above for new implementation is around 184.3 ms.

Comparing the above numbers we can say that we have got a decrease of around 29.03% while upserting the data in DB by using the fold operator.

Thanks for reading! Have a good day.

Github link: https://github.com/tusharpingale04/DBOptimisation

Special thanks to Nishant Pardamwar (https://github.com/nishantpardamwar) for his invaluable contribution in identifying and providing his insightful input for optimizing the code.