Skip to content

Commit

Permalink
Fix review suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
lnash94 committed Nov 13, 2024
1 parent 6d672f4 commit acff157
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
package io.ballerina.openapi.service.mapper.parameter;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.ConstantSymbol;
import io.ballerina.compiler.api.symbols.Symbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.DefaultableParameterNode;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.openapi.service.mapper.model.OperationInventory;
import io.ballerina.openapi.service.mapper.utils.MapperCommonUtils;
import io.swagger.v3.oas.models.parameters.Parameter;
Expand All @@ -32,6 +31,8 @@
import java.util.Objects;
import java.util.Optional;

import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getConstantValues;

/**
* This {@link AbstractParameterMapper} class represents the abstract parameter mapper.
*
Expand Down Expand Up @@ -60,13 +61,11 @@ public void setParameter() throws ParameterMapperException {
}

static Object getDefaultValue(DefaultableParameterNode parameterNode, SemanticModel semanticModel) {
ExpressionNode defaultValueExpression = (ExpressionNode) parameterNode.expression();
Node defaultValueExpression = parameterNode.expression();
Optional<Symbol> symbol = semanticModel.symbol(defaultValueExpression);
if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) {
Object constValue = constantSymbol.constValue();
if (constValue instanceof ConstantValue value) {
return value.value();
}
Optional<Object> constantValues = getConstantValues(symbol);
if (constantValues.isPresent()) {
return constantValues.get();
}
if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package io.ballerina.openapi.service.mapper.type;

import io.ballerina.compiler.api.SemanticModel;
import io.ballerina.compiler.api.symbols.ConstantSymbol;
import io.ballerina.compiler.api.symbols.IntersectionTypeSymbol;
import io.ballerina.compiler.api.symbols.RecordFieldSymbol;
import io.ballerina.compiler.api.symbols.RecordTypeSymbol;
Expand All @@ -27,7 +26,6 @@
import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.symbols.UnionTypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.ExpressionNode;
import io.ballerina.compiler.syntax.tree.Node;
import io.ballerina.compiler.syntax.tree.NodeList;
Expand All @@ -52,6 +50,7 @@
import java.util.Optional;
import java.util.Set;

import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getConstantValues;
import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getRecordFieldTypeDescription;
import static io.ballerina.openapi.service.mapper.utils.MapperCommonUtils.getTypeName;

Expand Down Expand Up @@ -79,7 +78,8 @@ public static Schema getSchema(RecordTypeSymbol typeSymbol, Components component
Set<String> requiredFields = new HashSet<>();

Map<String, RecordFieldSymbol> recordFieldMap = new LinkedHashMap<>(typeSymbol.fieldDescriptors());
List<Schema> allOfSchemaList = mapIncludedRecords(typeSymbol, components, recordFieldMap, additionalData);
List<Schema> allOfSchemaList = mapIncludedRecords(typeSymbol, components, recordFieldMap, additionalData,
recordName);

Map<String, Schema> properties = mapRecordFields(recordFieldMap, components, requiredFields,
recordName, false, additionalData);
Expand Down Expand Up @@ -107,7 +107,7 @@ public static Schema getSchema(RecordTypeSymbol typeSymbol, Components component

static List<Schema> mapIncludedRecords(RecordTypeSymbol typeSymbol, Components components,
Map<String, RecordFieldSymbol> recordFieldMap,
AdditionalData additionalData) {
AdditionalData additionalData, String recordName) {
List<Schema> allOfSchemaList = new ArrayList<>();
List<TypeSymbol> typeInclusions = typeSymbol.typeInclusions();
for (TypeSymbol typeInclusion : typeInclusions) {
Expand All @@ -126,9 +126,63 @@ static List<Schema> mapIncludedRecords(RecordTypeSymbol typeSymbol, Components c
for (Map.Entry<String, RecordFieldSymbol> includedRecordField : includedRecordFieldMap.entrySet()) {
RecordFieldSymbol recordFieldSymbol = recordFieldMap.get(includedRecordField.getKey());
RecordFieldSymbol includedRecordFieldValue = includedRecordField.getValue();
boolean isRemovableField = recordFieldSymbol != null && includedRecordFieldValue.typeDescriptor()
.equals(recordFieldSymbol.typeDescriptor()) && !recordFieldSymbol.hasDefaultValue();
if (isRemovableField) {
if (recordFieldSymbol == null) {
continue;
}
if (!includedRecordFieldValue.typeDescriptor().equals(recordFieldSymbol.typeDescriptor())) {
continue;
}
boolean recordHasDefault = recordFieldSymbol.hasDefaultValue();
boolean includedHasDefault = includedRecordFieldValue.hasDefaultValue();
boolean hasTypeInclusionName = typeInclusion.getName().isPresent();
if (recordHasDefault && includedHasDefault && hasTypeInclusionName) {
Optional<Object> recordFieldDefaultValueOpt = getRecordFieldDefaultValue(recordName,
includedRecordField.getKey(), additionalData.moduleMemberVisitor(),
additionalData.semanticModel());

Optional<Object> includedFieldDefaultValueOpt = getRecordFieldDefaultValue(
typeInclusion.getName().get(), includedRecordField.getKey(),
additionalData.moduleMemberVisitor(), additionalData.semanticModel());

/*
This check the scenarios
ex:
type RecA record {|
string a = "a";
string aa;
|};
type RecD record {|
*RecA;
string a = "aad";
int d;
|};
*/
boolean defaultsAreEqual = recordFieldDefaultValueOpt.isPresent()
&& includedFieldDefaultValueOpt.isPresent()
&& recordFieldDefaultValueOpt.get().toString()
.equals(includedFieldDefaultValueOpt.get().toString());

