Skip to content

Commit

Permalink
fix: fix deadlock issue in grpc_func and add cors support (#158)
Browse files Browse the repository at this point in the history
Signed-off-by: Zike Yang <zike@apache.org>
  • Loading branch information
RobertIndie authored Mar 8, 2024
1 parent a95c50a commit ffb4e90
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 37 deletions.
4 changes: 4 additions & 0 deletions fs/instance_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ func (instance *FunctionInstanceImpl) Run(runtimeFactory api.FunctionRuntimeFact
instance.log.ErrorContext(instance.ctx, "Error calling process function", slog.Any("error", err))
return
}
if output == nil {
instance.log.DebugContext(instance.ctx, "output is nil")
continue
}
select {
case sinkChan <- output:
case <-instance.ctx.Done():
Expand Down
44 changes: 27 additions & 17 deletions fs/runtime/grpc/grpc_func.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,17 @@ import (

type GRPCFuncRuntime struct {
api.FunctionRuntime
Name string
instance api.FunctionInstance
ctx context.Context
status *proto.FunctionStatus
readyOnce sync.Once
readyCh chan error
input chan string
output chan string
stopFunc func()
log *slog.Logger
Name string
instance api.FunctionInstance
ctx context.Context
status *proto.FunctionStatus
readyOnce sync.Once
readyCh chan error
input chan contube.Record
output chan contube.Record
stopFunc func()
processing atomic.Bool
log *slog.Logger
}

type Status int32
Expand Down Expand Up @@ -155,8 +156,8 @@ func (s *FSSReconcileServer) NewFunctionRuntime(instance api.FunctionInstance) (
Name: name,
instance: instance,
readyCh: make(chan error),
input: make(chan string),
output: make(chan string),
input: make(chan contube.Record),
output: make(chan contube.Record),
status: &proto.FunctionStatus{
Name: name,
Status: proto.FunctionStatus_CREATING,
Expand Down Expand Up @@ -227,9 +228,10 @@ func (f *GRPCFuncRuntime) Stop() {
}

func (f *GRPCFuncRuntime) Call(event contube.Record) (contube.Record, error) {
f.input <- string(event.GetPayload())
f.input <- event
out := <-f.output
return contube.NewRecordImpl([]byte(out), event.Commit), nil
f.processing.Store(false)
return out, nil
}

type FunctionServerImpl struct {
Expand All @@ -255,12 +257,20 @@ func (f *FunctionServerImpl) Process(req *proto.FunctionProcessRequest, stream p
})
errCh := make(chan error)

defer func() {
if runtime.processing.Load() {
runtime.output <- nil
runtime.processing.Store(false)
}
}()

logCounter := common.LogCounter()
for {
select {
case payload := <-runtime.input:
case event := <-runtime.input:
log.DebugContext(stream.Context(), "sending event", slog.Any("count", logCounter))
err := stream.Send(&proto.Event{Payload: payload})
runtime.processing.Store(true)
err := stream.Send(&proto.Event{Payload: string(event.GetPayload())}) // TODO: Change payload type to bytes
if err != nil {
log.Error("failed to send event", slog.Any("error", err))
return err
Expand Down Expand Up @@ -292,7 +302,7 @@ func (f *FunctionServerImpl) Output(ctx context.Context, e *proto.Event) (*proto
return nil, err
}
runtime.log.DebugContext(ctx, "received event")
runtime.output <- e.Payload
runtime.output <- contube.NewRecordImpl([]byte(e.Payload), func() {})
return &proto.Response{
Status: proto.Response_OK,
}, nil
Expand Down
45 changes: 33 additions & 12 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type Server struct {
options *serverOptions
httpSvr atomic.Pointer[http.Server]
log *slog.Logger
manager *fs.FunctionManager
Manager *fs.FunctionManager
}

type serverOptions struct {
Expand All @@ -59,7 +59,7 @@ func (f serverOptionFunc) apply(c *serverOptions) (*serverOptions, error) {
return f(c)
}

// WithFunctionManager sets the function manager for the server.
// WithFunctionManager sets the function Manager for the server.
func WithFunctionManager(opts ...fs.ManagerOption) ServerOption {
return serverOptionFunc(func(o *serverOptions) (*serverOptions, error) {
o.managerOpts = append(o.managerOpts, opts...)
Expand Down Expand Up @@ -114,7 +114,7 @@ func NewServer(opts ...ServerOption) (*Server, error) {
}
return &Server{
options: options,
manager: manager,
Manager: manager,
log: slog.With(),
}, nil
}
Expand Down Expand Up @@ -167,8 +167,29 @@ func (s *Server) Run(context context.Context) {
}
}

func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")

if r.Method == "OPTIONS" {
w.Header().Set("Access-Control-Max-Age", "86400")
w.WriteHeader(http.StatusOK)
return
}
next.ServeHTTP(w, r)
})
}

func (s *Server) startRESTHandlers() error {
r := mux.NewRouter()

r.PathPrefix("/").Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
})).Methods("OPTIONS")

r.Use(corsMiddleware)

r.HandleFunc("/api/v1/status", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}).Methods("GET")
Expand Down Expand Up @@ -208,7 +229,7 @@ func (s *Server) startRESTHandlers() error {
return
}

err = s.manager.StartFunction(f)
err = s.Manager.StartFunction(f)
if err != nil {
log.Error("Failed to start function", "error", err)
http.Error(w, err.Error(), http.StatusBadRequest)
Expand All @@ -222,7 +243,7 @@ func (s *Server) startRESTHandlers() error {
functionName := vars["function_name"]
log := s.log.With(slog.String("name", functionName), slog.String("phase", "deleting"))

err := s.manager.DeleteFunction(functionName)
err := s.Manager.DeleteFunction(functionName)
if errors.Is(err, common.ErrorFunctionNotFound) {
log.Error("Function not found", "error", err)
http.Error(w, err.Error(), http.StatusNotFound)
Expand All @@ -234,7 +255,7 @@ func (s *Server) startRESTHandlers() error {
r.HandleFunc("/api/v1/functions", func(w http.ResponseWriter, r *http.Request) {
log := s.log.With()
log.Info("Listing functions")
functions := s.manager.ListFunctions()
functions := s.Manager.ListFunctions()
w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(functions)
if err != nil {
Expand All @@ -255,7 +276,7 @@ func (s *Server) startRESTHandlers() error {
http.Error(w, errors.Wrap(err, "Failed to read body").Error(), http.StatusBadRequest)
return
}
err = s.manager.ProduceEvent(queueName, contube.NewRecordImpl(content, func() {}))
err = s.Manager.ProduceEvent(queueName, contube.NewRecordImpl(content, func() {}))
if err != nil {
log.Error("Failed to produce event", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand All @@ -269,7 +290,7 @@ func (s *Server) startRESTHandlers() error {
queueName := vars["queue_name"]
log := s.log.With(slog.String("queue_name", queueName))
log.Info("Consuming event from queue")
event, err := s.manager.ConsumeEvent(queueName)
event, err := s.Manager.ConsumeEvent(queueName)
if err != nil {
log.Error("Failed to consume event", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down Expand Up @@ -304,7 +325,7 @@ func (s *Server) startRESTHandlers() error {
}
log := s.log.With(slog.String("key", key))
log.Info("Getting state")
state := s.manager.GetStateStore()
state := s.Manager.GetStateStore()
if state == nil {
log.Error("No state store configured")
http.Error(w, "No state store configured", http.StatusBadRequest)
Expand Down Expand Up @@ -333,7 +354,7 @@ func (s *Server) startRESTHandlers() error {
}
log := s.log.With(slog.String("key", key))
log.Info("Getting state")
state := s.manager.GetStateStore()
state := s.Manager.GetStateStore()
if state == nil {
log.Error("No state store configured")
http.Error(w, "No state store configured", http.StatusBadRequest)
Expand Down Expand Up @@ -411,8 +432,8 @@ func (s *Server) Close() error {
return err
}
}
if s.manager != nil {
err := s.manager.Close()
if s.Manager != nil {
err := s.Manager.Close()
if err != nil {
return err
}
Expand Down
16 changes: 8 additions & 8 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestStandaloneBasicFunction(t *testing.T) {
Name: "test-func",
Replicas: 1,
}
err := s.manager.StartFunction(funcConf)
err := s.Manager.StartFunction(funcConf)
if err != nil {
t.Fatal(err)
}
Expand All @@ -116,13 +116,13 @@ func TestStandaloneBasicFunction(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = s.manager.ProduceEvent(inputTopic, contube.NewRecordImpl(jsonBytes, func() {
err = s.Manager.ProduceEvent(inputTopic, contube.NewRecordImpl(jsonBytes, func() {
}))
if err != nil {
t.Fatal(err)
}

event, err := s.manager.ConsumeEvent(outputTopic)
event, err := s.Manager.ConsumeEvent(outputTopic)
if err != nil {
t.Error(err)
return
Expand Down Expand Up @@ -163,7 +163,7 @@ func TestHttpTube(t *testing.T) {
Replicas: 1,
}

err := s.manager.StartFunction(funcConf)
err := s.Manager.StartFunction(funcConf)
assert.Nil(t, err)

p := &tests.Person{
Expand All @@ -178,7 +178,7 @@ func TestHttpTube(t *testing.T) {
_, err = http.Post(httpAddr+"/api/v1/http-tube/"+endpoint, "application/json", bytes.NewBuffer(jsonBytes))
assert.Nil(t, err)

event, err := s.manager.ConsumeEvent(funcConf.Output)
event, err := s.Manager.ConsumeEvent(funcConf.Output)
if err != nil {
t.Error(err)
return
Expand Down Expand Up @@ -243,19 +243,19 @@ func TestStatefulFunction(t *testing.T) {
Output: "output",
Replicas: 1,
}
err := s.manager.StartFunction(funcConf)
err := s.Manager.StartFunction(funcConf)
if err != nil {
t.Fatal(err)
}

_, err = http.Post(httpAddr+"/api/v1/state/key", "text/plain; charset=utf-8", bytes.NewBuffer([]byte("hello")))
assert.Nil(t, err)

err = s.manager.ProduceEvent(funcConf.Inputs[0], contube.NewRecordImpl(nil, func() {
err = s.Manager.ProduceEvent(funcConf.Inputs[0], contube.NewRecordImpl(nil, func() {
}))
assert.Nil(t, err)

_, err = s.manager.ConsumeEvent(funcConf.Output)
_, err = s.Manager.ConsumeEvent(funcConf.Output)
assert.Nil(t, err)

resp, err := http.Get(httpAddr + "/api/v1/state/key")
Expand Down

0 comments on commit ffb4e90

Please sign in to comment.