Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[2.1.x] Fix OpenAPI spec generation issues for U10 patch #1781

Merged
merged 9 commits into from
Nov 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package io.ballerina.openapi.service.mapper.parameter;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.syntax.tree.DefaultableParameterNode;
import io.ballerina.compiler.syntax.tree.Node;
Expand All @@ -27,6 +28,9 @@

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getConstantValues;

/**
* This {@link AbstractParameterMapper} class represents the abstract parameter mapper.
Expand Down Expand Up @@ -55,8 +59,13 @@ public void setParameter() throws ParameterMapperException {
parameterList.forEach(operationInventory::setParameter);
}

static Object getDefaultValue(DefaultableParameterNode parameterNode) {
static Object getDefaultValue(DefaultableParameterNode parameterNode, SemanticModel semanticModel) {
Node defaultValueExpression = parameterNode.expression();
Optional<Symbol> symbol = semanticModel.symbol(defaultValueExpression);
Optional<Object> constantValues = getConstantValues(symbol);
if (constantValues.isPresent()) {
return constantValues.get();
}
if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) {
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ public HeaderParameterMapper(ParameterNode parameterNode, Map<String, String> ap
this.description = apiDocs.get(headerParameter.getName().get());
this.treatNilableAsOptional = treatNilableAsOptional;
if (parameterNode instanceof DefaultableParameterNode defaultableHeaderParam) {
this.defaultValue = AbstractParameterMapper.getDefaultValue(defaultableHeaderParam);
this.defaultValue = AbstractParameterMapper.getDefaultValue(defaultableHeaderParam,
additionalData.semanticModel());
}
this.typeMapper = typeMapper;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public QueryParameterMapper(ParameterNode parameterNode, Map<String, String> api
this.semanticModel = additionalData.semanticModel();
this.typeMapper = typeMapper;
if (parameterNode instanceof DefaultableParameterNode defaultableQueryParam) {
this.defaultValue = AbstractParameterMapper.getDefaultValue(defaultableQueryParam);
this.defaultValue = AbstractParameterMapper.getDefaultValue(defaultableQueryParam,
additionalData.semanticModel());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
*/
package io.ballerina.openapi.service.mapper.type;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.IntersectionTypeSymbol;
import io.ballerina.compiler.api.symbols.RecordFieldSymbol;
import io.ballerina.compiler.api.symbols.RecordTypeSymbol;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.TypeDescKind;
import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
Expand Down Expand Up @@ -48,6 +50,7 @@
import java.util.Optional;
import java.util.Set;

import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getConstantValues;
import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getRecordFieldTypeDescription;
import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getTypeName;

