aboutsummaryrefslogtreecommitdiff
path: root/pkg/ipstack/ipstack.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/ipstack/ipstack.go')
-rw-r--r--pkg/ipstack/ipstack.go413
1 files changed, 413 insertions, 0 deletions
diff --git a/pkg/ipstack/ipstack.go b/pkg/ipstack/ipstack.go
new file mode 100644
index 0000000..be8bc1e
--- /dev/null
+++ b/pkg/ipstack/ipstack.go
@@ -0,0 +1,413 @@
+package ipstack
+
+import (
+ "fmt"
+ ipv4header "github.com/brown-csci1680/iptcp-headers"
+ "github.com/google/netstack/tcpip/header"
+ "github.com/pkg/errors"
+ "iptcp/pkg/lnxconfig"
+ "log"
+ "net"
+ "net/netip"
+ "time"
+)
+
+const (
+ MAX_IP_PACKET_SIZE = 1400
+ LOCAL_COST uint32 = 0
+ STATIC_COST uint32 = 4294967295 // 2^32 - 1
+)
+
+// STRUCTS ---------------------------------------------------------------------
+type Interface struct {
+ Name string
+ IpPrefix netip.Prefix
+ UdpAddr netip.AddrPort
+
+ RecvSocket net.UDPConn
+ SocketChannel chan bool
+ State bool
+}
+
+type Neighbor struct {
+ VipAddr netip.Addr
+ UdpAddr netip.AddrPort
+
+ SendSocket net.UDPConn
+ SocketChannel chan bool
+}
+
+type RIPMessage struct {
+ command uint8
+ numEntries uint8
+ entries []RIPEntry
+}
+
+type RIPEntry struct {
+ addr netip.Addr
+ cost uint32
+ mask netip.Prefix
+}
+
+type Hop struct {
+ Cost uint32
+ VipAsStr string
+}
+
+// GLOBAL VARIABLES (data structures) ------------------------------------------
+var myVIP netip.Addr
+var myInterfaces []*Interface
+var myNeighbors = make(map[string][]*Neighbor)
+
+// var myRIPNeighbors = make(map[string]Neighbor)
+type HandlerFunc func(int, string, *[]byte) error
+
+var protocolHandlers = make(map[uint16]HandlerFunc)
+
+// var routingTable = routingtable.New()
+var routingTable = make(map[netip.Prefix]Hop)
+
+// reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go
+func createUDPConn(UdpAddr netip.AddrPort, conn *net.UDPConn) error {
+ listenString := UdpAddr.String()
+ listenAddr, err := net.ResolveUDPAddr("udp4", listenString)
+ if err != nil {
+ return errors.WithMessage(err, "Error resolving address->\t"+listenString)
+ }
+ tmpConn, err := net.ListenUDP("udp4", listenAddr)
+ if err != nil {
+ return errors.WithMessage(err, "Could not bind to UDP port->\t"+listenString)
+ }
+ *conn = *tmpConn
+
+ return nil
+}
+
+func Initialize(lnxFilePath string) error {
+ //if len(os.Args) != 2 {
+ // fmt.Printf("Usage: %s <configFile>\n", os.Args[0])
+ // os.Exit(1)
+ //}
+ //lnxFilePath := os.Args[1]
+
+ // Parse the file
+ lnxConfig, err := lnxconfig.ParseConfig(lnxFilePath)
+ if err != nil {
+ return errors.WithMessage(err, "Error parsing config file->\t"+lnxFilePath)
+ }
+
+ // 1) initialize the interfaces on this node here and into the routing table
+ static := false
+ for _, iface := range lnxConfig.Interfaces {
+ prefix := netip.PrefixFrom(iface.AssignedIP, iface.AssignedPrefix.Bits())
+ i := &Interface{
+ Name: iface.Name,
+ IpPrefix: prefix,
+ UdpAddr: iface.UDPAddr,
+ RecvSocket: net.UDPConn{},
+ SocketChannel: make(chan bool),
+ State: false,
+ }
+
+ err := createUDPConn(iface.UDPAddr, &i.RecvSocket)
+ if err != nil {
+ return errors.WithMessage(err, "Error creating UDP socket for interface->\t"+iface.Name)
+ }
+ go InterfaceListenerRoutine(i.RecvSocket, i.SocketChannel)
+ myInterfaces = append(myInterfaces, i)
+
+ // TODO: (FOR HOSTS ONLY)
+ // add STATIC to routing table
+ if !static {
+ ifacePrefix := netip.MustParsePrefix("0.0.0.0/0")
+ routingTable[ifacePrefix] = Hop{STATIC_COST, iface.Name}
+ static = true
+ }
+ }
+
+ // 2) initialize the neighbors connected to the node and into the routing table
+ for _, neighbor := range lnxConfig.Neighbors {
+ n := &Neighbor{
+ VipAddr: neighbor.DestAddr,
+ UdpAddr: neighbor.UDPAddr,
+ SendSocket: net.UDPConn{},
+ SocketChannel: make(chan bool),
+ }
+
+ err := createUDPConn(neighbor.UDPAddr, &n.SendSocket)
+ if err != nil {
+ return errors.WithMessage(err, "Error creating UDP socket for neighbor->\t"+neighbor.DestAddr.String())
+ }
+
+ myNeighbors[neighbor.InterfaceName] = append(myNeighbors[neighbor.InterfaceName], n)
+
+ // add to routing table
+ // TODO: REVISIT AND SEE IF "24" IS CORRECT
+ neighborPrefix := netip.PrefixFrom(neighbor.DestAddr, 24)
+ routingTable[neighborPrefix] = Hop{LOCAL_COST, neighbor.InterfaceName}
+ }
+
+ return nil
+}
+
+func InterfaceListenerRoutine(socket net.UDPConn, signal <-chan bool) {
+ isUp := false
+ closed := false
+
+ // go routine that hangs on the recv
+ fmt.Println("MAKING GO ROUTINE TO LISTEN:\t", socket.LocalAddr().String())
+ go func() {
+ defer func() { // on close, set isUp to false
+ fmt.Println("exiting go routine that listens on ", socket.LocalAddr().String())
+ }()
+
+ for {
+ if closed { // stop this go routine if channel is closed
+ return
+ }
+ if !isUp { // don't call the listeners if interface is down
+ continue
+ }
+ // TODO: remove these "training wheels"
+ time.Sleep(1 * time.Millisecond)
+ err := RecvIP(socket, &isUp)
+ if err != nil {
+ fmt.Println("Error receiving IP packet", err)
+ return
+ }
+ }
+ }()
+
+ for {
+ select {
+ case sig, ok := <-signal:
+ if !ok {
+ fmt.Println("channel closed, exiting")
+ closed = true
+ return
+ }
+ fmt.Println("received isUP SIGNAL with value", sig)
+ isUp = sig
+ default:
+ continue
+ }
+ }
+}
+
+func InterfaceUp(iface *Interface) {
+ iface.State = true
+ iface.SocketChannel <- true
+}
+
+func InterfaceDown(iface *Interface) {
+ iface.SocketChannel <- false
+ iface.State = false
+}
+
+func GetInterfaceByName(ifaceName string) (*Interface, error) {
+ for _, iface := range myInterfaces {
+ if iface.Name == ifaceName {
+ return iface, nil
+ }
+ }
+
+ return nil, errors.Errorf("No interface with name %s", ifaceName)
+}
+
+func GetNeighborsToInterface(ifaceName string) ([]*Neighbor, error) {
+ if neighbors, ok := myNeighbors[ifaceName]; ok {
+ return neighbors, nil
+ }
+
+ return nil, errors.Errorf("No interface with name %s", ifaceName)
+}
+
+func SprintInterfaces() string {
+ buf := ""
+ for _, iface := range myInterfaces {
+ buf += fmt.Sprintf("%s\t%s\t%t\n", iface.Name, iface.IpPrefix.String(), iface.State)
+ }
+ return buf
+}
+
+func SprintNeighbors() string {
+ buf := ""
+ for ifaceName, neighbor := range myNeighbors {
+ for _, n := range neighbor {
+ buf += fmt.Sprintf("%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String())
+ }
+ }
+ return buf
+}
+
+func SprintRoutingTable() string {
+ buf := ""
+ for prefix, hop := range routingTable {
+ buf += fmt.Sprintf("%s\t%s\t%d\n", prefix.String(), hop.VipAsStr, hop.Cost)
+ }
+ return buf
+}
+
+func DebugNeighbors() {
+ for ifaceName, neighbor := range myNeighbors {
+ for _, n := range neighbor {
+ fmt.Printf("%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String())
+ }
+ }
+}
+
+func CleanUp() {
+ fmt.Print("Cleaning up...\n")
+ // go through the interfaces, pop thread & close the UDP FDs
+ for _, iface := range myInterfaces {
+ if iface.SocketChannel != nil {
+ close(iface.SocketChannel)
+ }
+ err := iface.RecvSocket.Close()
+ if err != nil {
+ continue
+ }
+ }
+
+ // go through the neighbors, pop thread & close the UDP FDs
+ for _, neighbor := range myNeighbors {
+ for _, n := range neighbor {
+ if n.SocketChannel != nil {
+ close(n.SocketChannel)
+ }
+ err := n.SendSocket.Close()
+ if err != nil {
+ continue
+ }
+ }
+ }
+
+ // delete all the neighbors
+ myNeighbors = make(map[string][]*Neighbor)
+ // delete all the interfaces
+ myInterfaces = nil
+ // delete the routing table
+ routingTable = make(map[netip.Prefix]Hop)
+
+ time.Sleep(5 * time.Millisecond)
+}
+
+// TODO: have it take TTL so we can decrement it when forwarding
+func SendIP(src Interface, dest Neighbor, protocolNum int, message []byte) error {
+ hdr := ipv4header.IPv4Header{
+ Version: 4,
+ Len: 20, // Header length is always 20 when no IP options
+ TOS: 0,
+ TotalLen: ipv4header.HeaderLen + len(message),
+ ID: 0,
+ Flags: 0,
+ FragOff: 0,
+ TTL: 32,
+ Protocol: protocolNum,
+ Checksum: 0, // Should be 0 until checksum is computed
+ Src: src.IpPrefix.Addr(),
+ Dst: dest.VipAddr,
+ Options: []byte{},
+ }
+
+ // Assemble the header into a byte array
+ headerBytes, err := hdr.Marshal()
+ if err != nil {
+ return err
+ }
+
+ // Compute the checksum (see below)
+ // Cast back to an int, which is what the Header structure expects
+ hdr.Checksum = int(ComputeChecksum(headerBytes))
+
+ headerBytes, err = hdr.Marshal()
+ if err != nil {
+ log.Fatalln("Error marshalling header: ", err)
+ }
+
+ bytesToSend := make([]byte, 0, len(headerBytes)+len(message))
+ bytesToSend = append(bytesToSend, headerBytes...)
+ bytesToSend = append(bytesToSend, []byte(message)...)
+
+ // Send the message to the "link-layer" addr:port on UDP
+ listenAddr, err := net.ResolveUDPAddr("udp4", dest.UdpAddr.String())
+ if err != nil {
+ return err
+ }
+ bytesWritten, err := dest.SendSocket.WriteToUDP(bytesToSend, listenAddr)
+ if err != nil {
+ return err
+ }
+ fmt.Printf("Sent %d bytes to %s\n", bytesWritten, listenAddr.String())
+
+ return nil
+}
+
+func RecvIP(conn net.UDPConn, isOpen *bool) error {
+ buffer := make([]byte, MAX_IP_PACKET_SIZE) // TODO: fix wordking
+
+ // Read on the UDP port
+ fmt.Println("wating to read from UDP socket")
+ _, sourceAddr, err := conn.ReadFromUDP(buffer)
+ if err != nil {
+ return err
+ }
+
+ if !*isOpen {
+ return errors.New("interface is down")
+ }
+
+ // Marshal the received byte array into a UDP header
+ // NOTE: This does not validate the checksum or check any fields
+ // (You'll need to do this part yourself)
+ hdr, err := ipv4header.ParseHeader(buffer)
+ if err != nil {
+ // What should you if the message fails to parse?
+ // Your node should not crash or exit when you get a bad message.
+ // Instead, simply drop the packet and return to processing.
+ fmt.Println("Error parsing header", err)
+ return err
+ }
+
+ headerSize := hdr.Len
+ headerBytes := buffer[:headerSize]
+ checksumFromHeader := uint16(hdr.Checksum)
+ computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader)
+
+ var checksumState string
+ if computedChecksum == checksumFromHeader {
+ checksumState = "OK"
+ } else {
+ checksumState = "FAIL"
+ }
+
+ // Next, get the message, which starts after the header
+ message := buffer[headerSize:]
+
+ // Finally, print everything out
+ fmt.Printf("Received IP packet from %s\nHeader: %v\nChecksum: %s\nMessage: %s\n",
+ sourceAddr.String(), hdr, checksumState, string(message))
+
+ // TODO: handle the message
+ // 1) check if the TTL & checksum is valid
+ // 2) check if the message is for me, if so, sendUP (aka call the correct handler)
+ // if not, need to forward the packer to a neighbor or check the table
+ // after decrementing TTL and updating checksum
+ // 3) check if message is for a neighbor, if so, sendIP there
+ // 4) check if message is for a neighbor, if so, forward to the neighbor with that VIP
+
+ return nil
+}
+
+func ComputeChecksum(b []byte) uint16 {
+ checksum := header.Checksum(b, 0)
+ checksumInv := checksum ^ 0xffff
+
+ return checksumInv
+}
+
+func ValidateChecksum(b []byte, fromHeader uint16) uint16 {
+ checksum := header.Checksum(b, fromHeader)
+
+ return checksum
+}