Client LuaCsForBarotrauma
SegmentTable.cs
1 #nullable enable
2 using System;
3 using System.Collections.Generic;
4 using System.Linq;
5 using Microsoft.Xna.Framework;
6 
7 namespace Barotrauma.Networking;
8 
9 /*
10  * What are segment tables for?
11  *
12  * Segment tables help make our networking packet reading code more robust by
13  * clearly stating where part of a message begins. Previously we would've done
14  * something like:
15  *
16  * msg.WriteByte(SegmentType.A);
17  * ...
18  * msg.WriteByte(SegmentType.B);
19  * ...
20  * msg.WriteByte(SegmentType.EndOfMessage);
21  *
22  * The problem with this design is that it's hard to debug when the writing and reading
23  * code do not align for whatever reason. INetSerializableStruct is an awesome way
24  * of avoiding that problem, but deploying it on a broad scale means rewriting most
25  * of the netcode. That isn't going to happen any time soon, so this exists as an easier
26  * way of increasing robustness.
27  *
28  * A segment table is laid out as follows:
29  *
30  * [TablePointer: UInt16]
31  * [Segment: arbitrary]
32  * ...
33  * [Segment: arbitrary]
34  * [NumberOfSegments: UInt16]
35  * [(Identifier, SegmentPointer): (T, UInt16)]
36  * ...
37  * [(Identifier, SegmentPointer): (T, UInt16)]
38  *
39  * A pointer in this context is an offset relative to the BitPosition where the TablePointer is written.
40  *
41  * It is used as follows:
42  *
43  * using (var segmentTable = SegmentTableWriter<T>.StartWriting(outMsg))
44  * {
45  * segmentTable.StartNewSegment(T.A);
46  * ... write segment to outMsg ...
47  * segmentTable.StartNewSegment(T.B);
48  * ... write segment to outMsg ...
49  * }
50  * peer.SendMessage(outMsg);
51  *
52  * ...
53  *
54  * SegmentTableReader<T>.Read(inc,
55  * segmentDataReader: (segment, inc) =>
56  * {
57  * switch (segment)
58  * {
59  * ... read segments ...
60  * }
61  * }
62  * }
63  *
64  * The advantages of this approach are:
65  * - If a message is truncated or corrupted near the end, it becomes far more obvious because the table
66  * would not be read properly and look like garbage when printed to the console.
67  * - If the reading and writing code for a segment disagree on something, issues will be isolated to that
68  * one segment.
69  * - The code no longer has to fiddle with padding and temporary buffers because the segment table is able
70  * to handle content that is not byte-aligned just fine.
71  * - Exception handling is far easier when using a segment table, when combined with a using statement
72  * any uncaught exception will result in the entire table being skipped, allowing the remainder of the
73  * message to still be read.
74  * - It's harder to make mistakes in the implementation of segments themselves with this approach. By using
75  * the SegmentTableWriter and SegmentTableReader types, you get a type-safe way of delimiting segments
76  * and it's harder to forget to finalize a packet.
77  */
78 
79 [NetworkSerialize]
80 public readonly record struct Segment<T>(T Identifier, int Pointer) : INetSerializableStruct where T : struct;
81 
82 readonly ref struct SegmentTableWriter<T> where T : struct
83 {
84  private readonly IWriteMessage message;
85  private readonly List<Segment<T>> segments;
86  public readonly int PointerLocation;
87  private SegmentTableWriter(IWriteMessage message, int pointerLocation)
88  {
89  this.message = message;
90  this.PointerLocation = pointerLocation;
91  this.segments = new List<Segment<T>>();
92  }
93 
94  public static SegmentTableWriter<T> StartWriting(IWriteMessage msg)
95  {
96  var retVal = new SegmentTableWriter<T>(msg, msg.BitPosition);
97  msg.WriteInt32(0); //reserve space for the table pointer
98  return retVal;
99  }
100 
101  private void ThrowOnInvalidState()
102  {
103  if (segments.Count >= UInt16.MaxValue)
104  {
105  throw new InvalidOperationException($"Too many segments in SegmentTable<{typeof(T).Name}>");
106  }
107  }
108 
109  public void StartNewSegment(T value)
110  {
111  ThrowOnInvalidState();
112  segments.Add(new Segment<T>(value, message.BitPosition - PointerLocation));
113  }
114 
115  public void Dispose()
116  {
117  ThrowOnInvalidState();
118  int tablePosition = message.BitPosition;
119 
120  //rewrite the table pointer now that we know where the table ends
121  message.BitPosition = PointerLocation;
122  message.WriteInt32(tablePosition - PointerLocation);
123 
124  //write the table
125  message.BitPosition = tablePosition;
126  message.WriteUInt16((UInt16)segments.Count);
127  foreach (var segment in segments)
128  {
129  message.WriteNetSerializableStruct(segment);
130  }
131  }
132 }
133 
134 readonly ref struct SegmentTableReader<T> where T : struct
135 {
136  private class SegmentReadMsg : IReadMessage
137  {
138  private readonly IReadMessage underlyingMsg;
139  private readonly IReadOnlyList<Segment<T>> segments;
140  private readonly int segmentIndex;
141  private readonly int offset;
142  private readonly int lengthBits;
143  public SegmentReadMsg(IReadMessage underlyingMsg, IReadOnlyList<Segment<T>> segments, int segmentIndex, int offset, int lengthBits)
144  {
145  this.underlyingMsg = underlyingMsg;
146  this.segments = segments;
147  this.segmentIndex = segmentIndex;
148  this.offset = offset;
149  this.lengthBits = lengthBits;
150 
151  if (offset + lengthBits >= underlyingMsg.LengthBits)
152  {
153  throw new Exception(
154  $"Segment table is corrupt, segment length is invalid: {offset} + {lengthBits} >= {underlyingMsg.LengthBits}");
155  }
156  }
157 
158  private void Check()
159  {
160  if (BitPosition > lengthBits)
161  {
162  throw new Exception($"Tried to read too much data from segment.");
163  }
164  }
165 
166  private TRead Check<TRead>(TRead v)
167  {
168  Check();
169  return v;
170  }
171 
172  public bool ReadBoolean() => Check(underlyingMsg.ReadBoolean());
173 
174  public void ReadPadBits()
175  {
176  Check(); underlyingMsg.ReadPadBits();
177  }
178 
179  public byte ReadByte() => Check(underlyingMsg.ReadByte());
180 
181  public byte PeekByte() => Check(underlyingMsg.PeekByte());
182 
183  public ushort ReadUInt16() => Check(underlyingMsg.ReadUInt16());
184 
185  public short ReadInt16() => Check(underlyingMsg.ReadInt16());
186 
187  public uint ReadUInt32() => Check(underlyingMsg.ReadUInt32());
188 
189  public int ReadInt32() => Check(underlyingMsg.ReadInt32());
190 
191  public ulong ReadUInt64() => Check(underlyingMsg.ReadUInt64());
192 
193  public long ReadInt64() => Check(underlyingMsg.ReadInt64());
194 
195  public float ReadSingle() => Check(underlyingMsg.ReadSingle());
196 
197  public double ReadDouble() => Check(underlyingMsg.ReadDouble());
198 
199  public uint ReadVariableUInt32() => Check(underlyingMsg.ReadVariableUInt32());
200 
201  public string ReadString() => Check(underlyingMsg.ReadString());
202 
203  public Identifier ReadIdentifier() => Check(underlyingMsg.ReadIdentifier());
204 
205  public Color ReadColorR8G8B8() => Check(underlyingMsg.ReadColorR8G8B8());
206 
207  public Color ReadColorR8G8B8A8() => Check(underlyingMsg.ReadColorR8G8B8A8());
208 
209  public int ReadRangedInteger(int min, int max) => Check(underlyingMsg.ReadRangedInteger(min, max));
210 
211  public float ReadRangedSingle(float min, float max, int bitCount) => Check(underlyingMsg.ReadRangedSingle(min, max, bitCount));
212 
213  public byte[] ReadBytes(int numberOfBytes) => Check(underlyingMsg.ReadBytes(numberOfBytes));
214 
215  public int BitPosition
216  {
217  get => underlyingMsg.BitPosition - offset;
218  set => Check(underlyingMsg.BitPosition = value + offset);
219  }
220 
221  public int BytePosition => BitPosition / 8;
222 
223  public byte[] Buffer => underlyingMsg.Buffer;
224 
225  public int LengthBits
226  {
227  get => lengthBits;
228  set => throw new InvalidOperationException($"Cannot resize {nameof(SegmentReadMsg)}");
229  }
230 
231  public int LengthBytes => lengthBits / 8;
232 
233  public NetworkConnection Sender => underlyingMsg.Sender;
234  }
235 
236  private readonly IReadMessage message;
237  private readonly List<Segment<T>> segments;
238  private readonly int exitLocation;
239  public readonly int PointerLocation;
240  private SegmentTableReader(IReadMessage message, List<Segment<T>> segments, int pointerLocation, int exitLocation)
241  {
242  this.message = message;
243  this.segments = segments;
244  this.PointerLocation = pointerLocation;
245  this.exitLocation = exitLocation;
246  }
247 
248  public IReadOnlyList<Segment<T>> Segments => segments;
249 
250  public enum BreakSegmentReading
251  {
252  No,
253  Yes
254  }
255 
256  public delegate BreakSegmentReading SegmentDataReader(
257  T segmentHeader,
258  IReadMessage incMsg);
259 
260  public delegate void ExceptionHandler(
261  Segment<T> segmentWithError,
262  Segment<T>[] previousSegments,
263  Exception exceptionThrown);
264 
265  public static void Read(
266  IReadMessage msg,
267  SegmentDataReader segmentDataReader,
268  ExceptionHandler? exceptionHandler = null)
269  {
270  int pointerLocation = msg.BitPosition;
271  int tablePointer = msg.ReadInt32();
272  int tableLocation = pointerLocation + tablePointer;
273 
274  int returnPosition = msg.BitPosition;
275 
276  //read the table
277  var segments = new List<Segment<T>>();
278  msg.BitPosition = tableLocation;
279  int numSegments = msg.ReadUInt16();
280  for (int i = 0; i < numSegments; i++)
281  {
282  segments.Add(INetSerializableStruct.Read<Segment<T>>(msg));
283  }
284 
285  //store the exit location and go back to the top
286  int exitLocation = msg.BitPosition;
287  msg.BitPosition = returnPosition;
288  using var segmentTable = new SegmentTableReader<T>(msg, segments, pointerLocation, exitLocation);
289 
290  for (int i = 0; i < segmentTable.Segments.Count; i++)
291  {
292  var segment = segmentTable.Segments[i];
293  msg.BitPosition = segmentTable.PointerLocation + segment.Pointer;
294  try
295  {
296  if (segmentDataReader(segment.Identifier, new SegmentReadMsg(
297  msg,
298  segments,
299  i,
300  offset: segmentTable.PointerLocation + segment.Pointer,
301  lengthBits: (i < segmentTable.Segments.Count - 1 ? segments[i + 1].Pointer : tablePointer) -
302  segment.Pointer))
303  is BreakSegmentReading.Yes)
304  {
305  break;
306  }
307  }
308  catch (Exception e)
309  {
310  var prevSegments = segments.Take(i).ToArray();
311  if (exceptionHandler is not null)
312  {
313  exceptionHandler(segment, prevSegments, e);
314  }
315  else
316  {
317  throw new Exception(
318  $"Exception thrown while reading segment {segment.Identifier} at position {segment.Pointer}." +
319  (prevSegments.Any() ? $" Previous segments: {string.Join(", ", prevSegments)}." : ""),
320  e);
321  }
322  }
323  }
324  }
325 
326  public void Dispose()
327  {
328  message.BitPosition = exitLocation;
329  }
330 }
static void Read(IReadMessage msg, SegmentDataReader segmentDataReader, ExceptionHandler? exceptionHandler=null)
readonly int PointerLocation
delegate BreakSegmentReading SegmentDataReader(T segmentHeader, IReadMessage incMsg)
delegate void ExceptionHandler(Segment< T > segmentWithError, Segment< T >[] previousSegments, Exception exceptionThrown)
IReadOnlyList< Segment< T > > Segments
void StartNewSegment(T value)
static SegmentTableWriter< T > StartWriting(IWriteMessage msg)
Definition: SegmentTable.cs:94
readonly int PointerLocation
Definition: SegmentTable.cs:86