flesh out data types and start some db access

This commit is contained in:
Artem Titoulenko 2021-12-15 19:55:43 -05:00
parent 5daca9ab52
commit e057e83599
9 changed files with 259 additions and 66 deletions

View file

@ -1,22 +1,35 @@
package main
import (
"context"
"fmt"
"errors"
)
var cipher = "hey wassup"
type AuthorizationRegistrationService struct{}
func (a *AuthorizationRegistrationService) HandleSNAC(ctx context.Context, snac *SNAC) {
func (a *AuthorizationRegistrationService) HandleSNAC(session *Session, snac *SNAC) error {
switch snac.Header.Subtype {
// Request MD5 Auth Key
if snac.Header.Subtype == 0x06 {
fmt.Println("damn it's 0x06")
// cipherData := ByteString(cipher) // []byte
case 0x06:
tlvs, err := UnmarshalTLVs(snac.Data)
panicIfError(err)
// snac := NewSNAC(0x17, 0x07, cipherData)
usernameTLV := FindTLV(tlvs, 1)
if usernameTLV == nil {
return errors.New("missing username TLV")
}
// resp := NewFLAP(2, snac)
// Create cipher for this user
cipher := "howdy"
db.Set("cipher-"+string(usernameTLV.Data), cipher)
cipherData := []byte(cipher)
snac := NewSNAC(0x17, 0x07, cipherData)
snacBytes, err := snac.MarshalBinary()
panicIfError(err)
resp := NewFLAP(session, 2, snacBytes)
return session.Send(resp)
}
return nil
}

19
db.go Normal file
View file

@ -0,0 +1,19 @@
package main
type DB struct {
data map[string]string
}
func (d *DB) Init() {
d.data = make(map[string]string)
}
func (d *DB) Get(key string) (data string, ok bool) {
data, ok = d.data[key]
return
}
func (d *DB) Set(key string, value string) error {
d.data[key] = value
return nil
}

38
flap.go
View file

@ -2,7 +2,6 @@ package main
import (
"bytes"
"context"
"encoding"
"encoding/binary"
"fmt"
@ -22,8 +21,7 @@ type FLAP struct {
Data []byte
}
func NewFLAP(ctx context.Context, channel uint8, data []byte) *FLAP {
session := ctx.Value("session").(*Session)
func NewFLAP(session *Session, channel uint8, data []byte) *FLAP {
session.SequenceNumber += 1
return &FLAP{
@ -36,6 +34,21 @@ func NewFLAP(ctx context.Context, channel uint8, data []byte) *FLAP {
}
}
func (f *FLAP) MarshalBinary() ([]byte, error) {
var buf bytes.Buffer
buf.WriteByte(0x2a)
binary.Write(&buf, binary.BigEndian, f.Header)
n, err := buf.Write(f.Data)
if n != int(f.Header.DataLength) {
return nil, fmt.Errorf("needed to write %d bytes to buffer but wrote %d", f.Header.DataLength, n)
}
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (f *FLAP) UnmarshalBinary(data []byte) error {
buf := bytes.NewBuffer(data)
start, err := buf.ReadByte()
@ -54,17 +67,10 @@ func (f *FLAP) UnmarshalBinary(data []byte) error {
return nil
}
func (f *FLAP) MarshalBinary() ([]byte, error) {
var buf bytes.Buffer
buf.WriteByte(0x2a)
binary.Write(&buf, binary.BigEndian, f.Header)
n, err := buf.Write(f.Data)
if n != int(f.Header.DataLength) {
return nil, fmt.Errorf("needed to write %d bytes to buffer but wrote %d", f.Header.DataLength, n)
}
if err != nil {
return nil, err
}
return buf.Bytes(), nil
func (f *FLAP) Len() int {
return 6 + int(f.Header.DataLength)
}
func (f *FLAP) String() string {
return fmt.Sprintf("FLAP(CH:%d, SEQ:%d):\n%s", f.Header.Channel, f.Header.SequenceNumber, prettyBytes(f.Data))
}

70
main.go
View file

@ -1,7 +1,6 @@
package main
import (
"context"
"fmt"
"io"
"log"
@ -10,7 +9,7 @@ import (
"os/signal"
"syscall"
"github.com/ghostiam/binstruct"
"github.com/pkg/errors"
)
const (
@ -20,8 +19,11 @@ const (
)
var services = make(map[uint16]Service)
var db *DB = nil
func init() {
db = &DB{}
db.Init()
services[0x17] = &AuthorizationRegistrationService{}
}
@ -49,29 +51,27 @@ func main() {
os.Exit(1)
}
ctx := NewContextWithSession(context.Background(), conn)
session := NewSession(conn)
log.Printf("Connection from %v", conn.RemoteAddr())
go handleTCPConnection(ctx, conn)
go handleTCPConnection(session, conn)
}
}
func handleTCPConnection(ctx context.Context, conn net.Conn) {
defer (func() {
if r := recover(); r != nil {
log.Println("Error handling message: ", r.(error).Error())
}
conn.Close()
log.Printf("Closed connection to %v", conn.RemoteAddr())
})()
func handleTCPConnection(session *Session, conn net.Conn) {
// defer (func() {
// if err := recover(); err != nil {
// log.Printf("Error handling message: %+v\n", err.(error))
// }
// conn.Close()
// log.Printf("Closed connection to %v", conn.RemoteAddr())
// })()
buf := make([]byte, 1024)
for {
session, err := CurrentSession(ctx)
panicIfError(err)
if !session.GreetedClient {
// send a hello
hello := NewFLAP(ctx, 1, []byte{0, 0, 0, 1})
hello := NewFLAP(session, 1, []byte{0, 0, 0, 1})
err := session.Send(hello)
panicIfError(err)
session.GreetedClient = true
@ -87,25 +87,47 @@ func handleTCPConnection(ctx context.Context, conn net.Conn) {
return
}
fmt.Printf("%v ->\n%s\n\n", conn.RemoteAddr(), prettyBytes(buf[:n]))
// Try to parse all of the FLAPs in the buffer if we have enough bytes to
// fill a FLAP header
for len(buf) >= 6 && buf[0] == 0x2a {
dataLength := Word(buf[4:6])
flapLength := int(dataLength) + 6
if len(buf) < flapLength {
log.Printf("not enough data, only %d bytes\n", len(buf))
break
}
flap := &FLAP{}
binstruct.UnmarshalBE(buf[:n], flap)
handleMessage(ctx, flap)
flap := &FLAP{}
if err := flap.UnmarshalBinary(buf[:flapLength]); err != nil {
panicIfError(errors.Wrap(err, "could not unmarshal FLAP"))
}
buf = buf[flapLength:]
fmt.Printf("%v ->\n%+v\n", conn.RemoteAddr(), flap)
handleMessage(session, flap)
}
}
}
func handleMessage(ctx context.Context, flap *FLAP) {
if flap.Channel == 1 {
func handleMessage(session *Session, flap *FLAP) {
if flap.Header.Channel == 1 {
} else if flap.Channel == 2 {
} else if flap.Header.Channel == 2 {
snac := &SNAC{}
err := snac.UnmarshalBinary(flap.Data)
panicIfError(err)
fmt.Printf("%+v\n", snac)
if tlvs, err := UnmarshalTLVs(snac.Data); err == nil {
for _, tlv := range tlvs {
fmt.Printf("%+v\n", tlv)
}
} else {
fmt.Printf("%s\n\n", prettyBytes(snac.Data))
}
if service, ok := services[snac.Header.Family]; ok {
service.HandleSNAC(ctx, snac)
err = service.HandleSNAC(session, snac)
panicIfError(err)
}
}
}

View file

@ -1,9 +1,5 @@
package main
import (
"context"
)
type Service interface {
HandleSNAC(context.Context, *SNAC)
HandleSNAC(*Session, *SNAC) error
}

View file

@ -5,6 +5,8 @@ import (
"encoding"
"fmt"
"net"
"github.com/pkg/errors"
)
type sessionKey string
@ -37,20 +39,20 @@ func NewContextWithSession(ctx context.Context, conn net.Conn) context.Context {
}
func CurrentSession(ctx context.Context) (session *Session, err error) {
session, ok := ctx.Value(currentSession).(*Session)
if !ok {
return nil, fmt.Errorf("no session in context")
s := ctx.Value(currentSession)
if s == nil {
return nil, errors.New("no session in context")
}
return
return s.(*Session), nil
}
func (s *Session) Send(m encoding.BinaryMarshaler) error {
bytes, err := m.MarshalBinary()
if err != nil {
return err
return errors.Wrap(err, "could not marshal message")
}
fmt.Printf("-> %v\n%s\n\n", s.Conn.RemoteAddr(), prettyBytes(bytes))
_, err = s.Conn.Write(bytes)
return err
return errors.Wrap(err, "could not write to client connection")
}

21
snac.go
View file

@ -2,15 +2,19 @@ package main
import (
"bytes"
"encoding"
"encoding/binary"
"fmt"
)
var _ encoding.BinaryUnmarshaler = &SNAC{}
var _ encoding.BinaryMarshaler = &SNAC{}
type SNACHeader struct {
Family uint16
Subtype uint16
Flags uint16
RequestID uint16
RequestID uint32
}
type SNAC struct {
@ -19,6 +23,9 @@ type SNAC struct {
}
func NewSNAC(family uint16, subtype uint16, data []byte) *SNAC {
d := make([]byte, 0, len(data))
copy(d, data)
return &SNAC{
Header: SNACHeader{
Family: family,
@ -26,7 +33,7 @@ func NewSNAC(family uint16, subtype uint16, data []byte) *SNAC {
Flags: 0,
RequestID: 0,
},
Data: data,
Data: d,
}
}
@ -47,12 +54,16 @@ func (s *SNAC) MarshalBinary() ([]byte, error) {
func (s *SNAC) UnmarshalBinary(data []byte) error {
buf := bytes.NewBuffer(data)
if err := binary.Read(buf, binary.BigEndian, &s.Header); s != nil {
if err := binary.Read(buf, binary.BigEndian, &s.Header); err != nil {
return err
}
s.Data = buf.Bytes()
s.Data = make([]byte, buf.Len())
copy(s.Data, buf.Bytes())
return nil
}
func (s *SNAC) String() string {
return fmt.Sprintf("SNAC(%#x, %#x)", s.Header.Family, s.Header.Subtype)
}

114
tlv.go Normal file
View file

@ -0,0 +1,114 @@
package main
import (
"encoding"
"encoding/binary"
"fmt"
"io"
"github.com/pkg/errors"
)
var _ encoding.BinaryUnmarshaler = &TLV{}
var _ encoding.BinaryMarshaler = &TLV{}
type TLV struct {
Type uint16
DataLength uint16
Data []byte
}
func NewTLV(tlvType uint16, data []byte) *TLV {
return &TLV{
Type: tlvType,
DataLength: uint16(len(data)),
Data: data,
}
}
func (t *TLV) Len() int {
return 4 + int(t.DataLength)
}
func (t *TLV) MarshalBinary() ([]byte, error) {
buf := make([]byte, 4+t.DataLength)
binary.BigEndian.PutUint16(buf[:2], t.Type)
binary.BigEndian.PutUint16(buf[2:4], t.DataLength)
copy(buf[4:], t.Data)
return buf, nil
}
func (t *TLV) UnmarshalBinary(data []byte) error {
if len(data) < 4 {
return io.ErrUnexpectedEOF
}
t.Type = Word(data[:2])
t.DataLength = Word(data[2:4])
if len(data) < 4+int(t.DataLength) {
return io.ErrUnexpectedEOF
}
t.Data = make([]byte, int(t.DataLength))
copy(t.Data, data[4:4+int(t.DataLength)])
return nil
}
func (t *TLV) String() string {
return fmt.Sprintf("TLV(%#x):\n%s", t.Type, prettyBytes(t.Data))
}
func UnmarshalTLVs(data []byte) ([]*TLV, error) {
tlvs := make([]*TLV, 0)
d := make([]byte, len(data))
copy(d, data)
for len(d) > 0 {
tlv := &TLV{}
if err := tlv.UnmarshalBinary(d); err != nil {
return nil, errors.Wrap(err, "enexpected end to unmarshalling TLVs")
}
tlvs = append(tlvs, tlv)
d = d[tlv.Len():]
}
return tlvs, nil
}
func FindTLV(tlvs []*TLV, tlvType uint16) *TLV {
for _, tlv := range tlvs {
if tlv.Type == tlvType {
return tlv
}
}
return nil
}
// type TLVReader struct {
// buf []byte
// pos int
// }
// func (r *TLVReader) ReadNextTLV() (*TLV, error) {
// if len(r.buf) < 4 {
// return nil, io.ErrUnexpectedEOF
// }
// t := &TLV{}
// t.Type = Word(r.buf[r.pos:r.pos+2])
// r.pos = r.pos + 2
// t.DataLength = Word(r.buf[r.pos : r.pos+2])
// r.pos = r.pos + 2
// copy(p[2:4], r.buf[r.pos:r.pos+2])
// r.pos = r.pos + 2
// // If there is not enough space to write the expected amount of data, error
// if dataLength > len(p)+4 {
// return 0, io.ErrUnexpectedEOF
// }
// n = n + copy(p[4:dataLength], r.buf[r.pos:r.pos+dataLength])
// r.pos = r.pos + dataLength
// return n, nil
// }
// func NewTLVReader(data []byte) *TLVReader {
// return &TLVReader{buf: data, pos: 0}
// }

10
util.go
View file

@ -23,3 +23,13 @@ func panicIfError(err error) {
panic(err)
}
}
func Word(b []byte) uint16 {
var _ = b[1]
return uint16(b[1]) | uint16(b[0])<<8
}
func DWord(b []byte) uint32 {
var _ = b[3]
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}