diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/pom.xml b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/pom.xml new file mode 100644 index 000000000000..a52ec2887fc8 --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/pom.xml @@ -0,0 +1,211 @@ + + + + + 4.0.0 + + org.wso2.carbon.identity.framework + ai-services-mgt + 7.7.93-SNAPSHOT + ../pom.xml + + + org.wso2.carbon.identity.ai.service.mgt + bundle + WSO2 Identity - AI Service Management Bundle + This represents the AI Service Management Bundle. + http://wso2.org + + + + org.ops4j.pax.logging + pax-logging-api + + + org.wso2.carbon.identity.framework + org.wso2.carbon.identity.core + + + org.apache.httpcomponents.wso2 + httpcore + provided + + + com.fasterxml.jackson.core + jackson-databind + provided + + + org.mockito + mockito-core + test + + + org.testng + testng + test + + + org.slf4j + slf4j-api + provided + + + org.apache.logging.log4j + log4j-core + test + + + org.wiremock + wiremock + test + + + org.ops4j.pax.logging + pax-logging-api + provided + + + + + + + org.apache.felix + maven-bundle-plugin + true + + + + ${project.artifactId} + + ${project.artifactId} + + org.osgi.framework; version="${osgi.framework.imp.pkg.version.range}", + org.osgi.service.component; version="${osgi.service.component.imp.pkg.version.range}", + com.google.gson;version="${com.google.code.gson.osgi.version.range}", + org.wso2.carbon.identity.core.util; version="${carbon.identity.package.import.version.range}", + org.apache.commons.lang; version="${commons-lang.wso2.osgi.version.range}", + org.apache.commons.logging; version="${import.package.version.commons.logging}", + com.fasterxml.jackson.databind.*; version="${com.fasterxml.jackson.annotation.version.range}", + org.wso2.carbon.context; version="${carbon.kernel.package.import.version.range}", + + org.apache.http; version="${httpcore.version.osgi.import.range}", + org.apache.http.client; version="${httpcomponents-httpclient.imp.pkg.version.range}", + org.apache.http.client.methods; version="${httpcomponents-httpclient.imp.pkg.version.range}", + org.apache.http.client.config; version="${httpcomponents-httpclient.imp.pkg.version.range}", + org.apache.http.entity; version="${httpcore.version.osgi.import.range}", + org.apache.http.message; version="${httpcore.version.osgi.import.range}", + org.apache.http.protocol; version="${httpcore.version.osgi.import.range}", + org.apache.http.util; version="${httpcore.version.osgi.import.range}", + org.apache.http.impl.client; version="${httpcomponents-httpclient.imp.pkg.version.range}", + org.apache.http.concurrent; version="${httpcore.version.osgi.import.range}", + + + org.wso2.carbon.identity.ai.service.mgt.*; version="${carbon.identity.package.export.version}" + + + + + + org.apache.maven.plugins + maven-surefire-plugin + ${maven.surefire.plugin.version} + + + + ${argLine} + --add-opens=java.base/java.lang=ALL-UNNAMED + --add-opens=java.base/java.util=ALL-UNNAMED + --add-opens java.xml/jdk.xml.internal=ALL-UNNAMED + --add-opens=java.base/java.io=ALL-UNNAMED + --add-opens=java.base/sun.nio.fs=ALL-UNNAMED + + + src/test/resources/testng.xml + + + + + org.jacoco + jacoco-maven-plugin + ${jacoco.version} + + + **/*Exception.class + **/*Constants*.class + + + + + default-prepare-agent + + prepare-agent + + + + default-prepare-agent-integration + + prepare-agent-integration + + + + default-report + + report + + + + default-report-integration + + report-integration + + + + default-check + + check + + + + + BUNDLE + + + COMPLEXITY + COVEREDRATIO + 0.80 + + + + + + + + + + com.github.spotbugs + spotbugs-maven-plugin + + High + + + + + + diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/constants/AIConstants.java b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/constants/AIConstants.java new file mode 100644 index 000000000000..9621165b6646 --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/constants/AIConstants.java @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.ai.service.mgt.constants; + +/** + * Constants for the LoginFlowAI module. + */ +public class AIConstants { + + private AIConstants () { + + } + + public static final String AI_SERVICE_KEY_PROPERTY_NAME = "AIServices.Key"; + public static final String AI_TOKEN_ENDPOINT_PROPERTY_NAME = "AIServices.TokenEndpoint"; + public static final String AI_TOKEN_SERVICE_MAX_RETRIES_PROPERTY_NAME = "AIServices.TokenRequestMaxRetries"; + public static final String AI_TOKEN_CONNECTION_TIMEOUT_PROPERTY_NAME = "AIServices.TokenConnectionTimeout"; + public static final String AI_TOKEN_CONNECTION_REQUEST_TIMEOUT_PROPERTY_NAME = "AIServices" + + ".TokenConnectionRequestTimeout"; + public static final String AI_TOKEN_SOCKET_TIMEOUT_PROPERTY_NAME = "AIServices.TokenConnectionSocketTimeout"; + + public static final String HTTP_CONNECTION_POOL_SIZE_PROPERTY_NAME = "AIServices.HTTPConnectionPoolSize"; + public static final String HTTP_CONNECTION_TIMEOUT_PROPERTY_NAME = "AIServices.HTTPConnectionTimeout"; + public static final String HTTP_CONNECTION_REQUEST_TIMEOUT_PROPERTY_NAME = "AIServices" + + ".HTTPConnectionRequestTimeout"; + public static final String HTTP_SOCKET_TIMEOUT_PROPERTY_NAME = "AIServices.HTTPSocketTimeout"; + + // Http constants. + public static final String HTTP_BASIC = "Basic"; + public static final String HTTP_BEARER = "Bearer"; + public static final String CONTENT_TYPE_FORM_URLENCODED = "application/x-www-form-urlencoded"; + public static final String CONTENT_TYPE_JSON = "application/json"; + + // Access Token response constants. + public static final String ACCESS_TOKEN_KEY = "access_token"; + + public static final String TENANT_CONTEXT_PREFIX = "/t/"; + + // Default Property values. + public static final int DEFAULT_TOKEN_REQUEST_MAX_RETRIES = 3; + public static final int DEFAULT_TOKEN_CONNECTION_TIMEOUT = 3000; + public static final int DEFAULT_TOKEN_CONNECTION_REQUEST_TIMEOUT = 3000; + public static final int DEFAULT_TOKEN_SOCKET_TIMEOUT = 3000; + + public static final int DEFAULT_HTTP_CONNECTION_POOL_SIZE = 20; + public static final int DEFAULT_HTTP_CONNECTION_TIMEOUT = 3000; + public static final int DEFAULT_HTTP_CONNECTION_REQUEST_TIMEOUT = 3000; + public static final int DEFAULT_HTTP_SOCKET_TIMEOUT = 3000; + + /** + * Enums for error messages. + */ + public enum ErrorMessages { + + MAXIMUM_RETRIES_EXCEEDED("AI-10000", "Maximum retries exceeded to retrieve the access token."), + UNABLE_TO_ACCESS_AI_SERVICE_WITH_RENEW_ACCESS_TOKEN("AI-10003", "Unable to access the " + + "AI service with the renewed access token."), + REQUEST_TIMEOUT("AI-10004", "Request to the AI service timed out."), + ERROR_RETRIEVING_ACCESS_TOKEN("AI-10007", "Error occurred while retrieving the " + + "access token."), + CLIENT_ERROR_WHILE_CONNECTING_TO_AI_SERVICE("AI-10008", "Client error occurred " + + "for %s tenant while connecting to AI service."), + SERVER_ERROR_WHILE_CONNECTING_TO_AI_SERVICE("AI-10009", "Server error occurred " + + "for %s tenant while connecting to AI service."); + + private final String code; + private final String message; + + ErrorMessages(String code, String message) { + + this.code = code; + this.message = message; + } + + public String getCode() { + + return code; + } + + public String getMessage() { + + return message; + } + + @Override + public String toString() { + + return code + ":" + message; + } + } +} diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/exceptions/AIClientException.java b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/exceptions/AIClientException.java new file mode 100644 index 000000000000..e13a05c3e838 --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/exceptions/AIClientException.java @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.ai.service.mgt.exceptions; + +/** + * Client Exception class for AI service. + */ +public class AIClientException extends AIException { + + public AIClientException(String message, String errorCode) { + + super(message, errorCode); + } + + public AIClientException(String message, String errorCode, int serverStatusCode, String serverMessage) { + + super(message, errorCode, serverStatusCode, serverMessage); + } + + public AIClientException(String message, String errorCode, Throwable cause) { + + super(message, errorCode, cause); + } +} diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/exceptions/AIException.java b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/exceptions/AIException.java new file mode 100644 index 000000000000..a7b718f234b2 --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/exceptions/AIException.java @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.ai.service.mgt.exceptions; + +/** + * Generic Exception class for AI service. + */ +public class AIException extends Exception { + + private String errorCode; + // This is the error message that comes from the server. + private String serverMessage; + // This is the status code that comes from the server. + private int serverStatusCode; + + public AIException(String message, String errorCode) { + + super(message); + this.errorCode = errorCode; + } + + public AIException(String message, Throwable cause) { + + super(message, cause); + } + + public AIException(String message, String errorCode, int serverStatusCode, String serverMessage) { + + super(message); + this.errorCode = errorCode; + this.serverStatusCode = serverStatusCode; + this.serverMessage = serverMessage; + } + + public AIException(String message, String errorCode, Throwable cause) { + + super(message, cause); + this.errorCode = errorCode; + } + + public String getErrorCode() { + + return errorCode; + } + + public String getServerMessage() { + + return serverMessage; + } + + public int getServerStatusCode() { + + return serverStatusCode; + } +} diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/exceptions/AIServerException.java b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/exceptions/AIServerException.java new file mode 100644 index 000000000000..7e5722cf2d6f --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/exceptions/AIServerException.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.ai.service.mgt.exceptions; + +/** + * Client Exception class for AI service. + */ +public class AIServerException extends AIException { + + public AIServerException(String message, String errorCode) { + + super(message, errorCode); + } + + public AIServerException(String message, Throwable e) { + + super(message, e); + } + + public AIServerException(String message, String errorCode, int serverStatusCode, String serverMessage) { + + super(message, errorCode, serverStatusCode, serverMessage); + } + + public AIServerException(String message, String errorCode, Throwable cause) { + + super(message, errorCode, cause); + } +} diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/token/AIAccessTokenManager.java b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/token/AIAccessTokenManager.java new file mode 100644 index 000000000000..bfe2b48f5cfb --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/token/AIAccessTokenManager.java @@ -0,0 +1,216 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.ai.service.mgt.token; + +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import org.apache.commons.lang.StringUtils; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.http.HttpResponse; +import org.apache.http.HttpStatus; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.message.BasicHeader; +import org.apache.http.util.EntityUtils; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIServerException; +import org.wso2.carbon.identity.core.util.IdentityUtil; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Map; +import java.util.UUID; + +import static org.apache.axis2.transport.http.HTTPConstants.HEADER_CONTENT_TYPE; +import static org.apache.http.HttpHeaders.AUTHORIZATION; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.ACCESS_TOKEN_KEY; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.AI_SERVICE_KEY_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.AI_TOKEN_CONNECTION_REQUEST_TIMEOUT_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.AI_TOKEN_CONNECTION_TIMEOUT_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.AI_TOKEN_ENDPOINT_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.AI_TOKEN_SERVICE_MAX_RETRIES_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.AI_TOKEN_SOCKET_TIMEOUT_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.CONTENT_TYPE_FORM_URLENCODED; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.DEFAULT_TOKEN_CONNECTION_REQUEST_TIMEOUT; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.DEFAULT_TOKEN_CONNECTION_TIMEOUT; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.DEFAULT_TOKEN_REQUEST_MAX_RETRIES; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.DEFAULT_TOKEN_SOCKET_TIMEOUT; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.ErrorMessages.MAXIMUM_RETRIES_EXCEEDED; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.HTTP_BASIC; +import static org.wso2.carbon.identity.core.util.IdentityTenantUtil.getTenantDomainFromContext; + +/** + * The purpose of this class is to retrieve an active token to access the AI service. + */ +public class AIAccessTokenManager { + + private static volatile AIAccessTokenManager instance; + private static final Object lock = new Object(); // Lock for synchronization. + + private static final Log LOG = LogFactory.getLog(AIAccessTokenManager.class); + + private static final String AI_KEY = IdentityUtil.getProperty(AI_SERVICE_KEY_PROPERTY_NAME); + private static final String AI_TOKEN_ENDPOINT = IdentityUtil.getProperty(AI_TOKEN_ENDPOINT_PROPERTY_NAME); + + private final AccessTokenRequestHelper accessTokenRequestHelper; + + private String accessToken; + private final String clientId; + + private AIAccessTokenManager() { + + byte[] decodedBytes = Base64.getDecoder().decode(AI_KEY); + String decodedString = new String(decodedBytes, StandardCharsets.UTF_8); + String[] parts = decodedString.split(":"); + if (parts.length == 2) { + this.clientId = parts[0]; + } else { + throw new IllegalArgumentException("Invalid AI service key."); + } + this.accessTokenRequestHelper = new AccessTokenRequestHelper(AI_KEY, AI_TOKEN_ENDPOINT); + } + + /** + * Get the singleton instance of the AIAccessTokenManager. + * + * @return The singleton instance. + */ + public static AIAccessTokenManager getInstance() { + + if (instance == null) { + synchronized (lock) { + if (instance == null) { + instance = new AIAccessTokenManager(); + } + } + } + return instance; + } + + /** + * Get the access token. + * + * @param renewAccessToken Whether to renew the access token. + * @return The access token. + * @throws AIServerException If an error occurs while obtaining the access token. + */ + public String getAccessToken(boolean renewAccessToken) throws AIServerException { + + if (StringUtils.isEmpty(accessToken) || renewAccessToken) { + synchronized (AIAccessTokenManager.class) { + if (StringUtils.isEmpty(accessToken) || renewAccessToken) { + this.accessToken = accessTokenRequestHelper.requestAccessToken(); + } + } + } + return this.accessToken; + } + + /** + * Get the client ID. + * + * @return The client ID. + */ + public String getClientId() { + + return this.clientId; + } + + /** + * Helper class to request access token from the AI services. + */ + private static class AccessTokenRequestHelper { + + private final CloseableHttpClient client; + private final Gson gson; + private final String key; + private final HttpPost tokenRequest; + private static final int MAX_RETRIES = readIntProperty(AI_TOKEN_SERVICE_MAX_RETRIES_PROPERTY_NAME, + DEFAULT_TOKEN_REQUEST_MAX_RETRIES); + private static final int CONNECTION_TIMEOUT = readIntProperty(AI_TOKEN_CONNECTION_TIMEOUT_PROPERTY_NAME, + DEFAULT_TOKEN_CONNECTION_TIMEOUT); + private static final int CONNECTION_REQUEST_TIMEOUT = readIntProperty( + AI_TOKEN_CONNECTION_REQUEST_TIMEOUT_PROPERTY_NAME, DEFAULT_TOKEN_CONNECTION_REQUEST_TIMEOUT); + private static final int SOCKET_TIMEOUT = readIntProperty(AI_TOKEN_SOCKET_TIMEOUT_PROPERTY_NAME, + DEFAULT_TOKEN_SOCKET_TIMEOUT); + + AccessTokenRequestHelper(String key, String tokenEndpoint) { + + RequestConfig requestConfig = RequestConfig.custom() + .setConnectTimeout(CONNECTION_TIMEOUT) + .setConnectionRequestTimeout(CONNECTION_REQUEST_TIMEOUT) + .setSocketTimeout(SOCKET_TIMEOUT) + .build(); + this.client = HttpClientBuilder.create() + .setDefaultRequestConfig(requestConfig).build(); + this.gson = new GsonBuilder().create(); + this.key = key; + this.tokenRequest = new HttpPost(tokenEndpoint); + } + + /** + * Request access token to access the AI services. + * + * @return the JWT access token. + * @throws AIServerException If an error occurs while requesting the access token. + */ + public String requestAccessToken() throws AIServerException { + + String tenantDomain = getTenantDomainFromContext(); + LOG.info("Initiating access token request for AI services from tenant: " + tenantDomain); + for (int attempt = 0; attempt < MAX_RETRIES; attempt++) { + try { + tokenRequest.setHeader(AUTHORIZATION, HTTP_BASIC + " " + key); + tokenRequest.setHeader(HEADER_CONTENT_TYPE, CONTENT_TYPE_FORM_URLENCODED); + + StringEntity entity = new StringEntity("grant_type=client_credentials&tokenBindingId=" + + UUID.randomUUID()); + entity.setContentType(new BasicHeader(HEADER_CONTENT_TYPE, CONTENT_TYPE_FORM_URLENCODED)); + tokenRequest.setEntity(entity); + + HttpResponse response = client.execute(tokenRequest); + if (response.getStatusLine().getStatusCode() == HttpStatus.SC_OK) { + String responseBody = EntityUtils.toString(response.getEntity()); + Map responseMap = gson.fromJson(responseBody, Map.class); + return (String) responseMap.get(ACCESS_TOKEN_KEY); + } else { + LOG.error("Token request failed with status code: " + + response.getStatusLine().getStatusCode()); + } + } catch (IOException e) { + throw new AIServerException("Error executing token request: " + e.getMessage(), e); + } finally { + tokenRequest.releaseConnection(); + } + } + throw new AIServerException("Failed to obtain access token after " + MAX_RETRIES + + " attempts.", MAXIMUM_RETRIES_EXCEEDED.getCode()); + } + + private static int readIntProperty(String key, int defaultValue) { + + String value = IdentityUtil.getProperty(key); + return value != null ? Integer.parseInt(value) : defaultValue; + } + } +} diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/util/AIHttpClientUtil.java b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/util/AIHttpClientUtil.java new file mode 100644 index 000000000000..337cf19837db --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/main/java/org/wso2/carbon/identity/ai/service/mgt/util/AIHttpClientUtil.java @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.ai.service.mgt.util; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.commons.httpclient.HttpStatus; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.http.HttpResponse; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpUriRequest; +import org.apache.http.entity.StringEntity; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.util.EntityUtils; +import org.wso2.carbon.context.PrivilegedCarbonContext; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIClientException; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIServerException; +import org.wso2.carbon.identity.ai.service.mgt.token.AIAccessTokenManager; +import org.wso2.carbon.identity.core.util.IdentityUtil; + +import java.io.IOException; +import java.util.Map; + +import static org.apache.axis2.transport.http.HTTPConstants.HEADER_CONTENT_TYPE; +import static org.apache.http.HttpHeaders.AUTHORIZATION; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.CONTENT_TYPE_JSON; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.DEFAULT_HTTP_CONNECTION_POOL_SIZE; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.DEFAULT_HTTP_CONNECTION_REQUEST_TIMEOUT; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.DEFAULT_HTTP_CONNECTION_TIMEOUT; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.DEFAULT_HTTP_SOCKET_TIMEOUT; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.ErrorMessages.CLIENT_ERROR_WHILE_CONNECTING_TO_AI_SERVICE; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.ErrorMessages.ERROR_RETRIEVING_ACCESS_TOKEN; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.ErrorMessages.SERVER_ERROR_WHILE_CONNECTING_TO_AI_SERVICE; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.ErrorMessages.UNABLE_TO_ACCESS_AI_SERVICE_WITH_RENEW_ACCESS_TOKEN; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.HTTP_BEARER; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.HTTP_CONNECTION_POOL_SIZE_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.HTTP_CONNECTION_REQUEST_TIMEOUT_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.HTTP_CONNECTION_TIMEOUT_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.HTTP_SOCKET_TIMEOUT_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.TENANT_CONTEXT_PREFIX; + +/** + * Utility class for AI Services to send HTTP requests. + */ +public class AIHttpClientUtil { + + private static final Log LOG = LogFactory.getLog(AIHttpClientUtil.class); + private static final ObjectMapper objectMapper = new ObjectMapper(); + + private static final int HTTP_CONNECTION_POOL_SIZE = readIntProperty(HTTP_CONNECTION_POOL_SIZE_PROPERTY_NAME, + DEFAULT_HTTP_CONNECTION_POOL_SIZE); + private static final int HTTP_CONNECTION_TIMEOUT = readIntProperty(HTTP_CONNECTION_TIMEOUT_PROPERTY_NAME, + DEFAULT_HTTP_CONNECTION_TIMEOUT); + private static final int HTTP_CONNECTION_REQUEST_TIMEOUT = readIntProperty( + HTTP_CONNECTION_REQUEST_TIMEOUT_PROPERTY_NAME, DEFAULT_HTTP_CONNECTION_REQUEST_TIMEOUT); + private static final int HTTP_SOCKET_TIMEOUT = readIntProperty(HTTP_SOCKET_TIMEOUT_PROPERTY_NAME, + DEFAULT_HTTP_SOCKET_TIMEOUT); + + // Singleton instance of CloseableHttpClient with connection pooling. + private static final CloseableHttpClient httpClient = HttpClients.custom() + .setMaxConnTotal(HTTP_CONNECTION_POOL_SIZE) + .setDefaultRequestConfig( + org.apache.http.client.config.RequestConfig.custom() + .setSocketTimeout(HTTP_SOCKET_TIMEOUT) + .setConnectTimeout(HTTP_CONNECTION_TIMEOUT) + .setConnectionRequestTimeout(HTTP_CONNECTION_REQUEST_TIMEOUT) + .build() + ).build(); + + private AIHttpClientUtil() { + + } + + /** + * Execute a request to the AI service. + * + * @param path The endpoint to which the request should be sent. + * @param requestType The type of the request (GET, POST). + * @param requestBody The request body(Only for POST requests). + * @param aiServiceEndpoint The endpoint of the AI service. + * @return The response from the AI service as a map. + * @throws AIServerException If a server error occurred while accessing the AI service. + * @throws AIClientException If a client error occurred while accessing the AI service. + */ + public static Map executeRequest(String aiServiceEndpoint, String path, + Class requestType, Object requestBody) + throws AIServerException, AIClientException { + + String tenantDomain = PrivilegedCarbonContext.getThreadLocalCarbonContext().getTenantDomain(); + + try { + String accessToken = AIAccessTokenManager.getInstance().getAccessToken(false); + String clientId = AIAccessTokenManager.getInstance().getClientId(); + + HttpUriRequest request = createRequest(aiServiceEndpoint + TENANT_CONTEXT_PREFIX + clientId + path, + requestType, accessToken, requestBody); + HttpResponseWrapper aiServiceResponse = executeRequestWithRetry(request); + + int statusCode = aiServiceResponse.getStatusCode(); + String responseBody = aiServiceResponse.getResponseBody(); + + if (statusCode >= 400) { + handleErrorResponse(statusCode, responseBody, tenantDomain); + } + return convertJsonStringToMap(responseBody); + } catch (IOException e) { + throw new AIServerException("An error occurred while connecting to the AI Service.", + SERVER_ERROR_WHILE_CONNECTING_TO_AI_SERVICE.getCode(), e); + } + } + + private static HttpUriRequest createRequest(String url, Class requestType, + String accessToken, Object requestBody) throws IOException { + + HttpUriRequest request; + if (requestType == HttpPost.class) { + HttpPost post = new HttpPost(url); + if (requestBody != null) { + post.setEntity(new StringEntity(objectMapper.writeValueAsString(requestBody))); + } + request = post; + } else if (requestType == HttpGet.class) { + request = new HttpGet(url); + } else { + throw new IllegalArgumentException("Unsupported request type: " + requestType.getName()); + } + + request.setHeader(AUTHORIZATION, HTTP_BEARER + " " + accessToken); + request.setHeader(HEADER_CONTENT_TYPE, CONTENT_TYPE_JSON); + return request; + } + + private static HttpResponseWrapper executeRequestWithRetry(HttpUriRequest request) + throws IOException, AIServerException { + + HttpResponseWrapper response = executeHttpRequest(request); + + if (response.getStatusCode() == HttpStatus.SC_UNAUTHORIZED) { + String newAccessToken = AIAccessTokenManager.getInstance().getAccessToken(true); + if (newAccessToken == null) { + throw new AIServerException("Failed to renew access token.", ERROR_RETRIEVING_ACCESS_TOKEN.getCode()); + } + request.setHeader(AUTHORIZATION, HTTP_BEARER + " " + newAccessToken); + response = executeHttpRequest(request); + } + return response; + } + + private static void handleErrorResponse(int statusCode, String responseBody, String tenantDomain) + throws AIServerException, AIClientException { + + if (statusCode == HttpStatus.SC_UNAUTHORIZED) { + throw new AIServerException("Failed to access AI service with renewed access token for " + + "the tenant domain: " + tenantDomain, + UNABLE_TO_ACCESS_AI_SERVICE_WITH_RENEW_ACCESS_TOKEN.getCode()); + } else if (statusCode >= 400 && statusCode < 500) { + throw new AIClientException("Client error occurred from tenant: " + tenantDomain + " with status code: '" + + statusCode + "' while accessing AI service.", + CLIENT_ERROR_WHILE_CONNECTING_TO_AI_SERVICE.getCode(), statusCode, responseBody); + } else if (statusCode >= 500) { + throw new AIServerException("Server error occurred from tenant: " + tenantDomain + " with status code: '" + + statusCode + "' while accessing AI service.", + SERVER_ERROR_WHILE_CONNECTING_TO_AI_SERVICE.getCode(), statusCode, responseBody); + } + } + + private static Map convertJsonStringToMap(String jsonString) throws AIServerException { + + try { + return objectMapper.readValue(jsonString, Map.class); + } catch (IOException e) { + throw new AIServerException("Error occurred while parsing the JSON response from the AI service.", e); + } + } + + private static HttpResponseWrapper executeHttpRequest(HttpUriRequest httpRequest) + throws IOException { + + // Here we don't close the client connection since we are using a connection pool. + HttpResponse httpResponse = httpClient.execute(httpRequest); + int status = httpResponse.getStatusLine().getStatusCode(); + String response = EntityUtils.toString(httpResponse.getEntity()); + return new HttpResponseWrapper(status, response); + } + + private static int readIntProperty(String key, int defaultValue) { + + String value = IdentityUtil.getProperty(key); + return value != null ? Integer.parseInt(value) : defaultValue; + } + + /** + * Wrapper class for HTTP response. + */ + public static class HttpResponseWrapper { + + private final int statusCode; + private final String responseBody; + + public HttpResponseWrapper(int statusCode, String responseBody) { + + this.statusCode = statusCode; + this.responseBody = responseBody; + } + + public int getStatusCode() { + + return statusCode; + } + + public String getResponseBody() { + + return responseBody; + } + } +} diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/java/org/wso2/carbon/identity/ai/service/mgt/token/AIAccessTokenManagerTest.java b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/java/org/wso2/carbon/identity/ai/service/mgt/token/AIAccessTokenManagerTest.java new file mode 100644 index 000000000000..77c9bd19976b --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/java/org/wso2/carbon/identity/ai/service/mgt/token/AIAccessTokenManagerTest.java @@ -0,0 +1,293 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.ai.service.mgt.token; + +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.http.Fault; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIServerException; +import org.wso2.carbon.identity.core.util.IdentityUtil; + +import java.util.Base64; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.AI_TOKEN_CONNECTION_REQUEST_TIMEOUT_PROPERTY_NAME; + +/** + * Test class for AIAccessTokenManager. + */ +public class AIAccessTokenManagerTest { + + private WireMockServer wireMockServer; + private AIAccessTokenManager tokenManager; + private MockedStatic mockedStatic; + + @BeforeClass + public void init() { + + mockedStatic = Mockito.mockStatic(IdentityUtil.class); + mockedStatic.when(() -> IdentityUtil.getProperty(AI_TOKEN_CONNECTION_REQUEST_TIMEOUT_PROPERTY_NAME)) + .thenReturn("2000"); + } + + @BeforeMethod + public void setUp() throws Exception { + + // Reset the singleton instance. + resetSingletonInstance(AIAccessTokenManager.class, "instance"); + } + + private void startWireMockServer() throws Exception { + + // Start WireMock server. + wireMockServer = new WireMockServer(wireMockConfig().dynamicPort()); + wireMockServer.start(); + + // Set the AI_TOKEN_ENDPOINT to WireMock's base URL. + setStaticField(AIAccessTokenManager.class, "AI_TOKEN_ENDPOINT", wireMockServer.baseUrl() + "/token"); + } + + private void setAiServiceKey(String key) throws Exception { + + String aiServiceKey = Base64.getEncoder().encodeToString((key).getBytes()); + setStaticField(AIAccessTokenManager.class, "AI_KEY", aiServiceKey); + } + + private void resetSingletonInstance(Class clazz, String fieldName) throws Exception { + + java.lang.reflect.Field instanceField = clazz.getDeclaredField(fieldName); + instanceField.setAccessible(true); + instanceField.set(null, null); // Reset the static field to null + } + + @Test + public void testGetAccessTokenSuccess() throws Exception { + + startWireMockServer(); + setAiServiceKey("testClientId:testClientSecret"); + tokenManager = AIAccessTokenManager.getInstance(); + + // Mock a successful token response + String expectedAccessToken = "mockedAccessToken"; + wireMockServer.stubFor(post(urlEqualTo("/token")) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"access_token\":\"" + expectedAccessToken + "\"}"))); + + String accessToken = tokenManager.getAccessToken(false); + + Assert.assertEquals(accessToken, expectedAccessToken, "Access token should match the mocked value."); + Assert.assertEquals(tokenManager.getClientId(), "testClientId", "Client ID should match the mocked value."); + } + + @Test(expectedExceptions = AIServerException.class) + public void testGetAccessTokenUnauthorized() throws Exception { + + startWireMockServer(); + setAiServiceKey("testClientId:testClientSecret"); + tokenManager = AIAccessTokenManager.getInstance(); + + wireMockServer.stubFor(post(urlEqualTo("/token")) + .willReturn(aResponse() + .withStatus(401) + .withHeader("Content-Type", "application/json") + .withBody("{\"error\":\"Unauthorized\"}"))); + + tokenManager.getAccessToken(false); + } + + @Test(expectedExceptions = AIServerException.class) + public void testGetAccessTokenServerError() throws Exception { + + startWireMockServer(); + setAiServiceKey("testClientId:testClientSecret"); + tokenManager = AIAccessTokenManager.getInstance(); + + // Arrange: Mock a 500 Internal Server Error response + wireMockServer.stubFor(post(urlEqualTo("/token")) + .willReturn(aResponse() + .withStatus(500) + .withHeader("Content-Type", "application/json") + .withBody("{\"error\":\"Internal Server Error\"}"))); + + tokenManager.getAccessToken(false); + } + + @Test + public void testGetAccessTokenRenewal() throws Exception { + + startWireMockServer(); + setAiServiceKey("testClientId:testClientSecret"); + tokenManager = AIAccessTokenManager.getInstance(); + + // Arrange: Mock a successful token response for renewal. + String newAccessToken = "newMockedAccessToken"; + wireMockServer.stubFor(post(urlEqualTo("/token")) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"access_token\":\"" + newAccessToken + "\"}"))); + + String accessToken = tokenManager.getAccessToken(true); + + Assert.assertEquals(accessToken, newAccessToken, "Access token should match the renewed mocked value."); + } + + @Test + public void testGetAccessTokenExistingTokenReturnsWithoutRenewal() throws Exception { + + startWireMockServer(); + setAiServiceKey("testClientId:testClientSecret"); + tokenManager = AIAccessTokenManager.getInstance(); + + // Arrange: Mock a successful token response. + String existingAccessToken = "existingMockedAccessToken"; + wireMockServer.stubFor(post(urlEqualTo("/token")) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody("{\"access_token\":\"" + existingAccessToken + "\"}"))); + + // Act (1): First call to getAccessToken. This should fetch the token from the server. + String firstCallToken = tokenManager.getAccessToken(false); + Assert.assertEquals(firstCallToken, existingAccessToken, + "First call should retrieve the newly obtained token."); + + // Reset WireMock’s request history to track subsequent calls. + wireMockServer.resetRequests(); + + // Act (2): Second call with renewAccessToken = false and an existing token. + // This should NOT call the token endpoint again; it should return the cached token. + String secondCallToken = tokenManager.getAccessToken(false); + + Assert.assertEquals(secondCallToken, existingAccessToken, + "Second call should return the same token without making a new request."); + // Verify that no new requests to the token endpoint were made after the first call. + wireMockServer.verify(0, postRequestedFor(urlEqualTo("/token"))); + } + + @Test(expectedExceptions = AIServerException.class, + expectedExceptionsMessageRegExp = "Failed to obtain access token after.*attempts.*") + public void testGetAccessTokenMaxRetriesExceeded() throws Exception { + + startWireMockServer(); + setAiServiceKey("testClientId:testClientSecret"); + tokenManager = AIAccessTokenManager.getInstance(); + + // Stub the /token endpoint to always return a non-200 status (e.g., 500). This simulates repeated failures. + wireMockServer.stubFor(post(urlEqualTo("/token")) + .willReturn(aResponse() + .withStatus(500) + .withHeader("Content-Type", "application/json") + .withBody("{\"error\":\"Internal Server Error\"}"))); + tokenManager.getAccessToken(false); + } + + @Test(expectedExceptions = AIServerException.class, + expectedExceptionsMessageRegExp = "Error executing token request:.*") + public void testGetAccessTokenIOException() throws Exception { + + startWireMockServer(); + setAiServiceKey("testClientId:testClientSecret"); + tokenManager = AIAccessTokenManager.getInstance(); + + // Configure WireMock to cause a network-level fault that should trigger an IOException. + wireMockServer.stubFor(post(urlEqualTo("/token")) + .willReturn(aResponse() + // This simulates a situation where the connection is abruptly reset. + .withFault(Fault.CONNECTION_RESET_BY_PEER))); + + // Act: When getAccessToken calls the endpoint, the client should throw an IOException, + // causing the catch block to throw an AIServerException with "Error executing token request: ...". + tokenManager.getAccessToken(false); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testInvalidAIServiceKey() throws Exception { + + setAiServiceKey("invalidKey"); + + // Act: Attempt to get the instance, which should throw an exception. + AIAccessTokenManager.getInstance(); + } + + @Test + public void testGetInstanceFirstTimeCreation() throws Exception { + + setAiServiceKey("testClientId:testClientSecret"); + setStaticField(AIAccessTokenManager.class, "AI_TOKEN_ENDPOINT", "http://localhost.com/token"); + AIAccessTokenManager firstCallInstance = AIAccessTokenManager.getInstance(); + Assert.assertNotNull(firstCallInstance, "First call to getInstance() should create a new instance."); + } + + @Test + public void testGetInstanceSubsequentCallsReturnSameInstance() throws Exception { + + setAiServiceKey("testClientId:testClientSecret"); + setStaticField(AIAccessTokenManager.class, "AI_TOKEN_ENDPOINT", "http://localhost.com/token"); + + // Reset the singleton to ensure it's null, then create it once. + AIAccessTokenManager firstCallInstance = AIAccessTokenManager.getInstance(); + + AIAccessTokenManager secondCallInstance = AIAccessTokenManager.getInstance(); + + // Verify that the second call did NOT re-create the object. + Assert.assertNotNull(secondCallInstance, "Second call should still return an instance."); + Assert.assertEquals(secondCallInstance, firstCallInstance, + "Both calls should return the exact same singleton instance."); + } + + private void setStaticField(Class clazz, String fieldName, String value) throws Exception { + + java.lang.reflect.Field field = clazz.getDeclaredField(fieldName); + field.setAccessible(true); + + java.lang.reflect.Field modifiersField = java.lang.reflect.Field.class.getDeclaredField("modifiers"); + modifiersField.setAccessible(true); + modifiersField.setInt(field, field.getModifiers() & ~java.lang.reflect.Modifier.FINAL); + field.set(null, value); + } + + @AfterMethod + public void tearDown() { + + if (wireMockServer != null) { + wireMockServer.stop(); + wireMockServer = null; + } + } + + @AfterClass + public void destroy() { + + mockedStatic.close(); + } +} diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/java/org/wso2/carbon/identity/ai/service/mgt/util/AIHttpClientUtilTest.java b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/java/org/wso2/carbon/identity/ai/service/mgt/util/AIHttpClientUtilTest.java new file mode 100644 index 000000000000..4d561e056c33 --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/java/org/wso2/carbon/identity/ai/service/mgt/util/AIHttpClientUtilTest.java @@ -0,0 +1,469 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. +*/ + +package org.wso2.carbon.identity.ai.service.mgt.util; + +import com.github.tomakehurst.wiremock.WireMockServer; +import com.github.tomakehurst.wiremock.http.Fault; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.client.methods.HttpUriRequest; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.testng.Assert; +import org.testng.annotations.AfterClass; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; +import org.wso2.carbon.base.CarbonBaseConstants; +import org.wso2.carbon.context.PrivilegedCarbonContext; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIClientException; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIServerException; +import org.wso2.carbon.identity.ai.service.mgt.token.AIAccessTokenManager; +import org.wso2.carbon.identity.core.util.IdentityUtil; + +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.equalTo; +import static com.github.tomakehurst.wiremock.client.WireMock.equalToJson; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.post; +import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig; +import static com.github.tomakehurst.wiremock.stubbing.Scenario.STARTED; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.openMocks; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.HTTP_CONNECTION_POOL_SIZE_PROPERTY_NAME; +import static org.wso2.carbon.identity.ai.service.mgt.constants.AIConstants.TENANT_CONTEXT_PREFIX; + +/** + * Test class for AIHttpClientUtil. + */ +public class AIHttpClientUtilTest { + + private WireMockServer wireMockServer; + private final String clientId = "testClientId"; + + @Mock + private AIAccessTokenManager mockTokenManager; + + private MockedStatic aiAccessTokenManagerMockedStatic; + private MockedStatic identityUtilMockedStatic; + + @BeforeClass + public void init() { + + identityUtilMockedStatic = Mockito.mockStatic(IdentityUtil.class); + identityUtilMockedStatic.when(() -> IdentityUtil.getProperty(HTTP_CONNECTION_POOL_SIZE_PROPERTY_NAME)) + .thenReturn("10"); + } + + @BeforeMethod + public void setUp() throws Exception { + + openMocks(this); + setCarbonHome(); + setCarbonContextForTenant(); + + aiAccessTokenManagerMockedStatic = mockStatic(AIAccessTokenManager.class); + when(AIAccessTokenManager.getInstance()).thenReturn(mockTokenManager); + when(mockTokenManager.getAccessToken(false)).thenReturn("testToken"); + when(mockTokenManager.getClientId()).thenReturn(clientId); + + // Start WireMock server on a random port. + wireMockServer = new WireMockServer(wireMockConfig().dynamicPort()); + wireMockServer.start(); + + // Reset WireMock state for each test. + wireMockServer.resetAll(); + } + + @Test + public void testExecuteRequestSuccess() throws Exception { + + String expectedResponse = "{\"result\":\"SUCCESS\"}"; + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody(expectedResponse))); + + String baseUrl = wireMockServer.baseUrl(); + Map resultMap = AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + + Assert.assertEquals(resultMap.get("result"), "SUCCESS"); + wireMockServer.verify(getRequestedFor(urlEqualTo(fullPath))); + } + + @Test + public void testExecuteRequestPostSuccess() throws Exception { + + String expectedResponse = "{\"result\":\"POST_SUCCESS\"}"; + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + String requestBody = "{\"key\":\"value\"}"; + + // Stub the POST request with the expected response. + wireMockServer.stubFor(post(urlEqualTo(fullPath)) + .withHeader("Content-Type", equalTo("application/json")) + .withRequestBody(equalToJson(requestBody)) // Ensure the request body matches. + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody(expectedResponse))); + + String baseUrl = wireMockServer.baseUrl(); + Map requestBodyMap = new HashMap<>(); + requestBodyMap.put("key", "value"); + Map resultMap = AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpPost.class, + requestBodyMap + ); + + Assert.assertEquals(resultMap.get("result"), "POST_SUCCESS"); + + // Verify that the POST request was made with the correct path and body. + wireMockServer.verify(postRequestedFor(urlEqualTo(fullPath)) + .withHeader("Content-Type", equalTo("application/json")) + .withRequestBody(equalToJson(requestBody))); + } + + @Test(expectedExceptions = AIClientException.class) + public void testExecuteRequestClientError() throws Exception { + + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(400) + .withHeader("Content-Type", "application/json") + .withBody("Bad Request"))); + + String baseUrl = wireMockServer.baseUrl(); + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + } + + @Test(expectedExceptions = AIServerException.class) + public void testExecuteRequestServerError() throws Exception { + + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(500) + .withHeader("Content-Type", "text/plain") + .withBody("Internal Server Error"))); + + // Act & Assert: Execute the HTTP request and expect AIServerException. + String baseUrl = wireMockServer.baseUrl(); + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + } + + @Test + public void testExecuteRequestTokenRenewal() throws Exception { + + // Mock the AccessTokenManager to simulate token renewal. + when(mockTokenManager.getAccessToken(true)).thenReturn("newToken"); + + // Arrange: Mock token renewal flow. + String expectedResponse = "{\"result\":\"SUCCESS\"}"; + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // First response: 401 Unauthorized. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal") + .whenScenarioStateIs(STARTED) + .willReturn(aResponse() + .withStatus(401) + .withHeader("Content-Type", "application/json") + .withBody("Unauthorized")) + .willSetStateTo("Token Renewed")); // Transition to the next state. + + // Second response: 200 OK. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal") + .whenScenarioStateIs("Token Renewed") + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody(expectedResponse))); + + // Act: Execute the HTTP request. + String baseUrl = wireMockServer.baseUrl(); + Map resultMap = AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + + // Assert: Verify the response. + Assert.assertEquals(resultMap.get("result"), "SUCCESS"); + + // Verify the requests were made twice: once for 401 and once for 200. + wireMockServer.verify(2, getRequestedFor(urlEqualTo(fullPath))); + + // Verify token renewal was called once. + verify(mockTokenManager, times(1)).getAccessToken(true); + } + + @Test(expectedExceptions = AIClientException.class) + public void testExecuteRequestTokenRenewalErrorAfterRenewal() throws Exception { + // Mock the AccessTokenManager to simulate token renewal. + when(mockTokenManager.getAccessToken(true)).thenReturn("newToken"); + + // Arrange: Define paths and mock token renewal flow. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // First response: 401 Unauthorized. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal with Error") + .whenScenarioStateIs(STARTED) + .willReturn(aResponse() + .withStatus(401) + .withHeader("Content-Type", "application/json") + .withBody("Unauthorized")) + .willSetStateTo("Token Renewed")); // Transition to the next state. + + // Second response: 400 Bad Request (or you can use 500 for Internal Server Error). + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal with Error") + .whenScenarioStateIs("Token Renewed") + .willReturn(aResponse() + .withStatus(400) + .withHeader("Content-Type", "application/json") + .withBody("{\"error\":\"Bad Request\"}"))); // Error response body. + + // Act: Execute the HTTP request. + String baseUrl = wireMockServer.baseUrl(); + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + } + + @Test(expectedExceptions = AIServerException.class) + public void testExecuteRequestIOException() throws Exception { + + // Arrange: Mock a server that simulates a connection reset. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // Simulate a connection reset. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withFault(Fault.CONNECTION_RESET_BY_PEER))); // Simulates a connection reset. + + // Act & Assert: Expect AIServerException due to simulated IOException (connection reset). + String baseUrl = wireMockServer.baseUrl(); + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + } + + @Test(expectedExceptions = AIServerException.class) + public void testExecuteRequestExecutionException() throws Exception { + + // Arrange: Mock a server that simulates an unexpected response. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // Simulate an unexpected response that triggers an ExecutionException. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withFault(Fault.MALFORMED_RESPONSE_CHUNK))); // Simulates a malformed response + + // Act & Assert: Expect AIServerException due to simulated ExecutionException. + String baseUrl = wireMockServer.baseUrl(); + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testExecuteRequestUnsupportedRequestType() throws Exception { + + // Arrange: Define the path and base URL. + String path = "/test-endpoint"; + String baseUrl = "https://ai-service.example.com"; + + // Act & Assert: Pass an unsupported request type and expect IllegalArgumentException. + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpUriRequest.class, // Unsupported request type. + null + ); + } + + @Test(expectedExceptions = AIServerException.class) + public void testExecuteRequestUnauthorizedAfterTokenRenewal() throws Exception { + + // Mock the AccessTokenManager for token renewal. + when(mockTokenManager.getAccessToken(true)).thenReturn("newToken"); + + // Arrange: Define paths. + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // First response: 401 Unauthorized. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal Fails") + .whenScenarioStateIs(STARTED) + .willReturn(aResponse() + .withStatus(401) + .withHeader("Content-Type", "application/json") + .withBody("Unauthorized")) + .willSetStateTo("Retry")); + + // Second response: 401 Unauthorized again. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .inScenario("Token Renewal Fails") + .whenScenarioStateIs("Retry") + .willReturn(aResponse() + .withStatus(401) // Still Unauthorized. + .withHeader("Content-Type", "application/json") + .withBody("Still Unauthorized"))); + + String baseUrl = wireMockServer.baseUrl(); + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + } + + @Test(expectedExceptions = AIServerException.class) + public void testExecuteRequestJsonParsingError() throws Exception { + + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // Mock the server to return invalid JSON. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(200) + .withHeader("Content-Type", "application/json") + .withBody("{ invalid json }"))); // Invalid JSON. + + // Act: Execute the HTTP request, expecting AIServerException due to JSON parsing error. + String baseUrl = wireMockServer.baseUrl(); + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + } + + @Test(expectedExceptions = AIServerException.class) + public void testExecuteRequestFailedTokenRenewal() throws Exception { + + // Mock the AccessTokenManager to simulate failed token renewal. + when(mockTokenManager.getAccessToken(false)).thenReturn("oldToken"); + when(mockTokenManager.getAccessToken(true)).thenReturn(null); // Simulate failed token renewal. + + String path = "/test-endpoint"; + String fullPath = TENANT_CONTEXT_PREFIX + clientId + path; // This is the path that AIHttpClientUtil will use. + + // Mock the server to return 401 Unauthorized. + wireMockServer.stubFor(get(urlEqualTo(fullPath)) + .willReturn(aResponse() + .withStatus(401) + .withHeader("Content-Type", "application/json") + .withBody("Unauthorized"))); + + // Act: Execute the HTTP request, expecting AIServerException due to failed token renewal. + String baseUrl = wireMockServer.baseUrl(); + AIHttpClientUtil.executeRequest( + baseUrl, + path, + HttpGet.class, + null + ); + } + + private void setCarbonHome() { + + String carbonHome = Paths.get(System.getProperty("user.dir"), "target", "test-classes").toString(); + System.setProperty(CarbonBaseConstants.CARBON_HOME, carbonHome); + System.setProperty(CarbonBaseConstants.CARBON_CONFIG_DIR_PATH, Paths.get(carbonHome, "conf").toString()); + } + + private void setCarbonContextForTenant() { + + PrivilegedCarbonContext.startTenantFlow(); + PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantDomain( + org.wso2.carbon.base.MultitenantConstants.SUPER_TENANT_DOMAIN_NAME); + } + + @AfterMethod + public void tearDown() { + + aiAccessTokenManagerMockedStatic.close(); + PrivilegedCarbonContext.endTenantFlow(); + wireMockServer.stop(); + } + + @AfterClass + public void destroy() { + + identityUtilMockedStatic.close(); + } +} diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/resources/log4j.properties b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/resources/log4j.properties new file mode 100644 index 000000000000..b2ef4da808f5 --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/resources/log4j.properties @@ -0,0 +1,26 @@ +# +# Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). +# +# WSO2 LLC. licenses this file to you under the Apache License, +# Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +# Root logger option +log4j.rootLogger=INFO, stdout + +# Direct log messages to stdout +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target=System.out +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n diff --git a/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/resources/testng.xml b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/resources/testng.xml new file mode 100644 index 000000000000..aa1aa752ae1c --- /dev/null +++ b/components/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt/src/test/resources/testng.xml @@ -0,0 +1,28 @@ + + + + + + + + + + + + diff --git a/components/ai-services-mgt/pom.xml b/components/ai-services-mgt/pom.xml new file mode 100644 index 000000000000..f2a1556aa301 --- /dev/null +++ b/components/ai-services-mgt/pom.xml @@ -0,0 +1,44 @@ + + + + + 4.0.0 + + + org.wso2.carbon.identity.framework + identity-framework + 7.7.93-SNAPSHOT + ../../pom.xml + + + ai-services-mgt + pom + WSO2 Identity - AI Management Aggregator Module + + This is a Carbon bundle that represent the AI Management Aggregator Module. + + http://wso2.org + + + org.wso2.carbon.identity.ai.service.mgt + + + diff --git a/components/application-mgt/org.wso2.carbon.identity.application.mgt/pom.xml b/components/application-mgt/org.wso2.carbon.identity.application.mgt/pom.xml index 194fff38ac76..2c13097221d1 100644 --- a/components/application-mgt/org.wso2.carbon.identity.application.mgt/pom.xml +++ b/components/application-mgt/org.wso2.carbon.identity.application.mgt/pom.xml @@ -133,6 +133,10 @@ org.wso2.carbon.identity.framework org.wso2.carbon.identity.claim.metadata.mgt + + org.wso2.carbon.identity.framework + org.wso2.carbon.identity.ai.service.mgt + org.wso2.carbon.identity.framework org.wso2.carbon.identity.api.resource.mgt @@ -264,9 +268,11 @@ version="${org.wso2.carbon.identity.organization.management.core.version.range}", org.wso2.carbon.identity.api.resource.mgt.model; version="${carbon.identity.package.import.version.range}", org.wso2.carbon.identity.api.resource.mgt.util; version="${carbon.identity.package.import.version.range}", + org.wso2.carbon.identity.ai.service.mgt.*; version="${carbon.identity.package.import.version.range}", org.wso2.carbon.identity.certificate.management.service; version="${carbon.identity.package.import.version.range}", org.wso2.carbon.identity.certificate.management.exception; version="${carbon.identity.package.import.version.range}", org.wso2.carbon.identity.certificate.management.model; version="${carbon.identity.package.import.version.range}", + org.apache.http.client.methods; version="${httpcomponents-httpclient.imp.pkg.version.range}", !org.wso2.carbon.identity.application.mgt.internal, diff --git a/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/ai/LoginFlowAIManager.java b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/ai/LoginFlowAIManager.java new file mode 100644 index 000000000000..876d5246dbec --- /dev/null +++ b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/ai/LoginFlowAIManager.java @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.application.mgt.ai; + +import org.json.JSONArray; +import org.json.JSONObject; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIClientException; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIServerException; + +import java.util.Map; + +/** + * AI Manager interface for the LoginFlowAI module. + */ +public interface LoginFlowAIManager { + + /** + * Generates an authentication sequence using the LoginFlow AI service. + * + * @param userQuery The user query. This is a string that contain the requested authentication + * flow by the user. + * @param userClaimsMetaData The user claims metadata. This is a JSON array that contains the user + * claims available + * for that organization. + * @param availableAuthenticators The available authenticators of the organization. + * @return Operation ID of the generated authentication sequence. + * @throws AIServerException When a server error occurs while connecting to the LoginFlow AI service. + * @throws AIClientException When a client error occurs while generating the authentication sequence. + */ + String generateAuthenticationSequence(String userQuery, JSONArray userClaimsMetaData, + JSONObject availableAuthenticators) + throws AIServerException, AIClientException; + + /** + * Retrieves the status of the authentication sequence generation operation. + * + * @param operationId The operation ID of the authentication sequence generation operation. + * @return A Json representation of the status' that are completed, pending, or failed. + * @throws AIServerException When a server error occurs while connecting to the LoginFlow AI service. + * @throws AIClientException When a client error occurs while retrieving the authentication sequence + */ + Map getAuthenticationSequenceGenerationStatus(String operationId) throws AIServerException, + AIClientException; + + /** + * Retrieves the result of the authentication sequence generation operation. + * + * @param operationId The operation ID of the authentication sequence generation operation. + * @return The result of the authentication sequence generation operation. + * @throws AIServerException When a server error occurs while connecting to the LoginFlow AI service. + * @throws AIClientException When a client error occurs while retrieving the authentication sequence + */ + Map getAuthenticationSequenceGenerationResult(String operationId) throws AIServerException, + AIClientException; +} diff --git a/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/ai/LoginFlowAIManagerImpl.java b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/ai/LoginFlowAIManagerImpl.java new file mode 100644 index 000000000000..05c413b84076 --- /dev/null +++ b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/ai/LoginFlowAIManagerImpl.java @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.application.mgt.ai; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.gson.JsonSyntaxException; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.http.client.methods.HttpGet; +import org.apache.http.client.methods.HttpPost; +import org.json.JSONArray; +import org.json.JSONObject; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIClientException; +import org.wso2.carbon.identity.ai.service.mgt.exceptions.AIServerException; +import org.wso2.carbon.identity.core.util.IdentityUtil; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.wso2.carbon.identity.ai.service.mgt.util.AIHttpClientUtil.executeRequest; +import static org.wso2.carbon.identity.application.mgt.ai.constant.LoginFlowAIConstants.AUTHENTICATORS_PROPERTY; +import static org.wso2.carbon.identity.application.mgt.ai.constant.LoginFlowAIConstants.ErrorMessages.CLIENT_ERROR_WHILE_CONNECTING_TO_LOGINFLOW_AI_SERVICE; +import static org.wso2.carbon.identity.application.mgt.ai.constant.LoginFlowAIConstants.OPERATION_ID_PROPERTY; +import static org.wso2.carbon.identity.application.mgt.ai.constant.LoginFlowAIConstants.USER_CLAIM_PROPERTY; +import static org.wso2.carbon.identity.application.mgt.ai.constant.LoginFlowAIConstants.USER_QUERY_PROPERTY; +import static org.wso2.carbon.registry.core.RegistryConstants.PATH_SEPARATOR; + +/** + * Implementation of the LoginFlowAIManager interface to communicate with the LoginFlowAI service. + */ +public class LoginFlowAIManagerImpl implements LoginFlowAIManager { + + private static final String LOGINFLOW_AI_ENDPOINT = IdentityUtil.getProperty( + "AIServices.LoginFlowAI.LoginFlowAIEndpoint"); + private static final String LOGINFLOW_AI_GENERATE_PATH = "/api/server/v1/applications/loginflow/generate"; + private static final String LOGINFLOW_AI_STATUS_PATH = "/api/server/v1/applications/loginflow/status"; + private static final String LOGINFLOW_AI_RESULT_PATH = "/api/server/v1/applications/loginflow/result"; + + private static final Log LOG = LogFactory.getLog(LoginFlowAIManagerImpl.class); + + @Override + public String generateAuthenticationSequence(String userQuery, JSONArray userClaimsMetaData, + JSONObject availableAuthenticators) throws AIServerException, + AIClientException { + + ObjectMapper objectMapper = new ObjectMapper(); + Map requestBody = new HashMap<>(); + requestBody.put(USER_QUERY_PROPERTY, userQuery); + try { + List userClaimsMetadataList = objectMapper.readValue(userClaimsMetaData.toString(), List.class); + requestBody.put(USER_CLAIM_PROPERTY, userClaimsMetadataList); + Map authenticatorsMap = objectMapper.readValue(availableAuthenticators.toString(), + Map.class); + requestBody.put(AUTHENTICATORS_PROPERTY, authenticatorsMap); + } catch (JsonSyntaxException | IOException e) { + throw new AIClientException("Error occurred while parsing the user claims or available " + + "authenticators.", CLIENT_ERROR_WHILE_CONNECTING_TO_LOGINFLOW_AI_SERVICE.getCode(), e); + } + + Map stringObjectMap = executeRequest(LOGINFLOW_AI_ENDPOINT, LOGINFLOW_AI_GENERATE_PATH, + HttpPost.class, requestBody); + return (String) stringObjectMap.get(OPERATION_ID_PROPERTY); + } + + @Override + public Map getAuthenticationSequenceGenerationStatus(String operationId) throws AIServerException, + AIClientException { + + return executeRequest(LOGINFLOW_AI_ENDPOINT, LOGINFLOW_AI_STATUS_PATH + PATH_SEPARATOR + operationId, + HttpGet.class, null); + } + + @Override + public Map getAuthenticationSequenceGenerationResult(String operationId) throws AIServerException, + AIClientException { + + return executeRequest(LOGINFLOW_AI_ENDPOINT, LOGINFLOW_AI_RESULT_PATH + PATH_SEPARATOR + operationId, + HttpGet.class, null); + } +} diff --git a/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/ai/constant/LoginFlowAIConstants.java b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/ai/constant/LoginFlowAIConstants.java new file mode 100644 index 000000000000..f9b5aed008c4 --- /dev/null +++ b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/ai/constant/LoginFlowAIConstants.java @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.application.mgt.ai.constant; + +/** + * Constants for the LoginFlowAI module. + */ +public class LoginFlowAIConstants { + + private LoginFlowAIConstants() { + + } + + public static final String OPERATION_ID_PROPERTY = "operation_id"; + public static final String USER_CLAIM_PROPERTY = "user_claims"; + public static final String USER_QUERY_PROPERTY = "user_query"; + public static final String AUTHENTICATORS_PROPERTY = "available_authenticators"; + + /** + * Enums for error messages. + */ + public enum ErrorMessages { + + CLIENT_ERROR_WHILE_CONNECTING_TO_LOGINFLOW_AI_SERVICE("AILF-10008", "Client error occurred " + + "for %s tenant while generating authentication sequence."), + SERVER_ERROR_WHILE_CONNECTING_TO_LOGINFLOW_AI_SERVICE("AILF-10009", "Server error occurred " + + "for %s tenant while generating authentication sequence."); + + private final String code; + private final String message; + + ErrorMessages(String code, String message) { + + this.code = code; + this.message = message; + } + + public String getCode() { + + return code; + } + + public String getMessage() { + + return message; + } + + @Override + public String toString() { + + return code + ":" + message; + } + } +} diff --git a/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/internal/ApplicationManagementServiceComponent.java b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/internal/ApplicationManagementServiceComponent.java index 986b2374dad9..f2a057ab251c 100644 --- a/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/internal/ApplicationManagementServiceComponent.java +++ b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/main/java/org/wso2/carbon/identity/application/mgt/internal/ApplicationManagementServiceComponent.java @@ -48,6 +48,8 @@ import org.wso2.carbon.identity.application.mgt.AuthorizedAPIManagementService; import org.wso2.carbon.identity.application.mgt.AuthorizedAPIManagementServiceImpl; import org.wso2.carbon.identity.application.mgt.DiscoverableApplicationManager; +import org.wso2.carbon.identity.application.mgt.ai.LoginFlowAIManager; +import org.wso2.carbon.identity.application.mgt.ai.LoginFlowAIManagerImpl; import org.wso2.carbon.identity.application.mgt.defaultsequence.DefaultAuthSeqMgtService; import org.wso2.carbon.identity.application.mgt.defaultsequence.DefaultAuthSeqMgtServiceImpl; import org.wso2.carbon.identity.application.mgt.inbound.protocol.ApplicationInboundAuthConfigHandler; @@ -143,6 +145,8 @@ protected void activate(ComponentContext context) { bundleContext.registerService(AuthorizedAPIManagementService.class, new AuthorizedAPIManagementServiceImpl(), null); + bundleContext.registerService(LoginFlowAIManager.class, new LoginFlowAIManagerImpl(), null); + bundleContext.registerService(RoleManagementListener.class, new DefaultRoleManagementListener(), null); bundleContext.registerService(ApplicationMgtListener.class, new DefaultRoleManagementListener(), null); diff --git a/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/test/java/org/wso2/carbon/identity/application/mgt/ai/LoginFlowAIManagerTest.java b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/test/java/org/wso2/carbon/identity/application/mgt/ai/LoginFlowAIManagerTest.java new file mode 100644 index 000000000000..c5cdce546d92 --- /dev/null +++ b/components/application-mgt/org.wso2.carbon.identity.application.mgt/src/test/java/org/wso2/carbon/identity/application/mgt/ai/LoginFlowAIManagerTest.java @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (http://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.wso2.carbon.identity.application.mgt.ai; + +import org.json.JSONArray; +import org.json.JSONObject; +import org.mockito.InjectMocks; +import org.mockito.MockedStatic; +import org.testng.Assert; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; +import org.wso2.carbon.base.CarbonBaseConstants; +import org.wso2.carbon.context.PrivilegedCarbonContext; +import org.wso2.carbon.identity.ai.service.mgt.util.AIHttpClientUtil; +import org.wso2.carbon.identity.common.testng.realm.InMemoryRealmService; +import org.wso2.carbon.identity.core.util.IdentityTenantUtil; +import org.wso2.carbon.user.core.UserStoreException; + +import java.nio.file.Paths; +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.MockitoAnnotations.openMocks; +import static org.wso2.carbon.base.MultitenantConstants.SUPER_TENANT_DOMAIN_NAME; +import static org.wso2.carbon.base.MultitenantConstants.SUPER_TENANT_ID; + +/** + * Test class for LoginFlowAIManager. + */ +public class LoginFlowAIManagerTest { + + private MockedStatic aiHttpClientUtilMockedStatic; + + @InjectMocks + private LoginFlowAIManagerImpl loginFlowAIManager; + + @BeforeMethod + public void setUp() throws UserStoreException { + + openMocks(this); + setCarbonHome(); + setCarbonContextForTenant(SUPER_TENANT_DOMAIN_NAME, SUPER_TENANT_ID); + aiHttpClientUtilMockedStatic = mockStatic(AIHttpClientUtil.class); + } + + @Test + public void testGenerateAuthenticationSequenceSuccess() throws Exception { + + Map response = new HashMap<>(); + response.put("operation_id", "12345"); + mockSuccessfulResponse(response); + String result = loginFlowAIManager.generateAuthenticationSequence("Need username and password as " + + "the first step", new JSONArray(), new JSONObject()); + Assert.assertEquals(result, "12345"); + } + + @Test + public void testGetAuthenticationSequenceGenerationStatusSuccess() throws Exception { + + Map response = new HashMap<>(); + response.put("status", "COMPLETED"); + mockSuccessfulResponse(response); + Object result = loginFlowAIManager.getAuthenticationSequenceGenerationStatus("operation123"); + + Assert.assertTrue(result instanceof Map); + Map resultMap = (Map) result; + Assert.assertEquals(resultMap.get("status"), "COMPLETED"); + } + + @Test + public void testGetAuthenticationSequenceGenerationResultSuccess() throws Exception { + + Map response = new HashMap<>(); + response.put("result", "SUCCESS"); + mockSuccessfulResponse(response); + Object result = loginFlowAIManager.getAuthenticationSequenceGenerationResult("operation123"); + + Assert.assertTrue(result instanceof Map); + Map resultMap = (Map) result; + Assert.assertEquals(resultMap.get("result"), "SUCCESS"); + } + + private void mockSuccessfulResponse(Map responseBody) { + + aiHttpClientUtilMockedStatic.when(() -> AIHttpClientUtil.executeRequest( + any(), any(), any(), any() + )).thenReturn(responseBody); + } + + + private void setCarbonHome() { + + String carbonHome = Paths.get(System.getProperty("user.dir"), "target", "test-classes", + "repository").toString(); + System.setProperty(CarbonBaseConstants.CARBON_HOME, carbonHome); + System.setProperty(CarbonBaseConstants.CARBON_CONFIG_DIR_PATH, Paths.get(carbonHome, "conf").toString()); + } + + private void setCarbonContextForTenant(String tenantDomain, int tenantId) throws UserStoreException { + + PrivilegedCarbonContext.startTenantFlow(); + PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantDomain(tenantDomain); + PrivilegedCarbonContext.getThreadLocalCarbonContext().setTenantId(tenantId); + InMemoryRealmService testSessionRealmService = new InMemoryRealmService(tenantId); + IdentityTenantUtil.setRealmService(testSessionRealmService); + } + + @AfterMethod + public void tearDown() { + + aiHttpClientUtilMockedStatic.close(); + PrivilegedCarbonContext.endTenantFlow(); + } +} diff --git a/features/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt.server.feature/pom.xml b/features/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt.server.feature/pom.xml new file mode 100644 index 000000000000..658532a8018a --- /dev/null +++ b/features/ai-services-mgt/org.wso2.carbon.identity.ai.service.mgt.server.feature/pom.xml @@ -0,0 +1,98 @@ + + + + 4.0.0 + + + org.wso2.carbon.identity.framework + ai-services-mgt-feature + 7.7.93-SNAPSHOT + ../pom.xml + + + org.wso2.carbon.identity.ai.service.mgt.server.feature + pom + AI Service Management Feature + http://wso2.org + This feature contains the core bundles required for AI Service Management Framework + + + + org.wso2.carbon.identity.framework + org.wso2.carbon.identity.ai.service.mgt + + + + + + + org.wso2.maven + carbon-p2-plugin + ${carbon.p2.plugin.version} + + + 4-p2-feature-generation + package + + p2-feature-gen + + + org.wso2.carbon.identity.ai.service.mgt.server + ../../etc/feature.properties + + + org.wso2.carbon.p2.category.type:server + + + + + org.wso2.carbon.identity.framework:org.wso2.carbon.identity.ai.service.mgt + + + + + + + + org.apache.maven.plugins + maven-antrun-plugin + 1.1 + + + clean_target + install + + + + + + + + + run + + + + + + + + diff --git a/features/ai-services-mgt/pom.xml b/features/ai-services-mgt/pom.xml new file mode 100644 index 000000000000..b8cf4a76921e --- /dev/null +++ b/features/ai-services-mgt/pom.xml @@ -0,0 +1,40 @@ + + + + 4.0.0 + + + org.wso2.carbon.identity.framework + identity-framework + 7.7.93-SNAPSHOT + ../../pom.xml + + + ai-services-mgt-feature + pom + WSO2 Identity - AI Service Management Feature + http://wso2.org + + + org.wso2.carbon.identity.ai.service.mgt.server.feature + + + diff --git a/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/identity.xml.j2 b/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/identity.xml.j2 index 129905580e2e..b0312b605a0a 100644 --- a/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/identity.xml.j2 +++ b/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/identity.xml.j2 @@ -4468,4 +4468,18 @@ {% endif %} + + {{ai_services.token_endpoint}} + {{ai_services.key}} + {{ai_services.token_request_retry_count}} + {{ai_services.token_connection_timeout}} + {{ai_services.token_connection_request_timeout}} + {{ai_services.token_connection_socket_timeout}} + {{ai_services.http_connection_pool_size}} + {{ai_services.http_connection_timeout}} + + {{ai_services.login_flow_ai.endpoint}} + + + diff --git a/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/resource-access-control-v2.xml.j2 b/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/resource-access-control-v2.xml.j2 index a50e7751d1c6..220491b45991 100644 --- a/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/resource-access-control-v2.xml.j2 +++ b/features/identity-core/org.wso2.carbon.identity.core.server.feature/resources/resource-access-control-v2.xml.j2 @@ -804,6 +804,11 @@ internal_application_mgt_update + + + internal_application_mgt_update + + internal_organization_admin diff --git a/pom.xml b/pom.xml index 5dbb1187c77b..3f925919e105 100644 --- a/pom.xml +++ b/pom.xml @@ -73,6 +73,7 @@ components/trusted-app-mgt components/rule-mgt components/action-mgt + components/ai-services-mgt components/certificate-mgt features/extension-mgt components/consent-server-configs-mgt @@ -109,6 +110,7 @@ features/trusted-app-mgt features/rule-mgt features/action-mgt + features/ai-services-mgt features/certificate-mgt @@ -721,6 +723,11 @@ org.wso2.carbon.identity.api.resource.mgt.server.feature ${project.version} + + org.wso2.carbon.identity.framework + org.wso2.carbon.identity.ai.service.mgt.server.feature + ${project.version} + com.google.api-client google-api-client @@ -1718,6 +1725,11 @@ org.wso2.carbon.identity.central.log.mgt ${project.version} + + org.wso2.carbon.identity.framework + org.wso2.carbon.identity.ai.service.mgt + ${project.version} + org.wso2.orbit.org.apache.commons commons-compress @@ -1735,6 +1747,13 @@ ${org.wso2.carbon.multitenancy.version} test + + org.wiremock + wiremock + ${wiremock.version} + test + + org.wso2.orbit.com.google.api-services-playintegrity @@ -2006,8 +2025,9 @@ 1.14.0.wso2v1 [1.4.0,2.0.0) - 4.3.3.wso2v1 + 4.4.14.wso2v1 [4.3.0, 5.0.0) + 2.8.9 [2.3.1,3.0.0) 1.3.9 @@ -2023,7 +2043,7 @@ [1.3.0,2.0.0) 1.47.0.wso2v1 [1.47.0.wso2v1,2.0.0) - 4.3.6.wso2v2 + 4.5.13.wso2v1 [4.3.1.wso2v2,5.0.0) 2.6.0.wso2v1 1.1.10 @@ -2106,6 +2126,7 @@ 3.2.5 5.3.1 0.5.2 + 3.9.1 1.8 1.8