diff options
Diffstat (limited to 'pkg/iptcp_utils/iptcp_utils.go')
-rw-r--r-- | pkg/iptcp_utils/iptcp_utils.go | 144 |
1 files changed, 144 insertions, 0 deletions
diff --git a/pkg/iptcp_utils/iptcp_utils.go b/pkg/iptcp_utils/iptcp_utils.go new file mode 100644 index 0000000..f8caf4d --- /dev/null +++ b/pkg/iptcp_utils/iptcp_utils.go @@ -0,0 +1,144 @@ +package iptcp_utils + +import ( + "encoding/binary" + "fmt" + "net/netip" + "strings" + + "github.com/google/netstack/tcpip/header" +) + +const ( + TcpHeaderLen = header.TCPMinimumSize + TcpPseudoHeaderLen = 12 + IpProtoTcp = header.TCPProtocolNumber + MaxVirtualPacketSize = 1400 +) + + + +// Build a TCPFields struct from the TCP byte array +// +// NOTE: the netstack package might have other options for parsing the header +// that you may like better--this example is most similar to our other class +// examples. Your mileage may vary! +func ParseTCPHeader(b []byte) header.TCPFields { + td := header.TCP(b) + return header.TCPFields{ + SrcPort: td.SourcePort(), + DstPort: td.DestinationPort(), + SeqNum: td.SequenceNumber(), + AckNum: td.AckNumber(), + DataOffset: td.DataOffset(), + Flags: td.Flags(), + WindowSize: td.WindowSize(), + Checksum: td.Checksum(), + } +} + +// The TCP checksum is computed based on a "pesudo-header" that +// combines the (virtual) IP source and destination address, protocol value, +// as well as the TCP header and payload +// +// This is one example one is way to combine all of this information +// and compute the checksum leveraging the netstack package. +// +// For more details, see the "Checksum" component of RFC9293 Section 3.1, +// https://www.rfc-editor.org/rfc/rfc9293.txt +func ComputeTCPChecksum(tcpHdr *header.TCPFields, + sourceIP netip.Addr, destIP netip.Addr, payload []byte) uint16 { + + // Fill in the pseudo header + pseudoHeaderBytes := make([]byte, TcpPseudoHeaderLen) + + // First are the source and dest IPs. This function only supports + // IPv4, so make sure the IPs are IPv4 addresses + copy(pseudoHeaderBytes[0:4], sourceIP.AsSlice()) + copy(pseudoHeaderBytes[4:8], destIP.AsSlice()) + + // Next, add the protocol number and header length + pseudoHeaderBytes[8] = uint8(0) + pseudoHeaderBytes[9] = uint8(IpProtoTcp) + + totalLength := TcpHeaderLen + len(payload) + binary.BigEndian.PutUint16(pseudoHeaderBytes[10:12], uint16(totalLength)) + + // Turn the TcpFields struct into a byte array + headerBytes := header.TCP(make([]byte, TcpHeaderLen)) + headerBytes.Encode(tcpHdr) + + // Compute the checksum for each individual part and combine To combine the + // checksums, we leverage the "initial value" argument of the netstack's + // checksum package to carry over the value from the previous part + pseudoHeaderChecksum := header.Checksum(pseudoHeaderBytes, 0) + headerChecksum := header.Checksum(headerBytes, pseudoHeaderChecksum) + fullChecksum := header.Checksum(payload, headerChecksum) + + // Return the inverse of the computed value, + // which seems to be the convention of the checksum algorithm + // in the netstack package's implementation + return fullChecksum ^ 0xffff +} + +// Compute the checksum using the netstack package +func ComputeIPChecksum(b []byte) uint16 { + checksum := header.Checksum(b, 0) + + // Invert the checksum value. Why is this necessary? + // This function returns the inverse of the checksum + // on an initial computation. While this may seem weird, + // it makes it easier to use this same function + // to validate the checksum on the receiving side. + // See ValidateChecksum in the receiver file for details. + checksumInv := checksum ^ 0xffff + + return checksumInv +} + +// Validate the checksum using the netstack package Here, we provide both the +// byte array for the header AND the initial checksum value that was stored in +// the header +// +// "Why don't we need to set the checksum value to 0 first?" +// +// Normally, the checksum is computed with the checksum field of the header set +// to 0. This library creatively avoids this step by instead subtracting the +// initial value from the computed checksum. If you use a different language or +// checksum function, you may need to handle this differently. +func ValidateIPChecksum(b []byte, fromHeader uint16) uint16 { + checksum := header.Checksum(b, fromHeader) + + return checksum +} + +// Pretty-print TCP flags value as a string +func TCPFlagsAsString(flags uint8) string { + strMap := map[uint8]string{ + header.TCPFlagAck: "ACK", + header.TCPFlagFin: "FIN", + header.TCPFlagPsh: "PSH", + header.TCPFlagRst: "RST", + header.TCPFlagSyn: "SYN", + header.TCPFlagUrg: "URG", + } + + matches := make([]string, 0) + + for b, str := range strMap { + if (b & flags) == b { + matches = append(matches, str) + } + } + + ret := strings.Join(matches, "+") + + return ret +} + +// Pretty-print a TCP header (with pretty-printed flags) +// Otherwise, using %+v in format strings is a good enough view in most cases +func TCPFieldsToString(hdr *header.TCPFields) string { + return fmt.Sprintf("{SrcPort:%d DstPort:%d, SeqNum:%d AckNum:%d DataOffset:%d Flags:%s WindowSize:%d Checksum:%x UrgentPointer:%d}", + hdr.SrcPort, hdr.DstPort, hdr.SeqNum, hdr.AckNum, hdr.DataOffset, TCPFlagsAsString(hdr.Flags), hdr.WindowSize, hdr.Checksum, hdr.UrgentPointer) +}
\ No newline at end of file |