Skip to content

Commit

Permalink
Fix packet length validation, new cipher tests
Browse files Browse the repository at this point in the history
This is v2 CP of two commits from master:

Fix packet length validation
Missing packet length validations could cause crash.

New SRTP cipher test suite
Moved tests from srtp_cipher_aead_aes_gcm_test.go to new file, cleaned
up their code to make it more DRY. Added tests for AES CM ciphers there
too.
This new cipher test suite will make it easier to add tests for new
SRTP features without lots of copy/paste.
  • Loading branch information
sirzooro committed Jul 7, 2024
1 parent f598d3b commit a3ec2e6
Show file tree
Hide file tree
Showing 9 changed files with 256 additions and 169 deletions.
3 changes: 2 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ var (
errNoConfig = errors.New("no config provided")
errNoConn = errors.New("no conn provided")
errFailedToVerifyAuthTag = errors.New("failed to verify auth tag")
errTooShortRTCP = errors.New("packet is too short to be rtcp packet")
errTooShortRTP = errors.New("packet is too short to be RTP packet")
errTooShortRTCP = errors.New("packet is too short to be RTCP packet")
errPayloadDiffers = errors.New("payload differs")
errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed")
errBadIVLength = errors.New("bad iv length in xorBytesCTR")
Expand Down
15 changes: 15 additions & 0 deletions protection_profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,18 @@ func (p ProtectionProfile) authKeyLen() (int, error) {
return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p)
}
}

func (p ProtectionProfile) String() string {
switch p {
case ProtectionProfileAes128CmHmacSha1_80:
return "SRTP_AES128_CM_HMAC_SHA1_80"
case ProtectionProfileAes128CmHmacSha1_32:
return "SRTP_AES128_CM_HMAC_SHA1_32"
case ProtectionProfileAeadAes128Gcm:
return "SRTP_AEAD_AES_128_GCM"
case ProtectionProfileAeadAes256Gcm:
return "SRTP_AEAD_AES_256_GCM"
default:
return fmt.Sprintf("Unknown SRTP profile: %#v", p)
}
}
4 changes: 4 additions & 0 deletions srtcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byt
}

func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) {
if len(decrypted) < 8 {
return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(decrypted))
}

ssrc := binary.BigEndian.Uint32(decrypted[4:])
s := c.getSRTCPSSRCState(ssrc)

Expand Down
24 changes: 24 additions & 0 deletions srtcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -594,3 +594,27 @@ func TestRTCPReplayDetectorFactory(t *testing.T) {
}
assert.Equal(1, cntFactory)
}

func TestDecryptInvalidSRTCP(t *testing.T) {
assert := assert.New(t)
key := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
salt := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
decryptContext, err := CreateContext(key, salt, ProtectionProfileAes128CmHmacSha1_80)
assert.NoError(err)

packet := []byte{0x8f, 0x48, 0xff, 0xff, 0xec, 0x77, 0xb0, 0x43, 0xf9, 0x04, 0x51, 0xff, 0xfb, 0xdf}
_, err = decryptContext.DecryptRTCP(nil, packet, nil)
assert.Error(err)
}

func TestEncryptInvalidRTCP(t *testing.T) {
assert := assert.New(t)
key := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
salt := []byte{0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01}
decryptContext, err := CreateContext(key, salt, ProtectionProfileAes128CmHmacSha1_80)
assert.NoError(err)

packet := []byte{0xbb, 0xbb, 0x0a, 0x2f}
_, err = decryptContext.EncryptRTCP(nil, packet, nil)
assert.Error(err)
}
13 changes: 9 additions & 4 deletions srtp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ import (
)

func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) {
authTagLen, err := c.cipher.rtpAuthTagLen()
if err != nil {
return nil, err
}

if len(ciphertext) < headerLen+authTagLen {
return nil, errTooShortRTP
}

s := c.getSRTPSSRCState(header.SSRC)

roc, diff, _ := s.nextRolloverCount(header.SequenceNumber)
Expand All @@ -21,10 +30,6 @@ func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerL
}
}

authTagLen, err := c.cipher.rtpAuthTagLen()
if err != nil {
return nil, err
}
dst = growBufferSize(dst, len(ciphertext)-authTagLen)

dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc)
Expand Down
164 changes: 0 additions & 164 deletions srtp_cipher_aead_aes_gcm_test.go

This file was deleted.

3 changes: 3 additions & 0 deletions srtp_cipher_aes_cm_hmac_sha1.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc
return nil, err
}
tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize)
if tailOffset < 8 {
return nil, errTooShortRTCP
}
out = out[0:tailOffset]

expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagLen])
Expand Down
Loading

0 comments on commit a3ec2e6

Please sign in to comment.