Tag Archives: julialang

Is case_when needed in DataFrames.jl?

By: Blog by Bogumił Kamiński

Re-posted from: https://bkamins.github.io/julialang/2020/12/18/casewhen.html

Introduction

Recently I received a very interesting question regarding transforming data
using the DataFrames.jl. One of the users wanted to know if we have
a functionality similar to the case_when function in dplyr. When trying
to answer it I came to the conclusion that we do not need it that we can
reproduce it using the ⋅ ? ⋅ : ⋅ ternary operator.

In this post I will be reproducing selected examples from the documentation of
case_when.

Reproducing dplyr examples

In the examples I will first show R code and then show an Julia code.

I R examples I assume that dplyr is loaded. The Julia examples were tested under
Julia 1.5.3, DataFrames.jl 0.22.2, DataFramesMeta.jl v0.6.0, HTTP.jl v0.8.19,
JSON3.jl v1.5.1, and Pipe.jl v1.3.0.

Example 1

This is the most basic case_when usage scenario.

R code:

> library(dplyr)
> x <- 1:10
> case_when(
+   x %% 35 == 0 ~ "fizz buzz",
+   x %% 5 == 0 ~ "fizz",
+   x %% 7 == 0 ~ "buzz",
+   TRUE ~ as.character(x)
+ )
 [1] "1"    "2"    "3"    "4"    "fizz" "6"    "buzz" "8"    "9"    "fizz"

Julia code:

julia> x = 1:10
1:10

julia> (x -> x % 6 == 0 ? "fizz buzz" :
             x % 2 == 0 ? "fizz" :
             x % 3 == 0 ? "buzz" :
             string(x)).(x)
10-element Array{String,1}:
 "1"
 "fizz"
 "buzz"
 "fizz"
 "5"
 "fizz buzz"
 "7"
 "fizz"
 "buzz"
 "fizz"

In this basic example note the following things:

  • in the Julia code we do not need to load any package; we are using the
    functionality in built into the language;
  • we create an anonymous function that is then broadcasted over an input vector
    using the . operator;
  • both codes look almost the same, apart from a bit different punctuation.

In summary – in my opinion the basic use case shows that the ternary operator
is as convenient as case_when from dplyr.

Example 2

In this example missing values are introduced. We are reusing the vector created
in the previous exercise.

R code:

> x[2:4] <- NA_real_
> case_when(
+   x %% 35 == 0 ~ "fizz buzz",
+   x %% 5 == 0 ~ "fizz",
+   x %% 7 == 0 ~ "buzz",
+   is.na(x) ~ "nope",
+   TRUE ~ as.character(x)
+ )
 [1] "1"    "nope" "nope" "nope" "fizz" "6"    "buzz" "8"    "9"    "fizz"

Julia code (two variants):

julia> x = [2 <= i <= 4 ? missing : x[i] for i in axes(x, 1)]

10-element Array{Union{Missing, Int64},1}:
  1
   missing
   missing
   missing
  5
  6
  7
  8
  9
 10

julia> (x -> isequal(x % 6, 0) ? "fizz buzz" :
             isequal(x % 2, 0) ? "fizz" :
             isequal(x % 3, 0) ? "buzz" :
             ismissing(x) ? "nope" :
             string(x)).(x)
10-element Array{String,1}:
 "1"
 "nope"
 "nope"
 "nope"
 "5"
 "fizz buzz"
 "7"
 "fizz"
 "buzz"
 "fizz"

julia> (x -> coalesce(x % 6 == 0, false) ? "fizz buzz" :
             coalesce(x % 2 == 0, false) ? "fizz" :
             coalesce(x % 3 == 0, false) ? "buzz" :
             ismissing(x) ? "nope" :
             string(x)).(x)
10-element Array{String,1}:
 "1"
 "nope"
 "nope"
 "nope"
 "5"
 "fizz buzz"
 "7"
 "fizz"
 "buzz"
 "fizz"

additionally note that the code can be simplified if we put the ismissing
condition first:

julia> (x -> ismissing(x) ? "nope" :
             x % 6 == 0 ? "fizz buzz" :
             x % 2 == 0 ? "fizz" :
             x % 3 == 0 ? "buzz" :
             string(x)).(x)
10-element Array{String,1}:
 "1"
 "nope"
 "nope"
 "nope"
 "5"
 "fizz buzz"
 "7"
 "fizz"
 "buzz"
 "fizz"

