Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.1.x] Fix OpenAPI spec generation issues for U10 patch #1781

Merged
merged 9 commits into from
Nov 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
package io.ballerina.openapi.service.mapper.parameter;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.ConstantSymbol;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.DefaultableParameterNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.openapi.service.mapper.model.OperationInventory;
import io.ballerina.openapi.service.mapper.utils.MapperCommonUtils;
import io.swagger.v3.oas.models.parameters.Parameter;
Expand All @@ -32,6 +30,8 @@
import java.util.Objects;
import java.util.Optional;

import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getConstantValues;

/**
* This {@link AbstractParameterMapper} class represents the abstract parameter mapper.
*
Expand Down Expand Up @@ -60,13 +60,11 @@ public void setParameter() throws ParameterMapperException {
}

static Object getDefaultValue(DefaultableParameterNode parameterNode, SemanticModel semanticModel) {
ExpressionNode defaultValueExpression = (ExpressionNode) parameterNode.expression();
Node defaultValueExpression = parameterNode.expression();
Optional<Symbol> symbol = semanticModel.symbol(defaultValueExpression);
if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) {
Object constValue = constantSymbol.constValue();
if (constValue instanceof ConstantValue value) {
return value.value();
}
Optional<Object> constantValues = getConstantValues(symbol);
if (constantValues.isPresent()) {
return constantValues.get();
}
if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package io.ballerina.openapi.service.mapper.type;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.ConstantSymbol;
import io.ballerina.compiler.api.symbols.IntersectionTypeSymbol;
import io.ballerina.compiler.api.symbols.RecordFieldSymbol;
import io.ballerina.compiler.api.symbols.RecordTypeSymbol;
Expand All @@ -27,7 +26,6 @@
import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.symbols.UnionTypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NodeList;
Expand All @@ -52,6 +50,7 @@
import java.util.Optional;
import java.util.Set;

import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getConstantValues;
import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getRecordFieldTypeDescription;
import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getTypeName;

Expand Down Expand Up @@ -79,7 +78,8 @@ public static Schema getSchema(RecordTypeSymbol typeSymbol, Components component
Set<String> requiredFields = new HashSet<>();

Map<String, RecordFieldSymbol> recordFieldMap = new LinkedHashMap<>(typeSymbol.fieldDescriptors());
List<Schema> allOfSchemaList = mapIncludedRecords(typeSymbol, components, recordFieldMap, additionalData);
List<Schema> allOfSchemaList = mapIncludedRecords(typeSymbol, components, recordFieldMap, additionalData,
recordName);

Map<String, Schema> properties = mapRecordFields(recordFieldMap, components, requiredFields,
recordName, false, additionalData);
Expand Down Expand Up @@ -107,7 +107,7 @@ public static Schema getSchema(RecordTypeSymbol typeSymbol, Components component

static List<Schema> mapIncludedRecords(RecordTypeSymbol typeSymbol, Components components,
Map<String, RecordFieldSymbol> recordFieldMap,
AdditionalData additionalData) {
AdditionalData additionalData, String recordName) {
List<Schema> allOfSchemaList = new ArrayList<>();
List<TypeSymbol> typeInclusions = typeSymbol.typeInclusions();
for (TypeSymbol typeInclusion : typeInclusions) {
Expand All @@ -126,17 +126,81 @@ static List<Schema> mapIncludedRecords(RecordTypeSymbol typeSymbol, Components c
for (Map.Entry<String, RecordFieldSymbol> includedRecordField : includedRecordFieldMap.entrySet()) {
RecordFieldSymbol recordFieldSymbol = recordFieldMap.get(includedRecordField.getKey());
RecordFieldSymbol includedRecordFieldValue = includedRecordField.getValue();
boolean isRemovableField = recordFieldSymbol != null && includedRecordFieldValue.typeDescriptor()
.equals(recordFieldSymbol.typeDescriptor()) && !recordFieldSymbol.hasDefaultValue();
if (isRemovableField) {
recordFieldMap.remove(includedRecordField.getKey());

if (recordFieldSymbol == null
lnash94 marked this conversation as resolved.
Show resolved Hide resolved
|| !includedRecordFieldValue.typeDescriptor().equals(recordFieldSymbol.typeDescriptor())) {
continue;
}
eliminateRedundantFields(recordFieldMap, additionalData, recordName, typeInclusion,
includedRecordField, recordFieldSymbol, includedRecordFieldValue);
}
}
}
return allOfSchemaList;
}

private static void eliminateRedundantFields(Map<String, RecordFieldSymbol> recordFieldMap,
AdditionalData additionalData, String recordName,
TypeSymbol typeInclusion,
Map.Entry<String, RecordFieldSymbol> includedRecordField,
RecordFieldSymbol recordFieldSymbol,
RecordFieldSymbol includedRecordFieldValue) {

boolean recordHasDefault = recordFieldSymbol.hasDefaultValue();
boolean includedHasDefault = includedRecordFieldValue.hasDefaultValue();
boolean hasTypeInclusionName = typeInclusion.getName().isPresent();

if (recordHasDefault && includedHasDefault && hasTypeInclusionName) {
Optional<Object> recordFieldDefaultValueOpt = getRecordFieldDefaultValue(recordName,
includedRecordField.getKey(), additionalData.moduleMemberVisitor(),
additionalData.semanticModel());

Optional<Object> includedFieldDefaultValueOpt = getRecordFieldDefaultValue(
typeInclusion.getName().get(), includedRecordField.getKey(),
additionalData.moduleMemberVisitor(), additionalData.semanticModel());

/*
This check the scenarios
ex:
type RecA record {|
string a = "a";
string aa;
|};
type RecD record {|
*RecA;
string a = "aad";
int d;
|};
*/
boolean defaultsAreEqual = recordFieldDefaultValueOpt.isPresent()
&& includedFieldDefaultValueOpt.isPresent()
&& recordFieldDefaultValueOpt.get().toString()
.equals(includedFieldDefaultValueOpt.get().toString());

