aboutsummaryrefslogtreecommitdiff
path: root/pkg/ipstack
diff options
context:
space:
mode:
authorDavid Doan <daviddoan@Davids-MacBook-Pro-70.local>2023-10-11 16:39:56 -0400
committerDavid Doan <daviddoan@Davids-MacBook-Pro-70.local>2023-10-11 16:39:56 -0400
commitdb0a9f0a4605d85ba4e535ba0ab590776cc4ba0a (patch)
tree4cfc83ca904f33b7d052ce5a6784d5996e827c90 /pkg/ipstack
parentded5a362b43715497a6f887354dd1a20bc9a621b (diff)
parentce42396f99e1d99d7e3b3016acabd9380627a297 (diff)
merge
Diffstat (limited to 'pkg/ipstack')
-rw-r--r--pkg/ipstack/ipstack.go413
-rw-r--r--pkg/ipstack/ipstack_test.go253
2 files changed, 666 insertions, 0 deletions
diff --git a/pkg/ipstack/ipstack.go b/pkg/ipstack/ipstack.go
new file mode 100644
index 0000000..be8bc1e
--- /dev/null
+++ b/pkg/ipstack/ipstack.go
@@ -0,0 +1,413 @@
+package ipstack
+
+import (
+ "fmt"
+ ipv4header "github.com/brown-csci1680/iptcp-headers"
+ "github.com/google/netstack/tcpip/header"
+ "github.com/pkg/errors"
+ "iptcp/pkg/lnxconfig"
+ "log"
+ "net"
+ "net/netip"
+ "time"
+)
+
+const (
+ MAX_IP_PACKET_SIZE = 1400
+ LOCAL_COST uint32 = 0
+ STATIC_COST uint32 = 4294967295 // 2^32 - 1
+)
+
+// STRUCTS ---------------------------------------------------------------------
+type Interface struct {
+ Name string
+ IpPrefix netip.Prefix
+ UdpAddr netip.AddrPort
+
+ RecvSocket net.UDPConn
+ SocketChannel chan bool
+ State bool
+}
+
+type Neighbor struct {
+ VipAddr netip.Addr
+ UdpAddr netip.AddrPort
+
+ SendSocket net.UDPConn
+ SocketChannel chan bool
+}
+
+type RIPMessage struct {
+ command uint8
+ numEntries uint8
+ entries []RIPEntry
+}
+
+type RIPEntry struct {
+ addr netip.Addr
+ cost uint32
+ mask netip.Prefix
+}
+
+type Hop struct {
+ Cost uint32
+ VipAsStr string
+}
+
+// GLOBAL VARIABLES (data structures) ------------------------------------------
+var myVIP netip.Addr
+var myInterfaces []*Interface
+var myNeighbors = make(map[string][]*Neighbor)
+
+// var myRIPNeighbors = make(map[string]Neighbor)
+type HandlerFunc func(int, string, *[]byte) error
+
+var protocolHandlers = make(map[uint16]HandlerFunc)
+
+// var routingTable = routingtable.New()
+var routingTable = make(map[netip.Prefix]Hop)
+
+// reference: https://github.com/brown-csci1680/lecture-examples/blob/main/ip-demo/cmd/udp-ip-recv/main.go
+func createUDPConn(UdpAddr netip.AddrPort, conn *net.UDPConn) error {
+ listenString := UdpAddr.String()
+ listenAddr, err := net.ResolveUDPAddr("udp4", listenString)
+ if err != nil {
+ return errors.WithMessage(err, "Error resolving address->\t"+listenString)
+ }
+ tmpConn, err := net.ListenUDP("udp4", listenAddr)
+ if err != nil {
+ return errors.WithMessage(err, "Could not bind to UDP port->\t"+listenString)
+ }
+ *conn = *tmpConn
+
+ return nil
+}
+
+func Initialize(lnxFilePath string) error {
+ //if len(os.Args) != 2 {
+ // fmt.Printf("Usage: %s <configFile>\n", os.Args[0])
+ // os.Exit(1)
+ //}
+ //lnxFilePath := os.Args[1]
+
+ // Parse the file
+ lnxConfig, err := lnxconfig.ParseConfig(lnxFilePath)
+ if err != nil {
+ return errors.WithMessage(err, "Error parsing config file->\t"+lnxFilePath)
+ }
+
+ // 1) initialize the interfaces on this node here and into the routing table
+ static := false
+ for _, iface := range lnxConfig.Interfaces {
+ prefix := netip.PrefixFrom(iface.AssignedIP, iface.AssignedPrefix.Bits())
+ i := &Interface{
+ Name: iface.Name,
+ IpPrefix: prefix,
+ UdpAddr: iface.UDPAddr,
+ RecvSocket: net.UDPConn{},
+ SocketChannel: make(chan bool),
+ State: false,
+ }
+
+ err := createUDPConn(iface.UDPAddr, &i.RecvSocket)
+ if err != nil {
+ return errors.WithMessage(err, "Error creating UDP socket for interface->\t"+iface.Name)
+ }
+ go InterfaceListenerRoutine(i.RecvSocket, i.SocketChannel)
+ myInterfaces = append(myInterfaces, i)
+
+ // TODO: (FOR HOSTS ONLY)
+ // add STATIC to routing table
+ if !static {
+ ifacePrefix := netip.MustParsePrefix("0.0.0.0/0")
+ routingTable[ifacePrefix] = Hop{STATIC_COST, iface.Name}
+ static = true
+ }
+ }
+
+ // 2) initialize the neighbors connected to the node and into the routing table
+ for _, neighbor := range lnxConfig.Neighbors {
+ n := &Neighbor{
+ VipAddr: neighbor.DestAddr,
+ UdpAddr: neighbor.UDPAddr,
+ SendSocket: net.UDPConn{},
+ SocketChannel: make(chan bool),
+ }
+
+ err := createUDPConn(neighbor.UDPAddr, &n.SendSocket)
+ if err != nil {
+ return errors.WithMessage(err, "Error creating UDP socket for neighbor->\t"+neighbor.DestAddr.String())
+ }
+
+ myNeighbors[neighbor.InterfaceName] = append(myNeighbors[neighbor.InterfaceName], n)
+
+ // add to routing table
+ // TODO: REVISIT AND SEE IF "24" IS CORRECT
+ neighborPrefix := netip.PrefixFrom(neighbor.DestAddr, 24)
+ routingTable[neighborPrefix] = Hop{LOCAL_COST, neighbor.InterfaceName}
+ }
+
+ return nil
+}
+
+func InterfaceListenerRoutine(socket net.UDPConn, signal <-chan bool) {
+ isUp := false
+ closed := false
+
+ // go routine that hangs on the recv
+ fmt.Println("MAKING GO ROUTINE TO LISTEN:\t", socket.LocalAddr().String())
+ go func() {
+ defer func() { // on close, set isUp to false
+ fmt.Println("exiting go routine that listens on ", socket.LocalAddr().String())
+ }()
+
+ for {
+ if closed { // stop this go routine if channel is closed
+ return
+ }
+ if !isUp { // don't call the listeners if interface is down
+ continue
+ }
+ // TODO: remove these "training wheels"
+ time.Sleep(1 * time.Millisecond)
+ err := RecvIP(socket, &isUp)
+ if err != nil {
+ fmt.Println("Error receiving IP packet", err)
+ return
+ }
+ }
+ }()
+
+ for {
+ select {
+ case sig, ok := <-signal:
+ if !ok {
+ fmt.Println("channel closed, exiting")
+ closed = true
+ return
+ }
+ fmt.Println("received isUP SIGNAL with value", sig)
+ isUp = sig
+ default:
+ continue
+ }
+ }
+}
+
+func InterfaceUp(iface *Interface) {
+ iface.State = true
+ iface.SocketChannel <- true
+}
+
+func InterfaceDown(iface *Interface) {
+ iface.SocketChannel <- false
+ iface.State = false
+}
+
+func GetInterfaceByName(ifaceName string) (*Interface, error) {
+ for _, iface := range myInterfaces {
+ if iface.Name == ifaceName {
+ return iface, nil
+ }
+ }
+
+ return nil, errors.Errorf("No interface with name %s", ifaceName)
+}
+
+func GetNeighborsToInterface(ifaceName string) ([]*Neighbor, error) {
+ if neighbors, ok := myNeighbors[ifaceName]; ok {
+ return neighbors, nil
+ }
+
+ return nil, errors.Errorf("No interface with name %s", ifaceName)
+}
+
+func SprintInterfaces() string {
+ buf := ""
+ for _, iface := range myInterfaces {
+ buf += fmt.Sprintf("%s\t%s\t%t\n", iface.Name, iface.IpPrefix.String(), iface.State)
+ }
+ return buf
+}
+
+func SprintNeighbors() string {
+ buf := ""
+ for ifaceName, neighbor := range myNeighbors {
+ for _, n := range neighbor {
+ buf += fmt.Sprintf("%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String())
+ }
+ }
+ return buf
+}
+
+func SprintRoutingTable() string {
+ buf := ""
+ for prefix, hop := range routingTable {
+ buf += fmt.Sprintf("%s\t%s\t%d\n", prefix.String(), hop.VipAsStr, hop.Cost)
+ }
+ return buf
+}
+
+func DebugNeighbors() {
+ for ifaceName, neighbor := range myNeighbors {
+ for _, n := range neighbor {
+ fmt.Printf("%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String())
+ }
+ }
+}
+
+func CleanUp() {
+ fmt.Print("Cleaning up...\n")
+ // go through the interfaces, pop thread & close the UDP FDs
+ for _, iface := range myInterfaces {
+ if iface.SocketChannel != nil {
+ close(iface.SocketChannel)
+ }
+ err := iface.RecvSocket.Close()
+ if err != nil {
+ continue
+ }
+ }
+
+ // go through the neighbors, pop thread & close the UDP FDs
+ for _, neighbor := range myNeighbors {
+ for _, n := range neighbor {
+ if n.SocketChannel != nil {
+ close(n.SocketChannel)
+ }
+ err := n.SendSocket.Close()
+ if err != nil {
+ continue
+ }
+ }
+ }
+
+ // delete all the neighbors
+ myNeighbors = make(map[string][]*Neighbor)
+ // delete all the interfaces
+ myInterfaces = nil
+ // delete the routing table
+ routingTable = make(map[netip.Prefix]Hop)
+
+ time.Sleep(5 * time.Millisecond)
+}
+
+// TODO: have it take TTL so we can decrement it when forwarding
+func SendIP(src Interface, dest Neighbor, protocolNum int, message []byte) error {
+ hdr := ipv4header.IPv4Header{
+ Version: 4,
+ Len: 20, // Header length is always 20 when no IP options
+ TOS: 0,
+ TotalLen: ipv4header.HeaderLen + len(message),
+ ID: 0,
+ Flags: 0,
+ FragOff: 0,
+ TTL: 32,
+ Protocol: protocolNum,
+ Checksum: 0, // Should be 0 until checksum is computed
+ Src: src.IpPrefix.Addr(),
+ Dst: dest.VipAddr,
+ Options: []byte{},
+ }
+
+ // Assemble the header into a byte array
+ headerBytes, err := hdr.Marshal()
+ if err != nil {
+ return err
+ }
+
+ // Compute the checksum (see below)
+ // Cast back to an int, which is what the Header structure expects
+ hdr.Checksum = int(ComputeChecksum(headerBytes))
+
+ headerBytes, err = hdr.Marshal()
+ if err != nil {
+ log.Fatalln("Error marshalling header: ", err)
+ }
+
+ bytesToSend := make([]byte, 0, len(headerBytes)+len(message))
+ bytesToSend = append(bytesToSend, headerBytes...)
+ bytesToSend = append(bytesToSend, []byte(message)...)
+
+ // Send the message to the "link-layer" addr:port on UDP
+ listenAddr, err := net.ResolveUDPAddr("udp4", dest.UdpAddr.String())
+ if err != nil {
+ return err
+ }
+ bytesWritten, err := dest.SendSocket.WriteToUDP(bytesToSend, listenAddr)
+ if err != nil {
+ return err
+ }
+ fmt.Printf("Sent %d bytes to %s\n", bytesWritten, listenAddr.String())
+
+ return nil
+}
+
+func RecvIP(conn net.UDPConn, isOpen *bool) error {
+ buffer := make([]byte, MAX_IP_PACKET_SIZE) // TODO: fix wordking
+
+ // Read on the UDP port
+ fmt.Println("wating to read from UDP socket")
+ _, sourceAddr, err := conn.ReadFromUDP(buffer)
+ if err != nil {
+ return err
+ }
+
+ if !*isOpen {
+ return errors.New("interface is down")
+ }
+
+ // Marshal the received byte array into a UDP header
+ // NOTE: This does not validate the checksum or check any fields
+ // (You'll need to do this part yourself)
+ hdr, err := ipv4header.ParseHeader(buffer)
+ if err != nil {
+ // What should you if the message fails to parse?
+ // Your node should not crash or exit when you get a bad message.
+ // Instead, simply drop the packet and return to processing.
+ fmt.Println("Error parsing header", err)
+ return err
+ }
+
+ headerSize := hdr.Len
+ headerBytes := buffer[:headerSize]
+ checksumFromHeader := uint16(hdr.Checksum)
+ computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader)
+
+ var checksumState string
+ if computedChecksum == checksumFromHeader {
+ checksumState = "OK"
+ } else {
+ checksumState = "FAIL"
+ }
+
+ // Next, get the message, which starts after the header
+ message := buffer[headerSize:]
+
+ // Finally, print everything out
+ fmt.Printf("Received IP packet from %s\nHeader: %v\nChecksum: %s\nMessage: %s\n",
+ sourceAddr.String(), hdr, checksumState, string(message))
+
+ // TODO: handle the message
+ // 1) check if the TTL & checksum is valid
+ // 2) check if the message is for me, if so, sendUP (aka call the correct handler)
+ // if not, need to forward the packer to a neighbor or check the table
+ // after decrementing TTL and updating checksum
+ // 3) check if message is for a neighbor, if so, sendIP there
+ // 4) check if message is for a neighbor, if so, forward to the neighbor with that VIP
+
+ return nil
+}
+
+func ComputeChecksum(b []byte) uint16 {
+ checksum := header.Checksum(b, 0)
+ checksumInv := checksum ^ 0xffff
+
+ return checksumInv
+}
+
+func ValidateChecksum(b []byte, fromHeader uint16) uint16 {
+ checksum := header.Checksum(b, fromHeader)
+
+ return checksum
+}
diff --git a/pkg/ipstack/ipstack_test.go b/pkg/ipstack/ipstack_test.go
new file mode 100644
index 0000000..941c4e9
--- /dev/null
+++ b/pkg/ipstack/ipstack_test.go
@@ -0,0 +1,253 @@
+package ipstack
+
+import (
+ "fmt"
+ ipv4header "github.com/brown-csci1680/iptcp-headers"
+ "net"
+ "net/netip"
+ "testing"
+ "time"
+)
+
+func TestInitialize(t *testing.T) {
+ lnxFilePath := "../../doc-example/r2.lnx"
+ err := Initialize(lnxFilePath)
+ if err != nil {
+ t.Error(err)
+ }
+ fmt.Printf("Interfaces:\n%s\n\n", SprintInterfaces())
+ fmt.Printf("Neighbors:\n%s\n", SprintNeighbors())
+ fmt.Printf("RoutingTable:\n%s\n", SprintRoutingTable())
+
+ fmt.Println("TestInitialize successful")
+ t.Cleanup(func() { CleanUp() })
+}
+
+func TestInterfaceUpThenDown(t *testing.T) {
+ lnxFilePath := "../../doc-example/r2.lnx"
+ err := Initialize(lnxFilePath)
+ if err != nil {
+ t.Error(err)
+ }
+
+ iface, err := GetInterfaceByName("if0")
+ if err != nil {
+ t.Error(err)
+ }
+
+ InterfaceUp(iface)
+ if iface.State == false {
+ t.Error("iface state should be true")
+ }
+
+ fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+
+ time.Sleep(5 * time.Millisecond) // allow time to print
+
+ InterfaceDown(iface)
+ if iface.State == true {
+ t.Error("iface state should be false")
+ }
+
+ time.Sleep(5 * time.Millisecond) // allow time to print
+
+ fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+
+ fmt.Println("TestInterfaceUpThenDown successful")
+ t.Cleanup(func() { CleanUp() })
+}
+
+func TestInterfaceUpThenDownTwice(t *testing.T) {
+ lnxFilePath := "../../doc-example/r2.lnx"
+ err := Initialize(lnxFilePath)
+ if err != nil {
+ t.Error(err)
+ }
+
+ iface, err := GetInterfaceByName("if0")
+ if err != nil {
+ t.Error(err)
+ }
+
+ InterfaceUp(iface)
+ if iface.State == false {
+ t.Error("iface state should be true")
+ }
+
+ fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+
+ time.Sleep(5 * time.Millisecond) // allow time to print
+
+ fmt.Println("putting interface down")
+ InterfaceDown(iface)
+ if iface.State == true {
+ t.Error("iface state should be false")
+ }
+
+ time.Sleep(3 * time.Millisecond)
+
+ fmt.Println("putting interface back up for 3 iterations")
+ InterfaceUp(iface)
+ if iface.State == false {
+ t.Error("iface state should be true")
+ }
+ time.Sleep(3 * time.Millisecond) // allow time to print
+
+ fmt.Println("putting interface down")
+ InterfaceDown(iface)
+ if iface.State == true {
+ t.Error("iface state should be false")
+ }
+
+ time.Sleep(5 * time.Millisecond) // allow time to print
+
+ fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+
+ fmt.Println("TestInterfaceUpThenDownTwice successful")
+ t.Cleanup(func() { CleanUp() })
+}
+
+func TestSendIPToNeighbor(t *testing.T) {
+ lnxFilePath := "../../doc-example/r2.lnx"
+ err := Initialize(lnxFilePath)
+ if err != nil {
+ t.Error(err)
+ }
+
+ // get the first neighbor of this interface
+ iface, err := GetInterfaceByName("if0")
+ if err != nil {
+ t.Error(err)
+ }
+ neighbors, err := GetNeighborsToInterface("if0")
+ if err != nil {
+ t.Error(err)
+ }
+
+ // setup a neighbor listener socket
+ testNeighbor := neighbors[0]
+ // close the socket so we can listen on it
+ err = testNeighbor.SendSocket.Close()
+ if err != nil {
+ t.Error(err)
+ }
+
+ fmt.Printf("Interfaces:\n%s\n", SprintInterfaces())
+ fmt.Printf("Neighbors:\n%s\n", SprintNeighbors())
+
+ listenString := testNeighbor.UdpAddr.String()
+ fmt.Println("listening on " + listenString)
+ listenAddr, err := net.ResolveUDPAddr("udp4", listenString)
+ if err != nil {
+ t.Error(err)
+ }
+ recvSocket, err := net.ListenUDP("udp4", listenAddr)
+ if err != nil {
+ t.Error(err)
+ }
+ testNeighbor.SendSocket = *recvSocket
+
+ sent := false
+ go func() {
+ buffer := make([]byte, MAX_IP_PACKET_SIZE)
+ fmt.Println("wating to read from UDP socket")
+ _, sourceAddr, err := recvSocket.ReadFromUDP(buffer)
+ if err != nil {
+ t.Error(err)
+ }
+ fmt.Println("read from UDP socket")
+ hdr, err := ipv4header.ParseHeader(buffer)
+ if err != nil {
+ t.Error(err)
+ }
+ headerSize := hdr.Len
+ headerBytes := buffer[:headerSize]
+ checksumFromHeader := uint16(hdr.Checksum)
+ computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader)
+
+ var checksumState string
+ if computedChecksum == checksumFromHeader {
+ checksumState = "OK"
+ } else {
+ checksumState = "FAIL"
+ }
+ message := buffer[headerSize:]
+ fmt.Printf("Received IP packet from %s\nHeader: %v\nChecksum: %s\nMessage: %s\n",
+ sourceAddr.String(), hdr, checksumState, string(message))
+ if err != nil {
+ t.Error(err)
+ }
+
+ sent = true
+ }()
+
+ time.Sleep(10 * time.Millisecond)
+
+ // send a message to the neighbor
+ fmt.Printf("sending message to neighbor\t%t\n", sent)
+ err = SendIP(*iface, *testNeighbor, 0, []byte("You are my firest neighbor!"))
+ if err != nil {
+ t.Error(err)
+ }
+
+ fmt.Printf("SENT message to neighbor\t%t\n", sent)
+ // give a little time for the message to be sent
+ time.Sleep(1000 * time.Millisecond)
+ if !sent {
+ t.Error("Message not sent")
+ t.Fail()
+ }
+
+ fmt.Println("TestSendIPToNeighbor successful")
+ t.Cleanup(func() { CleanUp() })
+}
+
+func TestRecvIP(t *testing.T) {
+ lnxFilePath := "../../doc-example/r2.lnx"
+ err := Initialize(lnxFilePath)
+ if err != nil {
+ t.Error(err)
+ }
+
+ // get the first neighbor of this interface to RecvIP from
+ iface, err := GetInterfaceByName("if0")
+ if err != nil {
+ t.Error(err)
+ }
+ InterfaceUp(iface)
+
+ // setup a random socket to send an ip packet from
+ listenAddr, err := net.ResolveUDPAddr("udp4", "127.0.0.1:6969")
+ sendSocket, err := net.ListenUDP("udp4", listenAddr)
+
+ // send a message to the neighbor
+ ifaceAsNeighbor := Neighbor{
+ VipAddr: iface.IpPrefix.Addr(),
+ UdpAddr: iface.UdpAddr,
+ SendSocket: iface.RecvSocket,
+ SocketChannel: iface.SocketChannel,
+ }
+ fakeIface := Interface{
+ Name: "if69",
+ IpPrefix: netip.MustParsePrefix("10.69.0.1/24"),
+ UdpAddr: netip.MustParseAddrPort("127.0.0.1:6969"),
+ RecvSocket: net.UDPConn{},
+ SocketChannel: nil,
+ State: true,
+ }
+ err = SendIP(fakeIface, ifaceAsNeighbor, 0, []byte("hello"))
+ if err != nil {
+ return
+ }
+
+ time.Sleep(10 * time.Millisecond)
+
+ // TODO: potenially make this a channel, so it actually checks values.
+ // For now, you must read the message from the console.
+
+ err = sendSocket.Close()
+ if err != nil {
+ t.Error(err)
+ }
+ t.Cleanup(func() { CleanUp() })
+}