Note the following patterns in this example:

  • we had to materialize the vector in Julia in a bit more complex way as the
    initial x vector was a 1:10 range which was read only;
  • in R comparison to missing is treated as failing by default; on the other
    hand Julia is strict about boolean tests and one has to use either the
    isequal or the coalesce functions to handle missing values (or move the
    ismissing test to the top); this strictness introduces a bit more verbosity
    in the code at the benefit of allowing the user to catch potential bugs in the
    logic of the code more easily.

Example 3

In this example we use the starwars dataset that is shipped with dplyr. So we
first have to fetch it from the Internet in Julia. Here is the code that does
the trick:

julia> using DataFrames

julia> using HTTP

julia> using JSON3

julia> using Pipe

julia> starwars = @pipe HTTP.get("https://swapi.dev/api/people/").body |>
                  JSON3.read |> _.results |> DataFrame |>
                  transform(_,
                            :species => ByRow(x -> isempty(x) ? "Human" : "Droid"),
                            [:height, :mass] .=> ByRow(x -> parse(Int, x)),
                            renamecols=false)
10×16 DataFrame
 Row │ name                height  mass   hair_color     skin_color   eye_color ⋯
     │ String              Int64   Int64  String         String       String    ⋯
─────┼───────────────────────────────────────────────────────────────────────────
   1 │ Luke Skywalker         172     77  blond          fair         blue      ⋯
   2 │ C-3PO                  167     75  n/a            gold         yellow
   3 │ R2-D2                   96     32  n/a            white, blue  red
   4 │ Darth Vader            202    136  none           white        yellow
   5 │ Leia Organa            150     49  brown          light        brown     ⋯
   6 │ Owen Lars              178    120  brown, grey    light        blue
   7 │ Beru Whitesun lars     165     75  brown          light        blue
   8 │ R5-D4                   97     32  n/a            white, red   red
   9 │ Biggs Darklighter      183     84  black          light        brown     ⋯
  10 │ Obi-Wan Kenobi         182     77  auburn, white  fair         blue-gray
                                                               10 columns omitted

We have fetched only 10 rows of data for the analysis (this is the number of
observations that the exposed API produces), but it is enough for our purposes.

As a side note – observe how easy it is in JuliaData ecosystem to fetch a JSON
file from the Internet, parse it, populate a DataFrame, and finally do some
column preprocessing to get the right column types for data that we are
interested it analyzing later.

Let us move to the example. In this case we want to process more than one column
using the case_when function within a data transformation pipeline.

R code

> starwars %>%
+   select(name:mass, gender, species) %>%
+   mutate(
+     type = case_when(
+       height > 200 | mass > 200 ~ "large",
+       species == "Droid"        ~ "robot",
+       TRUE                      ~ "other"
+     )
+   )
# A tibble: 87 x 6
   name               height  mass gender    species type
   <chr>               <int> <dbl> <chr>     <chr>   <chr>
 1 Luke Skywalker        172    77 masculine Human   other
 2 C-3PO                 167    75 masculine Droid   robot
 3 R2-D2                  96    32 masculine Droid   robot
 4 Darth Vader           202   136 masculine Human   large
 5 Leia Organa           150    49 feminine  Human   other
 6 Owen Lars             178   120 masculine Human   other
 7 Beru Whitesun lars    165    75 feminine  Human   other
 8 R5-D4                  97    32 masculine Droid   robot
 9 Biggs Darklighter     183    84 masculine Human   other
10 Obi-Wan Kenobi        182    77 masculine Human   other
# … with 77 more rows

Julia code

julia> @pipe starwars |>
             select(_, Between(:name, :mass), :gender, :species) |>
             transform(_, [:height, :mass, :species] =>
                          ByRow((height, mass, species) ->
                                height > 200 || mass > 200 ? "large" :
                                species == "Droid" ? "robot" :
                                "other") =>
                          :type)
10×6 DataFrame
 Row │ name                height  mass   gender  species  type
     │ String              Int64   Int64  String  String   String
─────┼────────────────────────────────────────────────────────────
   1 │ Luke Skywalker         172     77  male    Human    other
   2 │ C-3PO                  167     75  n/a     Droid    robot
   3 │ R2-D2                   96     32  n/a     Droid    robot
   4 │ Darth Vader            202    136  male    Human    large
   5 │ Leia Organa            150     49  female  Human    other
   6 │ Owen Lars              178    120  male    Human    other
   7 │ Beru Whitesun lars     165     75  female  Human    other
   8 │ R5-D4                   97     32  n/a     Droid    robot
   9 │ Biggs Darklighter      183     84  male    Human    other
  10 │ Obi-Wan Kenobi         182     77  male    Human    other

or if you like using DataFramesMeta.jl:

