From acff157d1983bbc1f6e378fe13a274c449db2f3b Mon Sep 17 00:00:00 2001 From: lnash94 Date: Thu, 14 Nov 2024 00:58:26 +0530 Subject: [PATCH] Fix review suggestions --- .../parameter/AbstractParameterMapper.java | 15 ++- .../service/mapper/type/RecordTypeMapper.java | 78 +++++++++++++--- .../mapper/utils/MapperCommonUtils.java | 11 +++ .../expected_gen/record/included_record.yaml | 93 ++++++++++++++++++- .../record/included_record.bal | 38 +++++++- 5 files changed, 210 insertions(+), 25 deletions(-) diff --git a/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/parameter/AbstractParameterMapper.java b/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/parameter/AbstractParameterMapper.java index cb3f690e9..d1c5dac54 100644 --- a/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/parameter/AbstractParameterMapper.java +++ b/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/parameter/AbstractParameterMapper.java @@ -18,12 +18,11 @@ package io.ballerina.openapi.service.mapper.parameter; import io.ballerina.compiler.api.SemanticModel; -import io.ballerina.compiler.api.symbols.ConstantSymbol; import io.ballerina.compiler.api.symbols.Symbol; import io.ballerina.compiler.api.symbols.TypeSymbol; -import io.ballerina.compiler.api.values.ConstantValue; import io.ballerina.compiler.syntax.tree.DefaultableParameterNode; import io.ballerina.compiler.syntax.tree.ExpressionNode; +import io.ballerina.compiler.syntax.tree.Node; import io.ballerina.openapi.service.mapper.model.OperationInventory; import io.ballerina.openapi.service.mapper.utils.MapperCommonUtils; import io.swagger.v3.oas.models.parameters.Parameter; @@ -32,6 +31,8 @@ import java.util.Objects; import java.util.Optional; +import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getConstantValues; + /** * This {@link AbstractParameterMapper} class represents the abstract parameter mapper. * @@ -60,13 +61,11 @@ public void setParameter() throws ParameterMapperException { } static Object getDefaultValue(DefaultableParameterNode parameterNode, SemanticModel semanticModel) { - ExpressionNode defaultValueExpression = (ExpressionNode) parameterNode.expression(); + Node defaultValueExpression = parameterNode.expression(); Optional symbol = semanticModel.symbol(defaultValueExpression); - if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) { - Object constValue = constantSymbol.constValue(); - if (constValue instanceof ConstantValue value) { - return value.value(); - } + Optional constantValues = getConstantValues(symbol); + if (constantValues.isPresent()) { + return constantValues.get(); } if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) { return null; diff --git a/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/type/RecordTypeMapper.java b/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/type/RecordTypeMapper.java index 852e4c95f..c75fb76e6 100644 --- a/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/type/RecordTypeMapper.java +++ b/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/type/RecordTypeMapper.java @@ -18,7 +18,6 @@ package io.ballerina.openapi.service.mapper.type; import io.ballerina.compiler.api.SemanticModel; -import io.ballerina.compiler.api.symbols.ConstantSymbol; import io.ballerina.compiler.api.symbols.IntersectionTypeSymbol; import io.ballerina.compiler.api.symbols.RecordFieldSymbol; import io.ballerina.compiler.api.symbols.RecordTypeSymbol; @@ -27,7 +26,6 @@ import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol; import io.ballerina.compiler.api.symbols.TypeSymbol; import io.ballerina.compiler.api.symbols.UnionTypeSymbol; -import io.ballerina.compiler.api.values.ConstantValue; import io.ballerina.compiler.syntax.tree.ExpressionNode; import io.ballerina.compiler.syntax.tree.Node; import io.ballerina.compiler.syntax.tree.NodeList; @@ -52,6 +50,7 @@ import java.util.Optional; import java.util.Set; +import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getConstantValues; import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getRecordFieldTypeDescription; import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getTypeName; @@ -79,7 +78,8 @@ public static Schema getSchema(RecordTypeSymbol typeSymbol, Components component Set requiredFields = new HashSet<>(); Map recordFieldMap = new LinkedHashMap<>(typeSymbol.fieldDescriptors()); - List allOfSchemaList = mapIncludedRecords(typeSymbol, components, recordFieldMap, additionalData); + List allOfSchemaList = mapIncludedRecords(typeSymbol, components, recordFieldMap, additionalData, + recordName); Map properties = mapRecordFields(recordFieldMap, components, requiredFields, recordName, false, additionalData); @@ -107,7 +107,7 @@ public static Schema getSchema(RecordTypeSymbol typeSymbol, Components component static List mapIncludedRecords(RecordTypeSymbol typeSymbol, Components components, Map recordFieldMap, - AdditionalData additionalData) { + AdditionalData additionalData, String recordName) { List allOfSchemaList = new ArrayList<>(); List typeInclusions = typeSymbol.typeInclusions(); for (TypeSymbol typeInclusion : typeInclusions) { @@ -126,9 +126,63 @@ static List mapIncludedRecords(RecordTypeSymbol typeSymbol, Components c for (Map.Entry includedRecordField : includedRecordFieldMap.entrySet()) { RecordFieldSymbol recordFieldSymbol = recordFieldMap.get(includedRecordField.getKey()); RecordFieldSymbol includedRecordFieldValue = includedRecordField.getValue(); - boolean isRemovableField = recordFieldSymbol != null && includedRecordFieldValue.typeDescriptor() - .equals(recordFieldSymbol.typeDescriptor()) && !recordFieldSymbol.hasDefaultValue(); - if (isRemovableField) { + if (recordFieldSymbol == null) { + continue; + } + if (!includedRecordFieldValue.typeDescriptor().equals(recordFieldSymbol.typeDescriptor())) { + continue; + } + boolean recordHasDefault = recordFieldSymbol.hasDefaultValue(); + boolean includedHasDefault = includedRecordFieldValue.hasDefaultValue(); + boolean hasTypeInclusionName = typeInclusion.getName().isPresent(); + if (recordHasDefault && includedHasDefault && hasTypeInclusionName) { + Optional recordFieldDefaultValueOpt = getRecordFieldDefaultValue(recordName, + includedRecordField.getKey(), additionalData.moduleMemberVisitor(), + additionalData.semanticModel()); + + Optional includedFieldDefaultValueOpt = getRecordFieldDefaultValue( + typeInclusion.getName().get(), includedRecordField.getKey(), + additionalData.moduleMemberVisitor(), additionalData.semanticModel()); + + /* + This check the scenarios + ex: + type RecA record {| + string a = "a"; + string aa; + |}; + type RecD record {| + *RecA; + string a = "aad"; + int d; + |}; + */ + boolean defaultsAreEqual = recordFieldDefaultValueOpt.isPresent() + && includedFieldDefaultValueOpt.isPresent() + && recordFieldDefaultValueOpt.get().toString() + .equals(includedFieldDefaultValueOpt.get().toString()); + + /* + This check the scenarios: RecA has deflatable field, in here both records + `.hasDefaultValue()` api return `true` but RecA gives the value of the default value + ex: + type RecA record {| + string a = "a"; + string aa; + |}; + type RecB record {| + *RecA; + int b; + |}; + */ + boolean onlyIncludedHasDefault = recordFieldDefaultValueOpt.isEmpty() && + includedFieldDefaultValueOpt.isPresent(); + + if (defaultsAreEqual || onlyIncludedHasDefault) { + recordFieldMap.remove(includedRecordField.getKey()); + } + + } else if (!recordHasDefault && !includedHasDefault) { recordFieldMap.remove(includedRecordField.getKey()); } } @@ -172,7 +226,7 @@ public static Map mapRecordFields(Map } public static Optional getRecordFieldDefaultValue(String recordName, String fieldName, - ModuleMemberVisitor moduleMemberVisitor, + ModuleMemberVisitor moduleMemberVisitor, SemanticModel semanticModel) { Optional recordDefNodeOpt = moduleMemberVisitor.getTypeDefinitionNode(recordName); if (recordDefNodeOpt.isPresent() && @@ -196,11 +250,9 @@ private static Optional getRecordFieldDefaultValue(String fieldName, } ExpressionNode defaultValueExpression = defaultValueNode.expression(); Optional symbol = semanticModel.symbol(defaultValueExpression); - if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) { - Object constValue = constantSymbol.constValue(); - if (constValue instanceof ConstantValue value) { - return Optional.of(value.value()); - } + Optional value = getConstantValues(symbol); + if (value.isPresent()) { + return value; } if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) { return Optional.empty(); diff --git a/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/utils/MapperCommonUtils.java b/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/utils/MapperCommonUtils.java index 808f48354..aa33462e1 100644 --- a/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/utils/MapperCommonUtils.java +++ b/ballerina-to-openapi/src/main/java/io/ballerina/openapi/service/mapper/utils/MapperCommonUtils.java @@ -34,6 +34,7 @@ import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol; import io.ballerina.compiler.api.symbols.TypeSymbol; import io.ballerina.compiler.api.symbols.UnionTypeSymbol; +import io.ballerina.compiler.api.values.ConstantValue; import io.ballerina.compiler.syntax.tree.AnnotationNode; import io.ballerina.compiler.syntax.tree.BasicLiteralNode; import io.ballerina.compiler.syntax.tree.DefaultableParameterNode; @@ -555,6 +556,16 @@ public static Optional getResourceFunction(Node function) { return Optional.empty(); } + public static Optional getConstantValues(Optional symbol) { + if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) { + Object constValue = constantSymbol.constValue(); + if (constValue instanceof ConstantValue value) { + return Optional.of(value.value()); + } + } + return Optional.empty(); + } + public static Node getTypeDescriptor(TypeDefinitionNode typeDefinitionNode) { Node node = typeDefinitionNode.typeDescriptor(); if (node instanceof DistinctTypeDescriptorNode distinctTypeDescriptorNode) { diff --git a/openapi-cli/src/test/resources/ballerina-to-openapi/expected_gen/record/included_record.yaml b/openapi-cli/src/test/resources/ballerina-to-openapi/expected_gen/record/included_record.yaml index d7805374f..efd97b07c 100644 --- a/openapi-cli/src/test/resources/ballerina-to-openapi/expected_gen/record/included_record.yaml +++ b/openapi-cli/src/test/resources/ballerina-to-openapi/expected_gen/record/included_record.yaml @@ -10,7 +10,7 @@ servers: port: default: "7080" paths: - /Pods: + /pods: get: operationId: getPods responses: @@ -22,7 +22,7 @@ paths: type: array items: $ref: "#/components/schemas/Pod" - /Services: + /services: get: operationId: getServices responses: @@ -34,6 +34,42 @@ paths: type: array items: $ref: "#/components/schemas/Service" + /recB: + get: + operationId: getRecb + responses: + "200": + description: Ok + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/RecB" + /recC: + get: + operationId: getRecc + responses: + "200": + description: Ok + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/RecC" + /recD: + post: + operationId: postRecd + responses: + "201": + description: Created + content: + application/json: + schema: + type: array + items: + $ref: "#/components/schemas/RecD" components: schemas: Metadata: @@ -68,6 +104,59 @@ components: properties: nodeName: type: string + RecA: + required: + - aa + type: object + properties: + a: + type: string + default: a + aa: + type: string + additionalProperties: false + RecB: + type: object + allOf: + - $ref: "#/components/schemas/RecA" + - required: + - b + type: object + properties: + b: + type: integer + format: int64 + additionalProperties: false + RecC: + type: object + allOf: + - $ref: "#/components/schemas/RecA" + - required: + - c + type: object + properties: + aa: + type: string + default: aa + c: + type: integer + format: int64 + additionalProperties: false + RecD: + type: object + allOf: + - $ref: "#/components/schemas/RecA" + - required: + - d + type: object + properties: + a: + type: string + default: aad + d: + type: integer + format: int64 + additionalProperties: false Resource: type: object allOf: diff --git a/openapi-cli/src/test/resources/ballerina-to-openapi/record/included_record.bal b/openapi-cli/src/test/resources/ballerina-to-openapi/record/included_record.bal index 22729fb7c..159508de1 100644 --- a/openapi-cli/src/test/resources/ballerina-to-openapi/record/included_record.bal +++ b/openapi-cli/src/test/resources/ballerina-to-openapi/record/included_record.bal @@ -50,12 +50,46 @@ public type Pod record { Status status?; }; +type RecA record {| + string a = "a"; + string aa; +|}; + +type RecB record {| + *RecA; + int b; +|}; + +type RecC record {| + *RecA; + string aa = "aa"; + int c; +|}; + +type RecD record {| + *RecA; + string a = "aad"; + int d; +|}; + service /payloadV on new http:Listener(7080) { - resource function get Pods() returns Pod[] { + resource function get pods() returns Pod[] { + return []; + } + + resource function get services() returns Service[] { + return []; + } + + resource function get recB() returns RecB[] { + return []; + } + + resource function get recC() returns RecC[] { return []; } - resource function get Services() returns Service[] { + resource function post recD() returns RecD[] { return []; } }