/*
This check the scenarios: RecA has deflatable field, in here both records
`.hasDefaultValue()` api return `true` but RecA gives the value of the default value
ex:
type RecA record {|
string a = "a";
string aa;
|};
type RecB record {|
*RecA;
int b;
|};
*/
boolean onlyIncludedHasDefault = recordFieldDefaultValueOpt.isEmpty() &&
includedFieldDefaultValueOpt.isPresent();

if (defaultsAreEqual || onlyIncludedHasDefault) {
recordFieldMap.remove(includedRecordField.getKey());
}

} else if (!recordHasDefault && !includedHasDefault) {
recordFieldMap.remove(includedRecordField.getKey());
}
}
Expand Down Expand Up @@ -172,7 +226,7 @@ public static Map<String, Schema> mapRecordFields(Map<String, RecordFieldSymbol>
}

public static Optional<Object> getRecordFieldDefaultValue(String recordName, String fieldName,
ModuleMemberVisitor moduleMemberVisitor,
ModuleMemberVisitor moduleMemberVisitor,
SemanticModel semanticModel) {
Optional<TypeDefinitionNode> recordDefNodeOpt = moduleMemberVisitor.getTypeDefinitionNode(recordName);
if (recordDefNodeOpt.isPresent() &&
Expand All @@ -196,11 +250,9 @@ private static Optional<Object> getRecordFieldDefaultValue(String fieldName,
}
ExpressionNode defaultValueExpression = defaultValueNode.expression();
Optional<Symbol> symbol = semanticModel.symbol(defaultValueExpression);
if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) {
Object constValue = constantSymbol.constValue();
if (constValue instanceof ConstantValue value) {
return Optional.of(value.value());
}
Optional<Object> value = getConstantValues(symbol);
if (value.isPresent()) {
return value;
}
if (MapperCommonUtils.isNotSimpleValueLiteralKind(defaultValueExpression.kind())) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import io.ballerina.compiler.api.symbols.TypeReferenceTypeSymbol;
import io.ballerina.compiler.api.symbols.TypeSymbol;
import io.ballerina.compiler.api.symbols.UnionTypeSymbol;
import io.ballerina.compiler.api.values.ConstantValue;
import io.ballerina.compiler.syntax.tree.AnnotationNode;
import io.ballerina.compiler.syntax.tree.BasicLiteralNode;
import io.ballerina.compiler.syntax.tree.DefaultableParameterNode;
Expand Down Expand Up @@ -555,6 +556,16 @@ public static Optional<ResourceFunction> getResourceFunction(Node function) {
return Optional.empty();
}

public static Optional<Object> getConstantValues(Optional<Symbol> symbol) {
if (symbol.isPresent() && symbol.get() instanceof ConstantSymbol constantSymbol) {
Object constValue = constantSymbol.constValue();
if (constValue instanceof ConstantValue value) {
return Optional.of(value.value());
}
}
return Optional.empty();
}

public static Node getTypeDescriptor(TypeDefinitionNode typeDefinitionNode) {
Node node = typeDefinitionNode.typeDescriptor();
if (node instanceof DistinctTypeDescriptorNode distinctTypeDescriptorNode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ servers:
port:
default: "7080"
paths:
/Pods:
/pods:
get:
operationId: getPods
responses:
Expand All @@ -22,7 +22,7 @@ paths:
type: array
items:
$ref: "#/components/schemas/Pod"
/Services:
/services:
get:
operationId: getServices
responses:
Expand All @@ -34,6 +34,42 @@ paths:
type: array
items:
$ref: "#/components/schemas/Service"
/recB:
get:
operationId: getRecb
responses:
"200":
description: Ok
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/RecB"
/recC:
get:
operationId: getRecc
responses:
"200":
description: Ok
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/RecC"
/recD:
post:
operationId: postRecd
responses:
"201":
description: Created
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/RecD"
components:
schemas:
Metadata:
Expand Down Expand Up @@ -68,6 +104,59 @@ components:
properties:
nodeName:
type: string
RecA:
required:
- aa
type: object
properties:
a:
type: string
default: a
aa:
type: string
additionalProperties: false
RecB:
type: object
allOf:
- $ref: "#/components/schemas/RecA"
- required:
- b
type: object
properties:
b:
type: integer
format: int64
additionalProperties: false
RecC:
type: object
allOf:
- $ref: "#/components/schemas/RecA"
- required:
- c
type: object
properties:
aa:
type: string
default: aa
c:
type: integer
format: int64
additionalProperties: false
RecD:
type: object
allOf:
- $ref: "#/components/schemas/RecA"
- required:
- d
type: object
properties:
a:
type: string
default: aad
d:
type: integer
format: int64
additionalProperties: false
Resource:
type: object
allOf:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,46 @@ public type Pod record {
Status status?;
};

type RecA record {|
string a = "a";
string aa;
|};

type RecB record {|
*RecA;
int b;
|};

type RecC record {|
*RecA;
string aa = "aa";
int c;
|};

type RecD record {|
*RecA;
string a = "aad";
int d;
|};

service /payloadV on new http:Listener(7080) {
resource function get Pods() returns Pod[] {
resource function get pods() returns Pod[] {
return [];
}

resource function get services() returns Service[] {
return [];
}

resource function get recB() returns RecB[] {
return [];
}

resource function get recC() returns RecC[] {
return [];
}

resource function get Services() returns Service[] {
resource function post recD() returns RecD[] {
return [];
}
}

0 comments on commit acff157

Please sign in to comment.