julia> using DataFramesMeta

julia> @pipe starwars |>
             select(_, Between(:name, :mass), :gender, :species) |>
             @eachrow _ begin
                 @newcol type::Vector{String}
                 :type = :height > 200 || :mass > 200 ? "large" :
                         :species == "Droid" ? "robot" :
                         "other"
             end
10×6 DataFrame
 Row │ name                height  mass   gender  species  type
     │ String              Int64   Int64  String  String   String
─────┼────────────────────────────────────────────────────────────
   1 │ Luke Skywalker         172     77  male    Human    other
   2 │ C-3PO                  167     75  n/a     Droid    robot
   3 │ R2-D2                   96     32  n/a     Droid    robot
   4 │ Darth Vader            202    136  male    Human    large
   5 │ Leia Organa            150     49  female  Human    other
   6 │ Owen Lars              178    120  male    Human    other
   7 │ Beru Whitesun lars     165     75  female  Human    other
   8 │ R5-D4                   97     32  n/a     Droid    robot
   9 │ Biggs Darklighter      183     84  male    Human    other
  10 │ Obi-Wan Kenobi         182     77  male    Human    other

As you can see it is easy to use the ternary operator also in the case of
several variables. Using DataFrames.jl requires a bit of boilerplate syntax.
This limitation can be conveniently overcome using DataFramesMeta.jl, in the
above example I decided to use the @eachrow macro.

Conclusions

As you can see using the ternary operator in Julia gives us a very similar
functionality and syntax in comparison to the case_when function from dplyr.
Apart from the differences how missing values are handled I have discussed above
there are two features that make the solution in Julia more convenient in my
opinion:

  • in case_when all values on right hand side have to have the same type, while
    in Julia there is no such restriction;
  • case_when evaluates all right hand side expressions, while the ternary
    operator evaluates only what has to be evaluated to determine the result
    of the operation (this is often preferred when some operations may
    throw an error for certain values of their arguments).

Binning your data with Julia

By: Blog by Bogumił Kamiński

Re-posted from: https://bkamins.github.io/julialang/2020/12/11/binning.html

Introduction

Cutting data into groups (binning) is one of the most common data preprocessing
tasks.

You can easily do binning into groups of equal sizes using the cut function
from CategoricalArrays.jl like this (here we bin a vector of values from 1 to 10
into 2 groups):

julia> using CategoricalArrays

julia> cut(1:10, 2)
10-element CategoricalArray{String,1,UInt32}:
 "Q1: [1.0, 5.5)"
 "Q1: [1.0, 5.5)"
 "Q1: [1.0, 5.5)"
 "Q1: [1.0, 5.5)"
 "Q1: [1.0, 5.5)"
 "Q2: [5.5, 10.0]"
 "Q2: [5.5, 10.0]"
 "Q2: [5.5, 10.0]"
 "Q2: [5.5, 10.0]"
 "Q2: [5.5, 10.0]"

However, the issue becomes more challenging when the number of bins is not a
divisor of vector length or if you have duplicates in the data.

The post is tested under Julia 1.5.3, DataFrames.jl 0.22.1, CategoricalArrays.jl
0.9.0, and FreqTables.jl 0.4.2.

The corner cases of binning

Let us first highlight some potential issues when binning data.

The first problem is when the number of groups is not a divisor of the vector
length. Let us check it out on some examples:

julia> cut(1:10, 3)
10-element CategoricalArray{String,1,UInt32}:
 "Q1: [1.0, 4.0)"
 "Q1: [1.0, 4.0)"
 "Q1: [1.0, 4.0)"
 "Q2: [4.0, 7.0)"
 "Q2: [4.0, 7.0)"
 "Q2: [4.0, 7.0)"
 "Q3: [7.0, 10.0]"
 "Q3: [7.0, 10.0]"
 "Q3: [7.0, 10.0]"
 "Q3: [7.0, 10.0]"

julia> cut(1:10, 4)
10-element CategoricalArray{String,1,UInt32}:
 "Q1: [1.0, 3.25)"
 "Q1: [1.0, 3.25)"
 "Q1: [1.0, 3.25)"
 "Q2: [3.25, 5.5)"
 "Q2: [3.25, 5.5)"
 "Q3: [5.5, 7.75)"
 "Q3: [5.5, 7.75)"
 "Q4: [7.75, 10.0]"
 "Q4: [7.75, 10.0]"
 "Q4: [7.75, 10.0]"

