Planting trees on-chain

Note: This is a repost of Dante’s HackMd from September 2023.

Full disclosure this post is not about gardening but about implementing ZK versions of machine learning algorithms with botanical nomenclature: decision trees, gradient boosted trees, and random forests. If you’re a keen gardener check this out.

Lingering Github issues give us heart palpitations, particularly those that have been open for months on end. Sitting like mildew in an otherwise pristine home.

Here’s one we’ve had open since January of this year:

EZKL (for those not in the know), is a library for converting common computational graphs, in the (quasi)-universal .onnx format, into zero knowledge (ZK) circuits. This allows, for example, for:

  • you to prove that you’ve run a particular neural network on a private dataset.
  • Or to prove you’ve developed a super secret new architecture that achieves 99% accuracy on a publicly available dataset without revealing the model parameters.

Though our library has improved in its scope of supported models, including transformer-based models (see here for a writeup), GANs, and LSTMs; implementing Kaggle crushing models like random forests and gradient boosting trees has been challenging.

Part of the issue stems from the way sklearn, xgboost, and lightgbm models are exported to .onnx. Instead of decomposing tree based models into individual operations, like matrix multiplication, addition, and comparisons, the whole model is exported as a single node (see image above) !

Given our library’s focus on modularity and composability this has been a bit of an anti-pattern, a proverbial thorn in our side.

This weekend after yet another call with users asking for a timeline for the implementation of such models we decided to roll up our sleeves and get it done in 48 hours. Check out a colab example here.

Here’s what it took.

Decomposing ugly single node onnx models into individual operations

As noted above, having single node onnx graphs is an anti-pattern, something that might destroy our library’s clean architecture if we try and accomodate it. A much better approach would be to instead convert the single node graph into its individual operations. Luckily we are not the only folks in history to have been keen to do this. And we landed on the beautifully coded sk2torch library which takes a graph of this form:

And turns it into something like this:

So much nicer !

For more complex models like random forests we can simply extract the individual trees / estimators, run them through sk2torch and recreate the forest as a pytorch module.

trees = []
for tree in clr.estimators_:
    trees.append(sk2torch.wrap(tree))

print(trees)


class RandomForest(nn.Module):
    def __init__(self, trees):
        super(RandomForest, self).__init__()
        self.trees = nn.ModuleList(trees)

    def forward(self, x):
        out = self.trees[0](x)
        for tree in self.trees[1:]:
            out += tree(x)
        return out / len(self.trees)

For xgboost and lightgbm we leveraged hummingbird, a Microsoft library for converting xgboost into torch / tensor graphs. A converted XGBoost classifier looks like this when exported to onnx:

An observant reader will note that some operations, like ArgMax or Gather don’t have particularly obvious implementations in zero-knowledge circuits. This was the second leg of our sprint.

Dynamic / private indexing of values in ZK

In python a simple and innocent indexing operation over a one-dimensional tensor x, z = x[m] is trivial. But in ZK-circuits how do we enforce this sort of indexing? especially when the indices like (m) might be private (advice in plonk parlance) values?

The first argument we constructed was one which allows us to implement zk-circuit equivalents of the Gather operation. Which essentially just indexes a tensor x at a given set of indices. To allow for these indices to be advice values we need to construct a new kind of argument for indexing over vectors / tensors in a zk-circuit.

an argument for private indexing
  1. We generate a claimed output. In the example above z. Which (should) correspond to the value of the tensor x at index
  2. We assign fixed public values for the potential indices. In the example above i is in the range .
  3. We use the equals argument (see appendix below for a full description of this argument) to generate the following constraint: $$b = [(i == m)]_{i=1}^{N}$$
    • Note that we want b to be 0 at indices not equal to m and to be 1 at m. This is a boolean operation, and should be distinguised from the typical zk-circuit operation of constraining two elements to be equal (i.e arguments of the form x-y=0).

  4. We use an element-wise multiplication argument to constrain: $$\hat{x} = x \odot b$$
  5. We constrain $\hat{z}$ to be the sum of the elements of $\hat{x}$ (see the appendix for a description of the summation argument): $$\hat{z} = \sum_{i=1}^{N} \hat{x}_i$$
  6. Finally we constrain $$\hat{z} = z.$$

Altogether this set of arguments and constraints allow us to constrain the claimed output z to be the mth element of x.

The construction of argmax and argmin is very similar to the private indexing argument (and in fact leverages it). We add one additional constaint which is that, for a claimed $$m = \text{argmax}(x),$$ we should have $$x[m] = \text{max}(x).$$

an argument for private argmax / argmin

Say we want to calculate $$m = \text{argmax}(x),$$ where x is of length N.

  1. We generate a claimed output. In the example above m.
  2. Using the indexing argument above, we constrain: $$z = x[m]$$
  3. As an additional step we constrain $$z = \sum_{i=1}^{N} \hat{x}_i = \text{max}(x)$$ (see the appendix for a description of the max argument).

For argmin you only need to replace the above max operations with min 😁.

Conclusion

You can try out colab notebooks for the new tree based models at:

All these models (when properly calibrated using ezkl) output predictions that are less than 0.1% away from the original sklearn, xgboost, and lightgbm predictions.


References


Appendix

equals argument

  1. Generate a lookup table a which corresponds to the “is zero” boolean operation (i.e returns 1 if an element is 0, else it returns 0).
  2. Use an element-wise subtraction argument to constrain: $$d = x-y$$
  3. Apply lookup a to d.

max argument

  1. Calculate the claimed $$m=\text{max}(x),$$ and instantiate a lookup table a which corresponds to the ReLU element-wise operation.
  2. Constrain $$w = x - (m - 1)$$
  3. Use lookup a on w, this is equivalent to clipping negative values: $$y=a(w).$$
  4. Constrain the values y to be equal to 0 or 1, i.e assert that $$y_i*(y_i - 1)=0, \qquad \forall i \in 1\dots N.$$
    • Any non-clipped values should be equal to 1 as at most we are subtracting the max.
    • This demonstrates that the there is no value of x that is larger than the claimed maximum of x.
  5. We have now demonstrated that m is larger than any value of x, we must now demonstrate that at least one value of x is equal to m, i.e that m is an element of x.
    • We do this by constructing the argument $$z = a(1 - \sum_i y_i) = 0.$$
    • Note that $$\sum_i y_i = 0 \iff z = 1$$ and thus no values of the witness are equal to $$\text{max}(x).$$
    • Conversely if $$\sum_i y_i >= 1 \iff z=0$$ and thus least one value is equal to 1.

sum argument

Consider the following plonk columns:

| a0  | a1  | m     | s_dot | 
|-----|-----|-------|-------|
| a_i | b_i |m_{i-1}| s_sum |
|     |     |m_i    |       |

The sum between vectors a and b is then enforced using the following constraints $$ a_i + b_i + m_{i-1} = m_i, \qquad \forall i \in 1..N$$