-
Notifications
You must be signed in to change notification settings - Fork 100
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
End-to-end inference on Llama-7b #306
Comments
Follow-up Q: while the missing functions are being implemented, is it possible to run just these functions in cleartext, let's say by revealing their inputs, performing the computation locally, and then secret-sharing the outputs? |
Hi @deevashwer To add sine and cosine, first step to do is propose a reasonable approximation like this Once such approximations are implemented the rest compiler/dispatch work should be straight forward. For general complex support, I need to check the amount of work.. Best |
@deevashwer Hi.
|
It is definitely possible to do this, and for experimental purpose, I think it's fine. |
I think it might be fine to do this in experiments, as the np.complex64 only needs to compute freqs_cis (This process does not need any private inputs). |
Thanks for the quick response! Can you please guide me on how to introduce this makeshift solution of running the unsupported functions in cleartext, so that I can test the end-to-end accuracy of my solution? |
I'll prepare an example :D |
There is a set of transforms to lower complex ops. I'll check how good/bad these are. |
Hi @deevashwer I added a simple hacked sine here Should give you an idea about how to end-to-end support an instruction |
Actually, I am considering just to use the python hijack (just as the gelu hack in the llama-7b example) But I have no idea how |
I am trying to work around this issue by partitioning the LLaMA model into private and public parts, where all the private parts (most of the computation) will be run within SPU and the public parts will be run with JAX (float). To do the same, I need to simulate parts of the NN computation within SPU, and an example of how this would work is as follows: import jax
from jax import random, numpy as jnp
from flax import linen as nn
import spu.utils.simulation as pps
import spu
protocol = spu.ProtocolKind.REF2K
field = spu.FieldType.FM64
config = spu.RuntimeConfig(protocol=protocol, field=field)
simulator = pps.Simulator(1, config)
# pure function
def dense(params, x):
return jax.lax.dot_general(x, params['params']['kernel'], (((x.ndim - 1,), (0,)), ((), ())),)
class LinearModel(nn.Module):
features: int
dtype: jnp.dtype = jnp.float32
def setup(self):
self.layer = nn.Dense(self.features, use_bias=False, dtype=self.dtype)
def __call__(self, x):
params = {"params": self.layer.variables['params']}
# FLAX - Works
y = self.layer.apply(params, x)
print("y", y)
# SPU with pure function - Works
spu_dense = pps.sim_jax(simulator, dense)
spu_y = spu_dense(params, x)
print("spu_y (pure function)", spu_y)
# SPU with nn.Module.apply - Doesn't work
spu_dense = pps.sim_jax(simulator, self.layer.apply)
spu_y = spu_dense(params, x)
print("spu_y (nn.Module.apply)", spu_y)
return y
model = LinearModel(features=5)
key = random.PRNGKey(0)
params = {'params': {'layer': {'kernel': random.normal(key, (10, 5))}}}
x = random.normal(key, (10,))
y = model.apply(params, x) In this example, I'm trying to run one of the linear layers within SPU, but it doesn't work with the class method
|
Hi @deevashwer This looks like a bug in our compilation cache. I'll take a look |
# Pull Request ## What problem does this PR solve? Issue Number: #306 Fixed compilation cache error when Callable contains weakref ## Possible side effects? - Performance: N/A - Backward compatibility: N/A
Hi @deevashwer Fixed. :P |
Hi @deevashwer Complex, sin/cos support were added in the latest version and puma README has been updated. Please take a look. Best |
Hi,
Thanks for building this framework!
I'm trying to run end-to-end secure inference with Llama-7b model here, but as mentioned in the README, the model uses operations that SecretFlow does not currently support.
The README mentions that
np.complex64
instruction is the problem, but it seems from this documentation that sine and cosine operations are also unsupported (are those all the functions?). Are there any plans to integrate these missing functions in the XLA? If not, can you please guide me on the process to do the same and provide an estimate for the required effort?The documentation here says that it is a general limitation of SecretFlow that it can't support complex numbers. Does this limitation apply here?
Thanks for your time,
Devesh
The text was updated successfully, but these errors were encountered: