diff --git a/src/tech/v3/datatype/struct.clj b/src/tech/v3/datatype/struct.clj index 855be1ea..41b01efa 100644 --- a/src/tech/v3/datatype/struct.clj +++ b/src/tech/v3/datatype/struct.clj @@ -42,11 +42,13 @@ user> *2 [clj-commons.primitive-math :as pmath]) (:import [tech.v3.datatype BinaryBuffer ObjectBuffer BooleanBuffer LongBuffer DoubleBuffer] - [ham_fisted Casts ITypedReduce] + [ham_fisted Casts ITypedReduce ChunkedList] [java.util.concurrent ConcurrentHashMap] [java.util RandomAccess List Map LinkedHashSet Collection LinkedHashMap] - [clojure.lang MapEntry IObj IFn ILookup])) + [clojure.lang MapEntry IObj IFn ILookup + IFn$OOLO IFn$OOLOO IFn$OOLL IFn$OOLD + IFn$OOLLO IFn$OOLDO])) (set! *warn-on-reflection* true) @@ -206,6 +208,7 @@ user> *2 (declare struct->buffer) (declare inplace-new-struct) +(declare inplace-new-array-of-structs) (defn- host-flatten @@ -215,71 +218,128 @@ user> *2 (casting/host-flatten dt))) + (defn- create-accessors [struct-def] (let [accessors (LinkedHashMap.) layout (get struct-def :data-layout)] (reduce (fn [acc layout-entry] (let [dtype (get layout-entry :datatype) - offset (long (get layout-entry :offset))] + offset (long (get layout-entry :offset)) + ^Accessor scalar-acc + (if (struct-datatype? dtype) + (let [sdef (get-struct-def dtype) + offset (long (get layout-entry :offset)) + dsize (long (get sdef :datatype-size))] + (Accessor. (fn [buffer bin-buffer ^long idx] + (->> (dtype-proto/sub-buffer buffer (+ offset (* dsize idx)) dsize) + (inplace-new-struct dtype))) + (fn [buffer bin-buffer ^long idx val] + (dtype-cmc/copy! (struct->buffer val) + (dtype-proto/sub-buffer buffer (+ offset (* dsize idx)) dsize)) + nil))) + (let [host-dtype (host-flatten dtype) + unsigned? (casting/unsigned-integer-type? dtype)] + (if unsigned? + (case host-dtype + :int8 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx] + (unchecked-long (Byte/toUnsignedInt (.readBinByte reader (+ offset idx))))) + (fn [buffer ^BinaryBuffer writer ^long idx ^long val] + (.writeBinByte writer (+ idx offset) (unchecked-byte (Casts/longCast val))))) + :int16 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx] + (unchecked-long (Short/toUnsignedInt (.readBinShort reader (+ offset (* idx 2)))))) + (fn [buffer ^BinaryBuffer writer ^long idx ^long val] + (.writeBinShort writer (+ offset (* idx 2)) (unchecked-short (Casts/longCast val))))) + :int32 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx] + (unchecked-long (Integer/toUnsignedLong (.readBinInt reader (+ offset (* idx 4)))))) + (fn [buffer ^BinaryBuffer writer ^long idx ^long val] + (.writeBinInt writer (+ offset (* idx 4)) (unchecked-int (Casts/longCast val))))) + :int64 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx] + (.readBinLong reader (+ offset (* idx 8)))) + (fn [buffer ^BinaryBuffer writer ^long idx ^long val] + (.writeBinLong writer (+ offset (* idx 8)) (Casts/longCast val))))) + (case host-dtype + :int8 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx] + (unchecked-long (.readBinByte reader (+ offset idx)))) + (fn [buffer ^BinaryBuffer writer ^long idx ^long ^long val] + (.writeBinByte writer (+ offset idx) (byte (Casts/longCast val))))) + :int16 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx] + (unchecked-long (.readBinShort reader (+ offset (* 2 idx))))) + (fn [buffer ^BinaryBuffer writer ^long idx ^long val] + (.writeBinShort writer (+ offset (* 2 idx)) (short (Casts/longCast val))))) + :int32 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx] + (unchecked-long (.readBinInt reader (+ offset (* 4 idx))))) + (fn [buffer ^BinaryBuffer writer ^long idx ^long val] + (.writeBinInt writer (+ offset (* 4 idx)) (int (Casts/longCast val))))) + :int64 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx] + (.readBinLong reader (+ offset (* 8 idx)))) + (fn [buffer ^BinaryBuffer writer ^long idx ^long val] + (.writeBinLong writer (+ offset (* 8 idx)) (Casts/longCast val)))) + :float32 (Accessor. (fn ^double [buffer ^BinaryBuffer reader ^long idx] + (double (.readBinFloat reader (+ offset (* 4 idx))))) + (fn [buffer ^BinaryBuffer writer ^long idx ^double val] + (.writeBinFloat writer (+ offset (* 4 idx)) (float (Casts/doubleCast val))))) + :float64 (Accessor. (fn ^double [buffer ^BinaryBuffer reader ^long idx] + (.readBinDouble reader (+ offset (* 8 idx)))) + (fn [buffer ^BinaryBuffer writer ^long idx ^double val] + (.writeBinDouble writer (+ offset (* 8 idx)) (Casts/doubleCast val)))))))) + n-elems (long (get layout-entry :n-elems))] (.put accessors (get layout-entry :name) - (if (struct-datatype? dtype) - (let [sdef (get-struct-def dtype) - offset (long (get layout-entry :offset)) - dsize (long (get sdef :datatype-size))] - (Accessor. (fn [buffer bin-buffer] - (->> (dtype-proto/sub-buffer buffer offset dsize) - (inplace-new-struct dtype))) - (fn [buffer bin-buffer val] - (dtype-cmc/copy! (struct->buffer val) - (dtype-proto/sub-buffer buffer offset dsize)) - nil))) - (let [host-dtype (host-flatten dtype) - unsigned? (casting/unsigned-integer-type? dtype)] - (if unsigned? - (case host-dtype - :int8 (Accessor. (fn [buffer ^BinaryBuffer reader] - (unchecked-short (Byte/toUnsignedInt (.readBinByte reader offset)))) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinByte writer offset (unchecked-byte (Casts/longCast val))))) - :int16 (Accessor. (fn [buffer ^BinaryBuffer reader] - (Short/toUnsignedInt (.readBinShort reader offset))) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinShort writer offset (unchecked-short (Casts/longCast val))))) - :int32 (Accessor. (fn [buffer ^BinaryBuffer reader] - (Integer/toUnsignedLong (.readBinInt reader offset))) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinInt writer offset (unchecked-int (Casts/longCast val))))) - :int64 (Accessor. (fn [buffer ^BinaryBuffer reader] - (.readBinLong reader offset)) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinLong writer offset (Casts/longCast val))))) - (case host-dtype - :int8 (Accessor. (fn [buffer ^BinaryBuffer reader] - (.readBinByte reader offset)) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinByte writer offset (byte (Casts/longCast val))))) - :int16 (Accessor. (fn [buffer ^BinaryBuffer reader] - (.readBinShort reader offset)) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinShort writer offset (short (Casts/longCast val))))) - :int32 (Accessor. (fn [buffer ^BinaryBuffer reader] - (.readBinInt reader offset)) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinInt writer offset (int (Casts/longCast val))))) - :int64 (Accessor. (fn [buffer ^BinaryBuffer reader] - (.readBinLong reader offset)) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinLong writer offset (Casts/longCast val)))) - :float32 (Accessor. (fn [buffer ^BinaryBuffer reader] - (.readBinFloat reader offset)) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinFloat writer offset (float (Casts/doubleCast val))))) - :float64 (Accessor. (fn [buffer ^BinaryBuffer reader] - (.readBinDouble reader offset)) - (fn [buffer ^BinaryBuffer writer val] - (.writeBinDouble writer offset (Casts/doubleCast val))))))))))) + (if (== 1 n-elems) + (let [read-fn (.reader scalar-acc) + write-fn (.writer scalar-acc)] + (Accessor. #(read-fn %1 %2 0) + #(write-fn %1 %2 0 %3))) + (case (casting/simple-operation-space dtype) + :int64 + (let [read-fn ^IFn$OOLL (.reader scalar-acc) + write-fn ^IFn$OOLLO (.writer scalar-acc) + _ (assert (and (instance? IFn$OOLL read-fn) + (instance? IFn$OOLLO write-fn)) + (str "Datatype " dtype " created invalid acccessr"))] + (Accessor. #(reify LongBuffer + (elemwiseDatatype [_] dtype) + (lsize [_] n-elems) + (readLong [_ idx] + (.invokePrim read-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx))) + (writeLong [_ idx v] + (.invokePrim write-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx) v))) + #(throw (Exception. "Bulk set of array properties not supported yet - use read to get writable list")))) + :float64 + (let [read-fn ^IFn$OOLD (.reader scalar-acc) + write-fn ^IFn$OOLDO (.writer scalar-acc) + _ (assert (and (instance? IFn$OOLD read-fn) + (instance? IFn$OOLDO write-fn)) + (str "Datatype " dtype " created invalid acccessr"))] + (Accessor. #(reify DoubleBuffer + (elemwiseDatatype [_] dtype) + (lsize [_] n-elems) + (readDouble [_ idx] + (.invokePrim read-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx))) + (writeDouble [_ idx v] + (.invokePrim write-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx) v))) + #(throw (Exception. "Bulk set of array properties not supported yet - use read to get writable list")))) + ;;Then return a new array of structs + (if (struct-datatype? dtype) + (let [sdef (get-struct-def dtype) + dsize (long (get sdef :datatype-size))] + (Accessor. (fn [buffer bin-buffer] + (inplace-new-array-of-structs dtype (dtype-proto/sub-buffer buffer offset (* n-elems dsize)))) + #(throw (Exception. "Bulk set of array properties not supported yet - use read to get writable list")))) + (let [read-fn ^IFn$OOLO (.reader scalar-acc) + write-fn ^IFn$OOLOO (.writer scalar-acc) + _ (assert (and (instance? IFn$OOLO read-fn) + (instance? IFn$OOLOO write-fn)) + (str "Datatype " dtype " created invalid acccessr"))] + (Accessor. #(reify ObjectBuffer + (elemwiseDatatype [_] dtype) + (lsize [_] n-elems) + (readObject [_ idx] + (.invokePrim read-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx))) + (writeObject [_ idx v] + (.invokePrim write-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx) v))) + #(throw (Exception. "Bulk set of array properties not supported yet - use read to get writable list")))))))))) nil layout) (assoc struct-def :accessors accessors))) diff --git a/test/tech/v3/datatype/struct_test.clj b/test/tech/v3/datatype/struct_test.clj index 1f95ec9b..7399b723 100644 --- a/test/tech/v3/datatype/struct_test.clj +++ b/test/tech/v3/datatype/struct_test.clj @@ -30,3 +30,22 @@ sarray (dt-struct/new-array-of-structs :ptr-types 10)] (is (= (vec (repeat 10 0)) (vec (dt-struct/array-of-structs->column sarray :a)))))) + + +(deftest get-set-array-members + (let [sdef (dt-struct/define-datatype! :array-member [{:name :x :datatype :int32 :n-elems 10}]) + vec-type (dt-struct/define-datatype! :vec3 [{:name :x :datatype :float32} + {:name :y :datatype :float32} + {:name :z :datatype :float32}]) + ttype (dt-struct/define-datatype! :triangle [{:name :pts :datatype :vec3 :n-elems 3}]) + sdata (dt-struct/new-struct :array-member) + tdata (dt-struct/new-struct :triangle) + x (get sdata :x) + _ (do + ;;sdata'x X member is a reified LongBuffer so you can set the value via getting it then dtype/copy! + (dtype/copy! (range 10) x) + (is (= (vec (range 10)) (get sdata :x)))) + ys (dt-struct/array-of-structs->column (:pts tdata) :y) + _ (do + (dtype/copy! (range 1 4) ys) + (is (= 2.0 (get-in tdata [:pts 1 :y]))))]))