diff --git a/0x17_authorization_registration_service.go b/0x17_authorization_registration_service.go index badc5ea..9c56e36 100644 --- a/0x17_authorization_registration_service.go +++ b/0x17_authorization_registration_service.go @@ -1,22 +1,35 @@ package main import ( - "context" - "fmt" + "errors" ) -var cipher = "hey wassup" - type AuthorizationRegistrationService struct{} -func (a *AuthorizationRegistrationService) HandleSNAC(ctx context.Context, snac *SNAC) { +func (a *AuthorizationRegistrationService) HandleSNAC(session *Session, snac *SNAC) error { + switch snac.Header.Subtype { // Request MD5 Auth Key - if snac.Header.Subtype == 0x06 { - fmt.Println("damn it's 0x06") - // cipherData := ByteString(cipher) // []byte + case 0x06: + tlvs, err := UnmarshalTLVs(snac.Data) + panicIfError(err) - // snac := NewSNAC(0x17, 0x07, cipherData) + usernameTLV := FindTLV(tlvs, 1) + if usernameTLV == nil { + return errors.New("missing username TLV") + } - // resp := NewFLAP(2, snac) + // Create cipher for this user + cipher := "howdy" + db.Set("cipher-"+string(usernameTLV.Data), cipher) + cipherData := []byte(cipher) + + snac := NewSNAC(0x17, 0x07, cipherData) + snacBytes, err := snac.MarshalBinary() + panicIfError(err) + + resp := NewFLAP(session, 2, snacBytes) + return session.Send(resp) } + + return nil } diff --git a/db.go b/db.go new file mode 100644 index 0000000..3e19fc8 --- /dev/null +++ b/db.go @@ -0,0 +1,19 @@ +package main + +type DB struct { + data map[string]string +} + +func (d *DB) Init() { + d.data = make(map[string]string) +} + +func (d *DB) Get(key string) (data string, ok bool) { + data, ok = d.data[key] + return +} + +func (d *DB) Set(key string, value string) error { + d.data[key] = value + return nil +} diff --git a/flap.go b/flap.go index f6298d7..526d545 100644 --- a/flap.go +++ b/flap.go @@ -2,7 +2,6 @@ package main import ( "bytes" - "context" "encoding" "encoding/binary" "fmt" @@ -22,8 +21,7 @@ type FLAP struct { Data []byte } -func NewFLAP(ctx context.Context, channel uint8, data []byte) *FLAP { - session := ctx.Value("session").(*Session) +func NewFLAP(session *Session, channel uint8, data []byte) *FLAP { session.SequenceNumber += 1 return &FLAP{ @@ -36,6 +34,21 @@ func NewFLAP(ctx context.Context, channel uint8, data []byte) *FLAP { } } +func (f *FLAP) MarshalBinary() ([]byte, error) { + var buf bytes.Buffer + buf.WriteByte(0x2a) + binary.Write(&buf, binary.BigEndian, f.Header) + n, err := buf.Write(f.Data) + if n != int(f.Header.DataLength) { + return nil, fmt.Errorf("needed to write %d bytes to buffer but wrote %d", f.Header.DataLength, n) + } + if err != nil { + return nil, err + } + + return buf.Bytes(), nil +} + func (f *FLAP) UnmarshalBinary(data []byte) error { buf := bytes.NewBuffer(data) start, err := buf.ReadByte() @@ -54,17 +67,10 @@ func (f *FLAP) UnmarshalBinary(data []byte) error { return nil } -func (f *FLAP) MarshalBinary() ([]byte, error) { - var buf bytes.Buffer - buf.WriteByte(0x2a) - binary.Write(&buf, binary.BigEndian, f.Header) - n, err := buf.Write(f.Data) - if n != int(f.Header.DataLength) { - return nil, fmt.Errorf("needed to write %d bytes to buffer but wrote %d", f.Header.DataLength, n) - } - if err != nil { - return nil, err - } - - return buf.Bytes(), nil +func (f *FLAP) Len() int { + return 6 + int(f.Header.DataLength) +} + +func (f *FLAP) String() string { + return fmt.Sprintf("FLAP(CH:%d, SEQ:%d):\n%s", f.Header.Channel, f.Header.SequenceNumber, prettyBytes(f.Data)) } diff --git a/main.go b/main.go index 5063f0d..e80e43a 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,6 @@ package main import ( - "context" "fmt" "io" "log" @@ -10,7 +9,7 @@ import ( "os/signal" "syscall" - "github.com/ghostiam/binstruct" + "github.com/pkg/errors" ) const ( @@ -20,8 +19,11 @@ const ( ) var services = make(map[uint16]Service) +var db *DB = nil func init() { + db = &DB{} + db.Init() services[0x17] = &AuthorizationRegistrationService{} } @@ -49,29 +51,27 @@ func main() { os.Exit(1) } - ctx := NewContextWithSession(context.Background(), conn) + session := NewSession(conn) log.Printf("Connection from %v", conn.RemoteAddr()) - go handleTCPConnection(ctx, conn) + go handleTCPConnection(session, conn) } } -func handleTCPConnection(ctx context.Context, conn net.Conn) { - defer (func() { - if r := recover(); r != nil { - log.Println("Error handling message: ", r.(error).Error()) - } - conn.Close() - log.Printf("Closed connection to %v", conn.RemoteAddr()) - })() +func handleTCPConnection(session *Session, conn net.Conn) { + // defer (func() { + // if err := recover(); err != nil { + // log.Printf("Error handling message: %+v\n", err.(error)) + // } + // conn.Close() + // log.Printf("Closed connection to %v", conn.RemoteAddr()) + // })() buf := make([]byte, 1024) for { - session, err := CurrentSession(ctx) - panicIfError(err) if !session.GreetedClient { // send a hello - hello := NewFLAP(ctx, 1, []byte{0, 0, 0, 1}) + hello := NewFLAP(session, 1, []byte{0, 0, 0, 1}) err := session.Send(hello) panicIfError(err) session.GreetedClient = true @@ -87,25 +87,47 @@ func handleTCPConnection(ctx context.Context, conn net.Conn) { return } - fmt.Printf("%v ->\n%s\n\n", conn.RemoteAddr(), prettyBytes(buf[:n])) + // Try to parse all of the FLAPs in the buffer if we have enough bytes to + // fill a FLAP header + for len(buf) >= 6 && buf[0] == 0x2a { + dataLength := Word(buf[4:6]) + flapLength := int(dataLength) + 6 + if len(buf) < flapLength { + log.Printf("not enough data, only %d bytes\n", len(buf)) + break + } - flap := &FLAP{} - binstruct.UnmarshalBE(buf[:n], flap) - - handleMessage(ctx, flap) + flap := &FLAP{} + if err := flap.UnmarshalBinary(buf[:flapLength]); err != nil { + panicIfError(errors.Wrap(err, "could not unmarshal FLAP")) + } + buf = buf[flapLength:] + fmt.Printf("%v ->\n%+v\n", conn.RemoteAddr(), flap) + handleMessage(session, flap) + } } } -func handleMessage(ctx context.Context, flap *FLAP) { - if flap.Channel == 1 { +func handleMessage(session *Session, flap *FLAP) { + if flap.Header.Channel == 1 { - } else if flap.Channel == 2 { + } else if flap.Header.Channel == 2 { snac := &SNAC{} err := snac.UnmarshalBinary(flap.Data) panicIfError(err) + fmt.Printf("%+v\n", snac) + if tlvs, err := UnmarshalTLVs(snac.Data); err == nil { + for _, tlv := range tlvs { + fmt.Printf("%+v\n", tlv) + } + } else { + fmt.Printf("%s\n\n", prettyBytes(snac.Data)) + } + if service, ok := services[snac.Header.Family]; ok { - service.HandleSNAC(ctx, snac) + err = service.HandleSNAC(session, snac) + panicIfError(err) } } } diff --git a/service.go b/service.go index d7de403..74c7ede 100644 --- a/service.go +++ b/service.go @@ -1,9 +1,5 @@ package main -import ( - "context" -) - type Service interface { - HandleSNAC(context.Context, *SNAC) + HandleSNAC(*Session, *SNAC) error } diff --git a/session.go b/session.go index 8e575c4..ff4f9ae 100644 --- a/session.go +++ b/session.go @@ -5,6 +5,8 @@ import ( "encoding" "fmt" "net" + + "github.com/pkg/errors" ) type sessionKey string @@ -37,20 +39,20 @@ func NewContextWithSession(ctx context.Context, conn net.Conn) context.Context { } func CurrentSession(ctx context.Context) (session *Session, err error) { - session, ok := ctx.Value(currentSession).(*Session) - if !ok { - return nil, fmt.Errorf("no session in context") + s := ctx.Value(currentSession) + if s == nil { + return nil, errors.New("no session in context") } - return + return s.(*Session), nil } func (s *Session) Send(m encoding.BinaryMarshaler) error { bytes, err := m.MarshalBinary() if err != nil { - return err + return errors.Wrap(err, "could not marshal message") } fmt.Printf("-> %v\n%s\n\n", s.Conn.RemoteAddr(), prettyBytes(bytes)) _, err = s.Conn.Write(bytes) - return err + return errors.Wrap(err, "could not write to client connection") } diff --git a/snac.go b/snac.go index 666727f..3d3b0fc 100644 --- a/snac.go +++ b/snac.go @@ -2,15 +2,19 @@ package main import ( "bytes" + "encoding" "encoding/binary" "fmt" ) +var _ encoding.BinaryUnmarshaler = &SNAC{} +var _ encoding.BinaryMarshaler = &SNAC{} + type SNACHeader struct { Family uint16 Subtype uint16 Flags uint16 - RequestID uint16 + RequestID uint32 } type SNAC struct { @@ -19,6 +23,9 @@ type SNAC struct { } func NewSNAC(family uint16, subtype uint16, data []byte) *SNAC { + d := make([]byte, 0, len(data)) + copy(d, data) + return &SNAC{ Header: SNACHeader{ Family: family, @@ -26,7 +33,7 @@ func NewSNAC(family uint16, subtype uint16, data []byte) *SNAC { Flags: 0, RequestID: 0, }, - Data: data, + Data: d, } } @@ -47,12 +54,16 @@ func (s *SNAC) MarshalBinary() ([]byte, error) { func (s *SNAC) UnmarshalBinary(data []byte) error { buf := bytes.NewBuffer(data) - - if err := binary.Read(buf, binary.BigEndian, &s.Header); s != nil { + if err := binary.Read(buf, binary.BigEndian, &s.Header); err != nil { return err } - s.Data = buf.Bytes() + s.Data = make([]byte, buf.Len()) + copy(s.Data, buf.Bytes()) return nil } + +func (s *SNAC) String() string { + return fmt.Sprintf("SNAC(%#x, %#x)", s.Header.Family, s.Header.Subtype) +} diff --git a/tlv.go b/tlv.go new file mode 100644 index 0000000..2d3fe2e --- /dev/null +++ b/tlv.go @@ -0,0 +1,114 @@ +package main + +import ( + "encoding" + "encoding/binary" + "fmt" + "io" + + "github.com/pkg/errors" +) + +var _ encoding.BinaryUnmarshaler = &TLV{} +var _ encoding.BinaryMarshaler = &TLV{} + +type TLV struct { + Type uint16 + DataLength uint16 + Data []byte +} + +func NewTLV(tlvType uint16, data []byte) *TLV { + return &TLV{ + Type: tlvType, + DataLength: uint16(len(data)), + Data: data, + } +} + +func (t *TLV) Len() int { + return 4 + int(t.DataLength) +} + +func (t *TLV) MarshalBinary() ([]byte, error) { + buf := make([]byte, 4+t.DataLength) + binary.BigEndian.PutUint16(buf[:2], t.Type) + binary.BigEndian.PutUint16(buf[2:4], t.DataLength) + copy(buf[4:], t.Data) + return buf, nil +} + +func (t *TLV) UnmarshalBinary(data []byte) error { + if len(data) < 4 { + return io.ErrUnexpectedEOF + } + t.Type = Word(data[:2]) + t.DataLength = Word(data[2:4]) + if len(data) < 4+int(t.DataLength) { + return io.ErrUnexpectedEOF + } + t.Data = make([]byte, int(t.DataLength)) + copy(t.Data, data[4:4+int(t.DataLength)]) + return nil +} + +func (t *TLV) String() string { + return fmt.Sprintf("TLV(%#x):\n%s", t.Type, prettyBytes(t.Data)) +} + +func UnmarshalTLVs(data []byte) ([]*TLV, error) { + tlvs := make([]*TLV, 0) + d := make([]byte, len(data)) + copy(d, data) + + for len(d) > 0 { + tlv := &TLV{} + if err := tlv.UnmarshalBinary(d); err != nil { + return nil, errors.Wrap(err, "enexpected end to unmarshalling TLVs") + } + tlvs = append(tlvs, tlv) + d = d[tlv.Len():] + } + return tlvs, nil +} + +func FindTLV(tlvs []*TLV, tlvType uint16) *TLV { + for _, tlv := range tlvs { + if tlv.Type == tlvType { + return tlv + } + } + return nil +} + +// type TLVReader struct { +// buf []byte +// pos int +// } + +// func (r *TLVReader) ReadNextTLV() (*TLV, error) { +// if len(r.buf) < 4 { +// return nil, io.ErrUnexpectedEOF +// } + +// t := &TLV{} +// t.Type = Word(r.buf[r.pos:r.pos+2]) +// r.pos = r.pos + 2 +// t.DataLength = Word(r.buf[r.pos : r.pos+2]) +// r.pos = r.pos + 2 +// copy(p[2:4], r.buf[r.pos:r.pos+2]) +// r.pos = r.pos + 2 + +// // If there is not enough space to write the expected amount of data, error +// if dataLength > len(p)+4 { +// return 0, io.ErrUnexpectedEOF +// } + +// n = n + copy(p[4:dataLength], r.buf[r.pos:r.pos+dataLength]) +// r.pos = r.pos + dataLength +// return n, nil +// } + +// func NewTLVReader(data []byte) *TLVReader { +// return &TLVReader{buf: data, pos: 0} +// } diff --git a/util.go b/util.go index 52d6ff9..44e1b17 100644 --- a/util.go +++ b/util.go @@ -23,3 +23,13 @@ func panicIfError(err error) { panic(err) } } + +func Word(b []byte) uint16 { + var _ = b[1] + return uint16(b[1]) | uint16(b[0])<<8 +} + +func DWord(b []byte) uint32 { + var _ = b[3] + return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 +}