aboutsummaryrefslogtreecommitdiff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/ipstack/ipstack.go1565
-rw-r--r--pkg/ipstack/ipstack_test.go301
-rw-r--r--pkg/iptcp_utils/iptcp_utils.go144
-rw-r--r--pkg/lnxconfig/lnxconfig.go355
-rw-r--r--pkg/routingtable/routingtable.go90
-rw-r--r--pkg/tcpstack/tcpstack.go232
6 files changed, 2687 insertions, 0 deletions
diff --git a/pkg/ipstack/ipstack.go b/pkg/ipstack/ipstack.go
new file mode 100644
index 0000000..0d317c2
--- /dev/null
+++ b/pkg/ipstack/ipstack.go
@@ -0,0 +1,1565 @@
+package ipstack
+
+// code begins on line 97 after imports, constants, and structs definitions
+// This class is divided as follows:
+// 1) INIT FUNCTIONS
+// 2) DOWN/UP FUNCTIONS
+// 3) SEND/RECV FUNCTIONS
+// 4) CHECKSUM FUNCTIONS
+// 5) RIP FUNCTIONS
+// 6) PROTOCOL HANDLERS
+// 7) HELPER FUNCTIONS
+// 8) GETTER FUNCTIONS
+// 9) PRINT FUNCTIONS
+// 10) CLEANUP FUNCTION
+
+import (
+ "encoding/binary"
+ "fmt"
+ ipv4header "github.com/brown-csci1680/iptcp-headers"
+ "github.com/google/netstack/tcpip/header"
+ "github.com/pkg/errors"
+ "iptcp/pkg/lnxconfig"
+ "iptcp/pkg/iptcp_utils"
+ "log"
+ "net"
+ "net/netip"
+ "sync"
+ "time"
+ "strings"
+ "math/rand"
+ // "github.com/google/netstack/tcpip/header"
+)
+
+const (
+ MAX_IP_PACKET_SIZE = 1400
+ LOCAL_COST uint32 = 0
+ STATIC_COST uint32 = 4294967295 // 2^32 - 1
+ INFINITY = 16
+ SIZE_OF_RIP_ENTRY = 12
+ RIP_PROTOCOL = 200
+ TEST_PROTOCOL = 0
+ TCP_PROTOCOL = 6
+ SIZE_OF_RIP_HEADER = 4
+ MAX_TIMEOUT = 12
+)
+
+// STRUCTS ---------------------------------------------------------------------
+type Interface struct {
+ Name string
+ IpPrefix netip.Prefix
+ UdpAddr netip.AddrPort
+
+ Socket net.UDPConn
+ SocketChannel chan bool
+ State bool
+}
+
+type Neighbor struct {
+ Name string
+ VipAddr netip.Addr
+ UdpAddr netip.AddrPort
+}
+
+type RIPHeader struct {
+ command uint16
+ numEntries uint16
+}
+
+type RIPEntry struct {
+ prefix netip.Prefix
+ cost uint32
+}
+
+type Hop struct {
+ Cost uint32
+ Type string
+
+ Interface *Interface
+ VIP netip.Addr
+}
+
+// GLOBAL VARIABLES (data structures) ------------------------------------------
+var myInterfaces []*Interface
+var myNeighbors = make(map[string][]*Neighbor)
+
+var myRIPNeighbors = make(map[string]*Neighbor)
+
+type HandlerFunc func(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error
+
+var protocolHandlers = make(map[int]HandlerFunc)
+
+var routingTable = make(map[netip.Prefix]Hop)
+
+var timeoutTableMu sync.Mutex
+var timeoutTable = make(map[netip.Prefix]int)
+
+// ************************************** INIT FUNCTIONS **********************************************************
+// reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go
+
+// createUDPListener creates a UDP listener on the given UDP address.
+// It SETS the conn parameter to the created UDP socket.
+func createUDPListener(UdpAddr netip.AddrPort, conn *net.UDPConn) error {
+ listenString := UdpAddr.String()
+ listenAddr, err := net.ResolveUDPAddr("udp4", listenString)
+ if err != nil {
+ return errors.WithMessage(err, "Error resolving address->\t"+listenString)
+ }
+ tmpConn, err := net.ListenUDP("udp4", listenAddr)
+ if err != nil {
+ return errors.WithMessage(err, "Could not bind to UDP port->\t"+listenString)
+ }
+ *conn = *tmpConn
+
+ return nil
+}
+
+// Initialize initializes the data structures and creates the UDP sockets.
+//
+// It will return an error if the lnx file is not valid or if a socket fails to be created.
+//
+// After parsing the lnx file, it does the following:
+// 1. adds each local interface to the routing table, as dictated by its subnet
+// 2. adds neighbors to interface->neighbors[] map
+// 3. adds RIP neighbors to RIP neighbor list
+// 4. adds static routes to routing table
+func Initialize(lnxFilePath string) error {
+ // Parse the file
+ lnxConfig, err := lnxconfig.ParseConfig(lnxFilePath)
+ if err != nil {
+ return errors.WithMessage(err, "Error parsing config file->\t"+lnxFilePath)
+ }
+
+ // 1) add each local "if" to the routing table, as dictated by its subnet
+ for _, iface := range lnxConfig.Interfaces {
+ prefix := netip.PrefixFrom(iface.AssignedIP, iface.AssignedPrefix.Bits())
+ i := &Interface{
+ Name: iface.Name,
+ IpPrefix: prefix,
+ UdpAddr: iface.UDPAddr,
+ Socket: net.UDPConn{},
+ SocketChannel: make(chan bool),
+ State: true,
+ }
+
+ // create the UDP listener
+ err := createUDPListener(iface.UDPAddr, &i.Socket)
+ if err != nil {
+ return errors.WithMessage(err, "Error creating UDP socket for interface->\t"+iface.Name)
+ }
+
+ // start the listener routine
+ go InterfaceListenerRoutine(i)
+
+ // add to the list of interfaces
+ myInterfaces = append(myInterfaces, i)
+
+ // add to the routing table
+ routingTable[prefix.Masked()] = Hop{LOCAL_COST, "L", i, prefix.Addr()}
+ }
+
+ // 2) add neighbors to if->neighbors map
+ for _, neighbor := range lnxConfig.Neighbors {
+ n := &Neighbor{
+ Name: neighbor.InterfaceName,
+ VipAddr: neighbor.DestAddr,
+ UdpAddr: neighbor.UDPAddr,
+ }
+
+ myNeighbors[neighbor.InterfaceName] = append(myNeighbors[neighbor.InterfaceName], n)
+ }
+
+ // 3) add RIP neighbors to RIP neighbor list
+ for _, route := range lnxConfig.RipNeighbors {
+ // add to RIP neighbors
+ for _, iface := range myInterfaces {
+ for _, neighbor := range myNeighbors[iface.Name] {
+ if neighbor.VipAddr == route {
+ myRIPNeighbors[neighbor.VipAddr.String()] = neighbor
+ break
+ }
+ }
+ }
+ }
+
+ // 4) add static routes to routing table
+ for prefix, addr := range lnxConfig.StaticRoutes {
+ // need loops to find the interface that matches the neighbor to send static to
+ // hops needs this interface
+ for _, iface := range myInterfaces {
+ for _, neighbor := range myNeighbors[iface.Name] {
+ if neighbor.VipAddr == addr {
+ routingTable[prefix] = Hop{STATIC_COST, "S", iface, addr}
+ break
+ }
+ }
+ }
+ }
+
+ return nil
+}
+
+// InterfaceListenerRoutine is a go routine for interfaces to listen on a UDP port.
+//
+// It is composed two go routines:
+// 1. a go routine that hangs on the recv and calls RecvIP() when a packet is received
+// 2. a go routine that listens on the channel for a signal to start/stop listening
+//
+// TODO: (performance) remove isUp and use the interface's value instead
+func InterfaceListenerRoutine(i *Interface) {
+ // decompose the interface
+ socket := i.Socket
+ signal := i.SocketChannel
+
+ // booleans to control listening routine
+ isUp := true
+ closed := false
+
+ // fmt.Println("MAKING GO ROUTINE TO LISTEN:\t", socket.LocalAddr().String())
+
+ // go routine that hangs on the recv
+ go func() {
+ defer func() {
+ fmt.Println("exiting go routine that listens on ", socket.LocalAddr().String())
+ }()
+
+ for {
+ if closed { // stop this go routine if channel is closed
+ return
+ }
+ err := RecvIP(i, &isUp)
+ if err != nil {
+ continue
+ }
+ }
+ }()
+
+ for {
+ select {
+ // if the channel is closed, exit
+ case sig, ok := <-signal:
+ if !ok {
+ fmt.Println("channel closed, exiting")
+ closed = true
+ return
+ }
+ // fmt.Println("received isUP SIGNAL with value", sig)
+ isUp = sig
+ // if the channel is not closed, continue
+ default:
+ continue
+ }
+ }
+}
+
+// ************************************** DOWN/UP FUNCTIONS ******************************************************
+
+// InterfaceUp brings up the link layer
+//
+// It does the following:
+// 1. tells the listener (through a channel) to start listening
+// 2. updates the interface state to up
+// 3. sends RIP request to all neighbors of this iface to quickly update the routing table
+func InterfaceUp(iface *Interface) {
+ // set the state to up and send the signal
+ iface.State = true
+ iface.SocketChannel <- true
+
+ // if were a router, send triggered updates on up
+ if _, ok := protocolHandlers[RIP_PROTOCOL]; ok {
+ ripEntries := make([]RIPEntry, 0)
+ ripEntries = append(ripEntries, RIPEntry{iface.IpPrefix.Masked(), LOCAL_COST})
+ SendTriggeredUpdates(ripEntries)
+
+ // send a request to all neighbors of this iface to get info ASAP
+ for _, neighbor := range myNeighbors[iface.Name] {
+ message := MakeRipMessage(1, nil)
+ addr := iface.IpPrefix.Addr()
+ _, err := SendIP(&addr, neighbor, RIP_PROTOCOL, message, neighbor.VipAddr.String(), nil)
+ if err != nil {
+ fmt.Println("Error sending RIP request to neighbor on interfaceup", err)
+ }
+ }
+ }
+
+}
+
+func InterfaceUpREPL(ifaceName string) {
+ iface, err := GetInterfaceByName(ifaceName)
+ if err != nil {
+ fmt.Println("Error getting interface by name", err)
+ return
+ }
+ // set the state to up and send the signal
+ InterfaceUp(iface)
+}
+
+// InterfaceDown cuts off the link layer.
+//
+// It does the following:
+// 1. tells the listener (through a channel) to stop listening
+// 2. updates the interface state to down
+// 3. updates the routing table by removing the routes those neighbors connected to, sending triggered updates.
+func InterfaceDown(iface *Interface) {
+ // set the state to down and send the signal
+ iface.SocketChannel <- false
+ iface.State = false
+
+ // if were a router, send triggered updates on down
+ if _, ok := protocolHandlers[RIP_PROTOCOL]; ok {
+ ripEntries := make([]RIPEntry, 0)
+ ripEntries = append(ripEntries, RIPEntry{iface.IpPrefix.Masked(), INFINITY})
+ SendTriggeredUpdates(ripEntries)
+ }
+}
+
+func InterfaceDownREPL(ifaceName string) {
+ iface, err := GetInterfaceByName(ifaceName)
+ if err != nil {
+ fmt.Println("Error getting interface by name", err)
+ return
+ }
+ // set the state to down and send the signal
+ InterfaceDown(iface)
+}
+
+// ************************************** SEND/RECV FUNCTIONS *******************************************************
+
+// SendIP sends an IP packet to a destination
+//
+// If the header is nil, then a new header is created
+// If the header is not nil, then it will use that header after decrementing TTL & recomputing checksum
+//
+// TODO: (performance) have this take in an interface instead of src for performance
+func SendIP(src *netip.Addr, dest *Neighbor, protocolNum int, message []byte, destIP string, hdr *ipv4header.IPv4Header) (int, error) {
+ // check if the interface is up
+ iface, err := GetInterfaceByName(dest.Name)
+ if !iface.State {
+ return 0, errors.Errorf("error SEND: %s is down", iface.Name)
+ }
+ // if the header is nil, create a new one
+ if hdr == nil {
+ hdr = &ipv4header.IPv4Header{
+ Version: 4,
+ Len: 20, // Header length is always 20 when no IP options
+ TOS: 0,
+ TotalLen: ipv4header.HeaderLen + len(message),
+ ID: 0,
+ Flags: 0,
+ FragOff: 0,
+ TTL: 32,
+ Protocol: protocolNum,
+ Checksum: 0, // Should be 0 until checksum is computed
+ Src: *src,
+ Dst: netip.MustParseAddr(destIP),
+ Options: []byte{},
+ }
+ } else {
+ // if the header is not nil, decrement the TTL
+ hdr = &ipv4header.IPv4Header{
+ Version: 4,
+ Len: 20, // Header length is always 20 when no IP options
+ TOS: 0,
+ TotalLen: ipv4header.HeaderLen + len(message),
+ ID: 0,
+ Flags: 0,
+ FragOff: 0,
+ TTL: hdr.TTL - 1,
+ Protocol: protocolNum,
+ Checksum: 0, // Should be 0 until checksum is computed
+ Src: *src,
+ Dst: netip.MustParseAddr(destIP),
+ Options: []byte{},
+ }
+ }
+
+ // Assemble the header into a byte array
+ headerBytes, err := hdr.Marshal()
+ if err != nil {
+ return 0, err
+ }
+
+ // Compute the checksum (see below)
+ // Cast back to an int, which is what the Header structure expects
+ hdr.Checksum = int(ComputeChecksum(headerBytes))
+
+ headerBytes, err = hdr.Marshal()
+ if err != nil {
+ log.Fatalln("Error marshalling header: ", err)
+ }
+
+ // Combine the header and the message into a single byte array
+ bytesToSend := make([]byte, 0, len(headerBytes)+len(message))
+ bytesToSend = append(bytesToSend, headerBytes...)
+ bytesToSend = append(bytesToSend, []byte(message)...)
+
+ sendAddr, err := net.ResolveUDPAddr("udp4", dest.UdpAddr.String())
+ if err != nil {
+ return -1, errors.WithMessage(err, "Could not bind to UDP port->\t"+dest.UdpAddr.String())
+ }
+
+ // send the packet
+ bytesWritten, err := iface.Socket.WriteToUDP(bytesToSend, sendAddr)
+ if err != nil {
+ fmt.Println("Error writing to UDP socket")
+ return 0, errors.WithMessage(err, "Error writing to UDP socket")
+ }
+
+ return bytesWritten, nil
+}
+
+// RecvIP receives an IP packet from the interface
+// To be called by the listener routine, representing one interface
+// Upon receiving a packet, this function:
+// 1. determines if packet is valid (checksum, TTL)
+// 2. determines if the packet is for me. if so, SENDUP (call correct handler)
+// 3. the packet is not SENTUP, then checks the routing table
+// 4. if there is no route in the routing table, then prints an error and DROPS the packet
+func RecvIP(iface *Interface, isOpen *bool) error {
+ buffer := make([]byte, MAX_IP_PACKET_SIZE)
+
+ // Read on the UDP port
+ // fmt.Println("wating to read from UDP socket")
+ _, _, err := iface.Socket.ReadFromUDP(buffer)
+ if err != nil {
+ return err
+ }
+
+ // check if the interface is up
+ if !*isOpen {
+ return errors.Errorf("error RECV: %s is down", iface.Name)
+ }
+
+ // Marshal the received byte array into a UDP header
+ hdr, err := ipv4header.ParseHeader(buffer)
+ if err != nil {
+ fmt.Println("Error parsing header", err)
+ return err
+ }
+
+ // checksum validation
+ headerSize := hdr.Len
+ headerBytes := buffer[:headerSize]
+ checksumFromHeader := uint16(hdr.Checksum)
+ computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader)
+
+ var checksumState string
+ if computedChecksum == checksumFromHeader {
+ checksumState = "OK"
+ } else {
+ checksumState = "FAIL"
+ }
+
+ // Next, get the message, which starts after the header
+ messageLen := hdr.TotalLen - hdr.Len
+ message := buffer[headerSize : messageLen+headerSize]
+
+ // 1) check if the TTL & checksum is valid
+ TTL := hdr.TTL
+ if TTL == 0 {
+ // drop the packet
+ return nil
+ }
+
+ // check if the checksum is valid
+ if checksumState == "FAIL" {
+ // drop the packet
+ // fmt.Println("checksum failed, dropping packet")
+ return nil
+ }
+
+ //if hdr.Protocol != RIP_PROTOCOL {
+ // fmt.Println("I see a non-rip packet")
+ //}
+
+ // at this point, the packet is valid. next steps consider the forwarding of the packet
+
+ // 2) check if the message is for me, if so, sendUP (aka call the correct handler)
+ for _, myIface := range myInterfaces {
+ if hdr.Dst == myIface.IpPrefix.Addr() {
+ // see if there is a handler for this protocol
+ if handler, ok := protocolHandlers[hdr.Protocol]; ok {
+ if hdr.Protocol != RIP_PROTOCOL {
+ // fmt.Println("this test packet is exactly for me")
+ }
+ err := handler(myIface, message, hdr)
+ if err != nil {
+ fmt.Println(err)
+ }
+ }
+ return nil
+ }
+ }
+
+ // 3) check forwarding table.
+ // - if it's a local hop, send to that iface
+ // - if it's a RIP hop, send to the neighbor with that VIP
+ // fmt.Println("checking routing table")
+ hop, err := Route(hdr.Dst)
+ if err == nil { // on no err, found a match
+ // fmt.Println("found route", hop.VIP)
+ if hop.Type == "S" {
+ // default, static route
+ // drop in this case
+ return nil
+ }
+
+ // - local hop
+ if hop.Type == "L" {
+ // if it's a local route, then the name is the interface name
+ for _, neighbor := range myNeighbors[hop.Interface.Name] {
+ if neighbor.VipAddr == hdr.Dst {
+ _, err2 := SendIP(&hdr.Src, neighbor, hdr.Protocol, message, hdr.Dst.String(), hdr)
+ if err2 != nil {
+ return err2
+ }
+ }
+ }
+ }
+
+ // - rip hop
+ if hop.Type == "R" {
+ // if it's a rip route, then the check is against the hop vip
+ for _, neighbor := range myNeighbors[hop.Interface.Name] {
+ if neighbor.VipAddr == hop.VIP {
+ _, err2 := SendIP(&hdr.Src, neighbor, hdr.Protocol, message, hdr.Dst.String(), hdr)
+ if err2 != nil {
+ return err2
+ }
+ }
+ }
+ }
+ }
+
+ // if not in table, drop packet
+ return nil
+}
+
+// ************************************** CHECKSUM FUNCTIONS ******************************************************
+// reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go
+func ComputeChecksum(b []byte) uint16 {
+ checksum := header.Checksum(b, 0)
+ checksumInv := checksum ^ 0xffff
+
+ return checksumInv
+}
+
+func ValidateChecksum(b []byte, fromHeader uint16) uint16 {
+ checksum := header.Checksum(b, fromHeader)
+
+ return checksum
+}
+
+// ************************************** RIP FUNCTIONS *******************************************************
+
+// PeriodicUpdateRoutine sends RIP updates to neighbors every 5 seconds
+// TODO: (performace) consider making this multithreaded and loops above more efficient
+func PeriodicUpdateRoutine() {
+ for {
+ // for each periodic update, we want to send our nodes in the table
+ for _, iface := range myInterfaces {
+ for _, n := range myNeighbors[iface.Name] {
+ _, in := myRIPNeighbors[n.VipAddr.String()]
+ // if the neighbor is not a RIP neighbor, skip it
+ if !in {
+ continue
+ }
+
+ // Sending to a rip neighbor
+ // create the entries
+ entries := make([]RIPEntry, 0)
+ for prefix, hop := range routingTable {
+ // implement split horizon + poison reverse at entry level
+ var cost uint32
+ if hop.VIP == n.VipAddr {
+ cost = INFINITY
+ } else {
+ cost = hop.Cost
+ }
+ entries = append(entries,
+ RIPEntry{
+ prefix: prefix,
+ cost: cost,
+ })
+ }
+
+ // make the message and send it
+ message := MakeRipMessage(2, entries)
+ addr := iface.IpPrefix.Addr()
+ _, err := SendIP(&addr, n, RIP_PROTOCOL, message, n.VipAddr.String(), nil)
+ if err != nil {
+ // fmt.Printf("Error sending RIP message to %s\n", n.VipAddr.String())
+ continue
+ }
+ }
+ }
+
+ // wait 5 sec and repeat
+ time.Sleep(5 * time.Second)
+ }
+}
+
+// SendTriggeredUpdates sends the entries consumed to ALL neighbors
+func SendTriggeredUpdates(newEntries []RIPEntry) {
+ for _, iface := range myInterfaces {
+ for _, n := range myNeighbors[iface.Name] {
+ // only send to RIP neighbors, else skip
+ _, in := myRIPNeighbors[n.VipAddr.String()]
+ if !in {
+ continue
+ }
+
+ // send the made entries to the neighbor
+ message := MakeRipMessage(2, newEntries)
+ addr := iface.IpPrefix.Addr()
+ _, err := SendIP(&addr, n, RIP_PROTOCOL, message, n.VipAddr.String(), nil)
+ if err != nil {
+ // fmt.Printf("Error sending RIP triggered update to %s\n", n.VipAddr.String())
+ continue
+ }
+ }
+ }
+}
+
+// ManageTimeoutsRoutine manages the timeout table by incrementing the timeouts every second.
+// If a timeout reaches MAX_TIMEOUT, then the entry is deleted from the routing table and a triggered update is sent.
+func ManageTimeoutsRoutine() {
+ for {
+ time.Sleep(time.Second)
+
+ timeoutTableMu.Lock()
+ // check if any timeouts have occurred
+ for key, _ := range timeoutTable {
+ timeoutTable[key]++
+ // if the timeout is MAX_TIMEOUT, delete the entry
+ if timeoutTable[key] == MAX_TIMEOUT {
+ delete(timeoutTable, key)
+
+ newEntries := make([]RIPEntry, 0)
+ delete(routingTable, key)
+ newEntries = append(newEntries, RIPEntry{key, INFINITY})
+
+ // send triggered update on timeout
+ if len(newEntries) > 0 {
+ SendTriggeredUpdates(newEntries)
+ }
+ }
+ }
+ timeoutTableMu.Unlock()
+ //fmt.Println("Timeout table: ", timeoutTable)
+ }
+}
+
+// StartRipRoutines handles all the routines for RIP
+// 1. sends a RIP request to every neighbor
+// 2. starts the routine that sends periodic updates every 5 seconds
+// 3. starts the routine that manages the timeout table
+func StartRipRoutines() {
+ // send a request to every neighbor
+ go func() {
+ for _, iface := range myInterfaces {
+ for _, neighbor := range myNeighbors[iface.Name] {
+ // only send to RIP neighbors, else skip
+ _, in := myRIPNeighbors[neighbor.VipAddr.String()]
+ if !in {
+ continue
+ }
+ // send a request
+ message := MakeRipMessage(1, nil)
+ addr := iface.IpPrefix.Addr()
+ _, err := SendIP(&addr, neighbor, RIP_PROTOCOL, message, neighbor.VipAddr.String(), nil)
+ if err != nil {
+ return
+ }
+ }
+ }
+ }()
+
+ // start a routine that sends updates every 5 seconds
+ go PeriodicUpdateRoutine()
+
+ // make a "timeout" table, for each response we add to the table via rip
+ go ManageTimeoutsRoutine()
+}
+
+// ************************************** PROTOCOL HANDLERS *******************************************************
+
+// RegisterProtocolHandler registers a protocol handler for a given protocol number
+// Returns true if the protocol number is valid, false otherwise
+func RegisterProtocolHandler(protocolNum int) bool {
+ switch protocolNum {
+ case RIP_PROTOCOL:
+ protocolHandlers[protocolNum] = HandleRIP
+ go StartRipRoutines()
+ return true
+ case TEST_PROTOCOL:
+ protocolHandlers[protocolNum] = HandleTestPackets
+ return true
+ case TCP_PROTOCOL:
+ protocolHandlers[protocolNum] = HandleTCP
+ return true
+ default:
+ return false
+ }
+}
+
+// HandleRIP handles incoming RIP packets in the following way:
+// 1. if the command is a request, send a RIP response only to that requestor
+// 2. if the command is a response, parse the entries, update the routing table from them,
+// and send applicable triggered updates (see implementation for how to update)
+func HandleRIP(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error {
+ // parse the RIP message
+ command := int(binary.BigEndian.Uint16(message[0:2]))
+ switch command {
+ // request message
+ case 1:
+ //fmt.Println("Received RIP command for specific info")
+
+ // only send if the person asking is a RIP neighbor
+ neighbor, in := myRIPNeighbors[hdr.Src.String()]
+ if !in {
+ break
+ }
+
+ // build the entries
+ entries := make([]RIPEntry, 0)
+ for prefix, hop := range routingTable {
+ // implement split horizon + poison reverse at entry level
+ var cost uint32
+ if hop.VIP == hdr.Src {
+ cost = INFINITY
+ } else {
+ cost = hop.Cost
+ }
+ entries = append(entries,
+ RIPEntry{
+ prefix: prefix,
+ cost: cost,
+ })
+ }
+ // send the entries
+ res := MakeRipMessage(2, entries)
+ _, err := SendIP(&hdr.Dst, neighbor, RIP_PROTOCOL, res, hdr.Src.String(), nil)
+ if err != nil {
+ return err
+ }
+ break
+ // response message
+ case 2:
+ // fmt.Println("Received RIP response with", numEntries, "entries")
+ numEntries := int(binary.BigEndian.Uint16(message[2:4]))
+
+ // parse the entries
+ entries := make([]RIPEntry, 0)
+ for i := 0; i < numEntries; i++ {
+ offset := SIZE_OF_RIP_HEADER + i*SIZE_OF_RIP_ENTRY
+
+ // each field is 4 bytes
+ cost := binary.BigEndian.Uint32(message[offset : offset+4])
+ address, _ := netip.AddrFromSlice(message[offset+4 : offset+8])
+ mask := net.IPv4Mask(message[offset+8], message[offset+9], message[offset+10], message[offset+11])
+
+ // make the prefix
+ bits, _ := mask.Size()
+ prefix := netip.PrefixFrom(address, bits)
+
+ entries = append(entries, RIPEntry{prefix, cost})
+ }
+
+ // update the routing table
+ triggeredEntries := make([]RIPEntry, 0)
+ for _, entry := range entries {
+ destination := entry.prefix.Masked()
+
+ // make upperbound for cost infinity
+ var newCost uint32
+ if entry.cost == INFINITY {
+ newCost = INFINITY
+ } else {
+ newCost = entry.cost + 1
+ }
+
+ hop, isin := routingTable[destination]
+ // if prefix not in table, add it (as long as it's not infinity)
+ if !isin {
+ if newCost != INFINITY {
+ // given an update to table, this is now a triggeredUpdate
+ // triggeredEntries = append(triggeredEntries, RIPEntry{destination, entry.cost + 1})
+
+ routingTable[destination] = Hop{newCost, "R", src, hdr.Src}
+ timeoutTable[destination] = 0
+ }
+ continue
+ }
+
+ // if the entry is in the table, only two cases affect the table:
+ // 1) the entry SRC is updating (or confirming) the hop to itself
+ // in this case, only update if the cost is different
+ // if it's infinity, then the route has expired.
+ // we must set the cost to INF then delete the entry after 12 seconds
+ //
+ // 2) a different entry SRC reveals a shorter path to the destination
+ // in this case, update the routing table to use this new path
+ //
+ // all other cases don't meaningfully change the route
+
+ // first, upon an update from this prefix, reset its timeout
+ if hop.Type == "R" {
+ timeoutTableMu.Lock()
+ _, in := timeoutTable[destination]
+ if in {
+ if routingTable[destination].VIP == hdr.Src {
+ timeoutTable[destination] = 0
+ }
+ }
+ timeoutTableMu.Unlock()
+ }
+
+ // case 1) the entry SRC == the hop to itself
+ if hop.VIP == hdr.Src &&
+ newCost != hop.Cost {
+ // given an update to table, this is now a triggeredUpdate
+ triggeredEntries = append(triggeredEntries, RIPEntry{destination, newCost})
+ routingTable[destination] = Hop{newCost, "R", src, hop.VIP}
+
+ // if we receive infinity from the same neighbor, then delete the route after 12 sec
+ if entry.cost == INFINITY {
+ // remove after GC time if the COST is still INFINITY
+ go func() {
+ time.Sleep(time.Second * time.Duration(MAX_TIMEOUT))
+ if routingTable[destination].Cost == INFINITY {
+ delete(routingTable, destination)
+ timeoutTableMu.Lock()
+ delete(timeoutTable, destination)
+ timeoutTableMu.Unlock()
+ }
+ }()
+ }
+ continue
+ }
+
+ // case 2) a shorter route for this destination is revealed from a different neighbor
+ if newCost < hop.Cost && newCost != INFINITY {
+ triggeredEntries = append(triggeredEntries, RIPEntry{destination, entry.cost + 1})
+ routingTable[destination] = Hop{entry.cost + 1, "R", src, hdr.Src}
+ continue
+ }
+ }
+
+ // send out triggered updates
+ if len(triggeredEntries) > 0 {
+ SendTriggeredUpdates(triggeredEntries)
+ }
+ }
+
+ return nil
+}
+
+// prints the test packet as per the spec
+func HandleTestPackets(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error {
+ fmt.Printf("Received test packet: Src: %s, Dst: %s, TTL: %d, Data: %s\n",
+ hdr.Src.String(), hdr.Dst.String(), hdr.TTL, string(message))
+ return nil
+}
+
+func HandleTCP(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error {
+ fmt.Println("I see a TCP packet")
+
+ tcpHeaderAndData := message
+ tcpHdr := iptcp_utils.ParseTCPHeader(tcpHeaderAndData)
+ tcpPayload := tcpHeaderAndData[tcpHdr.DataOffset:]
+ tcpChecksumFromHeader := tcpHdr.Checksum
+ tcpHdr.Checksum = 0
+ tcpComputedChecksum := iptcp_utils.ComputeTCPChecksum(&tcpHdr, hdr.Src, hdr.Dst, tcpPayload)
+
+ var tcpChecksumState string
+ if tcpComputedChecksum == tcpChecksumFromHeader {
+ tcpChecksumState = "OK"
+ } else {
+ tcpChecksumState = "FAIL"
+ }
+
+ if tcpChecksumState == "FAIL" {
+ // drop the packet
+ fmt.Println("checksum failed, dropping packet")
+ return nil
+ }
+
+ switch tcpHdr.Flags {
+ case header.TCPFlagSyn:
+ fmt.Println("I see a SYN flag")
+ // if the SYN flag is set, then send a SYNACK
+ available := false
+
+ socketEntry, in := VHostSocketMaps[SocketKey{hdr.Dst.String(), tcpHdr.DstPort, hdr.Src.String(), tcpHdr.SrcPort}]
+ if !in {
+ fmt.Println("no socket entry found")
+ } else if socketEntry.State == Established {
+ fmt.Println("socket entry found")
+
+ // make ack header
+ tcpHdr := &header.TCPFields{
+ SrcPort: tcpHdr.DstPort,
+ DstPort: tcpHdr.SrcPort,
+ SeqNum: tcpHdr.SeqNum,
+ AckNum: tcpHdr.SeqNum + 1,
+ DataOffset: 20,
+ Flags: 0x10,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ // make the payload
+ err := SendTCP(tcpHdr, message, hdr.Dst, hdr.Src)
+ if err != nil {
+ fmt.Println(err)
+ }
+ socketEntry.Conn.RecvBuffer.buffer = append(socketEntry.Conn.RecvBuffer.buffer, tcpPayload...)
+ socketEntry.Conn.RecvBuffer.recvNext += uint32(len(tcpPayload))
+ break
+ }
+ // add to table if available
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ // todo: check between all 4 field in tuple
+ if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == Listening{
+ // add a new socketEntry to the map
+ newEntry := &SocketEntry{
+ LocalPort: tcpHdr.DstPort,
+ RemotePort: tcpHdr.SrcPort,
+ LocalIP: hdr.Dst.String(),
+ RemoteIP: hdr.Src.String(),
+ State: SYNRECIEVED,
+ Socket: socketsMade,
+ }
+ // add the entry to the map
+ key := SocketKey{hdr.Dst.String(), tcpHdr.DstPort, hdr.Src.String(), tcpHdr.SrcPort}
+ VHostSocketMaps[key] = newEntry
+ socketsMade += 1
+ // add the entry to the map
+ available = true
+ break
+ }
+ }
+ mapMutex.Unlock()
+
+ // if no socket is available, then drop the packet
+ if !available {
+ fmt.Println("no socket available")
+ return nil
+ }
+ // make the header
+ tcpHdr := &header.TCPFields{
+ SrcPort: tcpHdr.DstPort,
+ DstPort: tcpHdr.SrcPort,
+ SeqNum: tcpHdr.SeqNum,
+ AckNum: tcpHdr.SeqNum + 1,
+ DataOffset: 20,
+ Flags: 0x12,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ // make the payload
+ synAckPayload := []byte{}
+ err := SendTCP(tcpHdr, synAckPayload, hdr.Dst, hdr.Src)
+ if err != nil {
+ fmt.Println(err)
+ }
+ break
+ case header.TCPFlagAck | header.TCPFlagSyn:
+ fmt.Println("I see a SYNACK flag")
+ // lookup for socket entry and update its state
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == SYNSENT {
+ socketEntry.State = Established
+ break
+ }
+ }
+ mapMutex.Unlock()
+
+ // send an ACK
+ // make the header
+ tcpHdr := &header.TCPFields{
+ SrcPort: tcpHdr.DstPort,
+ DstPort: tcpHdr.SrcPort,
+ SeqNum: tcpHdr.SeqNum + 1,
+ AckNum: tcpHdr.SeqNum,
+ DataOffset: 20,
+ Flags: 0x10,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ // make the payload
+ ackPayload := []byte{}
+ err := SendTCP(tcpHdr, ackPayload, hdr.Dst, hdr.Src)
+ if err != nil {
+ fmt.Println(err)
+ }
+ break
+ case header.TCPFlagAck:
+ fmt.Println("I see an ACK flag")
+ // lookup for socket entry and update its state
+ // set synChan to true (TODO)
+ key := SocketKey{hdr.Dst.String(), tcpHdr.DstPort, hdr.Src.String(), tcpHdr.SrcPort}
+ socketEntry, in := VHostSocketMaps[key]
+ if !in {
+ fmt.Println("no socket entry found")
+ } else if (socketEntry.State == Established) {
+ fmt.Println("socket entry found")
+ // socketEntry.Conn.RecvBuffer.buffer = append(socketEntry.Conn.RecvBuffer.buffer, tcpPayload...)
+ socketEntry.Conn.SendBuffer.una += uint32(len(tcpPayload))
+ break
+ }
+
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == SYNRECIEVED {
+ socketEntry.State = Established
+ break
+ }
+ }
+ mapMutex.Unlock()
+ break
+ default:
+ fmt.Println("I see a non TCP packet")
+ break
+ }
+
+
+ return nil
+}
+
+// *********************************************** HELPERS **********************************************************
+
+// Route returns the next HOP, based on longest prefix match for a given ip
+// TODO: revisit how to do this at the bit level, not hardcoded for 32 & 24
+func Route(src netip.Addr) (Hop, error) {
+ possibleBits := [2]int{32, 24}
+ for _, bits := range possibleBits {
+ cmpPrefix := netip.PrefixFrom(src, bits)
+ for prefix, hop := range routingTable {
+ if cmpPrefix.Overlaps(prefix) {
+ return hop, nil
+ }
+ }
+ }
+ return Hop{}, errors.Errorf("error ROUTE: destination %s does not exist on routing table.", src)
+}
+
+// MakeRipMessage returns the byte array to be used in SendIp for a RIP packet
+func MakeRipMessage(command uint16, entries []RIPEntry) []byte {
+ if command == 1 { // request message
+ buf := make([]byte, SIZE_OF_RIP_HEADER)
+ binary.BigEndian.PutUint16(buf[0:2], command)
+ binary.BigEndian.PutUint16(buf[2:4], uint16(0))
+ return buf
+ }
+
+ // command == 2, response message
+
+ // create the buffer
+ bufLen := SIZE_OF_RIP_HEADER + // sizeof uint16 is 2, we have two of them
+ len(entries)*SIZE_OF_RIP_ENTRY // each entry is 12
+
+ buf := make([]byte, bufLen)
+
+ // fill in the header
+ binary.BigEndian.PutUint16(buf[0:2], command)
+ binary.BigEndian.PutUint16(buf[2:4], uint16(len(entries)))
+
+ // fill in the entries
+ for i, entry := range entries {
+ offset := SIZE_OF_RIP_HEADER + i*SIZE_OF_RIP_ENTRY
+ binary.BigEndian.PutUint32(buf[offset:offset+4], entry.cost) // 0-3 = 4 bytes
+ copy(buf[offset+4:offset+8], entry.prefix.Addr().AsSlice()) // 4-7 = 4 bytes
+
+ // convert the prefix to a uint32
+ ipv4Netmask := uint32(0xffffffff)
+ ipv4Netmask <<= 32 - entry.prefix.Bits()
+ binary.BigEndian.PutUint32(buf[offset+8:offset+12], ipv4Netmask)
+ }
+
+ return buf
+}
+
+// ************************************** GETTER FUNCTIONS **********************************************************
+func GetInterfaceByName(ifaceName string) (*Interface, error) {
+ // iterate through the interfaces and return the one with the same name
+ for _, iface := range myInterfaces {
+ if iface.Name == ifaceName {
+ return iface, nil
+ }
+ }
+ return nil, errors.Errorf("No interface with name %s", ifaceName)
+}
+
+func GetInterfaces() []*Interface {
+ return myInterfaces
+}
+
+func GetNeighbors() map[string][]*Neighbor {
+ return myNeighbors
+}
+
+func GetRoutes() map[netip.Prefix]Hop {
+ return routingTable
+}
+
+// ************************************** PRINT FUNCTIONS **********************************************************
+
+// SprintInterfaces returns a string representation of the interfaces data structure
+func SprintInterfaces() string {
+ tmp := ""
+ for _, iface := range myInterfaces {
+ if iface.State {
+ // if the state is up, print UP
+ tmp += fmt.Sprintf("%s\t%s\t%s\n", iface.Name, iface.IpPrefix.String(), "UP")
+ } else {
+ // if the state is down, print DOWN
+ tmp += fmt.Sprintf("%s\t%s\t%s\n", iface.Name, iface.IpPrefix.String(), "DOWN")
+ }
+ }
+ return tmp
+}
+
+// SprintNeighbors returns a string representation of the neighbors data structure
+func SprintNeighbors() string {
+ tmp := ""
+ for _, iface := range myInterfaces {
+ if !iface.State {
+ // if the interface is down, skip it
+ continue
+ }
+ for _, n := range myNeighbors[iface.Name] {
+ tmp += fmt.Sprintf("%s\t%s\t%s\n", iface.Name, n.VipAddr.String(), n.UdpAddr.String())
+ }
+ }
+ return tmp
+}
+
+// SprintRoutingTable returns a string representation of the routing table
+func SprintRoutingTable() string {
+ tmp := ""
+ for prefix, hop := range routingTable {
+ if hop.Type == "L" {
+ // if the hop is local, print LOCAL
+ tmp += fmt.Sprintf("%s\t%s\tLOCAL:%s\t%d\n", hop.Type, prefix.String(), hop.Interface.Name, hop.Cost)
+ } else if hop.Type == "S" {
+ // if the hop is static, don't print the cost
+ tmp += fmt.Sprintf("%s\t%s\t%s\t%s\n", hop.Type, prefix.String(), hop.VIP.String(), "-")
+ } else {
+ tmp += fmt.Sprintf("%s\t%s\t%s\t%d\n", hop.Type, prefix.String(), hop.VIP.String(), hop.Cost)
+ }
+ }
+ return tmp
+}
+
+// ************************************** CLEANUP FUNCTIONS **********************************************************
+
+// CleanUp cleans up the data structures and closes the UDP sockets
+func CleanUp() {
+ fmt.Print("Cleaning up...\n")
+
+ // go through the interfaces, pop thread & close the UDP FDs
+ for _, iface := range myInterfaces {
+ // close the channel
+ if iface.SocketChannel != nil {
+ close(iface.SocketChannel)
+ }
+ // close the UDP FD
+ err := iface.Socket.Close()
+ if err != nil {
+ continue
+ }
+ }
+
+ // delete all the neighbors
+ myNeighbors = make(map[string][]*Neighbor)
+ // delete all the interfaces
+ myInterfaces = nil
+ // delete the routing table
+ routingTable = make(map[netip.Prefix]Hop)
+
+ time.Sleep(5 * time.Millisecond)
+}
+
+// ************************************** TCP FUNCTIONS **********************************************************
+
+type ConnectionState string
+const (
+ Established ConnectionState = "ESTABLISHED"
+ Listening ConnectionState = "LISTENING"
+ Closed ConnectionState = "CLOSED"
+ SYNSENT ConnectionState = "SYNSENT"
+ SYNRECIEVED ConnectionState = "SYNRECIEVED"
+ MAX_WINDOW_SIZE = 65535
+)
+
+// VTCPListener represents a listener socket (similar to Go’s net.TCPListener)
+type VTCPListener struct {
+ LocalAddr string
+ LocalPort uint16
+ RemoteAddr string
+ RemotePort uint16
+ Socket int
+ State ConnectionState
+}
+
+// // VTCPConn represents a “normal” socket for a TCP connection between two endpoints (similar to Go’s net.TCPConn)
+type VTCPConn struct {
+ LocalAddr string
+ LocalPort uint16
+ RemoteAddr string
+ RemotePort uint16
+ Socket int
+ State ConnectionState
+ SendBuffer *SendBuffer
+ RecvBuffer *RecvBuffer
+}
+
+type SocketEntry struct {
+ Socket int
+ LocalIP string
+ LocalPort uint16
+ RemoteIP string
+ RemotePort uint16
+ State ConnectionState
+ Conn *VTCPConn
+}
+
+type SocketKey struct {
+ LocalIP string
+ LocalPort uint16
+ RemoteIP string
+ RemotePort uint16
+}
+
+type RecvBuffer struct {
+ recvNext uint32
+ lbr uint32
+ buffer []byte
+}
+
+type SendBuffer struct {
+ una uint32
+ nxt uint32
+ lbr uint32
+ buffer []byte
+}
+
+// create a socket map
+// var VHostSocketMaps = make(map[int]*SocketEntry)
+var VHostSocketMaps = make(map[SocketKey]*SocketEntry)
+// create a channel map
+var VHostChannelMaps = make(map[int]chan []byte)
+var mapMutex = &sync.Mutex{}
+var socketsMade = 0
+var startingSeqNum = rand.Uint32()
+
+// Listen Sockets
+func VListen(port uint16) (*VTCPListener, error) {
+ myIP := GetInterfaces()[0].IpPrefix.Addr()
+ listener := &VTCPListener{
+ Socket: socketsMade,
+ State: Listening,
+ LocalPort: port,
+ LocalAddr: myIP.String(),
+ }
+
+ // add the socket to the socket map
+ mapMutex.Lock()
+
+ key := SocketKey{myIP.String(), port, "", 0}
+ VHostSocketMaps[key] = &SocketEntry{
+ Socket: socketsMade,
+ LocalIP: myIP.String(),
+ LocalPort: port,
+ RemoteIP: "0.0.0.0",
+ RemotePort: 0,
+ State: Listening,
+ }
+ mapMutex.Unlock()
+ socketsMade += 1
+ return listener, nil
+
+}
+
+func (l *VTCPListener) VAccept() (*VTCPConn, error) {
+ // synChan = make(chan bool)
+ for {
+ // wait for a SYN request
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.State == Established {
+ // create a new VTCPConn
+ conn := &VTCPConn{
+ LocalAddr: socketEntry.LocalIP,
+ LocalPort: socketEntry.LocalPort,
+ RemoteAddr: socketEntry.RemoteIP,
+ RemotePort: socketEntry.RemotePort,
+ Socket: socketEntry.Socket,
+ State: Established,
+ SendBuffer: &SendBuffer{
+ una: 0,
+ nxt: 0,
+ lbr: 0,
+ buffer: make([]byte, MAX_WINDOW_SIZE),
+ },
+ RecvBuffer: &RecvBuffer{
+ recvNext: 0,
+ lbr: 0,
+ buffer: make([]byte, MAX_WINDOW_SIZE),
+ },
+ }
+ socketEntry.Conn = conn
+ mapMutex.Unlock()
+ return conn, nil
+ }
+ }
+ mapMutex.Unlock()
+
+ }
+}
+
+func GetRandomPort() uint16 {
+ const (
+ minDynamicPort = 49152
+ maxDynamicPort = 65535
+ )
+ return uint16(rand.Intn(maxDynamicPort - minDynamicPort) + minDynamicPort)
+}
+
+func VConnect(ip string, port uint16) (*VTCPConn, error) {
+ // get my ip address
+ myIP := GetInterfaces()[0].IpPrefix.Addr()
+ // get random port
+ portRand := GetRandomPort()
+
+ tcpHdr := &header.TCPFields{
+ SrcPort: portRand,
+ DstPort: port,
+ SeqNum: startingSeqNum,
+ AckNum: 0,
+ DataOffset: 20,
+ Flags: header.TCPFlagSyn,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ payload := []byte{}
+ ipParsed, err := netip.ParseAddr(ip)
+ if err != nil {
+ return nil, err
+ }
+
+ err = SendTCP(tcpHdr, payload, myIP, ipParsed)
+ if err != nil {
+ return nil, err
+ }
+
+ conn := &VTCPConn{
+ LocalAddr: myIP.String(),
+ LocalPort: portRand,
+ RemoteAddr: ip,
+ RemotePort: port,
+ Socket: socketsMade,
+ State: Established,
+ SendBuffer: &SendBuffer{
+ una: 0,
+ nxt: 0,
+ lbr: 0,
+ buffer: make([]byte, MAX_WINDOW_SIZE),
+ },
+ RecvBuffer: &RecvBuffer{
+ recvNext: 0,
+ lbr: 0,
+ buffer: make([]byte, MAX_WINDOW_SIZE),
+ },
+ }
+
+ // add the socket to the socket map
+ key := SocketKey{myIP.String(), portRand, ip, port}
+ mapMutex.Lock()
+ VHostSocketMaps[key] = &SocketEntry{
+ Socket: socketsMade,
+ LocalIP: myIP.String(),
+ LocalPort: portRand,
+ RemoteIP: ip,
+ RemotePort: port,
+ State: SYNSENT,
+ Conn: conn,
+ }
+ mapMutex.Unlock()
+ socketsMade += 1
+
+ return conn, nil
+}
+
+func SendTCP(tcpHdr *header.TCPFields, payload []byte, myIP netip.Addr, ipParsed netip.Addr) error {
+ checksum := iptcp_utils.ComputeTCPChecksum(tcpHdr, myIP, ipParsed, payload)
+ tcpHdr.Checksum = checksum
+
+ tcpHeaderBytes := make(header.TCP, iptcp_utils.TcpHeaderLen)
+ tcpHeaderBytes.Encode(tcpHdr)
+
+ ipPacketPayload := make([]byte, 0, len(tcpHeaderBytes)+len(payload))
+ ipPacketPayload = append(ipPacketPayload, tcpHeaderBytes...)
+ ipPacketPayload = append(ipPacketPayload, []byte(payload)...)
+
+ // lookup neighbor
+ address := ipParsed
+ hop, err := Route(address)
+ if err != nil {
+ fmt.Println(err)
+ return err
+ }
+ myAddr := hop.Interface.IpPrefix.Addr()
+
+ for _, neighbor := range GetNeighbors()[hop.Interface.Name] {
+ if neighbor.VipAddr == address ||
+ neighbor.VipAddr == hop.VIP && hop.Type == "S" {
+ bytesWritten, err := SendIP(&myAddr, neighbor, TCP_PROTOCOL, ipPacketPayload, ipParsed.String(), nil)
+ fmt.Printf("Sent %d bytes to %s\n", bytesWritten, neighbor.VipAddr.String())
+ if err != nil {
+ fmt.Println(err)
+ }
+ }
+ }
+ return nil
+}
+
+func SprintSockets() string {
+ tmp := ""
+ for _, socket := range VHostSocketMaps {
+ // remove the spaces of the local and remote ip variables
+ socket.LocalIP = strings.ReplaceAll(socket.LocalIP, " ", "")
+ socket.RemoteIP = strings.ReplaceAll(socket.RemoteIP, " ", "")
+ if socket.RemotePort == 0 {
+ tmp += fmt.Sprintf("%d\t%s\t%d\t%s\t\t%d\t%s\n", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State)
+ continue
+ }
+ tmp += fmt.Sprintf("%d\t%s\t%d\t%s\t%d\t%s\n", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State)
+ }
+ return tmp
+}
+
+// MILESTONE 2
+func (c *VTCPConn) VClose() error {
+ // check if the socket is in the map
+ key := SocketKey{c.LocalAddr, c.LocalPort, c.RemoteAddr, c.RemotePort}
+ mapMutex.Lock()
+ socketEntry, in := VHostSocketMaps[key]
+ mapMutex.Unlock()
+ if !in {
+ return errors.Errorf("error VClose: socket %d does not exist", c.Socket)
+ }
+
+ // change the state to closed
+ socketEntry.State = Closed
+ return nil
+}
+
+
+// advertise window = max window size - (next - 1 - lbr)
+
+// early arrivals queue
+var earlyArrivals = make([][]byte, 0)
+
+// retranmission queue
+var retransmissionQueue = make([][]byte, 0)
+
+func (c *VTCPConn) VWrite(payload []byte) (int, error) {
+ // check if the socket is in the map
+ key := SocketKey{c.LocalAddr, c.LocalPort, c.RemoteAddr, c.RemotePort}
+ mapMutex.Lock()
+ socketEntry, in := VHostSocketMaps[key]
+ mapMutex.Unlock()
+ if !in {
+ return 0, errors.Errorf("error VWrite: socket %d does not exist", c.Socket)
+ }
+
+ // check if the state is established
+ if socketEntry.State != Established {
+ return 0, errors.Errorf("error VWrite: socket %d is not in established state", c.Socket)
+ }
+
+ // check if the payload is empty
+ if len(payload) == 0 {
+ return 0, nil
+ }
+
+ // check if the payload is larger than the window size
+ if len(payload) > MAX_WINDOW_SIZE {
+ return 0, errors.Errorf("error VWrite: payload is larger than the window size")
+ }
+
+ // check if the payload is larger than the available window size
+ if len(payload) > int(MAX_WINDOW_SIZE - (c.SendBuffer.nxt - 1 - c.SendBuffer.lbr)) {
+ return 0, errors.Errorf("error VWrite: payload is larger than the available window size")
+ }
+
+ // make the header
+ advertisedWindow := MAX_WINDOW_SIZE - (c.SendBuffer.nxt - 1 - c.SendBuffer.lbr)
+ tcpHdr := &header.TCPFields{
+ SrcPort: c.LocalPort,
+ DstPort: c.RemotePort,
+ SeqNum: c.SendBuffer.nxt,
+ AckNum: c.SendBuffer.una,
+ DataOffset: 20,
+ Flags: header.TCPFlagSyn,
+ WindowSize: uint16(advertisedWindow),
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+
+ myIP := GetInterfaces()[0].IpPrefix.Addr()
+ ipParsed, err := netip.ParseAddr(c.RemoteAddr)
+ if err != nil {
+ return 0, err
+ }
+
+ err = SendTCP(tcpHdr, payload, myIP, ipParsed)
+ if err != nil {
+ return 0, err
+ }
+ // update the next sequence number
+ // c.SendBuffer.nxt += uint32(len(payload))
+
+
+ c.SendBuffer.lbr += uint32(len(payload))
+ return len(payload), nil
+}
+
+
+func (c *VTCPConn) VRead(numBytesToRead int) (int, string, error) {
+ // check if the socket is in the map
+ key := SocketKey{c.LocalAddr, c.LocalPort, c.RemoteAddr, c.RemotePort}
+ // mapMutex.Lock()
+ socketEntry, in := VHostSocketMaps[key]
+ // mapMutex.Unlock()
+ // check if the socket is in the map
+ if !in {
+ return 0, "", errors.Errorf("error VRead: socket %d does not exist", c.Socket)
+ }
+
+ // check if the state is established
+ if socketEntry.State != Established {
+ return 0, "", errors.Errorf("error VRead: socket %d is not in established state", c.Socket)
+ }
+ fmt.Println("I am in VRead")
+ fmt.Println("I have", c.RecvBuffer.recvNext - c.RecvBuffer.lbr, "bytes to read")
+ fmt.Println(c.RecvBuffer.recvNext, c.RecvBuffer.lbr)
+ if (c.RecvBuffer.lbr < c.RecvBuffer.recvNext && c.RecvBuffer.recvNext - c.RecvBuffer.lbr >= uint32(numBytesToRead)) {
+ fmt.Println("I have enough data to read")
+ toReturn := string(socketEntry.Conn.RecvBuffer.buffer[c.RecvBuffer.lbr:c.RecvBuffer.lbr+uint32(numBytesToRead)])
+ // update the last byte read
+ c.RecvBuffer.lbr += uint32(numBytesToRead)
+ // return the data
+ return numBytesToRead, toReturn, nil
+ }
+
+ return 0, "", nil
+} \ No newline at end of file
diff --git a/pkg/ipstack/ipstack_test.go b/pkg/ipstack/ipstack_test.go
new file mode 100644
index 0000000..e782f67
--- /dev/null
+++ b/pkg/ipstack/ipstack_test.go
@@ -0,0 +1,301 @@
+package ipstack
+
+import (
+ "fmt"
+ "net/netip"
+ "testing"
+)
+
+//func TestInitialize(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+// fmt.Printf("Interfaces:\n%s\n\n", SprintInterfaces())
+// fmt.Printf("Neighbors:\n%s\n", SprintNeighbors())
+// fmt.Printf("RoutingTable:\n%s\n", SprintRoutingTable())
+//
+// fmt.Println("TestInitialize successful")
+// t.Cleanup(func() { CleanUp() })
+//}
+//
+//func TestInterfaceUpThenDown(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+//
+// iface, err := GetInterfaceByName("if0")
+// if err != nil {
+// t.Error(err)
+// }
+//
+// InterfaceUp(iface)
+// if iface.State == false {
+// t.Error("iface state should be true")
+// }
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+//
+// time.Sleep(5 * time.Millisecond) // allow time to print
+//
+// InterfaceDown(iface)
+// if iface.State == true {
+// t.Error("iface state should be false")
+// }
+//
+// time.Sleep(5 * time.Millisecond) // allow time to print
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+//
+// fmt.Println("TestInterfaceUpThenDown successful")
+// t.Cleanup(func() { CleanUp() })
+//}
+//
+//func TestInterfaceUpThenDownTwice(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+//
+// iface, err := GetInterfaceByName("if0")
+// if err != nil {
+// t.Error(err)
+// }
+//
+// InterfaceUp(iface)
+// if iface.State == false {
+// t.Error("iface state should be true")
+// }
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+//
+// time.Sleep(5 * time.Millisecond) // allow time to print
+//
+// fmt.Println("putting interface down")
+// InterfaceDown(iface)
+// if iface.State == true {
+// t.Error("iface state should be false")
+// }
+//
+// time.Sleep(3 * time.Millisecond)
+//
+// fmt.Println("putting interface back up for 3 iterations")
+// InterfaceUp(iface)
+// if iface.State == false {
+// t.Error("iface state should be true")
+// }
+// time.Sleep(3 * time.Millisecond) // allow time to print
+//
+// fmt.Println("putting interface down")
+// InterfaceDown(iface)
+// if iface.State == true {
+// t.Error("iface state should be false")
+// }
+//
+// time.Sleep(5 * time.Millisecond) // allow time to print
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+//
+// fmt.Println("TestInterfaceUpThenDownTwice successful")
+// t.Cleanup(func() { CleanUp() })
+//}
+//
+//func TestSendIPToNeighbor(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+//
+// // get the first neighbor of this interface
+// iface, err := GetInterfaceByName("if0")
+// if err != nil {
+// t.Error(err)
+// }
+// neighbors, err := GetNeighborsToInterface("if0")
+// if err != nil {
+// t.Error(err)
+// }
+//
+// // setup a neighbor listener socket
+// testNeighbor := neighbors[0]
+// // close the socket so we can listen on it
+// err = testNeighbor.SendSocket.Close()
+// if err != nil {
+// t.Error(err)
+// }
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+// fmt.Printf("Neighbors:\n%s\n", SprintNeighbors())
+//
+// listenString := testNeighbor.UdpAddr.String()
+// fmt.Println("listening on " + listenString)
+// listenAddr, err := net.ResolveUDPAddr("udp4", listenString)
+// if err != nil {
+// t.Error(err)
+// }
+// recvSocket, err := net.ListenUDP("udp4", listenAddr)
+// if err != nil {
+// t.Error(err)
+// }
+// testNeighbor.SendSocket = *recvSocket
+//
+// sent := false
+// go func() {
+// buffer := make([]byte, MAX_IP_PACKET_SIZE)
+// fmt.Println("wating to read from UDP socket")
+// _, sourceAddr, err := recvSocket.ReadFromUDP(buffer)
+// if err != nil {
+// t.Error(err)
+// }
+// fmt.Println("read from UDP socket")
+// hdr, err := ipv4header.ParseHeader(buffer)
+// if err != nil {
+// t.Error(err)
+// }
+// headerSize := hdr.Len
+// headerBytes := buffer[:headerSize]
+// checksumFromHeader := uint16(hdr.Checksum)
+// computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader)
+//
+// var checksumState string
+// if computedChecksum == checksumFromHeader {
+// checksumState = "OK"
+// } else {
+// checksumState = "FAIL"
+// }
+// message := buffer[headerSize:]
+// fmt.Printf("Received IP packet from %s\nHeader: %v\nChecksum: %s\nMessage: %s\n",
+// sourceAddr.String(), hdr, checksumState, string(message))
+// if err != nil {
+// t.Error(err)
+// }
+//
+// sent = true
+// }()
+//
+// time.Sleep(10 * time.Millisecond)
+//
+// // send a message to the neighbor
+// fmt.Printf("sending message to neighbor\t%t\n", sent)
+// err = SendIP(*iface, *testNeighbor, 0, []byte("You are my firest neighbor!"))
+// if err != nil {
+// t.Error(err)
+// }
+//
+// fmt.Printf("SENT message to neighbor\t%t\n", sent)
+// // give a little time for the message to be sent
+// time.Sleep(1000 * time.Millisecond)
+// if !sent {
+// t.Error("Message not sent")
+// t.Fail()
+// }
+//
+// fmt.Println("TestSendIPToNeighbor successful")
+// t.Cleanup(func() { CleanUp() })
+//}
+//
+//func TestRecvIP(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+//
+// // get the first neighbor of this interface to RecvIP from
+// iface, err := GetInterfaceByName("if0")
+// if err != nil {
+// t.Error(err)
+// }
+// InterfaceUp(iface)
+//
+// // setup a random socket to send an ip packet from
+// listenAddr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:6969")
+// sendSocket, err := net.ListenUDP("udp4", listenAddr)
+//
+// // send a message to the neighbor
+// ifaceAsNeighbor := Neighbor{
+// VipAddr: iface.IpPrefix.Addr(),
+// UdpAddr: iface.UdpAddr,
+// SendSocket: iface.RecvSocket,
+// SocketChannel: iface.SocketChannel,
+// }
+// fakeIface := Interface{
+// Name: "if69",
+// IpPrefix: netip.MustParsePrefix("10.69.0.1/24"),
+// UdpAddr: netip.MustParseAddrPort("127.0.0.1:6969"),
+// RecvSocket: net.UDPConn{},
+// SocketChannel: nil,
+// State: true,
+// }
+// err = SendIP(fakeIface, ifaceAsNeighbor, 0, []byte("hello"))
+// if err != nil {
+// return
+// }
+//
+// time.Sleep(10 * time.Millisecond)
+//
+// // TODO: potenially make this a channel, so it actually checks values.
+// // For now, you must read the message from the console.
+//
+// err = sendSocket.Close()
+// if err != nil {
+// t.Error(err)
+// }
+// t.Cleanup(func() { CleanUp() })
+//}
+
+func TestIntersect(t *testing.T) {
+ net1 := netip.MustParsePrefix("10.0.0.0/24")
+ net2 := netip.MustParsePrefix("1.1.1.2/24")
+ net3 := netip.MustParsePrefix("1.0.0.1/24")
+ net4 := netip.MustParsePrefix("0.0.0.0/0") // default route
+ net5 := netip.MustParsePrefix("10.2.0.3/32")
+
+ res00 := intersect(net5, net1)
+ if res00 {
+ t.Error("net5 -> net1 should not intersect")
+ t.Fail()
+ }
+ res01 := intersect(net5, net4)
+ if !res01 {
+ t.Error("net5 -> net4 should intersect")
+ t.Fail()
+ }
+
+ res6 := intersect(net1, net3)
+ if res6 {
+ t.Error("net1 and net3 should not intersect")
+ t.Fail()
+ }
+ res2 := intersect(net2, net3)
+ if res2 {
+ t.Error("net2 and net3 should not intersect")
+ t.Fail()
+ }
+ res3 := intersect(net1, net4)
+ if !res3 {
+ t.Error("net1 and net4 should intersect")
+ t.Fail()
+ }
+ res4 := intersect(net2, net4)
+ if !res4 {
+ t.Error("net2 and net4 should intersect")
+ t.Fail()
+ }
+ res5 := intersect(net3, net4)
+ if !res5 {
+ t.Error("net3 and net4 should intersect")
+ t.Fail()
+ }
+
+ fmt.Println("TestIntersect successful")
+}
+
+func intersect(n1, n2 netip.Prefix) bool {
+ return n1.Overlaps(n2)
+}
diff --git a/pkg/iptcp_utils/iptcp_utils.go b/pkg/iptcp_utils/iptcp_utils.go
new file mode 100644
index 0000000..f8caf4d
--- /dev/null
+++ b/pkg/iptcp_utils/iptcp_utils.go
@@ -0,0 +1,144 @@
+package iptcp_utils
+
+import (
+ "encoding/binary"
+ "fmt"
+ "net/netip"
+ "strings"
+
+ "github.com/google/netstack/tcpip/header"
+)
+
+const (
+ TcpHeaderLen = header.TCPMinimumSize
+ TcpPseudoHeaderLen = 12
+ IpProtoTcp = header.TCPProtocolNumber
+ MaxVirtualPacketSize = 1400
+)
+
+
+
+// Build a TCPFields struct from the TCP byte array
+//
+// NOTE: the netstack package might have other options for parsing the header
+// that you may like better--this example is most similar to our other class
+// examples. Your mileage may vary!
+func ParseTCPHeader(b []byte) header.TCPFields {
+ td := header.TCP(b)
+ return header.TCPFields{
+ SrcPort: td.SourcePort(),
+ DstPort: td.DestinationPort(),
+ SeqNum: td.SequenceNumber(),
+ AckNum: td.AckNumber(),
+ DataOffset: td.DataOffset(),
+ Flags: td.Flags(),
+ WindowSize: td.WindowSize(),
+ Checksum: td.Checksum(),
+ }
+}
+
+// The TCP checksum is computed based on a "pesudo-header" that
+// combines the (virtual) IP source and destination address, protocol value,
+// as well as the TCP header and payload
+//
+// This is one example one is way to combine all of this information
+// and compute the checksum leveraging the netstack package.
+//
+// For more details, see the "Checksum" component of RFC9293 Section 3.1,
+// https://www.rfc-editor.org/rfc/rfc9293.txt
+func ComputeTCPChecksum(tcpHdr *header.TCPFields,
+ sourceIP netip.Addr, destIP netip.Addr, payload []byte) uint16 {
+
+ // Fill in the pseudo header
+ pseudoHeaderBytes := make([]byte, TcpPseudoHeaderLen)
+
+ // First are the source and dest IPs. This function only supports
+ // IPv4, so make sure the IPs are IPv4 addresses
+ copy(pseudoHeaderBytes[0:4], sourceIP.AsSlice())
+ copy(pseudoHeaderBytes[4:8], destIP.AsSlice())
+
+ // Next, add the protocol number and header length
+ pseudoHeaderBytes[8] = uint8(0)
+ pseudoHeaderBytes[9] = uint8(IpProtoTcp)
+
+ totalLength := TcpHeaderLen + len(payload)
+ binary.BigEndian.PutUint16(pseudoHeaderBytes[10:12], uint16(totalLength))
+
+ // Turn the TcpFields struct into a byte array
+ headerBytes := header.TCP(make([]byte, TcpHeaderLen))
+ headerBytes.Encode(tcpHdr)
+
+ // Compute the checksum for each individual part and combine To combine the
+ // checksums, we leverage the "initial value" argument of the netstack's
+ // checksum package to carry over the value from the previous part
+ pseudoHeaderChecksum := header.Checksum(pseudoHeaderBytes, 0)
+ headerChecksum := header.Checksum(headerBytes, pseudoHeaderChecksum)
+ fullChecksum := header.Checksum(payload, headerChecksum)
+
+ // Return the inverse of the computed value,
+ // which seems to be the convention of the checksum algorithm
+ // in the netstack package's implementation
+ return fullChecksum ^ 0xffff
+}
+
+// Compute the checksum using the netstack package
+func ComputeIPChecksum(b []byte) uint16 {
+ checksum := header.Checksum(b, 0)
+
+ // Invert the checksum value. Why is this necessary?
+ // This function returns the inverse of the checksum
+ // on an initial computation. While this may seem weird,
+ // it makes it easier to use this same function
+ // to validate the checksum on the receiving side.
+ // See ValidateChecksum in the receiver file for details.
+ checksumInv := checksum ^ 0xffff
+
+ return checksumInv
+}
+
+// Validate the checksum using the netstack package Here, we provide both the
+// byte array for the header AND the initial checksum value that was stored in
+// the header
+//
+// "Why don't we need to set the checksum value to 0 first?"
+//
+// Normally, the checksum is computed with the checksum field of the header set
+// to 0. This library creatively avoids this step by instead subtracting the
+// initial value from the computed checksum. If you use a different language or
+// checksum function, you may need to handle this differently.
+func ValidateIPChecksum(b []byte, fromHeader uint16) uint16 {
+ checksum := header.Checksum(b, fromHeader)
+
+ return checksum
+}
+
+// Pretty-print TCP flags value as a string
+func TCPFlagsAsString(flags uint8) string {
+ strMap := map[uint8]string{
+ header.TCPFlagAck: "ACK",
+ header.TCPFlagFin: "FIN",
+ header.TCPFlagPsh: "PSH",
+ header.TCPFlagRst: "RST",
+ header.TCPFlagSyn: "SYN",
+ header.TCPFlagUrg: "URG",
+ }
+
+ matches := make([]string, 0)
+
+ for b, str := range strMap {
+ if (b & flags) == b {
+ matches = append(matches, str)
+ }
+ }
+
+ ret := strings.Join(matches, "+")
+
+ return ret
+}
+
+// Pretty-print a TCP header (with pretty-printed flags)
+// Otherwise, using %+v in format strings is a good enough view in most cases
+func TCPFieldsToString(hdr *header.TCPFields) string {
+ return fmt.Sprintf("{SrcPort:%d DstPort:%d, SeqNum:%d AckNum:%d DataOffset:%d Flags:%s WindowSize:%d Checksum:%x UrgentPointer:%d}",
+ hdr.SrcPort, hdr.DstPort, hdr.SeqNum, hdr.AckNum, hdr.DataOffset, TCPFlagsAsString(hdr.Flags), hdr.WindowSize, hdr.Checksum, hdr.UrgentPointer)
+} \ No newline at end of file
diff --git a/pkg/lnxconfig/lnxconfig.go b/pkg/lnxconfig/lnxconfig.go
new file mode 100644
index 0000000..8e43613
--- /dev/null
+++ b/pkg/lnxconfig/lnxconfig.go
@@ -0,0 +1,355 @@
+package lnxconfig
+
+import (
+ "bufio"
+ "fmt"
+ "github.com/pkg/errors"
+ "net/netip"
+ "os"
+ "strings"
+)
+
+type RoutingMode int
+
+const (
+ RoutingTypeNone RoutingMode = 0
+ RoutingTypeStatic RoutingMode = 1
+ RoutingTypeRIP RoutingMode = 2
+)
+
+type IPConfig struct {
+ Interfaces []InterfaceConfig
+ Neighbors []NeighborConfig
+
+ OriginatingPrefixes []netip.Prefix // Unused in F23, ignore.
+
+ RoutingMode RoutingMode
+
+ // ROUTERS ONLY: Neighbors to send RIP packets
+ RipNeighbors []netip.Addr
+
+ // Manually-added routes ("route" directive, usually just for default on hosts)
+ StaticRoutes map[netip.Prefix]netip.Addr
+}
+
+type InterfaceConfig struct {
+ Name string
+ AssignedIP netip.Addr
+ AssignedPrefix netip.Prefix
+
+ UDPAddr netip.AddrPort
+}
+
+type NeighborConfig struct {
+ DestAddr netip.Addr
+ UDPAddr netip.AddrPort
+
+ InterfaceName string
+}
+
+// Static config for testing
+var LnxConfig = IPConfig{
+ Interfaces: []InterfaceConfig{
+ {
+ Name: "if0",
+ AssignedIP: netip.MustParseAddr("10.1.0.1"),
+ AssignedPrefix: netip.MustParsePrefix("10.1.0.1/24"),
+ UDPAddr: netip.MustParseAddrPort("127.0.0.1:5000"),
+ },
+ {
+ Name: "if1",
+ AssignedIP: netip.MustParseAddr("10.10.1.1"),
+ AssignedPrefix: netip.MustParsePrefix("10.10.1.1/24"),
+ UDPAddr: netip.MustParseAddrPort("127.0.0.1:5001"),
+ },
+ },
+
+ Neighbors: []NeighborConfig{
+ {
+ DestAddr: netip.MustParseAddr("10.1.0.10"),
+ UDPAddr: netip.MustParseAddrPort("127.0.0.1:6001"),
+ InterfaceName: "if0",
+ },
+ {
+ DestAddr: netip.MustParseAddr("10.10.1.2"),
+ UDPAddr: netip.MustParseAddrPort("127.0.0.1:5100"),
+ InterfaceName: "if1",
+ },
+ },
+
+ OriginatingPrefixes: []netip.Prefix{
+ netip.MustParsePrefix("10.1.0.1/24"),
+ },
+
+ RoutingMode: RoutingTypeStatic,
+
+ RipNeighbors: []netip.Addr{
+ netip.MustParseAddr("10.10.1.2"),
+ },
+}
+
+// ******************** END PUBLIC INTERFACE *********************************************
+// (You shouldn't need to worry about what's below, unless you want to modify the parser.)
+
+type ParseFunc func(int, string, *IPConfig) error
+
+var parseCommands = map[string]ParseFunc{
+ "interface": parseInterface,
+ "neighbor": parseNeighbor,
+ "routing": parseRouting,
+ "route": parseRoute,
+ "rip": parseRip,
+}
+
+func parseRip(ln int, line string, config *IPConfig) error {
+ tokens := strings.Split(line, " ")
+
+ if len(tokens) < 2 {
+ return newErrString(ln, "Usage: rip [cmd] ...")
+ }
+ cmd := tokens[1]
+ ripTokens := tokens[2:]
+
+ switch cmd {
+ case "originate":
+ if len(ripTokens) < 2 && ripTokens[0] != "prefix" {
+ return newErrString(ln, "Usage: rip originate prefix <prefix>")
+ }
+ ripPrefix, err := netip.ParsePrefix(ripTokens[1])
+ if err != nil {
+ return newErr(ln, err)
+ }
+ ripPrefix = ripPrefix.Masked()
+
+ // Check if prefix is in config
+ err = addOriginatingPrefix(config, ripPrefix)
+ if err != nil {
+ return err
+ }
+ case "advertise-to":
+ if len(ripTokens) < 1 {
+ return newErrString(ln, "Usage: rip advertise-to <neighbor IP>")
+ }
+ addr, err := netip.ParseAddr(ripTokens[0])
+ if err != nil {
+ return newErr(ln, err)
+ }
+ err = addRipNeighbor(config, addr)
+ if err != nil {
+ return err
+ }
+ default:
+ return newErrString(ln, "Unrecognized RIP command %s", cmd)
+ }
+
+ return nil
+}
+
+func addOriginatingPrefix(config *IPConfig, prefix netip.Prefix) error {
+ for _, iface := range config.Interfaces {
+ if iface.AssignedPrefix == prefix {
+ config.OriginatingPrefixes = append(config.OriginatingPrefixes, prefix)
+ return nil
+ }
+ }
+
+ return errors.Errorf("No matching prefix %s in config", prefix.String())
+}
+
+func addRipNeighbor(config *IPConfig, neighbor netip.Addr) error {
+ for _, iface := range config.Neighbors {
+ if iface.DestAddr == neighbor {
+ config.RipNeighbors = append(config.RipNeighbors, neighbor)
+ return nil
+ }
+ }
+
+ return errors.Errorf("RIP neighbor %s is not a neighbor IP", neighbor.String())
+}
+
+func parseRouting(ln int, line string, config *IPConfig) error {
+ tokens := strings.Split(line, " ")
+
+ if len(tokens) < 2 {
+ return newErrString(ln, "routing directive must have format: routing <type>")
+ }
+ rt := tokens[1]
+
+ switch rt {
+ case "static":
+ config.RoutingMode = RoutingTypeStatic
+ case "rip":
+ config.RoutingMode = RoutingTypeRIP
+ default:
+ return newErrString(ln, "Invalid routing type: %s", rt)
+ }
+
+ return nil
+}
+
+func parseRoute(ln int, line string, config *IPConfig) error {
+ var sPrefix, sAddr string
+
+ format := "route <prefix> via <addr>"
+ r := strings.NewReader(line)
+ n, err := fmt.Fscanf(r, "route %s via %s", &sPrefix, &sAddr)
+
+ if err != nil {
+ return err
+ }
+
+ if n != 2 {
+ return newErrString(ln, "route directive must have format %s", format)
+ }
+
+ prefix, err := netip.ParsePrefix(sPrefix)
+ if err != nil {
+ return err
+ }
+
+ addr, err := netip.ParseAddr(sAddr)
+ if err != nil {
+ return err
+ }
+
+ config.StaticRoutes[prefix] = addr
+ return nil
+}
+
+func parseInterface(ln int, line string, config *IPConfig) error {
+ var sName, sPrefix, sBindAddr string
+
+ format := "interface <name> <prefix> <bindAddr>"
+
+ r := strings.NewReader(line)
+ n, err := fmt.Fscanf(r, "interface %s %s %s",
+ &sName, &sPrefix, &sBindAddr)
+
+ if err != nil {
+ return err
+ }
+
+ if n != 3 {
+ return newErrString(ln, "interface directive must have format: %s", format)
+ }
+
+ // Check prefix format first
+ prefix, err := netip.ParsePrefix(sPrefix)
+ if err != nil {
+ return err
+ }
+
+ addr := prefix.Addr() // Get addr part
+ prefix = prefix.Masked() // Clear add bits for prefix
+
+ addrPort, err := netip.ParseAddrPort(sBindAddr)
+ if err != nil {
+ return err
+ }
+
+ iface := InterfaceConfig{
+ Name: sName,
+ AssignedIP: addr,
+ AssignedPrefix: prefix,
+ UDPAddr: addrPort,
+ }
+
+ config.Interfaces = append(config.Interfaces, iface)
+ return nil
+}
+
+func parseNeighbor(ln int, line string, config *IPConfig) error {
+ var sDestAddr, sUDPAddr, sIfName string
+
+ format := "neighbor <vip> at <bindAddr> via <interface name>"
+
+ r := strings.NewReader(line)
+ n, err := fmt.Fscanf(r, "neighbor %s at %s via %s",
+ &sDestAddr, &sUDPAddr, &sIfName)
+
+ if err != nil {
+ return err
+ }
+
+ if n != 3 {
+ newErrString(ln, "neighbor directive must have format: %s", format)
+ }
+
+ destAddr, err := netip.ParseAddr(sDestAddr)
+ if err != nil {
+ return err
+ }
+
+ udpAddr, err := netip.ParseAddrPort(sUDPAddr)
+ if err != nil {
+ return err
+ }
+
+ neighbor := NeighborConfig{
+ DestAddr: destAddr,
+ UDPAddr: udpAddr,
+ InterfaceName: sIfName,
+ }
+
+ config.Neighbors = append(config.Neighbors, neighbor)
+
+ return nil
+}
+
+func newErrString(line int, msg string, args ...any) error {
+ _msg := fmt.Sprintf(msg, args...)
+ return errors.New(fmt.Sprintf("Parse error on line %d: %s", line, _msg))
+}
+
+func newErr(line int, err error) error {
+ return errors.New(fmt.Sprintf("Parse error on line %d: %s", line, err.Error()))
+
+}
+
+// Parse a configuration file
+func ParseConfig(configFile string) (*IPConfig, error) {
+ fd, err := os.Open(configFile)
+ if err != nil {
+ return nil, errors.New("Unable to open file")
+ }
+ defer fd.Close()
+
+ config := &IPConfig{
+ Interfaces: make([]InterfaceConfig, 0, 1),
+ Neighbors: make([]NeighborConfig, 0, 1),
+ OriginatingPrefixes: make([]netip.Prefix, 0, 1),
+
+ RipNeighbors: make([]netip.Addr, 0),
+ StaticRoutes: make(map[netip.Prefix]netip.Addr, 0),
+ }
+
+ scanner := bufio.NewScanner(fd)
+ ln := 0
+ for scanner.Scan() {
+ ln++
+
+ line := scanner.Text()
+ tokens := strings.Split(line, " ")
+
+ if len(tokens) == 0 {
+ continue
+ }
+
+ // Skip comments
+ head := tokens[0]
+ if len(head) == 0 || head == "#" || head[0] == '#' {
+ continue
+ }
+
+ pf, found := parseCommands[head]
+ if !found {
+ return nil, newErrString(ln, "Unrecognized token %s", head)
+ }
+ err = pf(ln, line, config)
+ if err != nil {
+ return nil, newErr(ln, err)
+ }
+ }
+
+ return config, nil
+}
diff --git a/pkg/routingtable/routingtable.go b/pkg/routingtable/routingtable.go
new file mode 100644
index 0000000..90b64ae
--- /dev/null
+++ b/pkg/routingtable/routingtable.go
@@ -0,0 +1,90 @@
+package routingtable
+
+import (
+ "fmt"
+ "github.com/pkg/errors"
+ "net/netip"
+)
+
+type Address struct {
+ Addr netip.Addr
+ Prefix netip.Prefix
+}
+
+type Hop struct {
+ Cost uint32
+ VipAsStr string
+}
+
+type RoutingTable map[Address]Hop
+
+const (
+ STATIC_COST uint32 = 4294967295 // 2^32 - 1
+)
+
+// TODO: consider making this take in arguments, such as a config file
+func New() *RoutingTable {
+ var table = make(RoutingTable)
+ return &table
+}
+
+//func Initialize(config lnxconfig.IPConfig) error {
+// if len(os.Args) != 2 {
+// fmt.Printf("Usage: %s <configFile>\n", os.Args[0])
+// os.Exit(1)
+// }
+// fileName := os.Args[1]
+//
+// lnxConfig, err := lnxconfig.ParseConfig(fileName)
+// if err != nil {
+// panic(err)
+// }
+//
+// // make and populate routing table
+// table = make(map[Address]Route)
+// for _, iface := range lnxConfig.Interfaces {
+// var address = Address{iface.AssignedIP, iface.AssignedPrefix}
+// var route = Route{Address{iface.AssignedIP, iface.AssignedPrefix}, 0, iface.AssignedPrefix}
+// table[address] = route
+// }
+//
+//
+//}
+
+func AddRoute(srcAddr Address, cost uint32, addrAsStr string, tableReference *RoutingTable) error {
+ if _, ok := (*tableReference)[srcAddr]; ok {
+ return errors.New("Route already exists")
+ }
+
+ (*tableReference)[srcAddr] = Hop{cost, addrAsStr}
+ return nil
+}
+
+func RemoveRoute(dest Address, table *RoutingTable) error {
+ if _, ok := (*table)[dest]; !ok {
+ return errors.New("Route doesn't exist")
+ }
+
+ delete(*table, dest)
+ return nil
+}
+
+// TODO: implement this with most specific prefix matching
+func Route(dest Address, table *RoutingTable) (Hop, error) {
+ // get the most specific route
+ for address, route := range *table {
+ if address.Prefix.Contains(dest.Addr) {
+ return route, nil
+ }
+ }
+ return Hop{}, errors.New("Route doesn't exist")
+}
+
+func SprintRoutingTable(table *RoutingTable) string {
+ message := ""
+ for address, route := range *table {
+ message += fmt.Sprintf("%s/%d\t%d\n", address.Addr, address.Prefix, route.Cost)
+ }
+
+ return message
+}
diff --git a/pkg/tcpstack/tcpstack.go b/pkg/tcpstack/tcpstack.go
new file mode 100644
index 0000000..ec451a7
--- /dev/null
+++ b/pkg/tcpstack/tcpstack.go
@@ -0,0 +1,232 @@
+package tcpstack
+
+// import (
+// // "encoding/binary"
+// "fmt"
+// // "syscall"
+// "iptcp/pkg/ipstack"
+// // ipv4header "github.com/brown-csci1680/iptcp-headers"
+// // tcpheader "github.com/brown-csci1680/iptcp-headers"
+// "github.com/google/netstack/tcpip/header"
+// "github.com/pkg/errors"
+// "iptcp/pkg/iptcp_utils"
+// // "log"
+// // "net"
+// "net/netip"
+// // "sync"
+// // "time"
+// "math/rand"
+// "sync"
+// "strings"
+// )
+
+// type ConnectionState string
+// const (
+// Established ConnectionState = "ESTABLISHED"
+// Listening ConnectionState = "LISTENING"
+// Closed ConnectionState = "CLOSED"
+// SYNSENT ConnectionState = "SYNSENT"
+// MAX_WINDOW_SIZE = 65535
+// )
+
+// // VTCPListener represents a listener socket (similar to Go’s net.TCPListener)
+// type VTCPListener struct {
+// LocalAddr string
+// LocalPort uint16
+// RemoteAddr string
+// RemotePort uint16
+// Socket int
+// State ConnectionState
+// }
+
+// // // VTCPConn represents a “normal” socket for a TCP connection between two endpoints (similar to Go’s net.TCPConn)
+// type VTCPConn struct {
+// LocalAddr string
+// LocalPort uint16
+// RemoteAddr string
+// RemotePort uint16
+// Socket int
+// State ConnectionState
+// Buffer []byte
+// }
+
+// type SocketEntry struct {
+// Socket int
+// LocalIP string
+// LocalPort uint16
+// RemoteIP string
+// RemotePort uint16
+// State ConnectionState
+// }
+
+// // create a socket map to print the local and remote ip and port as well as the state of the socket
+// var VHostSocketMaps = make(map[string]map[string]*SocketEntry)
+// var mapMutex = &sync.Mutex{}
+// var socketsMade = 0
+// var myIP = ipstack.GetInterfaces()[0].IpPrefix.Addr()
+
+// // Listen Sockets
+// func VListen(port uint16) (*VTCPListener, error) {
+// // get my ip address
+// // myIP := ipstack.GetInterfaces()[0].IpPrefix.Addr()
+// listener := &VTCPListener{
+// Socket: socketsMade,
+// State: Listening,
+// LocalPort: port,
+// LocalAddr: myIP.String(),
+// }
+
+// vhostMap, ok := VHostSocketMaps[myIP.String()]
+// if !ok {
+// vhostMap = make(map[string]*SocketEntry)
+// VHostSocketMaps[fmt.Sprintf("%s:%d", "", port)] = vhostMap
+// }
+// // add the socket to the socket map
+// mapMutex.Lock()
+// vhostMap[fmt.Sprintf("%s:%d", "", port)] = &SocketEntry{
+// Socket: socketsMade,
+// LocalIP: "0.0.0.0",
+// LocalPort: port,
+// RemoteIP: "0.0.0.0",
+// RemotePort: 0,
+// State: Listening,
+// }
+// mapMutex.Unlock()
+// socketsMade += 1
+// return listener, nil
+
+// }
+
+// func (l *VTCPListener) VAccept() (*VTCPConn, error) {
+
+// for {
+// // wait for a SYN request
+// // if there is a SYN request, create a new socket, send a SYN-ACK response, and return the socket
+// myInterface := ipstack.GetInterfaces()[0]
+// err := ipstack.RecvIP(myInterface, nil)
+// if err != nil {
+// return nil, err
+// } else {
+// break
+// }
+// }
+// socketsMade += 1
+// conn := &VTCPConn{
+// }
+
+// // add the socket to the socket map
+// // mapMutex.Lock()
+
+// return conn, nil
+// // create a new socket
+// // return the socket
+// }
+
+// func GetRandomPort() uint16 {
+// const (
+// minDynamicPort = 49152
+// maxDynamicPort = 65535
+// )
+// return uint16(rand.Intn(maxDynamicPort - minDynamicPort) + minDynamicPort)
+// }
+
+// func VConnect(ip string, port uint16) (*VTCPConn, error) {
+// // get my ip address
+// // myIP := ipstack.GetInterfaces()[0].IpPrefix.Addr()
+// // get random port
+// portRand := GetRandomPort()
+// // Create a new VTCPConn.
+// conn := &VTCPConn{
+// LocalAddr: myIP.String(),
+// LocalPort: portRand,
+// RemoteAddr: ip,
+// RemotePort: port,
+// Socket: socketsMade,
+// State: SYNSENT,
+// Buffer: []byte{},
+// }
+// // get the socket entry from the socket map in the given port
+// vhostMap, ok := VHostSocketMaps[ip]
+// if !ok {
+// return nil, errors.New("socket not found")
+// }
+
+// socketEntry, ok := vhostMap[fmt.Sprintf("%s:%d", ip, port)]
+// if !ok {
+// return nil, errors.New("socket not found")
+// }
+
+// // if the socket is closed, return an error
+// if socketEntry.State == Closed {
+// return nil, errors.New("socket is closed")
+// }
+
+// // lookup neighbor
+// neighbors := ipstack.GetNeighbors()[ipstack.GetInterfaces()[0].Name]
+// var neighbor *ipstack.Neighbor
+// for _, n := range neighbors {
+// if n.VipAddr.String() == ip {
+// neighbor = n
+// }
+// }
+
+// // Send SYN request.
+// if socketEntry.State == Listening {
+// tcpHdr := &header.TCPFields{
+// SrcPort: portRand,
+// DstPort: port,
+// SeqNum: 1,
+// AckNum: 0,
+// DataOffset: 20,
+// Flags: header.TCPFlagSyn,
+// WindowSize: MAX_WINDOW_SIZE,
+// Checksum: 0,
+// UrgentPointer: 0,
+// }
+// payload := []byte{}
+// ipParsed, err := netip.ParseAddr(ip)
+// if err != nil {
+// return nil, err
+// }
+// checksum := iptcp_utils.ComputeTCPChecksum(tcpHdr, myIP, ipParsed, []byte{})
+// tcpHdr.Checksum = checksum
+
+// tcpHeaderBytes := make(header.TCP, iptcp_utils.TcpHeaderLen)
+// tcpHeaderBytes.Encode(tcpHdr)
+
+// ipPacketPayload := make([]byte, 0, len(tcpHeaderBytes)+len(payload))
+// ipPacketPayload = append(ipPacketPayload, tcpHeaderBytes...)
+// ipPacketPayload = append(ipPacketPayload, []byte(payload)...)
+
+
+// _, err = ipstack.SendIP(&myIP, neighbor, 6, ipPacketPayload, ip, nil)
+// if err != nil {
+// return nil, err
+// }
+// }
+
+// // wait for a SYN-ACK response in the socket buffer
+// myIPinterface := ipstack.GetInterfaces()[0]
+// for {
+// err := ipstack.RecvIP(myIPinterface, nil)
+// if err != nil {
+// return nil, err
+// } else {
+// break
+// }
+// }
+
+// return conn, nil
+// }
+
+// func SprintSockets() string {
+// var socketStrings []string
+// // for _, socket := range SocketMap {
+// // socketStrings = append(socketStrings, fmt.Sprintf("%d\t%s\t%d\t%s\t%d\t%s", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State))
+// // }
+// // only print the sockets in the map with myIP as the key
+// for _, socket := range VHostSocketMaps[myIP.String()] {
+// socketStrings = append(socketStrings, fmt.Sprintf("%d\t%s\t%d\t%s\t%d\t%s", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State))
+// }
+// return strings.Join(socketStrings, "\t")
+// } \ No newline at end of file