Skip to content
This repository has been archived by the owner on Dec 17, 2022. It is now read-only.

Commit

Permalink
Merge pull request #12 from dinccey/mem_fix
Browse files Browse the repository at this point in the history
Fix big response / request body memory leaks (downloading / uploading large files)
  • Loading branch information
jerson authored Oct 25, 2021
2 parents 6b39139 + 9fcca86 commit 6f94c8c
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 8 deletions.
16 changes: 16 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ fmt:
client: deps
go build -o pgrok ./cmd/pgrok

compile-all:
GOOS=linux GOARCH=386 go build -o pgrok_linux_i386 ./cmd/pgrok
GOOS=windows GOARCH=386 go build -o pgrok_windows_i386 ./cmd/pgrok
GOOS=linux GOARCH=arm64 go build -o pgrok_linux_arm64 ./cmd/pgrok
GOOS=windows GOARCH=arm64 go build -o pgrok_windows_arm64 ./cmd/pgrok
GOOS=linux GOARCH=amd64 go build -o pgrok_linux_amd64 ./cmd/pgrok
GOOS=windows GOARCH=amd64 go build -o pgrok_windows_amd64 ./cmd/pgrok

GOOS=linux GOARCH=386 go build -o pgrokd_linux_i386 ./cmd/pgrokd
GOOS=windows GOARCH=386 go build -o pgrokd_windows_i386 ./cmd/pgrokd
GOOS=linux GOARCH=arm64 go build -o pgrokd_linux_arm64 ./cmd/pgrokd
GOOS=windows GOARCH=arm64 go build -o pgrokd_windows_arm64 ./cmd/pgrokd
GOOS=linux GOARCH=amd64 go build -o pgrokd_linux_amd64 ./cmd/pgrokd
GOOS=windows GOARCH=amd64 go build -o pgrokd_windows_amd64 ./cmd/pgrokd


assets: client-assets server-assets

go-bindata:
Expand Down
16 changes: 15 additions & 1 deletion client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,26 @@ package client

import (
"fmt"
"github.com/inconshreveable/mousetrap"
"math/rand"
"net/http"
"os"
"pgrok/log"
"pgrok/util"
"runtime"
"time"

"github.com/inconshreveable/mousetrap"

_ "net/http/pprof"
)

//debug memory profiler $ go tool pprof http://localhost:6060/debug/pprof/heap
func pprof() {
go func() {
http.ListenAndServe("localhost:6060", nil)
}()
}

