package ipstack import ( "fmt" ipv4header "github.com/brown-csci1680/iptcp-headers" "github.com/google/netstack/tcpip/header" "github.com/pkg/errors" "iptcp/pkg/lnxconfig" "log" "net" "net/netip" "time" "encoding/binary" ) const ( MAX_IP_PACKET_SIZE = 1400 LOCAL_COST uint32 = 0 STATIC_COST uint32 = 4294967295 // 2^32 - 1 MaxEntries = 64 INFINITY = 15 ) // STRUCTS --------------------------------------------------------------------- type Interface struct { Name string IpPrefix netip.Prefix UdpAddr netip.AddrPort RecvSocket net.UDPConn SocketChannel chan bool State bool } type Neighbor struct { VipAddr netip.Addr UdpAddr netip.AddrPort SendSocket net.UDPConn SocketChannel chan bool } type RIPMessage struct { command uint8 numEntries uint8 entries []RIPEntry } type RIPEntry struct { address uint32 cost uint32 mask uint32 } type Hop struct { Cost uint32 VipAsStr string Type string // added this for printing purposes } // GLOBAL VARIABLES (data structures) ------------------------------------------ var myVIP Interface var myInterfaces []*Interface var myNeighbors = make(map[string][]*Neighbor) // var myRIPNeighbors = make(map[string]Neighbor) type HandlerFunc func(int, string, *[]byte) error var protocolHandlers = make(map[uint16]HandlerFunc) // var routingTable = routingtable.New() var routingTable = make(map[netip.Prefix]Hop) // reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go func createUDPConn(UdpAddr netip.AddrPort, conn *net.UDPConn, isN bool) error { listenString := UdpAddr.String() listenAddr, err := net.ResolveUDPAddr("udp4", listenString) if err != nil { return errors.WithMessage(err, "Error resolving address->\t"+listenString) } if !isN { tmpConn, err := net.ListenUDP("udp4", listenAddr) if err != nil { return errors.WithMessage(err, "Could not bind to UDP port->\t"+listenString) } *conn = *tmpConn } else { tmpConn, err := net.DialUDP("udp4", nil, listenAddr) if err != nil { return errors.WithMessage(err, "Could not bind to UDP port->\t"+listenString) } *conn = *tmpConn } return nil } func Initialize(lnxFilePath string) error { // Parse the file lnxConfig, err := lnxconfig.ParseConfig(lnxFilePath) if err != nil { return errors.WithMessage(err, "Error parsing config file->\t"+lnxFilePath) } // 1) initialize the interfaces on this node here and into the routing table // static := false interfaceToReturn := Interface{} for _, iface := range lnxConfig.Interfaces { prefix := netip.PrefixFrom(iface.AssignedIP, iface.AssignedPrefix.Bits()) i := &Interface{ Name: iface.Name, IpPrefix: prefix, UdpAddr: iface.UDPAddr, RecvSocket: net.UDPConn{}, SocketChannel: make(chan bool), State: true, } if interfaceToReturn == (Interface{}) { interfaceToReturn = *i myVIP = *i } // Added this for printing purposes for REPL FYI, if you have a better way lmk err := createUDPConn(iface.UDPAddr, &i.RecvSocket, false) if err != nil { return errors.WithMessage(err, "Error creating UDP socket for interface->\t"+iface.Name) } go InterfaceListenerRoutine(i.RecvSocket, i.SocketChannel) myInterfaces = append(myInterfaces, i) // TODO: (FOR HOSTS ONLY) // add STATIC to routing table // if !static { // ifacePrefix := netip.MustParsePrefix("0.0.0.0/0") // routingTable[ifacePrefix] = Hop{STATIC_COST, iface.Name, "S"} // static = true // } // Took this out for printing purposes for REPL FYI } // 2) initialize the neighbors connected to the node and into the routing table for _, neighbor := range lnxConfig.Neighbors { n := &Neighbor{ VipAddr: neighbor.DestAddr, UdpAddr: neighbor.UDPAddr, SendSocket: net.UDPConn{}, SocketChannel: make(chan bool), } err := createUDPConn(neighbor.UDPAddr, &n.SendSocket, true) if err != nil { return errors.WithMessage(err, "Error creating UDP socket for neighbor->\t"+neighbor.DestAddr.String()) } myNeighbors[neighbor.InterfaceName] = append(myNeighbors[neighbor.InterfaceName], n) // add to routing table // TODO: REVISIT AND SEE IF "24" IS CORRECT neighborPrefix := netip.PrefixFrom(neighbor.DestAddr, 24) routingTable[neighborPrefix] = Hop{LOCAL_COST, neighbor.InterfaceName, "L"} } for _, route := range lnxConfig.StaticRoutes { // add to routing table prefix := netip.MustParsePrefix("0.0.0.0/0") routingTable[prefix] = Hop{LOCAL_COST, route.String(), "S"} } // added for printing purposes for REPL FYI return nil } func InterfaceListenerRoutine(socket net.UDPConn, signal <-chan bool) { isUp := true // testing purposes set to TRUE closed := false // go routine that hangs on the recv fmt.Println("MAKING GO ROUTINE TO LISTEN:\t", socket.LocalAddr().String()) go func() { defer func() { // on close, set isUp to false fmt.Println("exiting go routine that listens on ", socket.LocalAddr().String()) }() for { if closed { // stop this go routine if channel is closed return } if !isUp { // don't call the listeners if interface is down continue } // TODO: remove these "training wheels" time.Sleep(1 * time.Millisecond) err := RecvIP(socket, &isUp) if err != nil { fmt.Println("Error receiving IP packet", err) return } } }() for { select { case sig, ok := <-signal: if !ok { fmt.Println("channel closed, exiting") closed = true return } fmt.Println("received isUP SIGNAL with value", sig) isUp = sig default: continue } } } func InterfaceUp(iface *Interface) { iface.State = true iface.SocketChannel <- true } func InterfaceUpREPL(ifaceName string) { iface, err := GetInterfaceByName(ifaceName) if err != nil { fmt.Println("Error getting interface by name", err) return } iface.State = true iface.SocketChannel <- true } // we could do either of these but the REPL results in less work done in router and host func InterfaceDown(iface *Interface) { iface.SocketChannel <- false iface.State = false } func InterfaceDownREPL(ifaceName string) { iface, err := GetInterfaceByName(ifaceName) if err != nil { fmt.Println("Error getting interface by name", err) return } iface.SocketChannel <- false iface.State = false } // same as above comment func GetInterfaceByName(ifaceName string) (*Interface, error) { for _, iface := range myInterfaces { if iface.Name == ifaceName { return iface, nil } } return nil, errors.Errorf("No interface with name %s", ifaceName) } func GetNeighborByIP(ipAddr string) (*Neighbor, error) { for _, neighbors := range myNeighbors { for _, neighbor := range neighbors { if neighbor.VipAddr.String() == ipAddr { return neighbor, nil } } } return nil, errors.Errorf("No interface with ip %s", ipAddr) } func GetRouteByIP(ipAddr string) (*Neighbor, error) { for prefix, hop := range routingTable { netIP := net.ParseIP(prefix.Addr().String()) if netIP.String() == ipAddr { fmt.Println("found route", hop.VipAsStr) neighbor, err := GetNeighborByIP(hop.VipAsStr) if err != nil { fmt.Println("Error getting neighbors to interface", err) continue // fix with longest prefix matching? } return neighbor, nil } } return nil, errors.Errorf("No interface with ip %s", ipAddr) } func GetNeighborsToInterface(ifaceName string) ([]*Neighbor, error) { if neighbors, ok := myNeighbors[ifaceName]; ok { return neighbors, nil } return nil, errors.Errorf("No interface with name %s", ifaceName) } func GetMyVIP() Interface { return myVIP } func SprintInterfaces() { for _, iface := range myInterfaces { if iface.State { fmt.Printf("%s\t%s\t%s\n", iface.Name, iface.IpPrefix.String(), "UP") } else { fmt.Printf("%s\t%s\t%s\n", iface.Name, iface.IpPrefix.String(), "DOWN") } } } func SprintNeighbors() { for ifaceName, neighbor := range myNeighbors { for _, n := range neighbor { fmt.Printf("%s\t%s\t%s\n", ifaceName, n.VipAddr.String(), n.UdpAddr.String()) } } } func SprintRoutingTable() { for prefix, hop := range routingTable { if hop.Type == "L" { fmt.Printf("%s\t%s\tLOCAL:%s\t%d\n", hop.Type, prefix.String(), hop.VipAsStr, 0) } else if hop.Type == "S" { fmt.Printf("%s\t%s\t%s\t%s\n", hop.Type, prefix.String(), hop.VipAsStr, "-") } else { fmt.Printf("%s\t%s\t%s\t%d\n", hop.Type, prefix.String(), hop.VipAsStr, hop.Cost) } } } func DebugNeighbors() { for ifaceName, neighbor := range myNeighbors { for _, n := range neighbor { fmt.Printf("%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String()) } } } // func RemoveNeighbor(neighbor Neighbor) { // // TODO: remove from routing table // myRoutes := GetRoutes() // for prefix, hop := range myRoutes { // if hop.VipAsStr == neighbor.VipAddr.String() { // delete(myRoutes, prefix) // } // } // // TODO: remove from myNeighbors // myNeighbors[neighbor.VipAddr.String()] = nil // // TODO: close the UDP socket // err := neighbor.SendSocket.Close() // if err != nil { // fmt.Println("Error closing UDP socket", err) // } // } // untested function above func CleanUp() { fmt.Print("Cleaning up...\n") // go through the interfaces, pop thread & close the UDP FDs for _, iface := range myInterfaces { if iface.SocketChannel != nil { close(iface.SocketChannel) } err := iface.RecvSocket.Close() if err != nil { continue } } // go through the neighbors, pop thread & close the UDP FDs for _, neighbor := range myNeighbors { for _, n := range neighbor { if n.SocketChannel != nil { close(n.SocketChannel) } err := n.SendSocket.Close() if err != nil { continue } } } // delete all the neighbors myNeighbors = make(map[string][]*Neighbor) // delete all the interfaces myInterfaces = nil // delete the routing table routingTable = make(map[netip.Prefix]Hop) time.Sleep(5 * time.Millisecond) } // TODO: have it take TTL so we can decrement it when forwarding func SendIP(src Interface, dest *Neighbor, protocolNum int, message []byte, destIP string) error { hdr := ipv4header.IPv4Header{ Version: 4, Len: 20, // Header length is always 20 when no IP options TOS: 0, TotalLen: ipv4header.HeaderLen + len(message), ID: 0, Flags: 0, FragOff: 0, TTL: 32, Protocol: protocolNum, Checksum: 0, // Should be 0 until checksum is computed Src: src.IpPrefix.Addr(), Dst: netip.MustParseAddr(destIP), Options: []byte{}, } // Assemble the header into a byte array headerBytes, err := hdr.Marshal() if err != nil { return err } // Compute the checksum (see below) // Cast back to an int, which is what the Header structure expects hdr.Checksum = int(ComputeChecksum(headerBytes)) headerBytes, err = hdr.Marshal() if err != nil { log.Fatalln("Error marshalling header: ", err) } bytesToSend := make([]byte, 0, len(headerBytes)+len(message)) bytesToSend = append(bytesToSend, headerBytes...) bytesToSend = append(bytesToSend, []byte(message)...) // Send the message to the "link-layer" addr:port on UDP listenAddr, err := net.ResolveUDPAddr("udp4", dest.UdpAddr.String()) if err != nil { return err } // idk that this is ^^ bytesWritten, err := dest.SendSocket.Write(bytesToSend) if err != nil { return err } // send to the dest.UdpAddr without the SendSocket // conn, err := net.DialUDP("udp4", nil, listenAddr) // if err != nil { // fmt.Println(err, "here5") // return err // } // bytesWritten, err := conn.Write(bytesToSend) // if err != nil { // fmt.Println(err, "here6") // return err // } // what we had previously just in case fmt.Printf("Sent %d bytes to %s\n", bytesWritten, listenAddr.String()) return nil } func RecvIP(conn net.UDPConn, isOpen *bool) error { buffer := make([]byte, MAX_IP_PACKET_SIZE) // TODO: fix wordking // Read on the UDP port // fmt.Println("wating to read from UDP socket") _, _, err := conn.ReadFromUDP(buffer) if err != nil { return err } if !*isOpen { return errors.New("interface is down") } // Marshal the received byte array into a UDP header // NOTE: This does not validate the checksum or check any fields // (You'll need to do this part yourself) hdr, err := ipv4header.ParseHeader(buffer) if err != nil { // What should you if the message fails to parse? // Your node should not crash or exit when you get a bad message. // Instead, simply drop the packet and return to processing. fmt.Println("Error parsing header", err) return err } headerSize := hdr.Len headerBytes := buffer[:headerSize] checksumFromHeader := uint16(hdr.Checksum) computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader) var checksumState string if computedChecksum == checksumFromHeader { checksumState = "OK" } else { checksumState = "FAIL" } // Next, get the message, which starts after the header message := buffer[headerSize:] // TODO: handle the message // 1) check if the TTL & checksum is valid TTL := hdr.TTL if TTL == 0 { // drop the packet return nil } // check if the checksum is valid if checksumState == "FAIL" { // drop the packet return nil } if hdr.Protocol == 12 { // RIP // 1) check if the message is for me, if so, sendUP (aka call the correct handler) command := message[0] if command == 1 { // request // SendUpdates() } else if command == 2 { numEntries := message[1] entries := make([]RIPEntry, 0, numEntries) for i := 0; i < int(numEntries); i++ { offset := 4 + 2 + i*12 address := binary.BigEndian.Uint32(message[offset : offset+4]) mask := binary.BigEndian.Uint32(message[offset+4 : offset+8]) cost := binary.BigEndian.Uint32(message[offset+8 : offset+12]) entries = append(entries, RIPEntry{address: address, mask: mask, cost: cost}) } // add to routing table for _, entry := range entries { address := fmt.Sprintf("%d.%d.%d.%d", byte(entry.address>>24), byte(entry.address>>16), byte(entry.address>>8), byte(entry.address)) mask := fmt.Sprintf("%d.%d.%d.%d", byte(entry.mask>>24), byte((entry.mask>>16)&0xff), byte((entry.mask>>8)&0xff), byte(entry.mask&0xff)) // check if the entry is already in the routing table if _, ok := routingTable[netip.MustParsePrefix(address+"/24")]; ok { // if so, check if the cost is the same // if routingTable[netip.MustParsePrefix(address+"/24")].Cost == entry.cost { // // if so, do nothing // continue // } else { // // if not, update the cost // routingTable[netip.MustParsePrefix(address+"/24")] = Hop{entry.cost, mask, "R"} // } continue } // fmt.Println(entry.cost) routingTable[netip.MustParsePrefix(address+"/24")] = Hop{entry.cost + 1, mask, "R"} } } } else { // 2) check if the message is for me, if so, sendUP (aka call the correct handler) if hdr.Dst.String() == myVIP.IpPrefix.Addr().String() { fmt.Println("for me") if hdr.Protocol == 0 { fmt.Printf("Received test packet: Src: %s, Dst: %s, TTL: %d, Data: %s\n", hdr.Src.String(), hdr.Dst.String(), hdr.TTL, string(message)) return nil } } // if not, need to forward the packer to a neighbor or check the table // after decrementing TTL and updating checksum hdr.TTL-- // update checksum headerBytes, err = hdr.Marshal() if err != nil { log.Fatalln("Error marshalling header: ", err) } hdr.Checksum = int(ComputeChecksum(headerBytes)) headerBytes, err = hdr.Marshal() if err != nil { log.Fatalln("Error marshalling header: ", err) } bytesToSend := make([]byte, 0, len(headerBytes)+len(message)) bytesToSend = append(bytesToSend, headerBytes...) bytesToSend = append(bytesToSend, []byte(message)...) // 3) check if message is for a neighbor, if so, sendIP there for _, neighbors := range myNeighbors { // fmt.Println(neighbors) for _, neighbor := range neighbors { if hdr.Dst.String() == neighbor.VipAddr.String() { fmt.Println("for neighbor") // send the packet to the neighbor err := SendIP(myVIP, neighbor, hdr.Protocol, bytesToSend, hdr.Dst.String()) if err != nil { fmt.Println("Error sending IP packet", err) return err } return nil } } } // fmt.Println("No neighbors, checking Routes") // 4) check forwarding table. if so, forward to the neighbor with that VIP for prefix, hop := range routingTable { netIP := net.ParseIP(prefix.Addr().String()) if netIP.String() == hdr.Dst.String() { // send the packet to next hop neighbors, err := GetNeighborByIP(hop.VipAsStr) if err != nil { fmt.Println("Error getting neighbor by IP", err) return err } err = SendIP(myVIP, neighbors, hdr.Protocol, bytesToSend, hdr.Dst.String()) if err != nil { fmt.Println("Error sending IP packet", err) return err } return nil } } // send 10.2.0.3 hello // 5) if not, drop the packet return nil } return nil } func ComputeChecksum(b []byte) uint16 { checksum := header.Checksum(b, 0) checksumInv := checksum ^ 0xffff return checksumInv } func ValidateChecksum(b []byte, fromHeader uint16) uint16 { checksum := header.Checksum(b, fromHeader) return checksum } func GetInterfaces() []*Interface { return myInterfaces } func GetNeighbors() map[string][]*Neighbor { return myNeighbors } func GetRoutes() map[netip.Prefix]Hop { return routingTable } func NewRIPMessage(command uint8, entries []RIPEntry) *RIPMessage { return &RIPMessage{ command: command, numEntries: uint8(len(entries)), entries: entries, } } func SendRIPMessage(src Interface, dest *Neighbor, message *RIPMessage) error { hdr := ipv4header.IPv4Header{ Version: 4, Len: 20, // Header length is always 20 when no IP options TOS: 0, TotalLen: ipv4header.HeaderLen + 4 + 1 + 1 + len(message.entries)*12, ID: 0, Flags: 0, FragOff: 0, TTL: 32, Protocol: 12, Checksum: 0, // Should be 0 until checksum is computed Src: src.IpPrefix.Addr(), Dst: netip.MustParseAddr(dest.VipAddr.String()), Options: []byte{}, } headerBytes, err := hdr.Marshal() if err != nil { return err } hdr.Checksum = int(ComputeChecksum(headerBytes)) headerBytes, err = hdr.Marshal() if err != nil { log.Fatalln("Error marshalling header: ", err) } bytesToSend := make([]byte, 0, len(headerBytes)+4+1+1+len(message.entries)*12) bytesToSend = append(bytesToSend, headerBytes...) buf := make([]byte, 4+1+1+len(message.entries)*12) buf[0] = message.command buf[1] = message.numEntries for i, entry := range message.entries { offset := 4 + 2 + i*12 binary.BigEndian.PutUint32(buf[offset:offset+4], entry.address) binary.BigEndian.PutUint32(buf[offset+4:offset+8], entry.mask) binary.BigEndian.PutUint32(buf[offset+8:offset+12], entry.cost) } bytesToSend = append(bytesToSend, buf...) _, err = dest.SendSocket.Write(bytesToSend) if err != nil { return err } // what we had previously just in case // lots of printing // fmt.Printf("Sent %d bytes to %s\n", bytesWritten, dest.UdpAddr.String()) return nil } func SendUpdates() { entries := make([]RIPEntry, len(routingTable)) for prefix, hop := range routingTable { netIP := net.ParseIP(prefix.Addr().String()) ipBytes := netIP.To4() ipUint32 := uint32(ipBytes[0]) << 24 | uint32(ipBytes[1]) << 16 | uint32(ipBytes[2]) << 8 | uint32(ipBytes[3]) if netIP.String() == "0.0.0.0" { continue } if hop.Type == "S" { continue } for _, neighbors := range myNeighbors { for _, neighbor := range neighbors { neighborBits := net.ParseIP(myVIP.IpPrefix.Addr().String()) neighborBits = neighborBits.To4() neighborUint32 := uint32(neighborBits[0]) << 24 | uint32(neighborBits[1]) << 16 | uint32(neighborBits[2]) << 8 | uint32(neighborBits[3]) cost := hop.Cost if hop.Type == "R" { cost = INFINITY } entry := &RIPEntry{ address: ipUint32, cost: cost, mask: neighborUint32, } entries = append(entries, *entry) message := NewRIPMessage(2, entries) err := SendRIPMessage(myVIP, neighbor, message) if err != nil { // fmt.Println("Error sending RIP packet") continue } } } } } func CheckAndUpdateRoutingTable() { for { time.Sleep(12 * time.Second) for prefix, hop := range routingTable { // delete route if not refreshed in 12 seconds // not sure if there is a better way to do this if hop.Type == "R" { delete(routingTable, prefix) SendUpdates() } } } }