Visualizing Forney Factor Graphs when using RxInfer (Part 3)
Getting closer to the final format
Bayesian Inference
Active Inference
TikZ
RxInfer
Julia
Author
Kobus Esterhuysen
Published
July 22, 2024
Modified
July 23, 2024
In Part 2 we provided a configurable layout scheme suitable for Forney Factor Graphs (FFGs). We also moved to using TikZ rather than GraphViz (please see Part 2 for an overview of the limitations of this package for our purposes). In this part we:
add decorations for a Constrained Forney Factor Graph (CFFG)
currently, constraint decorations are added blindly, rather than making their addition dependent on information from the RxInfer-derived graph
in a future part, we will rectify this
renamed some variables in generate_tikz() to improve consistency and simplicity
remove the seqlup() function
remove the str_facs() function
remove the str_individual_vars() function
remove the str_vector_vars() function
add decorations for delta (also known as clamped) factors
One of the distinguishing features of the RxInfer Julia package is that it uses reactive message passing (RMP). This means node updates in the factor graph happens asynchronously. A user usually tends to gloss over these details due to the increased complexity. When node updates happen in the more traditional way, i.e. synchronously (for example as with the PyMDP Python package), users tend to have the option of being aware of which node is currently being udpated. It is quite beneficial to conceptualize the workings of factor graphs in general and the visualization of the factor graph itself is often helpful. The existing visualization code is rather limited and that is the reason this project was undertaken. The ultimate goal is reach a visualization which conforms to the Constrained Forney Factor Graph (CFFG) as defined in
GraphPPL.jl exports the @model macro for model specification allowing the acceptance of a regular Julia function and then builds a Forney Factor Graph (FFG) under the hood. This graph is bipartite where nodes belong to one of two sets: one for factors and one for variables. Note that we use the terms node and link rather than vertex and edge to align with RxInfer terminology.
Node Taxonomy
nodes
facs (FFG/RxInfer factor nodes)
vars (FFG/RxInfer variable nodes)
parvars (parameter variables, e.g. \(\alpha, a, A_t, B, \gamma_t, \mu\))
sysvars (system variables, e.g. \(x_t, s_t, u_t\))
Node Visualization
a fac node is visualized by a
square (style fac) if it is a factor
small black square (style dlt) if it is a delta (\(\delta\)) factor
delta factors connect to their associated parvars and sysvars
a var node is visualized by a
point (style var) if connected to a single fac
equals-square (style eql) if connected to multiple facs
rather than converting the bipartite graph from RxInfer to a graph that only contains a single kind of node, we keep the bipartite graph and visualize the nodes differently
For each of the examples in this project the following steps happen:
A @model is created
The @model is conditioned on some data
ususally a few points so that the visualization does not become unwieldy
A RxInfer model is created (rxi_model)
A GraphPPL model is obtained (gppl_model)
A meta graph is obtained (meta_graph)
The graph is rendered by the existing implementation (GraphPlot.gplot())
Some vertex labels are inspected for the sake of interest
Setup global parameters (indicated by a leading underscore character) to influence the behavior of generate_tikz()
_san_node_ids is provided for node names that needs to be sanitized
_lup_enabled is set to true/false to enable lookup of math names
_raw_links_enabled is set to true/false to enable the drawing of raw links for reference
_CFFG is set to true/false to enable the drawing of beads to turn a FFG into a CFFG
_lup_dict is provided for all variable names needing math equivalents
The TikZ string is generated (generate_tikz())
heading of the FFG
GraphPPL model
layers for the FFG
layer names (if applicable)
cntrl (control layer)
paramB1 (parameter B1 layer)
paramB2 (parameter B2 layer)
state1 (state 1 layer)
state2 (state 2 layer)
state3 (state 3 layer)
obser (observation layer)
prefr (preference layer, for setpoints)
paramA (parameter A layer)
The TikZ string is printed
The FFG is saved to a pdf file (show_tikz())
Setup notes
If additional Ubuntu packages are needed for a Julia package
sudo apt-get install package
If you want to run bash on the devcontainer in a local terminal; kobus@Kobuss-Mac-mini ~ %
Activating project at `~/.julia/environments/v1.10`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Status `~/.julia/environments/v1.10/Project.toml`
[fa8bd995] MetaGraphsNext v0.7.0
⌃ [86711068] RxInfer v3.0.0
Info Packages marked with ⌃ have new versions available and may be upgradable.
"""Takes the tikz string given by generate_tikz() and writes it to a .tex file to produce a TikZ visualisation. """functionshow_tikz(tikz_code_graph::String)## eval(Meta.parse(dot_code_graph)) ## for GraphViz## see problem in ^v1 under TikzPictures section ## tikz_picture = TikzPicture(tikz_code_graph)## return tikz_picture file_path ="myoutput.tex"## write to file rather than show in notebookopen(file_path, "w") do filewrite(file, tikz_code_graph)endend
show_tikz
### # If your PDF contains a TikZ picture (a vector graphic created using LaTeX), # # you can still use the method I mentioned earlier to display it in a Julia notebook. # # Here’s the method again for your reference:# import Pkg; Pkg.add("PGFPlotsX")# using PGFPlotsX# # Create a struct that represents a PDF file# struct PDF# file::String# end# # Define a function to show the PDF as inline SVG# function Base.show(io::IO, ::MIME"image/svg+xml", pdf::PDF)# svg = first(splitext(pdf.file)) * ".svg"# PGFPlotsX.convert_pdf_to_svg(pdf.file, svg)# write(io, read(svg))# try; rm(svg; force=true); catch e; end# return nothing# end# # Now you can display a PDF file like this:# display(PDF("myoutput.pdf"))# # Internal Error: cairo context error: invalid matrix (not invertible)<0a># # cairo error: invalid matrix (not invertible)
functionsanitized(vertex::String)ifhaskey(_san_node_ids, vertex)return _san_node_ids[vertex]elsereturn vertexendendfunctionlup(label::String)if _lup_enabled lup_label =get(_lup_dict, label, replace(label, "_"=>"\\_"))elsereturnreplace(label, "_"=>"\\_")endendfunctionshow_meta_graph(meta_graph) ## just for info str =""for i in MetaGraphsNext.vertices(meta_graph) label = MetaGraphsNext.label_for(meta_graph, i) str *=" $(label);\n"endfor edge in MetaGraphsNext.edges(meta_graph) source_vertex = MetaGraphsNext.label_for(meta_graph, edge.src) dest_vertex = MetaGraphsNext.label_for(meta_graph, edge.dst) str *=" $(source_vertex) -- $(dest_vertex);\n"endprint(str)end
show_meta_graph (generic function with 1 method)
functionspaces(n::Int) d =1/(n +1)return [i*d for i in1:n]endfunctionsetup_divn(n) divn =Dict()for i in1:n divn[i] =spaces(i)endreturn divnendfunctionradius(constrained::Bool)if constrainedreturn"\\ar"## arc radius of a fac nodeelsereturn"\\nr"## node radius of a fac nodeendend
radius (generic function with 1 method)
_ytop =18## grid y-value at the top of picture_show_grid =false## displays a grid with the FFG if true_node_size ="15mm"## the length of the sides of factor node boxes (in millimeters)
where \(y_i \in \{0, 1\}\) is a binary observation induced by Bernoulli likelihood while \(\theta\) is a Beta prior distribution on the parameter of Bernoulli. We are interested in inferring the posterior distribution of \(\theta\).
See Fig 4.1 in 2023_Bagaev_RPPfSBI_Thesis. So, \(N\)\(\delta\) factors, 1 Beta factor, \(N\) Ber factors.
@modelfunctioncoin_model(y, a, b)## We endow θ parameter of our model with some prior θ ~Beta(a, b)## We assume that outcome of each coin flip is governed by the Bernoulli distributionfor i ineachindex(y) y[i] ~Bernoulli(θ)endend
@modelfunctionhidden_markov_model(x) B ~MatrixDirichlet(ones(3, 3)) A ~MatrixDirichlet([10.01.01.0; 1.010.01.0; 1.01.010.0 ]) d =fill(1.0/3.0, 3) s₀ ~Categorical(d) sₖ₋₁ = s₀for k ineachindex(x) s[k] ~Transition(sₖ₋₁, B) x[k] ~Transition(s[k], A) sₖ₋₁ = s[k]endend