Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End-to-end inference on Llama-7b #306

Closed
deevashwer opened this issue Aug 16, 2023 · 15 comments
Closed

End-to-end inference on Llama-7b #306

deevashwer opened this issue Aug 16, 2023 · 15 comments

Comments

@deevashwer
Copy link

deevashwer commented Aug 16, 2023

Hi,

Thanks for building this framework!

I'm trying to run end-to-end secure inference with Llama-7b model here, but as mentioned in the README, the model uses operations that SecretFlow does not currently support.
The README mentions that np.complex64 instruction is the problem, but it seems from this documentation that sine and cosine operations are also unsupported (are those all the functions?). Are there any plans to integrate these missing functions in the XLA? If not, can you please guide me on the process to do the same and provide an estimate for the required effort?

The documentation here says that it is a general limitation of SecretFlow that it can't support complex numbers. Does this limitation apply here?

Thanks for your time,
Devesh

@deevashwer
Copy link
Author

Follow-up Q: while the missing functions are being implemented, is it possible to run just these functions in cleartext, let's say by revealing their inputs, performing the computation locally, and then secret-sharing the outputs?
Perhaps it's easier to do with the Simulator and the REF2K backend?

@anakinxc
Copy link
Contributor

Hi @deevashwer

To add sine and cosine, first step to do is propose a reasonable approximation like this

Once such approximations are implemented the rest compiler/dispatch work should be straight forward.

For general complex support, I need to check the amount of work..

Best
~Yancheng

@fionser
Copy link
Contributor

fionser commented Aug 16, 2023

@deevashwer Hi.

  • The trigonometric functions seem not supporting yet in SPU.
  • However, the dispatch of those function from the front-end to the mpc layer are still missing.
    So for simulation (or reveal-then-reshare adhoc) still needs some codes to dispatch the MILR down
    to the MPC layer.
  • But we can still run the llama-7b example but just ignoring these trig-functions (i.e., simply re-write the EasyLM's source code and return)
    BTW, the EasyLM code have been updated recently. So you should check out the
    the commit id '690cba24a2da6711097425a94ea0b5295a82144b' for an older version of EasyLM
  • We notice that the current LLama-7b example can not guarantee correctness. Because the trained weight is converted from PyTorch to JAX seems not working properly. We only check the consistency between SPU plaintext JAX execution and the ABY3 execution

@anakinxc
Copy link
Contributor

Follow-up Q: while the missing functions are being implemented, is it possible to run just these functions in cleartext, let's say by revealing their inputs, performing the computation locally, and then secret-sharing the outputs? Perhaps it's easier to do with the Simulator and the REF2K backend?

It is definitely possible to do this, and for experimental purpose, I think it's fine.

@Ye-D
Copy link
Collaborator

Ye-D commented Aug 16, 2023

Follow-up Q: while the missing functions are being implemented, is it possible to run just these functions in cleartext, let's say by revealing their inputs, performing the computation locally, and then secret-sharing the outputs? Perhaps it's easier to do with the Simulator and the REF2K backend?

I think it might be fine to do this in experiments, as the np.complex64 only needs to compute freqs_cis (This process does not need any private inputs).

@deevashwer
Copy link
Author

Hi @fionser and @anakinxc,

Thanks for the quick response!

Can you please guide me on how to introduce this makeshift solution of running the unsupported functions in cleartext, so that I can test the end-to-end accuracy of my solution?
In terms of Llama JAX being buggy, EasyLM is under active deployment and I can ask the authors for help with debugging that part.

@anakinxc
Copy link
Contributor

Hi @fionser and @anakinxc,

Thanks for the quick response!

Can you please guide me on how to introduce this makeshift solution of running the unsupported functions in cleartext, so that I can test the end-to-end accuracy of my solution? In terms of Llama JAX being buggy, EasyLM is under active deployment and I can ask the authors for help with debugging that part.

I'll prepare an example :D

@deevashwer
Copy link
Author

@ye, good point! sine and cosine are indeed operating on public inputs, but it seems in the code here that there is still some complex arithmetic going on secret inputs. Perhaps that is easy to emulate with real arithmetic. :)

@anakinxc, thank you so much! 😄

@anakinxc
Copy link
Contributor

@ye, good point! sine and cosine are indeed operating on public inputs, but it seems in the code here that there is still some complex arithmetic going on secret inputs. Perhaps that is easy to emulate with real arithmetic. :)

