diff --git a/go.mod b/go.mod index 53c9df3..a38dd09 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,9 @@ module aim-oscar go 1.16 + +require ( + github.com/ghostiam/binstruct v1.0.1 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/xhebox/bstruct v0.0.0-20210125172452-c36e1006cf29 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..27738bc --- /dev/null +++ b/go.sum @@ -0,0 +1,13 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/ghostiam/binstruct v1.0.1 h1:sg7Hi5c5b0noCDfRQpx3K2x9HLJYRafKnQRif7Orkek= +github.com/ghostiam/binstruct v1.0.1/go.mod h1:+NZwEDbcfME8MhF7nQRjAZV4U00c6XpNuk+nkvOxzvo= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/xhebox/bstruct v0.0.0-20210125172452-c36e1006cf29 h1:lMx3E/urZ0wDiG27VXTEEq0RG/MHsscHPkYiS5KIjcE= +github.com/xhebox/bstruct v0.0.0-20210125172452-c36e1006cf29/go.mod h1:KZ+p86hKYsjjg3oG2RfDFJFGkn85Z60Xhc5E5P58FXY= diff --git a/main.go b/main.go index 4856b51..5063f0d 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,8 @@ import ( "os" "os/signal" "syscall" + + "github.com/ghostiam/binstruct" ) const ( @@ -47,7 +49,7 @@ func main() { os.Exit(1) } - ctx := context.WithValue(context.Background(), "session", NewSession(conn)) + ctx := NewContextWithSession(context.Background(), conn) log.Printf("Connection from %v", conn.RemoteAddr()) go handleTCPConnection(ctx, conn) @@ -65,7 +67,8 @@ func handleTCPConnection(ctx context.Context, conn net.Conn) { buf := make([]byte, 1024) for { - session := ctx.Value("session").(*Session) + session, err := CurrentSession(ctx) + panicIfError(err) if !session.GreetedClient { // send a hello hello := NewFLAP(ctx, 1, []byte{0, 0, 0, 1}) @@ -85,17 +88,18 @@ func handleTCPConnection(ctx context.Context, conn net.Conn) { } fmt.Printf("%v ->\n%s\n\n", conn.RemoteAddr(), prettyBytes(buf[:n])) - handleMessage(ctx, buf[:n]) + + flap := &FLAP{} + binstruct.UnmarshalBE(buf[:n], flap) + + handleMessage(ctx, flap) } } -func handleMessage(ctx context.Context, buf []byte) { - flap := &FLAP{} - flap.UnmarshalBinary(buf) +func handleMessage(ctx context.Context, flap *FLAP) { + if flap.Channel == 1 { - if flap.Header.Channel == 1 { - - } else if flap.Header.Channel == 2 { + } else if flap.Channel == 2 { snac := &SNAC{} err := snac.UnmarshalBinary(flap.Data) panicIfError(err) diff --git a/session.go b/session.go index 7d82128..8e575c4 100644 --- a/session.go +++ b/session.go @@ -1,11 +1,22 @@ package main import ( + "context" "encoding" "fmt" "net" ) +type sessionKey string + +func (s sessionKey) String() string { + return "oscar-" + string(s) +} + +var ( + currentSession = sessionKey("session") +) + type Session struct { Conn net.Conn SequenceNumber uint16 @@ -20,6 +31,19 @@ func NewSession(conn net.Conn) *Session { } } +func NewContextWithSession(ctx context.Context, conn net.Conn) context.Context { + session := NewSession(conn) + return context.WithValue(ctx, currentSession, session) +} + +func CurrentSession(ctx context.Context) (session *Session, err error) { + session, ok := ctx.Value(currentSession).(*Session) + if !ok { + return nil, fmt.Errorf("no session in context") + } + return +} + func (s *Session) Send(m encoding.BinaryMarshaler) error { bytes, err := m.MarshalBinary() if err != nil { diff --git a/util.go b/util.go index 2aa5091..52d6ff9 100644 --- a/util.go +++ b/util.go @@ -1,11 +1,7 @@ package main import ( - "bytes" - "encoding/binary" "encoding/hex" - "fmt" - "io" ) func prettyBytes(bytes []byte) string { @@ -22,28 +18,8 @@ func prettyBytes(bytes []byte) string { return res } -func printBytes(bytes []byte) { - fmt.Printf("%s\n", prettyBytes(bytes)) -} - func panicIfError(err error) { if err != nil { panic(err) } } - -func readNBytes(buf *bytes.Buffer, n int) ([]byte, error) { - res := make([]byte, n) - _, err := io.ReadFull(buf, res) - return res, err -} - -func mustReadNBytes(buf *bytes.Buffer, n int) []byte { - res, err := readNBytes(buf, n) - panicIfError(err) - return res -} - -func mustRead(buf *bytes.Buffer, dest interface{}) { - panicIfError(binary.Read(buf, binary.BigEndian, dest)) -}