flesh out data types and start some db access

This commit is contained in:
Artem Titoulenko 2021-12-15 19:55:43 -05:00
parent 5daca9ab52
commit e057e83599
9 changed files with 259 additions and 66 deletions

View file

@ -1,22 +1,35 @@
package main package main
import ( import (
"context" "errors"
"fmt"
) )
var cipher = "hey wassup"
type AuthorizationRegistrationService struct{} type AuthorizationRegistrationService struct{}
func (a *AuthorizationRegistrationService) HandleSNAC(ctx context.Context, snac *SNAC) { func (a *AuthorizationRegistrationService) HandleSNAC(session *Session, snac *SNAC) error {
switch snac.Header.Subtype {
// Request MD5 Auth Key // Request MD5 Auth Key
if snac.Header.Subtype == 0x06 { case 0x06:
fmt.Println("damn it's 0x06") tlvs, err := UnmarshalTLVs(snac.Data)
// cipherData := ByteString(cipher) // []byte panicIfError(err)
// snac := NewSNAC(0x17, 0x07, cipherData) usernameTLV := FindTLV(tlvs, 1)
if usernameTLV == nil {
return errors.New("missing username TLV")
}
// resp := NewFLAP(2, snac) // Create cipher for this user
cipher := "howdy"
db.Set("cipher-"+string(usernameTLV.Data), cipher)
cipherData := []byte(cipher)
snac := NewSNAC(0x17, 0x07, cipherData)
snacBytes, err := snac.MarshalBinary()
panicIfError(err)
resp := NewFLAP(session, 2, snacBytes)
return session.Send(resp)
} }
return nil
} }

19
db.go Normal file
View file

@ -0,0 +1,19 @@
package main
type DB struct {
data map[string]string
}
func (d *DB) Init() {
d.data = make(map[string]string)
}
func (d *DB) Get(key string) (data string, ok bool) {
data, ok = d.data[key]
return
}
func (d *DB) Set(key string, value string) error {
d.data[key] = value
return nil
}

38
flap.go
View file

