diff --git a/core/src/main/scala/net/nmoncho/helenus/api/cql/ScalaPreparedStatement.scala b/core/src/main/scala/net/nmoncho/helenus/api/cql/ScalaPreparedStatement.scala index a8df8d4..7772868 100644 --- a/core/src/main/scala/net/nmoncho/helenus/api/cql/ScalaPreparedStatement.scala +++ b/core/src/main/scala/net/nmoncho/helenus/api/cql/ScalaPreparedStatement.scala @@ -28,12 +28,10 @@ import scala.concurrent.ExecutionContext import scala.concurrent.Future import com.datastax.oss.driver.api.core.CqlSession -import com.datastax.oss.driver.api.core.`type`.UserDefinedType import com.datastax.oss.driver.api.core.`type`.codec.TypeCodec import com.datastax.oss.driver.api.core.cql._ import net.nmoncho.helenus.api.RowMapper import net.nmoncho.helenus.api.cql.ScalaPreparedStatement.ScalaBoundStatement -import net.nmoncho.helenus.internal.codec.udt.UDTCodec import net.nmoncho.helenus.internal.cql.AdaptedScalaPreparedStatement import org.slf4j.LoggerFactory @@ -110,18 +108,7 @@ abstract class ScalaPreparedStatement[In, Out](pstmt: PreparedStatement, mapper: } actualParams.iterator().asScala.zip(codecs.iterator).zipWithIndex.foreach { case ((param, codec), idx) => - val check = param.getType == codec.getCqlType - - val areEquals = (param.getType, codec.getCqlType) match { - // Give this type another chance of being equals by only considering types - case (paramType: UserDefinedType, codecType: UserDefinedType) if !check && codec.isInstanceOf[UDTCodec[_]] => - paramType.getFieldTypes == codecType.getFieldTypes - - case _ => - check - } - - if (!areEquals) { + if (!codec.accepts(param)) { log.warn("Invalid PreparedStatement expected parameter with type {} at index {} but got type {}", param.getType.toString, idx.toString, codec.getCqlType.toString) } } diff --git a/core/src/main/scala/net/nmoncho/helenus/internal/codec/udt/IdenticalUDTCodec.scala b/core/src/main/scala/net/nmoncho/helenus/internal/codec/udt/IdenticalUDTCodec.scala index c6e6256..1272bbf 100644 --- a/core/src/main/scala/net/nmoncho/helenus/internal/codec/udt/IdenticalUDTCodec.scala +++ b/core/src/main/scala/net/nmoncho/helenus/internal/codec/udt/IdenticalUDTCodec.scala @@ -31,6 +31,7 @@ import scala.reflect.ClassTag import com.datastax.oss.driver.api.core.CqlIdentifier import com.datastax.oss.driver.api.core.ProtocolVersion import com.datastax.oss.driver.api.core.`type`.DataType +import com.datastax.oss.driver.api.core.`type`.UserDefinedType import com.datastax.oss.driver.api.core.`type`.codec.TypeCodec import com.datastax.oss.driver.api.core.`type`.reflect.GenericType import com.datastax.oss.driver.internal.core.`type`.DefaultUserDefinedType @@ -149,6 +150,13 @@ object IdenticalUDTCodec { ) } + override def accepts(cqlType: DataType): Boolean = cqlType match { + case udt: UserDefinedType => + udt.getFieldTypes == getCqlType.asInstanceOf[UserDefinedType].getFieldTypes + + case _ => false + } + override def encode(value: A, protocolVersion: ProtocolVersion): ByteBuffer = if (value == null) null else {