/*
This checks the scenario where RecA has `a` defaultable field. In this case, the
.hasDefaultValue() API returns true for both records, but RecA provides the value of the default.
ex:
type RecA record {|
string a = "a";
string aa;
|};
type RecB record {|
*RecA;
int b;
|};
*/
boolean onlyIncludedHasDefault = recordFieldDefaultValueOpt.isEmpty() &&
includedFieldDefaultValueOpt.isPresent();

if (defaultsAreEqual || onlyIncludedHasDefault) {
recordFieldMap.remove(includedRecordField.getKey());
}
} else if (!recordHasDefault && !includedHasDefault) {
recordFieldMap.remove(includedRecordField.getKey());
}
}

public static Map<String, Schema> mapRecordFields(Map<String, RecordFieldSymbol> recordFieldMap,
Components components, Set<String> requiredFields,
String recordName, boolean treatNilableAsOptional,
Expand Down Expand Up @@ -172,7 +236,7 @@ public static Map<String, Schema> mapRecordFields(Map<String, RecordFieldSymbol>
}

public static Optional<Object> getRecordFieldDefaultValue(String recordName, String fieldName,
ModuleMemberVisitor moduleMemberVisitor,
ModuleMemberVisitor moduleMemberVisitor,
SemanticModel semanticModel) {
Optional<TypeDefinitionNode> recordDefNodeOpt = moduleMemberVisitor.getTypeDefinitionNode(recordName);
if (recordDefNodeOpt.isPresent() &&
Expand All @@ -196,11 +260,9 @@ private static Optional<Object> getRecordFieldDefaultValue(String fieldName,
}
ExpressionNode defaultValueExpression = defaultValueNode.expression();
Optional<Symbol> symbol = semanticModel.symbol(defaultValueExpression);
if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) {
Object constValue = constantSymbol.constValue();
if (constValue instanceof ConstantValue value) {
return Optional.of(value.value());
}
Optional<Object> value = getConstantValues(symbol);
if (value.isPresent()) {
return value;
}
if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.symbols.UnionTypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.AnnotationNode;
import io.ballerina.compiler.syntax.tree.BasicLiteralNode;
import io.ballerina.compiler.syntax.tree.DefaultableParameterNode;
Expand Down Expand Up @@ -555,6 +556,16 @@ public static Optional<ResourceFunction> getResourceFunction(Node function) {
return Optional.empty();
}

public static Optional<Object> getConstantValues(Optional<Symbol> symbol) {
if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) {
Object constValue = constantSymbol.constValue();
if (constValue instanceof ConstantValue value) {
return Optional.of(value.value());
}
}
return Optional.empty();
}

public static Node getTypeDescriptor(TypeDefinitionNode typeDefinitionNode) {
Node node = typeDefinitionNode.typeDescriptor();
if (node instanceof DistinctTypeDescriptorNode distinctTypeDescriptorNode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ servers:
port:
default: "7080"
paths:
/Pods:
/pods:
get:
operationId: getPods
responses:
Expand All @@ -22,7 +22,7 @@ paths:
type: array
items:
$ref: "#/components/schemas/Pod"
/Services:
/services:
get:
operationId: getServices
responses:
Expand All @@ -34,6 +34,42 @@ paths:
type: array
items:
$ref: "#/components/schemas/Service"
/recB:
get:
operationId: getRecb
responses:
"200":
description: Ok
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/RecB"
/recC:
get:
operationId: getRecc
responses:
"200":
description: Ok
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/RecC"
/recD:
post:
operationId: postRecd
responses:
"201":
description: Created
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/RecD"
components:
schemas:
Metadata:
Expand Down Expand Up @@ -68,6 +104,59 @@ components:
properties:
nodeName:
type: string
RecA:
required:
- aa
type: object
properties:
a:
type: string
default: a
aa:
type: string
additionalProperties: false
RecB:
type: object
allOf:
- $ref: "#/components/schemas/RecA"
- required:
- b
type: object
properties:
b:
type: integer
format: int64
additionalProperties: false
RecC:
type: object
allOf:
- $ref: "#/components/schemas/RecA"
- required:
- c
type: object
properties:
aa:
type: string
default: aa
c:
type: integer
format: int64
additionalProperties: false
RecD:
type: object
allOf:
- $ref: "#/components/schemas/RecA"
- required:
- d
type: object
properties:
a:
type: string
default: aad
d:
type: integer
format: int64
additionalProperties: false
Resource:
type: object
allOf:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,46 @@ public type Pod record {
Status status?;
};

type RecA record {|
string a = "a";
string aa;
|};

type RecB record {|
*RecA;
int b;
|};

type RecC record {|
*RecA;
string aa = "aa";
int c;
|};

type RecD record {|
*RecA;
string a = "aad";
int d;
|};

service /payloadV on new http:Listener(7080) {
resource function get Pods() returns Pod[] {
resource function get pods() returns Pod[] {
return [];
}

resource function get services() returns Service[] {
return [];
}

resource function get recB() returns RecB[] {
return [];
}

resource function get recC() returns RecC[] {
return [];
}

resource function get Services() returns Service[] {
resource function post recD() returns RecD[] {
return [];
}
}
Loading