func init() {
if runtime.GOOS == "windows" {
if mousetrap.StartedByExplorer() {
Expand All @@ -23,6 +34,9 @@ func init() {
}

func Main() {
//run profiler
//pprof()

// parse options
opts, err := ParseArgs()
if err != nil {
Expand Down
145 changes: 138 additions & 7 deletions proto/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package proto
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
Expand Down Expand Up @@ -84,7 +86,7 @@ func (h *Http) readRequests(tee *conn.Tee, lastTxn chan *HttpTxn, connCtx interf

// make sure we read the body of the request so that
// we don't block the writer
_, err = httputil.DumpRequest(req, true)
_, err = DumpRequest(req, true)

h.reqMeter.Mark(1)
if err != nil {
Expand All @@ -109,6 +111,126 @@ func (h *Http) readRequests(tee *conn.Tee, lastTxn chan *HttpTxn, connCtx interf
}
}

//from httputil, here to use custom drainBody func
func DumpRequest(req *http.Request, body bool) ([]byte, error) {
var err error
save := req.Body
if !body || req.Body == nil {
req.Body = nil
} else {
//save, req.Body, err = drainBody(req.Body)
io.Copy(ioutil.Discard, req.Body)
if err != nil {
return nil, err
}
}

var b bytes.Buffer

// By default, print out the unmodified req.RequestURI, which
// is always set for incoming server requests. But because we
// previously used req.URL.RequestURI and the docs weren't
// always so clear about when to use DumpRequest vs
// DumpRequestOut, fall back to the old way if the caller
// provides a non-server Request.
reqURI := req.RequestURI
if reqURI == "" {
reqURI = req.URL.RequestURI()
}

fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"),
reqURI, req.ProtoMajor, req.ProtoMinor)

absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://")
if !absRequestURI {
host := req.Host
if host == "" && req.URL != nil {
host = req.URL.Host
}
if host != "" {
fmt.Fprintf(&b, "Host: %s\r\n", host)
}
}

chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked"
if len(req.TransferEncoding) > 0 {
fmt.Fprintf(&b, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ","))
}
if req.Close {
fmt.Fprintf(&b, "Connection: close\r\n")
}

err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump)
if err != nil {
return nil, err
}

io.WriteString(&b, "\r\n")

if req.Body != nil {
var dest io.Writer = &b
if chunked {
dest = httputil.NewChunkedWriter(dest)
}
_, err = io.Copy(dest, req.Body)
if chunked {
dest.(io.Closer).Close()
io.WriteString(&b, "\r\n")
}
}

req.Body = save
if err != nil {
return nil, err
}
return b.Bytes(), nil
}

//from httputil, here to use custom drainBody func
var errNoBody = errors.New("sentinel error value")
var emptyBody = io.NopCloser(strings.NewReader(""))

type failureToReadBody struct{}

func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody }
func (failureToReadBody) Close() error { return nil }

// DumpResponse is like DumpRequest but dumps a response.
func DumpResponse(resp *http.Response, body bool) ([]byte, error) {
var b bytes.Buffer
var err error
save := resp.Body
savecl := resp.ContentLength

if !body {
// For content length of zero. Make sure the body is an empty
// reader, instead of returning error through failureToReadBody{}.
if resp.ContentLength == 0 {
resp.Body = emptyBody
} else {
resp.Body = failureToReadBody{}
}
} else if resp.Body == nil {
resp.Body = emptyBody
} else {
io.Copy(ioutil.Discard, resp.Body)
//save, resp.Body, err = drainBody(resp.Body)
if err != nil {
return nil, err
}
}
err = resp.Write(&b)
if err == errNoBody {
err = nil
}
resp.Body = save
resp.ContentLength = savecl
if err != nil {
return nil, err
}
return b.Bytes(), nil
}

func (h *Http) readResponses(tee *conn.Tee, lastTxn chan *HttpTxn) {
for txn := range lastTxn {
resp, err := http.ReadResponse(tee.ReadBuffer(), txn.Req.Request)
Expand All @@ -121,7 +243,10 @@ func (h *Http) readResponses(tee *conn.Tee, lastTxn chan *HttpTxn) {
}
// make sure we read the body of the response so that
// we don't block the reader
_, _ = httputil.DumpResponse(resp, true)

// Drain and close the body to let the Transport reuse the connection

_, _ = DumpResponse(resp, true)

txn.Resp = &HttpResponse{Response: resp}
// apparently, Body can be nil in some cases
Expand Down Expand Up @@ -172,16 +297,17 @@ func (h *Http) readResponses(tee *conn.Tee, lastTxn chan *HttpTxn) {
// elaborate trick where the other copy is made during Request/Response.Write.
// This would complicate things too much, given that these functions are for
// debugging only.
func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
/*func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
var buf bytes.Buffer
if _, err = buf.ReadFrom(b); err != nil {
if buf.Reset(); err != nil {
return nil, nil, err
}
if err = b.Close(); err != nil {
return nil, nil, err
}
return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil
}
}*/

// dumpConn is a net.Conn which writes to Writer and reads from Reader
type dumpConn struct {
Expand Down Expand Up @@ -219,7 +345,8 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
}
} else {
var err error
save, req.Body, err = drainBody(req.Body)
io.Copy(ioutil.Discard, req.Body)
//save, req.Body, err = drainBody(req.Body)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -251,7 +378,11 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
req, _ := http.ReadRequest(bufio.NewReader(pr))
// THIS IS THE PART THAT'S BROKEN IN THE STDLIB (as of Go 1.3)
if req != nil && req.Body != nil {
ioutil.ReadAll(req.Body)
//this part consumes memory from requests, doesn't appear to be needed
//ioutil.ReadAll(req.Body)
//better way
//io.Copy(ioutil.Discard, req.Body)

}
dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\n\r\n")
}()
Expand Down

0 comments on commit 6f94c8c

Please sign in to comment.