The cut function is deterministic and it uses the quantile function to find
the bin endpoints. This means that in the first example cut(1:10, 3) the third
bin will be always larger than the first and second bin. Similarly in cut(1:10,
4)
the first and the fourth bins are going to be larger deterministically.

The other problem is duplicates in data. Consider the following scenario:

julia> cut([1; fill(2, 8); 3], 2)
10-element CategoricalArray{String,1,UInt32}:
 "Q1: [1.0, 2.0)"
 "Q2: [2.0, 3.0]"
 "Q2: [2.0, 3.0]"
 "Q2: [2.0, 3.0]"
 "Q2: [2.0, 3.0]"
 "Q2: [2.0, 3.0]"
 "Q2: [2.0, 3.0]"
 "Q2: [2.0, 3.0]"
 "Q2: [2.0, 3.0]"
 "Q2: [2.0, 3.0]"

We want two bins. Ideally both should have five elements, but since we have
duplicates in our data the first bin has size one and the second size nine.

In some cases you will like what cut produces, but in other cases
one might want to avoid these two problems, that is:

  • always make the bins of equal size and if it is not possible to do so, make
    the decision which bin should be larger and which smaller randomly;
  • allow duplicates to be split between two or more bins (this is then
    unavoidable in some cases), but in such a way that each duplicate has the same
    chance to fall into each bin.

Random binning

Here is the function that performs the binning that has the properties I have
described above:

using DataFrames
using FreqTables
using Random

function binvec(x::AbstractVector, n::Int,
                rng::AbstractRNG=Random.default_rng())
    n > 0 || throw(ArgumentError("number of bins must be positive"))
    l = length(x)

    # find bin sizes
    d, r = divrem(l, n)
    lens = fill(d, n)
    lens[1:r] .+= 1
    # randomly decide which bins should be larger
    shuffle!(rng, lens)

    # ensure that we have data sorted by x, but ties are ordered randomly
    df = DataFrame(id=axes(x, 1), x=x, r=rand(rng, l))
    sort!(df, [:x, :r])

    # assign bin ids to rows
    binids = reduce(vcat, [fill(i, v) for (i, v) in enumerate(lens)])
    df.binids = binids

    # recover original row order
    sort!(df, :id)
    return df.binids
end

Let us now test the binning on the following vector:

julia> Random.seed!(1234);

julia> x = repeat('a':'c', 3)
9-element Array{Char,1}:
 'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase)
 'b': ASCII/Unicode U+0062 (category Ll: Letter, lowercase)
 'c': ASCII/Unicode U+0063 (category Ll: Letter, lowercase)
 'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase)
 'b': ASCII/Unicode U+0062 (category Ll: Letter, lowercase)
 'c': ASCII/Unicode U+0063 (category Ll: Letter, lowercase)
 'a': ASCII/Unicode U+0061 (category Ll: Letter, lowercase)
 'b': ASCII/Unicode U+0062 (category Ll: Letter, lowercase)
 'c': ASCII/Unicode U+0063 (category Ll: Letter, lowercase)

julia> binvec(x, 2)
9-element Array{Int64,1}:
 1
 2
 2
 1
 2
 2
 1
 1
 2

As you can see 'b's are split betwen bin 1 and 2 to make them have almost
equal size.

Let us make sure that binvec does the right job in deciding on bin sizes
and splitting 'b's between both bins.

julia> df = reduce(vcat, [DataFrame(x=x, run_id=i, row_id=axes(x, 1),
                                    group_id=binvec(x, 2)) for i in 1:10_000]);

julia> freqtable(df, :group_id, :row_id, :x)
2×9×3 Named Array{Int64,3}

[:, :, x='a'] =
group_id ╲ row_id │     1      2      3      4      5      6      7      8      9
──────────────────┼──────────────────────────────────────────────────────────────
1                 │ 10000      0      0  10000      0      0  10000      0      0
2                 │     0      0      0      0      0      0      0      0      0

[:, :, x='b'] =
group_id ╲ row_id │    1     2     3     4     5     6     7     8     9
──────────────────┼─────────────────────────────────────────────────────
1                 │    0  4941     0     0  5005     0     0  5027     0
2                 │    0  5059     0     0  4995     0     0  4973     0

[:, :, x='c'] =
group_id ╲ row_id │     1      2      3      4      5      6      7      8      9
──────────────────┼──────────────────────────────────────────────────────────────
1                 │     0      0      0      0      0      0      0      0      0
2                 │     0      0  10000      0      0  10000      0      0  10000

Indeed we see that each 'b' falls to group 1 and group 2 with 50%
probability. Also the expected size of group 1 and group 2 is 4.5 as
desired. (both calculations are approximate because we used simulation.)

