diff options
Diffstat (limited to 'pkg/ipstack')
-rw-r--r-- | pkg/ipstack/ipstack.go | 191 | ||||
-rw-r--r-- | pkg/ipstack/ipstack_test.go | 45 |
2 files changed, 143 insertions, 93 deletions
diff --git a/pkg/ipstack/ipstack.go b/pkg/ipstack/ipstack.go index e4d1651..fb18161 100644 --- a/pkg/ipstack/ipstack.go +++ b/pkg/ipstack/ipstack.go @@ -6,6 +6,7 @@ import ( "iptcp/pkg/lnxconfig" "net" "net/netip" + "time" ) const ( @@ -20,7 +21,7 @@ type Interface struct { IpPrefix netip.Prefix RecvSocket net.Conn - SocketChannel chan<- bool + SocketChannel chan bool State bool } @@ -29,7 +30,7 @@ type Neighbor struct { UdpAddr netip.AddrPort SendSocket net.Conn - SocketChannel chan<- bool + SocketChannel chan bool } type RIPMessage struct { @@ -50,8 +51,8 @@ type Hop struct { } // GLOBAL VARIABLES (data structures) ------------------------------------------ -var myInterfaces []Interface -var myNeighbors = make(map[string][]Neighbor) +var myInterfaces []*Interface +var myNeighbors = make(map[string][]*Neighbor) // var myRIPNeighbors = make(map[string]Neighbor) type HandlerFunc func(int, string, *[]byte) error @@ -91,6 +92,7 @@ func Initialize(lnxFilePath string) error { } // 1) initialize the interfaces on this node here and into the routing table + static := false for _, iface := range lnxConfig.Interfaces { prefix := netip.PrefixFrom(iface.AssignedIP, iface.AssignedPrefix.Bits()) i := &Interface{ @@ -104,14 +106,18 @@ func Initialize(lnxFilePath string) error { if err != nil { return errors.WithMessage(err, "Error creating UDP socket for interface->\t"+iface.Name) } - myInterfaces = append(myInterfaces, *i) - - // add to routing table - //ifacePrefix := netip.PrefixFrom(iface.AssignedIP, iface.AssignedPrefix.Bits()) - //routingTable[ifacePrefix] = Hop{STATIC_COST, iface.Name} + myInterfaces = append(myInterfaces, i) + + // TODO: (FOR HOSTS ONLY) + // add STATIC to routing table + if !static { + ifacePrefix := netip.MustParsePrefix("0.0.0.0/0") + routingTable[ifacePrefix] = Hop{STATIC_COST, iface.Name} + static = true + } } - // 2) initialize the neighbors connected to the node + // 2) initialize the neighbors connected to the node and into the routing table for _, neighbor := range lnxConfig.Neighbors { n := &Neighbor{ VipAddr: neighbor.DestAddr, @@ -122,7 +128,7 @@ func Initialize(lnxFilePath string) error { if err != nil { return errors.WithMessage(err, "Error creating UDP socket for neighbor->\t"+neighbor.DestAddr.String()) } - myNeighbors[neighbor.InterfaceName] = append(myNeighbors[neighbor.InterfaceName], *n) + myNeighbors[neighbor.InterfaceName] = append(myNeighbors[neighbor.InterfaceName], n) // add to routing table neighborPrefix := netip.PrefixFrom(neighbor.DestAddr, 24) @@ -132,84 +138,51 @@ func Initialize(lnxFilePath string) error { return nil } -/* - -func ListerToInterfaces() { - for _, iface := range myInterfaces { - go RecvIp(iface) - } -} - -func RecvIp(iface Interface) error { +func InterfaceListenerRoutine(iface Interface, signal <-chan bool) { + isDown := false for { - buffer := make([]byte, MAX_IP_PACKET_SIZE) - _, sourceAddr, err := iface.udp.ReadFrom(buffer) - if err != nil { - log.Panicln("Error reading from UDP socket ", err) - } - - hdr, err := ipv4header.ParseHeader(buffer) - - if err != nil { - fmt.Println("Error parsing header", err) - continue - } - - headerSize := hdr.Len - headerBytes := buffer[:headerSize] - - checksumFromHeader := uint16(hdr.Checksum) - computedChecksum := ValidateChecksum(headerBytes, checksumFromHeader) - - var checksumState string - if computedChecksum == checksumFromHeader { - checksumState = "OK" - } else { - checksumState = "FAIL" - continue - } - - // check ttl - ttl := data[8] - if ttl == 0 { - fmt.Println("TTL is 0") - continue + select { + case open, sig := <-signal: + if !open { + fmt.Println("channel closed, exiting") + return + } + fmt.Println("received SIGNAL with value", sig) + if sig { + isDown = <-signal + } + default: + if isDown { + continue + } + fmt.Println("no activity, actively listening on", iface.Name) + // TODO: remove these training wheels + time.Sleep(1 * time.Millisecond) } + } +} - destAddr := netip.AddrFrom(data[16:20]) - protocolNum := data[9] - - if destAddr == iface.addr { - // send to handler - protocolHandlers[protocolNum](data) - // message := buffer[headerSize:] - - // fmt.Printf("Received IP packet from %s\nHeader: %v\nChecksum: %s\nMessage: %s\n", - // sourceAddr.String(), hdr, checksumState, string(message)) - } else { - // decrement ttl and update checksum - data[8] = ttl - 1 - data[10] = 0 - data[11] = 0 - newChecksum := int(ComputeChecksum(data[:headerSize])) - data[10] = newChecksum >> 8 - data[11] = newChecksum & 0xff - - // check neighbors - for _, neighbor := range iface.neighbors { - if neighbor == destAddr { - // send to neighbor - // SendIp(destAddr, protocolNum, data) - } - } +// When an interface goes up, we need to start it's go routine that listens +func InterfaceUp(iface *Interface) { + iface.SocketChannel = make(chan bool) + iface.State = true + go func() { + InterfaceListenerRoutine(*iface, iface.SocketChannel) + }() +} - // check forwarding table +func InterfaceDown(iface *Interface) { + iface.SocketChannel <- true + iface.State = false +} - } +/* +func ListerToInterfaces() { + for _, iface := range myInterfaces { + go RecvIp(iface) } } - func ValidateChecksum(b []byte, fromHeader uint16) uint16 { checksum := header.Checksum(b, fromHeader) @@ -323,25 +296,67 @@ func GetNeighbors() []netip.Addr { } */ -func PrintInterfaces() { +func GetInterfaceByName(ifaceName string) (*Interface, error) { for _, iface := range myInterfaces { - fmt.Printf("%s\t%s\t%t\n", iface.Name, iface.IpPrefix.String(), iface.State) + if iface.Name == ifaceName { + return iface, nil + } } + + return nil, errors.Errorf("No interface with name %s", ifaceName) } -func PrintNeighbors() { +func SprintInterfaces() string { + buf := "" + for _, iface := range myInterfaces { + buf += fmt.Sprintf("%s\t%s\t%t\n", iface.Name, iface.IpPrefix.String(), iface.State) + } + return buf +} + +func SprintNeighbors() string { + buf := "" for ifaceName, neighbor := range myNeighbors { for _, n := range neighbor { - fmt.Printf("%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String()) + buf += fmt.Sprintf("%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String()) } } + return buf } func SprintRoutingTable() string { - message := "" + buf := "" for prefix, hop := range routingTable { - message += fmt.Sprintf("%s\t%s\t%d\n", prefix.String(), hop.VipAsStr, hop.Cost) + buf += fmt.Sprintf("%s\t%s\t%d\n", prefix.String(), hop.VipAsStr, hop.Cost) + } + return buf +} + +func DebugNeighbors() { + for ifaceName, neighbor := range myNeighbors { + for _, n := range neighbor { + fmt.Printf("%s\t%s\t%s\t%s\n", ifaceName, n.UdpAddr.String(), n.VipAddr.String(), n.SendSocket) + } + } +} + +func CleanUp() { + fmt.Print("Cleaning up...\n") + // go through the interfaces, pop thread & close the UDP FDs + for _, iface := range myInterfaces { + if iface.SocketChannel != nil { + close(iface.SocketChannel) + } + iface.RecvSocket.Close() } - return message + // go through the neighbors, pop thread & close the UDP FDs + for _, neighbor := range myNeighbors { + for _, n := range neighbor { + if n.SocketChannel != nil { + close(n.SocketChannel) + } + n.SendSocket.Close() + } + } } diff --git a/pkg/ipstack/ipstack_test.go b/pkg/ipstack/ipstack_test.go index 5530b9d..d5b755a 100644 --- a/pkg/ipstack/ipstack_test.go +++ b/pkg/ipstack/ipstack_test.go @@ -3,6 +3,7 @@ package ipstack import ( "fmt" "testing" + "time" ) func TestInitialize(t *testing.T) { @@ -11,10 +12,44 @@ func TestInitialize(t *testing.T) { if err != nil { t.Error(err) } + fmt.Printf("Interfaces:\n%s\n\n", SprintInterfaces()) + fmt.Printf("Neighbors:\n%s\n", SprintNeighbors()) + fmt.Printf("RoutingTable:\n%s\n", SprintRoutingTable()) + fmt.Println("TestInitialize successful") - PrintInterfaces() - fmt.Println("Interfaces^^") - PrintNeighbors() - fmt.Println("Neighbors^^") - fmt.Println(SprintRoutingTable()) + t.Cleanup(func() { CleanUp() }) +} + +func TestInterfaceUpThenDown(t *testing.T) { + lnxFilePath := "../../doc-example/r2.lnx" + err := Initialize(lnxFilePath) + if err != nil { + t.Error(err) + } + + iface, err := GetInterfaceByName("if0") + if err != nil { + t.Error(err) + } + + InterfaceUp(iface) + if iface.State == false { + t.Error("iface state should be true") + } + + fmt.Printf("Interfaces:\n%s\n", SprintInterfaces()) + + time.Sleep(5 * time.Millisecond) // allow time to print + + InterfaceDown(iface) + if iface.State == true { + t.Error("iface state should be false") + } + + time.Sleep(5 * time.Millisecond) // allow time to print + + fmt.Printf("Interfaces:\n%s\n", SprintInterfaces()) + + fmt.Println("TestInterfaceUpThenDown successful") + t.Cleanup(func() { CleanUp() }) } |