Expand All @@ -58,7 +61,6 @@
* @since 1.9.0
*/
public class RecordTypeMapper extends AbstractTypeMapper {

public RecordTypeMapper(TypeReferenceTypeSymbol typeSymbol, AdditionalData additionalData) {
super(typeSymbol, additionalData);
}
Expand All @@ -71,14 +73,19 @@ public Schema getReferenceSchema(Components components) {

public static Schema getSchema(RecordTypeSymbol typeSymbol, Components components, String recordName,
AdditionalData additionalData) {
Set<String> fieldsOnlyForRequiredList = new HashSet<>();
ObjectSchema schema = new ObjectSchema();
Set<String> requiredFields = new HashSet<>();

Map<String, RecordFieldSymbol> recordFieldMap = new LinkedHashMap<>(typeSymbol.fieldDescriptors());
List<Schema> allOfSchemaList = mapIncludedRecords(typeSymbol, components, recordFieldMap, additionalData);
List<Schema> allOfSchemaList = mapIncludedRecords(typeSymbol, components, recordFieldMap, additionalData,
recordName, fieldsOnlyForRequiredList);

RecordFieldMappingContext mappingContext = new RecordFieldMappingContext(
recordFieldMap, components, requiredFields, recordName, false, additionalData,
fieldsOnlyForRequiredList);

Map<String, Schema> properties = mapRecordFields(recordFieldMap, components, requiredFields,
recordName, false, additionalData);
Map<String, Schema> properties = mapRecordFields(mappingContext);

Optional<TypeSymbol> restFieldType = typeSymbol.restTypeDescriptor();
if (restFieldType.isPresent()) {
Expand All @@ -90,8 +97,8 @@ public static Schema getSchema(RecordTypeSymbol typeSymbol, Components component
schema.additionalProperties(false);
}

schema.setProperties(properties);
schema.setRequired(requiredFields.stream().toList());
schema.setProperties(properties);
if (!allOfSchemaList.isEmpty()) {
ObjectSchema schemaWithAllOf = new ObjectSchema();
allOfSchemaList.add(schema);
Expand All @@ -103,7 +110,8 @@ public static Schema getSchema(RecordTypeSymbol typeSymbol, Components component

static List<Schema> mapIncludedRecords(RecordTypeSymbol typeSymbol, Components components,
Map<String, RecordFieldSymbol> recordFieldMap,
AdditionalData additionalData) {
AdditionalData additionalData, String recordName,
Set<String> fieldsOnlyForRequiredList) {
List<Schema> allOfSchemaList = new ArrayList<>();
List<TypeSymbol> typeInclusions = typeSymbol.typeInclusions();
for (TypeSymbol typeInclusion : typeInclusions) {
Expand All @@ -120,25 +128,156 @@ static List<Schema> mapIncludedRecords(RecordTypeSymbol typeSymbol, Components c
.typeDescriptor();
Map<String, RecordFieldSymbol> includedRecordFieldMap = includedRecordTypeSymbol.fieldDescriptors();
for (Map.Entry<String, RecordFieldSymbol> includedRecordField : includedRecordFieldMap.entrySet()) {
recordFieldMap.remove(includedRecordField.getKey());
if (!recordFieldMap.containsKey(includedRecordField.getKey())) {
continue;
}
RecordFieldSymbol recordFieldSymbol = recordFieldMap.get(includedRecordField.getKey());
RecordFieldSymbol includedRecordFieldValue = includedRecordField.getValue();

if (!includedRecordFieldValue.typeDescriptor().equals(recordFieldSymbol.typeDescriptor())) {
continue;
}
IncludedFieldContext context = new IncludedFieldContext(recordFieldMap, recordName,
typeInclusion, includedRecordField, recordFieldSymbol, includedRecordFieldValue
);
eliminateRedundantFields(context, additionalData, fieldsOnlyForRequiredList);
}
}
}
return allOfSchemaList;
}

public static Map<String, Schema> mapRecordFields(Map<String, RecordFieldSymbol> recordFieldMap,
Components components, Set<String> requiredFields,
String recordName, boolean treatNilableAsOptional,
AdditionalData additionalData) {
private static void eliminateRedundantFields(IncludedFieldContext context, AdditionalData additionalData,
Set<String> fieldsOnlyForRequiredList) {
Map<String, RecordFieldSymbol> recordFieldMap = context.recordFieldMap();
String recordName = context.recordName();
TypeSymbol typeInclusion = context.typeInclusion();
Map.Entry<String, RecordFieldSymbol> includedRecordField = context.includedRecordField();
RecordFieldSymbol recordFieldSymbol = context.recordFieldSymbol();
RecordFieldSymbol includedRecordFieldValue = context.includedRecordFieldValue();

boolean recordHasDefault = recordFieldSymbol.hasDefaultValue();
boolean includedHasDefault = includedRecordFieldValue.hasDefaultValue();
boolean hasTypeInclusionName = typeInclusion.getName().isPresent();
boolean isIncludedOptional = includedRecordFieldValue.isOptional();
boolean isRecordFieldOptional = recordFieldSymbol.isOptional();
boolean recordFieldName = recordFieldSymbol.getName().isPresent();

if (recordHasDefault && includedHasDefault && hasTypeInclusionName) {
Optional<Object> recordFieldDefaultValueOpt = getRecordFieldDefaultValue(recordName,
includedRecordField.getKey(), additionalData.moduleMemberVisitor(),
additionalData.semanticModel());

Optional<Object> includedFieldDefaultValueOpt = getRecordFieldDefaultValue(
typeInclusion.getName().get(), includedRecordField.getKey(),
additionalData.moduleMemberVisitor(), additionalData.semanticModel());

/*
This check the scenarios
ex:
type RecA record {|
string a = "a";
string aa;
|};
type RecD record {|
*RecA;
string a = "aad";
int d;
|};
*/
boolean defaultsAreEqual = recordFieldDefaultValueOpt.isPresent()
&& includedFieldDefaultValueOpt.isPresent()
&& recordFieldDefaultValueOpt.get().toString()
.equals(includedFieldDefaultValueOpt.get().toString());

/*
This checks the scenario where RecA has `a` defaultable field. In this case, the
.hasDefaultValue() API returns true for both records, but RecA provides the value of the default.
ex:
type RecA record {|
string a = "a";
string aa;
|};
type RecB record {|
*RecA;
int b;
|};
*/
boolean onlyIncludedHasDefault = recordFieldDefaultValueOpt.isEmpty() &&
includedFieldDefaultValueOpt.isPresent();

if (defaultsAreEqual || onlyIncludedHasDefault) {
recordFieldMap.remove(includedRecordField.getKey());
}
} else if (!recordHasDefault && !includedHasDefault) {
recordFieldMap.remove(includedRecordField.getKey());
}
if (!isRecordFieldOptional && isIncludedOptional && !recordHasDefault && recordFieldName) {
fieldsOnlyForRequiredList.add(MapperCommonUtils.unescapeIdentifier(recordFieldSymbol.getName().get()));
recordFieldMap.remove(includedRecordField.getKey());
}
}

/**
* Encapsulates the context of included fields in a record for processing.
*
* @param recordFieldMap A map containing record field symbols.
* @param recordName The name of the record being processed.
* @param typeInclusion The type symbol representing type inclusions in the record.
* @param includedRecordField An entry representing the included record field and its symbol.
* @param recordFieldSymbol The symbol of the current record field being processed.
* @param includedRecordFieldValue The symbol of the field in the included record.
*/
public record IncludedFieldContext(
Map<String, RecordFieldSymbol> recordFieldMap,
String recordName,
TypeSymbol typeInclusion,
Map.Entry<String, RecordFieldSymbol> includedRecordField,
RecordFieldSymbol recordFieldSymbol,
RecordFieldSymbol includedRecordFieldValue) {
}

/**
* Encapsulates the context needed for mapping record fields to schemas.
*
* @param recordFieldMap A map containing record field symbols.
* @param components Components used for managing and storing schemas during mapping.
* @param requiredFields A set of field names that are required in the mapped schema.
* @param recordName The name of the record being processed.
* @param treatNilableAsOptional Flag indicating whether nilable fields should be treated as optional.
* @param additionalData Additional data required for schema generation and field processing.
* @param fieldsOnlyForRequiredList A set of fields that should be exclusively marked as required.
*/
public record RecordFieldMappingContext(
Map<String, RecordFieldSymbol> recordFieldMap,
Components components,
Set<String> requiredFields,
String recordName,
boolean treatNilableAsOptional,
AdditionalData additionalData,
Set<String> fieldsOnlyForRequiredList) {
}

public static Map<String, Schema> mapRecordFields(RecordFieldMappingContext context) {
Map<String, RecordFieldSymbol> recordFieldMap = context.recordFieldMap();
Components components = context.components();
Set<String> requiredFields = context.requiredFields();
String recordName = context.recordName();
boolean treatNilableAsOptional = context.treatNilableAsOptional();
AdditionalData additionalData = context.additionalData();
Set<String> fieldsOnlyForRequiredList = context.fieldsOnlyForRequiredList();
Map<String, Schema> properties = new LinkedHashMap<>();

for (Map.Entry<String, RecordFieldSymbol> recordField : recordFieldMap.entrySet()) {
RecordFieldSymbol recordFieldSymbol = recordField.getValue();
String recordFieldName = MapperCommonUtils.unescapeIdentifier(recordField.getKey().trim());
if (!recordFieldSymbol.isOptional() && !recordFieldSymbol.hasDefaultValue() &&
(!treatNilableAsOptional || !UnionTypeMapper.hasNilableType(recordFieldSymbol.typeDescriptor()))) {
requiredFields.add(recordFieldName);
}
if (!fieldsOnlyForRequiredList.isEmpty()) {
requiredFields.addAll(fieldsOnlyForRequiredList);
}
String recordFieldDescription = getRecordFieldTypeDescription(recordFieldSymbol);
Schema recordFieldSchema = TypeMapperImpl.getTypeSchema(recordFieldSymbol.typeDescriptor(),
components, additionalData);
Expand All @@ -147,7 +286,7 @@ public static Map<String, Schema> mapRecordFields(Map<String, RecordFieldSymbol>
}
if (recordFieldSymbol.hasDefaultValue()) {
Optional<Object> recordFieldDefaultValueOpt = getRecordFieldDefaultValue(recordName, recordFieldName,
additionalData.moduleMemberVisitor());
additionalData.moduleMemberVisitor(), additionalData.semanticModel());
if (recordFieldDefaultValueOpt.isPresent()) {
TypeMapper.setDefaultValue(recordFieldSchema, recordFieldDefaultValueOpt.get());
} else {
Expand All @@ -162,17 +301,19 @@ public static Map<String, Schema> mapRecordFields(Map<String, RecordFieldSymbol>
}

public static Optional<Object> getRecordFieldDefaultValue(String recordName, String fieldName,
ModuleMemberVisitor moduleMemberVisitor) {
ModuleMemberVisitor moduleMemberVisitor,
SemanticModel semanticModel) {
Optional<TypeDefinitionNode> recordDefNodeOpt = moduleMemberVisitor.getTypeDefinitionNode(recordName);
if (recordDefNodeOpt.isPresent() &&
recordDefNodeOpt.get().typeDescriptor() instanceof RecordTypeDescriptorNode recordDefNode) {
return getRecordFieldDefaultValue(fieldName, recordDefNode);
return getRecordFieldDefaultValue(fieldName, recordDefNode, semanticModel);
}
return Optional.empty();
}

private static Optional<Object> getRecordFieldDefaultValue(String fieldName,
RecordTypeDescriptorNode recordDefNode) {
RecordTypeDescriptorNode recordDefNode,
SemanticModel semanticModel) {
NodeList<Node> recordFields = recordDefNode.fields();
RecordFieldWithDefaultValueNode defaultValueNode = recordFields.stream()
.filter(field -> field instanceof RecordFieldWithDefaultValueNode)
Expand All @@ -183,6 +324,11 @@ private static Optional<Object> getRecordFieldDefaultValue(String fieldName,
return Optional.empty();
}
ExpressionNode defaultValueExpression = defaultValueNode.expression();
Optional<Symbol> symbol = semanticModel.symbol(defaultValueExpression);
Optional<Object> value = getConstantValues(symbol);
if (value.isPresent()) {
return value;
}
if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import io.swagger.v3.oas.models.Components;
import io.swagger.v3.oas.models.media.Schema;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

Expand Down Expand Up @@ -113,8 +114,10 @@ protected static void createComponentMapping(TypeReferenceTypeSymbol typeSymbol,
public Map<String, Schema> getSchemaForRecordFields(Map<String, RecordFieldSymbol> recordFieldMap,
Set<String> requiredFields, String recordName,
boolean treatNilableAsOptional) {
return RecordTypeMapper.mapRecordFields(recordFieldMap, components, requiredFields, recordName,
treatNilableAsOptional, componentMapperData);
RecordTypeMapper.RecordFieldMappingContext context = new RecordTypeMapper.RecordFieldMappingContext(
recordFieldMap, components, requiredFields, recordName, treatNilableAsOptional, componentMapperData,
new HashSet<>());
return RecordTypeMapper.mapRecordFields(context);
}

public TypeSymbol getReferredType(TypeSymbol typeSymbol) {
Expand Down
Loading
Loading