summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Sotoudeh <masotoudeh@ucdavis.edu>2020-06-11 18:47:30 -0700
committerGitHub <noreply@github.com>2020-06-11 18:47:30 -0700
commit38c22236f1c460f16c27443e7d37c8d33a81a5f8 (patch)
tree3f18205d3d7708abf10e557176f7a59f98df9959
parentf435165057e8995bce67ca0770b736574bb8af13 (diff)
Add KerasToSyrenn script, update pip installer (#4)
This commit adds a script for converting some Keras models to SyReNN and updates the pip-install script to ensure it always uses the sandboxed Python installation.
-rw-r--r--BUILD14
-rw-r--r--scripts/README.md22
-rw-r--r--scripts/keras_to_syrenn.py113
-rw-r--r--scripts/keras_to_syrenn_example.py29
4 files changed, 174 insertions, 4 deletions
diff --git a/BUILD b/BUILD
index 7c4fe99..8a75ead 100644
--- a/BUILD
+++ b/BUILD
@@ -60,25 +60,31 @@ genrule(
DUMMY_HOME=/tmp/$$(head /dev/urandom | tr -dc A-Za-z0-9 | head -c 8)
rm -rf $$DUMMY_HOME
- export HOME=$$DUMMY_HOME
+ export HOME=$$DUMMY_HOME
PIP_INSTALL="$$PIP \
install --no-cache-dir --disable-pip-version-check \
--target=$@"
- # Install the correct version of Torch
+ # Setup the environment to point to the right Python installation.
mkdir -p $$DUMMY_HOME
+ INSTALLDIR=$$PWD/$$(find **/** -name "installdir" | head -n 1)
+ export LD_LIBRARY_PATH=$$INSTALLDIR/lib/:$$INSTALLDIR/lib64/
+ ln -s $$INSTALLDIR/python3.7 $$DUMMY_HOME/python
+ export PYTHONPATH=$$PWD/$@
+ export PATH=$$DUMMY_HOME:$$PATH
+
+ # Install the correct version of Torch
$$PIP_INSTALL torch==1.2.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Install the other requirements.
- mkdir -p $$DUMMY_HOME
+ $$PIP_INSTALL --upgrade setuptools
$$PIP_INSTALL -r requirements.txt
# The custom typing package installed as a dependency doesn't seem to work
# well.
rm -rf $@/typing-*
rm -rf $@/typing.py
-
rm -rf $$DUMMY_HOME
""",
tools = [
diff --git a/scripts/README.md b/scripts/README.md
new file mode 100644
index 0000000..43c3616
--- /dev/null
+++ b/scripts/README.md
@@ -0,0 +1,22 @@
+# Miscellaneous Scripts
+This directory contains miscellaneous scripts related to SyReNN which we do not
+yet want to make a part of the main packages and/or build system.
+
+There is currently only one script, described below.
+
+### Keras To SyReNN
+`keras_to_syrenn.py` contains a `keras_to_syrenn(...)` method which shows how
+to convert a sequential Keras Model into an equivalent SyReNN Network.
+
+There is an example script converting a simple model in
+`keras_to_syrenn_example.py`.
+
+Please note that this script makes a lot of assumptions about the model,
+including that it is sequential and uses HWIO layout for images. You should
+always validate that the resultant SyReNN model has the same (within an
+epsilon) output as the original Keras model.
+
+We do not plan to move this code into the main SyReNN code in the near future,
+as it does not currently support all networks that the current ONNX importer
+does (e.g. with Concatenate layers), and it adds a heavy Tensorflow dependency
+which we otherwise do not need.
diff --git a/scripts/keras_to_syrenn.py b/scripts/keras_to_syrenn.py
new file mode 100644
index 0000000..459967a
--- /dev/null
+++ b/scripts/keras_to_syrenn.py
@@ -0,0 +1,113 @@
+"""Script for converting sequential Keras models to SyReNN Networks.
+"""
+import tensorflow.keras as keras
+import numpy as np
+import pysyrenn
+
+def keras_to_syrenn(model):
+ """Converts a sequential Keras model to a SyReNN Network.
+
+ Note that this conversion code makes a number of not-always-valid
+ assumptions about the model; you should *always* manually verify that the
+ returned SyReNN network has the same (within a small epsilon) output as the
+ Keras model.
+ """
+ syrenn_layers = []
+ def append_activation(function):
+ """Adds activation function @function to the SyReNN layers.
+ """
+ if function is None or function is keras.activations.linear:
+ # Identity: https://github.com/keras-team/keras/blob/bd024a1fc1cd6d88e8bc5da148968ff5e079caeb/keras/activations.py#L187
+ pass
+ elif function is keras.activations.relu:
+ syrenn_layers.append(pysyrenn.ReluLayer())
+ else:
+ print(function)
+ raise NotImplementedError
+
+ for layer in model.layers:
+ if isinstance(layer, keras.layers.InputLayer):
+ continue
+ elif isinstance(layer, keras.layers.Conv2D):
+ # filters: Height, Width, InChannels, OutChannels
+ # biases: OutChannels
+ filters, biases = map(to_numpy, layer.weights)
+ if layer.padding == "same":
+ pad_height = compute_same_padding(
+ filters.shape[0], layer.input_shape[1], layer.strides[0])
+ pad_width = compute_same_padding(
+ filters.shape[1], layer.input_shape[2], layer.strides[1])
+
+ assert pad_height % 2 == 0
+ assert pad_width % 2 == 0
+
+ padding = [pad_height // 2, pad_width // 2]
+ elif layer.padding == "valid":
+ padding = [0, 0]
+ else:
+ raise NotImplementedError
+
+ window_data = pysyrenn.StridedWindowData(
+ layer.input_shape[1:], # HWC
+ filters.shape[:2], # HW
+ layer.strides, # HW
+ padding, # HW
+ filters.shape[3])
+
+ # Note that SyReNN *assumes* the HWIO format and transforms it
+ # internally to the Pytorch OIHW format.
+ syrenn_layers.append(
+ pysyrenn.Conv2DLayer(window_data, filters, biases))
+ append_activation(layer.activation)
+ elif isinstance(layer, keras.layers.Activation):
+ append_activation(layer.activation)
+ elif isinstance(layer, keras.layers.BatchNormalization):
+ gamma, beta, mean, var = map(to_numpy, layer.weights)
+ # See https://github.com/keras-team/keras/blob/cb96315a291a8515544c6dd807500073958f8928/keras/backend/numpy_backend.py#L531
+ # ((x - mean) / sqrt(var + epsilon)) * gamma + beta
+ # = ((x - (mean - (d*beta))) / d) where
+ # d := sqrt(var + epsilon) / gamma
+ std = np.sqrt(var + 0.001) / gamma
+ mean = mean - (std * beta)
+ syrenn_layers.append(pysyrenn.NormalizeLayer(mean, std))
+ elif isinstance(layer, keras.layers.MaxPooling2D):
+ assert layer.padding == "valid"
+ window_data = pysyrenn.StridedWindowData(
+ layer.input_shape[1:], # HWC
+ layer.pool_size, # HW
+ layer.strides, # HW
+ [0, 0], # HW
+ layer.input_shape[3])
+
+ # Note that SyReNN *assumes* the HWIO format and transforms it
+ # internally to the Pytorch OIHW format.
+ syrenn_layers.append(pysyrenn.MaxPoolLayer(window_data))
+ elif isinstance(layer, keras.layers.Dropout):
+ # Not needed for inference.
+ pass
+ elif isinstance(layer, keras.layers.Flatten):
+ # By default, SyReNN passes data around in NHWC format to match with
+ # ERAN/TF.
+ assert layer.data_format == "channels_last"
+ elif isinstance(layer, keras.layers.Dense):
+ # weights: (from, to)
+ # biases: (to,)
+ weights, biases = map(to_numpy, layer.weights)
+ syrenn_layers.append(pysyrenn.FullyConnectedLayer(weights, biases))
+ append_activation(layer.activation)
+ else:
+ raise NotImplementedError
+ return pysyrenn.Network(syrenn_layers)
+
+def to_numpy(x):
+ """Helper to convert TensorFlow tensors to Numpy.
+ """
+ return x.numpy()
+
+def compute_same_padding(filter_size, in_size, stride):
+ """Helper to compute the amount of padding used by a convolution.
+
+ Computation based on https://stackoverflow.com/a/44242277
+ """
+ out_size = (in_size + (stride - 1)) // stride
+ return max((out_size - 1) * stride + filter_size - in_size, 0)
diff --git a/scripts/keras_to_syrenn_example.py b/scripts/keras_to_syrenn_example.py
new file mode 100644
index 0000000..bc31d92
--- /dev/null
+++ b/scripts/keras_to_syrenn_example.py
@@ -0,0 +1,29 @@
+"""Example of converting a sequential Keras model to SyReNN.
+"""
+import tensorflow.keras as keras
+import numpy as np
+from keras_to_syrenn import keras_to_syrenn
+import pysyrenn
+
+# https://keras.io/guides/sequential_model/
+model = keras.Sequential(
+ [
+ keras.layers.Dense(2, activation="relu", name="layer1"),
+ keras.layers.Dense(3, activation=None, name="layer2"),
+ keras.layers.Dense(4, name="layer3"),
+ ]
+)
+# We need to evaluate the model at least once before converting to SyReNN. I
+# think this is because Keras doesn't actually initialize the parameters until
+# this.
+model(np.ones((1, 3)))
+
+syrenn_network = keras_to_syrenn(model)
+
+x = np.ones((3, 3))
+
+print("Keras model output:")
+print(model(x))
+
+print("SyReNN network output:")
+print(syrenn_network.compute(x))
generated by cgit on debian on lair
contact matthew@masot.net with questions or feedback