From 1ff724b0827a0c477d9b1196dd230b763e17a314 Mon Sep 17 00:00:00 2001
From: OrlandoCo <luisorlando.co@gmail.com>
Date: Wed, 9 Dec 2020 22:14:34 -0600
Subject: [PATCH] fix(sfu): remove audio tracks from twcc (#332)

---
 pkg/buffer/interceptor.go |  2 +-
 pkg/sfu/downtrack.go      | 11 +++++-----
 pkg/sfu/mediaengine.go    | 13 +++++-------
 pkg/sfu/router.go         |  2 +-
 pkg/sfu/subscriber.go     | 44 +++++++++++++++++++++++----------------
 5 files changed, 39 insertions(+), 33 deletions(-)

diff --git a/pkg/buffer/interceptor.go b/pkg/buffer/interceptor.go
index f336d7617..0e189490f 100644
--- a/pkg/buffer/interceptor.go
+++ b/pkg/buffer/interceptor.go
@@ -72,7 +72,7 @@ func (i *Interceptor) BindRTCPReader(reader interceptor.RTCPReader) interceptor.
 			case *rtcp.SenderReport:
 				buffer := i.getBuffer(pkt.SSRC)
 				if buffer == nil {
-					return pkts, attributes, nil
+					continue
 				}
 				buffer.setSenderReportData(pkt.RTPTime, pkt.NTPTime)
 			}
diff --git a/pkg/sfu/downtrack.go b/pkg/sfu/downtrack.go
index 84031c395..bcb5c907e 100644
--- a/pkg/sfu/downtrack.go
+++ b/pkg/sfu/downtrack.go
@@ -64,14 +64,14 @@ type DownTrack struct {
 }
 
 // NewDownTrack returns a DownTrack.
-func NewDownTrack(c webrtc.RTPCodecCapability, r Receiver, peerID, id, streamID string) (*DownTrack, error) {
+func NewDownTrack(c webrtc.RTPCodecCapability, r Receiver, peerID string) (*DownTrack, error) {
 	return &DownTrack{
-		id:       id,
+		id:       r.TrackID(),
 		peerID:   peerID,
+		streamID: r.StreamID(),
 		nList:    newNACKList(),
-		codec:    c,
 		receiver: r,
-		streamID: streamID,
+		codec:    c,
 	}, nil
 }
 
@@ -216,9 +216,10 @@ func (d *DownTrack) writeSimpleRTP(pkt rtp.Packet) error {
 	atomic.AddUint32(&d.octetCount, uint32(len(pkt.Payload)))
 	atomic.AddUint32(&d.packetCount, 1)
 
+	d.lastSSRC = pkt.SSRC
 	newSN := pkt.SequenceNumber - d.snOffset
 	newTS := pkt.Timestamp - d.tsOffset
-	if (newSN-d.lastSN)&0x8000 == 0 {
+	if (newSN-d.lastSN)&0x8000 == 0 || d.lastSN == 0 {
 		d.lastSN = newSN
 		atomic.StoreInt64(&d.lastPacketMs, time.Now().UnixNano()/1e6)
 		atomic.StoreUint32(&d.lastTS, newTS)
diff --git a/pkg/sfu/mediaengine.go b/pkg/sfu/mediaengine.go
index c88ffb547..556927fd8 100644
--- a/pkg/sfu/mediaengine.go
+++ b/pkg/sfu/mediaengine.go
@@ -66,10 +66,13 @@ func getPublisherMediaEngine() (*webrtc.MediaEngine, error) {
 		sdp.SDESRTPStreamIDURI,
 		sdp.TransportCCURI,
 	} {
-		if err := me.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: extension}, webrtc.RTPCodecTypeAudio); err != nil {
+		if err := me.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: extension}, webrtc.RTPCodecTypeVideo); err != nil {
 			return nil, err
 		}
-		if err := me.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: extension}, webrtc.RTPCodecTypeVideo); err != nil {
+		if extension == sdp.TransportCCURI {
+			continue
+		}
+		if err := me.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: extension}, webrtc.RTPCodecTypeAudio); err != nil {
 			return nil, err
 		}
 	}
