From ea15c86e84618809db8f8ce120ca191bcdd66c72 Mon Sep 17 00:00:00 2001 From: Gabe <7622243+decentralgabe@users.noreply.github.com> Date: Wed, 23 Oct 2024 15:57:36 -0700 Subject: [PATCH] SD-JWT impl (#6) * update deps * first cut * tmp * working ish * working i think * test * sd jwt vc done * thumbprint * verifiable presentation * lints --- credential/credential.go | 54 ++++-- go.mod | 7 +- go.sum | 16 +- jose/jose.go | 44 +++-- jose/jose_test.go | 4 +- sd-jwt/sd_jwt.go | 375 ++++++++++++++++++++++++++++++++++++++ sd-jwt/sd_jwt_test.go | 384 +++++++++++++++++++++++++++++++++++++++ util/crypto.go | 9 + 8 files changed, 853 insertions(+), 40 deletions(-) create mode 100644 sd-jwt/sd_jwt.go create mode 100644 sd-jwt/sd_jwt_test.go diff --git a/credential/credential.go b/credential/credential.go index 95985f8..3871ca1 100644 --- a/credential/credential.go +++ b/credential/credential.go @@ -43,6 +43,21 @@ type VerifiableCredential struct { Evidence util.SingleOrArray[any] `json:"evidence,omitempty"` } +// ToMap converts the VerifiableCredential to a map[string]any +func (vc *VerifiableCredential) ToMap() (map[string]any, error) { + jsonBytes, err := json.Marshal(vc) + if err != nil { + return nil, fmt.Errorf("failed to marshal VerifiableCredential: %w", err) + } + + var result map[string]any + if err = json.Unmarshal(jsonBytes, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal VerifiableCredential to map: %w", err) + } + + return result, nil +} + // IssuerHolder represents the issuer of a Verifiable Credential or holder of a Verifiable Presentation, which can be // either a URL string or an object containing an ID property type IssuerHolder struct { @@ -134,19 +149,19 @@ type Schema struct { DigestSRI string `json:"digestSRI,omitempty"` } -func (v *VerifiableCredential) IsEmpty() bool { - if v == nil { +func (vc *VerifiableCredential) IsEmpty() bool { + if vc == nil { return true } - return reflect.DeepEqual(v, &VerifiableCredential{}) + return reflect.DeepEqual(vc, &VerifiableCredential{}) } -func (v *VerifiableCredential) IsValid() error { - return util.NewValidator().Struct(v) +func (vc *VerifiableCredential) IsValid() error { + return util.NewValidator().Struct(vc) } -func (v *VerifiableCredential) IssuerID() string { - return v.Issuer.ID() +func (vc *VerifiableCredential) IssuerID() string { + return vc.Issuer.ID() } // VerifiablePresentation https://www.w3.org/TR/vc-data-model-2.0/#verifiable-presentations @@ -158,13 +173,28 @@ type VerifiablePresentation struct { VerifiableCredential []VerifiableCredential `json:"verifiableCredential,omitempty"` } -func (v *VerifiablePresentation) IsEmpty() bool { - if v == nil { +// ToMap converts the VerifiablePresentation to a map[string]any +func (vp *VerifiablePresentation) ToMap() (map[string]any, error) { + jsonBytes, err := json.Marshal(vp) + if err != nil { + return nil, fmt.Errorf("failed to marshal VerifiablePresentation: %w", err) + } + + var result map[string]any + if err = json.Unmarshal(jsonBytes, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal VerifiablePresentation to map: %w", err) + } + + return result, nil +} + +func (vp *VerifiablePresentation) IsEmpty() bool { + if vp == nil { return true } - return reflect.DeepEqual(v, &VerifiablePresentation{}) + return reflect.DeepEqual(vp, &VerifiablePresentation{}) } -func (v *VerifiablePresentation) IsValid() error { - return util.NewValidator().Struct(v) +func (vp *VerifiablePresentation) IsValid() error { + return util.NewValidator().Struct(vp) } diff --git a/go.mod b/go.mod index 7e83604..5a3b4b4 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( ) require ( + github.com/MichaelFraser99/go-sd-jwt v1.2.1 github.com/btcsuite/btcd/btcec/v2 v2.3.4 github.com/davecgh/go-spew v1.1.1 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.3.0 @@ -28,9 +29,9 @@ require ( github.com/lestrrat-go/option v1.0.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/segmentio/asm v1.2.0 // indirect - golang.org/x/crypto v0.26.0 // indirect - golang.org/x/net v0.28.0 // indirect + golang.org/x/crypto v0.28.0 // indirect + golang.org/x/net v0.30.0 // indirect golang.org/x/sys v0.26.0 // indirect - golang.org/x/text v0.17.0 // indirect + golang.org/x/text v0.19.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index f8b11fb..e432725 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/MichaelFraser99/go-jose v0.9.0 h1:7vUcuJs5vGP0F+AQDStv6puqMYMmx75B4/Qc2CeKQR8= +github.com/MichaelFraser99/go-jose v0.9.0/go.mod h1:kdRvg7/FPcDnsEz8PyCg5hhcBlLud9F0jB4Xy/u771c= +github.com/MichaelFraser99/go-sd-jwt v1.2.1 h1:1Rf+Wy4jdPnRXRI4dvhjUsH2ygERYIrZETtiBtqIPos= +github.com/MichaelFraser99/go-sd-jwt v1.2.1/go.mod h1:1Kt/SQQEpexmeO0NrfPACRwn51NdhcqORikJDNDQMVA= github.com/btcsuite/btcd/btcec/v2 v2.3.4 h1:3EJjcN70HCu/mwqlUsGK8GcNVyLVxFDlWurTXGPFfiQ= github.com/btcsuite/btcd/btcec/v2 v2.3.4/go.mod h1:zYzJ8etWJQIv1Ogk7OzpWjowwOdXY1W/17j2MW85J04= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -47,17 +51,17 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= -golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= -golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= -golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.26.0 h1:KHjCJyddX0LoSTb3J+vWpupP9p0oznkqVk/IfjymZbo= golang.org/x/sys v0.26.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.25.0 h1:WtHI/ltw4NvSUig5KARz9h521QvRC8RmF/cuYqifU24= golang.org/x/term v0.25.0/go.mod h1:RPyXicDX+6vLxogjjRxjgD2TKtmAO6NZBsBRfrOLu7M= -golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= -golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/jose/jose.go b/jose/jose.go index b447d2f..30b1a6f 100644 --- a/jose/jose.go +++ b/jose/jose.go @@ -1,6 +1,7 @@ package jose import ( + "errors" "fmt" "time" @@ -16,21 +17,23 @@ import ( const ( VCJOSEType = "vc+jwt" VPJOSEType = "vp+jwt" - VCJWTTyp = "JWT" - VCJWTAlg = "alg" - VCJWTKid = "kid" ) // SignVerifiableCredential dynamically signs a VerifiableCredential based on the key type. -func SignVerifiableCredential(vc *credential.VerifiableCredential, key jwk.Key) (string, error) { - // Marshal the VerifiableCredential to a map - vcMap := make(map[string]any) - vcBytes, err := json.Marshal(vc) - if err != nil { - return "", err +func SignVerifiableCredential(vc credential.VerifiableCredential, key jwk.Key) (*string, error) { + if vc.IsEmpty() { + return nil, errors.New("VerifiableCredential is empty") } - if err = json.Unmarshal(vcBytes, &vcMap); err != nil { - return "", err + if key.KeyID() == "" { + return nil, errors.New("key ID is required") + } + if key.Algorithm().String() == "" { + return nil, errors.New("key algorithm is required") + } + // Convert VC to a map + vcMap, err := vc.ToMap() + if err != nil { + return nil, fmt.Errorf("failed to convert VC to map: %w", err) } // Add standard claims @@ -50,30 +53,31 @@ func SignVerifiableCredential(vc *credential.VerifiableCredential, key jwk.Key) // Marshal the claims to JSON payload, err := json.Marshal(vcMap) if err != nil { - return "", err + return nil, err } // Add protected header values jwsHeaders := jws.NewHeaders() headers := map[string]string{ - "typ": VPJOSEType, - "cty": credential.VPContentType, + "typ": VCJOSEType, + "cty": credential.VCContentType, "alg": key.Algorithm().String(), "kid": key.KeyID(), } for k, v := range headers { if err = jwsHeaders.Set(k, v); err != nil { - return "", err + return nil, err } } // Sign the payload signed, err := jws.Sign(payload, jws.WithKey(key.Algorithm(), key, jws.WithProtectedHeaders(jwsHeaders))) if err != nil { - return "", err + return nil, err } - return string(signed), nil + result := string(signed) + return &result, nil } // VerifyVerifiableCredential verifies a VerifiableCredential JWT using the provided key. @@ -96,6 +100,12 @@ func VerifyVerifiableCredential(jwt string, key jwk.Key) (*credential.Verifiable // SignVerifiablePresentation dynamically signs a VerifiablePresentation based on the key type. func SignVerifiablePresentation(vp credential.VerifiablePresentation, key jwk.Key) (string, error) { var alg jwa.SignatureAlgorithm + if key.KeyID() == "" { + return "", errors.New("key ID is required") + } + if key.Algorithm().String() == "" { + return "", errors.New("key algorithm is required") + } kty := key.KeyType() switch kty { diff --git a/jose/jose_test.go b/jose/jose_test.go index c30c114..e1f1ead 100644 --- a/jose/jose_test.go +++ b/jose/jose_test.go @@ -27,7 +27,7 @@ func Test_Sign_Verify_VerifiableCredential(t *testing.T) { key, err := util.GenerateJWKWithAlgorithm(tt.curve) require.NoError(t, err) - vc := &credential.VerifiableCredential{ + vc := credential.VerifiableCredential{ Context: []string{"https://www.w3.org/2018/credentials/v1"}, ID: "https://example.edu/credentials/1872", Type: []string{"VerifiableCredential"}, @@ -43,7 +43,7 @@ func Test_Sign_Verify_VerifiableCredential(t *testing.T) { assert.NotEmpty(t, jwt) // Verify the VC - verifiedVC, err := VerifyVerifiableCredential(jwt, key) + verifiedVC, err := VerifyVerifiableCredential(*jwt, key) require.NoError(t, err) assert.Equal(t, vc.ID, verifiedVC.ID) assert.Equal(t, vc.Issuer.ID(), verifiedVC.Issuer.ID()) diff --git a/sd-jwt/sd_jwt.go b/sd-jwt/sd_jwt.go new file mode 100644 index 0000000..4a9b404 --- /dev/null +++ b/sd-jwt/sd_jwt.go @@ -0,0 +1,375 @@ +package sdjwt + +import ( + "crypto" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + + sdjwt "github.com/MichaelFraser99/go-sd-jwt" + "github.com/MichaelFraser99/go-sd-jwt/disclosure" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + + "github.com/TBD54566975/vc-jose-cose-go/credential" +) + +const ( + VCSDJWTType = "vc+sd-jwt" + VPSDJWTType = "vp+sd-jwt" +) + +// DisclosurePath represents a path to a field that should be made selectively disclosable +// Example paths: +// - "credentialSubject.id" +// - "credentialSubject.address.streetAddress" +// - "credentialSubject.nationalities[0]" for array element +type DisclosurePath string + +// SignVerifiableCredential creates an SD-JWT from a VerifiableCredential, making specified fields +// selectively disclosable according to the provided paths. +func SignVerifiableCredential(vc credential.VerifiableCredential, disclosurePaths []DisclosurePath, key jwk.Key) (*string, error) { + if vc.IsEmpty() { + return nil, errors.New("VerifiableCredential is empty") + } + if key.KeyID() == "" { + return nil, errors.New("key ID is required") + } + if key.Algorithm().String() == "" { + return nil, errors.New("key algorithm is required") + } + + // Convert VC to a map for manipulation + vcMap, err := vc.ToMap() + if err != nil { + return nil, fmt.Errorf("failed to convert VC to map: %w", err) + } + + // Add standard claims + if !vc.Issuer.IsEmpty() { + vcMap["iss"] = vc.Issuer.ID() + } + if vc.ID != "" { + vcMap["jti"] = vc.ID + } + if vc.ValidFrom != "" { + vcMap["iat"] = vc.ValidFrom + } + if vc.ValidUntil != "" { + vcMap["exp"] = vc.ValidUntil + } + + // Process disclosures + disclosures := make([]disclosure.Disclosure, 0, len(disclosurePaths)) + processedMap, err := processDisclosures(vcMap, disclosurePaths, &disclosures) + if err != nil { + return nil, fmt.Errorf("failed to process disclosures: %w", err) + } + vcMap = processedMap + + // Marshal the claims to JSON + payload, err := json.Marshal(vcMap) + if err != nil { + return nil, err + } + + // Add protected header values + jwsHeaders := jws.NewHeaders() + headers := map[string]string{ + "typ": VCSDJWTType, + "cty": credential.VCContentType, + "alg": key.Algorithm().String(), + "kid": key.KeyID(), + } + for k, v := range headers { + if err = jwsHeaders.Set(k, v); err != nil { + return nil, err + } + } + + // Sign the JWS issuer key + signed, err := jws.Sign(payload, jws.WithKey(key.Algorithm(), key, jws.WithProtectedHeaders(jwsHeaders))) + if err != nil { + return nil, err + } + + // Combine JWT with disclosures + sdJWTParts := []string{(string)(signed)} + for _, d := range disclosures { + sdJWTParts = append(sdJWTParts, d.EncodedValue) + } + + sdJwt := fmt.Sprintf("%s~", strings.Join(sdJWTParts, "~")) + return &sdJwt, nil +} + +// processDisclosures traverses the credential map and creates disclosures for specified paths +func processDisclosures(data map[string]any, paths []DisclosurePath, disclosures *[]disclosure.Disclosure) (map[string]any, error) { + result := make(map[string]any) + for k, v := range data { + result[k] = v + } + for _, path := range paths { + parts := strings.Split(string(path), ".") + if err := processPath(result, parts, disclosures); err != nil { + return nil, fmt.Errorf("failed to process path %s: %w", path, err) + } + } + return result, nil +} + +// processPath handles a single disclosure path +func processPath(data map[string]any, pathParts []string, disclosures *[]disclosure.Disclosure) error { + if len(pathParts) == 0 { + return nil + } + + // Split path part into field name and optional array index + parts := strings.SplitN(pathParts[0], "[", 2) + field := parts[0] + arrayIndex := -1 + + // Check if we have an array index + if len(parts) == 2 { + // Remove trailing ']' + indexStr := strings.TrimSuffix(parts[1], "]") + var err error + arrayIndex, err = strconv.Atoi(indexStr) + if err != nil { + return fmt.Errorf("invalid array index '%s' in path: %s", indexStr, pathParts[0]) + } + } + + value, exists := data[field] + if !exists { + return fmt.Errorf("field not found: %s", field) + } + + // If this is the last path part, create the disclosure + if len(pathParts) == 1 { + if arrayIndex >= 0 { + arr, ok := value.([]any) + if !ok { + return fmt.Errorf("field %s is not an array", field) + } + if arrayIndex >= len(arr) { + return fmt.Errorf("array index %d out of bounds for field %s", arrayIndex, field) + } + // Create disclosure for array element + d, err := disclosure.NewFromArrayElement(arr[arrayIndex], nil) + if err != nil { + return err + } + *disclosures = append(*disclosures, *d) + + // Replace with digest + arr[arrayIndex] = map[string]any{ + "...": string(d.Hash(crypto.SHA256.New())), + } + data[field] = arr + } else { + // Create disclosure for object property + d, err := disclosure.NewFromObject(field, value, nil) + if err != nil { + return err + } + *disclosures = append(*disclosures, *d) + + // Add hash to _sd array + hash := d.Hash(crypto.SHA256.New()) + if data["_sd"] == nil { + data["_sd"] = []string{string(hash)} + } else { + data["_sd"] = append(data["_sd"].([]string), string(hash)) + } + delete(data, field) + } + return nil + } + + // Need to traverse deeper + if arrayIndex >= 0 { + arr, ok := value.([]any) + if !ok { + return fmt.Errorf("field %s is not an array", field) + } + if arrayIndex >= len(arr) { + return fmt.Errorf("array index %d out of bounds for field %s", arrayIndex, field) + } + nextMap, ok := arr[arrayIndex].(map[string]any) + if !ok { + return fmt.Errorf("array element at index %d of field %s is not an object", arrayIndex, field) + } + if err := processPath(nextMap, pathParts[1:], disclosures); err != nil { + return err + } + arr[arrayIndex] = nextMap + data[field] = arr + return nil + } + + nextMap, ok := value.(map[string]any) + if !ok { + return fmt.Errorf("field %s is not an object", field) + } + return processPath(nextMap, pathParts[1:], disclosures) +} + +// VerifyVerifiableCredential verifies an SD-JWT credential and returns the disclosed claims +func VerifyVerifiableCredential(sdJwtStr string, key jwk.Key) (*credential.VerifiableCredential, error) { + // Parse and verify the SD-JWT + sdJwt, err := sdjwt.New(sdJwtStr) + if err != nil { + return nil, fmt.Errorf("failed to parse SD-JWT: %w", err) + } + + // Get disclosed claims + claims, err := sdJwt.GetDisclosedClaims() + if err != nil { + return nil, fmt.Errorf("failed to get disclosed claims: %w", err) + } + + // Convert claims back to VerifiableCredential + vcBytes, err := json.Marshal(claims) + if err != nil { + return nil, fmt.Errorf("failed to marshal claims: %w", err) + } + + var vc credential.VerifiableCredential + if err = json.Unmarshal(vcBytes, &vc); err != nil { + return nil, fmt.Errorf("failed to unmarshal VC: %w", err) + } + + // Extract signature from SD-JWT + parts := strings.Split(sdJwtStr, "~") + if len(parts) < 1 { + return nil, errors.New("invalid SD-JWT format") + } + + jwsParts := strings.Split(parts[0], ".") + if len(jwsParts) != 3 { + return nil, errors.New("invalid JWT format") + } + + if _, err = jws.Verify([]byte(parts[0]), jws.WithKey(key.Algorithm(), key)); err != nil { + return nil, fmt.Errorf("invalid JWT signature: %w", err) + } + + return &vc, nil +} + +// SignVerifiablePresentation creates an SD-JWT from a VerifiablePresentation, making specified fields +// selectively disclosable according to the provided paths. +func SignVerifiablePresentation(vp credential.VerifiablePresentation, disclosurePaths []DisclosurePath, key jwk.Key) (*string, error) { + if vp.IsEmpty() { + return nil, errors.New("VerifiablePresentation is empty") + } + if key.KeyID() == "" { + return nil, errors.New("key ID is required") + } + if key.Algorithm().String() == "" { + return nil, errors.New("key algorithm is required") + } + + // Convert VP to a map for manipulation + vpMap, err := vp.ToMap() + if err != nil { + return nil, fmt.Errorf("failed to convert VP to map: %w", err) + } + + // Add standard claims + if vp.ID != "" { + vpMap["jti"] = vp.ID + } + if !vp.Holder.IsEmpty() { + vpMap["iss"] = vp.Holder.ID() + } + + // Process disclosures + disclosures := make([]disclosure.Disclosure, 0, len(disclosurePaths)) + processedMap, err := processDisclosures(vpMap, disclosurePaths, &disclosures) + if err != nil { + return nil, fmt.Errorf("failed to process disclosures: %w", err) + } + vpMap = processedMap + + // Marshal the claims to JSON + payload, err := json.Marshal(vpMap) + if err != nil { + return nil, err + } + + // Add protected header values + jwsHeaders := jws.NewHeaders() + headers := map[string]string{ + "typ": VPSDJWTType, + "cty": credential.VPContentType, + "alg": key.Algorithm().String(), + "kid": key.KeyID(), + } + for k, v := range headers { + if err = jwsHeaders.Set(k, v); err != nil { + return nil, err + } + } + + // Sign the JWS with the holder's key + signed, err := jws.Sign(payload, jws.WithKey(key.Algorithm(), key, jws.WithProtectedHeaders(jwsHeaders))) + if err != nil { + return nil, err + } + + // Combine JWT with disclosures + sdJWTParts := []string{(string)(signed)} + for _, d := range disclosures { + sdJWTParts = append(sdJWTParts, d.EncodedValue) + } + + sdJwt := fmt.Sprintf("%s~", strings.Join(sdJWTParts, "~")) + return &sdJwt, nil +} + +// VerifyVerifiablePresentation verifies an SD-JWT presentation and returns the disclosed claims +func VerifyVerifiablePresentation(sdJwtStr string, key jwk.Key) (*credential.VerifiablePresentation, error) { + // Parse and verify the SD-JWT + sdJwt, err := sdjwt.New(sdJwtStr) + if err != nil { + return nil, fmt.Errorf("failed to parse SD-JWT: %w", err) + } + + // Get disclosed claims + claims, err := sdJwt.GetDisclosedClaims() + if err != nil { + return nil, fmt.Errorf("failed to get disclosed claims: %w", err) + } + + // Convert claims back to VerifiablePresentation + vpBytes, err := json.Marshal(claims) + if err != nil { + return nil, fmt.Errorf("failed to marshal claims: %w", err) + } + + var vp credential.VerifiablePresentation + if err = json.Unmarshal(vpBytes, &vp); err != nil { + return nil, fmt.Errorf("failed to unmarshal VP: %w", err) + } + + // Extract signature from SD-JWT + parts := strings.Split(sdJwtStr, "~") + if len(parts) < 1 { + return nil, errors.New("invalid SD-JWT format") + } + + jwsParts := strings.Split(parts[0], ".") + if len(jwsParts) != 3 { + return nil, errors.New("invalid JWT format") + } + + if _, err = jws.Verify([]byte(parts[0]), jws.WithKey(key.Algorithm(), key)); err != nil { + return nil, fmt.Errorf("invalid JWT signature: %w", err) + } + + return &vp, nil +} diff --git a/sd-jwt/sd_jwt_test.go b/sd-jwt/sd_jwt_test.go new file mode 100644 index 0000000..89982a9 --- /dev/null +++ b/sd-jwt/sd_jwt_test.go @@ -0,0 +1,384 @@ +package sdjwt + +import ( + "testing" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/TBD54566975/vc-jose-cose-go/credential" + "github.com/TBD54566975/vc-jose-cose-go/util" +) + +func Test_Sign_Verify_VerifiableCredential(t *testing.T) { + simpleVC := credential.VerifiableCredential{ + Context: []string{"https://www.w3.org/2018/credentials/v1"}, + ID: "https://example.edu/credentials/1872", + Type: []string{"VerifiableCredential"}, + Issuer: credential.NewIssuerHolderFromString("did:example:issuer"), + ValidFrom: "2010-01-01T19:23:24Z", + CredentialSubject: map[string]any{ + "id": "did:example:ebfeb1f712ebc6f1c276e12ec21", + }, + } + + detailVC := credential.VerifiableCredential{ + Context: []string{"https://www.w3.org/2018/credentials/v1"}, + ID: "https://example.edu/credentials/1872", + Type: []string{"VerifiableCredential"}, + Issuer: credential.NewIssuerHolderFromString("did:example:issuer"), + ValidFrom: "2010-01-01T19:23:24Z", + CredentialSubject: map[string]any{ + "id": "did:example:ebfeb1f712ebc6f1c276e12ec21", + "address": map[string]any{ + "streetAddress": "123 Main St", + "city": "Anytown", + "country": "US", + }, + "details": []any{ + "Detail 1", + "Detail 2", + }, + }, + } + + tests := []struct { + name string + curve jwa.EllipticCurveAlgorithm + disclosurePaths []DisclosurePath + vc *credential.VerifiableCredential + verifyFields func(*testing.T, *credential.VerifiableCredential) + }{ + { + name: "EC P-256 with simple credential subject disclosure", + curve: jwa.P256, + disclosurePaths: []DisclosurePath{ + "credentialSubject.id", + }, + vc: &simpleVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + }, + }, + { + name: "EC P-256 with complex nested disclosures", + curve: jwa.P256, + disclosurePaths: []DisclosurePath{ + "credentialSubject.id", + "credentialSubject.address.streetAddress", + "credentialSubject.details[0]", + }, + vc: &detailVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + address := vc.CredentialSubject["address"].(map[string]any) + assert.Equal(t, "123 Main St", address["streetAddress"]) + assert.Equal(t, "Anytown", address["city"]) + details := vc.CredentialSubject["details"].([]any) + assert.Equal(t, "Detail 1", details[0]) + }, + }, + { + name: "EC P-256 with top level disclosures", + curve: jwa.P256, + disclosurePaths: []DisclosurePath{ + "id", + "validFrom", + "credentialSubject.id", + }, + vc: &simpleVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "https://example.edu/credentials/1872", vc.ID) + assert.Equal(t, "2010-01-01T19:23:24Z", vc.ValidFrom) + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + }, + }, + { + name: "EC P-384 with simple credential subject disclosure", + curve: jwa.P384, + disclosurePaths: []DisclosurePath{ + "credentialSubject.id", + }, + vc: &simpleVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + }, + }, + { + name: "EC P-384 with complex nested disclosures", + curve: jwa.P384, + disclosurePaths: []DisclosurePath{ + "credentialSubject.id", + "credentialSubject.address.streetAddress", + "credentialSubject.details[0]", + }, + vc: &detailVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + address := vc.CredentialSubject["address"].(map[string]any) + assert.Equal(t, "123 Main St", address["streetAddress"]) + assert.Equal(t, "Anytown", address["city"]) + details := vc.CredentialSubject["details"].([]any) + assert.Equal(t, "Detail 1", details[0]) + }, + }, + { + name: "EC P-384 with top level disclosures", + curve: jwa.P384, + disclosurePaths: []DisclosurePath{ + "id", + "validFrom", + "credentialSubject.id", + }, + vc: &simpleVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "https://example.edu/credentials/1872", vc.ID) + assert.Equal(t, "2010-01-01T19:23:24Z", vc.ValidFrom) + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + }, + }, + { + name: "EC P-521 with simple credential subject disclosure", + curve: jwa.P521, + disclosurePaths: []DisclosurePath{ + "credentialSubject.id", + }, + vc: &simpleVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + }, + }, + { + name: "EC P-521 with complex nested disclosures", + curve: jwa.P521, + disclosurePaths: []DisclosurePath{ + "credentialSubject.id", + "credentialSubject.address.streetAddress", + "credentialSubject.details[0]", + }, + vc: &detailVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + address := vc.CredentialSubject["address"].(map[string]any) + assert.Equal(t, "123 Main St", address["streetAddress"]) + assert.Equal(t, "Anytown", address["city"]) + details := vc.CredentialSubject["details"].([]any) + assert.Equal(t, "Detail 1", details[0]) + }, + }, + { + name: "EC P-521 with top level disclosures", + curve: jwa.P521, + disclosurePaths: []DisclosurePath{ + "id", + "validFrom", + "credentialSubject.id", + }, + vc: &simpleVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "https://example.edu/credentials/1872", vc.ID) + assert.Equal(t, "2010-01-01T19:23:24Z", vc.ValidFrom) + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + }, + }, + { + name: "OKP EdDSA with simple credential subject disclosure", + curve: jwa.Ed25519, + disclosurePaths: []DisclosurePath{ + "credentialSubject.id", + }, + vc: &simpleVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + }, + }, + { + name: "OKP EdDSA with complex nested disclosures", + curve: jwa.Ed25519, + disclosurePaths: []DisclosurePath{ + "credentialSubject.id", + "credentialSubject.address.streetAddress", + "credentialSubject.details[0]", + }, + vc: &detailVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + address := vc.CredentialSubject["address"].(map[string]any) + assert.Equal(t, "123 Main St", address["streetAddress"]) + assert.Equal(t, "Anytown", address["city"]) + details := vc.CredentialSubject["details"].([]any) + assert.Equal(t, "Detail 1", details[0]) + }, + }, + { + name: "OKP EdDSA with top level disclosures", + curve: jwa.Ed25519, + disclosurePaths: []DisclosurePath{ + "id", + "validFrom", + "credentialSubject.id", + }, + vc: &simpleVC, + verifyFields: func(t *testing.T, vc *credential.VerifiableCredential) { + assert.Equal(t, "https://example.edu/credentials/1872", vc.ID) + assert.Equal(t, "2010-01-01T19:23:24Z", vc.ValidFrom) + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vc.CredentialSubject["id"]) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate issuer key + issuerKey, err := util.GenerateJWKWithAlgorithm(tt.curve) + require.NoError(t, err) + + // Sign the credential + sdJwt, err := SignVerifiableCredential(*tt.vc, tt.disclosurePaths, issuerKey) + require.NoError(t, err) + require.NotNil(t, sdJwt) + + // Verify the credential + verifiedVC, err := VerifyVerifiableCredential(*sdJwt, issuerKey) + require.NoError(t, err) + require.NotNil(t, verifiedVC) + + // Verify standard fields + assert.Equal(t, tt.vc.Context, verifiedVC.Context) + assert.Equal(t, tt.vc.Type, verifiedVC.Type) + assert.Equal(t, tt.vc.Issuer, verifiedVC.Issuer) + + // Apply any test-specific verification + if tt.verifyFields != nil { + tt.verifyFields(t, verifiedVC) + } + + // Verify validation fails with wrong key + wrongKey, err := util.GenerateJWKWithAlgorithm(tt.curve) + require.NoError(t, err) + _, err = VerifyVerifiableCredential(*sdJwt, wrongKey) + assert.Error(t, err) + }) + } +} + +func Test_Sign_Verify_VerifiablePresentation(t *testing.T) { + simpleVP := credential.VerifiablePresentation{ + Context: []string{"https://www.w3.org/2018/credentials/v1"}, + ID: "urn:uuid:3978344f-8596-4c3a-a978-8fcaba3903c5", + Type: []string{"VerifiablePresentation"}, + Holder: credential.NewIssuerHolderFromString("did:example:holder"), + VerifiableCredential: []credential.VerifiableCredential{ + { + Context: []string{"https://www.w3.org/2018/credentials/v1"}, + ID: "https://example.edu/credentials/1872", + Type: []string{"VerifiableCredential"}, + Issuer: credential.NewIssuerHolderFromString("did:example:issuer"), + ValidFrom: "2010-01-01T19:23:24Z", + CredentialSubject: map[string]any{ + "id": "did:example:ebfeb1f712ebc6f1c276e12ec21", + }, + }, + }, + } + + tests := []struct { + name string + curve jwa.EllipticCurveAlgorithm + disclosurePaths []DisclosurePath + vp *credential.VerifiablePresentation + verifyFields func(*testing.T, *credential.VerifiablePresentation) + }{ + { + name: "EC P-256 with simple presentation disclosure", + curve: jwa.P256, + disclosurePaths: []DisclosurePath{ + "holder", + "verifiableCredential[0].credentialSubject.id", + }, + vp: &simpleVP, + verifyFields: func(t *testing.T, vp *credential.VerifiablePresentation) { + assert.Equal(t, "did:example:holder", vp.Holder.ID()) + assert.Len(t, vp.VerifiableCredential, 1) + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vp.VerifiableCredential[0].CredentialSubject["id"]) + }, + }, + { + name: "EC P-384 with simple presentation disclosure", + curve: jwa.P384, + disclosurePaths: []DisclosurePath{ + "holder", + "verifiableCredential[0].credentialSubject.id", + }, + vp: &simpleVP, + verifyFields: func(t *testing.T, vp *credential.VerifiablePresentation) { + assert.Equal(t, "did:example:holder", vp.Holder.ID()) + assert.Len(t, vp.VerifiableCredential, 1) + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vp.VerifiableCredential[0].CredentialSubject["id"]) + }, + }, + { + name: "EC P-521 with simple presentation disclosure", + curve: jwa.P521, + disclosurePaths: []DisclosurePath{ + "holder", + "verifiableCredential[0].credentialSubject.id", + }, + vp: &simpleVP, + verifyFields: func(t *testing.T, vp *credential.VerifiablePresentation) { + assert.Equal(t, "did:example:holder", vp.Holder.ID()) + assert.Len(t, vp.VerifiableCredential, 1) + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vp.VerifiableCredential[0].CredentialSubject["id"]) + }, + }, + { + name: "OKP EdDSA with simple presentation disclosure", + curve: jwa.Ed25519, + disclosurePaths: []DisclosurePath{ + "holder", + "verifiableCredential[0].credentialSubject.id", + }, + vp: &simpleVP, + verifyFields: func(t *testing.T, vp *credential.VerifiablePresentation) { + assert.Equal(t, "did:example:holder", vp.Holder.ID()) + assert.Len(t, vp.VerifiableCredential, 1) + assert.Equal(t, "did:example:ebfeb1f712ebc6f1c276e12ec21", vp.VerifiableCredential[0].CredentialSubject["id"]) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Generate holder key + holderKey, err := util.GenerateJWKWithAlgorithm(tt.curve) + require.NoError(t, err) + + // Sign the presentation + sdJwt, err := SignVerifiablePresentation(*tt.vp, tt.disclosurePaths, holderKey) + require.NoError(t, err) + require.NotNil(t, sdJwt) + + // Verify the presentation + verifiedVP, err := VerifyVerifiablePresentation(*sdJwt, holderKey) + require.NoError(t, err) + require.NotNil(t, verifiedVP) + + // Verify standard fields + assert.Equal(t, tt.vp.Context, verifiedVP.Context) + assert.Equal(t, tt.vp.Type, verifiedVP.Type) + assert.Equal(t, tt.vp.Holder, verifiedVP.Holder) + + // Apply any test-specific verification + if tt.verifyFields != nil { + tt.verifyFields(t, verifiedVP) + } + + // Verify validation fails with wrong key + wrongKey, err := util.GenerateJWKWithAlgorithm(tt.curve) + require.NoError(t, err) + _, err = VerifyVerifiablePresentation(*sdJwt, wrongKey) + assert.Error(t, err) + }) + } +} diff --git a/util/crypto.go b/util/crypto.go index c8b7149..d7bbaf0 100644 --- a/util/crypto.go +++ b/util/crypto.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "encoding/base64" "fmt" "reflect" @@ -67,6 +68,14 @@ func GenerateJWKWithAlgorithm(eca jwa.EllipticCurveAlgorithm) (jwk.Key, error) { return nil, fmt.Errorf("unsupported elliptic curve algorithm: %s", eca) } + thumbprintBytes, err := jwkKey.Thumbprint(crypto.SHA256) + if err != nil { + return nil, fmt.Errorf("failed to compute thumbprint: %w", err) + } + thumbprint := base64.RawURLEncoding.EncodeToString(thumbprintBytes) + if err = jwkKey.Set(jwk.KeyIDKey, thumbprint); err != nil { + return nil, fmt.Errorf("failed to set key ID in JWK: %w", err) + } // Set the algorithm in the JWK if err = jwkKey.Set(jwk.AlgorithmKey, alg); err != nil { return nil, fmt.Errorf("failed to set algorithm in JWK: %w", err)