diff --git a/data_psi.go b/data_psi.go index e0c3e78..7b306e6 100644 --- a/data_psi.go +++ b/data_psi.go @@ -119,6 +119,8 @@ func parsePSIData(i *astikit.BytesIterator) (d *PSIData, err error) { if s, stop, err = parsePSISection(i); err != nil { err = fmt.Errorf("astits: parsing PSI table failed: %w", err) return + } else if stop { + break } d.Sections = append(d.Sections, s) } @@ -132,14 +134,10 @@ func parsePSISection(i *astikit.BytesIterator) (s *PSISection, stop bool, err er // Parse header var offsetStart, offsetSectionsEnd, offsetEnd int - if s.Header, offsetStart, _, offsetSectionsEnd, offsetEnd, err = parsePSISectionHeader(i); err != nil { + if s.Header, offsetStart, _, offsetSectionsEnd, offsetEnd, stop, err = parsePSISectionHeader(i); err != nil { err = fmt.Errorf("astits: parsing PSI section header failed: %w", err) return - } - - // Check whether we need to stop the parsing - if shouldStopPSIParsing(s.Header.TableID) { - stop = true + } else if stop { return } @@ -199,12 +197,11 @@ func parseCRC32(i *astikit.BytesIterator) (c uint32, err error) { // shouldStopPSIParsing checks whether the PSI parsing should be stopped func shouldStopPSIParsing(tableID PSITableID) bool { - return tableID == PSITableIDNull || - tableID.isUnknown() + return tableID == PSITableIDNull } // parsePSISectionHeader parses a PSI section header -func parsePSISectionHeader(i *astikit.BytesIterator) (h *PSISectionHeader, offsetStart, offsetSectionsStart, offsetSectionsEnd, offsetEnd int, err error) { +func parsePSISectionHeader(i *astikit.BytesIterator) (h *PSISectionHeader, offsetStart, offsetSectionsStart, offsetSectionsEnd, offsetEnd int, stop bool, err error) { // Init h = &PSISectionHeader{} offsetStart = i.Offset() @@ -223,7 +220,8 @@ func parsePSISectionHeader(i *astikit.BytesIterator) (h *PSISectionHeader, offse h.TableType = h.TableID.Type() // Check whether we need to stop the parsing - if shouldStopPSIParsing(h.TableID) { + if h.TableID == PSITableIDNull { + stop = true return } @@ -241,7 +239,7 @@ func parsePSISectionHeader(i *astikit.BytesIterator) (h *PSISectionHeader, offse h.PrivateBit = bs[0]&0x40 > 0 // Section length - h.SectionLength = uint16(bs[0]&0xf)<<8 | uint16(bs[1]) + h.SectionLength = uint16(bs[0]&3)<<8 | uint16(bs[1]) // Offsets offsetSectionsStart = i.Offset() diff --git a/data_psi_test.go b/data_psi_test.go index a452222..2e7ea07 100644 --- a/data_psi_test.go +++ b/data_psi_test.go @@ -2,12 +2,22 @@ package astits import ( "bytes" + "encoding/hex" + "strings" "testing" "github.com/asticode/go-astikit" "github.com/stretchr/testify/assert" ) +func hexToBytes(in string) []byte { + o, err := hex.DecodeString(strings.ReplaceAll(in, "\n", "")) + if err != nil { + panic(err) + } + return o +} + var psi = &PSIData{ PointerField: 4, Sections: []*PSISection{ @@ -202,7 +212,7 @@ func TestParsePSISectionHeader(t *testing.T) { w.Write(uint8(254)) // Table ID w.Write("1") // Syntax section indicator w.Write("0000000") // Finish the byte - d, _, _, _, _, err := parsePSISectionHeader(astikit.NewBytesIterator(buf.Bytes())) + d, _, _, _, _, _, err := parsePSISectionHeader(astikit.NewBytesIterator(buf.Bytes())) assert.Equal(t, d, &PSISectionHeader{ TableID: 254, TableType: PSITableTypeUnknown, @@ -210,12 +220,13 @@ func TestParsePSISectionHeader(t *testing.T) { assert.NoError(t, err) // Valid table type - d, offsetStart, offsetSectionsStart, offsetSectionsEnd, offsetEnd, err := parsePSISectionHeader(astikit.NewBytesIterator(psiSectionHeaderBytes())) + d, offsetStart, offsetSectionsStart, offsetSectionsEnd, offsetEnd, stop, err := parsePSISectionHeader(astikit.NewBytesIterator(psiSectionHeaderBytes())) assert.Equal(t, d, psiSectionHeader) assert.Equal(t, 0, offsetStart) assert.Equal(t, 3, offsetSectionsStart) assert.Equal(t, 2729, offsetSectionsEnd) assert.Equal(t, 2733, offsetEnd) + assert.Equal(t, false, stop) assert.NoError(t, err) } @@ -356,6 +367,27 @@ var psiDataTestCases = []psiDataTestCase{ }, } +func TestParsePSIDataPMTMultipleSections(t *testing.T) { + pmt := hexToBytes(`00C0001500000100FF000000 +000000010000000000038D646B02B07B +0001C90000EF9BF02109044749E10B05 +04474139348713C1010100F30D01656E +670100000554562D504702EF9BF00E11 +01FF1006C0BD62C0080006010281EF9C +F018050441432D33810A083805FF0F01 +BF656E670A04656E670081EF9DF01805 +0441432D33810A082885FF0001BF7370 +610A0473706100082F08E3FFFFFFFFFF +FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF +FFFFFFFFFFFFFFFFFFFFFFFF`) + d, err := parsePSIData(astikit.NewBytesIterator(pmt)) + assert.NoError(t, err) + assert.NotNil(t, d) + assert.Len(t, d.Sections, 2) + assert.Equal(t, PSITableID(0xc0), d.Sections[0].Header.TableID) + assert.Equal(t, PSITableID(0x02), d.Sections[1].Header.TableID) +} + func TestWritePSIData(t *testing.T) { for _, tc := range psiDataTestCases { t.Run(tc.name, func(t *testing.T) { diff --git a/demuxer_test.go b/demuxer_test.go index b7e5b6b..77712a6 100644 --- a/demuxer_test.go +++ b/demuxer_test.go @@ -84,6 +84,37 @@ func TestDemuxerNextData(t *testing.T) { assert.EqualError(t, err, ErrNoMorePackets.Error()) } +func TestDemuxerNextDataPMTComplex(t *testing.T) { + // complex pmt with two tables (0xc0 and 0x2) split across two packets + pmt := hexToBytes(`47403b1e00c0001500000100610000000000000100000000 +0035e3e2d702b0b20001c50000eefdf01809044749e10b05 +0441432d330504454143330504435545491beefdf0102a02 +7e1f9700e9080c001f418507d04181eefef00f810706380f +ff1f003f0a04656e670081eefff00f8107061003ff1f003f +0a047370610086ef00f00f8a01009700e9080c001f418507 +d041c0ef01f012050445545631a100e9080c001f418507d0 +41c0ef02f013050445545631a20100e9080c001f47003b1f +418507d041c0ef03f008bf06496e76696469a5cff3afffff +ffffffffffffffffffffffffffffffffffffffffffffffff +ffffffffffffffffffffffffffffffffffffffffffffffff +ffffffffffffffffffffffffffffffffffffffffffffffff +ffffffffffffffffffffffffffffffffffffffffffffffff +ffffffffffffffffffffffffffffffffffffffffffffffff +ffffffffffffffffffffffffffffffffffffffffffffffff +ffffffffffffffffffffffffffffffff`) + r := bytes.NewReader(pmt) + assert.Equal(t, 188*2, r.Len()) + + dmx := NewDemuxer(context.Background(), r, DemuxerOptPacketSize(188)) + dmx.programMap.set(59, 1) + + d, err := dmx.NextData() + assert.NoError(t, err) + assert.NotNil(t, d) + assert.Equal(t, uint16(59), d.FirstPacket.Header.PID) + assert.NotNil(t, d.PMT) +} + func TestDemuxerRewind(t *testing.T) { r := bytes.NewReader([]byte("content")) dmx := NewDemuxer(context.Background(), r)