Skip to content

Commit

Permalink
Merge pull request #185 from Visionaid-International-Ltd/validation
Browse files Browse the repository at this point in the history
Improve stream read validation
  • Loading branch information
ironfede authored Oct 8, 2024
2 parents b488e15 + e3a6d59 commit b895a73
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 70 deletions.
12 changes: 6 additions & 6 deletions sources/OpenMcdf.Extensions/StreamDecorator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ public override long Position
/// <inheritdoc/>
public override int Read(byte[] buffer, int offset, int count)
{
if (count > buffer.Length)
throw new ArgumentException("Count parameter exceeds buffer size");

if (buffer == null)
throw new ArgumentNullException("Buffer cannot be null");
throw new ArgumentNullException(nameof(buffer));

if (offset < 0)
throw new ArgumentOutOfRangeException(nameof(offset), "Offset must be a non-negative number");

if (offset < 0 || count < 0)
throw new ArgumentOutOfRangeException("Offset and Count parameters must be non-negative numbers");
if ((uint)count > buffer.Length - offset)
throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection");

if (position >= cfStream.Size)
return 0;
Expand Down
34 changes: 15 additions & 19 deletions sources/OpenMcdf/CompoundFile.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1283,14 +1283,13 @@ List<Sector> result
null,
sourceStream);

byte[] nextDIFATSectorBuffer = new byte[4];
StreamRW difatStreamRW = new(difatStream);

int i = 0;

while (result.Count < header.FATSectorsNumber)
{
difatStream.Read(nextDIFATSectorBuffer, 0, 4);
nextSecID = BitConverter.ToInt32(nextDIFATSectorBuffer, 0);
nextSecID = difatStreamRW.ReadInt32();

EnsureUniqueSectorIndex(nextSecID, processedSectors);

Expand All @@ -1308,20 +1307,14 @@ List<Sector> result

result.Add(s);

//difatStream.Read(nextDIFATSectorBuffer, 0, 4);
//nextSecID = BitConverter.ToInt32(nextDIFATSectorBuffer, 0);

if (difatStream.Position == (SectorSize - 4 + i * SectorSize))
{
// Skip DIFAT chain fields considering the possibility that the last FAT entry has been already read
difatStream.Read(nextDIFATSectorBuffer, 0, 4);
if (BitConverter.ToInt32(nextDIFATSectorBuffer, 0) == Sector.ENDOFCHAIN)
if (difatStreamRW.ReadInt32() == Sector.ENDOFCHAIN)
break;
else
{
i++;
continue;
}

i++;
continue;
}
}
}
Expand Down Expand Up @@ -1576,13 +1569,14 @@ List<Sector> directoryChain
using StreamView dirReader
= new StreamView(directoryChain, SectorSize, directoryChain.Count * SectorSize, null, sourceStream);

StreamRW dirReaderRW = new(dirReader);

while (dirReader.Position < directoryChain.Count * SectorSize)
{
IDirectoryEntry de
= DirectoryEntry.New(string.Empty, StgType.StgInvalid, directoryEntries);
IDirectoryEntry de = DirectoryEntry.New(string.Empty, StgType.StgInvalid, directoryEntries);

//We are not inserting dirs. Do not use 'InsertNewDirectoryEntry'
de.Read(dirReader, Version);
// We are not inserting dirs. Do not use 'InsertNewDirectoryEntry'
de.Read(dirReaderRW, Version);
}
}

Expand All @@ -1598,17 +1592,19 @@ List<Sector> directorySectors

using StreamView sv = new StreamView(directorySectors, SectorSize, 0, null, sourceStream);

StreamRW svRW = new(sv);

foreach (IDirectoryEntry di in directoryEntries)
{
di.Write(sv);
di.Write(svRW);
}

int delta = directoryEntries.Count;

while (delta % (SectorSize / DIRECTORY_SIZE) != 0)
{
IDirectoryEntry dummy = DirectoryEntry.New(string.Empty, StgType.StgInvalid, directoryEntries);
dummy.Write(sv);
dummy.Write(svRW);
delta++;
}