Conclusions

First – let me comment in what scenario the binning I described is desirable.
Assume you have a set of patients you want to vaccinate against COVID-19. Now
let each of them have assigned a discrete urgency level (typically there will be
3 or 4 such urgency levels). You then have to split them into several batches
of equal size so that each batch gets a vaccine in a different period. If you
want to be fair in assigning people to batches you get exactly the setting I
have described.

Second – as usual I wanted to showcase some features of JuliaData ecosystem. In
particular you have seen reducevcat combo twice (for vectors and for data
frames) and integration of FreqTables.jl with DataFrames.jl that I am very fond
of. Of course this is not the fastest way to get the desired results. I
encourage you to write a faster function that does the same what the binvec
function does.

E-Graph Pattern Matching (Part II)

By: Philip Zucker

Re-posted from: https:/www.philipzucker.com/egraph-2/

Last time we spent a bit of effort building an e-graph data structure in Julia. The e-graph is a data structure that compactly stores and maintains equivalence classes of terms.

We can use the data structure for equational rewriting by iteratively pattern matching terms inside of it and adding new equalities found by instantiating those patterns. For example, we could use the pattern ~x + 0 to instantiate the rule ~x + 0 == ~x, find that ~x binds to foo23(a,b,c) in the e-graph and then insert into the e-graph the equality foo23(a,b,c) + 0 == foo23(a,b,c) and do this over and over. The advantage of the e-graph vs just rewriting terms is that we never lose information while doing this or rewrite ourselves into a corner. The disadvantage is that we maintain this huge data structure.

The rewrite rule problem can be viewed something like a graph. The vertices of the graph are every possible term. A rewrite rule that is applicable to that node describes a directed edge in the graph.

This graph is very large, infinite often, so it can only be described implicitly.

This graph perspective fails to capture some useful properties though. Treating each term as an indivisible vertex fails the capture that there can be rewriting happening in only a small part of the term, the vast majority of it left unchanged. There is a lot of opportunity for shared computations.

The EGraph from this perspective is a data structure for holding the already seen set of vertices efficiently and sharing these computations.

You can use this procedure for theorem proving. Insert the two terms you want to prove equal into the e-graph. Iteratively rewrite using your equalities. At each iteration, check whether two nodes in the e-graph have become equivalent. If so, congrats, problem proved! This is the analog of finding a path between two nodes in a graph by doing a bidirectional search from the start and endpoint. This is roughly the main method by which Z3 reasons about syntactic terms under equality in the presence of quantified equations. There are also methods by which to reconstruct the proof found if you want it.

Another application of basically the same algorithm is finding optimized rewrites. Apply rewrites until you’re sick of it, then extract from the equivalence class of your query term your favorite equivalent term. This is a very intriguing application to me as it produces more than just a yes/no answer.

Pattern Matching

Pattern matching takes a pattern with named holes in it and tries to find a way to fill the holes to match a term. The package SymbolicUtils.jl is useful for pattern matching against single terms, which is what I’ll describe in this section, but not quite set up to pattern match in an e-graph.

Pattern matching against a term is a very straightforward process once you sit down to do it. Where it gets complicated is trying to be efficient. But one should walk before jogging. I sometimes find myself so overwhelmed by performance considerations that I never get started.

First, one needs a data type for describing patterns (well, one doesn’t need it, but it’s nice. You could directly use pattern matching combinators).

struct PatVar
    id::Symbol
end

struct PatTerm
    head::Symbol
    args::Array{Union{PatTerm, PatVar}}
end

Pattern = Union{PatTerm,PatVar}

There is an analogous data structure in SymbolicUtils which I probably should’ve just used. For example Slot is basically PatVar.

The pattern matching function takes a Term and Pattern and returns a dictionary Dict{PatVar,Term} of bindings of pattern variables to terms or nothing if it fails

Matching a pattern variable builds a dictionary of bindings from that variable to a term.

match(t::Term,  p::PatVar ) = Dict(p => t)

Matching a PatTerm requires we check if anything is obviously wrong (wrong head, or wrong number of args) and then recurse on the args.
The resulting binding dictionaries are checked for consistency of their bindings and merged.


function merge_consistent(ds)
    newd = Dict()
    for d in ds
        for (k,v) in d
            if haskey(newd,k)
                if newd[k] != v
                    return nothing
                end
            else
                newd[k] = v
            end
        end
    end
    return newd
end

function match(t::Term, p::PatTerm) 
    if t.head != p.head || length(t.args) != length(p.args)
        return nothing
    else
        merge_consistent( [ match(t1,p1) for (t1,p1) in zip(t.args, p.args) ])
    end
