Skip to content

Commit

Permalink
Fixes #108
Browse files Browse the repository at this point in the history
  • Loading branch information
cnuernber committed Sep 22, 2024
1 parent 9899297 commit 9a47a50
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 59 deletions.
178 changes: 119 additions & 59 deletions src/tech/v3/datatype/struct.clj
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@ user> *2
[clj-commons.primitive-math :as pmath])
(:import [tech.v3.datatype BinaryBuffer ObjectBuffer BooleanBuffer
LongBuffer DoubleBuffer]
[ham_fisted Casts ITypedReduce]
[ham_fisted Casts ITypedReduce ChunkedList]
[java.util.concurrent ConcurrentHashMap]
[java.util RandomAccess List Map LinkedHashSet Collection
LinkedHashMap]
[clojure.lang MapEntry IObj IFn ILookup]))
[clojure.lang MapEntry IObj IFn ILookup
IFn$OOLO IFn$OOLOO IFn$OOLL IFn$OOLD
IFn$OOLLO IFn$OOLDO]))


(set! *warn-on-reflection* true)
Expand Down Expand Up @@ -206,6 +208,7 @@ user> *2

(declare struct->buffer)
(declare inplace-new-struct)
(declare inplace-new-array-of-structs)


(defn- host-flatten
Expand All @@ -215,71 +218,128 @@ user> *2
(casting/host-flatten dt)))