@ -2,7 +2,6 @@ package main
import ( import (
"bytes" "bytes"
"context"
"encoding" "encoding"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
@ -22,8 +21,7 @@ type FLAP struct {
Data []byte Data []byte
} }
func NewFLAP(ctx context.Context, channel uint8, data []byte) *FLAP { func NewFLAP(session *Session, channel uint8, data []byte) *FLAP {
session := ctx.Value("session").(*Session)
session.SequenceNumber += 1 session.SequenceNumber += 1
return &FLAP{ return &FLAP{
@ -36,6 +34,21 @@ func NewFLAP(ctx context.Context, channel uint8, data []byte) *FLAP {
} }
} }
func (f *FLAP) MarshalBinary() ([]byte, error) {
var buf bytes.Buffer
buf.WriteByte(0x2a)
binary.Write(&buf, binary.BigEndian, f.Header)
n, err := buf.Write(f.Data)
if n != int(f.Header.DataLength) {
return nil, fmt.Errorf("needed to write %d bytes to buffer but wrote %d", f.Header.DataLength, n)
}
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}
func (f *FLAP) UnmarshalBinary(data []byte) error { func (f *FLAP) UnmarshalBinary(data []byte) error {
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
start, err := buf.ReadByte() start, err := buf.ReadByte()
@ -54,17 +67,10 @@ func (f *FLAP) UnmarshalBinary(data []byte) error {
return nil return nil
} }
func (f *FLAP) MarshalBinary() ([]byte, error) { func (f *FLAP) Len() int {
var buf bytes.Buffer return 6 + int(f.Header.DataLength)
buf.WriteByte(0x2a) }
binary.Write(&buf, binary.BigEndian, f.Header)
n, err := buf.Write(f.Data) func (f *FLAP) String() string {
if n != int(f.Header.DataLength) { return fmt.Sprintf("FLAP(CH:%d, SEQ:%d):\n%s", f.Header.Channel, f.Header.SequenceNumber, prettyBytes(f.Data))
return nil, fmt.Errorf("needed to write %d bytes to buffer but wrote %d", f.Header.DataLength, n)
}
if err != nil {
return nil, err
}
return buf.Bytes(), nil
} }

70
main.go
View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -10,7 +9,7 @@ import (
"os/signal" "os/signal"
"syscall" "syscall"
"github.com/ghostiam/binstruct" "github.com/pkg/errors"
) )
const ( const (
@ -20,8 +19,11 @@ const (
) )
var services = make(map[uint16]Service) var services = make(map[uint16]Service)
var db *DB = nil
func init() { func init() {
db = &DB{}
db.Init()
services[0x17] = &AuthorizationRegistrationService{} services[0x17] = &AuthorizationRegistrationService{}
} }
@ -49,29 +51,27 @@ func main() {
os.Exit(1) os.Exit(1)
} }
ctx := NewContextWithSession(context.Background(), conn) session := NewSession(conn)
log.Printf("Connection from %v", conn.RemoteAddr()) log.Printf("Connection from %v", conn.RemoteAddr())
go handleTCPConnection(ctx, conn) go handleTCPConnection(session, conn)
} }
} }
func handleTCPConnection(ctx context.Context, conn net.Conn) { func handleTCPConnection(session *Session, conn net.Conn) {
defer (func() { // defer (func() {
if r := recover(); r != nil { // if err := recover(); err != nil {
log.Println("Error handling message: ", r.(error).Error()) // log.Printf("Error handling message: %+v\n", err.(error))
} // }
conn.Close() // conn.Close()
log.Printf("Closed connection to %v", conn.RemoteAddr()) // log.Printf("Closed connection to %v", conn.RemoteAddr())
})() // })()
buf := make([]byte, 1024) buf := make([]byte, 1024)
for { for {
session, err := CurrentSession(ctx)
panicIfError(err)
if !session.GreetedClient { if !session.GreetedClient {
// send a hello // send a hello
hello := NewFLAP(ctx, 1, []byte{0, 0, 0, 1}) hello := NewFLAP(session, 1, []byte{0, 0, 0, 1})
err := session.Send(hello) err := session.Send(hello)
panicIfError(err) panicIfError(err)
session.GreetedClient = true session.GreetedClient = true
@ -87,25 +87,47 @@ func handleTCPConnection(ctx context.Context, conn net.Conn) {
return return
} }
fmt.Printf("%v ->\n%s\n\n", conn.RemoteAddr(), prettyBytes(buf[:n])) // Try to parse all of the FLAPs in the buffer if we have enough bytes to
// fill a FLAP header
for len(buf) >= 6 && buf[0] == 0x2a {
dataLength := Word(buf[4:6])
flapLength := int(dataLength) + 6
if len(buf) < flapLength {
log.Printf("not enough data, only %d bytes\n", len(buf))
break
}
flap := &FLAP{} flap := &FLAP{}
binstruct.UnmarshalBE(buf[:n], flap) if err := flap.UnmarshalBinary(buf[:flapLength]); err != nil {
panicIfError(errors.Wrap(err, "could not unmarshal FLAP"))
handleMessage(ctx, flap) }
buf = buf[flapLength:]
fmt.Printf("%v ->\n%+v\n", conn.RemoteAddr(), flap)
handleMessage(session, flap)
}
} }
} }
func handleMessage(ctx context.Context, flap *FLAP) { func handleMessage(session *Session, flap *FLAP) {
if flap.Channel == 1 { if flap.Header.Channel == 1 {
} else if flap.Channel == 2 { } else if flap.Header.Channel == 2 {
snac := &SNAC{} snac := &SNAC{}
err := snac.UnmarshalBinary(flap.Data) err := snac.UnmarshalBinary(flap.Data)
panicIfError(err) panicIfError(err)
fmt.Printf("%+v\n", snac)
if tlvs, err := UnmarshalTLVs(snac.Data); err == nil {
for _, tlv := range tlvs {
fmt.Printf("%+v\n", tlv)
}
} else {
fmt.Printf("%s\n\n", prettyBytes(snac.Data))
}
if service, ok := services[snac.Header.Family]; ok { if service, ok := services[snac.Header.Family]; ok {
service.HandleSNAC(ctx, snac) err = service.HandleSNAC(session, snac)
panicIfError(err)
} }
} }
} }

View file

@ -1,9 +1,5 @@
package main package main
import (
"context"
)
type Service interface { type Service interface {
HandleSNAC(context.Context, *SNAC) HandleSNAC(*Session, *SNAC) error
} }

View file

@ -5,6 +5,8 @@ import (
"encoding" "encoding"
"fmt" "fmt"
"net" "net"
"github.com/pkg/errors"
) )
type sessionKey string type sessionKey string
@ -37,20 +39,20 @@ func NewContextWithSession(ctx context.Context, conn net.Conn) context.Context {
} }
func CurrentSession(ctx context.Context) (session *Session, err error) { func CurrentSession(ctx context.Context) (session *Session, err error) {
session, ok := ctx.Value(currentSession).(*Session) s := ctx.Value(currentSession)
if !ok { if s == nil {
return nil, fmt.Errorf("no session in context") return nil, errors.New("no session in context")
} }
return return s.(*Session), nil
} }
func (s *Session) Send(m encoding.BinaryMarshaler) error { func (s *Session) Send(m encoding.BinaryMarshaler) error {
bytes, err := m.MarshalBinary() bytes, err := m.MarshalBinary()
if err != nil { if err != nil {
return err return errors.Wrap(err, "could not marshal message")
} }
fmt.Printf("-> %v\n%s\n\n", s.Conn.RemoteAddr(), prettyBytes(bytes)) fmt.Printf("-> %v\n%s\n\n", s.Conn.RemoteAddr(), prettyBytes(bytes))
_, err = s.Conn.Write(bytes) _, err = s.Conn.Write(bytes)
return err return errors.Wrap(err, "could not write to client connection")
} }

21
snac.go
View file

