Skip to content

Commit

Permalink
grpc service (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
amirylm committed Sep 26, 2023
1 parent 3f1740a commit 60ec4d6
Show file tree
Hide file tree
Showing 12 changed files with 661 additions and 81 deletions.
46 changes: 46 additions & 0 deletions api/grpc/control.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package grpcapi

import (
"context"

"github.com/amirylm/p2pmq/core"
"github.com/amirylm/p2pmq/proto"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// ControlServiceImpl is an implementation of ControlServiceServer.
type ControlServiceImpl struct {
proto.ControlServiceServer

pubsub core.Pubsuber
}

// NewControlServiceServer creates a new ControlServiceServer instance.
func NewControlServiceServer(ps core.Pubsuber) *ControlServiceImpl {
return &ControlServiceImpl{pubsub: ps}
}

// Publish implements the Publish RPC method.
func (s *ControlServiceImpl) Publish(ctx context.Context, req *proto.PublishRequest) (*proto.PublishResponse, error) {
if err := s.pubsub.Publish(ctx, req.GetTopic(), req.GetData()); err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
return &proto.PublishResponse{}, nil
}

// Subscribe implements the Subscribe RPC method.
func (s *ControlServiceImpl) Subscribe(ctx context.Context, req *proto.SubscribeRequest) (*proto.SubscribeResponse, error) {
if err := s.pubsub.Subscribe(ctx, req.GetTopic()); err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
return &proto.SubscribeResponse{}, nil
}

// Unsubscribe implements the Unsubscribe RPC method.
func (s *ControlServiceImpl) Unsubscribe(_ context.Context, req *proto.SubscribeRequest) (*proto.SubscribeResponse, error) {
if err := s.pubsub.Unsubscribe(req.GetTopic()); err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
return &proto.SubscribeResponse{}, nil
}
50 changes: 50 additions & 0 deletions api/grpc/msg_router.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package grpcapi

import (
"fmt"

"github.com/amirylm/p2pmq/core"
"github.com/amirylm/p2pmq/proto"
)

// MsgRouterImpl is an implementation of MsgRouterServer.
type MsgRouterImpl struct {
proto.MsgRouterServer

q chan *proto.Message
}

// NewMsgRouterServer creates a new MsgRouterServer instance.
func NewMsgRouterServer(qSize int) *MsgRouterImpl {
return &MsgRouterImpl{q: make(chan *proto.Message, qSize)}
}

func (r *MsgRouterImpl) Push(next *core.MsgWrapper[error]) error {
select {
case r.q <- &proto.Message{
MessageId: next.Msg.ID,
Topic: next.Msg.GetTopic(),
Data: next.Msg.GetData(),
}:
default:
return fmt.Errorf("queue is full")
}
return nil
}

// Listen implements the Listen RPC method.
func (r *MsgRouterImpl) Listen(req *proto.ListenRequest, stream proto.MsgRouter_ListenServer) error {
for {
select {
case <-stream.Context().Done():
return nil
case next := <-r.q:
if next == nil {
return nil
}
if err := stream.Send(next); err != nil {
return streamErr(err)
}
}
}
}
41 changes: 41 additions & 0 deletions api/grpc/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package grpcapi

import (
"fmt"
"net"

"github.com/amirylm/p2pmq/core"
"github.com/amirylm/p2pmq/proto"
"github.com/pkg/errors"
"google.golang.org/grpc"
)

func NewServices(ps core.Pubsuber, qSize int) (*ControlServiceImpl, *MsgRouterImpl, *ValidationRouterImpl) {
controlServiceServer := NewControlServiceServer(ps)
msgRouterServer := NewMsgRouterServer(qSize)
valRouterServer := NewValidationRouterServer(ps, qSize)

return controlServiceServer, msgRouterServer, valRouterServer
}

func NewGrpcServer(controlService *ControlServiceImpl, msgRouter *MsgRouterImpl, valRouter *ValidationRouterImpl) *grpc.Server {
grpcServer := grpc.NewServer()

proto.RegisterControlServiceServer(grpcServer, controlService)
proto.RegisterMsgRouterServer(grpcServer, msgRouter)
proto.RegisterValidationRouterServer(grpcServer, valRouter)

return grpcServer
}

func ListenGrpc(s *grpc.Server, grpcPort int) error {
lis, err := net.Listen("tcp", fmt.Sprintf(":%d", grpcPort))
if err != nil {
return errors.Wrap(err, "could not create TCP listener")
}
defer s.Stop()
if err := s.Serve(lis); err != nil {
return errors.Wrap(err, "could not serve grpc")
}
return nil
}
240 changes: 240 additions & 0 deletions api/grpc/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
package grpcapi

import (
"context"
"fmt"
"io"
"math/rand"
"sync"
"sync/atomic"
"testing"
"time"

logging "github.com/ipfs/go-log"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

"github.com/amirylm/p2pmq/commons/utils"
"github.com/amirylm/p2pmq/core"
"github.com/amirylm/p2pmq/proto"
)

func TestGrpc_Network(t *testing.T) {
t.Skip()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

n := 4
rounds := 5

require.NoError(t, logging.SetLogLevelRegex("p2pmq", "debug"))

valHitMap := map[string]*atomic.Int32{}
msgHitMap := map[string]*atomic.Int32{}
for i := 0; i < n; i++ {
topic := fmt.Sprintf("test-%d", i+1)
msgHitMap[topic] = &atomic.Int32{}
valHitMap[topic] = &atomic.Int32{}
}

controllers, _, _, done := core.SetupTestControllers(ctx, t, n, func(msg *pubsub.Message) {
msgHitMap[msg.GetTopic()].Add(1)
// lggr.Infow("got pubsub message", "topic", m.GetTopic(), "from", m.GetFrom(), "data", string(m.GetData()))
}, func(p peer.ID, msg *pubsub.Message) pubsub.ValidationResult {
valHitMap[msg.GetTopic()].Add(1)
return pubsub.ValidationAccept
})
defer done()
require.Equal(t, n, len(controllers))

grpcServers := make([]*grpc.Server, n)
for i := 0; i < n; i++ {
ctrl := controllers[i]
control, msgR, valR := NewServices(ctrl, 128)
ctrl.RefreshRouters(func(mw *core.MsgWrapper[error]) {
require.NoError(t, msgR.Push(mw))
}, func(mw *core.MsgWrapper[pubsub.ValidationResult]) {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
mw.Result = valR.PushWait(ctx, mw)
})
grpcServers[i] = NewGrpcServer(control, msgR, valR)
}

threadC := utils.NewThreadControl()
defer threadC.Close()

ports := make([]int, n)
for i, s := range grpcServers {
{
srv := s
port := randPort()
ports[i] = port
threadC.Go(func(ctx context.Context) {
err := ListenGrpc(srv, port)
if ctx.Err() == nil {
require.NoError(t, err)
}
})
}
}

<-time.After(time.Second * 5) // TODO: avoid timeout

conns := make([]*grpc.ClientConn, n)
for i := range grpcServers {
conn, err := grpc.Dial(fmt.Sprintf(":%d", ports[i]), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
conns[i] = conn
}

for i := range grpcServers {
conn, err := grpc.Dial(fmt.Sprintf(":%d", ports[i]), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
threadC.Go(func(ctx context.Context) {
val := proto.NewValidationRouterClient(conn)
valClient, err := val.Handle(ctx)
require.NoError(t, err)

for ctx.Err() == nil {
msg, err := valClient.Recv()
if err == io.EOF || err == context.Canceled || ctx.Err() != nil || msg == nil { // stream closed
return
}
require.NoError(t, err)
valHitMap[msg.GetTopic()].Add(1)
if len(msg.Data) > 48 {
require.NoError(t, valClient.Send(&proto.ValidatedMessage{
Result: proto.ValidationResult_REJECT,
Msg: msg,
}))
} else if len(msg.Data) > 32 {
require.NoError(t, valClient.Send(&proto.ValidatedMessage{
Result: proto.ValidationResult_IGNORE,
Msg: msg,
}))
} else {
require.NoError(t, valClient.Send(&proto.ValidatedMessage{
Result: proto.ValidationResult_ACCEPT,
Msg: msg,
}))
}
}
})
}

for i := range grpcServers {
conn, err := grpc.Dial(fmt.Sprintf(":%d", ports[i]), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
threadC.Go(func(ctx context.Context) {
msgRouter := proto.NewMsgRouterClient(conn)
client, err := msgRouter.Listen(ctx, &proto.ListenRequest{})
require.NoError(t, err)

for ctx.Err() == nil {
msg, err := client.Recv()
if err == io.EOF || err == context.Canceled || ctx.Err() != nil || msg == nil { // stream closed
return
}
require.NoError(t, err)
msgHitMap[msg.GetTopic()].Add(1)
require.LessOrEqualf(t, len(msg.Data), 32, "should see only valid messages: %s", msg.Data)
}
})
}

var wg sync.WaitGroup
for i := range grpcServers {
control := proto.NewControlServiceClient(conns[i])
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < n; i++ {
_, err := control.Subscribe(ctx, &proto.SubscribeRequest{
Topic: fmt.Sprintf("test-%d", i+1),
})
require.NoError(t, err)
}
}()
}
wg.Wait()

<-time.After(time.Second * 5) // TODO: avoid timeout
t.Log("Publishing")
for r := 0; r < rounds; r++ {
for i := range grpcServers {
control := proto.NewControlServiceClient(conns[i])
req := &proto.PublishRequest{
Topic: fmt.Sprintf("test-%d", i+1),
Data: []byte(fmt.Sprintf("round-%d-test-data-%d", r+1, i+1)),
}
wg.Add(1)
go func() {
defer wg.Done()

c, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
_, err := control.Publish(c, req)
require.NoError(t, err)
}()
}
}
wg.Wait()

// invalid messages
for i := range grpcServers {
control := proto.NewControlServiceClient(conns[i])
data := []byte(fmt.Sprintf("%d-test-data-%d", rand.Int31n(1e3), i+1))
for len(data)+1 < 48 {
data = append(data, []byte(fmt.Sprintf("%d", 1e5+rand.Int31n(1e9)))...)
}
req := &proto.PublishRequest{
Topic: fmt.Sprintf("test-%d", i+1),
Data: data,
}
wg.Add(1)
go func() {
defer wg.Done()
_, _ = control.Publish(ctx, req)
}()
}

// ignored messages
for i := range grpcServers {
control := proto.NewControlServiceClient(conns[i])
data := []byte(fmt.Sprintf("%d-test-data-%d", rand.Int31n(1e3), i+1))
for len(data)+1 < 32 {
data = append(data, []byte(fmt.Sprintf("%d", rand.Int31n(1e3)))...)
}
req := &proto.PublishRequest{
Topic: fmt.Sprintf("test-%d", i+1),
Data: data,
}
wg.Add(1)
go func() {
defer wg.Done()
_, _ = control.Publish(ctx, req)
}()
}
wg.Wait()

<-time.After(time.Second * 2) // TODO: avoid timeout

for _, s := range grpcServers {
s.Stop()
}

t.Log("Asserting")
for topic, counter := range msgHitMap {
count := int(counter.Load()) / n // per node
require.Equal(t, rounds, count, "should get %d messages on topic %s", rounds, topic)
}
}

func randPort() int {
return 5000 + rand.Intn(5000)
}
Loading

0 comments on commit 60ec4d6

Please sign in to comment.