package graphql import ( "context" "fmt" "github.com/Nerzal/gocloak/v11" "github.com/golang-jwt/jwt/v4" "github.com/gshopify/service-wrapper/auth" "github.com/gshopify/service-wrapper/config" "github.com/gshopify/service-wrapper/model" "github.com/gshopify/service-wrapper/scalar" "gshopper.com/gshopify/customer/graphql/generated" "strings" "time" ) type Resolver struct { conf *auth.Config client gocloak.GoCloak } func NewResolver() (*Resolver, error) { r := &Resolver{ conf: auth.New(), } if err := config.Instance().Load(context.Background(), r.conf); err != nil { return nil, err } r.client = gocloak.NewClient(r.conf.Endpoint) return r, nil } func (r *Resolver) decodeAccessToken(ctx context.Context, t string) (*jwt.Token, string, error) { t = strings.TrimSpace(t) if t == "" { return nil, "", fmt.Errorf("could not decode accessToken: Token is empty") } token, claim, err := r.client.DecodeAccessToken(ctx, t, r.conf.Realm) if err != nil { return nil, "", err } if !token.Valid { return nil, "", fmt.Errorf("could not decode accessToken: Token is NOT valid") } var sessionId string if claimed, ok := (*claim)["sid"]; ok { if s, ok := claimed.(string); ok { sessionId = s } } if sessionId == "" { return nil, "", fmt.Errorf("could not claim session id") } return token, sessionId, nil } func (r *mutationResolver) CustomerAccessTokenCreate( ctx context.Context, input generated.CustomerAccessTokenCreateInput) (*generated.CustomerAccessTokenCreatePayload, error) { var ( token *gocloak.JWT err error session = auth.SessionManager() response = &generated.CustomerAccessTokenCreatePayload{} ) if token, err = r.client.Login(ctx, r.conf.ClientId, r.conf.ClientSecret, r.conf.Realm, input.Email, input.Password); err != nil { response.CustomerUserErrors = append(response.CustomerUserErrors, CustomerError(generated.CustomerErrorCodeUnidentifiedCustomer, err)) return response, nil } if err = session.PutToken(ctx, token.SessionState, token.RefreshToken, time.Duration(token.RefreshExpiresIn)*time.Second); err != nil { response.CustomerUserErrors = append(response.CustomerUserErrors, CustomerError(generated.CustomerErrorCodeTokenInvalid, err)) return response, nil } response.CustomerAccessToken = &generated.CustomerAccessToken{ AccessToken: token.AccessToken, ExpiresAt: scalar.NewDateTimeIn(token.RefreshExpiresIn).String(), } return response, nil } func (r *mutationResolver) CustomerAccessTokenRenew(ctx context.Context, t string) (*generated.CustomerAccessTokenRenewPayload, error) { var ( session = auth.SessionManager() response = &generated.CustomerAccessTokenRenewPayload{} ) _, sid, err := r.decodeAccessToken(ctx, t) if err != nil { response.UserErrors = append(response.UserErrors, ErrTokenNotExists) return response, nil } refresh, err := session.Token(ctx, sid) if err != nil { response.UserErrors = append(response.UserErrors, ErrToken(err.Error())) return response, nil } token, err := r.client.RefreshToken(ctx, refresh, r.conf.ClientId, r.conf.ClientSecret, r.conf.Realm) if err != nil { response.UserErrors = append(response.UserErrors, ErrTokenExpired) return response, nil } if err = session.PutToken(ctx, token.SessionState, token.RefreshToken, time.Duration(token.RefreshExpiresIn)*time.Second); err != nil { response.UserErrors = append(response.UserErrors, ErrToken(err.Error())) return response, nil } response.CustomerAccessToken = &generated.CustomerAccessToken{ AccessToken: token.AccessToken, ExpiresAt: scalar.NewDateTimeIn(token.RefreshExpiresIn).String(), } return response, nil } func (r *mutationResolver) CustomerAccessTokenDelete(ctx context.Context, t string) (*generated.CustomerAccessTokenDeletePayload, error) { var ( response = &generated.CustomerAccessTokenDeletePayload{} session = auth.SessionManager() ) _, sid, err := r.decodeAccessToken(ctx, t) if err != nil { response.UserErrors = append(response.UserErrors, ErrToken(err.Error())) return response, nil } refresh, err := session.Token(ctx, sid) if err != nil { response.UserErrors = append(response.UserErrors, ErrToken(err.Error())) return response, nil } if err = r.client.Logout(ctx, r.conf.ClientId, r.conf.ClientSecret, r.conf.Realm, refresh); err != nil { response.UserErrors = append(response.UserErrors, ErrToken(err.Error())) return response, nil } response.DeletedAccessToken = &t response.DeletedCustomerAccessTokenID = &sid return response, nil } func (r *queryResolver) Customer(ctx context.Context, t string) (*generated.Customer, error) { if _, _, err := r.decodeAccessToken(ctx, t); err != nil { return nil, err } userinfo, err := r.client.GetUserInfo(ctx, t, r.conf.Realm) if err != nil { return nil, err } return &generated.Customer{ AcceptsMarketing: false, //TODO: Addresses: nil, //TODO: DefaultAddress: nil, //TODO: DisplayName: *userinfo.PreferredUsername, Email: userinfo.Email, FirstName: userinfo.Name, ID: *userinfo.Sub, LastName: userinfo.FamilyName, NumberOfOrders: model.UInt(0), //TODO: Phone: userinfo.PhoneNumber, Tags: nil, //TODO: }, nil } // Mutation returns generated.MutationResolver implementation. func (r *Resolver) Mutation() generated.MutationResolver { return &mutationResolver{r} } // Query returns generated.QueryResolver implementation. func (r *Resolver) Query() generated.QueryResolver { return &queryResolver{r} } type mutationResolver struct{ *Resolver } type queryResolver struct{ *Resolver }