diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java index 4bfba1ca9f..0921801876 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/MappingR2dbcConverter.java @@ -22,15 +22,6 @@ import io.r2dbc.spi.ReadableMetadata; import io.r2dbc.spi.Row; import io.r2dbc.spi.RowMetadata; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.Optional; -import java.util.function.BiFunction; - import org.springframework.core.convert.ConversionService; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.convert.CustomConversions; @@ -52,6 +43,16 @@ import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.function.BiFunction; + +import static org.springframework.data.r2dbc.convert.RowMetadataUtils.findColumnMetadata; + /** * Converter for R2DBC. * @@ -140,21 +141,24 @@ private static void captureRowValues(Readable row, @Nullable Iterable propertyType = property.getType(); + String column = columnMetadata == null ? identifier : columnMetadata.getName(); if (propertyType.equals(Clob.class)) { - value = row.get(identifier, Clob.class); + value = row.get(column, Clob.class); } else if (propertyType.equals(Blob.class)) { - value = row.get(identifier, Blob.class); + value = row.get(column, Blob.class); } else { - value = row.get(identifier); + value = row.get(column); } document.put(identifier, value); @@ -475,9 +479,9 @@ private boolean potentiallySetId(Row row, RowMetadata metadata, PersistentProper @Nullable private Object extractGeneratedIdentifier(Row row, RowMetadata metadata, String idColumnName) { - - if (RowMetadataUtils.containsColumn(metadata, idColumnName)) { - return row.get(idColumnName); + ReadableMetadata columnMetadata = findColumnMetadata(RowMetadataUtils.getColumnMetadata(metadata), idColumnName); + if (columnMetadata != null) { + return row.get(columnMetadata.getName()); } Iterable columns = RowMetadataUtils.getColumnMetadata(metadata); diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/RowMetadataUtils.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/RowMetadataUtils.java index b4ea7dc1f4..0ab714352a 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/RowMetadataUtils.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/RowMetadataUtils.java @@ -18,15 +18,17 @@ import io.r2dbc.spi.ColumnMetadata; import io.r2dbc.spi.ReadableMetadata; import io.r2dbc.spi.RowMetadata; +import org.springframework.data.util.ParsingUtils; +import org.springframework.lang.Nullable; /** * Utility methods for {@link io.r2dbc.spi.RowMetadata} * * @author Mark Paluch + * @author kfyty725 * @since 1.3.7 */ class RowMetadataUtils { - /** * Check whether the column {@code name} is contained in {@link RowMetadata}. The check happens case-insensitive. * @@ -46,14 +48,32 @@ public static boolean containsColumn(RowMetadata metadata, String name) { * @return {@code true} if the metadata contains the column {@code name}. */ public static boolean containsColumn(Iterable columns, String name) { + return findColumnMetadata(columns, name) != null; + } + /** + * Query matching {@link ColumnMetadata} from name + *

+ * This method will check the column name of property and the name of property. + * Because when use alias in sql, the name of the property maybe equals to alias in sql, and the column name of property + * are not equals to alias in sql. + * + * @param columns the metadata to inspect. + * @param name column name. + * @return the column metadata. + */ + @Nullable + public static ReadableMetadata findColumnMetadata(Iterable columns, String name) { for (ReadableMetadata columnMetadata : columns) { if (name.equalsIgnoreCase(columnMetadata.getName())) { - return true; + return columnMetadata; + } + String columnName = ParsingUtils.reconcatenateCamelCase(columnMetadata.getName(), "_"); + if (name.equalsIgnoreCase(columnName)) { + return columnMetadata; } } - - return false; + return null; } /** @@ -63,7 +83,6 @@ public static boolean containsColumn(Iterable column * @return * @since 1.4.1 */ - @SuppressWarnings("unchecked") public static Iterable getColumnMetadata(RowMetadata metadata) { return metadata.getColumnMetadatas(); } diff --git a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/RowPropertyAccessor.java b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/RowPropertyAccessor.java index b52c5a6f43..fbaac34cdd 100644 --- a/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/RowPropertyAccessor.java +++ b/spring-data-r2dbc/src/main/java/org/springframework/data/r2dbc/convert/RowPropertyAccessor.java @@ -15,9 +15,9 @@ */ package org.springframework.data.r2dbc.convert; +import io.r2dbc.spi.ReadableMetadata; import io.r2dbc.spi.Row; import io.r2dbc.spi.RowMetadata; - import org.springframework.expression.EvaluationContext; import org.springframework.expression.PropertyAccessor; import org.springframework.expression.TypedValue; @@ -54,7 +54,16 @@ public TypedValue read(EvaluationContext context, @Nullable Object target, Strin return TypedValue.NULL; } - Object value = ((Row) target).get(name); + String column = name; + + if (rowMetadata != null) { + ReadableMetadata columnMetadata = RowMetadataUtils.findColumnMetadata(RowMetadataUtils.getColumnMetadata(rowMetadata), name); + if (columnMetadata != null) { + column = columnMetadata.getName(); + } + } + + Object value = ((Row) target).get(column); if (value == null) { return TypedValue.NULL; diff --git a/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/alias/StringBasedAliasQueryUnitTests.java b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/alias/StringBasedAliasQueryUnitTests.java new file mode 100644 index 0000000000..74f095887d --- /dev/null +++ b/spring-data-r2dbc/src/test/java/org/springframework/data/r2dbc/repository/query/alias/StringBasedAliasQueryUnitTests.java @@ -0,0 +1,114 @@ +package org.springframework.data.r2dbc.repository.query.alias; + +import io.r2dbc.h2.H2ConnectionFactory; +import io.r2dbc.spi.ConnectionFactory; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.annotation.Id; +import org.springframework.data.annotation.Transient; +import org.springframework.data.domain.Persistable; +import org.springframework.data.r2dbc.core.R2dbcEntityTemplate; +import org.springframework.data.r2dbc.repository.Query; +import org.springframework.data.r2dbc.repository.R2dbcRepository; +import org.springframework.data.r2dbc.repository.config.EnableR2dbcRepositories; +import org.springframework.data.relational.core.mapping.Table; +import org.springframework.stereotype.Repository; +import reactor.core.publisher.Mono; + +/** + * 描述: Unit test for {@link org.springframework.data.r2dbc.repository.Query} + * + * @author kfyty725 + * @date 2023/12/13 14:55 + * @email kfyty725@hotmail.com + */ +@Configuration +@EnableR2dbcRepositories(considerNestedRepositories = true) +public class StringBasedAliasQueryUnitTests { + private AnnotationConfigApplicationContext context; + private UserRepository userRepository; + + @Before + public void before() { + context = new AnnotationConfigApplicationContext(StringBasedAliasQueryUnitTests.class); + + userRepository = context.getBean(UserRepository.class); + + init(); + } + + public void init() { + R2dbcEntityTemplate r2dbcEntityTemplate = context.getBean(R2dbcEntityTemplate.class); + r2dbcEntityTemplate.getDatabaseClient().sql("create table person (id bigint not null, fans_num int not null,primary key(id));").fetch().one().block(); + } + + @Test + public void aliasQueryTest() { + User insert = this.userRepository.save(new User(1L, 1, true)).block(); + User get = this.userRepository.getById(1L).block(); + Assert.assertEquals(insert.getFansNum(), get.getFansNum()); // equals + } + + @Bean + public ConnectionFactory connectionFactory() { + System.setProperty("h2.caseInsensitiveIdentifiers", "true"); + System.setProperty("h2.databaseToUpper", "false"); + System.setProperty("h2.databaseToLower", "false"); + return H2ConnectionFactory.inMemory("test"); + } + + @Bean + public R2dbcEntityTemplate r2dbcEntityTemplate(ConnectionFactory connectionFactory) { + return new R2dbcEntityTemplate(connectionFactory); + } + + @Repository + private interface UserRepository extends R2dbcRepository { + @Query("select id, fans_num as fansNum from person where id = :id") + Mono getById(Long id); + } + + @Table("person") + static class User implements Persistable { + @Id + private Long id; + private Integer fansNum; + + @Transient + private boolean isNew; + + public User() { + } + + public User(Long id, Integer fansNum, boolean isNew) { + this.id = id; + this.fansNum = fansNum; + this.isNew = isNew; + } + + @Override + public boolean isNew() { + return isNew; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public Integer getFansNum() { + return fansNum; + } + + public void setFansNum(Integer fansNum) { + this.fansNum = fansNum; + } + } +}