aboutsummaryrefslogtreecommitdiff
path: root/pkg/ipstack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/ipstack')
-rw-r--r--pkg/ipstack/ipstack.go1565
-rw-r--r--pkg/ipstack/ipstack_test.go301
2 files changed, 1866 insertions, 0 deletions
diff --git a/pkg/ipstack/ipstack.go b/pkg/ipstack/ipstack.go
new file mode 100644
index 0000000..0d317c2
--- /dev/null
+++ b/pkg/ipstack/ipstack.go
@@ -0,0 +1,1565 @@
+package ipstack
+
+// code begins on line 97 after imports, constants, and structs definitions
+// This class is divided as follows:
+// 1) INIT FUNCTIONS
+// 2) DOWN/UP FUNCTIONS
+// 3) SEND/RECV FUNCTIONS
+// 4) CHECKSUM FUNCTIONS
+// 5) RIP FUNCTIONS
+// 6) PROTOCOL HANDLERS
+// 7) HELPER FUNCTIONS
+// 8) GETTER FUNCTIONS
+// 9) PRINT FUNCTIONS
+// 10) CLEANUP FUNCTION
+
+import (
+ "encoding/binary"
+ "fmt"
+ ipv4header "github.com/brown-csci1680/iptcp-headers"
+ "github.com/google/netstack/tcpip/header"
+ "github.com/pkg/errors"
+ "iptcp/pkg/lnxconfig"
+ "iptcp/pkg/iptcp_utils"
+ "log"
+ "net"
+ "net/netip"
+ "sync"
+ "time"
+ "strings"
+ "math/rand"
+ // "github.com/google/netstack/tcpip/header"
+)
+
+const (
+ MAX_IP_PACKET_SIZE = 1400
+ LOCAL_COST uint32 = 0
+ STATIC_COST uint32 = 4294967295 // 2^32 - 1
+ INFINITY = 16
+ SIZE_OF_RIP_ENTRY = 12
+ RIP_PROTOCOL = 200
+ TEST_PROTOCOL = 0
+ TCP_PROTOCOL = 6
+ SIZE_OF_RIP_HEADER = 4
+ MAX_TIMEOUT = 12
+)
+
+// STRUCTS ---------------------------------------------------------------------
+type Interface struct {
+ Name string
+ IpPrefix netip.Prefix
+ UdpAddr netip.AddrPort
+
+ Socket net.UDPConn
+ SocketChannel chan bool
+ State bool
+}
+
+type Neighbor struct {
+ Name string
+ VipAddr netip.Addr
+ UdpAddr netip.AddrPort
+}
+
+type RIPHeader struct {
+ command uint16
+ numEntries uint16
+}
+
+type RIPEntry struct {
+ prefix netip.Prefix
+ cost uint32
+}
+
+type Hop struct {
+ Cost uint32
+ Type string
+
+ Interface *Interface
+ VIP netip.Addr
+}
+
+// GLOBAL VARIABLES (data structures) ------------------------------------------
+var myInterfaces []*Interface
+var myNeighbors = make(map[string][]*Neighbor)
+
+var myRIPNeighbors = make(map[string]*Neighbor)
+
+type HandlerFunc func(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error
+
+var protocolHandlers = make(map[int]HandlerFunc)
+
+var routingTable = make(map[netip.Prefix]Hop)
+
+var timeoutTableMu sync.Mutex
+var timeoutTable = make(map[netip.Prefix]int)
+
+// ************************************** INIT FUNCTIONS **********************************************************
+// reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go
+
+// createUDPListener creates a UDP listener on the given UDP address.
+// It SETS the conn parameter to the created UDP socket.
+func createUDPListener(UdpAddr netip.AddrPort, conn *net.UDPConn) error {
+ listenString := UdpAddr.String()
+ listenAddr, err := net.ResolveUDPAddr("udp4", listenString)
+ if err != nil {
+ return errors.WithMessage(err, "Error resolving address->\t"+listenString)
+ }
+ tmpConn, err := net.ListenUDP("udp4", listenAddr)
+ if err != nil {
+ return errors.WithMessage(err, "Could not bind to UDP port->\t"+listenString)
+ }
+ *conn = *tmpConn
+
+ return nil
+}
+
+// Initialize initializes the data structures and creates the UDP sockets.
+//
+// It will return an error if the lnx file is not valid or if a socket fails to be created.
+//
+// After parsing the lnx file, it does the following:
+// 1. adds each local interface to the routing table, as dictated by its subnet
+// 2. adds neighbors to interface->neighbors[] map
+// 3. adds RIP neighbors to RIP neighbor list
+// 4. adds static routes to routing table
+func Initialize(lnxFilePath string) error {
+ // Parse the file
+ lnxConfig, err := lnxconfig.ParseConfig(lnxFilePath)
+ if err != nil {
+ return errors.WithMessage(err, "Error parsing config file->\t"+lnxFilePath)
+ }
+
+ // 1) add each local "if" to the routing table, as dictated by its subnet
+ for _, iface := range lnxConfig.Interfaces {
+ prefix := netip.PrefixFrom(iface.AssignedIP, iface.AssignedPrefix.Bits())
+ i := &Interface{
+ Name: iface.Name,
+ IpPrefix: prefix,
+ UdpAddr: iface.UDPAddr,
+ Socket: net.UDPConn{},
+ SocketChannel: make(chan bool),
+ State: true,
+ }
+
+ // create the UDP listener
+ err := createUDPListener(iface.UDPAddr, &i.Socket)
+ if err != nil {
+ return errors.WithMessage(err, "Error creating UDP socket for interface->\t"+iface.Name)
+ }
+
+ // start the listener routine
+ go InterfaceListenerRoutine(i)
+
+ // add to the list of interfaces
+ myInterfaces = append(myInterfaces, i)
+
+ // add to the routing table
+ routingTable[prefix.Masked()] = Hop{LOCAL_COST, "L", i, prefix.Addr()}
+ }
+
+ // 2) add neighbors to if->neighbors map
+ for _, neighbor := range lnxConfig.Neighbors {
+ n := &Neighbor{
+ Name: neighbor.InterfaceName,
+ VipAddr: neighbor.DestAddr,
+ UdpAddr: neighbor.UDPAddr,
+ }
+
+ myNeighbors[neighbor.InterfaceName] = append(myNeighbors[neighbor.InterfaceName], n)
+ }
+
+ // 3) add RIP neighbors to RIP neighbor list
+ for _, route := range lnxConfig.RipNeighbors {
+ // add to RIP neighbors
+ for _, iface := range myInterfaces {
+ for _, neighbor := range myNeighbors[iface.Name] {
+ if neighbor.VipAddr == route {
+ myRIPNeighbors[neighbor.VipAddr.String()] = neighbor
+ break
+ }
+ }
+ }
+ }
+
+ // 4) add static routes to routing table
+ for prefix, addr := range lnxConfig.StaticRoutes {
+ // need loops to find the interface that matches the neighbor to send static to
+ // hops needs this interface
+ for _, iface := range myInterfaces {
+ for _, neighbor := range myNeighbors[iface.Name] {
+ if neighbor.VipAddr == addr {
+ routingTable[prefix] = Hop{STATIC_COST, "S", iface, addr}
+ break
+ }
+ }
+ }
+ }
+
+ return nil
+}
+
+// InterfaceListenerRoutine is a go routine for interfaces to listen on a UDP port.
+//
+// It is composed two go routines:
+// 1. a go routine that hangs on the recv and calls RecvIP() when a packet is received
+// 2. a go routine that listens on the channel for a signal to start/stop listening
+//
+// TODO: (performance) remove isUp and use the interface's value instead
+func InterfaceListenerRoutine(i *Interface) {
+ // decompose the interface
+ socket := i.Socket
+ signal := i.SocketChannel
+
+ // booleans to control listening routine
+ isUp := true
+ closed := false
+
+ // fmt.Println("MAKING GO ROUTINE TO LISTEN:\t", socket.LocalAddr().String())
+
+ // go routine that hangs on the recv
+ go func() {
+ defer func() {
+ fmt.Println("exiting go routine that listens on ", socket.LocalAddr().String())
+ }()
+
+ for {
+ if closed { // stop this go routine if channel is closed
+ return
+ }
+ err := RecvIP(i, &isUp)
+ if err != nil {
+ continue
+ }
+ }
+ }()
+
+ for {
+ select {
+ // if the channel is closed, exit
+ case sig, ok := <-signal:
+ if !ok {
+ fmt.Println("channel closed, exiting")
+ closed = true
+ return
+ }
+ // fmt.Println("received isUP SIGNAL with value", sig)
+ isUp = sig
+ // if the channel is not closed, continue
+ default:
+ continue
+ }
+ }
+}
+
+// ************************************** DOWN/UP FUNCTIONS ******************************************************
+
+// InterfaceUp brings up the link layer
+//
+// It does the following:
+// 1. tells the listener (through a channel) to start listening
+// 2. updates the interface state to up
+// 3. sends RIP request to all neighbors of this iface to quickly update the routing table
+func InterfaceUp(iface *Interface) {
+ // set the state to up and send the signal
+ iface.State = true
+ iface.SocketChannel <- true
+
+ // if were a router, send triggered updates on up
+ if _, ok := protocolHandlers[RIP_PROTOCOL]; ok {
+ ripEntries := make([]RIPEntry, 0)
+ ripEntries = append(ripEntries, RIPEntry{iface.IpPrefix.Masked(), LOCAL_COST})
+ SendTriggeredUpdates(ripEntries)
+
+ // send a request to all neighbors of this iface to get info ASAP
+ for _, neighbor := range myNeighbors[iface.Name] {
+ message := MakeRipMessage(1, nil)
+ addr := iface.IpPrefix.Addr()
+ _, err := SendIP(&addr, neighbor, RIP_PROTOCOL, message, neighbor.VipAddr.String(), nil)
+ if err != nil {
+ fmt.Println("Error sending RIP request to neighbor on interfaceup", err)
+ }
+ }
+ }
+
+}
+
+func InterfaceUpREPL(ifaceName string) {
+ iface, err := GetInterfaceByName(ifaceName)
+ if err != nil {
+ fmt.Println("Error getting interface by name", err)
+ return
+ }
+ // set the state to up and send the signal
+ InterfaceUp(iface)
+}
+
+// InterfaceDown cuts off the link layer.
+//
+// It does the following:
+// 1. tells the listener (through a channel) to stop listening
+// 2. updates the interface state to down
+// 3. updates the routing table by removing the routes those neighbors connected to, sending triggered updates.
+func InterfaceDown(iface *Interface) {
+ // set the state to down and send the signal
+ iface.SocketChannel <- false
+ iface.State = false
+
+ // if were a router, send triggered updates on down
+ if _, ok := protocolHandlers[RIP_PROTOCOL]; ok {
+ ripEntries := make([]RIPEntry, 0)
+ ripEntries = append(ripEntries, RIPEntry{iface.IpPrefix.Masked(), INFINITY})
+ SendTriggeredUpdates(ripEntries)
+ }
+}
+
+func InterfaceDownREPL(ifaceName string) {
+ iface, err := GetInterfaceByName(ifaceName)
+ if err != nil {
+ fmt.Println("Error getting interface by name", err)
+ return
+ }
+ // set the state to down and send the signal
+ InterfaceDown(iface)
+}
+
+// ************************************** SEND/RECV FUNCTIONS *******************************************************
+
+// SendIP sends an IP packet to a destination
+//
+// If the header is nil, then a new header is created
+// If the header is not nil, then it will use that header after decrementing TTL & recomputing checksum
+//
+// TODO: (performance) have this take in an interface instead of src for performance
+func SendIP(src *netip.Addr, dest *Neighbor, protocolNum int, message []byte, destIP string, hdr *ipv4header.IPv4Header) (int, error) {
+ // check if the interface is up
+ iface, err := GetInterfaceByName(dest.Name)
+ if !iface.State {
+ return 0, errors.Errorf("error SEND: %s is down", iface.Name)
+ }
+ // if the header is nil, create a new one
+ if hdr == nil {
+ hdr = &ipv4header.IPv4Header{
+ Version: 4,
+ Len: 20, // Header length is always 20 when no IP options
+ TOS: 0,
+ TotalLen: ipv4header.HeaderLen + len(message),
+ ID: 0,
+ Flags: 0,
+ FragOff: 0,
+ TTL: 32,
+ Protocol: protocolNum,
+ Checksum: 0, // Should be 0 until checksum is computed
+ Src: *src,
+ Dst: netip.MustParseAddr(destIP),
+ Options: []byte{},
+ }
+ } else {
+ // if the header is not nil, decrement the TTL
+ hdr = &ipv4header.IPv4Header{
+ Version: 4,
+ Len: 20, // Header length is always 20 when no IP options
+ TOS: 0,
+ TotalLen: ipv4header.HeaderLen + len(message),
+ ID: 0,
+ Flags: 0,
+ FragOff: 0,
+ TTL: hdr.TTL - 1,
+ Protocol: protocolNum,
+ Checksum: 0, // Should be 0 until checksum is computed
+ Src: *src,
+ Dst: netip.MustParseAddr(destIP),
+ Options: []byte{},
+ }
+ }
+
+ // Assemble the header into a byte array
+ headerBytes, err := hdr.Marshal()
+ if err != nil {
+ return 0, err
+ }
+
+ // Compute the checksum (see below)
+ // Cast back to an int, which is what the Header structure expects
+ hdr.Checksum = int(ComputeChecksum(headerBytes))
+
+ headerBytes, err = hdr.Marshal()
+ if err != nil {
+ log.Fatalln("Error marshalling header: ", err)
+ }
+
+ // Combine the header and the message into a single byte array
+ bytesToSend := make([]byte, 0, len(headerBytes)+len(message))
+ bytesToSend = append(bytesToSend, headerBytes...)
+ bytesToSend = append(bytesToSend, []byte(message)...)
+
+ sendAddr, err := net.ResolveUDPAddr("udp4", dest.UdpAddr.String())
+ if err != nil {
+ return -1, errors.WithMessage(err, "Could not bind to UDP port->\t"+dest.UdpAddr.String())
+ }
+
+ // send the packet
+ bytesWritten, err := iface.Socket.WriteToUDP(bytesToSend, sendAddr)
+ if err != nil {
+ fmt.Println("Error writing to UDP socket")
+ return 0, errors.WithMessage(err, "Error writing to UDP socket")
+ }
+
+ return bytesWritten, nil
+}
+
+// RecvIP receives an IP packet from the interface
+// To be called by the listener routine, representing one interface
+// Upon receiving a packet, this function:
+// 1. determines if packet is valid (checksum, TTL)
+// 2. determines if the packet is for me. if so, SENDUP (call correct handler)
+// 3. the packet is not SENTUP, then checks the routing table
+// 4. if there is no route in the routing table, then prints an error and DROPS the packet
+func RecvIP(iface *Interface, isOpen *bool) error {
+ buffer := make([]byte, MAX_IP_PACKET_SIZE)
+
+ // Read on the UDP port
+ // fmt.Println("wating to read from UDP socket")
+ _, _, err := iface.Socket.ReadFromUDP(buffer)
+ if err != nil {
+ return err
+ }
+
+ // check if the interface is up
+ if !*isOpen {
+ return errors.Errorf("error RECV: %s is down", iface.Name)
+ }
+
+ // Marshal the received byte array into a UDP header
+ hdr, err := ipv4header.ParseHeader(buffer)
+ if err != nil {
+ fmt.Println("Error parsing header", err)
+ return err
+ }
+
+ // checksum validation
+ headerSize := hdr.Len
+ headerBytes := buffer[:headerSize]
+ checksumFromHeader := uint16(hdr.Checksum)
+ computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader)
+
+ var checksumState string
+ if computedChecksum == checksumFromHeader {
+ checksumState = "OK"
+ } else {
+ checksumState = "FAIL"
+ }
+
+ // Next, get the message, which starts after the header
+ messageLen := hdr.TotalLen - hdr.Len
+ message := buffer[headerSize : messageLen+headerSize]
+
+ // 1) check if the TTL & checksum is valid
+ TTL := hdr.TTL
+ if TTL == 0 {
+ // drop the packet
+ return nil
+ }
+
+ // check if the checksum is valid
+ if checksumState == "FAIL" {
+ // drop the packet
+ // fmt.Println("checksum failed, dropping packet")
+ return nil
+ }
+
+ //if hdr.Protocol != RIP_PROTOCOL {
+ // fmt.Println("I see a non-rip packet")
+ //}
+
+ // at this point, the packet is valid. next steps consider the forwarding of the packet
+
+ // 2) check if the message is for me, if so, sendUP (aka call the correct handler)
+ for _, myIface := range myInterfaces {
+ if hdr.Dst == myIface.IpPrefix.Addr() {
+ // see if there is a handler for this protocol
+ if handler, ok := protocolHandlers[hdr.Protocol]; ok {
+ if hdr.Protocol != RIP_PROTOCOL {
+ // fmt.Println("this test packet is exactly for me")
+ }
+ err := handler(myIface, message, hdr)
+ if err != nil {
+ fmt.Println(err)
+ }
+ }
+ return nil
+ }
+ }
+
+ // 3) check forwarding table.
+ // - if it's a local hop, send to that iface
+ // - if it's a RIP hop, send to the neighbor with that VIP
+ // fmt.Println("checking routing table")
+ hop, err := Route(hdr.Dst)
+ if err == nil { // on no err, found a match
+ // fmt.Println("found route", hop.VIP)
+ if hop.Type == "S" {
+ // default, static route
+ // drop in this case
+ return nil
+ }
+
+ // - local hop
+ if hop.Type == "L" {
+ // if it's a local route, then the name is the interface name
+ for _, neighbor := range myNeighbors[hop.Interface.Name] {
+ if neighbor.VipAddr == hdr.Dst {
+ _, err2 := SendIP(&hdr.Src, neighbor, hdr.Protocol, message, hdr.Dst.String(), hdr)
+ if err2 != nil {
+ return err2
+ }
+ }
+ }
+ }
+
+ // - rip hop
+ if hop.Type == "R" {
+ // if it's a rip route, then the check is against the hop vip
+ for _, neighbor := range myNeighbors[hop.Interface.Name] {
+ if neighbor.VipAddr == hop.VIP {
+ _, err2 := SendIP(&hdr.Src, neighbor, hdr.Protocol, message, hdr.Dst.String(), hdr)
+ if err2 != nil {
+ return err2
+ }
+ }
+ }
+ }
+ }
+
+ // if not in table, drop packet
+ return nil
+}
+
+// ************************************** CHECKSUM FUNCTIONS ******************************************************
+// reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go
+func ComputeChecksum(b []byte) uint16 {
+ checksum := header.Checksum(b, 0)
+ checksumInv := checksum ^ 0xffff
+
+ return checksumInv
+}
+
+func ValidateChecksum(b []byte, fromHeader uint16) uint16 {
+ checksum := header.Checksum(b, fromHeader)
+
+ return checksum
+}
+
+// ************************************** RIP FUNCTIONS *******************************************************
+
+// PeriodicUpdateRoutine sends RIP updates to neighbors every 5 seconds
+// TODO: (performace) consider making this multithreaded and loops above more efficient
+func PeriodicUpdateRoutine() {
+ for {
+ // for each periodic update, we want to send our nodes in the table
+ for _, iface := range myInterfaces {
+ for _, n := range myNeighbors[iface.Name] {
+ _, in := myRIPNeighbors[n.VipAddr.String()]
+ // if the neighbor is not a RIP neighbor, skip it
+ if !in {
+ continue
+ }
+
+ // Sending to a rip neighbor
+ // create the entries
+ entries := make([]RIPEntry, 0)
+ for prefix, hop := range routingTable {
+ // implement split horizon + poison reverse at entry level
+ var cost uint32
+ if hop.VIP == n.VipAddr {
+ cost = INFINITY
+ } else {
+ cost = hop.Cost
+ }
+ entries = append(entries,
+ RIPEntry{
+ prefix: prefix,
+ cost: cost,
+ })
+ }
+
+ // make the message and send it
+ message := MakeRipMessage(2, entries)
+ addr := iface.IpPrefix.Addr()
+ _, err := SendIP(&addr, n, RIP_PROTOCOL, message, n.VipAddr.String(), nil)
+ if err != nil {
+ // fmt.Printf("Error sending RIP message to %s\n", n.VipAddr.String())
+ continue
+ }
+ }
+ }
+
+ // wait 5 sec and repeat
+ time.Sleep(5 * time.Second)
+ }
+}
+
+// SendTriggeredUpdates sends the entries consumed to ALL neighbors
+func SendTriggeredUpdates(newEntries []RIPEntry) {
+ for _, iface := range myInterfaces {
+ for _, n := range myNeighbors[iface.Name] {
+ // only send to RIP neighbors, else skip
+ _, in := myRIPNeighbors[n.VipAddr.String()]
+ if !in {
+ continue
+ }
+
+ // send the made entries to the neighbor
+ message := MakeRipMessage(2, newEntries)
+ addr := iface.IpPrefix.Addr()
+ _, err := SendIP(&addr, n, RIP_PROTOCOL, message, n.VipAddr.String(), nil)
+ if err != nil {
+ // fmt.Printf("Error sending RIP triggered update to %s\n", n.VipAddr.String())
+ continue
+ }
+ }
+ }
+}
+
+// ManageTimeoutsRoutine manages the timeout table by incrementing the timeouts every second.
+// If a timeout reaches MAX_TIMEOUT, then the entry is deleted from the routing table and a triggered update is sent.
+func ManageTimeoutsRoutine() {
+ for {
+ time.Sleep(time.Second)
+
+ timeoutTableMu.Lock()
+ // check if any timeouts have occurred
+ for key, _ := range timeoutTable {
+ timeoutTable[key]++
+ // if the timeout is MAX_TIMEOUT, delete the entry
+ if timeoutTable[key] == MAX_TIMEOUT {
+ delete(timeoutTable, key)
+
+ newEntries := make([]RIPEntry, 0)
+ delete(routingTable, key)
+ newEntries = append(newEntries, RIPEntry{key, INFINITY})
+
+ // send triggered update on timeout
+ if len(newEntries) > 0 {
+ SendTriggeredUpdates(newEntries)
+ }
+ }
+ }
+ timeoutTableMu.Unlock()
+ //fmt.Println("Timeout table: ", timeoutTable)
+ }
+}
+
+// StartRipRoutines handles all the routines for RIP
+// 1. sends a RIP request to every neighbor
+// 2. starts the routine that sends periodic updates every 5 seconds
+// 3. starts the routine that manages the timeout table
+func StartRipRoutines() {
+ // send a request to every neighbor
+ go func() {
+ for _, iface := range myInterfaces {
+ for _, neighbor := range myNeighbors[iface.Name] {
+ // only send to RIP neighbors, else skip
+ _, in := myRIPNeighbors[neighbor.VipAddr.String()]
+ if !in {
+ continue
+ }
+ // send a request
+ message := MakeRipMessage(1, nil)
+ addr := iface.IpPrefix.Addr()
+ _, err := SendIP(&addr, neighbor, RIP_PROTOCOL, message, neighbor.VipAddr.String(), nil)
+ if err != nil {
+ return
+ }
+ }
+ }
+ }()
+
+ // start a routine that sends updates every 5 seconds
+ go PeriodicUpdateRoutine()
+
+ // make a "timeout" table, for each response we add to the table via rip
+ go ManageTimeoutsRoutine()
+}
+
+// ************************************** PROTOCOL HANDLERS *******************************************************
+
+// RegisterProtocolHandler registers a protocol handler for a given protocol number
+// Returns true if the protocol number is valid, false otherwise
+func RegisterProtocolHandler(protocolNum int) bool {
+ switch protocolNum {
+ case RIP_PROTOCOL:
+ protocolHandlers[protocolNum] = HandleRIP
+ go StartRipRoutines()
+ return true
+ case TEST_PROTOCOL:
+ protocolHandlers[protocolNum] = HandleTestPackets
+ return true
+ case TCP_PROTOCOL:
+ protocolHandlers[protocolNum] = HandleTCP
+ return true
+ default:
+ return false
+ }
+}
+
+// HandleRIP handles incoming RIP packets in the following way:
+// 1. if the command is a request, send a RIP response only to that requestor
+// 2. if the command is a response, parse the entries, update the routing table from them,
+// and send applicable triggered updates (see implementation for how to update)
+func HandleRIP(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error {
+ // parse the RIP message
+ command := int(binary.BigEndian.Uint16(message[0:2]))
+ switch command {
+ // request message
+ case 1:
+ //fmt.Println("Received RIP command for specific info")
+
+ // only send if the person asking is a RIP neighbor
+ neighbor, in := myRIPNeighbors[hdr.Src.String()]
+ if !in {
+ break
+ }
+
+ // build the entries
+ entries := make([]RIPEntry, 0)
+ for prefix, hop := range routingTable {
+ // implement split horizon + poison reverse at entry level
+ var cost uint32
+ if hop.VIP == hdr.Src {
+ cost = INFINITY
+ } else {
+ cost = hop.Cost
+ }
+ entries = append(entries,
+ RIPEntry{
+ prefix: prefix,
+ cost: cost,
+ })
+ }
+ // send the entries
+ res := MakeRipMessage(2, entries)
+ _, err := SendIP(&hdr.Dst, neighbor, RIP_PROTOCOL, res, hdr.Src.String(), nil)
+ if err != nil {
+ return err
+ }
+ break
+ // response message
+ case 2:
+ // fmt.Println("Received RIP response with", numEntries, "entries")
+ numEntries := int(binary.BigEndian.Uint16(message[2:4]))
+
+ // parse the entries
+ entries := make([]RIPEntry, 0)
+ for i := 0; i < numEntries; i++ {
+ offset := SIZE_OF_RIP_HEADER + i*SIZE_OF_RIP_ENTRY
+
+ // each field is 4 bytes
+ cost := binary.BigEndian.Uint32(message[offset : offset+4])
+ address, _ := netip.AddrFromSlice(message[offset+4 : offset+8])
+ mask := net.IPv4Mask(message[offset+8], message[offset+9], message[offset+10], message[offset+11])
+
+ // make the prefix
+ bits, _ := mask.Size()
+ prefix := netip.PrefixFrom(address, bits)
+
+ entries = append(entries, RIPEntry{prefix, cost})
+ }
+
+ // update the routing table
+ triggeredEntries := make([]RIPEntry, 0)
+ for _, entry := range entries {
+ destination := entry.prefix.Masked()
+
+ // make upperbound for cost infinity
+ var newCost uint32
+ if entry.cost == INFINITY {
+ newCost = INFINITY
+ } else {
+ newCost = entry.cost + 1
+ }
+
+ hop, isin := routingTable[destination]
+ // if prefix not in table, add it (as long as it's not infinity)
+ if !isin {
+ if newCost != INFINITY {
+ // given an update to table, this is now a triggeredUpdate
+ // triggeredEntries = append(triggeredEntries, RIPEntry{destination, entry.cost + 1})
+
+ routingTable[destination] = Hop{newCost, "R", src, hdr.Src}
+ timeoutTable[destination] = 0
+ }
+ continue
+ }
+
+ // if the entry is in the table, only two cases affect the table:
+ // 1) the entry SRC is updating (or confirming) the hop to itself
+ // in this case, only update if the cost is different
+ // if it's infinity, then the route has expired.
+ // we must set the cost to INF then delete the entry after 12 seconds
+ //
+ // 2) a different entry SRC reveals a shorter path to the destination
+ // in this case, update the routing table to use this new path
+ //
+ // all other cases don't meaningfully change the route
+
+ // first, upon an update from this prefix, reset its timeout
+ if hop.Type == "R" {
+ timeoutTableMu.Lock()
+ _, in := timeoutTable[destination]
+ if in {
+ if routingTable[destination].VIP == hdr.Src {
+ timeoutTable[destination] = 0
+ }
+ }
+ timeoutTableMu.Unlock()
+ }
+
+ // case 1) the entry SRC == the hop to itself
+ if hop.VIP == hdr.Src &&
+ newCost != hop.Cost {
+ // given an update to table, this is now a triggeredUpdate
+ triggeredEntries = append(triggeredEntries, RIPEntry{destination, newCost})
+ routingTable[destination] = Hop{newCost, "R", src, hop.VIP}
+
+ // if we receive infinity from the same neighbor, then delete the route after 12 sec
+ if entry.cost == INFINITY {
+ // remove after GC time if the COST is still INFINITY
+ go func() {
+ time.Sleep(time.Second * time.Duration(MAX_TIMEOUT))
+ if routingTable[destination].Cost == INFINITY {
+ delete(routingTable, destination)
+ timeoutTableMu.Lock()
+ delete(timeoutTable, destination)
+ timeoutTableMu.Unlock()
+ }
+ }()
+ }
+ continue
+ }
+
+ // case 2) a shorter route for this destination is revealed from a different neighbor
+ if newCost < hop.Cost && newCost != INFINITY {
+ triggeredEntries = append(triggeredEntries, RIPEntry{destination, entry.cost + 1})
+ routingTable[destination] = Hop{entry.cost + 1, "R", src, hdr.Src}
+ continue
+ }
+ }
+
+ // send out triggered updates
+ if len(triggeredEntries) > 0 {
+ SendTriggeredUpdates(triggeredEntries)
+ }
+ }
+
+ return nil
+}
+
+// prints the test packet as per the spec
+func HandleTestPackets(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error {
+ fmt.Printf("Received test packet: Src: %s, Dst: %s, TTL: %d, Data: %s\n",
+ hdr.Src.String(), hdr.Dst.String(), hdr.TTL, string(message))
+ return nil
+}
+
+func HandleTCP(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error {
+ fmt.Println("I see a TCP packet")
+
+ tcpHeaderAndData := message
+ tcpHdr := iptcp_utils.ParseTCPHeader(tcpHeaderAndData)
+ tcpPayload := tcpHeaderAndData[tcpHdr.DataOffset:]
+ tcpChecksumFromHeader := tcpHdr.Checksum
+ tcpHdr.Checksum = 0
+ tcpComputedChecksum := iptcp_utils.ComputeTCPChecksum(&tcpHdr, hdr.Src, hdr.Dst, tcpPayload)
+
+ var tcpChecksumState string
+ if tcpComputedChecksum == tcpChecksumFromHeader {
+ tcpChecksumState = "OK"
+ } else {
+ tcpChecksumState = "FAIL"
+ }
+
+ if tcpChecksumState == "FAIL" {
+ // drop the packet
+ fmt.Println("checksum failed, dropping packet")
+ return nil
+ }
+
+ switch tcpHdr.Flags {
+ case header.TCPFlagSyn:
+ fmt.Println("I see a SYN flag")
+ // if the SYN flag is set, then send a SYNACK
+ available := false
+
+ socketEntry, in := VHostSocketMaps[SocketKey{hdr.Dst.String(), tcpHdr.DstPort, hdr.Src.String(), tcpHdr.SrcPort}]
+ if !in {
+ fmt.Println("no socket entry found")
+ } else if socketEntry.State == Established {
+ fmt.Println("socket entry found")
+
+ // make ack header
+ tcpHdr := &header.TCPFields{
+ SrcPort: tcpHdr.DstPort,
+ DstPort: tcpHdr.SrcPort,
+ SeqNum: tcpHdr.SeqNum,
+ AckNum: tcpHdr.SeqNum + 1,
+ DataOffset: 20,
+ Flags: 0x10,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ // make the payload
+ err := SendTCP(tcpHdr, message, hdr.Dst, hdr.Src)
+ if err != nil {
+ fmt.Println(err)
+ }
+ socketEntry.Conn.RecvBuffer.buffer = append(socketEntry.Conn.RecvBuffer.buffer, tcpPayload...)
+ socketEntry.Conn.RecvBuffer.recvNext += uint32(len(tcpPayload))
+ break
+ }
+ // add to table if available
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ // todo: check between all 4 field in tuple
+ if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == Listening{
+ // add a new socketEntry to the map
+ newEntry := &SocketEntry{
+ LocalPort: tcpHdr.DstPort,
+ RemotePort: tcpHdr.SrcPort,
+ LocalIP: hdr.Dst.String(),
+ RemoteIP: hdr.Src.String(),
+ State: SYNRECIEVED,
+ Socket: socketsMade,
+ }
+ // add the entry to the map
+ key := SocketKey{hdr.Dst.String(), tcpHdr.DstPort, hdr.Src.String(), tcpHdr.SrcPort}
+ VHostSocketMaps[key] = newEntry
+ socketsMade += 1
+ // add the entry to the map
+ available = true
+ break
+ }
+ }
+ mapMutex.Unlock()
+
+ // if no socket is available, then drop the packet
+ if !available {
+ fmt.Println("no socket available")
+ return nil
+ }
+ // make the header
+ tcpHdr := &header.TCPFields{
+ SrcPort: tcpHdr.DstPort,
+ DstPort: tcpHdr.SrcPort,
+ SeqNum: tcpHdr.SeqNum,
+ AckNum: tcpHdr.SeqNum + 1,
+ DataOffset: 20,
+ Flags: 0x12,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ // make the payload
+ synAckPayload := []byte{}
+ err := SendTCP(tcpHdr, synAckPayload, hdr.Dst, hdr.Src)
+ if err != nil {
+ fmt.Println(err)
+ }
+ break
+ case header.TCPFlagAck | header.TCPFlagSyn:
+ fmt.Println("I see a SYNACK flag")
+ // lookup for socket entry and update its state
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == SYNSENT {
+ socketEntry.State = Established
+ break
+ }
+ }
+ mapMutex.Unlock()
+
+ // send an ACK
+ // make the header
+ tcpHdr := &header.TCPFields{
+ SrcPort: tcpHdr.DstPort,
+ DstPort: tcpHdr.SrcPort,
+ SeqNum: tcpHdr.SeqNum + 1,
+ AckNum: tcpHdr.SeqNum,
+ DataOffset: 20,
+ Flags: 0x10,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ // make the payload
+ ackPayload := []byte{}
+ err := SendTCP(tcpHdr, ackPayload, hdr.Dst, hdr.Src)
+ if err != nil {
+ fmt.Println(err)
+ }
+ break
+ case header.TCPFlagAck:
+ fmt.Println("I see an ACK flag")
+ // lookup for socket entry and update its state
+ // set synChan to true (TODO)
+ key := SocketKey{hdr.Dst.String(), tcpHdr.DstPort, hdr.Src.String(), tcpHdr.SrcPort}
+ socketEntry, in := VHostSocketMaps[key]
+ if !in {
+ fmt.Println("no socket entry found")
+ } else if (socketEntry.State == Established) {
+ fmt.Println("socket entry found")
+ // socketEntry.Conn.RecvBuffer.buffer = append(socketEntry.Conn.RecvBuffer.buffer, tcpPayload...)
+ socketEntry.Conn.SendBuffer.una += uint32(len(tcpPayload))
+ break
+ }
+
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == SYNRECIEVED {
+ socketEntry.State = Established
+ break
+ }
+ }
+ mapMutex.Unlock()
+ break
+ default:
+ fmt.Println("I see a non TCP packet")
+ break
+ }
+
+
+ return nil
+}
+
+// *********************************************** HELPERS **********************************************************
+
+// Route returns the next HOP, based on longest prefix match for a given ip
+// TODO: revisit how to do this at the bit level, not hardcoded for 32 & 24
+func Route(src netip.Addr) (Hop, error) {
+ possibleBits := [2]int{32, 24}
+ for _, bits := range possibleBits {
+ cmpPrefix := netip.PrefixFrom(src, bits)
+ for prefix, hop := range routingTable {
+ if cmpPrefix.Overlaps(prefix) {
+ return hop, nil
+ }
+ }
+ }
+ return Hop{}, errors.Errorf("error ROUTE: destination %s does not exist on routing table.", src)
+}
+
+// MakeRipMessage returns the byte array to be used in SendIp for a RIP packet
+func MakeRipMessage(command uint16, entries []RIPEntry) []byte {
+ if command == 1 { // request message
+ buf := make([]byte, SIZE_OF_RIP_HEADER)
+ binary.BigEndian.PutUint16(buf[0:2], command)
+ binary.BigEndian.PutUint16(buf[2:4], uint16(0))
+ return buf
+ }
+
+ // command == 2, response message
+
+ // create the buffer
+ bufLen := SIZE_OF_RIP_HEADER + // sizeof uint16 is 2, we have two of them
+ len(entries)*SIZE_OF_RIP_ENTRY // each entry is 12
+
+ buf := make([]byte, bufLen)
+
+ // fill in the header
+ binary.BigEndian.PutUint16(buf[0:2], command)
+ binary.BigEndian.PutUint16(buf[2:4], uint16(len(entries)))
+
+ // fill in the entries
+ for i, entry := range entries {
+ offset := SIZE_OF_RIP_HEADER + i*SIZE_OF_RIP_ENTRY
+ binary.BigEndian.PutUint32(buf[offset:offset+4], entry.cost) // 0-3 = 4 bytes
+ copy(buf[offset+4:offset+8], entry.prefix.Addr().AsSlice()) // 4-7 = 4 bytes
+
+ // convert the prefix to a uint32
+ ipv4Netmask := uint32(0xffffffff)
+ ipv4Netmask <<= 32 - entry.prefix.Bits()
+ binary.BigEndian.PutUint32(buf[offset+8:offset+12], ipv4Netmask)
+ }
+
+ return buf
+}
+
+// ************************************** GETTER FUNCTIONS **********************************************************
+func GetInterfaceByName(ifaceName string) (*Interface, error) {
+ // iterate through the interfaces and return the one with the same name
+ for _, iface := range myInterfaces {
+ if iface.Name == ifaceName {
+ return iface, nil
+ }
+ }
+ return nil, errors.Errorf("No interface with name %s", ifaceName)
+}
+
+func GetInterfaces() []*Interface {
+ return myInterfaces
+}
+
+func GetNeighbors() map[string][]*Neighbor {
+ return myNeighbors
+}
+
+func GetRoutes() map[netip.Prefix]Hop {
+ return routingTable
+}
+
+// ************************************** PRINT FUNCTIONS **********************************************************
+
+// SprintInterfaces returns a string representation of the interfaces data structure
+func SprintInterfaces() string {
+ tmp := ""
+ for _, iface := range myInterfaces {
+ if iface.State {
+ // if the state is up, print UP
+ tmp += fmt.Sprintf("%s\t%s\t%s\n", iface.Name, iface.IpPrefix.String(), "UP")
+ } else {
+ // if the state is down, print DOWN
+ tmp += fmt.Sprintf("%s\t%s\t%s\n", iface.Name, iface.IpPrefix.String(), "DOWN")
+ }
+ }
+ return tmp
+}
+
+// SprintNeighbors returns a string representation of the neighbors data structure
+func SprintNeighbors() string {
+ tmp := ""
+ for _, iface := range myInterfaces {
+ if !iface.State {
+ // if the interface is down, skip it
+ continue
+ }
+ for _, n := range myNeighbors[iface.Name] {
+ tmp += fmt.Sprintf("%s\t%s\t%s\n", iface.Name, n.VipAddr.String(), n.UdpAddr.String())
+ }
+ }
+ return tmp
+}
+
+// SprintRoutingTable returns a string representation of the routing table
+func SprintRoutingTable() string {
+ tmp := ""
+ for prefix, hop := range routingTable {
+ if hop.Type == "L" {
+ // if the hop is local, print LOCAL
+ tmp += fmt.Sprintf("%s\t%s\tLOCAL:%s\t%d\n", hop.Type, prefix.String(), hop.Interface.Name, hop.Cost)
+ } else if hop.Type == "S" {
+ // if the hop is static, don't print the cost
+ tmp += fmt.Sprintf("%s\t%s\t%s\t%s\n", hop.Type, prefix.String(), hop.VIP.String(), "-")
+ } else {
+ tmp += fmt.Sprintf("%s\t%s\t%s\t%d\n", hop.Type, prefix.String(), hop.VIP.String(), hop.Cost)
+ }
+ }
+ return tmp
+}
+
+// ************************************** CLEANUP FUNCTIONS **********************************************************
+
+// CleanUp cleans up the data structures and closes the UDP sockets
+func CleanUp() {
+ fmt.Print("Cleaning up...\n")
+
+ // go through the interfaces, pop thread & close the UDP FDs
+ for _, iface := range myInterfaces {
+ // close the channel
+ if iface.SocketChannel != nil {
+ close(iface.SocketChannel)
+ }
+ // close the UDP FD
+ err := iface.Socket.Close()
+ if err != nil {
+ continue
+ }
+ }
+
+ // delete all the neighbors
+ myNeighbors = make(map[string][]*Neighbor)
+ // delete all the interfaces
+ myInterfaces = nil
+ // delete the routing table
+ routingTable = make(map[netip.Prefix]Hop)
+
+ time.Sleep(5 * time.Millisecond)
+}
+
+// ************************************** TCP FUNCTIONS **********************************************************
+
+type ConnectionState string
+const (
+ Established ConnectionState = "ESTABLISHED"
+ Listening ConnectionState = "LISTENING"
+ Closed ConnectionState = "CLOSED"
+ SYNSENT ConnectionState = "SYNSENT"
+ SYNRECIEVED ConnectionState = "SYNRECIEVED"
+ MAX_WINDOW_SIZE = 65535
+)
+
+// VTCPListener represents a listener socket (similar to Go’s net.TCPListener)
+type VTCPListener struct {
+ LocalAddr string
+ LocalPort uint16
+ RemoteAddr string
+ RemotePort uint16
+ Socket int
+ State ConnectionState
+}
+
+// // VTCPConn represents a “normal” socket for a TCP connection between two endpoints (similar to Go’s net.TCPConn)
+type VTCPConn struct {
+ LocalAddr string
+ LocalPort uint16
+ RemoteAddr string
+ RemotePort uint16
+ Socket int
+ State ConnectionState
+ SendBuffer *SendBuffer
+ RecvBuffer *RecvBuffer
+}
+
+type SocketEntry struct {
+ Socket int
+ LocalIP string
+ LocalPort uint16
+ RemoteIP string
+ RemotePort uint16
+ State ConnectionState
+ Conn *VTCPConn
+}
+
+type SocketKey struct {
+ LocalIP string
+ LocalPort uint16
+ RemoteIP string
+ RemotePort uint16
+}
+
+type RecvBuffer struct {
+ recvNext uint32
+ lbr uint32
+ buffer []byte
+}
+
+type SendBuffer struct {
+ una uint32
+ nxt uint32
+ lbr uint32
+ buffer []byte
+}
+
+// create a socket map
+// var VHostSocketMaps = make(map[int]*SocketEntry)
+var VHostSocketMaps = make(map[SocketKey]*SocketEntry)
+// create a channel map
+var VHostChannelMaps = make(map[int]chan []byte)
+var mapMutex = &sync.Mutex{}
+var socketsMade = 0
+var startingSeqNum = rand.Uint32()
+
+// Listen Sockets
+func VListen(port uint16) (*VTCPListener, error) {
+ myIP := GetInterfaces()[0].IpPrefix.Addr()
+ listener := &VTCPListener{
+ Socket: socketsMade,
+ State: Listening,
+ LocalPort: port,
+ LocalAddr: myIP.String(),
+ }
+
+ // add the socket to the socket map
+ mapMutex.Lock()
+
+ key := SocketKey{myIP.String(), port, "", 0}
+ VHostSocketMaps[key] = &SocketEntry{
+ Socket: socketsMade,
+ LocalIP: myIP.String(),
+ LocalPort: port,
+ RemoteIP: "0.0.0.0",
+ RemotePort: 0,
+ State: Listening,
+ }
+ mapMutex.Unlock()
+ socketsMade += 1
+ return listener, nil
+
+}
+
+func (l *VTCPListener) VAccept() (*VTCPConn, error) {
+ // synChan = make(chan bool)
+ for {
+ // wait for a SYN request
+ mapMutex.Lock()
+ for _, socketEntry := range VHostSocketMaps {
+ if socketEntry.State == Established {
+ // create a new VTCPConn
+ conn := &VTCPConn{
+ LocalAddr: socketEntry.LocalIP,
+ LocalPort: socketEntry.LocalPort,
+ RemoteAddr: socketEntry.RemoteIP,
+ RemotePort: socketEntry.RemotePort,
+ Socket: socketEntry.Socket,
+ State: Established,
+ SendBuffer: &SendBuffer{
+ una: 0,
+ nxt: 0,
+ lbr: 0,
+ buffer: make([]byte, MAX_WINDOW_SIZE),
+ },
+ RecvBuffer: &RecvBuffer{
+ recvNext: 0,
+ lbr: 0,
+ buffer: make([]byte, MAX_WINDOW_SIZE),
+ },
+ }
+ socketEntry.Conn = conn
+ mapMutex.Unlock()
+ return conn, nil
+ }
+ }
+ mapMutex.Unlock()
+
+ }
+}
+
+func GetRandomPort() uint16 {
+ const (
+ minDynamicPort = 49152
+ maxDynamicPort = 65535
+ )
+ return uint16(rand.Intn(maxDynamicPort - minDynamicPort) + minDynamicPort)
+}
+
+func VConnect(ip string, port uint16) (*VTCPConn, error) {
+ // get my ip address
+ myIP := GetInterfaces()[0].IpPrefix.Addr()
+ // get random port
+ portRand := GetRandomPort()
+
+ tcpHdr := &header.TCPFields{
+ SrcPort: portRand,
+ DstPort: port,
+ SeqNum: startingSeqNum,
+ AckNum: 0,
+ DataOffset: 20,
+ Flags: header.TCPFlagSyn,
+ WindowSize: MAX_WINDOW_SIZE,
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+ payload := []byte{}
+ ipParsed, err := netip.ParseAddr(ip)
+ if err != nil {
+ return nil, err
+ }
+
+ err = SendTCP(tcpHdr, payload, myIP, ipParsed)
+ if err != nil {
+ return nil, err
+ }
+
+ conn := &VTCPConn{
+ LocalAddr: myIP.String(),
+ LocalPort: portRand,
+ RemoteAddr: ip,
+ RemotePort: port,
+ Socket: socketsMade,
+ State: Established,
+ SendBuffer: &SendBuffer{
+ una: 0,
+ nxt: 0,
+ lbr: 0,
+ buffer: make([]byte, MAX_WINDOW_SIZE),
+ },
+ RecvBuffer: &RecvBuffer{
+ recvNext: 0,
+ lbr: 0,
+ buffer: make([]byte, MAX_WINDOW_SIZE),
+ },
+ }
+
+ // add the socket to the socket map
+ key := SocketKey{myIP.String(), portRand, ip, port}
+ mapMutex.Lock()
+ VHostSocketMaps[key] = &SocketEntry{
+ Socket: socketsMade,
+ LocalIP: myIP.String(),
+ LocalPort: portRand,
+ RemoteIP: ip,
+ RemotePort: port,
+ State: SYNSENT,
+ Conn: conn,
+ }
+ mapMutex.Unlock()
+ socketsMade += 1
+
+ return conn, nil
+}
+
+func SendTCP(tcpHdr *header.TCPFields, payload []byte, myIP netip.Addr, ipParsed netip.Addr) error {
+ checksum := iptcp_utils.ComputeTCPChecksum(tcpHdr, myIP, ipParsed, payload)
+ tcpHdr.Checksum = checksum
+
+ tcpHeaderBytes := make(header.TCP, iptcp_utils.TcpHeaderLen)
+ tcpHeaderBytes.Encode(tcpHdr)
+
+ ipPacketPayload := make([]byte, 0, len(tcpHeaderBytes)+len(payload))
+ ipPacketPayload = append(ipPacketPayload, tcpHeaderBytes...)
+ ipPacketPayload = append(ipPacketPayload, []byte(payload)...)
+
+ // lookup neighbor
+ address := ipParsed
+ hop, err := Route(address)
+ if err != nil {
+ fmt.Println(err)
+ return err
+ }
+ myAddr := hop.Interface.IpPrefix.Addr()
+
+ for _, neighbor := range GetNeighbors()[hop.Interface.Name] {
+ if neighbor.VipAddr == address ||
+ neighbor.VipAddr == hop.VIP && hop.Type == "S" {
+ bytesWritten, err := SendIP(&myAddr, neighbor, TCP_PROTOCOL, ipPacketPayload, ipParsed.String(), nil)
+ fmt.Printf("Sent %d bytes to %s\n", bytesWritten, neighbor.VipAddr.String())
+ if err != nil {
+ fmt.Println(err)
+ }
+ }
+ }
+ return nil
+}
+
+func SprintSockets() string {
+ tmp := ""
+ for _, socket := range VHostSocketMaps {
+ // remove the spaces of the local and remote ip variables
+ socket.LocalIP = strings.ReplaceAll(socket.LocalIP, " ", "")
+ socket.RemoteIP = strings.ReplaceAll(socket.RemoteIP, " ", "")
+ if socket.RemotePort == 0 {
+ tmp += fmt.Sprintf("%d\t%s\t%d\t%s\t\t%d\t%s\n", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State)
+ continue
+ }
+ tmp += fmt.Sprintf("%d\t%s\t%d\t%s\t%d\t%s\n", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State)
+ }
+ return tmp
+}
+
+// MILESTONE 2
+func (c *VTCPConn) VClose() error {
+ // check if the socket is in the map
+ key := SocketKey{c.LocalAddr, c.LocalPort, c.RemoteAddr, c.RemotePort}
+ mapMutex.Lock()
+ socketEntry, in := VHostSocketMaps[key]
+ mapMutex.Unlock()
+ if !in {
+ return errors.Errorf("error VClose: socket %d does not exist", c.Socket)
+ }
+
+ // change the state to closed
+ socketEntry.State = Closed
+ return nil
+}
+
+
+// advertise window = max window size - (next - 1 - lbr)
+
+// early arrivals queue
+var earlyArrivals = make([][]byte, 0)
+
+// retranmission queue
+var retransmissionQueue = make([][]byte, 0)
+
+func (c *VTCPConn) VWrite(payload []byte) (int, error) {
+ // check if the socket is in the map
+ key := SocketKey{c.LocalAddr, c.LocalPort, c.RemoteAddr, c.RemotePort}
+ mapMutex.Lock()
+ socketEntry, in := VHostSocketMaps[key]
+ mapMutex.Unlock()
+ if !in {
+ return 0, errors.Errorf("error VWrite: socket %d does not exist", c.Socket)
+ }
+
+ // check if the state is established
+ if socketEntry.State != Established {
+ return 0, errors.Errorf("error VWrite: socket %d is not in established state", c.Socket)
+ }
+
+ // check if the payload is empty
+ if len(payload) == 0 {
+ return 0, nil
+ }
+
+ // check if the payload is larger than the window size
+ if len(payload) > MAX_WINDOW_SIZE {
+ return 0, errors.Errorf("error VWrite: payload is larger than the window size")
+ }
+
+ // check if the payload is larger than the available window size
+ if len(payload) > int(MAX_WINDOW_SIZE - (c.SendBuffer.nxt - 1 - c.SendBuffer.lbr)) {
+ return 0, errors.Errorf("error VWrite: payload is larger than the available window size")
+ }
+
+ // make the header
+ advertisedWindow := MAX_WINDOW_SIZE - (c.SendBuffer.nxt - 1 - c.SendBuffer.lbr)
+ tcpHdr := &header.TCPFields{
+ SrcPort: c.LocalPort,
+ DstPort: c.RemotePort,
+ SeqNum: c.SendBuffer.nxt,
+ AckNum: c.SendBuffer.una,
+ DataOffset: 20,
+ Flags: header.TCPFlagSyn,
+ WindowSize: uint16(advertisedWindow),
+ Checksum: 0,
+ UrgentPointer: 0,
+ }
+
+ myIP := GetInterfaces()[0].IpPrefix.Addr()
+ ipParsed, err := netip.ParseAddr(c.RemoteAddr)
+ if err != nil {
+ return 0, err
+ }
+
+ err = SendTCP(tcpHdr, payload, myIP, ipParsed)
+ if err != nil {
+ return 0, err
+ }
+ // update the next sequence number
+ // c.SendBuffer.nxt += uint32(len(payload))
+
+
+ c.SendBuffer.lbr += uint32(len(payload))
+ return len(payload), nil
+}
+
+
+func (c *VTCPConn) VRead(numBytesToRead int) (int, string, error) {
+ // check if the socket is in the map
+ key := SocketKey{c.LocalAddr, c.LocalPort, c.RemoteAddr, c.RemotePort}
+ // mapMutex.Lock()
+ socketEntry, in := VHostSocketMaps[key]
+ // mapMutex.Unlock()
+ // check if the socket is in the map
+ if !in {
+ return 0, "", errors.Errorf("error VRead: socket %d does not exist", c.Socket)
+ }
+
+ // check if the state is established
+ if socketEntry.State != Established {
+ return 0, "", errors.Errorf("error VRead: socket %d is not in established state", c.Socket)
+ }
+ fmt.Println("I am in VRead")
+ fmt.Println("I have", c.RecvBuffer.recvNext - c.RecvBuffer.lbr, "bytes to read")
+ fmt.Println(c.RecvBuffer.recvNext, c.RecvBuffer.lbr)
+ if (c.RecvBuffer.lbr < c.RecvBuffer.recvNext && c.RecvBuffer.recvNext - c.RecvBuffer.lbr >= uint32(numBytesToRead)) {
+ fmt.Println("I have enough data to read")
+ toReturn := string(socketEntry.Conn.RecvBuffer.buffer[c.RecvBuffer.lbr:c.RecvBuffer.lbr+uint32(numBytesToRead)])
+ // update the last byte read
+ c.RecvBuffer.lbr += uint32(numBytesToRead)
+ // return the data
+ return numBytesToRead, toReturn, nil
+ }
+
+ return 0, "", nil
+} \ No newline at end of file
diff --git a/pkg/ipstack/ipstack_test.go b/pkg/ipstack/ipstack_test.go
new file mode 100644
index 0000000..e782f67
--- /dev/null
+++ b/pkg/ipstack/ipstack_test.go
@@ -0,0 +1,301 @@
+package ipstack
+
+import (
+ "fmt"
+ "net/netip"
+ "testing"
+)
+
+//func TestInitialize(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+// fmt.Printf("Interfaces:\n%s\n\n", SprintInterfaces())
+// fmt.Printf("Neighbors:\n%s\n", SprintNeighbors())
+// fmt.Printf("RoutingTable:\n%s\n", SprintRoutingTable())
+//
+// fmt.Println("TestInitialize successful")
+// t.Cleanup(func() { CleanUp() })
+//}
+//
+//func TestInterfaceUpThenDown(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+//
+// iface, err := GetInterfaceByName("if0")
+// if err != nil {
+// t.Error(err)
+// }
+//
+// InterfaceUp(iface)
+// if iface.State == false {
+// t.Error("iface state should be true")
+// }
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+//
+// time.Sleep(5 * time.Millisecond) // allow time to print
+//
+// InterfaceDown(iface)
+// if iface.State == true {
+// t.Error("iface state should be false")
+// }
+//
+// time.Sleep(5 * time.Millisecond) // allow time to print
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+//
+// fmt.Println("TestInterfaceUpThenDown successful")
+// t.Cleanup(func() { CleanUp() })
+//}
+//
+//func TestInterfaceUpThenDownTwice(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+//
+// iface, err := GetInterfaceByName("if0")
+// if err != nil {
+// t.Error(err)
+// }
+//
+// InterfaceUp(iface)
+// if iface.State == false {
+// t.Error("iface state should be true")
+// }
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+//
+// time.Sleep(5 * time.Millisecond) // allow time to print
+//
+// fmt.Println("putting interface down")
+// InterfaceDown(iface)
+// if iface.State == true {
+// t.Error("iface state should be false")
+// }
+//
+// time.Sleep(3 * time.Millisecond)
+//
+// fmt.Println("putting interface back up for 3 iterations")
+// InterfaceUp(iface)
+// if iface.State == false {
+// t.Error("iface state should be true")
+// }
+// time.Sleep(3 * time.Millisecond) // allow time to print
+//
+// fmt.Println("putting interface down")
+// InterfaceDown(iface)
+// if iface.State == true {
+// t.Error("iface state should be false")
+// }
+//
+// time.Sleep(5 * time.Millisecond) // allow time to print
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+//
+// fmt.Println("TestInterfaceUpThenDownTwice successful")
+// t.Cleanup(func() { CleanUp() })
+//}
+//
+//func TestSendIPToNeighbor(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+//
+// // get the first neighbor of this interface
+// iface, err := GetInterfaceByName("if0")
+// if err != nil {
+// t.Error(err)
+// }
+// neighbors, err := GetNeighborsToInterface("if0")
+// if err != nil {
+// t.Error(err)
+// }
+//
+// // setup a neighbor listener socket
+// testNeighbor := neighbors[0]
+// // close the socket so we can listen on it
+// err = testNeighbor.SendSocket.Close()
+// if err != nil {
+// t.Error(err)
+// }
+//
+// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+// fmt.Printf("Neighbors:\n%s\n", SprintNeighbors())
+//
+// listenString := testNeighbor.UdpAddr.String()
+// fmt.Println("listening on " + listenString)
+// listenAddr, err := net.ResolveUDPAddr("udp4", listenString)
+// if err != nil {
+// t.Error(err)
+// }
+// recvSocket, err := net.ListenUDP("udp4", listenAddr)
+// if err != nil {
+// t.Error(err)
+// }
+// testNeighbor.SendSocket = *recvSocket
+//
+// sent := false
+// go func() {
+// buffer := make([]byte, MAX_IP_PACKET_SIZE)
+// fmt.Println("wating to read from UDP socket")
+// _, sourceAddr, err := recvSocket.ReadFromUDP(buffer)
+// if err != nil {
+// t.Error(err)
+// }
+// fmt.Println("read from UDP socket")
+// hdr, err := ipv4header.ParseHeader(buffer)
+// if err != nil {
+// t.Error(err)
+// }
+// headerSize := hdr.Len
+// headerBytes := buffer[:headerSize]
+// checksumFromHeader := uint16(hdr.Checksum)
+// computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader)
+//
+// var checksumState string
+// if computedChecksum == checksumFromHeader {
+// checksumState = "OK"
+// } else {
+// checksumState = "FAIL"
+// }
+// message := buffer[headerSize:]
+// fmt.Printf("Received IP packet from %s\nHeader: %v\nChecksum: %s\nMessage: %s\n",
+// sourceAddr.String(), hdr, checksumState, string(message))
+// if err != nil {
+// t.Error(err)
+// }
+//
+// sent = true
+// }()
+//
+// time.Sleep(10 * time.Millisecond)
+//
+// // send a message to the neighbor
+// fmt.Printf("sending message to neighbor\t%t\n", sent)
+// err = SendIP(*iface, *testNeighbor, 0, []byte("You are my firest neighbor!"))
+// if err != nil {
+// t.Error(err)
+// }
+//
+// fmt.Printf("SENT message to neighbor\t%t\n", sent)
+// // give a little time for the message to be sent
+// time.Sleep(1000 * time.Millisecond)
+// if !sent {
+// t.Error("Message not sent")
+// t.Fail()
+// }
+//
+// fmt.Println("TestSendIPToNeighbor successful")
+// t.Cleanup(func() { CleanUp() })
+//}
+//
+//func TestRecvIP(t *testing.T) {
+// lnxFilePath := "../../doc-example/r2.lnx"
+// err := Initialize(lnxFilePath)
+// if err != nil {
+// t.Error(err)
+// }
+//
+// // get the first neighbor of this interface to RecvIP from
+// iface, err := GetInterfaceByName("if0")
+// if err != nil {
+// t.Error(err)
+// }
+// InterfaceUp(iface)
+//
+// // setup a random socket to send an ip packet from
+// listenAddr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:6969")
+// sendSocket, err := net.ListenUDP("udp4", listenAddr)
+//
+// // send a message to the neighbor
+// ifaceAsNeighbor := Neighbor{
+// VipAddr: iface.IpPrefix.Addr(),
+// UdpAddr: iface.UdpAddr,
+// SendSocket: iface.RecvSocket,
+// SocketChannel: iface.SocketChannel,
+// }
+// fakeIface := Interface{
+// Name: "if69",
+// IpPrefix: netip.MustParsePrefix("10.69.0.1/24"),
+// UdpAddr: netip.MustParseAddrPort("127.0.0.1:6969"),
+// RecvSocket: net.UDPConn{},
+// SocketChannel: nil,
+// State: true,
+// }
+// err = SendIP(fakeIface, ifaceAsNeighbor, 0, []byte("hello"))
+// if err != nil {
+// return
+// }
+//
+// time.Sleep(10 * time.Millisecond)
+//
+// // TODO: potenially make this a channel, so it actually checks values.
+// // For now, you must read the message from the console.
+//
+// err = sendSocket.Close()
+// if err != nil {
+// t.Error(err)
+// }
+// t.Cleanup(func() { CleanUp() })
+//}
+
+func TestIntersect(t *testing.T) {
+ net1 := netip.MustParsePrefix("10.0.0.0/24")
+ net2 := netip.MustParsePrefix("1.1.1.2/24")
+ net3 := netip.MustParsePrefix("1.0.0.1/24")
+ net4 := netip.MustParsePrefix("0.0.0.0/0") // default route
+ net5 := netip.MustParsePrefix("10.2.0.3/32")
+
+ res00 := intersect(net5, net1)
+ if res00 {
+ t.Error("net5 -> net1 should not intersect")
+ t.Fail()
+ }
+ res01 := intersect(net5, net4)
+ if !res01 {
+ t.Error("net5 -> net4 should intersect")
+ t.Fail()
+ }
+
+ res6 := intersect(net1, net3)
+ if res6 {
+ t.Error("net1 and net3 should not intersect")
+ t.Fail()
+ }
+ res2 := intersect(net2, net3)
+ if res2 {
+ t.Error("net2 and net3 should not intersect")
+ t.Fail()
+ }
+ res3 := intersect(net1, net4)
+ if !res3 {
+ t.Error("net1 and net4 should intersect")
+ t.Fail()
+ }
+ res4 := intersect(net2, net4)
+ if !res4 {
+ t.Error("net2 and net4 should intersect")
+ t.Fail()
+ }
+ res5 := intersect(net3, net4)
+ if !res5 {
+ t.Error("net3 and net4 should intersect")
+ t.Fail()
+ }
+
+ fmt.Println("TestIntersect successful")
+}
+
+func intersect(n1, n2 netip.Prefix) bool {
+ return n1.Overlaps(n2)
+}