@anakinxc, thank you so much! 😄

There is a set of transforms to lower complex ops.

I'll check how good/bad these are.

@anakinxc
Copy link
Contributor

Hi @deevashwer

I added a simple hacked sine here

Should give you an idea about how to end-to-end support an instruction

@fionser
Copy link
Contributor

fionser commented Aug 16, 2023

Actually, I am considering just to use the python hijack (just as the gelu hack in the llama-7b example)
to work around the np.complex
This should give us a workable code just for verification (and not need to do the front-end-back-end dispatch, maybe)

But I have no idea how apply_rotary_emb function work.
It seems to compute the attention on complex number, and stack the image/real part, like here
Then I would like to just re-write some EasyLM's code to explicitly call the multiplication on real numbers
(Suppose the weight matrix is real, but not complex?)

@deevashwer
Copy link
Author

Hi @anakinxc @fionser,

I am trying to work around this issue by partitioning the LLaMA model into private and public parts, where all the private parts (most of the computation) will be run within SPU and the public parts will be run with JAX (float).

To do the same, I need to simulate parts of the NN computation within SPU, and an example of how this would work is as follows:

import jax
from jax import random, numpy as jnp
from flax import linen as nn
import spu.utils.simulation as pps
import spu

protocol = spu.ProtocolKind.REF2K
field = spu.FieldType.FM64
config = spu.RuntimeConfig(protocol=protocol, field=field)
simulator = pps.Simulator(1, config)

# pure function
def dense(params, x):
    return jax.lax.dot_general(x, params['params']['kernel'], (((x.ndim - 1,), (0,)), ((), ())),)

class LinearModel(nn.Module):
    features: int
    dtype: jnp.dtype = jnp.float32
    
    def setup(self):
        self.layer = nn.Dense(self.features, use_bias=False, dtype=self.dtype)

    def __call__(self, x):
        params = {"params": self.layer.variables['params']}

        # FLAX - Works
        y = self.layer.apply(params, x)
        print("y", y)

        # SPU with pure function - Works
        spu_dense = pps.sim_jax(simulator, dense)
        spu_y = spu_dense(params, x)
        print("spu_y (pure function)", spu_y)

        # SPU with nn.Module.apply - Doesn't work
        spu_dense = pps.sim_jax(simulator, self.layer.apply)
        spu_y = spu_dense(params, x)
        print("spu_y (nn.Module.apply)", spu_y)

        return y

model = LinearModel(features=5)

key = random.PRNGKey(0)
params = {'params': {'layer': {'kernel': random.normal(key, (10, 5))}}}
x = random.normal(key, (10,))

y = model.apply(params, x)

In this example, I'm trying to run one of the linear layers within SPU, but it doesn't work with the class method apply (error in the following code block). I wrote a pure function emulating the same functionality, and that works without any problems. However, writing pure functions for each layer of a complex NN like LLaMA will be infeasible. So, for this approach to work, I have to make the simulation work with the layer.apply method.
The way I've written this example is the standard way of implementing models in FLAX. Can you please help me understand what's the issue here and suggest fixes? Thanks!