Expand Down
64 changes: 30 additions & 34 deletions sources/OpenMcdf/DirectoryEntry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -210,23 +210,21 @@ public override int GetHashCode()
return (int)fnv_hash(EntryName);
}

public void Write(Stream stream)
public void Write(StreamRW streamRW)
{
StreamRW rw = new StreamRW(stream);

rw.Write(EntryName);
rw.Write(nameLength);
rw.Write((byte)StgType);
rw.Write((byte)StgColor);
rw.Write(LeftSibling);
rw.Write(RightSibling);
rw.Write(Child);
rw.Write(storageCLSID);
rw.Write(StateBits);
rw.Write(CreationDate);
rw.Write(ModifyDate);
rw.Write(StartSect);
rw.Write(Size);
streamRW.Write(EntryName);
streamRW.Write(nameLength);
streamRW.Write((byte)StgType);
streamRW.Write((byte)StgColor);
streamRW.Write(LeftSibling);
streamRW.Write(RightSibling);
streamRW.Write(Child);
streamRW.Write(storageCLSID);
streamRW.Write(StateBits);
streamRW.Write(CreationDate);
streamRW.Write(ModifyDate);
streamRW.Write(StartSect);
streamRW.Write(Size);
}

//public Byte[] ToByteArray()
Expand Down Expand Up @@ -256,18 +254,16 @@ public void Write(Stream stream)
// return ms.ToArray();
//}

public void Read(Stream stream, CFSVersion ver = CFSVersion.Ver_3)
public void Read(StreamRW streamRW, CFSVersion ver = CFSVersion.Ver_3)
{
StreamRW rw = new StreamRW(stream);

rw.ReadBytes(EntryName);
nameLength = rw.ReadUInt16();
StgType = (StgType)rw.ReadByte();
streamRW.ReadBytes(EntryName);
nameLength = streamRW.ReadUInt16();
StgType = (StgType)streamRW.ReadByte();
//rw.ReadByte();//Ignore color, only black tree
StgColor = (StgColor)rw.ReadByte();
LeftSibling = rw.ReadInt32();
RightSibling = rw.ReadInt32();
Child = rw.ReadInt32();
StgColor = (StgColor)streamRW.ReadByte();
LeftSibling = streamRW.ReadInt32();
RightSibling = streamRW.ReadInt32();
Child = streamRW.ReadInt32();

// Thanks to bugaccount (BugTrack id 3519554)
if (StgType == StgType.StgInvalid)
Expand All @@ -277,23 +273,23 @@ public void Read(Stream stream, CFSVersion ver = CFSVersion.Ver_3)
Child = NOSTREAM;
}

storageCLSID = rw.ReadGuid();
StateBits = rw.ReadInt32();
rw.ReadBytes(CreationDate);
rw.ReadBytes(ModifyDate);
StartSect = rw.ReadInt32();
storageCLSID = streamRW.ReadGuid();
StateBits = streamRW.ReadInt32();
streamRW.ReadBytes(CreationDate);
streamRW.ReadBytes(ModifyDate);
StartSect = streamRW.ReadInt32();

if (ver == CFSVersion.Ver_3)
{
// avoid dirty read for version 3 files (max size: 32bit integer)
// where most significant bits are not initialized to zero

Size = rw.ReadInt32();
rw.Seek(4, SeekOrigin.Current); // discard most significant 4 (possibly) dirty bytes
Size = streamRW.ReadInt32();
streamRW.Seek(4, SeekOrigin.Current); // discard most significant 4 (possibly) dirty bytes
}
else
{
Size = rw.ReadInt64();
Size = streamRW.ReadInt64();
}
}

Expand Down
4 changes: 2 additions & 2 deletions sources/OpenMcdf/IDirectoryEntry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ internal interface IDirectoryEntry : IComparable, IRBNode
byte[] ModifyDate { get; set; }
string Name { get; }
ushort NameLength { get; set; }
void Read(System.IO.Stream stream, CFSVersion ver = CFSVersion.Ver_3);
void Read(StreamRW streamRW, CFSVersion ver = CFSVersion.Ver_3);
int RightSibling { get; set; }
void SetEntryName(string entryName);
int SID { get; set; }
Expand All @@ -31,7 +31,7 @@ internal interface IDirectoryEntry : IComparable, IRBNode
StgColor StgColor { get; set; }
StgType StgType { get; set; }
Guid StorageCLSID { get; set; }
void Write(System.IO.Stream stream);
void Write(StreamRW stream);
void Reset();
}
}
30 changes: 21 additions & 9 deletions sources/OpenMcdf/StreamRW.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,58 +21,70 @@ public StreamRW(Stream stream)
this.stream = stream;
}

