Visualizing Forney Factor Graphs when using RxInfer (Part 2)
Exploring possibilities using TikZ
Bayesian Inference
Active Inference
TikZ
RxInfer
Julia
Author
Kobus Esterhuysen
Published
July 6, 2024
Modified
August 15, 2024
In part 1 we laid a foundation for the visualization of Forney Factor Graphs (FFGs) when using RxInfer. In the current part (Part 2) we provide a satisfactory way to layout the graph in a grid format. We can also switch between using the raw variable names from RxInfer and the equivalent mathematical symbols.
In this notebook, we abandon the use of GraphViz in favor of using TikZ. We encountered certain limitations in the GraphViz package that turns out to be less ideal for our purpose. Here is a summary of the limitations:
GraphVizdot layout
allows ortho edges
does not allow node placement at coordinates with pos
GraphVizneato layout
does not allow ortho edges
allows node placement at coordinates with pos
GraphVizfdp layout
allows ortho edges
allows node placement at coordinates with pos
although fdp allows BOTH, nodes at the same x or y coordinate does not always line up properly
due to force directed placement
nodes repel
edges are springs
One of the distinguishing features of the RxInfer Julia package is that it uses reactive message passing (RMP). This means node updates in the factor graph happens asynchronously. A user usually tends to gloss over these details due to the increased complexity. When node updates happen in the more traditional way, i.e. synchronously (for example as with the PyMDP Python package), users tend to have the option of being aware of which node is currently being udpated. It is quite beneficial to conceptualize the workings of factor graphs in general and the visualization of the factor graph itself is often helpful. The existing visualization code is rather limited and that is the reason this project was undertaken. The ultimate goal is reach a visualization which conforms to the Constrained Forney Factor Graph (CFFG) as defined in
GraphPPL.jl exports the @model macro for model specification allowing the acceptance of a regular Julia function and then builds an FFG under the hood. This graph is bipartite where vertices belong to one of two sets: one for factors and one for variables. We will render these verices as follows:
factor vertex is rendered by a rectangle
variable vertex is renderer by a
point if connected to a single factor
equals-rectangle if connected to multiple factors
For each of the examples in this project the following steps happen:
A @model is created
The @model is conditioned on some data
ususally a few points so that the visualization does not become unwieldy
A RxInfer model is created (rxi_model)
A GraphPPL model is obtained (gppl_model)
A meta graph is obtained (meta_graph)
The graph is rendered by the existing implementation (GraphPlot.gplot())
Some vertex labels are inspected for the sake of interest
Setup global parameters to influence the behavior of generate_tikz()
_san_node_ids is provided for node names that needs to be sanitized
_lup_enabled is set to true/false to enable lookup of math names
_seqlup_enabled is set to true/false to enable lookup/sequential lookup of math names
_raw_links_enabled is set to true/false to enable the drawing of raw links for reference
_lup_dict is provided for all variable names needing math equivalents
The TikZ string is generated (generate_tikz())
heading of the FFG
GraphPPL model
layers for the FFG
variable names for each layer (if applicable)
“cntrl” (control layer)
“paramB1” (parameter B1 layer)
“paramB2” (parameter B2 layer)
“state1” (state 1 layer)
“state2” (state 2 layer)
“state3” (state 3 layer)
“obser” (observation layer)
“prefr” (preference layer, for setpoints)
“paramA” (parameter A layer)
The TikZ string is printed
The FFG is saved to a pdf file (show_tikz())
versioninfo() ## Julia version
Julia Version 1.10.0
Commit 3120989f39b (2023-12-25 18:01 UTC)
Build Info:
Official https://julialang.org/ release
Platform Info:
OS: Linux (x86_64-linux-gnu)
CPU: 12 × Intel(R) Core(TM) i7-8700B CPU @ 3.20GHz
WORD_SIZE: 64
LIBM: libopenlibm
LLVM: libLLVM-15.0.7 (ORCJIT, skylake)
Threads: 1 on 12 virtual cores
Environment:
JULIA_NUM_THREADS =
Activating project at `~/.julia/environments/v1.10`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
Resolving package versions...
No Changes to `~/.julia/environments/v1.10/Project.toml`
No Changes to `~/.julia/environments/v1.10/Manifest.toml`
usingRxInfer, MetaGraphsNext
Pkg.status()
Status `~/.julia/environments/v1.10/Project.toml`
[86223c79] Graphs v1.11.2
[7073ff75] IJulia v1.25.0
[fa8bd995] MetaGraphsNext v0.7.0
[8314cec4] PGFPlotsX v1.6.1
⌃ [86711068] RxInfer v3.0.0
[b4f28e30] TikzGraphs v1.4.0
[37f6aa50] TikzPictures v3.5.0
Info Packages marked with ⌃ have new versions available and may be upgradable.
"""Takes the tikz string given by generate_tikz() and writes it to a .tex file to produce a TikZ visualisation. """functionshow_tikz(tikz_code_graph::String)## eval(Meta.parse(dot_code_graph)) ## for GraphViz## see problem in ^v1 under TikzPictures section ## tikz_picture = TikzPicture(tikz_code_graph)## return tikz_picture file_path ="myoutput.tex"## write to file rather than show in notebookopen(file_path, "w") do filewrite(file, tikz_code_graph)endend
show_tikz
### # If your PDF contains a TikZ picture (a vector graphic created using LaTeX), # # you can still use the method I mentioned earlier to display it in a Julia notebook. # # Here’s the method again for your reference:# import Pkg; Pkg.add("PGFPlotsX")# using PGFPlotsX# # Create a struct that represents a PDF file# struct PDF# file::String# end# # Define a function to show the PDF as inline SVG# function Base.show(io::IO, ::MIME"image/svg+xml", pdf::PDF)# svg = first(splitext(pdf.file)) * ".svg"# PGFPlotsX.convert_pdf_to_svg(pdf.file, svg)# write(io, read(svg))# try; rm(svg; force=true); catch e; end# return nothing# end# # Now you can display a PDF file like this:# display(PDF("myoutput.pdf"))# # Internal Error: cairo context error: invalid matrix (not invertible)<0a># # cairo error: invalid matrix (not invertible)
functionretain_can(s::String) ## retain canonical chars of a label string m =match(r"^[a-zA-Zγθs₀\*]*", s)return m.matchendfunctionsanitized(vertex::String)ifhaskey(_san_node_ids, vertex)return _san_node_ids[vertex]elsereturn vertexendendfunctionlup(label::String)if _lup_enabled lup_label =get(_lup_dict, label, replace(label, "_"=>"\\_"))elsereturnreplace(label, "_"=>"\\_")endendfunctionseqlup(label::String, i::Integer)if _seqlup_enabled lup_label =get(_lup_dict, label, split(label, '_')[1]*"_"*string(i))return lup_labelelsereturnreplace(label, "_"=>"\\_")endendfunctionshow_meta_graph(meta_graph) ## just for info str =""for i in MetaGraphsNext.vertices(meta_graph) label = MetaGraphsNext.label_for(meta_graph, i) str *=" $(label);\n"endfor edge in MetaGraphsNext.edges(meta_graph) source_vertex = MetaGraphsNext.label_for(meta_graph, edge.src) dest_vertex = MetaGraphsNext.label_for(meta_graph, edge.dst) str *=" $(source_vertex) -- $(dest_vertex);\n"endprint(str)end
show_meta_graph (generic function with 1 method)
"""Given a user specified canonical factor label string, returns all the GraphPPL associated labels from the context's factor_nodes."""functionstr_facs(gppl_model, factor::String) context = GraphPPL.getcontext(gppl_model) vec = []for (factor_ID, factor_label) inpairs(context.factor_nodes)## println("$(factor_ID), $(factor_label)")## println("$(typeof(factor_ID)), $(typeof(factor_label))") node_data = gppl_model[factor_label]## println("$(node_data.properties.fform)\n")ifretain_can(string(node_data.properties.fform)) == factorappend!(vec, factor_label)endend result =string.(vec)return resultend"""Given a user specified canonical variable label string, returns all the GraphPPL associated labels from the context's individual_variables."""functionstr_individual_vars(gppl_model, var::String) context = GraphPPL.getcontext(gppl_model) vec = []for (node_ID, node_label) inpairs(context.individual_variables)## println("$(node_ID), $(node_label)")## println("$(typeof(node_ID)), $(typeof(node_label))")ifoccursin(var, string(node_ID))append!(vec, node_label)endend result =string.(vec)return resultend"""Given a user specified canonical variable label string, returns all the GraphPPL associated labels from the context's vector_variables."""functionstr_vector_vars(gppl_model, var::String) context = GraphPPL.getcontext(gppl_model) vec = []for (node_ID, node_label) inpairs(context.vector_variables)## println("$(node_ID), $(node_label)")## println("$(typeof(node_ID)), $(typeof(node_label))")ifstring(node_ID) == varappend!(vec, node_label)endend result =string.(vec)return resultend
str_vector_vars
functionspaces(n::Int) d =1/(n +1)return [i*d for i in1:n]endfunctionsetup_divn(n) divn =Dict()for i in1:n divn[i] =spaces(i)endreturn divnend
where \(y_i \in \{0, 1\}\) is a binary observation induced by Bernoulli likelihood while \(\theta\) is a Beta prior distribution on the parameter of Bernoulli. We are interested in inferring the posterior distribution of \(\theta\).
See Fig 4.1 in 2023_Bagaev_RPPfSBI_Thesis. So, \(N\)\(\delta\) factors, 1 Beta factor, \(N\) Ber factors.
@modelfunctioncoin_model(y, a, b)## We endow θ parameter of our model with some prior θ ~Beta(a, b)## We assume that outcome of each coin flip is governed by the Bernoulli distributionfor i ineachindex(y) y[i] ~Bernoulli(θ)endend
@modelfunctionhidden_markov_model(x) B ~MatrixDirichlet(ones(3, 3)) A ~MatrixDirichlet([10.01.01.0; 1.010.01.0; 1.01.010.0 ]) d =fill(1.0/3.0, 3) s₀ ~Categorical(d) sₖ₋₁ = s₀for k ineachindex(x) s[k] ~Transition(sₖ₋₁, B) x[k] ~Transition(s[k], A) sₖ₋₁ = s[k]endend