diff options
Diffstat (limited to 'pkg/ipstack')
-rw-r--r-- | pkg/ipstack/ipstack.go | 1565 | ||||
-rw-r--r-- | pkg/ipstack/ipstack_test.go | 301 |
2 files changed, 1866 insertions, 0 deletions
diff --git a/pkg/ipstack/ipstack.go b/pkg/ipstack/ipstack.go new file mode 100644 index 0000000..0d317c2 --- /dev/null +++ b/pkg/ipstack/ipstack.go @@ -0,0 +1,1565 @@ +package ipstack + +// code begins on line 97 after imports, constants, and structs definitions +// This class is divided as follows: +// 1) INIT FUNCTIONS +// 2) DOWN/UP FUNCTIONS +// 3) SEND/RECV FUNCTIONS +// 4) CHECKSUM FUNCTIONS +// 5) RIP FUNCTIONS +// 6) PROTOCOL HANDLERS +// 7) HELPER FUNCTIONS +// 8) GETTER FUNCTIONS +// 9) PRINT FUNCTIONS +// 10) CLEANUP FUNCTION + +import ( + "encoding/binary" + "fmt" + ipv4header "github.com/brown-csci1680/iptcp-headers" + "github.com/google/netstack/tcpip/header" + "github.com/pkg/errors" + "iptcp/pkg/lnxconfig" + "iptcp/pkg/iptcp_utils" + "log" + "net" + "net/netip" + "sync" + "time" + "strings" + "math/rand" + // "github.com/google/netstack/tcpip/header" +) + +const ( + MAX_IP_PACKET_SIZE = 1400 + LOCAL_COST uint32 = 0 + STATIC_COST uint32 = 4294967295 // 2^32 - 1 + INFINITY = 16 + SIZE_OF_RIP_ENTRY = 12 + RIP_PROTOCOL = 200 + TEST_PROTOCOL = 0 + TCP_PROTOCOL = 6 + SIZE_OF_RIP_HEADER = 4 + MAX_TIMEOUT = 12 +) + +// STRUCTS --------------------------------------------------------------------- +type Interface struct { + Name string + IpPrefix netip.Prefix + UdpAddr netip.AddrPort + + Socket net.UDPConn + SocketChannel chan bool + State bool +} + +type Neighbor struct { + Name string + VipAddr netip.Addr + UdpAddr netip.AddrPort +} + +type RIPHeader struct { + command uint16 + numEntries uint16 +} + +type RIPEntry struct { + prefix netip.Prefix + cost uint32 +} + +type Hop struct { + Cost uint32 + Type string + + Interface *Interface + VIP netip.Addr +} + +// GLOBAL VARIABLES (data structures) ------------------------------------------ +var myInterfaces []*Interface +var myNeighbors = make(map[string][]*Neighbor) + +var myRIPNeighbors = make(map[string]*Neighbor) + +type HandlerFunc func(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error + +var protocolHandlers = make(map[int]HandlerFunc) + +var routingTable = make(map[netip.Prefix]Hop) + +var timeoutTableMu sync.Mutex +var timeoutTable = make(map[netip.Prefix]int) + +// ************************************** INIT FUNCTIONS ********************************************************** +// reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go + +// createUDPListener creates a UDP listener on the given UDP address. +// It SETS the conn parameter to the created UDP socket. +func createUDPListener(UdpAddr netip.AddrPort, conn *net.UDPConn) error { + listenString := UdpAddr.String() + listenAddr, err := net.ResolveUDPAddr("udp4", listenString) + if err != nil { + return errors.WithMessage(err, "Error resolving address->\t"+listenString) + } + tmpConn, err := net.ListenUDP("udp4", listenAddr) + if err != nil { + return errors.WithMessage(err, "Could not bind to UDP port->\t"+listenString) + } + *conn = *tmpConn + + return nil +} + +// Initialize initializes the data structures and creates the UDP sockets. +// +// It will return an error if the lnx file is not valid or if a socket fails to be created. +// +// After parsing the lnx file, it does the following: +// 1. adds each local interface to the routing table, as dictated by its subnet +// 2. adds neighbors to interface->neighbors[] map +// 3. adds RIP neighbors to RIP neighbor list +// 4. adds static routes to routing table +func Initialize(lnxFilePath string) error { + // Parse the file + lnxConfig, err := lnxconfig.ParseConfig(lnxFilePath) + if err != nil { + return errors.WithMessage(err, "Error parsing config file->\t"+lnxFilePath) + } + + // 1) add each local "if" to the routing table, as dictated by its subnet + for _, iface := range lnxConfig.Interfaces { + prefix := netip.PrefixFrom(iface.AssignedIP, iface.AssignedPrefix.Bits()) + i := &Interface{ + Name: iface.Name, + IpPrefix: prefix, + UdpAddr: iface.UDPAddr, + Socket: net.UDPConn{}, + SocketChannel: make(chan bool), + State: true, + } + + // create the UDP listener + err := createUDPListener(iface.UDPAddr, &i.Socket) + if err != nil { + return errors.WithMessage(err, "Error creating UDP socket for interface->\t"+iface.Name) + } + + // start the listener routine + go InterfaceListenerRoutine(i) + + // add to the list of interfaces + myInterfaces = append(myInterfaces, i) + + // add to the routing table + routingTable[prefix.Masked()] = Hop{LOCAL_COST, "L", i, prefix.Addr()} + } + + // 2) add neighbors to if->neighbors map + for _, neighbor := range lnxConfig.Neighbors { + n := &Neighbor{ + Name: neighbor.InterfaceName, + VipAddr: neighbor.DestAddr, + UdpAddr: neighbor.UDPAddr, + } + + myNeighbors[neighbor.InterfaceName] = append(myNeighbors[neighbor.InterfaceName], n) + } + + // 3) add RIP neighbors to RIP neighbor list + for _, route := range lnxConfig.RipNeighbors { + // add to RIP neighbors + for _, iface := range myInterfaces { + for _, neighbor := range myNeighbors[iface.Name] { + if neighbor.VipAddr == route { + myRIPNeighbors[neighbor.VipAddr.String()] = neighbor + break + } + } + } + } + + // 4) add static routes to routing table + for prefix, addr := range lnxConfig.StaticRoutes { + // need loops to find the interface that matches the neighbor to send static to + // hops needs this interface + for _, iface := range myInterfaces { + for _, neighbor := range myNeighbors[iface.Name] { + if neighbor.VipAddr == addr { + routingTable[prefix] = Hop{STATIC_COST, "S", iface, addr} + break + } + } + } + } + + return nil +} + +// InterfaceListenerRoutine is a go routine for interfaces to listen on a UDP port. +// +// It is composed two go routines: +// 1. a go routine that hangs on the recv and calls RecvIP() when a packet is received +// 2. a go routine that listens on the channel for a signal to start/stop listening +// +// TODO: (performance) remove isUp and use the interface's value instead +func InterfaceListenerRoutine(i *Interface) { + // decompose the interface + socket := i.Socket + signal := i.SocketChannel + + // booleans to control listening routine + isUp := true + closed := false + + // fmt.Println("MAKING GO ROUTINE TO LISTEN:\t", socket.LocalAddr().String()) + + // go routine that hangs on the recv + go func() { + defer func() { + fmt.Println("exiting go routine that listens on ", socket.LocalAddr().String()) + }() + + for { + if closed { // stop this go routine if channel is closed + return + } + err := RecvIP(i, &isUp) + if err != nil { + continue + } + } + }() + + for { + select { + // if the channel is closed, exit + case sig, ok := <-signal: + if !ok { + fmt.Println("channel closed, exiting") + closed = true + return + } + // fmt.Println("received isUP SIGNAL with value", sig) + isUp = sig + // if the channel is not closed, continue + default: + continue + } + } +} + +// ************************************** DOWN/UP FUNCTIONS ****************************************************** + +// InterfaceUp brings up the link layer +// +// It does the following: +// 1. tells the listener (through a channel) to start listening +// 2. updates the interface state to up +// 3. sends RIP request to all neighbors of this iface to quickly update the routing table +func InterfaceUp(iface *Interface) { + // set the state to up and send the signal + iface.State = true + iface.SocketChannel <- true + + // if were a router, send triggered updates on up + if _, ok := protocolHandlers[RIP_PROTOCOL]; ok { + ripEntries := make([]RIPEntry, 0) + ripEntries = append(ripEntries, RIPEntry{iface.IpPrefix.Masked(), LOCAL_COST}) + SendTriggeredUpdates(ripEntries) + + // send a request to all neighbors of this iface to get info ASAP + for _, neighbor := range myNeighbors[iface.Name] { + message := MakeRipMessage(1, nil) + addr := iface.IpPrefix.Addr() + _, err := SendIP(&addr, neighbor, RIP_PROTOCOL, message, neighbor.VipAddr.String(), nil) + if err != nil { + fmt.Println("Error sending RIP request to neighbor on interfaceup", err) + } + } + } + +} + +func InterfaceUpREPL(ifaceName string) { + iface, err := GetInterfaceByName(ifaceName) + if err != nil { + fmt.Println("Error getting interface by name", err) + return + } + // set the state to up and send the signal + InterfaceUp(iface) +} + +// InterfaceDown cuts off the link layer. +// +// It does the following: +// 1. tells the listener (through a channel) to stop listening +// 2. updates the interface state to down +// 3. updates the routing table by removing the routes those neighbors connected to, sending triggered updates. +func InterfaceDown(iface *Interface) { + // set the state to down and send the signal + iface.SocketChannel <- false + iface.State = false + + // if were a router, send triggered updates on down + if _, ok := protocolHandlers[RIP_PROTOCOL]; ok { + ripEntries := make([]RIPEntry, 0) + ripEntries = append(ripEntries, RIPEntry{iface.IpPrefix.Masked(), INFINITY}) + SendTriggeredUpdates(ripEntries) + } +} + +func InterfaceDownREPL(ifaceName string) { + iface, err := GetInterfaceByName(ifaceName) + if err != nil { + fmt.Println("Error getting interface by name", err) + return + } + // set the state to down and send the signal + InterfaceDown(iface) +} + +// ************************************** SEND/RECV FUNCTIONS ******************************************************* + +// SendIP sends an IP packet to a destination +// +// If the header is nil, then a new header is created +// If the header is not nil, then it will use that header after decrementing TTL & recomputing checksum +// +// TODO: (performance) have this take in an interface instead of src for performance +func SendIP(src *netip.Addr, dest *Neighbor, protocolNum int, message []byte, destIP string, hdr *ipv4header.IPv4Header) (int, error) { + // check if the interface is up + iface, err := GetInterfaceByName(dest.Name) + if !iface.State { + return 0, errors.Errorf("error SEND: %s is down", iface.Name) + } + // if the header is nil, create a new one + if hdr == nil { + hdr = &ipv4header.IPv4Header{ + Version: 4, + Len: 20, // Header length is always 20 when no IP options + TOS: 0, + TotalLen: ipv4header.HeaderLen + len(message), + ID: 0, + Flags: 0, + FragOff: 0, + TTL: 32, + Protocol: protocolNum, + Checksum: 0, // Should be 0 until checksum is computed + Src: *src, + Dst: netip.MustParseAddr(destIP), + Options: []byte{}, + } + } else { + // if the header is not nil, decrement the TTL + hdr = &ipv4header.IPv4Header{ + Version: 4, + Len: 20, // Header length is always 20 when no IP options + TOS: 0, + TotalLen: ipv4header.HeaderLen + len(message), + ID: 0, + Flags: 0, + FragOff: 0, + TTL: hdr.TTL - 1, + Protocol: protocolNum, + Checksum: 0, // Should be 0 until checksum is computed + Src: *src, + Dst: netip.MustParseAddr(destIP), + Options: []byte{}, + } + } + + // Assemble the header into a byte array + headerBytes, err := hdr.Marshal() + if err != nil { + return 0, err + } + + // Compute the checksum (see below) + // Cast back to an int, which is what the Header structure expects + hdr.Checksum = int(ComputeChecksum(headerBytes)) + + headerBytes, err = hdr.Marshal() + if err != nil { + log.Fatalln("Error marshalling header: ", err) + } + + // Combine the header and the message into a single byte array + bytesToSend := make([]byte, 0, len(headerBytes)+len(message)) + bytesToSend = append(bytesToSend, headerBytes...) + bytesToSend = append(bytesToSend, []byte(message)...) + + sendAddr, err := net.ResolveUDPAddr("udp4", dest.UdpAddr.String()) + if err != nil { + return -1, errors.WithMessage(err, "Could not bind to UDP port->\t"+dest.UdpAddr.String()) + } + + // send the packet + bytesWritten, err := iface.Socket.WriteToUDP(bytesToSend, sendAddr) + if err != nil { + fmt.Println("Error writing to UDP socket") + return 0, errors.WithMessage(err, "Error writing to UDP socket") + } + + return bytesWritten, nil +} + +// RecvIP receives an IP packet from the interface +// To be called by the listener routine, representing one interface +// Upon receiving a packet, this function: +// 1. determines if packet is valid (checksum, TTL) +// 2. determines if the packet is for me. if so, SENDUP (call correct handler) +// 3. the packet is not SENTUP, then checks the routing table +// 4. if there is no route in the routing table, then prints an error and DROPS the packet +func RecvIP(iface *Interface, isOpen *bool) error { + buffer := make([]byte, MAX_IP_PACKET_SIZE) + + // Read on the UDP port + // fmt.Println("wating to read from UDP socket") + _, _, err := iface.Socket.ReadFromUDP(buffer) + if err != nil { + return err + } + + // check if the interface is up + if !*isOpen { + return errors.Errorf("error RECV: %s is down", iface.Name) + } + + // Marshal the received byte array into a UDP header + hdr, err := ipv4header.ParseHeader(buffer) + if err != nil { + fmt.Println("Error parsing header", err) + return err + } + + // checksum validation + headerSize := hdr.Len + headerBytes := buffer[:headerSize] + checksumFromHeader := uint16(hdr.Checksum) + computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader) + + var checksumState string + if computedChecksum == checksumFromHeader { + checksumState = "OK" + } else { + checksumState = "FAIL" + } + + // Next, get the message, which starts after the header + messageLen := hdr.TotalLen - hdr.Len + message := buffer[headerSize : messageLen+headerSize] + + // 1) check if the TTL & checksum is valid + TTL := hdr.TTL + if TTL == 0 { + // drop the packet + return nil + } + + // check if the checksum is valid + if checksumState == "FAIL" { + // drop the packet + // fmt.Println("checksum failed, dropping packet") + return nil + } + + //if hdr.Protocol != RIP_PROTOCOL { + // fmt.Println("I see a non-rip packet") + //} + + // at this point, the packet is valid. next steps consider the forwarding of the packet + + // 2) check if the message is for me, if so, sendUP (aka call the correct handler) + for _, myIface := range myInterfaces { + if hdr.Dst == myIface.IpPrefix.Addr() { + // see if there is a handler for this protocol + if handler, ok := protocolHandlers[hdr.Protocol]; ok { + if hdr.Protocol != RIP_PROTOCOL { + // fmt.Println("this test packet is exactly for me") + } + err := handler(myIface, message, hdr) + if err != nil { + fmt.Println(err) + } + } + return nil + } + } + + // 3) check forwarding table. + // - if it's a local hop, send to that iface + // - if it's a RIP hop, send to the neighbor with that VIP + // fmt.Println("checking routing table") + hop, err := Route(hdr.Dst) + if err == nil { // on no err, found a match + // fmt.Println("found route", hop.VIP) + if hop.Type == "S" { + // default, static route + // drop in this case + return nil + } + + // - local hop + if hop.Type == "L" { + // if it's a local route, then the name is the interface name + for _, neighbor := range myNeighbors[hop.Interface.Name] { + if neighbor.VipAddr == hdr.Dst { + _, err2 := SendIP(&hdr.Src, neighbor, hdr.Protocol, message, hdr.Dst.String(), hdr) + if err2 != nil { + return err2 + } + } + } + } + + // - rip hop + if hop.Type == "R" { + // if it's a rip route, then the check is against the hop vip + for _, neighbor := range myNeighbors[hop.Interface.Name] { + if neighbor.VipAddr == hop.VIP { + _, err2 := SendIP(&hdr.Src, neighbor, hdr.Protocol, message, hdr.Dst.String(), hdr) + if err2 != nil { + return err2 + } + } + } + } + } + + // if not in table, drop packet + return nil +} + +// ************************************** CHECKSUM FUNCTIONS ****************************************************** +// reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go +func ComputeChecksum(b []byte) uint16 { + checksum := header.Checksum(b, 0) + checksumInv := checksum ^ 0xffff + + return checksumInv +} + +func ValidateChecksum(b []byte, fromHeader uint16) uint16 { + checksum := header.Checksum(b, fromHeader) + + return checksum +} + +// ************************************** RIP FUNCTIONS ******************************************************* + +// PeriodicUpdateRoutine sends RIP updates to neighbors every 5 seconds +// TODO: (performace) consider making this multithreaded and loops above more efficient +func PeriodicUpdateRoutine() { + for { + // for each periodic update, we want to send our nodes in the table + for _, iface := range myInterfaces { + for _, n := range myNeighbors[iface.Name] { + _, in := myRIPNeighbors[n.VipAddr.String()] + // if the neighbor is not a RIP neighbor, skip it + if !in { + continue + } + + // Sending to a rip neighbor + // create the entries + entries := make([]RIPEntry, 0) + for prefix, hop := range routingTable { + // implement split horizon + poison reverse at entry level + var cost uint32 + if hop.VIP == n.VipAddr { + cost = INFINITY + } else { + cost = hop.Cost + } + entries = append(entries, + RIPEntry{ + prefix: prefix, + cost: cost, + }) + } + + // make the message and send it + message := MakeRipMessage(2, entries) + addr := iface.IpPrefix.Addr() + _, err := SendIP(&addr, n, RIP_PROTOCOL, message, n.VipAddr.String(), nil) + if err != nil { + // fmt.Printf("Error sending RIP message to %s\n", n.VipAddr.String()) + continue + } + } + } + + // wait 5 sec and repeat + time.Sleep(5 * time.Second) + } +} + +// SendTriggeredUpdates sends the entries consumed to ALL neighbors +func SendTriggeredUpdates(newEntries []RIPEntry) { + for _, iface := range myInterfaces { + for _, n := range myNeighbors[iface.Name] { + // only send to RIP neighbors, else skip + _, in := myRIPNeighbors[n.VipAddr.String()] + if !in { + continue + } + + // send the made entries to the neighbor + message := MakeRipMessage(2, newEntries) + addr := iface.IpPrefix.Addr() + _, err := SendIP(&addr, n, RIP_PROTOCOL, message, n.VipAddr.String(), nil) + if err != nil { + // fmt.Printf("Error sending RIP triggered update to %s\n", n.VipAddr.String()) + continue + } + } + } +} + +// ManageTimeoutsRoutine manages the timeout table by incrementing the timeouts every second. +// If a timeout reaches MAX_TIMEOUT, then the entry is deleted from the routing table and a triggered update is sent. +func ManageTimeoutsRoutine() { + for { + time.Sleep(time.Second) + + timeoutTableMu.Lock() + // check if any timeouts have occurred + for key, _ := range timeoutTable { + timeoutTable[key]++ + // if the timeout is MAX_TIMEOUT, delete the entry + if timeoutTable[key] == MAX_TIMEOUT { + delete(timeoutTable, key) + + newEntries := make([]RIPEntry, 0) + delete(routingTable, key) + newEntries = append(newEntries, RIPEntry{key, INFINITY}) + + // send triggered update on timeout + if len(newEntries) > 0 { + SendTriggeredUpdates(newEntries) + } + } + } + timeoutTableMu.Unlock() + //fmt.Println("Timeout table: ", timeoutTable) + } +} + +// StartRipRoutines handles all the routines for RIP +// 1. sends a RIP request to every neighbor +// 2. starts the routine that sends periodic updates every 5 seconds +// 3. starts the routine that manages the timeout table +func StartRipRoutines() { + // send a request to every neighbor + go func() { + for _, iface := range myInterfaces { + for _, neighbor := range myNeighbors[iface.Name] { + // only send to RIP neighbors, else skip + _, in := myRIPNeighbors[neighbor.VipAddr.String()] + if !in { + continue + } + // send a request + message := MakeRipMessage(1, nil) + addr := iface.IpPrefix.Addr() + _, err := SendIP(&addr, neighbor, RIP_PROTOCOL, message, neighbor.VipAddr.String(), nil) + if err != nil { + return + } + } + } + }() + + // start a routine that sends updates every 5 seconds + go PeriodicUpdateRoutine() + + // make a "timeout" table, for each response we add to the table via rip + go ManageTimeoutsRoutine() +} + +// ************************************** PROTOCOL HANDLERS ******************************************************* + +// RegisterProtocolHandler registers a protocol handler for a given protocol number +// Returns true if the protocol number is valid, false otherwise +func RegisterProtocolHandler(protocolNum int) bool { + switch protocolNum { + case RIP_PROTOCOL: + protocolHandlers[protocolNum] = HandleRIP + go StartRipRoutines() + return true + case TEST_PROTOCOL: + protocolHandlers[protocolNum] = HandleTestPackets + return true + case TCP_PROTOCOL: + protocolHandlers[protocolNum] = HandleTCP + return true + default: + return false + } +} + +// HandleRIP handles incoming RIP packets in the following way: +// 1. if the command is a request, send a RIP response only to that requestor +// 2. if the command is a response, parse the entries, update the routing table from them, +// and send applicable triggered updates (see implementation for how to update) +func HandleRIP(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error { + // parse the RIP message + command := int(binary.BigEndian.Uint16(message[0:2])) + switch command { + // request message + case 1: + //fmt.Println("Received RIP command for specific info") + + // only send if the person asking is a RIP neighbor + neighbor, in := myRIPNeighbors[hdr.Src.String()] + if !in { + break + } + + // build the entries + entries := make([]RIPEntry, 0) + for prefix, hop := range routingTable { + // implement split horizon + poison reverse at entry level + var cost uint32 + if hop.VIP == hdr.Src { + cost = INFINITY + } else { + cost = hop.Cost + } + entries = append(entries, + RIPEntry{ + prefix: prefix, + cost: cost, + }) + } + // send the entries + res := MakeRipMessage(2, entries) + _, err := SendIP(&hdr.Dst, neighbor, RIP_PROTOCOL, res, hdr.Src.String(), nil) + if err != nil { + return err + } + break + // response message + case 2: + // fmt.Println("Received RIP response with", numEntries, "entries") + numEntries := int(binary.BigEndian.Uint16(message[2:4])) + + // parse the entries + entries := make([]RIPEntry, 0) + for i := 0; i < numEntries; i++ { + offset := SIZE_OF_RIP_HEADER + i*SIZE_OF_RIP_ENTRY + + // each field is 4 bytes + cost := binary.BigEndian.Uint32(message[offset : offset+4]) + address, _ := netip.AddrFromSlice(message[offset+4 : offset+8]) + mask := net.IPv4Mask(message[offset+8], message[offset+9], message[offset+10], message[offset+11]) + + // make the prefix + bits, _ := mask.Size() + prefix := netip.PrefixFrom(address, bits) + + entries = append(entries, RIPEntry{prefix, cost}) + } + + // update the routing table + triggeredEntries := make([]RIPEntry, 0) + for _, entry := range entries { + destination := entry.prefix.Masked() + + // make upperbound for cost infinity + var newCost uint32 + if entry.cost == INFINITY { + newCost = INFINITY + } else { + newCost = entry.cost + 1 + } + + hop, isin := routingTable[destination] + // if prefix not in table, add it (as long as it's not infinity) + if !isin { + if newCost != INFINITY { + // given an update to table, this is now a triggeredUpdate + // triggeredEntries = append(triggeredEntries, RIPEntry{destination, entry.cost + 1}) + + routingTable[destination] = Hop{newCost, "R", src, hdr.Src} + timeoutTable[destination] = 0 + } + continue + } + + // if the entry is in the table, only two cases affect the table: + // 1) the entry SRC is updating (or confirming) the hop to itself + // in this case, only update if the cost is different + // if it's infinity, then the route has expired. + // we must set the cost to INF then delete the entry after 12 seconds + // + // 2) a different entry SRC reveals a shorter path to the destination + // in this case, update the routing table to use this new path + // + // all other cases don't meaningfully change the route + + // first, upon an update from this prefix, reset its timeout + if hop.Type == "R" { + timeoutTableMu.Lock() + _, in := timeoutTable[destination] + if in { + if routingTable[destination].VIP == hdr.Src { + timeoutTable[destination] = 0 + } + } + timeoutTableMu.Unlock() + } + + // case 1) the entry SRC == the hop to itself + if hop.VIP == hdr.Src && + newCost != hop.Cost { + // given an update to table, this is now a triggeredUpdate + triggeredEntries = append(triggeredEntries, RIPEntry{destination, newCost}) + routingTable[destination] = Hop{newCost, "R", src, hop.VIP} + + // if we receive infinity from the same neighbor, then delete the route after 12 sec + if entry.cost == INFINITY { + // remove after GC time if the COST is still INFINITY + go func() { + time.Sleep(time.Second * time.Duration(MAX_TIMEOUT)) + if routingTable[destination].Cost == INFINITY { + delete(routingTable, destination) + timeoutTableMu.Lock() + delete(timeoutTable, destination) + timeoutTableMu.Unlock() + } + }() + } + continue + } + + // case 2) a shorter route for this destination is revealed from a different neighbor + if newCost < hop.Cost && newCost != INFINITY { + triggeredEntries = append(triggeredEntries, RIPEntry{destination, entry.cost + 1}) + routingTable[destination] = Hop{entry.cost + 1, "R", src, hdr.Src} + continue + } + } + + // send out triggered updates + if len(triggeredEntries) > 0 { + SendTriggeredUpdates(triggeredEntries) + } + } + + return nil +} + +// prints the test packet as per the spec +func HandleTestPackets(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error { + fmt.Printf("Received test packet: Src: %s, Dst: %s, TTL: %d, Data: %s\n", + hdr.Src.String(), hdr.Dst.String(), hdr.TTL, string(message)) + return nil +} + +func HandleTCP(src *Interface, message []byte, hdr *ipv4header.IPv4Header) error { + fmt.Println("I see a TCP packet") + + tcpHeaderAndData := message + tcpHdr := iptcp_utils.ParseTCPHeader(tcpHeaderAndData) + tcpPayload := tcpHeaderAndData[tcpHdr.DataOffset:] + tcpChecksumFromHeader := tcpHdr.Checksum + tcpHdr.Checksum = 0 + tcpComputedChecksum := iptcp_utils.ComputeTCPChecksum(&tcpHdr, hdr.Src, hdr.Dst, tcpPayload) + + var tcpChecksumState string + if tcpComputedChecksum == tcpChecksumFromHeader { + tcpChecksumState = "OK" + } else { + tcpChecksumState = "FAIL" + } + + if tcpChecksumState == "FAIL" { + // drop the packet + fmt.Println("checksum failed, dropping packet") + return nil + } + + switch tcpHdr.Flags { + case header.TCPFlagSyn: + fmt.Println("I see a SYN flag") + // if the SYN flag is set, then send a SYNACK + available := false + + socketEntry, in := VHostSocketMaps[SocketKey{hdr.Dst.String(), tcpHdr.DstPort, hdr.Src.String(), tcpHdr.SrcPort}] + if !in { + fmt.Println("no socket entry found") + } else if socketEntry.State == Established { + fmt.Println("socket entry found") + + // make ack header + tcpHdr := &header.TCPFields{ + SrcPort: tcpHdr.DstPort, + DstPort: tcpHdr.SrcPort, + SeqNum: tcpHdr.SeqNum, + AckNum: tcpHdr.SeqNum + 1, + DataOffset: 20, + Flags: 0x10, + WindowSize: MAX_WINDOW_SIZE, + Checksum: 0, + UrgentPointer: 0, + } + // make the payload + err := SendTCP(tcpHdr, message, hdr.Dst, hdr.Src) + if err != nil { + fmt.Println(err) + } + socketEntry.Conn.RecvBuffer.buffer = append(socketEntry.Conn.RecvBuffer.buffer, tcpPayload...) + socketEntry.Conn.RecvBuffer.recvNext += uint32(len(tcpPayload)) + break + } + // add to table if available + mapMutex.Lock() + for _, socketEntry := range VHostSocketMaps { + // todo: check between all 4 field in tuple + if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == Listening{ + // add a new socketEntry to the map + newEntry := &SocketEntry{ + LocalPort: tcpHdr.DstPort, + RemotePort: tcpHdr.SrcPort, + LocalIP: hdr.Dst.String(), + RemoteIP: hdr.Src.String(), + State: SYNRECIEVED, + Socket: socketsMade, + } + // add the entry to the map + key := SocketKey{hdr.Dst.String(), tcpHdr.DstPort, hdr.Src.String(), tcpHdr.SrcPort} + VHostSocketMaps[key] = newEntry + socketsMade += 1 + // add the entry to the map + available = true + break + } + } + mapMutex.Unlock() + + // if no socket is available, then drop the packet + if !available { + fmt.Println("no socket available") + return nil + } + // make the header + tcpHdr := &header.TCPFields{ + SrcPort: tcpHdr.DstPort, + DstPort: tcpHdr.SrcPort, + SeqNum: tcpHdr.SeqNum, + AckNum: tcpHdr.SeqNum + 1, + DataOffset: 20, + Flags: 0x12, + WindowSize: MAX_WINDOW_SIZE, + Checksum: 0, + UrgentPointer: 0, + } + // make the payload + synAckPayload := []byte{} + err := SendTCP(tcpHdr, synAckPayload, hdr.Dst, hdr.Src) + if err != nil { + fmt.Println(err) + } + break + case header.TCPFlagAck | header.TCPFlagSyn: + fmt.Println("I see a SYNACK flag") + // lookup for socket entry and update its state + mapMutex.Lock() + for _, socketEntry := range VHostSocketMaps { + if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == SYNSENT { + socketEntry.State = Established + break + } + } + mapMutex.Unlock() + + // send an ACK + // make the header + tcpHdr := &header.TCPFields{ + SrcPort: tcpHdr.DstPort, + DstPort: tcpHdr.SrcPort, + SeqNum: tcpHdr.SeqNum + 1, + AckNum: tcpHdr.SeqNum, + DataOffset: 20, + Flags: 0x10, + WindowSize: MAX_WINDOW_SIZE, + Checksum: 0, + UrgentPointer: 0, + } + // make the payload + ackPayload := []byte{} + err := SendTCP(tcpHdr, ackPayload, hdr.Dst, hdr.Src) + if err != nil { + fmt.Println(err) + } + break + case header.TCPFlagAck: + fmt.Println("I see an ACK flag") + // lookup for socket entry and update its state + // set synChan to true (TODO) + key := SocketKey{hdr.Dst.String(), tcpHdr.DstPort, hdr.Src.String(), tcpHdr.SrcPort} + socketEntry, in := VHostSocketMaps[key] + if !in { + fmt.Println("no socket entry found") + } else if (socketEntry.State == Established) { + fmt.Println("socket entry found") + // socketEntry.Conn.RecvBuffer.buffer = append(socketEntry.Conn.RecvBuffer.buffer, tcpPayload...) + socketEntry.Conn.SendBuffer.una += uint32(len(tcpPayload)) + break + } + + mapMutex.Lock() + for _, socketEntry := range VHostSocketMaps { + if socketEntry.LocalPort == tcpHdr.DstPort && socketEntry.LocalIP == hdr.Dst.String() && socketEntry.State == SYNRECIEVED { + socketEntry.State = Established + break + } + } + mapMutex.Unlock() + break + default: + fmt.Println("I see a non TCP packet") + break + } + + + return nil +} + +// *********************************************** HELPERS ********************************************************** + +// Route returns the next HOP, based on longest prefix match for a given ip +// TODO: revisit how to do this at the bit level, not hardcoded for 32 & 24 +func Route(src netip.Addr) (Hop, error) { + possibleBits := [2]int{32, 24} + for _, bits := range possibleBits { + cmpPrefix := netip.PrefixFrom(src, bits) + for prefix, hop := range routingTable { + if cmpPrefix.Overlaps(prefix) { + return hop, nil + } + } + } + return Hop{}, errors.Errorf("error ROUTE: destination %s does not exist on routing table.", src) +} + +// MakeRipMessage returns the byte array to be used in SendIp for a RIP packet +func MakeRipMessage(command uint16, entries []RIPEntry) []byte { + if command == 1 { // request message + buf := make([]byte, SIZE_OF_RIP_HEADER) + binary.BigEndian.PutUint16(buf[0:2], command) + binary.BigEndian.PutUint16(buf[2:4], uint16(0)) + return buf + } + + // command == 2, response message + + // create the buffer + bufLen := SIZE_OF_RIP_HEADER + // sizeof uint16 is 2, we have two of them + len(entries)*SIZE_OF_RIP_ENTRY // each entry is 12 + + buf := make([]byte, bufLen) + + // fill in the header + binary.BigEndian.PutUint16(buf[0:2], command) + binary.BigEndian.PutUint16(buf[2:4], uint16(len(entries))) + + // fill in the entries + for i, entry := range entries { + offset := SIZE_OF_RIP_HEADER + i*SIZE_OF_RIP_ENTRY + binary.BigEndian.PutUint32(buf[offset:offset+4], entry.cost) // 0-3 = 4 bytes + copy(buf[offset+4:offset+8], entry.prefix.Addr().AsSlice()) // 4-7 = 4 bytes + + // convert the prefix to a uint32 + ipv4Netmask := uint32(0xffffffff) + ipv4Netmask <<= 32 - entry.prefix.Bits() + binary.BigEndian.PutUint32(buf[offset+8:offset+12], ipv4Netmask) + } + + return buf +} + +// ************************************** GETTER FUNCTIONS ********************************************************** +func GetInterfaceByName(ifaceName string) (*Interface, error) { + // iterate through the interfaces and return the one with the same name + for _, iface := range myInterfaces { + if iface.Name == ifaceName { + return iface, nil + } + } + return nil, errors.Errorf("No interface with name %s", ifaceName) +} + +func GetInterfaces() []*Interface { + return myInterfaces +} + +func GetNeighbors() map[string][]*Neighbor { + return myNeighbors +} + +func GetRoutes() map[netip.Prefix]Hop { + return routingTable +} + +// ************************************** PRINT FUNCTIONS ********************************************************** + +// SprintInterfaces returns a string representation of the interfaces data structure +func SprintInterfaces() string { + tmp := "" + for _, iface := range myInterfaces { + if iface.State { + // if the state is up, print UP + tmp += fmt.Sprintf("%s\t%s\t%s\n", iface.Name, iface.IpPrefix.String(), "UP") + } else { + // if the state is down, print DOWN + tmp += fmt.Sprintf("%s\t%s\t%s\n", iface.Name, iface.IpPrefix.String(), "DOWN") + } + } + return tmp +} + +// SprintNeighbors returns a string representation of the neighbors data structure +func SprintNeighbors() string { + tmp := "" + for _, iface := range myInterfaces { + if !iface.State { + // if the interface is down, skip it + continue + } + for _, n := range myNeighbors[iface.Name] { + tmp += fmt.Sprintf("%s\t%s\t%s\n", iface.Name, n.VipAddr.String(), n.UdpAddr.String()) + } + } + return tmp +} + +// SprintRoutingTable returns a string representation of the routing table +func SprintRoutingTable() string { + tmp := "" + for prefix, hop := range routingTable { + if hop.Type == "L" { + // if the hop is local, print LOCAL + tmp += fmt.Sprintf("%s\t%s\tLOCAL:%s\t%d\n", hop.Type, prefix.String(), hop.Interface.Name, hop.Cost) + } else if hop.Type == "S" { + // if the hop is static, don't print the cost + tmp += fmt.Sprintf("%s\t%s\t%s\t%s\n", hop.Type, prefix.String(), hop.VIP.String(), "-") + } else { + tmp += fmt.Sprintf("%s\t%s\t%s\t%d\n", hop.Type, prefix.String(), hop.VIP.String(), hop.Cost) + } + } + return tmp +} + +// ************************************** CLEANUP FUNCTIONS ********************************************************** + +// CleanUp cleans up the data structures and closes the UDP sockets +func CleanUp() { + fmt.Print("Cleaning up...\n") + + // go through the interfaces, pop thread & close the UDP FDs + for _, iface := range myInterfaces { + // close the channel + if iface.SocketChannel != nil { + close(iface.SocketChannel) + } + // close the UDP FD + err := iface.Socket.Close() + if err != nil { + continue + } + } + + // delete all the neighbors + myNeighbors = make(map[string][]*Neighbor) + // delete all the interfaces + myInterfaces = nil + // delete the routing table + routingTable = make(map[netip.Prefix]Hop) + + time.Sleep(5 * time.Millisecond) +} + +// ************************************** TCP FUNCTIONS ********************************************************** + +type ConnectionState string +const ( + Established ConnectionState = "ESTABLISHED" + Listening ConnectionState = "LISTENING" + Closed ConnectionState = "CLOSED" + SYNSENT ConnectionState = "SYNSENT" + SYNRECIEVED ConnectionState = "SYNRECIEVED" + MAX_WINDOW_SIZE = 65535 +) + +// VTCPListener represents a listener socket (similar to Go’s net.TCPListener) +type VTCPListener struct { + LocalAddr string + LocalPort uint16 + RemoteAddr string + RemotePort uint16 + Socket int + State ConnectionState +} + +// // VTCPConn represents a “normal” socket for a TCP connection between two endpoints (similar to Go’s net.TCPConn) +type VTCPConn struct { + LocalAddr string + LocalPort uint16 + RemoteAddr string + RemotePort uint16 + Socket int + State ConnectionState + SendBuffer *SendBuffer + RecvBuffer *RecvBuffer +} + +type SocketEntry struct { + Socket int + LocalIP string + LocalPort uint16 + RemoteIP string + RemotePort uint16 + State ConnectionState + Conn *VTCPConn +} + +type SocketKey struct { + LocalIP string + LocalPort uint16 + RemoteIP string + RemotePort uint16 +} + +type RecvBuffer struct { + recvNext uint32 + lbr uint32 + buffer []byte +} + +type SendBuffer struct { + una uint32 + nxt uint32 + lbr uint32 + buffer []byte +} + +// create a socket map +// var VHostSocketMaps = make(map[int]*SocketEntry) +var VHostSocketMaps = make(map[SocketKey]*SocketEntry) +// create a channel map +var VHostChannelMaps = make(map[int]chan []byte) +var mapMutex = &sync.Mutex{} +var socketsMade = 0 +var startingSeqNum = rand.Uint32() + +// Listen Sockets +func VListen(port uint16) (*VTCPListener, error) { + myIP := GetInterfaces()[0].IpPrefix.Addr() + listener := &VTCPListener{ + Socket: socketsMade, + State: Listening, + LocalPort: port, + LocalAddr: myIP.String(), + } + + // add the socket to the socket map + mapMutex.Lock() + + key := SocketKey{myIP.String(), port, "", 0} + VHostSocketMaps[key] = &SocketEntry{ + Socket: socketsMade, + LocalIP: myIP.String(), + LocalPort: port, + RemoteIP: "0.0.0.0", + RemotePort: 0, + State: Listening, + } + mapMutex.Unlock() + socketsMade += 1 + return listener, nil + +} + +func (l *VTCPListener) VAccept() (*VTCPConn, error) { + // synChan = make(chan bool) + for { + // wait for a SYN request + mapMutex.Lock() + for _, socketEntry := range VHostSocketMaps { + if socketEntry.State == Established { + // create a new VTCPConn + conn := &VTCPConn{ + LocalAddr: socketEntry.LocalIP, + LocalPort: socketEntry.LocalPort, + RemoteAddr: socketEntry.RemoteIP, + RemotePort: socketEntry.RemotePort, + Socket: socketEntry.Socket, + State: Established, + SendBuffer: &SendBuffer{ + una: 0, + nxt: 0, + lbr: 0, + buffer: make([]byte, MAX_WINDOW_SIZE), + }, + RecvBuffer: &RecvBuffer{ + recvNext: 0, + lbr: 0, + buffer: make([]byte, MAX_WINDOW_SIZE), + }, + } + socketEntry.Conn = conn + mapMutex.Unlock() + return conn, nil + } + } + mapMutex.Unlock() + + } +} + +func GetRandomPort() uint16 { + const ( + minDynamicPort = 49152 + maxDynamicPort = 65535 + ) + return uint16(rand.Intn(maxDynamicPort - minDynamicPort) + minDynamicPort) +} + +func VConnect(ip string, port uint16) (*VTCPConn, error) { + // get my ip address + myIP := GetInterfaces()[0].IpPrefix.Addr() + // get random port + portRand := GetRandomPort() + + tcpHdr := &header.TCPFields{ + SrcPort: portRand, + DstPort: port, + SeqNum: startingSeqNum, + AckNum: 0, + DataOffset: 20, + Flags: header.TCPFlagSyn, + WindowSize: MAX_WINDOW_SIZE, + Checksum: 0, + UrgentPointer: 0, + } + payload := []byte{} + ipParsed, err := netip.ParseAddr(ip) + if err != nil { + return nil, err + } + + err = SendTCP(tcpHdr, payload, myIP, ipParsed) + if err != nil { + return nil, err + } + + conn := &VTCPConn{ + LocalAddr: myIP.String(), + LocalPort: portRand, + RemoteAddr: ip, + RemotePort: port, + Socket: socketsMade, + State: Established, + SendBuffer: &SendBuffer{ + una: 0, + nxt: 0, + lbr: 0, + buffer: make([]byte, MAX_WINDOW_SIZE), + }, + RecvBuffer: &RecvBuffer{ + recvNext: 0, + lbr: 0, + buffer: make([]byte, MAX_WINDOW_SIZE), + }, + } + + // add the socket to the socket map + key := SocketKey{myIP.String(), portRand, ip, port} + mapMutex.Lock() + VHostSocketMaps[key] = &SocketEntry{ + Socket: socketsMade, + LocalIP: myIP.String(), + LocalPort: portRand, + RemoteIP: ip, + RemotePort: port, + State: SYNSENT, + Conn: conn, + } + mapMutex.Unlock() + socketsMade += 1 + + return conn, nil +} + +func SendTCP(tcpHdr *header.TCPFields, payload []byte, myIP netip.Addr, ipParsed netip.Addr) error { + checksum := iptcp_utils.ComputeTCPChecksum(tcpHdr, myIP, ipParsed, payload) + tcpHdr.Checksum = checksum + + tcpHeaderBytes := make(header.TCP, iptcp_utils.TcpHeaderLen) + tcpHeaderBytes.Encode(tcpHdr) + + ipPacketPayload := make([]byte, 0, len(tcpHeaderBytes)+len(payload)) + ipPacketPayload = append(ipPacketPayload, tcpHeaderBytes...) + ipPacketPayload = append(ipPacketPayload, []byte(payload)...) + + // lookup neighbor + address := ipParsed + hop, err := Route(address) + if err != nil { + fmt.Println(err) + return err + } + myAddr := hop.Interface.IpPrefix.Addr() + + for _, neighbor := range GetNeighbors()[hop.Interface.Name] { + if neighbor.VipAddr == address || + neighbor.VipAddr == hop.VIP && hop.Type == "S" { + bytesWritten, err := SendIP(&myAddr, neighbor, TCP_PROTOCOL, ipPacketPayload, ipParsed.String(), nil) + fmt.Printf("Sent %d bytes to %s\n", bytesWritten, neighbor.VipAddr.String()) + if err != nil { + fmt.Println(err) + } + } + } + return nil +} + +func SprintSockets() string { + tmp := "" + for _, socket := range VHostSocketMaps { + // remove the spaces of the local and remote ip variables + socket.LocalIP = strings.ReplaceAll(socket.LocalIP, " ", "") + socket.RemoteIP = strings.ReplaceAll(socket.RemoteIP, " ", "") + if socket.RemotePort == 0 { + tmp += fmt.Sprintf("%d\t%s\t%d\t%s\t\t%d\t%s\n", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State) + continue + } + tmp += fmt.Sprintf("%d\t%s\t%d\t%s\t%d\t%s\n", socket.Socket, socket.LocalIP, socket.LocalPort, socket.RemoteIP, socket.RemotePort, socket.State) + } + return tmp +} + +// MILESTONE 2 +func (c *VTCPConn) VClose() error { + // check if the socket is in the map + key := SocketKey{c.LocalAddr, c.LocalPort, c.RemoteAddr, c.RemotePort} + mapMutex.Lock() + socketEntry, in := VHostSocketMaps[key] + mapMutex.Unlock() + if !in { + return errors.Errorf("error VClose: socket %d does not exist", c.Socket) + } + + // change the state to closed + socketEntry.State = Closed + return nil +} + + +// advertise window = max window size - (next - 1 - lbr) + +// early arrivals queue +var earlyArrivals = make([][]byte, 0) + +// retranmission queue +var retransmissionQueue = make([][]byte, 0) + +func (c *VTCPConn) VWrite(payload []byte) (int, error) { + // check if the socket is in the map + key := SocketKey{c.LocalAddr, c.LocalPort, c.RemoteAddr, c.RemotePort} + mapMutex.Lock() + socketEntry, in := VHostSocketMaps[key] + mapMutex.Unlock() + if !in { + return 0, errors.Errorf("error VWrite: socket %d does not exist", c.Socket) + } + + // check if the state is established + if socketEntry.State != Established { + return 0, errors.Errorf("error VWrite: socket %d is not in established state", c.Socket) + } + + // check if the payload is empty + if len(payload) == 0 { + return 0, nil + } + + // check if the payload is larger than the window size + if len(payload) > MAX_WINDOW_SIZE { + return 0, errors.Errorf("error VWrite: payload is larger than the window size") + } + + // check if the payload is larger than the available window size + if len(payload) > int(MAX_WINDOW_SIZE - (c.SendBuffer.nxt - 1 - c.SendBuffer.lbr)) { + return 0, errors.Errorf("error VWrite: payload is larger than the available window size") + } + + // make the header + advertisedWindow := MAX_WINDOW_SIZE - (c.SendBuffer.nxt - 1 - c.SendBuffer.lbr) + tcpHdr := &header.TCPFields{ + SrcPort: c.LocalPort, + DstPort: c.RemotePort, + SeqNum: c.SendBuffer.nxt, + AckNum: c.SendBuffer.una, + DataOffset: 20, + Flags: header.TCPFlagSyn, + WindowSize: uint16(advertisedWindow), + Checksum: 0, + UrgentPointer: 0, + } + + myIP := GetInterfaces()[0].IpPrefix.Addr() + ipParsed, err := netip.ParseAddr(c.RemoteAddr) + if err != nil { + return 0, err + } + + err = SendTCP(tcpHdr, payload, myIP, ipParsed) + if err != nil { + return 0, err + } + // update the next sequence number + // c.SendBuffer.nxt += uint32(len(payload)) + + + c.SendBuffer.lbr += uint32(len(payload)) + return len(payload), nil +} + + +func (c *VTCPConn) VRead(numBytesToRead int) (int, string, error) { + // check if the socket is in the map + key := SocketKey{c.LocalAddr, c.LocalPort, c.RemoteAddr, c.RemotePort} + // mapMutex.Lock() + socketEntry, in := VHostSocketMaps[key] + // mapMutex.Unlock() + // check if the socket is in the map + if !in { + return 0, "", errors.Errorf("error VRead: socket %d does not exist", c.Socket) + } + + // check if the state is established + if socketEntry.State != Established { + return 0, "", errors.Errorf("error VRead: socket %d is not in established state", c.Socket) + } + fmt.Println("I am in VRead") + fmt.Println("I have", c.RecvBuffer.recvNext - c.RecvBuffer.lbr, "bytes to read") + fmt.Println(c.RecvBuffer.recvNext, c.RecvBuffer.lbr) + if (c.RecvBuffer.lbr < c.RecvBuffer.recvNext && c.RecvBuffer.recvNext - c.RecvBuffer.lbr >= uint32(numBytesToRead)) { + fmt.Println("I have enough data to read") + toReturn := string(socketEntry.Conn.RecvBuffer.buffer[c.RecvBuffer.lbr:c.RecvBuffer.lbr+uint32(numBytesToRead)]) + // update the last byte read + c.RecvBuffer.lbr += uint32(numBytesToRead) + // return the data + return numBytesToRead, toReturn, nil + } + + return 0, "", nil +}
\ No newline at end of file diff --git a/pkg/ipstack/ipstack_test.go b/pkg/ipstack/ipstack_test.go new file mode 100644 index 0000000..e782f67 --- /dev/null +++ b/pkg/ipstack/ipstack_test.go @@ -0,0 +1,301 @@ +package ipstack + +import ( + "fmt" + "net/netip" + "testing" +) + +//func TestInitialize(t *testing.T) { +// lnxFilePath := "../../doc-example/r2.lnx" +// err := Initialize(lnxFilePath) +// if err != nil { +// t.Error(err) +// } +// fmt.Printf("Interfaces:\n%s\n\n", SprintInterfaces()) +// fmt.Printf("Neighbors:\n%s\n", SprintNeighbors()) +// fmt.Printf("RoutingTable:\n%s\n", SprintRoutingTable()) +// +// fmt.Println("TestInitialize successful") +// t.Cleanup(func() { CleanUp() }) +//} +// +//func TestInterfaceUpThenDown(t *testing.T) { +// lnxFilePath := "../../doc-example/r2.lnx" +// err := Initialize(lnxFilePath) +// if err != nil { +// t.Error(err) +// } +// +// iface, err := GetInterfaceByName("if0") +// if err != nil { +// t.Error(err) +// } +// +// InterfaceUp(iface) +// if iface.State == false { +// t.Error("iface state should be true") +// } +// +// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces()) +// +// time.Sleep(5 * time.Millisecond) // allow time to print +// +// InterfaceDown(iface) +// if iface.State == true { +// t.Error("iface state should be false") +// } +// +// time.Sleep(5 * time.Millisecond) // allow time to print +// +// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces()) +// +// fmt.Println("TestInterfaceUpThenDown successful") +// t.Cleanup(func() { CleanUp() }) +//} +// +//func TestInterfaceUpThenDownTwice(t *testing.T) { +// lnxFilePath := "../../doc-example/r2.lnx" +// err := Initialize(lnxFilePath) +// if err != nil { +// t.Error(err) +// } +// +// iface, err := GetInterfaceByName("if0") +// if err != nil { +// t.Error(err) +// } +// +// InterfaceUp(iface) +// if iface.State == false { +// t.Error("iface state should be true") +// } +// +// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces()) +// +// time.Sleep(5 * time.Millisecond) // allow time to print +// +// fmt.Println("putting interface down") +// InterfaceDown(iface) +// if iface.State == true { +// t.Error("iface state should be false") +// } +// +// time.Sleep(3 * time.Millisecond) +// +// fmt.Println("putting interface back up for 3 iterations") +// InterfaceUp(iface) +// if iface.State == false { +// t.Error("iface state should be true") +// } +// time.Sleep(3 * time.Millisecond) // allow time to print +// +// fmt.Println("putting interface down") +// InterfaceDown(iface) +// if iface.State == true { +// t.Error("iface state should be false") +// } +// +// time.Sleep(5 * time.Millisecond) // allow time to print +// +// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces()) +// +// fmt.Println("TestInterfaceUpThenDownTwice successful") +// t.Cleanup(func() { CleanUp() }) +//} +// +//func TestSendIPToNeighbor(t *testing.T) { +// lnxFilePath := "../../doc-example/r2.lnx" +// err := Initialize(lnxFilePath) +// if err != nil { +// t.Error(err) +// } +// +// // get the first neighbor of this interface +// iface, err := GetInterfaceByName("if0") +// if err != nil { +// t.Error(err) +// } +// neighbors, err := GetNeighborsToInterface("if0") +// if err != nil { +// t.Error(err) +// } +// +// // setup a neighbor listener socket +// testNeighbor := neighbors[0] +// // close the socket so we can listen on it +// err = testNeighbor.SendSocket.Close() +// if err != nil { +// t.Error(err) +// } +// +// fmt.Printf("Interfaces:\n%s\n", SprintInterfaces()) +// fmt.Printf("Neighbors:\n%s\n", SprintNeighbors()) +// +// listenString := testNeighbor.UdpAddr.String() +// fmt.Println("listening on " + listenString) +// listenAddr, err := net.ResolveUDPAddr("udp4", listenString) +// if err != nil { +// t.Error(err) +// } +// recvSocket, err := net.ListenUDP("udp4", listenAddr) +// if err != nil { +// t.Error(err) +// } +// testNeighbor.SendSocket = *recvSocket +// +// sent := false +// go func() { +// buffer := make([]byte, MAX_IP_PACKET_SIZE) +// fmt.Println("wating to read from UDP socket") +// _, sourceAddr, err := recvSocket.ReadFromUDP(buffer) +// if err != nil { +// t.Error(err) +// } +// fmt.Println("read from UDP socket") +// hdr, err := ipv4header.ParseHeader(buffer) +// if err != nil { +// t.Error(err) +// } +// headerSize := hdr.Len +// headerBytes := buffer[:headerSize] +// checksumFromHeader := uint16(hdr.Checksum) +// computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader) +// +// var checksumState string +// if computedChecksum == checksumFromHeader { +// checksumState = "OK" +// } else { +// checksumState = "FAIL" +// } +// message := buffer[headerSize:] +// fmt.Printf("Received IP packet from %s\nHeader: %v\nChecksum: %s\nMessage: %s\n", +// sourceAddr.String(), hdr, checksumState, string(message)) +// if err != nil { +// t.Error(err) +// } +// +// sent = true +// }() +// +// time.Sleep(10 * time.Millisecond) +// +// // send a message to the neighbor +// fmt.Printf("sending message to neighbor\t%t\n", sent) +// err = SendIP(*iface, *testNeighbor, 0, []byte("You are my firest neighbor!")) +// if err != nil { +// t.Error(err) +// } +// +// fmt.Printf("SENT message to neighbor\t%t\n", sent) +// // give a little time for the message to be sent +// time.Sleep(1000 * time.Millisecond) +// if !sent { +// t.Error("Message not sent") +// t.Fail() +// } +// +// fmt.Println("TestSendIPToNeighbor successful") +// t.Cleanup(func() { CleanUp() }) +//} +// +//func TestRecvIP(t *testing.T) { +// lnxFilePath := "../../doc-example/r2.lnx" +// err := Initialize(lnxFilePath) +// if err != nil { +// t.Error(err) +// } +// +// // get the first neighbor of this interface to RecvIP from +// iface, err := GetInterfaceByName("if0") +// if err != nil { +// t.Error(err) +// } +// InterfaceUp(iface) +// +// // setup a random socket to send an ip packet from +// listenAddr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:6969") +// sendSocket, err := net.ListenUDP("udp4", listenAddr) +// +// // send a message to the neighbor +// ifaceAsNeighbor := Neighbor{ +// VipAddr: iface.IpPrefix.Addr(), +// UdpAddr: iface.UdpAddr, +// SendSocket: iface.RecvSocket, +// SocketChannel: iface.SocketChannel, +// } +// fakeIface := Interface{ +// Name: "if69", +// IpPrefix: netip.MustParsePrefix("10.69.0.1/24"), +// UdpAddr: netip.MustParseAddrPort("127.0.0.1:6969"), +// RecvSocket: net.UDPConn{}, +// SocketChannel: nil, +// State: true, +// } +// err = SendIP(fakeIface, ifaceAsNeighbor, 0, []byte("hello")) +// if err != nil { +// return +// } +// +// time.Sleep(10 * time.Millisecond) +// +// // TODO: potenially make this a channel, so it actually checks values. +// // For now, you must read the message from the console. +// +// err = sendSocket.Close() +// if err != nil { +// t.Error(err) +// } +// t.Cleanup(func() { CleanUp() }) +//} + +func TestIntersect(t *testing.T) { + net1 := netip.MustParsePrefix("10.0.0.0/24") + net2 := netip.MustParsePrefix("1.1.1.2/24") + net3 := netip.MustParsePrefix("1.0.0.1/24") + net4 := netip.MustParsePrefix("0.0.0.0/0") // default route + net5 := netip.MustParsePrefix("10.2.0.3/32") + + res00 := intersect(net5, net1) + if res00 { + t.Error("net5 -> net1 should not intersect") + t.Fail() + } + res01 := intersect(net5, net4) + if !res01 { + t.Error("net5 -> net4 should intersect") + t.Fail() + } + + res6 := intersect(net1, net3) + if res6 { + t.Error("net1 and net3 should not intersect") + t.Fail() + } + res2 := intersect(net2, net3) + if res2 { + t.Error("net2 and net3 should not intersect") + t.Fail() + } + res3 := intersect(net1, net4) + if !res3 { + t.Error("net1 and net4 should intersect") + t.Fail() + } + res4 := intersect(net2, net4) + if !res4 { + t.Error("net2 and net4 should intersect") + t.Fail() + } + res5 := intersect(net3, net4) + if !res5 { + t.Error("net3 and net4 should intersect") + t.Fail() + } + + fmt.Println("TestIntersect successful") +} + +func intersect(n1, n2 netip.Prefix) bool { + return n1.Overlaps(n2) +} |