Skip to content
Snippets Groups Projects
ssmc.go 4.66 KiB
package main

import (
        //"encoding/binary"
        "encoding/binary"
        "fmt"
        "net"
        "os"

        //"runtime"
        "regexp"
        "time"

        "golang.org/x/net/ipv4"
)

func usage() {
        fmt.Println("USAGE: ssm <source>%<group>:<port>")
        os.Exit(0)
}

func main() {
        arguments := os.Args
        if len(arguments) < 2 {
                usage()
        }

        /* Process the SSM group join argument */

        ssmarg := arguments[1]
        re := regexp.MustCompile("^(.*)%(.*):([0-9]+)$")

        if re.MatchString(ssmarg) == false {
                usage()
        }

        ssm := re.FindAllStringSubmatch(ssmarg, -1)[0]
        source := ssm[1]
        group := ssm[2]
        port := ssm[3]

        var bindif string
        allnets, _ := net.Interfaces()

        if len(allnets) == 0 {
                fmt.Printf("Did not find any network interfaces\n")
                return
        } else if len(allnets) == 1 {
                // Only one interface available so use that.
                bindif = allnets[0].Name
        } else {
                bindif = allnets[1].Name
        }

        if len(arguments) == 4 {
                switch arguments[2] {
                case "-i":
                        bindif = arguments[3]
                default:
                        fmt.Printf("INVALID ARG: %v\n", arguments[2])
                        usage()
                }

        }

        ssmif, err := net.InterfaceByName(bindif)
        if err != nil {
                fmt.Printf("Could not bind to interface '%v'\n", bindif)
                return
        }

        c, err := net.ListenPacket("udp4", "0.0.0.0:"+port)

        if err != nil {
                fmt.Println(err)
                return
        }

        fmt.Printf("Server listening on UDP port %v\nJoining multicast (S,G)=%v,%v w/iface %v\n", port, source, group, bindif)

        defer c.Close()

        p := ipv4.NewPacketConn(c)

        ssmsource := net.UDPAddr{IP: net.ParseIP(source)}
        ssmgroup := net.UDPAddr{IP: net.ParseIP(group)}

        if err := p.JoinSourceSpecificGroup(ssmif, &ssmgroup, &ssmsource); err != nil {
                // error handling
                fmt.Println(err)
                return
        }

        b := make([]byte, 9000)
        var index int32
        var last_index int32
        index = 0
        last_index = 0
        first_packet := true
        packet_count := 0
        var stream_start time.Time
        packetSize := 0

        for {
                n, _, _, err := p.ReadFrom(b)
                if err != nil {
                        // error handling
                        fmt.Println(err)
                        return
                }
                if n != 0 {

                        packet_count++
                        index = int32(binary.BigEndian.Uint32(b[0:4]))
                        if index < 0 {
                                // end of stream
                                fmt.Printf("!EOS [i:%v]\n", index)
                                first_packet = true
                                packet_count = 0
                                last_index = 0
                                index = 0
                                continue
                        }

                        if first_packet {
                                fmt.Printf("S: %v", index)
                                first_packet = false
                                stream_start = time.Now()
                        } else if index == (last_index + 1) {
                                fmt.Printf("\033[32m!\033[0m")
                        } else if index < last_index {
                                fmt.Printf("\033[34m<\033[0m")
                        } else if index > (last_index + 1) {
                                fmt.Printf("\033[36m>\033[0m")
                        } else {
                                fmt.Printf("\033[31m.\033[0m")
                        }
                        last_index = index

                        if n != packetSize {
                                fmt.Printf("PS[%v]", n)
                                packetSize = n
                        }

                        if time.Since(stream_start)/time.Second >= 5 {
                                fmt.Printf("[%vpps]", packet_count/5)
                                stream_start = time.Now()
                                packet_count = 0
                        }
                } else {
                        break
                }
        }

        if err := p.LeaveSourceSpecificGroup(ssmif, &ssmgroup, &ssmsource); err != nil {
                // error handling
                fmt.Println(err)
                return
        }
}