@@ -79,11 +82,5 @@ func getPublisherMediaEngine() (*webrtc.MediaEngine, error) {
 
 func getSubscriberMediaEngine() (*webrtc.MediaEngine, error) {
 	me := &webrtc.MediaEngine{}
-	if err := me.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESMidURI}, webrtc.RTPCodecTypeVideo); err != nil {
-		return nil, err
-	}
-	if err := me.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESMidURI}, webrtc.RTPCodecTypeAudio); err != nil {
-		return nil, err
-	}
 	return me, nil
 }
diff --git a/pkg/sfu/router.go b/pkg/sfu/router.go
index b51fc0909..e01f81819 100644
--- a/pkg/sfu/router.go
+++ b/pkg/sfu/router.go
@@ -125,7 +125,7 @@ func (r *router) addDownTrack(sub *Subscriber, recv Receiver) error {
 		Channels:     codec.Channels,
 		SDPFmtpLine:  codec.SDPFmtpLine,
 		RTCPFeedback: []webrtc.RTCPFeedback{{"goog-remb", ""}, {"nack", ""}, {"nack", "pli"}},
-	}, recv, sub.id, recv.TrackID(), recv.StreamID())
+	}, recv, sub.id)
 	if err != nil {
 		return err
 	}
diff --git a/pkg/sfu/subscriber.go b/pkg/sfu/subscriber.go
index d4ab1a5c3..8fe81920b 100644
--- a/pkg/sfu/subscriber.go
+++ b/pkg/sfu/subscriber.go
@@ -200,6 +200,12 @@ func (s *Subscriber) downTracksReports() {
 						Type: rtcp.SDESCNAME,
 						Text: dt.streamID,
 					}},
+				}, rtcp.SourceDescriptionChunk{
+					Source: dt.ssrc,
+					Items: []rtcp.SourceDescriptionItem{{
+						Type: rtcp.SDESType(15),
+						Text: dt.transceiver.Mid(),
+					}},
 				})
 			}
 		}
@@ -226,32 +232,34 @@ func (s *Subscriber) sendStreamDownTracksReports(streamID string) {
 		if !dt.bound.get() {
 			continue
 		}
-		now := time.Now().UnixNano()
-		nowNTP := timeToNtp(now)
-		lastPktMs := atomic.LoadInt64(&dt.lastPacketMs)
-		maxPktTs := atomic.LoadUint32(&dt.lastTS)
-		diffTs := uint32((now/1e6)-lastPktMs) * dt.codec.ClockRate / 1000
-		octets, packets := dt.getSRStats()
-		r = append(r, &rtcp.SenderReport{
-			SSRC:        dt.ssrc,
-			NTPTime:     nowNTP,
-			RTPTime:     maxPktTs + diffTs,
-			PacketCount: packets,
-			OctetCount:  octets,
-		})
 		sd = append(sd, rtcp.SourceDescriptionChunk{
 			Source: dt.ssrc,
 			Items: []rtcp.SourceDescriptionItem{{
 				Type: rtcp.SDESCNAME,
 				Text: dt.streamID,
 			}},
+		}, rtcp.SourceDescriptionChunk{
+			Source: dt.ssrc,
+			Items: []rtcp.SourceDescriptionItem{{
+				Type: rtcp.SDESType(15),
+				Text: dt.transceiver.Mid(),
+			}},
 		})
 	}
 	s.RUnlock()
-	if len(r) > 0 {
-		r = append(r, &rtcp.SourceDescription{Chunks: sd})
-		if err := s.pc.WriteRTCP(r); err != nil {
-			log.Errorf("Sending track binding reports err:%v", err)
+	r = append(r, &rtcp.SourceDescription{Chunks: sd})
+	go func() {
+		r := r
+		i := 0
+		for {
+			if err := s.pc.WriteRTCP(r); err != nil {
+				log.Errorf("Sending track binding reports err:%v", err)
+			}
+			if i > 5 {
+				return
+			}
+			i++
+			time.Sleep(20 * time.Millisecond)
 		}
-	}
+	}()
 }