Skip to content

Commit

Permalink
Integer-typed UWDs (#921)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcohen1 authored Aug 30, 2024
1 parent d336c7b commit f7ea390
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
16 changes: 11 additions & 5 deletions src/programs/RelationalPrograms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,10 @@ function parse_relation_diagram(head::Expr, body::Expr)
Expr(:where, expr, context) => (expr, parse_relation_context(context)...)
_ => (head, nothing, nothing)
end
var_types = if isnothing(all_types) # Untyped case.
var_types = if isnothing(all_types) # Untyped case
vars -> length(vars)
else # Typed case.
var_type_map = Dict{Symbol,Symbol}(zip(all_vars, all_types))
else
var_type_map = Dict(zip(all_vars, all_types))
vars -> getindex.(Ref(var_type_map), vars)
end

Expand Down Expand Up @@ -194,19 +194,24 @@ function parse_relation_context(context)
vars = map(terms) do term
@match term begin
Expr(:(::), var::Symbol, type::Symbol) => (var => type)
Expr(:(::), var::Symbol, type::Expr) => (var => type)
Expr(:(::), var::Symbol, type::Integer) => (var => type)
var::Symbol => var
_ => error("Invalid syntax in term $expr of context")
_ => error("Invalid syntax in term $term of context")
end
end

if vars isa AbstractVector{Symbol}
(vars, nothing)
elseif vars isa AbstractVector{Pair{Symbol,Symbol}}
elseif eltype(vars) <: Pair
(first.(vars), last.(vars))
else
error("Context $context mixes typed and untyped variables")
end

end


function parse_relation_call(call)
@match call begin
Expr(:call, name::Symbol, Expr(:parameters, args)) =>
Expand Down Expand Up @@ -242,6 +247,7 @@ function parse_relation_inferred_args(args)
Expr(:kw, name::Symbol, var::Symbol) => (name => var)
Expr(:(=), name::Symbol, var::Symbol) => (name => var)
var::Symbol => var
Expr(:(::), _, _) => error("All variable types must be included in the where clause and not in the argument list")
_ => error("Expected name as positional or keyword argument")
end
end
Expand Down
66 changes: 65 additions & 1 deletion test/programs/RelationalPrograms.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
module TestRelationalPrograms

using Test

using Catlab.CategoricalAlgebra.CSets
Expand Down Expand Up @@ -49,14 +50,16 @@ d = RelationDiagram(0)
add_box!(d, 0, name=:A)
@test parsed == d

# Typed

# Typed by Symbols
#------

parsed = @relation (x,y,z) where (x::X, y::Y, z::Z, w::W) begin
R(x,w)
S(y,w)
T(z,w)
end

d = RelationDiagram([:X,:Y,:Z])
add_box!(d, [:X,:W], name=:R)
add_box!(d, [:Y,:W], name=:S)
Expand All @@ -66,6 +69,67 @@ set_junction!(d, [1,4,2,4,3,4])
set_junction!(d, [1,2,3], outer=true)
@test parsed == d


# Typed by Integers
#------

parsed = @relation (x,y,z) where (x::1, y::2, z::3, w::4) begin
R(x,w)
S(y,w)
T(z,w)
end

d = RelationDiagram([:1,:2,:3])
add_box!(d, [:1,:4], name=:R)
add_box!(d, [:2,:4], name=:S)
add_box!(d, [:3,:4], name=:T)
add_junctions!(d, [:1,:2,:3,:4], variable=[:x,:y,:z,:w])
set_junction!(d, [1,4,2,4,3,4])
set_junction!(d, [1,2,3], outer=true)
@test parsed == d



# Typed by Expressions
#------

parsed = @relation (x,y,z) where (x::n(1), y::n(2), z::n(3), w::n(4)) begin
R(x,w)
S(y,w)
T(z,w)
end

d = RelationDiagram([:(n(1)), :(n(2)), :(n(3))])
add_box!(d, [:(n(1)),:(n(4))], name=:R)
add_box!(d, [:(n(2)),:(n(4))], name=:S)
add_box!(d, [:(n(3)),:(n(4))], name=:T)
add_junctions!(d, [:(n(1)),:(n(2)),:(n(3)),:(n(4))], variable=[:x,:y,:z,:w])
set_junction!(d, [1,4,2,4,3,4])
set_junction!(d, [1,2,3], outer=true)
@test parsed == d

# Mixed types
#------

parsed = @relation (x,y,z) where (x::n(1), y::2, z::C, w::nothing) begin
R(x,w)
S(y,w)
T(z,w)
end

d = RelationDiagram([:(n(1)), :2, :C])
add_box!(d, [:(n(1)),:nothing], name=:R)
add_box!(d, [:2,:nothing], name=:S)
add_box!(d, [:C,:nothing], name=:T)
add_junctions!(d, [:(n(1)),:2,:C,:nothing], variable=[:x,:y,:z,:w])
set_junction!(d, [1,4,2,4,3,4])
set_junction!(d, [1,2,3], outer=true)
@test parsed == d





# Special case: closed diagram.
sird_uwd = @relation () where (S::Pop, I::Pop, R::Pop, D::Pop) begin
infect(S,I,I,I) # inf
Expand Down

0 comments on commit f7ea390

Please sign in to comment.