@ -2,15 +2,19 @@ package main
import ( import (
"bytes" "bytes"
"encoding"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
) )
var _ encoding.BinaryUnmarshaler = &SNAC{}
var _ encoding.BinaryMarshaler = &SNAC{}
type SNACHeader struct { type SNACHeader struct {
Family uint16 Family uint16
Subtype uint16 Subtype uint16
Flags uint16 Flags uint16
RequestID uint16 RequestID uint32
} }
type SNAC struct { type SNAC struct {
@ -19,6 +23,9 @@ type SNAC struct {
} }
func NewSNAC(family uint16, subtype uint16, data []byte) *SNAC { func NewSNAC(family uint16, subtype uint16, data []byte) *SNAC {
d := make([]byte, 0, len(data))
copy(d, data)
return &SNAC{ return &SNAC{
Header: SNACHeader{ Header: SNACHeader{
Family: family, Family: family,
@ -26,7 +33,7 @@ func NewSNAC(family uint16, subtype uint16, data []byte) *SNAC {
Flags: 0, Flags: 0,
RequestID: 0, RequestID: 0,
}, },
Data: data, Data: d,
} }
} }
@ -47,12 +54,16 @@ func (s *SNAC) MarshalBinary() ([]byte, error) {
func (s *SNAC) UnmarshalBinary(data []byte) error { func (s *SNAC) UnmarshalBinary(data []byte) error {
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
if err := binary.Read(buf, binary.BigEndian, &s.Header); err != nil {
if err := binary.Read(buf, binary.BigEndian, &s.Header); s != nil {
return err return err
} }
s.Data = buf.Bytes() s.Data = make([]byte, buf.Len())
copy(s.Data, buf.Bytes())
return nil return nil
} }
func (s *SNAC) String() string {
return fmt.Sprintf("SNAC(%#x, %#x)", s.Header.Family, s.Header.Subtype)
}

114
tlv.go Normal file
View file

@ -0,0 +1,114 @@
package main
import (
"encoding"
"encoding/binary"
"fmt"
"io"
"github.com/pkg/errors"
)
var _ encoding.BinaryUnmarshaler = &TLV{}
var _ encoding.BinaryMarshaler = &TLV{}
type TLV struct {
Type uint16
DataLength uint16
Data []byte
}
func NewTLV(tlvType uint16, data []byte) *TLV {
return &TLV{
Type: tlvType,
DataLength: uint16(len(data)),
Data: data,
}
}
func (t *TLV) Len() int {
return 4 + int(t.DataLength)
}
func (t *TLV) MarshalBinary() ([]byte, error) {
buf := make([]byte, 4+t.DataLength)
binary.BigEndian.PutUint16(buf[:2], t.Type)
binary.BigEndian.PutUint16(buf[2:4], t.DataLength)
copy(buf[4:], t.Data)
return buf, nil
}
func (t *TLV) UnmarshalBinary(data []byte) error {
if len(data) < 4 {
return io.ErrUnexpectedEOF
}
t.Type = Word(data[:2])
t.DataLength = Word(data[2:4])
if len(data) < 4+int(t.DataLength) {
return io.ErrUnexpectedEOF
}
t.Data = make([]byte, int(t.DataLength))
copy(t.Data, data[4:4+int(t.DataLength)])
return nil
}
func (t *TLV) String() string {
return fmt.Sprintf("TLV(%#x):\n%s", t.Type, prettyBytes(t.Data))
}
func UnmarshalTLVs(data []byte) ([]*TLV, error) {
tlvs := make([]*TLV, 0)
d := make([]byte, len(data))
copy(d, data)
for len(d) > 0 {
tlv := &TLV{}
if err := tlv.UnmarshalBinary(d); err != nil {
return nil, errors.Wrap(err, "enexpected end to unmarshalling TLVs")
}
tlvs = append(tlvs, tlv)
d = d[tlv.Len():]
}
return tlvs, nil
}
func FindTLV(tlvs []*TLV, tlvType uint16) *TLV {
for _, tlv := range tlvs {
if tlv.Type == tlvType {
return tlv
}
}
return nil
}
// type TLVReader struct {
// buf []byte
// pos int
// }
// func (r *TLVReader) ReadNextTLV() (*TLV, error) {
// if len(r.buf) < 4 {
// return nil, io.ErrUnexpectedEOF
// }
// t := &TLV{}
// t.Type = Word(r.buf[r.pos:r.pos+2])
// r.pos = r.pos + 2
// t.DataLength = Word(r.buf[r.pos : r.pos+2])
// r.pos = r.pos + 2
// copy(p[2:4], r.buf[r.pos:r.pos+2])
// r.pos = r.pos + 2
// // If there is not enough space to write the expected amount of data, error
// if dataLength > len(p)+4 {
// return 0, io.ErrUnexpectedEOF
// }
// n = n + copy(p[4:dataLength], r.buf[r.pos:r.pos+dataLength])
// r.pos = r.pos + dataLength
// return n, nil
// }
// func NewTLVReader(data []byte) *TLVReader {
// return &TLVReader{buf: data, pos: 0}
// }

10
util.go
View file

@ -23,3 +23,13 @@ func panicIfError(err error) {
panic(err) panic(err)
} }
} }
func Word(b []byte) uint16 {
var _ = b[1]
return uint16(b[1]) | uint16(b[0])<<8
}
func DWord(b []byte) uint32 {
var _ = b[3]
return uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
}