end

There are other ways of arranging this computation, such as passing a dictionary parameter along and modifying it, by I find the above purely functional definition the most clear.

Pattern Matching in E-Graphs

The twist that E-Graphs throw into the pattern matching process is that instead of checking whether a child of a term matches the pattern, we need to check for any possible child matching in the equivalence class of children. This means our pattern matching procedure contains a backtracking search.

The de Moura and Bjorner e-matching paper describes how to do this efficiently via an interpreted virtual machine. An explicit virtual machine is important for them because the patterns change during the Z3 search process, but it seems less relevant for the use case here or in egg. It may be better to use partial evaluation techniques to remove the existence of the virtual machine at runtime like in A Typed, Algebraic Approach to Parsing but I haven’t figured out how to do it yet.

The Simplify paper describes a much simpler inefficient pattern matching algorithm in terms of generators. This can be directly translated to Julia using Channels. This is a nice blog post describing how to build generators in Julia that work like python generators. Basically, as soon as you define your generator function, wrap the whole thing in a Channel() do c and when you want to yield myvalue call put!(c, myvalue).

# https://www.hpl.hp.com/techreports/2003/HPL-2003-148.pdf
function ematchlist(e::EGraph, t::Array{Union{PatTerm, PatVar}} , v::Array{Id}, sub)
    Channel() do c
        if length(t) == 0
            put!(c, sub)
        else
            for sub1 in ematch(e, t[1], v[1], sub)
                for sub2 in ematchlist(e, t[2:end], v[2:end], sub1)
                    put!(c, sub2)
                end
            end
        end
    end
end
# sub should be a map from pattern variables to Id
function ematch(e::EGraph, t::PatVar, v::Id, sub)
    Channel() do c
        if haskey(sub, t)
            if find_root!(e, sub[t]) == find_root!(e, v)
                put!(c, sub)
            end
        else
            put!(c,  Base.ImmutableDict(sub, t => find_root!(e, v)))
        end
    end
end

    
function ematch(e::EGraph, t::PatTerm, v::Id, sub)
    Channel() do c
        for n in e.classes[find_root!(e,v)].nodes
            if n.head == t.head
                for sub1 in ematchlist(e, t.args , n.args , sub)
                    put!(c,sub1)
                end
            end
        end
    end
end

You can then instantiate patterns with the returned dictionaries via


function instantiate(e::EGraph, p::PatVar , sub)
    sub[p]
end

function instantiate(e::EGraph, p::PatTerm , sub)
    push!( e, Term(p.head, [ instantiate(e,a,sub) for a in p.args ] ))
end

And build rewrite rules

struct Rule
    lhs::Pattern
    rhs::Pattern
end

function rewrite!(e::EGraph, r::Rule)
    matches = []
    EMPTY_DICT2 = Base.ImmutableDict{PatVar, Id}(PatVar(:____),  Id(-1))
    for (n, cls) in e.classes
        for sub in ematch(e, r.lhs, n, EMPTY_DICT2)
            push!( matches, ( instantiate(e, r.lhs ,sub)  , instantiate(e, r.rhs ,sub)))
        end
    end
    for (l,r) in matches
        union!(e,l,r)
    end
    rebuild!(e)
end

Here’s a very simple equation proving function that takes in a pile of rules

function prove_eq(e::EGraph, t1::Id, t2::Id , rules)
    for i in 1:3
        if in_same_set(e,t1,t2)
            return true
        end
        for r in rules
            rewrite!(e,r) # I should split this stuff up. We only need to rebuild at the end
        end
    end
    return nothing
end

As sometimes happens, I’m losing steam on this and would like to work on something different for a bit. But I loaded up the WIP at https://github.com/philzook58/EGraphs.jl.

Bits and Bobbles

Catlab

You can find piles of equational axioms in Catlab like here for example

A tricky thing is that some equational axioms of categories seem to produce variables out of thin air.
Consider the rewrite rule ~f => id(~A) ⋅ ~f. Where does the A come from? It’s because we’ve highly suppressed all the typing information. The A should come from the type of f.

I think the simplest way to fix is the “type tag” approach I mentioned here https://www.philipzucker.com/theorem-proving-for-catlab-2-lets-try-z3-this-time-nope/ and can be read about here https://people.mpi-inf.mpg.de/~jblanche/mono-trans.pdf. You wrap every term in a tagging function so that f becomes tag(f, Hom(a,b)). This approach makes sense because it turns GATs purely equational without the implicit typing context, I think. It is a bummer that it will probably choke the e-graph with junk though.

It may be best to build a TypedTerm to use rather than Term that has an explicit field for the type. This brings it close to the representation that Catlab already uses for syntax, so maybe I should just directly use that. I like having control over my own type definitions though and Catlab plays some monkey business with the Julia level types. 🙁

struct TypedTerm
    type
    head
    args
end

As I have written it so far, I don’t allow the patterns to match on the heads of terms, although there isn’t any technical reason to not allow it. The natural way of using a TypedTerm would probably want to do so though, since you’d want to match on the type sometimes while keeping the head a pattern. Another possible way to do this that avoids all these issues is to actually make terms a bit like how regular Julia Expr are made, which usually has :call in the head position. Then by convention the first arg could be the actual head, the second arg the type, and the rest of the arguments the regular arguments.

A confusing example sketch:

TypedTerm(:mcopy, typedterm(Hom(a,otimes(a,a))) , [TypedTerm(a,Ob,[])])
would become
Term(:call, [:mcopy, term!(Hom(a,otimes(a,a)) , term!(a) ])

Catlab already does some simplifications on it’s syntax, like collecting up associative operations into lists. It is confusing how to integrate the with the e-graph structure so I think the first step is to just not. That stuff isn’t on by default, so you can get rid of it by defining your own versions using @syntax

using Catlab
using Catlab.Theories
import Catlab.Theories: compose, Ob, Hom
@syntax MyFreeCartesianCategory{ObExpr,HomExpr} CartesianCategory begin
  #otimes(A::Ob, B::Ob) = associate_unit(new(A,B), munit)
  #otimes(f::Hom, g::Hom) = associate(new(f,g))
  #compose(f::Hom, g::Hom) = associate(new(f,g; strict=true))

  #pair(f::Hom, g::Hom) = compose(mcopy(dom(f)), otimes(f,g))
  #proj1(A::Ob, B::Ob) = otimes(id(A), delete(B))
  #proj2(A::Ob, B::Ob) = otimes(delete(A), id(B))
end

# we can translate the axiom sets programmatically.
names(Main.MyFreeCartesianCategory)
#A = Ob(MyFreeCartesianCategory, :A)

A = Ob(MyFreeCartesianCategory, :A)
B = Ob(MyFreeCartesianCategory, :B)
f = Hom(:f, A,B)
g = Hom(:g, B, A)
h = compose(compose(f,g),f)
# dump(h)
dump(f)

#=
Main.MyFreeCartesianCategory.Hom{:generator}
  args: Array{Any}((3,))
    1: Symbol f
    2: Main.MyFreeCartesianCategory.Ob{:generator}
      args: Array{Symbol}((1,))
        1: Symbol A
      type_args: Array{GATExpr}((0,))
    3: Main.MyFreeCartesianCategory.Ob{:generator}
      args: Array{Symbol}((1,))
        1: Symbol B
      type_args: Array{GATExpr}((0,))
  type_args: Array{GATExpr}((2,))
    1: Main.MyFreeCartesianCategory.Ob{:generator}
      args: Array{Symbol}((1,))
        1: Symbol A
      type_args: Array{GATExpr}((0,))
    2: Main.MyFreeCartesianCategory.Ob{:generator}
      args: Array{Symbol}((1,))
        1: Symbol B
      type_args: Array{GATExpr}((0,))
=#

Rando thought: Can an EGraph data structure itself be expressed in Catlab? I note that IntDisjointSet does show up in a Catlab definition of pushouts. The difficulty and scariness of the EGraph is maintaining it’s invariants, which may be expressible as some kind of composition condition on the various maps that make up the EGraph.

The rewrite rule problem can be viewed something like a graph. The vertices of the graph is every possible term. A rewrite rule that is applicable to that node describes a directed edge in the graph.

This graph is very large, infinite often, so it can only be described implicitly.

Proving the equivalence of two terms is like finding a path in this graph. You could attempt a greedy approach or a breadth first search approach, or any variety of your favorite search algorithm like A*.

This graph perspective fails to capture some useful properties though. Treating each term as an indivisible vertex fails the capture that there can be rewriting happening in only a small part of the term, the vast majority of it left unchanged. There is a lot of opportunity for shared computations.

The EGraph from this perspective is a data structure for holding the already seen vertices.

I suspect some heuristics to be helpful like applying the “simplifying” direction of the equations more often than the “complicating” direction. In a 5:1 ratio let’s say.

A natural algorithm to consider for optimizing terms is to take the best expression found so far, destroy the e-graph and place just that expression in it.
Try rewrite rules. If the best is still old query, don’t destroy e-graph and apply a new round of rewrites the widen the radius of the e-graph. Gives you kind of a combo greedy + complete.

Partial Evaluation

Partial evaluation seems like a good technique for optimizing pattern matches. The ideal pattern match code expands out to a nicely nested set of if statements. One alternative to passing a dictionary might be to assign each pattern variable to an integer at compile time and instead pass an array, which would be a bit better. However, by using metaprogramming we can insert local variables into generated code and avoid the need for a dictionary being based around at runtime. Then the julia compiler can register allocate and what not like it usually does (and quite efficiently I’d expect).

See this post (in particular to reference links https://www.philipzucker.com/metaocaml-style-partial-evaluation-in-coq/ for more on partial evaluation.

A first thing to consider is that we’re building up and destroying dictionaries at a shocking rate.

Secondly dictionaries themselves are (relatively) expensive things.

I experimented a bit with using a curried form for match to see if maybe the Julia compiler was smart enough to sort of evaluate away my use of dictionaries, but that did not seem to be the case.

I found the examining the @code_llvm and @code_native of the following simple experiments illuminating as to what Julia can and can no get rid of when the dictionary is obviously known at compile time to human eyes. Mostly it needs to be very obvious. I suspect Julia does not have any built special compiler magic reasoning for dictionaries and trying to automatically infer what’s safely optimizable by actually expanding the definition of a dictionary op sounds impossible.

function foo() #bad
    x = Dict(:a => 4)
    return x[:a]
end
function foo2()
    return 4
end


function foo3() # bad
    x = Dict(:a => 4)
    () -> x[:a]
end

function foo4() # much better. But we do pay a closure cost.
    x = Dict(:a => 4)
    r = x[:a]
    () -> r
end

function foo5()
    x = Dict(:a => 4)
    :( () -> $(x[:a]) )
end

function foo7()
    return foo()
end


function foo6() # julia does however do constant propagation and deletion of unnecessary bullcrap
    x = 4
    y = 7
    y += x
    return x
end

function foo7() # this compiles well
    x = Base.ImmutableDict( :a => 4)
    return x[:a]
end

function foo8()
    x = Base.ImmutableDict( :a => 4)
    r = x[:a]
    z -> z == r # still has a closure indirection. It's pretty good though
end

function foo9()
    x = Base.ImmutableDict( :a => 4)
    r = x[:a]
    z -> z == 4 
end


#@code_native foo7()

z = foo4()
#@code_llvm z()

z = eval(foo5())
@code_native z()
@code_native foo2()
@code_llvm foo6()
z = foo8()
@code_native z(3)
z = foo9()
@code_native z(3)
@code_native foo7()

# so it's possible that using an ImmutableDict is sufficient for julia to inline almost evertything itself.
# You want the indexing to happen outside the closure.

A possibly useful technique is partial evaluation. We have in a sense built an interpreter for the pattern language. Specializing this interpeter gives a fast pattern matcher. We explicitly can build up the code Expr that will do the pattern matching by interpreting the pattern.

So here’s a version that takes in a pattern and builds code the perform the pattern in terms of if then else statements and runtime bindings.

Here’s a sketch. This version only returns true or false rather than returning the bindings. I probably need to cps it to get actual bindings out. I think that’s what i was starting to do with the c parameter.

function prematch(p::PatVar, env, t, c )
    if haskey(env, p)
       :(  $(env[p]) != $t  && return false )
    else
       env[p] = gensym(p.id) #make fresh variable
       :(  $(env[p]) = $t  ) #bind it at runtime to the current t
    end
end


function sequence(code)
    foldl(
        (s1,s2) -> quote
                        $s1
                        $s2
                    end
        , code)
end

function prematch(p::PatTerm, env, t, c)
    if length(p.args) == 0
        :( $(t).head == $( QuoteNode(p.head)  ) && length($t) == 0 && return false   )
    else
        quote
            if $(t).head != $( QuoteNode(p.head) ) || $( length(p.args) ) != length($(t).args)
                return false
            else
                $( sequence([prematch(a, env, :( $(t).args[$n] ), c)   for (n,a) in enumerate(p.args) ]))
            end
        end 
    end
end


println( prematch( PatTerm(:f, []) , Dict(),  :w , nothing) )
println( prematch( PatTerm(:f, [PatVar(:a)]) , Dict(),  :t , nothing) )
println( prematch( PatTerm(:f, [PatVar(:a), PatVar(:a)]) , Dict(),  :t , nothing) )