Visualizing Forney Factor Graphs when using RxInfer (Part 4)
Automating most of the layout configuration
Bayesian Inference
Active Inference
TikZ
RxInfer
Julia
Author
Kobus Esterhuysen
Published
August 15, 2024
Modified
August 15, 2024
In Part 3 we added decorations for Constrained Forney Factor Graphs (CFFG) as well as for delta factors. In this part we:
add a control example
refactor the code to increase reuse
make the configuration of FFGs mostly automatic
a new graph (called for now the kobus_graph to distinguish it from the meta_graph) is added
in addition to adding some useful properties, it extends the MetaGraphsNext to also include neighbours for variable nodes (i.e. not just for factor nodes)
the neighbors of variable nodes are indicated with the English spelling: neighbours
only the layer names and a few key variable names have to be specified - the rest happens automatically
there are a few exceptions related to more complicated FFGs
these exceptions might be addressed in the future
make appending of global_counter values switchable
this helps greatly when encountering a new model
One of the distinguishing features of the RxInfer Julia package is that it uses reactive message passing (RMP). This means node updates in the factor graph happens asynchronously. A user usually tends to gloss over these details due to the increased complexity. When node updates happen in the more traditional way, i.e. synchronously (for example as with the PyMDP Python package), users tend to have the option of being aware of which node is currently being udpated. It is quite beneficial to conceptualize the workings of factor graphs in general and the visualization of the factor graph itself is often helpful. The existing visualization code is rather limited and that is the reason this project was undertaken. The ultimate goal is reach a visualization which conforms to the Constrained Forney Factor Graph (CFFG) as defined in
GraphPPL.jl exports the @model macro for model specification allowing the acceptance of a regular Julia function and then builds a Forney Factor Graph (FFG) under the hood. This graph is bipartite where nodes belong to one of two sets: one for factors and one for variables. Note that we use the terms node and link rather than vertex and edge to align with RxInfer terminology.
Node Taxonomy
nodes
facs (FFG/RxInfer factor nodes)
vars (FFG/RxInfer variable nodes)
parvars (parameter variables, e.g. \(\alpha, a, A_t, B, \gamma_t, \mu\))
sysvars (system variables, e.g. \(x_t, s_t, u_t\))
Node Visualization
a fac node is visualized by a
square (style fac) if it is a factor
small black square (style dlt) if it is a delta (\(\delta\)) factor
delta factors connect to their associated parvars and sysvars
a var node is visualized by a
point (style var) if connected to a single fac
equals-square (style eql) if connected to multiple facs
rather than converting the bipartite graph from RxInfer to a graph that only contains a single kind of node, we keep the bipartite graph and visualize the nodes differently
For each of the examples in this project the following steps happen:
A @model is created
The @model is conditioned on some data
ususally a few points so that the visualization does not become unwieldy
A RxInfer model is created (rxi_model)
A GraphPPL model is obtained (gppl_model)
A meta graph is obtained (meta_graph)
The graph is rendered by the existing implementation (GraphPlot.gplot())
Some vertex labels are inspected for the sake of interest
Setup global parameters (indicated by a leading underscore character) to influence the behavior of generate_tikz()
_san_node_ids is provided for node names that needs to be sanitized
_lup_enabled is set to true/false to enable lookup of math names
_lup_agc is set to true/false to enable appending of global_counter values
_raw_links_enabled is set to true/false to enable the drawing of raw links for reference
_CFFG is set to true/false to enable the drawing of beads to turn a FFG into a CFFG
_deltas_enabled is set to true/false to enable the drawing of delta decorations
_lup_dict is provided for all variable names needing math equivalents
_tsep specifies the time delta separation of nodes
_lup_dict specifies the lookup of canonical latex names for nodes
when _lup_agc is true, global_counter values are appended
when _lup_agc is false, global_counter values are not appended
The TikZ string is generated (generate_tikz())
heading of the FFG
GraphPPL model
layers for the FFG
layer names (if applicable)
cntrl (control layer)
paramB1 (parameter B1 layer)
paramB2 (parameter B2 layer)
stateM2 (state minus 2 layer, i.e. 2 layers above state0 layer)
stateM1 (state minus 1 layer, i.e. 1 layer above state0 layer)
state0 (state layer)
stateP1 (state plus 1 layer, i.e. 1 layer below state0 layer)
stateP2 (state plus 2 layer, i.e. 2 layers below state0 layer)
obser (observation layer)
prefr (preference layer, for setpoints)
paramA (parameter A layer)
The TikZ string is printed
The FFG is saved to a pdf file (show_tikz())
Setup notes
If additional Ubuntu packages are needed for a Julia package
sudo apt-get install package
If you want to run bash on the devcontainer in a local terminal; kobus@Kobuss-Mac-mini ~ %
Activating project at `~/.julia/environments/v1.10`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Status `~/.julia/environments/v1.10/Project.toml`
⌃ [a2cc645c] GraphPlot v0.5.2
[fa8bd995] MetaGraphsNext v0.7.0
⌃ [86711068] RxInfer v3.0.0
Info Packages marked with ⌃ have new versions available and may be upgradable.
"""Takes the tikz string given by generate_tikz() and writes it to a .tex file to produce a TikZ visualisation. """functionshow_tikz(tikz_code_graph::String)## eval(Meta.parse(dot_code_graph)) ## for GraphViz## see problem in ^v1 under TikzPictures section ## tikz_picture = TikzPicture(tikz_code_graph)## return tikz_picture file_path ="myoutput.tex"## write to file rather than show in notebookopen(file_path, "w") do filewrite(file, tikz_code_graph)endend
show_tikz
### # If your PDF contains a TikZ picture (a vector graphic created using LaTeX), # # you can still use the method I mentioned earlier to display it in a Julia notebook. # # Here’s the method again for your reference:# import Pkg; Pkg.add("PGFPlotsX")# using PGFPlotsX# # Create a struct that represents a PDF file# struct PDF# file::String# end# # Define a function to show the PDF as inline SVG# function Base.show(io::IO, ::MIME"image/svg+xml", pdf::PDF)# svg = first(splitext(pdf.file)) * ".svg"# PGFPlotsX.convert_pdf_to_svg(pdf.file, svg)# write(io, read(svg))# try; rm(svg; force=true); catch e; end# return nothing# end# # Now you can display a PDF file like this:# display(PDF("myoutput.pdf"))# # Internal Error: cairo context error: invalid matrix (not invertible)<0a># # cairo error: invalid matrix (not invertible)
"""Given a unsanitized string, provides a sanitized equivalent."""functionsanitized(vertex::String)ifhaskey(_san_node_ids, vertex)return _san_node_ids[vertex]elsereturn vertexendend"""Given a label string, provides its looked up latex value."""functionlup(label::String)if _lup_enabled lup_label =get(_lup_dict, label, replace(label, "_"=>"\\_"))elsereturnreplace(label, "_"=>"\\_")endendfunctionprint_meta_graph(meta_graph) ## just for infoprintln("============== NODES ==============")for node in MetaGraphsNext.vertices(meta_graph) label = MetaGraphsNext.label_for(meta_graph, node)println("$node: $label")endprintln("\n============== LINKS ==============")for link in MetaGraphsNext.edges(meta_graph) source_node = MetaGraphsNext.label_for(meta_graph, link.src) dest_node = MetaGraphsNext.label_for(meta_graph, link.dst)println("$(source_node) -> $(dest_node)")endend
print_meta_graph (generic function with 1 method)
functionretain_can(s::String) ## retain canonical chars of a label string m =match(r"^[a-zA-Zγθs₀\*]*", s)return m.matchend"""Given a user specified canonical factor label string, returns all the GraphPPL associated labels from the context's factor_nodes."""functionstr_facs(gppl_model, factor::String) context = GraphPPL.getcontext(gppl_model) vec = []for (factor_ID, factor_label) inpairs(context.factor_nodes)## println("$(factor_ID), $(factor_label)")## println("$(typeof(factor_ID)), $(typeof(factor_label))") node_data = gppl_model[factor_label]## println("$(node_data.properties.fform)\n")ifretain_can(string(node_data.properties.fform)) == factorappend!(vec, factor_label)endend result =string.(vec)return resultend"""Given a user specified canonical variable label string, returns all the GraphPPL associated labels from the context's individual_variables."""functionstr_individual_vars(gppl_model, var::String) context = GraphPPL.getcontext(gppl_model) vec = []for (node_ID, node_label) inpairs(context.individual_variables)## println("$(node_ID), $(node_label)")## println("$(typeof(node_ID)), $(typeof(node_label))")ifoccursin(var, string(node_ID))append!(vec, node_label)endend result =string.(vec)return resultend"""Given a user specified canonical variable label string, returns all the GraphPPL associated labels from the context's vector_variables."""functionstr_vector_vars(gppl_model, var::String) context = GraphPPL.getcontext(gppl_model) vec = []for (node_ID, node_label) inpairs(context.vector_variables)## println("$(node_ID), $(node_label)")## println("$(typeof(node_ID)), $(typeof(node_label))")ifstring(node_ID) == varappend!(vec, node_label)endend result =string.(vec)return resultend
str_vector_vars
"""Given a GraphPPL model, provides a `kobus_graph` which also contains neighboursfor variable nodes (i.e. not only for factor nodes). It also contains some otheruseful properties."""functioncreate_kobus_graph(gppl_model) kobus_graph =Dict{String, Dict{Symbol, Any}}() meta_graph = gppl_model.graphfor (i,vertex) inenumerate(MetaGraphsNext.vertices(meta_graph)) node_label = MetaGraphsNext.label_for(meta_graph, vertex) node_label_str =sanitized(string(node_label))println("$(i): $(node_label_str)")if !haskey(kobus_graph, node_label_str) kobus_graph[node_label_str] =Dict{Symbol, Any}() kobus_graph[node_label_str][:neighbours] = [] kobus_graph[node_label_str][:node_label] = node_label node_code =code_for(meta_graph, node_label) kobus_graph[node_label_str][:node_code] = node_code node_lup =lup(node_label_str) kobus_graph[node_label_str][:node_lup] = node_lup node_global_counter = node_label.global_counter kobus_graph[node_label_str][:node_global_counter] = node_global_counter node_data = gppl_model[node_label] ## GraphPPL.NodeData kobus_graph[node_label_str][:node_data] = node_data properties = node_data.properties ## GraphPPL.Factor|VariableNodeProperties### kobus_graph[node_label_str][:properties] = Dict{Symbol, Any}()# for prop in properties# value = getproperty(properties, prop)# # println("field: $(field), value: $(value)")# kobus_graph[node_label_str][:properties][prop] = value# end kobus_graph[node_label_str][:properties] = properties fields =fieldnames(typeof(properties)) ## Tuple{Symbol, Symbol} kobus_graph[node_label_str][:fields] =Dict{Symbol, Any}()for field in fields value =getfield(properties, field) kobus_graph[node_label_str][:fields][field] = valueendendend## populate the :neighbours keyfor (i,edge) inenumerate(MetaGraphsNext.edges(meta_graph)) src_node = MetaGraphsNext.label_for(meta_graph, edge.src) dst_node = MetaGraphsNext.label_for(meta_graph, edge.dst) san_src_node =sanitized(string(src_node)); #println(typeof(san_src_node)) san_dst_node =sanitized(string(dst_node))## println("$(i): $(san_src_node) -- $(san_dst_node)")push!(kobus_graph[san_src_node][:neighbours], san_dst_node)push!(kobus_graph[san_dst_node][:neighbours], san_src_node)endreturn kobus_graphend"""Given a user name of a variable node, provides the `inifac` node connected to it."""functionfind_inifac_node(kobus_graph, out_neighbor_uname, gppl_model) inifac_out_neighbor_str =str_individual_vars(gppl_model, out_neighbor_uname)[1] var_neighbours = kobus_graph[inifac_out_neighbor_str][:neighbours]for i in1:length(var_neighbours) neighbors = kobus_graph[var_neighbours[i]][:properties].neighborsprintln(" var_neighbours[$(i)]: $(var_neighbours[i])") in_neighbor_label = neighbors[2][1]; println(" in_neighbor_label: $(in_neighbor_label)") out_neighbor_label = neighbors[1][1]; println(" out_neighbor_label: $(out_neighbor_label)")ifstring(out_neighbor_label)==inifac_out_neighbor_strprintln("inifac_out_neighbor_str: $(inifac_out_neighbor_str)")println("inifac_str: $(var_neighbours[i])")return var_neighbours[i]endendend"""Given a user name of a `sysvar` variable node, provides the `cntrl` nodes connected to it."""functionfind_cntrl_nodes(kobus_graph, uname, gppl_model) cntrl_var_label_strs =str_vector_vars(gppl_model, uname)println("cntrl_var_label_strs: $(cntrl_var_label_strs)") sysvars =Vector{String}() facs =Vector{String}() parvars =Vector{Vector{String}}()for i in1:length(cntrl_var_label_strs) var_neighbours = kobus_graph[cntrl_var_label_strs[i]][:neighbours]println("cntrl_var_label_strs[i]: $(cntrl_var_label_strs[i])")for j in1:length(var_neighbours) neighbors = kobus_graph[var_neighbours[j]][:properties].neighborsprintln(" var_neighbours[j]: $(var_neighbours[j])") in_neighbor_label = neighbors[2][1]println(" in_neighbor_label: $(in_neighbor_label)") out_neighbor_label = neighbors[1][1]println(" out_neighbor_label: $(out_neighbor_label)")if (string(out_neighbor_label) in cntrl_var_label_strs)push!(facs, var_neighbours[j])push!(sysvars, cntrl_var_label_strs[i]) parvar_nodes =find_parvar_nodes(kobus_graph, var_neighbours[j], gppl_model)append!(parvars, [parvar_nodes])endendendreturn sysvars, facs, parvarsend"""Given a `inifac` label string, provides the `iniparvar` nodes connected to it."""functionfind_iniparvar_nodes(kobus_graph, inifac_label_str, gppl_model) neighbors = kobus_graph[inifac_label_str][:properties].neighborsprintln("neighbors: $(neighbors)") out_var =string(neighbors[1][1]); println(" out_var: $(out_var)") iniparvars =Vector{String}()for i in2:length(neighbors)push!(iniparvars, string(neighbors[i][1]))endreturn out_var, iniparvarsend"""Given a `fac` label string, provides the `parvar` nodes connected to it."""functionfind_parvar_nodes(kobus_graph, fac_label_str, gppl_model) state_var_label_strs =str_vector_vars(gppl_model, "s") neighbors = kobus_graph[fac_label_str][:properties].neighborsprintln("neighbors: $(neighbors)") parvars =Vector{String}()for i in2:length(neighbors) in_neighbor_label = neighbors[i][1]if !(string(in_neighbor_label) in state_var_label_strs)push!(parvars, string(in_neighbor_label))endendreturn parvarsend"""Given a `fac` label string, provides the next variable in the `state0` sequence."""functionnext_var(kobus_graph, fac_label_str) neighbors = kobus_graph[fac_label_str][:properties].neighbors var = neighbors[1]println("var: $(var)") var_label = var[1]println("=== NEXT var node is: $(var_label)")return var_labelend"""Given a `var` label, provides the next factor in the `state0` sequence."""functionnext_fac(kobus_graph, var_label, state_var_label_strs, gppl_model)println("var_label: $(var_label)") try var_index = gppl_model[var_label].properties.index catch var_index =0println("Exception: Setting var_index to 0")endif var_index ==nothing var_index =0end## for s_0 println("var_index: $(var_index)") var_neighbours = kobus_graph[string(var_label)][:neighbours]println("var_neighbours: $(var_neighbours)")for i in1:length(var_neighbours) neighbors = kobus_graph[var_neighbours[i]][:properties].neighborsprintln(" var_neighbours[$(i)]: $(var_neighbours[i])")## in_neighbor is this var node && out_neighbor is the next s in_neighbor_label = neighbors[2][1]; println(" in_neighbor_label: $(in_neighbor_label)") out_neighbor_label = neighbors[1][1]; println(" out_neighbor_label: $(out_neighbor_label)")if in_neighbor_label==var_labelif var_index+1<=length(state_var_label_strs)ifstring(out_neighbor_label)==state_var_label_strs[var_index +1] fac_label_str = var_neighbours[i]println(" === NEXT fac node is: $(fac_label_str)")return fac_label_strelseiflength(var_neighbours) ==2println(" Take this var_label as next fac because we ONLY have 2 var_neighbours") fac_label_str = var_neighbours[i]println(" === NEXT fac node is: $(fac_label_str)")return fac_label_strendelsereturn"END"endendendend"""Given a `inifac` label string, provides the `state0` nodes."""functionfind_state0_nodes(kobus_graph, inifac_label_str, gppl_model) state_var_label_strs =str_vector_vars(gppl_model, "s") vars =Vector{String}() ## may contain non-sysvars too facs =Vector{String}() node_lab_strs = [inifac_label_str] fac_str = node_lab_strs[1]while fac_str !="END" var =next_var(kobus_graph, fac_str)push!(node_lab_strs, string(var))push!(vars, string(var)) fac_str =next_fac(kobus_graph, var, state_var_label_strs, gppl_model)if fac_str =="END"|| fac_str ==nothingreturn node_lab_strs, vars, facselsepush!(node_lab_strs, fac_str)push!(facs, fac_str)endendend"""Given a user name, provides the `obser` nodes."""functionfind_obser_nodes(kobus_graph, uname, gppl_model) obser_var_label_strs =str_vector_vars(gppl_model, uname) state_var_label_strs =str_vector_vars(gppl_model, "s")println("obser_var_label_strs: $(obser_var_label_strs)") sysvars =Vector{String}() facs =Vector{String}() parvars =Vector{Vector{String}}()for i in1:length(obser_var_label_strs) var_neighbours = kobus_graph[obser_var_label_strs[i]][:neighbours]println("obser_var_label_strs[i]: $(obser_var_label_strs[i])")for j in1:length(var_neighbours) neighbors = kobus_graph[var_neighbours[j]][:properties].neighborsprintln(" var_neighbours[j]: $(var_neighbours[j])") in_neighbor_label = neighbors[2][1]println(" in_neighbor_label: $(in_neighbor_label)") out_neighbor_label = neighbors[1][1]println(" out_neighbor_label: $(out_neighbor_label)")iflength(state_var_label_strs) >0if (string(out_neighbor_label) in obser_var_label_strs) && (string(in_neighbor_label) in state_var_label_strs)push!(facs, var_neighbours[j])push!(sysvars, obser_var_label_strs[i]) parvar_nodes =find_parvar_nodes(kobus_graph, var_neighbours[j], gppl_model)append!(parvars, [parvar_nodes])endelseifisempty(state_var_label_strs) ## no state vars availifstring(out_neighbor_label) in obser_var_label_strspush!(facs, var_neighbours[j])push!(sysvars, obser_var_label_strs[i]) parvar_nodes =find_parvar_nodes(kobus_graph, var_neighbours[j], gppl_model)append!(parvars, [parvar_nodes])endelseprintln("find_obser_nodes(): UNHANDLED CASE")endendendreturn sysvars, facs, parvarsend"""Given a user name, provides the `prefr` nodes."""functionfind_prefr_nodes(kobus_graph, uname, gppl_model) obser_var_label_strs =str_vector_vars(gppl_model, uname) state_var_label_strs =str_vector_vars(gppl_model, "s")println("obser_var_label_strs: $(obser_var_label_strs)") facs =Vector{String}() parvars =Vector{Vector{String}}()for i in1:length(obser_var_label_strs) var_neighbours = kobus_graph[obser_var_label_strs[i]][:neighbours]println("obser_var_label_strs[i]: $(obser_var_label_strs[i])")for j in1:length(var_neighbours) neighbors = kobus_graph[var_neighbours[j]][:properties].neighborsprintln(" var_neighbours[j]: $(var_neighbours[j])") in_neighbor_label = neighbors[2][1]println(" in_neighbor_label: $(in_neighbor_label)") out_neighbor_label = neighbors[1][1]println(" out_neighbor_label: $(out_neighbor_label)")if (string(out_neighbor_label) in obser_var_label_strs) && !(string(in_neighbor_label) in state_var_label_strs)push!(facs, var_neighbours[j]) parvar_nodes =find_parvar_nodes(kobus_graph, var_neighbours[j], gppl_model)append!(parvars, [parvar_nodes])endendendreturn facs, parvarsend
find_prefr_nodes
"""Provides equal spaces to be used on the sides of factor nodes."""functionspaces(n::Int) d =1/(n +1)return [i*d for i in1:n]end"""Setup a table of equal spaces to be used on the sides of factor nodes. """functionsetup_divn(n) divn =Dict()for i in1:n divn[i] =spaces(i)endreturn divnend"""Calculate the arc radius to be used for `CFFG`s"""functionradius(constrained::Bool)if constrainedreturn"\\ar"## arc radius of a fac nodeelsereturn"\\nr"## node radius of a fac nodeendend
radius
_ytop =30#18 ## grid y-value at the top of picture_show_grid =false## displays a grid with the FFG if true_node_size ="10mm"#"15mm" ## the length of the sides of factor node boxes (in millimeters)## _show_var_points = false ## draws the var circles if true_xsep =2# _xsep = 2.5_tsep =6*_xsep# _tsep = 7*_xsep# _tsep = 7.5*_xsep# _tsep = 8*_xsep# _tsep = 9*_xsep_ysep =3_divn =setup_divn(15)
get_preamble_and_postamble (generic function with 1 method)
"""Given the name of a layer, draws the nodes in that layer. Used for all layerscontaining system variables (`sysvar`s). These are typically:`cntrl`, `stateM4`, `stateM3`, `stateM2`, `stateM1`, `stateP1`, `stateP2`, `obser`, `prefr`"""functiondraw_sysvar_layer(layers2, name, divn, y, iob) anchor = layers2[name].anchor y_cntrl = y facs = layers2[name].facsfor k in1:length(facs)iflength(facs[k]) >0 x = anchor + (k-1)*_xsep ##anchor for this fac_grpfor (i,v) inenumerate(facs[k]) ##for each fac in this fac_grpwrite(iob, "\\node($(v))[fac] at ($(x), $(y)) {\$$(lup(v))\$};\n")if _CFFG## write(iob, "\\draw ($(v))++(\\ar,0mm) arc[start angle=0, end angle=180, radius=\\ar];\n") ##TOP## write(iob, "\\draw ($(v))++(\\ar,0mm) arc[start angle=0, end angle=-180, radius=\\ar];\n") ##BOTTOMwrite(iob, "\\draw ($(v))++(\\ar,0mm) arc[start angle=0, end angle=-180, radius=\\ar];\n") ##BOTTOMend x += _tsep ##move to next tendendend parvars = layers2[name].parvarsiflength(parvars) >0for k in1:length(parvars)iflength(parvars[k]) >0for (i,v) inenumerate(parvars[k])for (j, w) inenumerate(v)write(iob, "\\node($(w))[var] at (\$ ($(facs[k][i]).south west) + (-.5*\\ns, $(divn[length(v)][j])*\\ns) \$) {};\n")write(iob, "\\node[above] at ($(w)) {\$$(lup(w))\$};\n") ## labelif _deltas_enabled write(iob, "\\node($(w))[dlt] at (\$ ($(w).center) + (-.4*\\ns, 0) \$) {};\n") endendendendendend x = anchor sysvars = layers2[name].sysvarsfor (i,v) inenumerate(sysvars)write(iob, "\\node($(v))[var] at ($(x), $(y-1.5)) {};\n")write(iob, "\\node[above right] at ($(v)) {\$$(lup(v))\$};\n") ## label## write(iob, "\\node($(v))[dlt] at (\$ ($(obser_facs[i]).center) + (0, -2.0*\\ns) \$) {};\n")if _deltas_enabled write(iob, "\\node($(v))[dlt] at (\$ ($(v).center) + (0, -.4*\\ns) \$) {};\n") end x += _tsependend"""Given the name of a layer, draws the nodes in that layer. Used for all layerscontaining parameter variables (`parvar`s). These are typically:`paramB1`, `paramB2`, `paramA`"""functiondraw_param_layer(layers2, name, divn, y, iob) anchor = layers2[name].anchor y_paramB1 = y inifac = layers2[name].inifacif !isempty(inifac) x = anchorwrite(iob, "\\node($(inifac))[fac] at ($(x), $(y)) {\$$(lup(inifac))\$};\n")end iniparvars = layers2[name].iniparvarsiflength(iniparvars) >0 n =length(iniparvars)for (i,v) inenumerate(iniparvars)write(iob, "\\node($(v))[var] at (\$ ($(inifac).south west) + (-.5*\\ns, $(divn[n][i])*\\ns) \$) {};\n")write(iob, "\\node[above] at ($(v)) {\$$(lup(v))\$};\n") ## labelif _deltas_enabled write(iob, "\\node($(v))[dlt] at (\$ ($(v).center) + (-.4*\\ns, 0) \$) {};\n") endendend parvar = layers2[name].parvarif !isempty(parvar) x = anchor +1*_xsepwrite(iob, "\\node($(parvar))[eql] at ($(x), $(y)) {\$=\$};\n")write(iob, "\\node[xshift=.4*\\ns, yshift=.4*\\ns] at ($(parvar)) {\$$(lup(parvar))\$};\n") ## labelendend"""Given the name of a layer, draws the nodes in that layer. Used for the `state0` layer.This is the layer containing factors (sometimes interleafed with variables) associatedwith the state transition sequence."""functiondraw_state_layer(layers2, name, divn, y, iob) anchor = layers2[name].anchor y_state0 = y inifac = layers2[name].inifacif !isempty(inifac) x = anchor -2*_xsepwrite(iob, "\\node($(inifac))[fac] at ($(x), $(y)) {\$$(lup(inifac))\$};\n")end iniparvars = layers2[name].iniparvarsiflength(iniparvars) >0 n =length(iniparvars)for (i,v) inenumerate(iniparvars)write(iob, "\\node($(v))[var] at (\$ ($(inifac).south west) + (-.5*\\ns, $(divn[n][i])*\\ns) \$) {};\n")write(iob, "\\node[above] at ($(v)) {\$$(lup(v))\$};\n") ## labelif _deltas_enabled write(iob, "\\node($(v))[dlt] at (\$ ($(v).center) + (-.4*\\ns, 0) \$) {};\n") endendend sysvars = layers2[name].sysvarsiflength(sysvars) >0 x = anchor -1*_xsep;write(iob, "\\node($(sysvars[1]))[var] at ($(x), $(y)) {};\n")write(iob, "\\node[above] at ($(sysvars[1])) {\$$(lup(sysvars[1]))\$};\n") ## label x += _tsepfor (i,v) inenumerate(sysvars[2:end])if i <length(sysvars) -1write(iob, "\\node($(v))[eql] at ($(x), $(y)) {\$=\$};\n")write(iob, "\\node[xshift=.4*\\ns, yshift=.4*\\ns] at ($(v)) {\$$(lup(v))\$};\n") ## labelelsewrite(iob, "\\node($(v))[var] at ($(x), $(y)) {};\n")write(iob, "\\node[above] at ($(v)) {\$$(lup(v))\$};\n") ## labelend x += _tsependend facs = layers2[name].facsfor k in1:length(facs)iflength(facs[k]) >0 x = anchor + (k-1)*_xsepfor (i,v) inenumerate(facs[k])if k %2==1## for facswrite(iob, "\\node($(v))[fac] at ($(x), $(y)) {\$$(lup(v))\$};\n")if _CFFG## write(iob, "\\draw ($(v))++(\\ar,0mm) arc[start angle=0, end angle=180, radius=\\ar];\n") ##TOPwrite(iob, "\\draw ($(v))++(\\ar,0mm) arc[start angle=0, end angle=-180, radius=\\ar];\n") ##BOTTOMendelse## for varswrite(iob, "\\node($(v))[var] at ($(x), $(y)) {};\n")write(iob, "\\node[above] at ($(v)) {\$$(lup(v))\$};\n") ## label end x += _tsependendend## to have parvars in state layer at north, this needs to be adjusted parvars = layers2[name].parvarsiflength(parvars) >0for k in1:length(parvars)iflength(parvars[k]) >0for (i,v) inenumerate(parvars[k])for (j, w) inenumerate(v)write(iob, "\\node($(w))[var] at (\$ ($(facs[k][i]).south west) + (-.5*\\ns, $(divn[length(v)][j])*\\ns) \$) {};\n")write(iob, "\\node[above] at ($(w)) {\$$(lup(w))\$};\n") ## labelif _deltas_enabled write(iob, "\\node($(w))[dlt] at (\$ ($(w).center) + (-.4*\\ns, 0) \$) {};\n") endendendendendendend
draw_state_layer
"""Given the name of a layer, draws the links for that layer. Links are usually: from `iniparvar`s or `parvars` (parameter variable nodes) to `fac`s (factor nodes) from the `sysvar`s (system variable nodes) to the linked `fac`s (factor nodes) in the layer belowUsed for all layers containing system variables (`sysvar`s). These are typically:`cntrl`, `stateM4`, `stateM3`, `stateM2`, `stateM1`, `stateP1`, `stateP2`, `obser`, `prefr`"""functiondraw_sysvar_links(layers2, name, divn, y, iob) ## for cntrl & obser parvars = layers2[name].parvarsiflength(parvars) >0for k in1:length(parvars)iflength(parvars[k]) >0 facs = layers2[name].facsfor (i,v) inenumerate(parvars[k])for (j, w) inenumerate(v)## write(iob, "\\draw ($(w)) -- ([shift={(0, $(divn[length(v)][j])*\\ns)}]$(facs[1][i]).south west);\n")## can also have CFFG when using .center:write(iob, "\\draw ($(w)) -- ([shift={(-$(radius(_CFFG)), ($(divn[length(v)][j])-.5)*\\ns)}]$(facs[1][i]).center);\n")endendendendend sysvars = layers2[name].sysvars facs = layers2[name].facsfor (i,v) inenumerate(sysvars)write(iob, "\\draw ($(facs[1][i])) -- ($(sysvars[i]));\n")end link_facs = layers2[name].linksif link_facs !==nothingfor (i,v) inenumerate(link_facs[1])write(iob, "\\draw ($(sysvars[i])) -- ($(v));\n") ## ignore s_0# write(iob, "\\draw ($(sysvars[i+1][1])) -- ($(v));\n") ## ignore s_0# write(iob, "\\draw ($(sysvars[2][i])) -- ($(v));\n") ## ignore s_0endendend"""Given the name of a layer, draws the links for that layer. Links are usually: from `iniparvar`s or `parvars` (parameter variable nodes) to `fac`s (factor nodes) from the `parvar` (parameter variable node) to the linked `fac`s (factor nodes)Used for all layers containing parameter variables (`parvar`s). These are typically:`paramB1`, `paramB2`, `paramA`"""functiondraw_param_links(layers2, name, divn, y, iob) iniparvars = layers2[name].iniparvars n =length(iniparvars) inifac = layers2[name].inifacfor (i,v) inenumerate(iniparvars)## write(iob, "\\draw[blue, line width=3pt] ($(v)) -- ([shift={(0, $(divn[length(iniparamA_vars)][i]))}]$(iniparamA_fac).south west);\n")write(iob, "\\draw ($(v)) -- ([shift={(0, $(divn[n][i])*\\ns)}]$(inifac).south west);\n")end parvar = layers2[name].parvarwrite(iob, "\\draw ($(inifac)) -| ($(parvar));\n") linked_facs = layers2[name].links parvar = layers2[name].parvar link_fac_side = layers2[name].link_fac_side link_fac_offset = layers2[name].link_fac_offsetfor k in1:length(linked_facs)iflength(linked_facs[k]) >0if link_fac_side =="west"for i in linked_facs[k]write(iob, "\\draw ($(parvar)) -| ([shift={(-\\nr,0)}]$(i).west) -- ([shift={(-$(radius(_CFFG)),0)}]$(i).center);\n")## write(iob, "\\draw ($(parvar)) -| ([shift={(-.5*\\ns,0)}]$(i).west) -- ([shift={(-$(radius(_CFFG)),$(link_fac_offset))}]$(i).center);\n")endelseif link_fac_side =="north"for i in linked_facs[k]write(iob, "\\draw ($(parvar)) -| ([shift={(-$(link_fac_offset)*\\ns, 0)}]$(i).north east);\n")endelseprintln("ERROR: Invalid link_fac_side $link_fac_side")endendendend"""Given the name of a layer, draws the links for that layer. Links are usually: from `iniparvar`s or `parvars` (parameter variable nodes) to `fac`s (factor nodes) from each `fac`/`var` to the next one in the `state0`` sequence.Used for the `state0` layer. This is the layer containing factors (sometimes interleafed with variables) associatedwith the state transition sequence."""functiondraw_state_links(layers2, name, divn, y, iob) inifac = facs = layers2[name].inifac sysvars = layers2[name].sysvars facs = layers2[name].facsif !isempty(inifac) iniparvars = layers2[name].iniparvars n =length(iniparvars)for (i,v) inenumerate(iniparvars)write(iob, "\\draw ($(v)) -- ([shift={(0, $(divn[n][i])*\\ns)}]$(inifac).south west);\n")endwrite(iob, "\\draw ($(inifac)) -| ($(sysvars[1]));\n")endiflength(sysvars) >0for k in1:length(facs[1]) ## for each column seq = [row[k] for row in facs] ## seq k, col kprintln("seq: $seq")iflength(seq) >1## println("\\draw ($(sysvars[k])) -- ([shift={(-$(radius(_CFFG)),0)}]$(seq[1]).center);\n")write(iob, "\\draw ($(sysvars[k])) -- ([shift={(-$(radius(_CFFG)),0)}]$(seq[1]).center);\n")for i in1:length(seq) -1## if i % 2 == 1 ## for facs## write(iob, "\\draw ($(seq[i])) -- ($(seq[i+1]));\n")## else ## for vars## println("\\draw ([shift={($(radius(_CFFG)),0)}]$(seq[i]).center) -- ([shift={(-$(radius(_CFFG)),0)}]$(seq[i+1]).center);\n")write(iob, "\\draw ([shift={($(radius(_CFFG)),0)}]$(seq[i]).center) -- ([shift={(-$(radius(_CFFG)),0)}]$(seq[i+1]).center);\n")## write(iob, "\\draw ($(seq[i])) -- ($(seq[i+1]));\n")## endend## println("\\draw ([shift={($(radius(_CFFG)),0)}]$(fac_seq[end]).center) -- ([shift={(-$(radius(_CFFG)),0)}]$(sysvars[k+1]).center);\n")write(iob, "\\draw ([shift={($(radius(_CFFG)),0)}]$(seq[end]).center) -- ($(sysvars[k+1]));\n")## write(iob, "\\draw ($(seq[end])) -- ($(sysvars[k+1]));\n")elseprintln("... doing the else ...")for (i,v) inenumerate(facs[1])## println("\\draw ([shift={($(radius(_CFFG)),0)}]$(v).center) -- ($(sysvars[i+1]).west);\n")write(iob, "\\draw ([shift={($(radius(_CFFG)),0)}]$(v).center) -- ($(sysvars[i+1]).west);\n")## write(iob, "\\draw ($(v)) -- ($(sysvars[i+1]).west);\n")endfor (i,v) inenumerate(facs[1])## println("\\draw ([shift={(-$(radius(_CFFG)),0)}]$(v).center) -- ($(sysvars[i]).east);\n")write(iob, "\\draw ([shift={(-$(radius(_CFFG)),0)}]$(v).center) -- ($(sysvars[i]).east);\n")## write(iob, "\\draw ($(v)) -- ($(sysvars[i]).east);\n")endendendend link_facs = layers2[name].linksfor (i,v) inenumerate(link_facs[1])write(iob, "\\draw ($(sysvars[i+1])) -- ($(v));\n") ## ignore s_0endend
draw_state_links
"""Given a GraphPPL.Model and a definition of layers, generates the `tikz` code."""functiongenerate_tikz(;heading::String,Model::GraphPPL.Model,layers2::Dict) preamble, postamble =get_preamble_and_postamble(heading) tikz = preamble divn =setup_divn(15) iob =IOBuffer() ## graph ytop = _ytop y = ytopwrite(iob, "%% --------------------------- NODES ---------------------------\n")ifhaskey(layers2, "cntrl") y -= _ysepdraw_sysvar_layer(layers2, "cntrl", divn, y, iob)endifhaskey(layers2, "paramB1") y -= _ysepdraw_param_layer(layers2, "paramB1", divn, y, iob)endifhaskey(layers2, "paramB2") y -= _ysepdraw_param_layer(layers2, "paramB2", divn, y, iob)endifhaskey(layers2, "stateM4") y -= _ysepdraw_sysvar_layer(layers2, "stateM4", divn, y, iob)endifhaskey(layers2, "stateM3") y -= _ysepdraw_sysvar_layer(layers2, "stateM3", divn, y, iob)endifhaskey(layers2, "stateM2") y -= _ysepdraw_sysvar_layer(layers2, "stateM2", divn, y, iob)endifhaskey(layers2, "stateM1") y -= _ysepdraw_sysvar_layer(layers2, "stateM1", divn, y, iob)endifhaskey(layers2, "state0") y -= _ysepdraw_state_layer(layers2, "state0", divn, y, iob)endifhaskey(layers2, "stateP1") y -= _ysepdraw_sysvar_layer(layers2, "stateP1", divn, y, iob)endifhaskey(layers2, "stateP2") y -= _ysepdraw_sysvar_layer(layers2, "stateP2", divn, y, iob)endifhaskey(layers2, "obser") y -= _ysepdraw_sysvar_layer(layers2, "obser", divn, y, iob)endifhaskey(layers2, "prefr") y -= _ysepdraw_sysvar_layer(layers2, "prefr", divn, y, iob)endifhaskey(layers2, "paramA") y -= _ysepdraw_param_layer(layers2, "paramA", divn, y, iob)endwrite(iob, "%% --------------------------- LINKS ---------------------------\n") y = ytopifhaskey(layers2, "cntrl")draw_sysvar_links(layers2, "cntrl", divn, y, iob)endifhaskey(layers2, "paramB1")draw_param_links(layers2, "paramB1", divn, y, iob)endifhaskey(layers2, "paramB2") y -= _ysepdraw_param_links(layers2, "paramB2", divn, y, iob)endifhaskey(layers2, "stateM4")draw_sysvar_links(layers2, "stateM4", divn, y, iob)endifhaskey(layers2, "stateM3")draw_sysvar_links(layers2, "stateM3", divn, y, iob)endifhaskey(layers2, "stateM2")draw_sysvar_links(layers2, "stateM2", divn, y, iob)endifhaskey(layers2, "stateM1")draw_sysvar_links(layers2, "stateM1", divn, y, iob)endifhaskey(layers2, "state0")draw_state_links(layers2, "state0", divn, y, iob)endifhaskey(layers2, "stateP1")draw_sysvar_links(layers2, "stateP1", divn, y, iob)endifhaskey(layers2, "stateP2")draw_sysvar_links(layers2, "stateP2", divn, y, iob)endifhaskey(layers2, "obser")draw_sysvar_links(layers2, "obser", divn, y, iob)endifhaskey(layers2, "prefr")draw_sysvar_links(layers2, "prefr", divn, y, iob)endifhaskey(layers2, "paramA")draw_param_links(layers2, "paramA", divn, y, iob)end## DRAW ALL RAW LINKS FOR REFERENCEif _raw_links_enabled meta_graph = Model.graphfor edge in MetaGraphsNext.edges(meta_graph) source_vertex = MetaGraphsNext.label_for(meta_graph, edge.src) dest_vertex = MetaGraphsNext.label_for(meta_graph, edge.dst)write(iob, "\\draw[dotted, red, line width=1pt] ($(sanitized(string(source_vertex)))) -- ($(sanitized(string(dest_vertex))));\n")endend## WRITE EQUAL NODES AGAINifhaskey(layers2, "paramB1")write(iob, "\\node($(layers2["paramB1"].parvar))[eql] at ($(layers2["paramB1"].parvar)) {\$=\$};\n")endifhaskey(layers2, "paramB2")write(iob, "\\node($(layers2["paramB2"].parvar))[eql] at ($(layers2["paramB2"].parvar)) {\$=\$};\n")endifhaskey(layers2, "paramA")write(iob, "\\node($(layers2["paramA"].parvar))[eql] at ($(layers2["paramA"].parvar)) {\$=\$};\n")end## OVERLAY BEADSif _CFFGifhaskey(layers2, "paramB1") write(iob, "\\draw ($(layers2["paramB1"].inifac).east) circle(\\br)[fill=white];\n") endifhaskey(layers2, "paramB2") write(iob, "\\draw ($(layers2["paramB2"].inifac).east) circle(\\br)[fill=white];\n") endifhaskey(layers2, "state0") write(iob, "\\draw ($(layers2["state0"].inifac).east) circle(\\br)[fill=white];\n") endifhaskey(layers2, "paramA") write(iob, "\\draw ($(layers2["paramA"].inifac).east) circle(\\br)[fill=white];\n") end anchor =8*_xsep obser_facs = layers2["obser"].facsfor k in1:length(obser_facs)iflength(obser_facs[k]) >0 x = anchor + (k-1)*_xsepfor (i,v) inenumerate(obser_facs[k])if _CFFGwrite(iob, "\\draw ($(v).north) circle(\\br)[fill=white];\n")write(iob, "\\draw ($(v).south)++(-1*\\br,-1*\\br) rectangle([shift={(\\br,\\br)}]$(v).south)[fill=white];\n")end x += _tsependendendend graph =String(take!(iob)) tikz *= graph tikz *= postamblereturn tikzend
where \(y_i \in \{0, 1\}\) is a binary observation induced by Bernoulli likelihood while \(\theta\) is a Beta prior distribution on the parameter of Bernoulli. We are interested in inferring the posterior distribution of \(\theta\).
See Fig 4.1 in 2023_Bagaev_RPPfSBI_Thesis. So, \(N\)\(\delta\) factors, 1 Beta factor, \(N\) Ber factors.
@modelfunctioncoin_model(y, a, b)## We endow θ parameter of our model with some prior θ ~Beta(a, b)## We assume that outcome of each coin flip is governed by the Bernoulli distributionfor i ineachindex(y) y[i] ~Bernoulli(θ)endend
meta_graph = coin_meta_graph## Shorten some labels to make graph more readable# str_labels = [string(lab) for lab in labels(meta_graph)]# replacements = # Pair("MvNormalMeanCovariance", "Nmc"), # Pair("MvNormalMeanPrecision", "Nmp"), # Pair("constvar", "cv"),# Pair("x", "XXXXX") ## make more obvious# short_labels = [replace(s, replacements...) for s in str_labels]GraphPlot.gplot( ## existing plotting functionality meta_graph, layout=spring_layout, nodelabel=collect(labels(meta_graph)),## nodelabel=short_labels, nodelabelsize=0.1, NODESIZE=0.02, ## diameter of the nodes NODELABELSIZE=1.5,# nodelabelc="white", nodelabelc="green", nodelabeldist=0.0, nodefillc=nothing, ## "cyan" edgestrokec="red",## ImageSize = (800, 800) ##- does not work)
obser_var_label_strs: ["y_7", "y_9", "y_11"]
obser_var_label_strs[i]: y_7
var_neighbours[j]: Bernoulli_8
in_neighbor_label: θ_1
out_neighbor_label: y_7
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(y_7, out, NodeData in context with properties name = y, index = 1), (θ_1, p, NodeData in context with properties name = θ, index = nothing)]
obser_var_label_strs[i]: y_9
var_neighbours[j]: Bernoulli_10
in_neighbor_label: θ_1
out_neighbor_label: y_9
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(y_9, out, NodeData in context with properties name = y, index = 2), (θ_1, p, NodeData in context with properties name = θ, index = nothing)]
obser_var_label_strs[i]: y_11
var_neighbours[j]: Bernoulli_12
in_neighbor_label: θ_1
out_neighbor_label: y_11
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(y_11, out, NodeData in context with properties name = y, index = 3), (θ_1, p, NodeData in context with properties name = θ, index = nothing)]
var_neighbours[1]: Beta_6
in_neighbor_label: constvar_2_3
out_neighbor_label: θ_1
inifac_out_neighbor_str: θ_1
inifac_str: Beta_6
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(θ_1, out, NodeData in context with properties name = θ, index = nothing), (constvar_2_3, a, NodeData in context with properties name = constvar_2, index = nothing), (constvar_4_5, b, NodeData in context with properties name = constvar_4, index = nothing)]
out_var: θ_1
coin_tikz =generate_tikz( heading ="Coin Toss", Model = coin_gppl_model, layers2 = layers2) ## fac layers, i.e. not var (vars are handled with assoced layer)## print(coin_tikz)show_tikz(coin_tikz); ## write to file rather than show in notebook
@modelfunctionhidden_markov_model(x) B ~MatrixDirichlet(ones(3, 3)) A ~MatrixDirichlet([10.01.01.0; 1.010.01.0; 1.01.010.0 ]) d =fill(1.0/3.0, 3) s₀ ~Categorical(d) sₖ₋₁ = s₀for k ineachindex(x) s[k] ~Transition(sₖ₋₁, B) x[k] ~Transition(s[k], A) sₖ₋₁ = s[k]endend
meta_graph = hmm_meta_graph## Shorten some labels to make graph more readable# str_labels = [string(lab) for lab in labels(meta_graph)]# replacements = # Pair("MvNormalMeanCovariance", "Nmc"), # Pair("MvNormalMeanPrecision", "Nmp"), # Pair("constvar", "cv"),# Pair("x", "XXXXX") ## make more obvious# short_labels = [replace(s, replacements...) for s in str_labels]GraphPlot.gplot( ## existing plotting functionality meta_graph, layout=spring_layout, nodelabel=collect(labels(meta_graph)),## nodelabel=short_labels, nodelabelsize=0.1, NODESIZE=0.02, ## diameter of the nodes NODELABELSIZE=1.5,# nodelabelc="white", nodelabelc="green", nodelabeldist=0.0, nodefillc=nothing, ## "cyan" edgestrokec="red",## ImageSize = (800, 800) ##- does not work)
var_neighbours[1]: MatrixDirichlet_4
in_neighbor_label: constvar_2_3
out_neighbor_label: B_1
inifac_out_neighbor_str: B_1
inifac_str: MatrixDirichlet_4
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(B_1, out, NodeData in context with properties name = B, index = nothing), (constvar_2_3, a, NodeData in context with properties name = constvar_2, index = nothing)]
out_var: B_1
var_neighbours[1]: Categorical_12
in_neighbor_label: constvar_10_11
out_neighbor_label: s₀_9
inifac_out_neighbor_str: s₀_9
inifac_str: Categorical_12
var: (s₀_9, out, NodeData in context with properties name = s₀, index = nothing)
=== NEXT var node is: s₀_9
var_label: s₀_9
var_index: 0
var_neighbours: Any["Categorical_12", "Transition_14"]
var_neighbours[1]: Categorical_12
in_neighbor_label: constvar_10_11
out_neighbor_label: s₀_9
var_neighbours[2]: Transition_14
in_neighbor_label: s₀_9
out_neighbor_label: s_13
=== NEXT fac node is: Transition_14
var: (s_13, out, NodeData in context with properties name = s, index = 1)
=== NEXT var node is: s_13
var_label: s_13
var_index: 1
var_neighbours: Any["Transition_14", "Transition_16", "Transition_18"]
var_neighbours[1]: Transition_14
in_neighbor_label: s₀_9
out_neighbor_label: s_13
var_neighbours[2]: Transition_16
in_neighbor_label: s_13
out_neighbor_label: x_15
var_neighbours[3]: Transition_18
in_neighbor_label: s_13
out_neighbor_label: s_17
=== NEXT fac node is: Transition_18
var: (s_17, out, NodeData in context with properties name = s, index = 2)
=== NEXT var node is: s_17
var_label: s_17
var_index: 2
var_neighbours: Any["Transition_18", "Transition_20"]
var_neighbours[1]: Transition_18
in_neighbor_label: s_13
out_neighbor_label: s_17
var_neighbours[2]: Transition_20
in_neighbor_label: s_17
out_neighbor_label: x_19
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(s₀_9, out, NodeData in context with properties name = s₀, index = nothing), (constvar_10_11, p, NodeData in context with properties name = constvar_10, index = nothing)]
out_var: s₀_9
obser_var_label_strs: ["x_15", "x_19"]
obser_var_label_strs[i]: x_15
var_neighbours[j]: Transition_16
in_neighbor_label: s_13
out_neighbor_label: x_15
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(x_15, out, NodeData in context with properties name = x, index = 1), (s_13, in, NodeData in context with properties name = s, index = 1), (A_5, a, NodeData in context with properties name = A, index = nothing)]
obser_var_label_strs[i]: x_19
var_neighbours[j]: Transition_20
in_neighbor_label: s_17
out_neighbor_label: x_19
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(x_19, out, NodeData in context with properties name = x, index = 2), (s_17, in, NodeData in context with properties name = s, index = 2), (A_5, a, NodeData in context with properties name = A, index = nothing)]
var_neighbours[1]: MatrixDirichlet_8
in_neighbor_label: constvar_6_7
out_neighbor_label: A_5
inifac_out_neighbor_str: A_5
inifac_str: MatrixDirichlet_8
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(A_5, out, NodeData in context with properties name = A, index = nothing), (constvar_6_7, a, NodeData in context with properties name = constvar_6, index = nothing)]
out_var: A_5
hmm_tikz =generate_tikz( heading ="Hidden Markov Model", Model = hmm_gppl_model, layers2 = layers2) ## fac layers, i.e. not var (vars are handled with assoced layer)## print(hmm_tikz)show_tikz(hmm_tikz); ## write to file rather than show in notebook
meta_graph = lar_meta_graph## Shorten some labels to make graph more readable# str_labels = [string(lab) for lab in labels(meta_graph)]# replacements = # Pair("MvNormalMeanCovariance", "Nmc"), # Pair("MvNormalMeanPrecision", "Nmp"), # Pair("constvar", "cv"),# Pair("x", "XXXXX") ## make more obvious# short_labels = [replace(s, replacements...) for s in str_labels]GraphPlot.gplot( ## existing plotting functionality meta_graph, layout=spring_layout, nodelabel=collect(labels(meta_graph)),## nodelabel=short_labels, nodelabelsize=0.1, NODESIZE=0.02, ## diameter of the nodes NODELABELSIZE=1.5,# nodelabelc="white", nodelabelc="green", nodelabeldist=0.0, nodefillc=nothing, ## "cyan" edgestrokec="red",## ImageSize = (800, 800) ##- does not work)
var_neighbours[1]: NormalMeanPrecision_12
in_neighbor_label: constvar_8_9
out_neighbor_label: θ_7
inifac_out_neighbor_str: θ_7
inifac_str: NormalMeanPrecision_12
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(θ_7, out, NodeData in context with properties name = θ, index = nothing), (constvar_8_9, μ, NodeData in context with properties name = constvar_8, index = nothing), (constvar_10_11, τ, NodeData in context with properties name = constvar_10, index = nothing)]
out_var: θ_7
var_neighbours[1]: GammaShapeRate_6
in_neighbor_label: constvar_2_3
out_neighbor_label: γ_1
inifac_out_neighbor_str: γ_1
inifac_str: GammaShapeRate_6
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(γ_1, out, NodeData in context with properties name = γ, index = nothing), (constvar_2_3, α, NodeData in context with properties name = constvar_2, index = nothing), (constvar_4_5, β, NodeData in context with properties name = constvar_4, index = nothing)]
out_var: γ_1
var_neighbours[1]: NormalMeanPrecision_18
in_neighbor_label: constvar_14_15
out_neighbor_label: s₀_13
inifac_out_neighbor_str: s₀_13
inifac_str: NormalMeanPrecision_18
var: (s₀_13, out, NodeData in context with properties name = s₀, index = nothing)
=== NEXT var node is: s₀_13
var_label: s₀_13
var_index: 0
var_neighbours: Any["NormalMeanPrecision_18", "AR_20"]
var_neighbours[1]: NormalMeanPrecision_18
in_neighbor_label: constvar_14_15
out_neighbor_label: s₀_13
var_neighbours[2]: AR_20
in_neighbor_label: s₀_13
out_neighbor_label: s_19
=== NEXT fac node is: AR_20
var: (s_19, y, NodeData in context with properties name = s, index = 1)
=== NEXT var node is: s_19
var_label: s_19
var_index: 1
var_neighbours: Any["AR_20", "*_24", "AR_30"]
var_neighbours[1]: AR_20
in_neighbor_label: s₀_13
out_neighbor_label: s_19
var_neighbours[2]: *_24
in_neighbor_label: constvar_22_23
out_neighbor_label: anonymous_var_graphppl_21
var_neighbours[3]: AR_30
in_neighbor_label: s_19
out_neighbor_label: s_29
=== NEXT fac node is: AR_30
var: (s_29, y, NodeData in context with properties name = s, index = 2)
=== NEXT var node is: s_29
var_label: s_29
var_index: 2
var_neighbours: Any["AR_30", "*_34", "AR_40"]
var_neighbours[1]: AR_30
in_neighbor_label: s_19
out_neighbor_label: s_29
var_neighbours[2]: *_34
in_neighbor_label: constvar_32_33
out_neighbor_label: anonymous_var_graphppl_31
var_neighbours[3]: AR_40
in_neighbor_label: s_29
out_neighbor_label: s_39
=== NEXT fac node is: AR_40
var: (s_39, y, NodeData in context with properties name = s, index = 3)
=== NEXT var node is: s_39
var_label: s_39
var_index: 3
var_neighbours: Any["AR_40", "*_44"]
var_neighbours[1]: AR_40
in_neighbor_label: s_29
out_neighbor_label: s_39
var_neighbours[2]: *_44
in_neighbor_label: constvar_42_43
out_neighbor_label: anonymous_var_graphppl_41
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(s₀_13, out, NodeData in context with properties name = s₀, index = nothing), (constvar_14_15, μ, NodeData in context with properties name = constvar_14, index = nothing), (constvar_16_17, τ, NodeData in context with properties name = constvar_16, index = nothing)]
out_var: s₀_13
obser_var_label_strs: ["x_25", "x_35", "x_45"]
obser_var_label_strs[i]: x_25
var_neighbours[j]: NormalMeanPrecision_28
in_neighbor_label: anonymous_var_graphppl_21
out_neighbor_label: x_25
obser_var_label_strs[i]: x_35
var_neighbours[j]: NormalMeanPrecision_38
in_neighbor_label: anonymous_var_graphppl_31
out_neighbor_label: x_35
obser_var_label_strs[i]: x_45
var_neighbours[j]: NormalMeanPrecision_48
in_neighbor_label: anonymous_var_graphppl_41
out_neighbor_label: x_45
lar_tikz =generate_tikz( heading ="Time-Varying Autoregressive Model (Univariate)", Model = lar_gppl_model, layers2 = layers2) ## fac layers, i.e. not var (vars are handled with assoced layer)## print(lar_tikz)show_tikz(lar_tikz); ## write to file rather than show in notebook
To infer goal-driven (i.e. purposeful) behavior, we add prior beliefs \(p^+(\mathbf{x})\) about desired future observations. This leads to an extended agent model:
\[p(\mathbf{x}_k \mid \mathbf{s}_k) = \mathcal{N}(\mathbf{x}_k \mid \mathbf{s}_k,\,\mathbf\Theta)\] where \(\mathbf{x}_k = (\chi_{1k}, \chi_{2k}, ...)\) denotes observations of the agent after interacting with the environment.
This means we set a vague prior for the initial state.
4.5.3.1 Generative Model for the Drone
The code in the next block defines the agent’s internal beliefs over the external dynamics and its probabilistic model of the environment, which correspond accurately by directly using the functions defined above. We use the @model macro from RxInfer to define the probabilistic model and the meta block to define approximation methods for the nonlinear state-transition functions.
In the model specification we in addition to the current state of the agent we include the beliefs over its future states (up to T steps ahead):
@modelfunctiondronenav_model(x, mᵤ, Vᵤ, mₓ, Vₓ, mₛ₍ₜ₋₁₎, Vₛ₍ₜ₋₁₎, T, Rᵃ)## Transition function g = (sₜ₋₁::AbstractVector) ->begin sₜ =similar(sₜ₋₁) ## Next state sₜ =Aᵃ(sₜ₋₁, 1.0) + sₜ₋₁return sₜend## Function for modeling turn/yaw control h = (u::AbstractVector) ->Rᵃ(u[1]) Γ =_γ*diageye(4) ## Transition precision 𝚯 =_ϑ*diageye(4) ## Observation variance## sₜ₋₁ ~ MvNormal(mean=mₛ₍ₜ₋₁₎, cov=Vₛ₍ₜ₋₁₎) s₀ ~MvNormal(mean=mₛ₍ₜ₋₁₎, cov=Vₛ₍ₜ₋₁₎)## sₖ₋₁ = sₜ₋₁ sₖ₋₁ = s₀local sfor k in1:T## Control u[k] ~MvNormal(mean=mᵤ[k], cov=Vᵤ[k]) hIuI[k] ~h(u[k]) where { meta=DeltaMeta(method=Unscented()) }## State transition gIsI[k] ~g(sₖ₋₁) where { meta=DeltaMeta(method=Unscented()) } ghSum[k] ~ gIsI[k] + hIuI[k]#. s[k] ~MvNormal(mean=ghSum[k], precision=Γ)## Likelihood of future observations x[k] ~MvNormal(mean=s[k], cov=𝚯)## Target/Goal prior x[k] ~MvNormal(mean=mₓ[k], cov=Vₓ[k]) sₖ₋₁ = s[k]endreturn (s, )end
meta_graph = drone_meta_graph## Shorten some labels to make graph more readablestr_labels = [string(lab) for lab inlabels(meta_graph)]replacements =Pair("MvNormalMeanCovariance", "Nmc"), Pair("MvNormalMeanPrecision", "Nmp"), Pair("constvar", "cv"),Pair("x", "XXXXX") ## make more obviousshort_labels = [replace(s, replacements...) for s in str_labels]GraphPlot.gplot( ## existing plotting functionality meta_graph, layout=spring_layout,## nodelabel=collect(labels(meta_graph)), nodelabel=short_labels, nodelabelsize=0.1, NODESIZE=0.02, ## diameter of the nodes NODELABELSIZE=1.5,# nodelabelc="white", nodelabelc="green", nodelabeldist=0.0, nodefillc=nothing, ## "cyan" edgestrokec="red",## ImageSize = (800, 800) ##- does not work)
cntrl_var_label_strs: ["u_7", "u_32", "u_57"]
cntrl_var_label_strs[i]: u_7
var_neighbours[j]: MvNormalMeanCovariance_12
in_neighbor_label: constvar_8_9
out_neighbor_label: u_7
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(u_7, out, NodeData in context with properties name = u, index = 1), (constvar_8_9, μ, NodeData in context with properties name = constvar_8, index = nothing), (constvar_10_11, Σ, NodeData in context with properties name = constvar_10, index = nothing)]
var_neighbours[j]: CH_14
in_neighbor_label: u_7
out_neighbor_label: hIuI_13
cntrl_var_label_strs[i]: u_32
var_neighbours[j]: MvNormalMeanCovariance_37
in_neighbor_label: constvar_33_34
out_neighbor_label: u_32
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(u_32, out, NodeData in context with properties name = u, index = 2), (constvar_33_34, μ, NodeData in context with properties name = constvar_33, index = nothing), (constvar_35_36, Σ, NodeData in context with properties name = constvar_35, index = nothing)]
var_neighbours[j]: CH_39
in_neighbor_label: u_32
out_neighbor_label: hIuI_38
cntrl_var_label_strs[i]: u_57
var_neighbours[j]: MvNormalMeanCovariance_62
in_neighbor_label: constvar_58_59
out_neighbor_label: u_57
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(u_57, out, NodeData in context with properties name = u, index = 3), (constvar_58_59, μ, NodeData in context with properties name = constvar_58, index = nothing), (constvar_60_61, Σ, NodeData in context with properties name = constvar_60, index = nothing)]
var_neighbours[j]: CH_64
in_neighbor_label: u_57
out_neighbor_label: hIuI_63
var_neighbours[1]: MvNormalMeanCovariance_6
in_neighbor_label: constvar_2_3
out_neighbor_label: s₀_1
inifac_out_neighbor_str: s₀_1
inifac_str: MvNormalMeanCovariance_6
var: (s₀_1, out, NodeData in context with properties name = s₀, index = nothing)
=== NEXT var node is: s₀_1
var_label: s₀_1
var_index: 0
var_neighbours: Any["MvNormalMeanCovariance_6", "SH_16"]
var_neighbours[1]: MvNormalMeanCovariance_6
in_neighbor_label: constvar_2_3
out_neighbor_label: s₀_1
var_neighbours[2]: SH_16
in_neighbor_label: s₀_1
out_neighbor_label: gIsI_15
Take this var_label as next fac because we ONLY have 2 var_neighbours
=== NEXT fac node is: SH_16
var: (gIsI_15, out, NodeData in context with properties name = gIsI, index = 1)
=== NEXT var node is: gIsI_15
var_label: gIsI_15
var_index: 1
var_neighbours: Any["SH_16", "+_18"]
var_neighbours[1]: SH_16
in_neighbor_label: s₀_1
out_neighbor_label: gIsI_15
var_neighbours[2]: +_18
in_neighbor_label: gIsI_15
out_neighbor_label: ghSum_17
Take this var_label as next fac because we ONLY have 2 var_neighbours
=== NEXT fac node is: +_18
var: (ghSum_17, out, NodeData in context with properties name = ghSum, index = 1)
=== NEXT var node is: ghSum_17
var_label: ghSum_17
var_index: 1
var_neighbours: Any["+_18", "MvNormalMeanPrecision_22"]
var_neighbours[1]: +_18
in_neighbor_label: gIsI_15
out_neighbor_label: ghSum_17
var_neighbours[2]: MvNormalMeanPrecision_22
in_neighbor_label: ghSum_17
out_neighbor_label: s_19
Take this var_label as next fac because we ONLY have 2 var_neighbours
=== NEXT fac node is: MvNormalMeanPrecision_22
var: (s_19, out, NodeData in context with properties name = s, index = 1)
=== NEXT var node is: s_19
var_label: s_19
var_index: 1
var_neighbours: Any["MvNormalMeanPrecision_22", "MvNormalMeanCovariance_26", "SH_41"]
var_neighbours[1]: MvNormalMeanPrecision_22
in_neighbor_label: ghSum_17
out_neighbor_label: s_19
var_neighbours[2]: MvNormalMeanCovariance_26
in_neighbor_label: s_19
out_neighbor_label: x_23
var_neighbours[3]: SH_41
in_neighbor_label: s_19
out_neighbor_label: gIsI_40
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(s₀_1, out, NodeData in context with properties name = s₀, index = nothing), (constvar_2_3, μ, NodeData in context with properties name = constvar_2, index = nothing), (constvar_4_5, Σ, NodeData in context with properties name = constvar_4, index = nothing)]
out_var: s₀_1
obser_var_label_strs: ["x_23", "x_48", "x_73"]
obser_var_label_strs[i]: x_23
var_neighbours[j]: MvNormalMeanCovariance_26
in_neighbor_label: s_19
out_neighbor_label: x_23
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(x_23, out, NodeData in context with properties name = x, index = 1), (s_19, μ, NodeData in context with properties name = s, index = 1), (constvar_24_25, Σ, NodeData in context with properties name = constvar_24, index = nothing)]
var_neighbours[j]: MvNormalMeanCovariance_31
in_neighbor_label: constvar_27_28
out_neighbor_label: x_23
obser_var_label_strs[i]: x_48
var_neighbours[j]: MvNormalMeanCovariance_51
in_neighbor_label: s_44
out_neighbor_label: x_48
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(x_48, out, NodeData in context with properties name = x, index = 2), (s_44, μ, NodeData in context with properties name = s, index = 2), (constvar_49_50, Σ, NodeData in context with properties name = constvar_49, index = nothing)]
var_neighbours[j]: MvNormalMeanCovariance_56
in_neighbor_label: constvar_52_53
out_neighbor_label: x_48
obser_var_label_strs[i]: x_73
var_neighbours[j]: MvNormalMeanCovariance_76
in_neighbor_label: s_69
out_neighbor_label: x_73
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(x_73, out, NodeData in context with properties name = x, index = 3), (s_69, μ, NodeData in context with properties name = s, index = 3), (constvar_74_75, Σ, NodeData in context with properties name = constvar_74, index = nothing)]
var_neighbours[j]: MvNormalMeanCovariance_81
in_neighbor_label: constvar_77_78
out_neighbor_label: x_73
obser_var_label_strs: ["x_23", "x_48", "x_73"]
obser_var_label_strs[i]: x_23
var_neighbours[j]: MvNormalMeanCovariance_26
in_neighbor_label: s_19
out_neighbor_label: x_23
var_neighbours[j]: MvNormalMeanCovariance_31
in_neighbor_label: constvar_27_28
out_neighbor_label: x_23
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(x_23, out, NodeData in context with properties name = x, index = 1), (constvar_27_28, μ, NodeData in context with properties name = constvar_27, index = nothing), (constvar_29_30, Σ, NodeData in context with properties name = constvar_29, index = nothing)]
obser_var_label_strs[i]: x_48
var_neighbours[j]: MvNormalMeanCovariance_51
in_neighbor_label: s_44
out_neighbor_label: x_48
var_neighbours[j]: MvNormalMeanCovariance_56
in_neighbor_label: constvar_52_53
out_neighbor_label: x_48
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(x_48, out, NodeData in context with properties name = x, index = 2), (constvar_52_53, μ, NodeData in context with properties name = constvar_52, index = nothing), (constvar_54_55, Σ, NodeData in context with properties name = constvar_54, index = nothing)]
obser_var_label_strs[i]: x_73
var_neighbours[j]: MvNormalMeanCovariance_76
in_neighbor_label: s_69
out_neighbor_label: x_73
var_neighbours[j]: MvNormalMeanCovariance_81
in_neighbor_label: constvar_77_78
out_neighbor_label: x_73
neighbors: Tuple{GraphPPL.NodeLabel, GraphPPL.EdgeLabel, GraphPPL.NodeData}[(x_73, out, NodeData in context with properties name = x, index = 3), (constvar_77_78, μ, NodeData in context with properties name = constvar_77, index = nothing), (constvar_79_80, Σ, NodeData in context with properties name = constvar_79, index = nothing)]
drone_tikz =generate_tikz( heading ="Drone Flying to Target", Model = drone_gppl_model, layers2 = layers2) ## fac layers, i.e. not var (vars are handled with assoced layer)## print(drone_tikz)show_tikz(drone_tikz); ## write to file rather than show in notebook