Skip to content

Commit

Permalink
Streamed replies are now handled by Request
Browse files Browse the repository at this point in the history
  • Loading branch information
cdevienne committed Sep 18, 2018
1 parent c59a3ad commit 9f435ef
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 156 deletions.
123 changes: 21 additions & 102 deletions examples/alloptions/alloptions.nrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,70 +51,6 @@ func (h *SvcCustomSubjectHandler) MtNoRequestPublish(pkginstance string, msg Sim
return h.nc.Publish(subject, rawMsg)
}

func (h *SvcCustomSubjectHandler) MtStreamedReplyHandler(
ctx context.Context, request *nrpc.Request, req StringArg) {
ctx, cancel := context.WithCancel(ctx)

keepStreamAlive := nrpc.NewKeepStreamAlive(
request.Conn, request.ReplySubject, request.Encoding, cancel,
)

var msgCount uint32

_, nrpcErr := nrpc.CaptureErrors(func() (proto.Message, error) {
err := h.server.MtStreamedReply(ctx, req, func(rep SimpleStringReply){
if err := request.SendReply(&rep, nil); err != nil {
log.Printf("nrpc: error publishing response")
cancel()
return
}
msgCount++
})
return nil, err
})
keepStreamAlive.Stop()

if nrpcErr != nil {
request.SendReply(nil, nrpcErr)
} else {
request.SendReply(
nil, &nrpc.Error{Type: nrpc.Error_EOS, MsgCount: msgCount},
)
}
}

func (h *SvcCustomSubjectHandler) MtVoidReqStreamedReplyHandler(
ctx context.Context, request *nrpc.Request) {
ctx, cancel := context.WithCancel(ctx)

keepStreamAlive := nrpc.NewKeepStreamAlive(
request.Conn, request.ReplySubject, request.Encoding, cancel,
)

var msgCount uint32

_, nrpcErr := nrpc.CaptureErrors(func() (proto.Message, error) {
err := h.server.MtVoidReqStreamedReply(ctx, func(rep SimpleStringReply){
if err := request.SendReply(&rep, nil); err != nil {
log.Printf("nrpc: error publishing response")
cancel()
return
}
msgCount++
})
return nil, err
})
keepStreamAlive.Stop()

if nrpcErr != nil {
request.SendReply(nil, nrpcErr)
} else {
request.SendReply(
nil, &nrpc.Error{Type: nrpc.Error_EOS, MsgCount: msgCount},
)
}
}

func (h *SvcCustomSubjectHandler) Handler(msg *nats.Msg) {
request := nrpc.NewRequest(h.ctx, h.nc, msg.Subject, msg.Reply)
// extract method name & encoding from subject
Expand Down Expand Up @@ -194,8 +130,13 @@ func (h *SvcCustomSubjectHandler) Handler(msg *nats.Msg) {
Message: "bad request received: " + err.Error(),
}
} else {
h.MtStreamedReplyHandler(h.ctx, request, req)
return
request.SetupStreamedReply()
request.Handler = func(ctx context.Context)(proto.Message, error){
err := h.server.MtStreamedReply(ctx, req, func(rep SimpleStringReply){
request.SendStreamReply(&rep)
})
return nil, err
}
}
case "mtvoidreqstreamedreply":
_, request.Encoding, err = nrpc.ParseSubjectTail(0, request.SubjectTail)
Expand All @@ -211,8 +152,13 @@ func (h *SvcCustomSubjectHandler) Handler(msg *nats.Msg) {
Message: "bad request received: " + err.Error(),
}
} else {
h.MtVoidReqStreamedReplyHandler(h.ctx, request)
return
request.SetupStreamedReply()
request.Handler = func(ctx context.Context)(proto.Message, error){
err := h.server.MtVoidReqStreamedReply(ctx, func(rep SimpleStringReply){
request.SendStreamReply(&rep)
})
return nil, err
}
}
default:
log.Printf("SvcCustomSubjectHandler: unknown name %q", name)
Expand Down Expand Up @@ -428,38 +374,6 @@ func (h *SvcSubjectParamsHandler) Subject() string {
return "root.*.svcsubjectparams.*.>"
}