void ReadExactly(byte[] buffer, int offset, int count)
{
int totalRead = 0;
do
{
int read = stream.Read(buffer, offset + totalRead, count - totalRead);
if (read == 0)
throw new EndOfStreamException();

totalRead += read;
} while (totalRead < count);
}

public long Seek(long count, SeekOrigin origin)
{
return stream.Seek(count, origin);
}

public byte ReadByte()
{
stream.Read(buffer, 0, 1);
ReadExactly(buffer, 0, sizeof(byte));
return buffer[0];
}

public ushort ReadUInt16()
{
stream.Read(buffer, 0, 2);
ReadExactly(buffer, 0, sizeof(ushort));
return (ushort)(buffer[0] | (buffer[1] << 8));
}

public int ReadInt32()
{
stream.Read(buffer, 0, 4);
ReadExactly(buffer, 0, sizeof(int));
return buffer[0] | (buffer[1] << 8) | (buffer[2] << 16) | (buffer[3] << 24);
}

public uint ReadUInt32()
{
stream.Read(buffer, 0, 4);
ReadExactly(buffer, 0, sizeof(uint));
return (uint)(buffer[0] | (buffer[1] << 8) | (buffer[2] << 16) | (buffer[3] << 24));
}

public long ReadInt64()
{
stream.Read(buffer, 0, 8);
ReadExactly(buffer, 0, sizeof(long));
uint ls = (uint)(buffer[0] | (buffer[1] << 8) | (buffer[2] << 16) | (buffer[3] << 24));
uint ms = (uint)((buffer[4]) | (buffer[5] << 8) | (buffer[6] << 16) | (buffer[7] << 24));
return (long)(((ulong)ms << 32) | ls);
}

public ulong ReadUInt64()
{
stream.Read(buffer, 0, 8);
ReadExactly(buffer, 0, sizeof(ulong));
return (ulong)(buffer[0] | (buffer[1] << 8) | (buffer[2] << 16) | (buffer[3] << 24) | (buffer[4] << 32) | (buffer[5] << 40) | (buffer[6] << 48) | (buffer[7] << 56));
}

public void ReadBytes(byte[] result)
{
// TODO: Check if the expected number of bytes were read
stream.Read(result, 0, result.Length);
ReadExactly(result, 0, result.Length);
}

public Guid ReadGuid()
{
stream.Read(buffer, 0, 16);
ReadExactly(buffer, 0, 16);
return new Guid(buffer);
}

Expand Down
15 changes: 15 additions & 0 deletions sources/OpenMcdf/StreamView.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ public override long Position

public override int Read(byte[] buffer, int offset, int count)
{
if (buffer is null)
throw new ArgumentNullException(nameof(buffer));

if (offset < 0)
throw new ArgumentOutOfRangeException(nameof(offset), "Offset must be a non-negative number");

if ((uint)count > buffer.Length - offset)
throw new ArgumentException("Offset and length were out of bounds for the array or count is greater than the number of elements from index to the end of the source collection");

int nRead = 0;

// Don't try to read more bytes than this stream contains.
Expand Down Expand Up @@ -157,14 +166,20 @@ public override long Seek(long offset, SeekOrigin origin)
switch (origin)
{
case SeekOrigin.Begin:
if (offset < 0)
throw new IOException("Seek before origin");
position = offset;
break;

case SeekOrigin.Current:
if (position + offset < 0)
throw new IOException("Seek before origin");
position += offset;
break;

case SeekOrigin.End:
if (Length - offset < 0)
throw new IOException("Seek before origin");
position = Length - offset;
break;
}
Expand Down

0 comments on commit b895a73

Please sign in to comment.