(defn- create-accessors
[struct-def]
(let [accessors (LinkedHashMap.)
layout (get struct-def :data-layout)]
(reduce (fn [acc layout-entry]
(let [dtype (get layout-entry :datatype)
offset (long (get layout-entry :offset))]
offset (long (get layout-entry :offset))
^Accessor scalar-acc
(if (struct-datatype? dtype)
(let [sdef (get-struct-def dtype)
offset (long (get layout-entry :offset))
dsize (long (get sdef :datatype-size))]
(Accessor. (fn [buffer bin-buffer ^long idx]
(->> (dtype-proto/sub-buffer buffer (+ offset (* dsize idx)) dsize)
(inplace-new-struct dtype)))
(fn [buffer bin-buffer ^long idx val]
(dtype-cmc/copy! (struct->buffer val)
(dtype-proto/sub-buffer buffer (+ offset (* dsize idx)) dsize))
nil)))
(let [host-dtype (host-flatten dtype)
unsigned? (casting/unsigned-integer-type? dtype)]
(if unsigned?
(case host-dtype
:int8 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx]
(unchecked-long (Byte/toUnsignedInt (.readBinByte reader (+ offset idx)))))
(fn [buffer ^BinaryBuffer writer ^long idx ^long val]
(.writeBinByte writer (+ idx offset) (unchecked-byte (Casts/longCast val)))))
:int16 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx]
(unchecked-long (Short/toUnsignedInt (.readBinShort reader (+ offset (* idx 2))))))
(fn [buffer ^BinaryBuffer writer ^long idx ^long val]
(.writeBinShort writer (+ offset (* idx 2)) (unchecked-short (Casts/longCast val)))))
:int32 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx]
(unchecked-long (Integer/toUnsignedLong (.readBinInt reader (+ offset (* idx 4))))))
(fn [buffer ^BinaryBuffer writer ^long idx ^long val]
(.writeBinInt writer (+ offset (* idx 4)) (unchecked-int (Casts/longCast val)))))
:int64 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx]
(.readBinLong reader (+ offset (* idx 8))))
(fn [buffer ^BinaryBuffer writer ^long idx ^long val]
(.writeBinLong writer (+ offset (* idx 8)) (Casts/longCast val)))))
(case host-dtype
:int8 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx]
(unchecked-long (.readBinByte reader (+ offset idx))))
(fn [buffer ^BinaryBuffer writer ^long idx ^long ^long val]
(.writeBinByte writer (+ offset idx) (byte (Casts/longCast val)))))
:int16 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx]
(unchecked-long (.readBinShort reader (+ offset (* 2 idx)))))
(fn [buffer ^BinaryBuffer writer ^long idx ^long val]
(.writeBinShort writer (+ offset (* 2 idx)) (short (Casts/longCast val)))))
:int32 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx]
(unchecked-long (.readBinInt reader (+ offset (* 4 idx)))))
(fn [buffer ^BinaryBuffer writer ^long idx ^long val]
(.writeBinInt writer (+ offset (* 4 idx)) (int (Casts/longCast val)))))
:int64 (Accessor. (fn ^long [buffer ^BinaryBuffer reader ^long idx]
(.readBinLong reader (+ offset (* 8 idx))))
(fn [buffer ^BinaryBuffer writer ^long idx ^long val]
(.writeBinLong writer (+ offset (* 8 idx)) (Casts/longCast val))))
:float32 (Accessor. (fn ^double [buffer ^BinaryBuffer reader ^long idx]
(double (.readBinFloat reader (+ offset (* 4 idx)))))
(fn [buffer ^BinaryBuffer writer ^long idx ^double val]
(.writeBinFloat writer (+ offset (* 4 idx)) (float (Casts/doubleCast val)))))
:float64 (Accessor. (fn ^double [buffer ^BinaryBuffer reader ^long idx]
(.readBinDouble reader (+ offset (* 8 idx))))
(fn [buffer ^BinaryBuffer writer ^long idx ^double val]
(.writeBinDouble writer (+ offset (* 8 idx)) (Casts/doubleCast val))))))))
n-elems (long (get layout-entry :n-elems))]
(.put accessors
(get layout-entry :name)
(if (struct-datatype? dtype)
(let [sdef (get-struct-def dtype)
offset (long (get layout-entry :offset))
dsize (long (get sdef :datatype-size))]
(Accessor. (fn [buffer bin-buffer]
(->> (dtype-proto/sub-buffer buffer offset dsize)
(inplace-new-struct dtype)))
(fn [buffer bin-buffer val]
(dtype-cmc/copy! (struct->buffer val)
(dtype-proto/sub-buffer buffer offset dsize))
nil)))
(let [host-dtype (host-flatten dtype)
unsigned? (casting/unsigned-integer-type? dtype)]
(if unsigned?
(case host-dtype
:int8 (Accessor. (fn [buffer ^BinaryBuffer reader]
(unchecked-short (Byte/toUnsignedInt (.readBinByte reader offset))))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinByte writer offset (unchecked-byte (Casts/longCast val)))))
:int16 (Accessor. (fn [buffer ^BinaryBuffer reader]
(Short/toUnsignedInt (.readBinShort reader offset)))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinShort writer offset (unchecked-short (Casts/longCast val)))))
:int32 (Accessor. (fn [buffer ^BinaryBuffer reader]
(Integer/toUnsignedLong (.readBinInt reader offset)))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinInt writer offset (unchecked-int (Casts/longCast val)))))
:int64 (Accessor. (fn [buffer ^BinaryBuffer reader]
(.readBinLong reader offset))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinLong writer offset (Casts/longCast val)))))
(case host-dtype
:int8 (Accessor. (fn [buffer ^BinaryBuffer reader]
(.readBinByte reader offset))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinByte writer offset (byte (Casts/longCast val)))))
:int16 (Accessor. (fn [buffer ^BinaryBuffer reader]
(.readBinShort reader offset))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinShort writer offset (short (Casts/longCast val)))))
:int32 (Accessor. (fn [buffer ^BinaryBuffer reader]
(.readBinInt reader offset))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinInt writer offset (int (Casts/longCast val)))))
:int64 (Accessor. (fn [buffer ^BinaryBuffer reader]
(.readBinLong reader offset))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinLong writer offset (Casts/longCast val))))
:float32 (Accessor. (fn [buffer ^BinaryBuffer reader]
(.readBinFloat reader offset))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinFloat writer offset (float (Casts/doubleCast val)))))
:float64 (Accessor. (fn [buffer ^BinaryBuffer reader]
(.readBinDouble reader offset))
(fn [buffer ^BinaryBuffer writer val]
(.writeBinDouble writer offset (Casts/doubleCast val)))))))))))
(if (== 1 n-elems)
(let [read-fn (.reader scalar-acc)
write-fn (.writer scalar-acc)]
(Accessor. #(read-fn %1 %2 0)
#(write-fn %1 %2 0 %3)))
(case (casting/simple-operation-space dtype)
:int64
(let [read-fn ^IFn$OOLL (.reader scalar-acc)
write-fn ^IFn$OOLLO (.writer scalar-acc)
_ (assert (and (instance? IFn$OOLL read-fn)
(instance? IFn$OOLLO write-fn))
(str "Datatype " dtype " created invalid acccessr"))]
(Accessor. #(reify LongBuffer
(elemwiseDatatype [_] dtype)
(lsize [_] n-elems)
(readLong [_ idx]
(.invokePrim read-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx)))
(writeLong [_ idx v]
(.invokePrim write-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx) v)))
#(throw (Exception. "Bulk set of array properties not supported yet - use read to get writable list"))))
:float64
(let [read-fn ^IFn$OOLD (.reader scalar-acc)
write-fn ^IFn$OOLDO (.writer scalar-acc)
_ (assert (and (instance? IFn$OOLD read-fn)
(instance? IFn$OOLDO write-fn))
(str "Datatype " dtype " created invalid acccessr"))]
(Accessor. #(reify DoubleBuffer
(elemwiseDatatype [_] dtype)
(lsize [_] n-elems)
(readDouble [_ idx]
(.invokePrim read-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx)))
(writeDouble [_ idx v]
(.invokePrim write-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx) v)))
#(throw (Exception. "Bulk set of array properties not supported yet - use read to get writable list"))))
;;Then return a new array of structs
(if (struct-datatype? dtype)
(let [sdef (get-struct-def dtype)
dsize (long (get sdef :datatype-size))]
(Accessor. (fn [buffer bin-buffer]
(inplace-new-array-of-structs dtype (dtype-proto/sub-buffer buffer offset (* n-elems dsize))))
#(throw (Exception. "Bulk set of array properties not supported yet - use read to get writable list"))))
(let [read-fn ^IFn$OOLO (.reader scalar-acc)
write-fn ^IFn$OOLOO (.writer scalar-acc)
_ (assert (and (instance? IFn$OOLO read-fn)
(instance? IFn$OOLOO write-fn))
(str "Datatype " dtype " created invalid acccessr"))]
(Accessor. #(reify ObjectBuffer
(elemwiseDatatype [_] dtype)
(lsize [_] n-elems)
(readObject [_ idx]
(.invokePrim read-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx)))
(writeObject [_ idx v]
(.invokePrim write-fn %1 %2 (ChunkedList/indexCheck 0 n-elems idx) v)))
#(throw (Exception. "Bulk set of array properties not supported yet - use read to get writable list"))))))))))
nil
layout)
(assoc struct-def :accessors accessors)))
Expand Down
19 changes: 19 additions & 0 deletions test/tech/v3/datatype/struct_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,22 @@
sarray (dt-struct/new-array-of-structs :ptr-types 10)]
(is (= (vec (repeat 10 0))
(vec (dt-struct/array-of-structs->column sarray :a))))))


(deftest get-set-array-members
(let [sdef (dt-struct/define-datatype! :array-member [{:name :x :datatype :int32 :n-elems 10}])
vec-type (dt-struct/define-datatype! :vec3 [{:name :x :datatype :float32}
{:name :y :datatype :float32}
{:name :z :datatype :float32}])
ttype (dt-struct/define-datatype! :triangle [{:name :pts :datatype :vec3 :n-elems 3}])
sdata (dt-struct/new-struct :array-member)
tdata (dt-struct/new-struct :triangle)
x (get sdata :x)
_ (do
;;sdata'x X member is a reified LongBuffer so you can set the value via getting it then dtype/copy!
(dtype/copy! (range 10) x)
(is (= (vec (range 10)) (get sdata :x))))
ys (dt-struct/array-of-structs->column (:pts tdata) :y)
_ (do
(dtype/copy! (range 1 4) ys)
(is (= 2.0 (get-in tdata [:pts 1 :y]))))]))

0 comments on commit 9a47a50

Please sign in to comment.