This commit is contained in:
Daniel 2022-12-18 21:26:34 -05:00
parent 9224b21db7
commit 6aa2c46717
8 changed files with 107 additions and 66 deletions

View file

@ -112,3 +112,5 @@ func (a *Authenticator) Check(f *LoginFlow, req *http.Request, rm map[string]int
return nil, clientID, ErrInvalidAuth
}
//func (a *Authenticator) GetOrCreateCreds(

View file

@ -104,7 +104,6 @@ func (f *LoginFlow) progress(a *Authenticator, c echo.Context) error {
return c.JSON(http.StatusBadRequest, f.ShowForm([]string{err.Error()}))
}
user, clientID, err := a.Check(f, c.Request(), rm)
switch err {
case nil:

View file

@ -27,11 +27,22 @@ type HAUser struct {
func (hau *HAUser) UserData() provider.ProviderUser {
return &UserData{ // strip secret
Username: hau.Username,
AuthProvider: hau.AuthProvider,
}
}
func (hau *HAUser) Provider() provider.AuthProvider {
return hau.AuthProvider
}
func (hau *UserData) Provider() provider.AuthProvider {
return hau.AuthProvider
}
type UserData struct {
Username string `json:"username"`
provider.AuthProvider `json:"-"`
}
func (ud *UserData) UserData() provider.ProviderUser {
@ -84,6 +95,28 @@ func (hap *HomeAssistantProvider) hashPass(p string) ([]byte, error) {
return bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost)
}
func (hap *HomeAssistantProvider) EqualCreds(c1, c2 provider.ProviderUser) bool {
switch c1c := c1.(type) {
case *HAUser:
switch c2c := c2.(type) {
case *HAUser:
return c2c.Username == c1c.Username
case *UserData:
return c2c.Username == c1c.Username
}
case *UserData:
switch c2c := c2.(type) {
case *HAUser:
return c2c.Username == c1c.Username
case *UserData:
return c2c.Username == c1c.Username
}
}
return false
}
func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]interface{}) (provider.ProviderUser, bool) {
usernameE, hasU := rm["username"]
passwordE, hasP := rm["password"]
@ -117,6 +150,7 @@ func (hap *HomeAssistantProvider) ValidateCreds(r *http.Request, rm map[string]i
err = bcrypt.CompareHashAndPassword(hash, []byte(password))
if err == nil {
found.AuthProvider = hap
return found, true
}

View file

@ -17,6 +17,7 @@ type AuthProvider interface { // TODO: this should include stepping
FlowSchema() flow.Schema
NewCredData() interface{}
ValidateCreds(r *http.Request, reqMap map[string]interface{}) (user ProviderUser, success bool)
EqualCreds(c1, c2 ProviderUser) bool
Lookup(ProviderUser) ProviderUser
}
@ -24,39 +25,10 @@ func Register(providerName string, f func(storage.Store) (AuthProvider, error))
Providers[providerName] = f
}
type Credentials struct {
ID CredID `json:"id"`
UserID UserID `json:"user_id"`
AuthProviderType string `json:"auth_provider_type"`
AuthProviderID *string `json:"auth_provider_id"`
DataRaw *json.RawMessage `json:"data,omitempty"`
user ProviderUser `json:"-"`
}
func (cred *Credentials) MarshalJSON() ([]byte, error) {
type CredAlias Credentials // alias so ø method set and we don't recurse
nCd := (*CredAlias)(cred)
if cred.user != nil {
providerData := cred.user.UserData()
if providerData != nil {
b, err := json.Marshal(providerData)
if err != nil {
return nil, err
}
dr := json.RawMessage(b)
nCd.DataRaw = &dr
}
}
return json.Marshal(nCd)
}
type ProviderUser interface {
// TODO: make sure this is sane with all the ProviderUser and UserData type stuff
UserData() ProviderUser
Credentials() *Credentials
Provider() AuthProvider
}
type AuthProviderBase struct {

View file

@ -19,11 +19,22 @@ type User struct {
func (hau *User) UserData() provider.ProviderUser {
return &UserData{
UserID: hau.UserID,
AuthProvider: hau.AuthProvider,
}
}
func (hau *UserData) Provider() provider.AuthProvider {
return hau.AuthProvider
}
func (hau *User) Provider() provider.AuthProvider {
return hau.AuthProvider
}
type UserData struct {
UserID string `json:"user_id"`
provider.AuthProvider `json:"-"`
}
func (ud *UserData) UserData() provider.ProviderUser {
@ -36,6 +47,11 @@ type TrustedNetworksProvider struct {
provider.AuthProviderBase `json:"-"`
}
func (hap *TrustedNetworksProvider) EqualCreds(c1, c2 provider.ProviderUser) bool {
panic("not implemented")
return false
}
func New(s storage.Store) (provider.AuthProvider, error) {
hap := &TrustedNetworksProvider{
AuthProviderBase: provider.AuthProviderBase{

View file

@ -10,8 +10,6 @@ import (
"dynatron.me/x/blasphem/internal/common"
"dynatron.me/x/blasphem/internal/generate"
"dynatron.me/x/blasphem/pkg/auth/provider"
"github.com/rs/zerolog/log"
)
type (
@ -90,7 +88,6 @@ func (ss *authCodeStore) cull() {
}
func (ss *authCodeStore) store(clientID ClientID, cred *Credentials) string {
log.Info().Msgf("store cred is %+v", cred)
ss.cull()
code := generate.UUID()
ss.s[authCodeTuple{clientID, AuthCode(code)}] = flowResult{Time: time.Now(), Cred: cred}
@ -117,15 +114,16 @@ type Credentials struct {
AuthProviderType string `json:"auth_provider_type"`
AuthProviderID *string `json:"auth_provider_id"`
DataRaw *json.RawMessage `json:"data,omitempty"`
user provider.ProviderUser `json:"-"`
User provider.ProviderUser `json:"-"`
}
func (cred *Credentials) MarshalJSON() ([]byte, error) {
type CredAlias Credentials // alias so ø method set and we don't recurse
nCd := (*CredAlias)(cred)
if cred.user != nil {
providerData := cred.user.UserData()
if cred.User != nil {
providerData := cred.User.UserData()
if providerData != nil {
b, err := json.Marshal(providerData)
if err != nil {
@ -170,6 +168,7 @@ func (c *ClientID) IsValid() bool {
}
type AuthCode string
func (ac *AuthCode) IsValid() bool {
return *ac != ""
}

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"dynatron.me/x/blasphem/internal/generate"
"dynatron.me/x/blasphem/pkg/auth/provider"
"dynatron.me/x/blasphem/pkg/storage"
@ -20,19 +21,40 @@ type AuthStore interface {
}
type authStore struct {
Users []User `json:"users"`
Groups []Group `json:"groups"`
Credentials []Credentials `json:"credentials"`
Refresh []RefreshToken `json:"refresh_tokens"`
Users []*User `json:"users"`
Groups []*Group `json:"groups"`
Credentials []*Credentials `json:"credentials"`
Refresh []*RefreshToken `json:"refresh_tokens"`
userMap map[UserID]*User
providerUsers map[provider.ProviderUser]*Credentials
}
func strPtrEq(n1, n2 *string) bool {
return (n1 == n2 || (n1 != nil && n2 != nil && *n1 == *n2))
}
func (as *authStore) Credential(p provider.ProviderUser) *Credentials {
c, have := as.providerUsers[p]
if !have {
for _, cr := range as.Credentials {
if p.Provider() != nil &&
strPtrEq(cr.AuthProviderID, p.Provider().ProviderID()) &&
cr.AuthProviderType == p.Provider().ProviderType() &&
p.Provider().EqualCreds(cr.User.UserData(), p.UserData()) {
return cr
}
}
return nil
}
func (as *authStore) newCredential(p provider.ProviderUser) *Credentials {
// XXX: probably broken
prov := p.Provider()
id := generate.UUID()
c := &Credentials{
ID: CredID(id),
AuthProviderType: prov.ProviderBase().Type,
AuthProviderID: prov.ProviderBase().ID,
}
return c
@ -46,7 +68,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
as.providerUsers = make(map[provider.ProviderUser]*Credentials)
for _, u := range as.Users {
as.userMap[u.ID] = &u
as.userMap[u.ID] = u
}
for _, c := range as.Credentials {
@ -63,11 +85,11 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
return nil, err
}
c.user = prov.Lookup(pd.(provider.ProviderUser))
if c.user == nil {
c.User = prov.Lookup(pd.(provider.ProviderUser))
if c.User == nil {
return nil, fmt.Errorf("cannot find user in provider %s", prov.ProviderName())
}
as.providerUsers[c.user] = &c
as.providerUsers[c.User] = c
}
u, hasUser := as.userMap[c.UserID]
@ -90,7 +112,7 @@ func (a *Authenticator) newAuthStore(s storage.Store) (as *authStore, err error)
// don't leak memory
for j := i; j < len(as.Refresh); j++ {
as.Refresh[j] = RefreshToken{}
as.Refresh[j] = nil
}
as.Refresh = as.Refresh[:i]

View file

@ -1,8 +1,6 @@
package auth
import (
"github.com/rs/zerolog/log"
)
import ()
type UserID string
type GroupID string
@ -19,7 +17,7 @@ type User struct {
Data interface{} `json:"data,omitempty"`
UserMetadata
Creds []Credentials `json:"-"`
Creds []*Credentials `json:"-"`
}
type UserMetadata struct {
@ -39,7 +37,6 @@ func (u *User) allowedToAuth() error {
}
func (a *Authenticator) getOrCreateUser(c *Credentials) (*User, error) {
log.Debug().Interface("userdata", c).Msg("getOrCreateUser")
u := a.store.User(c.UserID)
if u == nil {
return nil, ErrInvalidAuth