Traceback (most recent call last):
  File "/Users/deevashwer/MPC-auto/spu-scripts/test.py", line 48, in <module>
    y = model.apply(params, x)
  File "/Users/deevashwer/MPC-auto/spu-scripts/test.py", line 37, in __call__
    spu_y = spu_dense(params, x)
  File "/usr/local/anaconda3/lib/python3.10/site-packages/spu/utils/simulation.py", line 151, in wrapper
    executable, output = spu_fe.compile(
  File "/usr/local/anaconda3/lib/python3.10/site-packages/spu/utils/frontend.py", line 148, in compile
    ir_text, output = _jax_compilation(
  File "/usr/local/anaconda3/lib/python3.10/site-packages/cachetools/__init__.py", line 732, in wrapper
    k = key(*args, **kwargs)
  File "/usr/local/anaconda3/lib/python3.10/site-packages/spu/utils/frontend.py", line 34, in _jax_compilation_key
    f'{hash(cloudpickle.dumps(fn))}-{static_argnums}-{static_argnames}-{types}'
  File "/usr/local/anaconda3/lib/python3.10/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/usr/local/anaconda3/lib/python3.10/site-packages/cloudpickle/cloudpickle_fast.py", line 602, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle 'weakref.ReferenceType' object

@anakinxc
Copy link
Contributor

Hi @deevashwer

This looks like a bug in our compilation cache. I'll take a look

anakinxc added a commit that referenced this issue Aug 22, 2023
# Pull Request

## What problem does this PR solve?

Issue Number: #306 

Fixed compilation cache error when Callable contains weakref

## Possible side effects?

- Performance: N/A

- Backward compatibility: N/A
@anakinxc
Copy link
Contributor

Hi @anakinxc @fionser,

I am trying to work around this issue by partitioning the LLaMA model into private and public parts, where all the private parts (most of the computation) will be run within SPU and the public parts will be run with JAX (float).

To do the same, I need to simulate parts of the NN computation within SPU, and an example of how this would work is as follows:

import jax
from jax import random, numpy as jnp
from flax import linen as nn
import spu.utils.simulation as pps
import spu

protocol = spu.ProtocolKind.REF2K
field = spu.FieldType.FM64
config = spu.RuntimeConfig(protocol=protocol, field=field)
simulator = pps.Simulator(1, config)

# pure function
def dense(params, x):
    return jax.lax.dot_general(x, params['params']['kernel'], (((x.ndim - 1,), (0,)), ((), ())),)

class LinearModel(nn.Module):
    features: int
    dtype: jnp.dtype = jnp.float32
    
    def setup(self):
        self.layer = nn.Dense(self.features, use_bias=False, dtype=self.dtype)

    def __call__(self, x):
        params = {"params": self.layer.variables['params']}

        # FLAX - Works
        y = self.layer.apply(params, x)
        print("y", y)

        # SPU with pure function - Works
        spu_dense = pps.sim_jax(simulator, dense)
        spu_y = spu_dense(params, x)
        print("spu_y (pure function)", spu_y)

        # SPU with nn.Module.apply - Doesn't work
        spu_dense = pps.sim_jax(simulator, self.layer.apply)
        spu_y = spu_dense(params, x)
        print("spu_y (nn.Module.apply)", spu_y)

        return y

model = LinearModel(features=5)

key = random.PRNGKey(0)
params = {'params': {'layer': {'kernel': random.normal(key, (10, 5))}}}
x = random.normal(key, (10,))

y = model.apply(params, x)

In this example, I'm trying to run one of the linear layers within SPU, but it doesn't work with the class method apply (error in the following code block). I wrote a pure function emulating the same functionality, and that works without any problems. However, writing pure functions for each layer of a complex NN like LLaMA will be infeasible. So, for this approach to work, I have to make the simulation work with the layer.apply method. The way I've written this example is the standard way of implementing models in FLAX. Can you please help me understand what's the issue here and suggest fixes? Thanks!

Traceback (most recent call last):
  File "/Users/deevashwer/MPC-auto/spu-scripts/test.py", line 48, in <module>
    y = model.apply(params, x)
  File "/Users/deevashwer/MPC-auto/spu-scripts/test.py", line 37, in __call__
    spu_y = spu_dense(params, x)
  File "/usr/local/anaconda3/lib/python3.10/site-packages/spu/utils/simulation.py", line 151, in wrapper
    executable, output = spu_fe.compile(
  File "/usr/local/anaconda3/lib/python3.10/site-packages/spu/utils/frontend.py", line 148, in compile
    ir_text, output = _jax_compilation(
  File "/usr/local/anaconda3/lib/python3.10/site-packages/cachetools/__init__.py", line 732, in wrapper
    k = key(*args, **kwargs)
  File "/usr/local/anaconda3/lib/python3.10/site-packages/spu/utils/frontend.py", line 34, in _jax_compilation_key
    f'{hash(cloudpickle.dumps(fn))}-{static_argnums}-{static_argnames}-{types}'
  File "/usr/local/anaconda3/lib/python3.10/site-packages/cloudpickle/cloudpickle_fast.py", line 73, in dumps
    cp.dump(obj)
  File "/usr/local/anaconda3/lib/python3.10/site-packages/cloudpickle/cloudpickle_fast.py", line 602, in dump
    return Pickler.dump(self, obj)
TypeError: cannot pickle 'weakref.ReferenceType' object

Hi @deevashwer

Fixed. :P

@anakinxc
Copy link
Contributor

Hi @deevashwer

Complex, sin/cos support were added in the latest version and puma README has been updated.

Please take a look.

Best
~Anakin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants