diff --git a/main.go b/main.go index 1e7238b..5c6c249 100644 --- a/main.go +++ b/main.go @@ -63,48 +63,6 @@ func proxyRequest(fullSubdomain, path string, buffer *bytes.Buffer, r *http.Requ return resp.StatusCode, headers, nil } -func proxyGrpcRequest(ctx context.Context, fullSubdomain, method string, r *http.Request) (codes.Code, map[string]string, []byte, error) { - target := "https://" + fullSubdomain + ".lunaroasis.net" + method - conn, err := grpc.Dial(target, grpc.WithInsecure()) - if err != nil { - return 0, nil, nil, err - } - defer conn.Close() - - stream, err := conn.NewStream(ctx, &grpc.StreamDesc{ServerStreams: true}, method) - if err != nil { - return 0, nil, nil, err - } - - // Assume r.Body is a io.Reader containing the serialized request message - if err := stream.SendMsg(r.Body); err != nil { - return 0, nil, nil, err - } - - // Create a buffer to hold the serialized response message - buffer := new(bytes.Buffer) - if err := stream.RecvMsg(buffer); err != nil { - return 0, nil, nil, err - } - - headers, _ := metadata.FromIncomingContext(ctx) - headerMap := make(map[string]string) - for key, values := range headers { - for _, value := range values { - headerMap[key] = value - } - } - - if err := stream.RecvMsg(buffer); err != nil { - st, ok := status.FromError(err) - if ok { - return st.Code(), nil, buffer.Bytes(), err - } - return codes.Unknown, nil, nil, err - } - return codes.Unknown, nil, nil, fmt.Errorf("unexpected error") -} - func handleHttpRequest(w http.ResponseWriter, r *http.Request) { infoLog.Printf("Received request from %s", r.Host) hostParts := strings.Split(r.Host, ".") @@ -146,6 +104,59 @@ func handleHttpRequest(w http.ResponseWriter, r *http.Request) { io.Copy(w, buffer) } +func proxyGrpcRequest(ctx context.Context, fullSubdomain, method string, r *http.Request) (codes.Code, map[string]string, []byte, error) { + target := "https://" + fullSubdomain + ".lunaroasis.net" + method + conn, err := grpc.Dial(target, grpc.WithInsecure()) + if err != nil { + return 0, nil, nil, err + } + defer conn.Close() + + // Convert http.Header to map[string]string + headerMap := make(map[string]string) + for key, values := range r.Header { + headerMap[key] = strings.Join(values, ",") + } + + // Create metadata from the header map + md := metadata.New(headerMap) + + // Create a new outgoing context with the metadata + ctx = metadata.NewOutgoingContext(ctx, md) + + stream, err := conn.NewStream(ctx, &grpc.StreamDesc{ServerStreams: true}, method) + if err != nil { + return 0, nil, nil, err + } + + // Assume r.Body is a io.Reader containing the serialized request message + if err := stream.SendMsg(r.Body); err != nil { + return 0, nil, nil, err + } + + // Create a buffer to hold the serialized response message + buffer := new(bytes.Buffer) + if err := stream.RecvMsg(buffer); err != nil { + return 0, nil, nil, err + } + + // Collect headers from the gRPC response + headers, _ := metadata.FromIncomingContext(ctx) + responseHeaderMap := make(map[string]string) + for key, values := range headers { + responseHeaderMap[key] = strings.Join(values, ",") + } + + if err := stream.RecvMsg(buffer); err != nil { + st, ok := status.FromError(err) + if ok { + return st.Code(), nil, buffer.Bytes(), err + } + return codes.Unknown, nil, nil, err + } + return codes.Unknown, nil, nil, fmt.Errorf("unexpected error") +} + func handleGrpcRequest(w http.ResponseWriter, r *http.Request) { infoLog.Printf("Received gRPC request from %s", r.Host) hostParts := strings.Split(r.Host, ".")