diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/aot/JdbcRuntimeHints.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/aot/JdbcRuntimeHints.java index ca9c551cb9..46daafb46d 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/aot/JdbcRuntimeHints.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/aot/JdbcRuntimeHints.java @@ -16,11 +16,13 @@ package org.springframework.data.jdbc.aot; import java.util.Arrays; +import java.util.UUID; import org.springframework.aot.hint.MemberCategory; import org.springframework.aot.hint.RuntimeHints; import org.springframework.aot.hint.RuntimeHintsRegistrar; import org.springframework.aot.hint.TypeReference; +import org.springframework.data.jdbc.core.dialect.JdbcPostgresDialect; import org.springframework.data.jdbc.repository.support.SimpleJdbcRepository; import org.springframework.data.relational.auditing.RelationalAuditingCallback; import org.springframework.data.relational.core.mapping.event.AfterConvertCallback; @@ -54,5 +56,14 @@ public void registerHints(RuntimeHints hints, @Nullable ClassLoader classLoader) TypeReference.of("org.springframework.aop.SpringProxy"), TypeReference.of("org.springframework.aop.framework.Advised"), TypeReference.of("org.springframework.core.DecoratingProxy")); + + hints.reflection().registerType(TypeReference.of("org.postgresql.jdbc.TypeInfoCache"), + MemberCategory.PUBLIC_CLASSES); + + for (Class simpleType : JdbcPostgresDialect.INSTANCE.simpleTypes()) { + hints.reflection().registerType(TypeReference.of(simpleType), MemberCategory.PUBLIC_CLASSES); + } + + hints.reflection().registerType(TypeReference.of(UUID.class.getName()), MemberCategory.PUBLIC_CLASSES); } } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultJdbcTypeFactory.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultJdbcTypeFactory.java index dd39e9568b..eff492803f 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultJdbcTypeFactory.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/DefaultJdbcTypeFactory.java @@ -18,7 +18,6 @@ import java.sql.Array; import java.sql.SQLType; -import org.springframework.data.jdbc.support.JdbcUtil; import org.springframework.jdbc.core.ConnectionCallback; import org.springframework.jdbc.core.JdbcOperations; import org.springframework.util.Assert; @@ -66,9 +65,9 @@ public Array createArray(Object[] value) { Assert.notNull(value, "Value must not be null"); Class componentType = arrayColumns.getArrayType(value.getClass()); + SQLType jdbcType = arrayColumns.getSqlType(componentType); - SQLType jdbcType = JdbcUtil.targetSqlTypeFor(componentType); - Assert.notNull(jdbcType, () -> String.format("Couldn't determine JDBCType for %s", componentType)); + Assert.notNull(jdbcType, () -> String.format("Couldn't determine SQLType for %s", componentType)); String typeName = arrayColumns.getArrayTypeName(jdbcType); return operations.execute((ConnectionCallback) c -> c.createArrayOf(typeName, value)); diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/JdbcArrayColumns.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/JdbcArrayColumns.java index fa12638e11..146ef51c04 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/JdbcArrayColumns.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/convert/JdbcArrayColumns.java @@ -17,6 +17,7 @@ import java.sql.SQLType; +import org.springframework.data.jdbc.support.JdbcUtil; import org.springframework.data.relational.core.dialect.ArrayColumns; /** @@ -33,6 +34,17 @@ default Class getArrayType(Class userType) { return ArrayColumns.unwrapComponentType(userType); } + /** + * Determine the {@link SQLType} for a given {@link Class array component type}. + * + * @param componentType component type of the array. + * @return the dialect-supported array type. + * @since 3.1.3 + */ + default SQLType getSqlType(Class componentType) { + return JdbcUtil.targetSqlTypeFor(getArrayType(componentType)); + } + /** * The appropriate SQL type as a String which should be used to represent the given {@link SQLType} in an * {@link java.sql.Array}. Defaults to the name of the argument. diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/dialect/JdbcPostgresDialect.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/dialect/JdbcPostgresDialect.java index 138ab5873c..03008725aa 100644 --- a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/dialect/JdbcPostgresDialect.java +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/dialect/JdbcPostgresDialect.java @@ -15,16 +15,29 @@ */ package org.springframework.data.jdbc.core.dialect; +import java.sql.Array; import java.sql.JDBCType; +import java.sql.SQLException; import java.sql.SQLType; +import java.sql.Types; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.UUID; +import org.postgresql.core.Oid; +import org.postgresql.jdbc.TypeInfoCache; import org.springframework.data.jdbc.core.convert.JdbcArrayColumns; import org.springframework.data.relational.core.dialect.PostgresDialect; +import org.springframework.util.ClassUtils; /** * JDBC specific Postgres Dialect. * * @author Jens Schauder + * @author Mark Paluch * @since 2.3 */ public class JdbcPostgresDialect extends PostgresDialect implements JdbcDialect { @@ -40,11 +53,31 @@ public JdbcArrayColumns getArraySupport() { static class JdbcPostgresArrayColumns implements JdbcArrayColumns { + private static final boolean TYPE_INFO_PRESENT = ClassUtils.isPresent("org.postgresql.jdbc.TypeInfoCache", + JdbcPostgresDialect.class.getClassLoader()); + + private static final TypeInfoWrapper TYPE_INFO_WRAPPER; + + static { + TYPE_INFO_WRAPPER = TYPE_INFO_PRESENT ? new TypeInfoCacheWrapper() : new TypeInfoWrapper(); + } + @Override public boolean isSupported() { return true; } + @Override + public SQLType getSqlType(Class componentType) { + + SQLType sqlType = TYPE_INFO_WRAPPER.getArrayTypeMap().get(componentType); + if (sqlType != null) { + return sqlType; + } + + return JdbcArrayColumns.super.getSqlType(componentType); + } + @Override public String getArrayTypeName(SQLType jdbcType) { @@ -58,4 +91,92 @@ public String getArrayTypeName(SQLType jdbcType) { return jdbcType.getName(); } } + + /** + * Wrapper for Postgres types. Defaults to no-op to guard runtimes against absent TypeInfoCache. + * + * @since 3.1.3 + */ + static class TypeInfoWrapper { + + /** + * @return a type map between a Java array component type and its Postgres type. + */ + Map, SQLType> getArrayTypeMap() { + return Collections.emptyMap(); + } + } + + /** + * {@link TypeInfoWrapper} backed by {@link TypeInfoCache}. + * + * @since 3.1.3 + */ + static class TypeInfoCacheWrapper extends TypeInfoWrapper { + + private final Map, SQLType> arrayTypes = new HashMap<>(); + + public TypeInfoCacheWrapper() { + + TypeInfoCache cache = new TypeInfoCache(null, 0); + addWellKnownTypes(cache); + + Iterator it = cache.getPGTypeNamesWithSQLTypes(); + + try { + + while (it.hasNext()) { + + String pgTypeName = it.next(); + int oid = cache.getPGType(pgTypeName); + String javaClassName = cache.getJavaClass(oid); + int arrayOid = cache.getJavaArrayType(pgTypeName); + + if (!ClassUtils.isPresent(javaClassName, getClass().getClassLoader())) { + continue; + } + + Class javaClass = ClassUtils.forName(javaClassName, getClass().getClassLoader()); + + // avoid accidental usage of smaller database types that map to the same Java type or generic-typed SQL + // arrays. + if (javaClass == Array.class || javaClass == String.class || javaClass == Integer.class || oid == Oid.OID + || oid == Oid.MONEY) { + continue; + } + + arrayTypes.put(javaClass, new PGSQLType(pgTypeName, arrayOid)); + } + } catch (SQLException | ClassNotFoundException e) { + throw new IllegalStateException("Cannot create type info mapping", e); + } + } + + private static void addWellKnownTypes(TypeInfoCache cache) { + cache.addCoreType("uuid", Oid.UUID, Types.OTHER, UUID.class.getName(), Oid.UUID_ARRAY); + } + + @Override + Map, SQLType> getArrayTypeMap() { + return arrayTypes; + } + + record PGSQLType(String name, int oid) implements SQLType { + + @Override + public String getName() { + return name; + } + + @Override + public String getVendor() { + return "Postgres"; + } + + @Override + public Integer getVendorTypeNumber() { + return oid; + } + } + } } diff --git a/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/dialect/package-info.java b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/dialect/package-info.java new file mode 100644 index 0000000000..645c30d7c6 --- /dev/null +++ b/spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/dialect/package-info.java @@ -0,0 +1,7 @@ +/** + * JDBC-specific Dialect implementations. + */ +@NonNullApi +package org.springframework.data.jdbc.core.dialect; + +import org.springframework.lang.NonNullApi; diff --git a/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/DefaultJdbcTypeFactoryTest.java b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/DefaultJdbcTypeFactoryTest.java new file mode 100644 index 0000000000..f549a93ab5 --- /dev/null +++ b/spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/convert/DefaultJdbcTypeFactoryTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.jdbc.core.convert; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import java.sql.Array; +import java.sql.SQLException; +import java.util.UUID; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.postgresql.core.BaseConnection; +import org.springframework.data.jdbc.core.dialect.JdbcPostgresDialect; +import org.springframework.jdbc.core.ConnectionCallback; +import org.springframework.jdbc.core.JdbcOperations; + +/** + * Unit tests for {@link DefaultJdbcTypeFactory}. + * + * @author Mark Paluch + */ +@ExtendWith(MockitoExtension.class) +class DefaultJdbcTypeFactoryTest { + + @Mock JdbcOperations operations; + @Mock BaseConnection connection; + + @Test // GH-1567 + void shouldProvidePostgresArrayType() throws SQLException { + + DefaultJdbcTypeFactory sut = new DefaultJdbcTypeFactory(operations, JdbcPostgresDialect.INSTANCE.getArraySupport()); + + when(operations.execute(any(ConnectionCallback.class))).thenAnswer(invocation -> { + + ConnectionCallback callback = invocation.getArgument(0, ConnectionCallback.class); + return callback.doInConnection(connection); + }); + + UUID uuids[] = new UUID[] { UUID.randomUUID(), UUID.randomUUID() }; + when(connection.createArrayOf("uuid", uuids)).thenReturn(mock(Array.class)); + Array array = sut.createArray(uuids); + + assertThat(array).isNotNull(); + } + +}