func (h *SvcSubjectParamsHandler) MtStreamedReplyWithSubjectParamsHandler(
ctx context.Context, request *nrpc.Request, mtParams []string) {
ctx, cancel := context.WithCancel(ctx)

keepStreamAlive := nrpc.NewKeepStreamAlive(
request.Conn, request.ReplySubject, request.Encoding, cancel,
)

var msgCount uint32

_, nrpcErr := nrpc.CaptureErrors(func() (proto.Message, error) {
err := h.server.MtStreamedReplyWithSubjectParams(ctx, mtParams[0], mtParams[1], func(rep SimpleStringReply){
if err := request.SendReply(&rep, nil); err != nil {
log.Printf("nrpc: error publishing response")
cancel()
return
}
msgCount++
})
return nil, err
})
keepStreamAlive.Stop()

if nrpcErr != nil {
request.SendReply(nil, nrpcErr)
} else {
request.SendReply(
nil, &nrpc.Error{Type: nrpc.Error_EOS, MsgCount: msgCount},
)
}
}

func (h *SvcSubjectParamsHandler) MtNoRequestWParamsPublish(pkginstance string, svcclientid string, mtmp1 string, msg SimpleStringReply) error {
rawMsg, err := nrpc.Marshal("protobuf", &msg)
if err != nil {
Expand Down Expand Up @@ -526,8 +440,13 @@ func (h *SvcSubjectParamsHandler) Handler(msg *nats.Msg) {
Message: "bad request received: " + err.Error(),
}
} else {
h.MtStreamedReplyWithSubjectParamsHandler(h.ctx, request, mtParams)
return
request.SetupStreamedReply()
request.Handler = func(ctx context.Context)(proto.Message, error){
err := h.server.MtStreamedReplyWithSubjectParams(ctx, mtParams[0], mtParams[1], func(rep SimpleStringReply){
request.SendStreamReply(&rep)
})
return nil, err
}
}
case "mtnoreply":
request.NoReply = true
Expand Down
1 change: 1 addition & 0 deletions examples/alloptions/alloptions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ func TestAll(t *testing.T) {
context.Background(),
StringArg{Arg1: "arg"},
func(ctx context.Context, rep SimpleStringReply) {
fmt.Println("received", rep)
resList = append(resList, rep.GetReply())
})
if err != nil {
Expand Down
47 changes: 46 additions & 1 deletion nrpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,11 @@ type Request struct {
Context context.Context
Conn NatsConn

KeepStreamAlive *KeepStreamAlive
StreamContext context.Context
StreamCancel func()
StreamMsgCount uint32

Subject string
MethodName string
SubjectTail []string
Expand All @@ -286,7 +291,11 @@ func (r Request) Elapsed() time.Duration {
// that should be returned to the caller
func (r Request) Run() (msg proto.Message, replyError *Error) {
r.StartedAt = time.Now()
ctx := context.WithValue(r.Context, RequestContextKey, &r)
ctx := r.Context
if r.StreamedReply() {
ctx = r.StreamContext
}
ctx = context.WithValue(ctx, RequestContextKey, &r)
msg, replyError = CaptureErrors(
func() (proto.Message, error) {
return r.Handler(ctx)
Expand Down Expand Up @@ -326,8 +335,44 @@ func (r *Request) SetServiceParam(key, value string) {
r.ServiceParams[key] = value
}

// SetupStreamedReply initializes the reply stream
func (r *Request) SetupStreamedReply() {
r.StreamContext, r.StreamCancel = context.WithCancel(r.Context)
r.KeepStreamAlive = NewKeepStreamAlive(
r.Conn, r.ReplySubject, r.Encoding, r.StreamCancel)
}

// StreamedReply returns true if the request reply is streamed
func (r Request) StreamedReply() bool {
return r.KeepStreamAlive != nil
}

// SendStreamReply send a reply a part of a stream
func (r *Request) SendStreamReply(msg proto.Message) {
log.Printf("nrpc: SendStreamReply")
if err := r.sendReply(msg, nil); err != nil {
log.Printf("nrpc: error publishing response")
r.StreamCancel()
return
}
r.StreamMsgCount++
}

// SendReply sends a reply to the caller
func (r *Request) SendReply(resp proto.Message, withError *Error) error {
if r.StreamedReply() {
r.KeepStreamAlive.Stop()
if withError == nil {
return r.sendReply(
nil, &Error{Type: Error_EOS, MsgCount: r.StreamMsgCount},
)
}
}
return r.sendReply(resp, withError)
}

// sendReply sends a reply to the caller
func (r *Request) sendReply(resp proto.Message, withError *Error) error {
return Publish(resp, withError, r.Conn, r.ReplySubject, r.Encoding)
}

Expand Down
63 changes: 10 additions & 53 deletions protoc-gen-nrpc/tmpl.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,54 +155,6 @@ func (h *{{$serviceName}}Handler) {{.GetName}}Publish(
return h.nc.Publish(subject, rawMsg)
}
{{- end}}
{{- if HasStreamedReply .}}
func (h *{{$serviceName}}Handler) {{.GetName}}Handler(
ctx context.Context, request *nrpc.Request
{{- if GetMethodSubjectParams . -}}
, mtParams []string
{{- end -}}
{{- if ne .GetInputType ".nrpc.Void" -}}
, req {{GoType .GetInputType}}
{{- end -}}
) {
ctx, cancel := context.WithCancel(ctx)
keepStreamAlive := nrpc.NewKeepStreamAlive(
request.Conn, request.ReplySubject, request.Encoding, cancel,
)
var msgCount uint32
_, nrpcErr := nrpc.CaptureErrors(func() (proto.Message, error) {
err := h.server.{{.GetName}}(ctx
{{- range $i, $p := GetMethodSubjectParams . -}}
, mtParams[{{ $i }}]
{{- end -}}
{{- if ne .GetInputType ".nrpc.Void" -}}
, req
{{- end -}}
, func(rep {{GoType .GetOutputType}}){
if err := request.SendReply(&rep, nil); err != nil {
log.Printf("nrpc: error publishing response")
cancel()
return
}
msgCount++
})
return nil, err
})
keepStreamAlive.Stop()
if nrpcErr != nil {
request.SendReply(nil, nrpcErr)
} else {
request.SendReply(
nil, &nrpc.Error{Type: nrpc.Error_EOS, MsgCount: msgCount},
)
}
}
{{- end}}
{{- end}}
{{- if ServiceNeedsHandler .}}
Expand Down Expand Up @@ -262,15 +214,20 @@ func (h *{{.GetName}}Handler) Handler(msg *nats.Msg) {
{{- end}}
} else {
{{- if HasStreamedReply .}}
h.{{.GetName}}Handler(h.ctx, request
{{- if GetMethodSubjectParams . -}}
, mtParams
request.SetupStreamedReply()
request.Handler = func(ctx context.Context)(proto.Message, error){
err := h.server.{{.GetName}}(ctx
{{- range $i, $p := GetMethodSubjectParams . -}}
, mtParams[{{ $i }}]
{{- end -}}
{{- if ne .GetInputType ".nrpc.Void" -}}
, req
{{- end -}}
)
return
, func(rep {{GoType .GetOutputType}}){
request.SendStreamReply(&rep)
})
return nil, err
}
{{- else }}
request.Handler = func(ctx context.Context)(proto.Message, error){
{{- if eq .GetOutputType ".nrpc.NoReply" -}}
Expand Down

0 comments on commit 9f435ef

Please sign in to comment.