Skip to content

Commit

Permalink
Adv/multi file serialization (#4067)
Browse files Browse the repository at this point in the history
* updates for serializers

* working proof of concept

* test multi file

* small fix for loading without attrs

* fixing merge

* Update src/Groups/GAPGroups.jl

* fixes errors introduced from merge

* simplify saving external lp

* fixes from discussion

* Update src/Serialization/PolyhedralGeometry.jl

* forgot to change type in signature

* typo

* fix for LP tests

* missed input type for one of the saves

* pass kw to load in setup tests

* fixes warning + polyhedral sets

* version number check without commit

* hack to now check dev version

* hack to not check dev version

* deal with conflicts

---------

Co-authored-by: Lars Göttgens <lars.goettgens@rwth-aachen.de>
  • Loading branch information
antonydellavecchia and lgoettgens authored Sep 12, 2024
1 parent 8d44d45 commit 97bc0a4
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 61 deletions.
18 changes: 18 additions & 0 deletions src/Serialization/PolyhedralGeometry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,17 @@ function save_object(s::SerializerState, lp::LinearProgram{QQFieldElem})
end
end

function save_object(s::SerializerState{<: LPSerializer}, lp::LinearProgram{QQFieldElem})
lp_filename = basepath(s.serializer) * "-$(objectid(lp)).lp"
save_lp(lp_filename, lp)

save_object(s, basename(lp_filename))
end

function load_object(s::DeserializerState, ::Type{<:LinearProgram}, field::QQField)
if s.obj isa String
error("Loading this file requires using the LPSerializer")
end
coeff_type = elem_type(field)
fr = load_object(s, Polyhedron, field, :feasible_region)
conv = load_object(s, String, :convention)
Expand All @@ -113,6 +123,14 @@ function load_object(s::DeserializerState, ::Type{<:LinearProgram}, field::QQFie
return LinearProgram{coeff_type}(fr, lp, Symbol(conv))
end

function load_object(s::DeserializerState{LPSerializer},
::Type{<:LinearProgram}, field::QQField)
load_node(s) do _
lp_filename = dirname(basepath(s.serializer)) * "/$(s.obj)"
pm_lp = load_lp(lp_filename)
end
end

##############################################################################
@register_serialization_type MixedIntegerLinearProgram uses_params

Expand Down
54 changes: 24 additions & 30 deletions src/Serialization/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,10 +356,10 @@ function register_serialization_type(ex::Any, str::String, uses_id::Bool,
if !($ex <: Union{Number, String, Bool, Symbol, Vector, Tuple, Matrix, NamedTuple, Dict, Set})
function Oscar.serialize(s::Oscar.AbstractSerializer, obj::T) where T <: $ex
Oscar.serialize_type(s, T)
Oscar.save(s.io, obj; serializer_type=Oscar.IPCSerializer)
Oscar.save(s.io, obj; serializer=Oscar.IPCSerializer())
end
function Oscar.deserialize(s::Oscar.AbstractSerializer, ::Type{<:$ex})
Oscar.load(s.io; serializer_type=Oscar.IPCSerializer)
Oscar.load(s.io; serializer=Oscar.IPCSerializer())
end
end
end)
Expand Down Expand Up @@ -511,11 +511,10 @@ julia> load("/tmp/fourtitwo.mrdi")
"""
function save(io::IO, obj::T; metadata::Union{MetaData, Nothing}=nothing,
with_attrs::Bool=true,
serializer_type::Type{<: OscarSerializer} = JSONSerializer) where T

s = state(serializer_open(io, serializer_type,
with_attrs ? type_attr_map : Dict{String, Vector{Symbol}}()))
save_data_dict(s) do
serializer::OscarSerializer = JSONSerializer()) where T
s = serializer_open(io, serializer,
with_attrs ? type_attr_map : Dict{String, Vector{Symbol}}())
save_data_dict(s) do
# write out the namespace first
save_header(s, get_oscar_serialization_version(), :_ns)

Expand All @@ -528,21 +527,9 @@ function save(io::IO, obj::T; metadata::Union{MetaData, Nothing}=nothing,
global_serializer_state.id_to_obj[ref] = obj
end
save_object(s, string(ref), :id)

end

# this should be handled by serializers in a later commit / PR
if !isempty(s.refs) && serializer_type == JSONSerializer
save_data_dict(s, refs_key) do
for id in s.refs
ref_obj = global_serializer_state.id_to_obj[id]
s.key = Symbol(id)
save_data_dict(s) do
save_typed_object(s, ref_obj)
end
end
end
end
handle_refs(s)

if !isnothing(metadata)
save_json(s, JSON3.write(metadata), :meta)
Expand All @@ -552,13 +539,19 @@ function save(io::IO, obj::T; metadata::Union{MetaData, Nothing}=nothing,
return nothing
end

function save(filename::String, obj::Any; metadata::Union{MetaData, Nothing}=nothing,
function save(filename::String, obj::Any;
metadata::Union{MetaData, Nothing}=nothing,
serializer::OscarSerializer=JSONSerializer(),
with_attrs::Bool=true)
dir_name = dirname(filename)
# julia dirname does not return "." for plain filenames without any slashes
temp_file = tempname(isempty(dir_name) ? pwd() : dir_name)

open(temp_file, "w") do file
save(file, obj; metadata=metadata, with_attrs=with_attrs)
save(file, obj;
metadata=metadata,
with_attrs=with_attrs,
serializer=serializer)
end
Base.Filesystem.rename(temp_file, filename) # atomic "multi process safe"
return nothing
Expand Down Expand Up @@ -622,8 +615,8 @@ true
```
"""
function load(io::IO; params::Any = nothing, type::Any = nothing,
serializer_type=JSONSerializer, with_attrs::Bool=true)
s = state(deserializer_open(io, serializer_type, with_attrs))
serializer=JSONSerializer(), with_attrs::Bool=true)
s = deserializer_open(io, serializer, with_attrs)
if haskey(s.obj, :id)
id = s.obj[:id]
if haskey(global_serializer_state.id_to_obj, UUID(id))
Expand Down Expand Up @@ -658,9 +651,9 @@ function load(io::IO; params::Any = nothing, type::Any = nothing,
jsondict = copy(s.obj)
jsondict = upgrade(file_version, jsondict)
jsondict_str = JSON3.write(jsondict)
s = state(deserializer_open(IOBuffer(jsondict_str),
serializer_type,
with_attrs))
s = deserializer_open(IOBuffer(jsondict_str),
serializer,
with_attrs)
end

try
Expand Down Expand Up @@ -704,7 +697,7 @@ function load(io::IO; params::Any = nothing, type::Any = nothing,
end
return loaded
catch e
if file_version > VERSION_NUMBER
if VersionNumber(replace(String(file_version), r"DEV.+", "DEV")) > VERSION_NUMBER
@warn """
Attempted loading file stored with Oscar version $file_version
using Oscar version $VERSION_NUMBER
Expand All @@ -720,8 +713,9 @@ function load(io::IO; params::Any = nothing, type::Any = nothing,
end

function load(filename::String; params::Any = nothing,
type::Any = nothing, with_attrs::Bool=true)
type::Any = nothing, with_attrs::Bool=true,
serializer::OscarSerializer=JSONSerializer())
open(filename) do file
return load(file; params=params, type=type)
return load(file; params=params, type=type, serializer=serializer)
end
end
71 changes: 44 additions & 27 deletions src/Serialization/serializers.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
using JSON3
import Base.haskey

################################################################################
# Serializers
abstract type OscarSerializer end

struct JSONSerializer <: OscarSerializer end

struct IPCSerializer <: OscarSerializer end

abstract type MultiFileSerializer <: OscarSerializer end

struct LPSerializer <: MultiFileSerializer
basepath::String
end

basepath(serializer::MultiFileSerializer) = serializer.basepath

################################################################################
# (de)Serializer States

Expand All @@ -23,8 +39,8 @@ function reset_global_serializer_state()
end

# struct which tracks state for (de)serialization
mutable struct SerializerState
# dict to track already serialized objects
mutable struct SerializerState{T <: OscarSerializer}
serializer::T
new_level_entry::Bool
# UUIDs that point to the objs in the global state,
# ideally this would be an ordered set
Expand Down Expand Up @@ -137,9 +153,10 @@ function finish_writing(s::SerializerState)
# nothing to do here
end

mutable struct DeserializerState
mutable struct DeserializerState{T <: OscarSerializer}
# or perhaps Dict{Int,Any} to be resilient against corrupts/malicious files using huge ids
# the values of refs are objects to be deserialized
serializer::T
obj::Union{Dict{Symbol, Any}, Vector, JSON3.Object, JSON3.Array, BasicTypeUnion}
key::Union{Symbol, Int, Nothing}
refs::Union{Dict{Symbol, Any}, JSON3.Object, Nothing}
Expand Down Expand Up @@ -200,46 +217,46 @@ function load_params_node(s::DeserializerState)
end
end

################################################################################
# Serializers
abstract type OscarSerializer end

struct JSONSerializer <: OscarSerializer
state::S where S <: Union{SerializerState, DeserializerState}
end

struct IPCSerializer <: OscarSerializer
state::S where S <: Union{SerializerState, DeserializerState}
end

state(s::OscarSerializer) = s.state

function serializer_open(
io::IO,
T::Type{<: OscarSerializer},
serializer::OscarSerializer,
type_attr_map::S) where S <: Union{Dict{String, Vector{Symbol}}, Nothing}

# some level of handling should be done here at a later date
return T(SerializerState(true, UUID[], io, nothing, type_attr_map))
return SerializerState(serializer, true, UUID[], io, nothing, type_attr_map)
end

function deserializer_open(io::IO, T::Type{JSONSerializer}, with_attrs::Bool)
function deserializer_open(io::IO, serializer::OscarSerializer, with_attrs::Bool)
obj = JSON3.read(io)
refs = nothing
if haskey(obj, refs_key)
refs = obj[refs_key]
end
return T(DeserializerState(obj, nothing, refs, with_attrs))

return DeserializerState(serializer, obj, nothing, refs, with_attrs)
end

function deserializer_open(io::IO, T::Type{IPCSerializer}, with_attrs::Bool)
function deserializer_open(io::IO, serializer::IPCSerializer, with_attrs::Bool)
# Using a JSON3.Object from JSON3 version 1.13.2 causes
# @everywhere using Oscar
# to hang. So we use a Dict here for now.

# put_params to hang
#obj = JSON3.read(io)
obj = JSON.parse(io, dicttype=Dict{Symbol, Any})
return T(DeserializerState(obj, nothing, nothing, with_attrs))

return DeserializerState(serializer, obj, nothing, nothing, with_attrs)
end

function handle_refs(s::SerializerState)
if !isempty(s.refs)
save_data_dict(s, refs_key) do
for id in s.refs
ref_obj = global_serializer_state.id_to_obj[id]
s.key = Symbol(id)
save_data_dict(s) do
save_typed_object(s, ref_obj)
end
end
end
end
end

function attrs_list(s::SerializerState, T::Type)
Expand Down
6 changes: 6 additions & 0 deletions test/Serialization/PolyhedralGeometry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ using Oscar: _integer_variables
@test objective_function(LP) == objective_function(loaded)
@test feasible_region(LP) == feasible_region(loaded)
end

serializer=Oscar.LPSerializer(joinpath(path, "original"))
test_save_load_roundtrip(path, LP; serializer=serializer) do loaded
@test objective_function(LP) == objective_function(loaded)
@test feasible_region(LP) == feasible_region(loaded)
end
end

@testset "MixedIntegerLinearProgram" begin
Expand Down
8 changes: 4 additions & 4 deletions test/Serialization/setup_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ if !isdefined(Main, :test_save_load_roundtrip) || isinteractive()
# save and load from a file
filename = joinpath(path, "original.json")
save(filename, original; kw...)
loaded = load(filename; params=params)
loaded = load(filename; params=params, kw...)

@test loaded isa T
func(loaded)
Expand All @@ -21,7 +21,7 @@ if !isdefined(Main, :test_save_load_roundtrip) || isinteractive()
io = IOBuffer()
save(io, original; kw...)
seekstart(io)
loaded = load(io; params=params)
loaded = load(io; params=params, kw...)

@test loaded isa T
func(loaded)
Expand All @@ -30,15 +30,15 @@ if !isdefined(Main, :test_save_load_roundtrip) || isinteractive()
io = IOBuffer()
save(io, original; kw...)
seekstart(io)
loaded = load(io; type=T, params=params)
loaded = load(io; type=T, params=params, kw...)

@test loaded isa T
func(loaded)

# test loading on a empty state
save(filename, original; kw...)
Oscar.reset_global_serializer_state()
loaded = load(filename; params=params)
loaded = load(filename; params=params, kw...)
@test loaded isa T

# test schema
Expand Down

0 comments on commit 97bc0a4

Please sign in to comment.