diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt index 3b1451333..ddb1e3352 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/assistants/CachedTool.kt @@ -15,13 +15,44 @@ abstract class CachedTool( private val timeCachePolicy: Duration = 1.days ) : Tool { - override suspend fun invoke(input: Input): Output { - return cache(CachedToolKey(input, seed)) { onCacheMissed(input) } - } + /** + * Logic to be executed when the cache is missed. + * + * @return the output. + */ + abstract suspend fun onCacheMissed(input: Input): Output + + /** + * Criteria to check if the cache should be used for the given [input]. By default, it returns + * true, meaning always use the cache if available. + * + * @return true if the cache should be used. + */ + open suspend fun shouldUseCache(input: Input): Boolean = true + + /** + * Criteria to check if the result should be cached based on the given [input] and [output]. By + * default, it returns true, meaning always cache the result. + * + * @return true if the result should be cached. + */ + open suspend fun shouldCacheOutput(input: Input, output: Output): Boolean = true + + /** + * Caches the result of [onCacheMissed] if [shouldCacheOutput] returns true. Otherwise, returns + * the result of [onCacheMissed]. + * + * @return the output. + */ + override suspend fun invoke(input: Input): Output = + if (shouldUseCache(input)) cache(CachedToolKey(input, seed)) { onCacheMissed(input) } + else onCacheMissed(input) /** * Exposes the cache as a [Map] of [Input] to [Output] filtered by instance [seed] and * [timeCachePolicy]. Removes expired cache entries. + * + * @return the map of input to output. */ suspend fun getCache(): Map { val lastTimeInCache = timeInMillis() - timeCachePolicy.inWholeMilliseconds @@ -42,8 +73,6 @@ abstract class CachedTool( return withoutExpired.map { it.key.value to it.value.value }.toMap() } - abstract suspend fun onCacheMissed(input: Input): Output - private suspend fun cache(input: CachedToolKey, block: suspend () -> Output): Output { val cachedToolInfo = cache.get()[input] if (cachedToolInfo != null) { @@ -55,7 +84,9 @@ abstract class CachedTool( } } val response = block() - cache.get()[input] = CachedToolValue(response, timeInMillis()) + if (shouldCacheOutput(input.value, response)) { + cache.get()[input] = CachedToolValue(response, timeInMillis()) + } return response } }