diff --git a/authstate.go b/authstate.go index f512275a..15de09b4 100644 --- a/authstate.go +++ b/authstate.go @@ -17,10 +17,10 @@ type StateProvider interface { type FederatedStateClient interface { LookupState( - ctx context.Context, s ServerName, roomID, eventID string, roomVersion RoomVersion, + ctx context.Context, origin, s ServerName, roomID, eventID string, roomVersion RoomVersion, ) (res RespState, err error) LookupStateIDs( - ctx context.Context, s ServerName, roomID, eventID string, + ctx context.Context, origin, s ServerName, roomID, eventID string, ) (res RespStateIDs, err error) } @@ -28,6 +28,7 @@ type FederatedStateClient interface { type FederatedStateProvider struct { FedClient FederatedStateClient // The remote server to ask. + Origin ServerName Server ServerName // Set to true to remember the auth event IDs for the room at various states RememberAuthEvents bool @@ -38,7 +39,7 @@ type FederatedStateProvider struct { // StateIDsBeforeEvent implements StateProvider func (p *FederatedStateProvider) StateIDsBeforeEvent(ctx context.Context, event *HeaderedEvent) ([]string, error) { - res, err := p.FedClient.LookupStateIDs(ctx, p.Server, event.RoomID(), event.EventID()) + res, err := p.FedClient.LookupStateIDs(ctx, p.Origin, p.Server, event.RoomID(), event.EventID()) if err != nil { return nil, err } @@ -50,7 +51,7 @@ func (p *FederatedStateProvider) StateIDsBeforeEvent(ctx context.Context, event // StateBeforeEvent implements StateProvider func (p *FederatedStateProvider) StateBeforeEvent(ctx context.Context, roomVer RoomVersion, event *HeaderedEvent, eventIDs []string) (map[string]*Event, error) { - res, err := p.FedClient.LookupState(ctx, p.Server, event.RoomID(), event.EventID(), roomVer) + res, err := p.FedClient.LookupState(ctx, p.Origin, p.Server, event.RoomID(), event.EventID(), roomVer) if err != nil { return nil, err } diff --git a/backfill.go b/backfill.go index 5fa3ea36..a1e50ca8 100644 --- a/backfill.go +++ b/backfill.go @@ -10,7 +10,7 @@ import ( type BackfillClient interface { // Backfill performs a backfill request to the given server. // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid - Backfill(ctx context.Context, server ServerName, roomID string, limit int, fromEventIDs []string) (Transaction, error) + Backfill(ctx context.Context, origin, server ServerName, roomID string, limit int, fromEventIDs []string) (Transaction, error) } // BackfillRequester contains the necessary functions to perform backfill requests from one server to another. @@ -45,7 +45,7 @@ type BackfillRequester interface { // but to verify it we need to know the prev_events of fromEventIDs. // // TODO: When does it make sense to return errors? -func RequestBackfill(ctx context.Context, b BackfillRequester, keyRing JSONVerifier, +func RequestBackfill(ctx context.Context, origin ServerName, b BackfillRequester, keyRing JSONVerifier, roomID string, ver RoomVersion, fromEventIDs []string, limit int) ([]*HeaderedEvent, error) { if len(fromEventIDs) == 0 { @@ -67,7 +67,7 @@ func RequestBackfill(ctx context.Context, b BackfillRequester, keyRing JSONVerif return nil, fmt.Errorf("gomatrixserverlib: RequestBackfill context cancelled %w", ctx.Err()) } // fetch some events, and try a different server if it fails - txn, err := b.Backfill(ctx, s, roomID, limit, fromEventIDs) + txn, err := b.Backfill(ctx, origin, s, roomID, limit, fromEventIDs) if err != nil { continue // try the next server } diff --git a/backfill_test.go b/backfill_test.go index 48e055a3..4b0eb9b9 100644 --- a/backfill_test.go +++ b/backfill_test.go @@ -12,7 +12,7 @@ import ( type testBackfillRequester struct { servers []ServerName - backfillFn func(server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) + backfillFn func(origin, server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) authEventsToProvide [][]byte stateIDsAtEvent map[string][]string callOrderForStateIDsBeforeEvent []string // event IDs called @@ -28,8 +28,8 @@ func (t *testBackfillRequester) StateBeforeEvent(ctx context.Context, roomVer Ro func (t *testBackfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []ServerName { return t.servers } -func (t *testBackfillRequester) Backfill(ctx context.Context, server ServerName, roomID string, limit int, fromEventIDs []string) (Transaction, error) { - txn, err := t.backfillFn(server, roomID, fromEventIDs, limit) +func (t *testBackfillRequester) Backfill(ctx context.Context, origin, server ServerName, roomID string, limit int, fromEventIDs []string) (Transaction, error) { + txn, err := t.backfillFn(origin, server, roomID, fromEventIDs, limit) if err != nil { return Transaction{}, err } @@ -92,14 +92,14 @@ func TestRequestBackfillMultipleServers(t *testing.T) { "$fnwGrQEpiOIUoDU2:baba.is.you": {"$WCraVpPZe5TtHAqs:baba.is.you"}, "$WCraVpPZe5TtHAqs:baba.is.you": nil, }, - backfillFn: func(server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) { + backfillFn: func(origin, server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) { if roomID != testRoomID { return nil, fmt.Errorf("bad room id: %s", roomID) } if server == serverA { // server A returns events 1 and 3. return &Transaction{ - Origin: serverA, + Origin: origin, OriginServerTS: AsTimestamp(time.Now()), PDUs: []json.RawMessage{ testBackfillEvents[1], testBackfillEvents[3], @@ -108,7 +108,7 @@ func TestRequestBackfillMultipleServers(t *testing.T) { } else if server == serverB { // server B returns events 0 and 2 and 3. return &Transaction{ - Origin: serverB, + Origin: origin, OriginServerTS: AsTimestamp(time.Now()), PDUs: []json.RawMessage{ testBackfillEvents[0], testBackfillEvents[2], testBackfillEvents[3], @@ -118,7 +118,7 @@ func TestRequestBackfillMultipleServers(t *testing.T) { return nil, fmt.Errorf("bad server name: %s", server) }, } - result, err := RequestBackfill(ctx, tbr, keyRing, testRoomID, RoomVersionV1, testFromEventIDs, testLimit) + result, err := RequestBackfill(ctx, serverA, tbr, keyRing, testRoomID, RoomVersionV1, testFromEventIDs, testLimit) if err != nil { t.Fatalf("RequestBackfill got error: %s", err) } @@ -157,13 +157,13 @@ func TestRequestBackfillTopologicalSort(t *testing.T) { "$fnwGrQEpiOIUoDU2:baba.is.you": {"$WCraVpPZe5TtHAqs:baba.is.you"}, "$WCraVpPZe5TtHAqs:baba.is.you": nil, }, - backfillFn: func(server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) { + backfillFn: func(origin, server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) { if roomID != testRoomID { return nil, fmt.Errorf("bad room id: %s", roomID) } if server == serverA { return &Transaction{ - Origin: serverA, + Origin: origin, OriginServerTS: AsTimestamp(time.Now()), PDUs: []json.RawMessage{ testBackfillEvents[0], testBackfillEvents[1], testBackfillEvents[2], testBackfillEvents[3], @@ -173,7 +173,7 @@ func TestRequestBackfillTopologicalSort(t *testing.T) { return nil, fmt.Errorf("bad server name: %s", server) }, } - result, err := RequestBackfill(ctx, tbr, keyRing, testRoomID, RoomVersionV1, testFromEventIDs, testLimit) + result, err := RequestBackfill(ctx, serverA, tbr, keyRing, testRoomID, RoomVersionV1, testFromEventIDs, testLimit) if err != nil { t.Fatalf("RequestBackfill got error: %s", err) } diff --git a/federationclient.go b/federationclient.go index 4b47231c..b8b6279b 100644 --- a/federationclient.go +++ b/federationclient.go @@ -16,9 +16,13 @@ import ( // "Authorization: X-Matrix" headers to requests that need ed25519 signatures type FederationClient struct { Client - serverName ServerName - serverKeyID KeyID - serverPrivateKey ed25519.PrivateKey + identities []*SigningIdentity +} + +type SigningIdentity struct { + ServerName ServerName + KeyID KeyID + PrivateKey ed25519.PrivateKey } // NewFederationClient makes a new FederationClient. You can supply @@ -26,21 +30,29 @@ type FederationClient struct { // TLS validation etc - see WithTransport, WithTimeout, WithSkipVerify, // WithDNSCache etc. func NewFederationClient( - serverName ServerName, keyID KeyID, privateKey ed25519.PrivateKey, + identities []*SigningIdentity, options ...ClientOption, ) *FederationClient { return &FederationClient{ Client: *NewClient( append(options, WithWellKnownSRVLookups(true))..., ), - serverName: serverName, - serverKeyID: keyID, - serverPrivateKey: privateKey, + identities: append([]*SigningIdentity{}, identities...), } } func (ac *FederationClient) doRequest(ctx context.Context, r FederationRequest, resBody interface{}) error { - if err := r.Sign(ac.serverName, ac.serverKeyID, ac.serverPrivateKey); err != nil { + var identity *SigningIdentity + for _, id := range ac.identities { + if id.ServerName == r.fields.Origin { + identity = id + break + } + } + if identity == nil { + return fmt.Errorf("no signing identity for server name %q", r.fields.Origin) + } + if err := r.Sign(identity.ServerName, identity.KeyID, identity.PrivateKey); err != nil { return err } @@ -60,7 +72,7 @@ func (ac *FederationClient) SendTransaction( ctx context.Context, t Transaction, ) (res RespSend, err error) { path := federationPathPrefixV1 + "/send/" + string(t.TransactionID) - req := NewFederationRequest("PUT", t.Destination, path) + req := NewFederationRequest("PUT", t.Origin, t.Destination, path) if err = req.SetContent(t); err != nil { return } @@ -93,14 +105,14 @@ func makeVersionQueryString(roomVersions []RoomVersion) string { // server's key and pass it to SendJoin. // See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms func (ac *FederationClient) MakeJoin( - ctx context.Context, s ServerName, roomID, userID string, + ctx context.Context, origin, s ServerName, roomID, userID string, roomVersions []RoomVersion, ) (res RespMakeJoin, err error) { versionQueryString := makeVersionQueryString(roomVersions) path := federationPathPrefixV1 + "/make_join/" + url.PathEscape(roomID) + "/" + url.PathEscape(userID) + versionQueryString - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -110,9 +122,9 @@ func (ac *FederationClient) MakeJoin( // This is used to join a room the local server isn't a member of. // See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms func (ac *FederationClient) SendJoin( - ctx context.Context, s ServerName, event *Event, + ctx context.Context, origin, s ServerName, event *Event, ) (res RespSendJoin, err error) { - return ac.sendJoin(ctx, s, event, false) + return ac.sendJoin(ctx, origin, s, event, false) } // SendJoinPartialState sends a join m.room.member event obtained using MakeJoin via a @@ -121,14 +133,14 @@ func (ac *FederationClient) SendJoin( // This is used to join a room the local server isn't a member of. // See https://matrix.org/docs/spec/server_server/unstable.html#joining-rooms func (ac *FederationClient) SendJoinPartialState( - ctx context.Context, s ServerName, event *Event, + ctx context.Context, origin, s ServerName, event *Event, ) (res RespSendJoin, err error) { - return ac.sendJoin(ctx, s, event, true) + return ac.sendJoin(ctx, origin, s, event, true) } // sendJoin is an internal implementation shared between SendJoin and SendJoinPartialState func (ac *FederationClient) sendJoin( - ctx context.Context, s ServerName, event *Event, partialState bool, + ctx context.Context, origin, s ServerName, event *Event, partialState bool, ) (res RespSendJoin, err error) { path := federationPathPrefixV2 + "/send_join/" + url.PathEscape(event.RoomID()) + "/" + @@ -137,7 +149,7 @@ func (ac *FederationClient) sendJoin( path += "?org.matrix.msc3706.partial_state=true" } - req := NewFederationRequest("PUT", s, path) + req := NewFederationRequest("PUT", origin, s, path) if err = req.SetContent(event); err != nil { return } @@ -148,7 +160,7 @@ func (ac *FederationClient) sendJoin( v1path := federationPathPrefixV1 + "/send_join/" + url.PathEscape(event.RoomID()) + "/" + url.PathEscape(event.EventID()) - v1req := NewFederationRequest("PUT", s, v1path) + v1req := NewFederationRequest("PUT", origin, s, v1path) if err = v1req.SetContent(event); err != nil { return } @@ -170,14 +182,14 @@ func (ac *FederationClient) sendJoin( // server's key and pass it to SendKnock. // See https://spec.matrix.org/v1.3/server-server-api/#knocking-upon-a-room func (ac *FederationClient) MakeKnock( - ctx context.Context, s ServerName, roomID, userID string, + ctx context.Context, origin, s ServerName, roomID, userID string, roomVersions []RoomVersion, ) (res RespMakeKnock, err error) { versionQueryString := makeVersionQueryString(roomVersions) path := federationPathPrefixV1 + "/make_knock/" + url.PathEscape(roomID) + "/" + url.PathEscape(userID) + versionQueryString - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -187,13 +199,13 @@ func (ac *FederationClient) MakeKnock( // This is used to ask to join a room the local server isn't a member of. // See https://spec.matrix.org/v1.3/server-server-api/#knocking-upon-a-room func (ac *FederationClient) SendKnock( - ctx context.Context, s ServerName, event *Event, + ctx context.Context, origin, s ServerName, event *Event, ) (res RespSendKnock, err error) { path := federationPathPrefixV1 + "/send_knock/" + url.PathEscape(event.RoomID()) + "/" + url.PathEscape(event.EventID()) - req := NewFederationRequest("PUT", s, path) + req := NewFederationRequest("PUT", origin, s, path) if err = req.SetContent(event); err != nil { return } @@ -207,12 +219,12 @@ func (ac *FederationClient) SendKnock( // the event_id with our own, and pass it to SendLeave. // See https://matrix.org/docs/spec/server_server/r0.1.1.html#get-matrix-federation-v1-make-leave-roomid-userid func (ac *FederationClient) MakeLeave( - ctx context.Context, s ServerName, roomID, userID string, + ctx context.Context, origin, s ServerName, roomID, userID string, ) (res RespMakeLeave, err error) { path := federationPathPrefixV1 + "/make_leave/" + url.PathEscape(roomID) + "/" + url.PathEscape(userID) - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -222,12 +234,12 @@ func (ac *FederationClient) MakeLeave( // This is used to reject a remote invite. // See https://matrix.org/docs/spec/server_server/r0.1.1.html#put-matrix-federation-v1-send-leave-roomid-eventid func (ac *FederationClient) SendLeave( - ctx context.Context, s ServerName, event *Event, + ctx context.Context, origin, s ServerName, event *Event, ) (err error) { path := federationPathPrefixV2 + "/send_leave/" + url.PathEscape(event.RoomID()) + "/" + url.PathEscape(event.EventID()) - req := NewFederationRequest("PUT", s, path) + req := NewFederationRequest("PUT", origin, s, path) if err = req.SetContent(event); err != nil { return } @@ -239,7 +251,7 @@ func (ac *FederationClient) SendLeave( v1path := federationPathPrefixV1 + "/send_leave/" + url.PathEscape(event.RoomID()) + "/" + url.PathEscape(event.EventID()) - v1req := NewFederationRequest("PUT", s, v1path) + v1req := NewFederationRequest("PUT", origin, s, v1path) if err = v1req.SetContent(event); err != nil { return } @@ -255,12 +267,12 @@ func (ac *FederationClient) SendLeave( // SendInvite sends an invite m.room.member event to an invited server to be // signed by it. This is used to invite a user that is not on the local server. func (ac *FederationClient) SendInvite( - ctx context.Context, s ServerName, event *Event, + ctx context.Context, origin, s ServerName, event *Event, ) (res RespInvite, err error) { path := federationPathPrefixV1 + "/invite/" + url.PathEscape(event.RoomID()) + "/" + url.PathEscape(event.EventID()) - req := NewFederationRequest("PUT", s, path) + req := NewFederationRequest("PUT", origin, s, path) if err = req.SetContent(event); err != nil { return } @@ -271,13 +283,13 @@ func (ac *FederationClient) SendInvite( // SendInviteV2 sends an invite m.room.member event to an invited server to be // signed by it. This is used to invite a user that is not on the local server. func (ac *FederationClient) SendInviteV2( - ctx context.Context, s ServerName, request InviteV2Request, + ctx context.Context, origin, s ServerName, request InviteV2Request, ) (res RespInviteV2, err error) { event := request.Event() path := federationPathPrefixV2 + "/invite/" + url.PathEscape(event.RoomID()) + "/" + url.PathEscape(event.EventID()) - req := NewFederationRequest("PUT", s, path) + req := NewFederationRequest("PUT", origin, s, path) if err = req.SetContent(request); err != nil { return } @@ -287,7 +299,7 @@ func (ac *FederationClient) SendInviteV2( if ok && gerr.Code == 404 { // fallback to v1 which returns [200, body] var resp RespInvite - resp, err = ac.SendInvite(ctx, s, request.Event()) + resp, err = ac.SendInvite(ctx, origin, s, request.Event()) if err != nil { return } @@ -306,11 +318,11 @@ func (ac *FederationClient) SendInviteV2( // This is used to exchange a m.room.third_party_invite event for a m.room.member // one in a room the local server isn't a member of. func (ac *FederationClient) ExchangeThirdPartyInvite( - ctx context.Context, s ServerName, builder EventBuilder, + ctx context.Context, origin, s ServerName, builder EventBuilder, ) (err error) { path := federationPathPrefixV1 + "/exchange_third_party_invite/" + url.PathEscape(builder.RoomID) - req := NewFederationRequest("PUT", s, path) + req := NewFederationRequest("PUT", origin, s, path) if err = req.SetContent(builder); err != nil { return } @@ -322,13 +334,13 @@ func (ac *FederationClient) ExchangeThirdPartyInvite( // LookupState retrieves the room state for a room at an event from a // remote matrix server as full matrix events. func (ac *FederationClient) LookupState( - ctx context.Context, s ServerName, roomID, eventID string, roomVersion RoomVersion, + ctx context.Context, origin, s ServerName, roomID, eventID string, roomVersion RoomVersion, ) (res RespState, err error) { path := federationPathPrefixV1 + "/state/" + url.PathEscape(roomID) + "?event_id=" + url.QueryEscape(eventID) - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -336,13 +348,13 @@ func (ac *FederationClient) LookupState( // LookupStateIDs retrieves the room state for a room at an event from a // remote matrix server as lists of matrix event IDs. func (ac *FederationClient) LookupStateIDs( - ctx context.Context, s ServerName, roomID, eventID string, + ctx context.Context, origin, s ServerName, roomID, eventID string, ) (res RespStateIDs, err error) { path := federationPathPrefixV1 + "/state_ids/" + url.PathEscape(roomID) + "?event_id=" + url.QueryEscape(eventID) - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -351,12 +363,12 @@ func (ac *FederationClient) LookupStateIDs( // given bracket. // https://matrix.org/docs/spec/server_server/r0.1.3#post-matrix-federation-v1-get-missing-events-roomid func (ac *FederationClient) LookupMissingEvents( - ctx context.Context, s ServerName, roomID string, + ctx context.Context, origin, s ServerName, roomID string, missing MissingEvents, roomVersion RoomVersion, ) (res RespMissingEvents, err error) { path := federationPathPrefixV1 + "/get_missing_events/" + url.PathEscape(roomID) - req := NewFederationRequest("POST", s, path) + req := NewFederationRequest("POST", origin, s, path) if err = req.SetContent(missing); err != nil { return } @@ -366,7 +378,7 @@ func (ac *FederationClient) LookupMissingEvents( // Peek starts a peek on a remote server: see MSC2753 func (ac *FederationClient) Peek( - ctx context.Context, s ServerName, roomID, peekID string, + ctx context.Context, origin, s ServerName, roomID, peekID string, roomVersions []RoomVersion, ) (res RespPeek, err error) { versionQueryString := "" @@ -380,7 +392,7 @@ func (ac *FederationClient) Peek( path := federationPathPrefixV1 + "/peek/" + url.PathEscape(roomID) + "/" + url.PathEscape(peekID) + versionQueryString - req := NewFederationRequest("PUT", s, path) + req := NewFederationRequest("PUT", origin, s, path) var empty struct{} if err = req.SetContent(empty); err != nil { return @@ -395,11 +407,11 @@ func (ac *FederationClient) Peek( // If the room alias doesn't exist on the remote server then a 404 gomatrix.HTTPError // is returned. func (ac *FederationClient) LookupRoomAlias( - ctx context.Context, s ServerName, roomAlias string, + ctx context.Context, origin, s ServerName, roomAlias string, ) (res RespDirectory, err error) { path := federationPathPrefixV1 + "/query/directory?room_alias=" + url.QueryEscape(roomAlias) - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -408,10 +420,10 @@ func (ac *FederationClient) LookupRoomAlias( // Spec: https://matrix.org/docs/spec/server_server/r0.1.1.html#get-matrix-federation-v1-publicrooms // thirdPartyInstanceID can only be non-empty if includeAllNetworks is false. func (ac *FederationClient) GetPublicRooms( - ctx context.Context, s ServerName, limit int, since string, + ctx context.Context, origin, s ServerName, limit int, since string, includeAllNetworks bool, thirdPartyInstanceID string, ) (res RespPublicRooms, err error) { - return ac.GetPublicRoomsFiltered(ctx, s, limit, since, "", includeAllNetworks, thirdPartyInstanceID) + return ac.GetPublicRoomsFiltered(ctx, origin, s, limit, since, "", includeAllNetworks, thirdPartyInstanceID) } // searchTerm is used when querying e.g. remote public rooms @@ -432,7 +444,7 @@ type postPublicRoomsReq struct { // Spec: https://spec.matrix.org/v1.1/server-server-api/#post_matrixfederationv1publicrooms // thirdPartyInstanceID can only be non-empty if includeAllNetworks is false. func (ac *FederationClient) GetPublicRoomsFiltered( - ctx context.Context, s ServerName, limit int, since, filter string, + ctx context.Context, origin, s ServerName, limit int, since, filter string, includeAllNetworks bool, thirdPartyInstanceID string, ) (res RespPublicRooms, err error) { if includeAllNetworks && thirdPartyInstanceID != "" { @@ -447,7 +459,7 @@ func (ac *FederationClient) GetPublicRoomsFiltered( Since: since, } path := federationPathPrefixV1 + "/publicRooms" - req := NewFederationRequest("POST", s, path) + req := NewFederationRequest("POST", origin, s, path) if err = req.SetContent(roomsReq); err != nil { return } @@ -461,14 +473,14 @@ func (ac *FederationClient) GetPublicRoomsFiltered( // which field of the profile should be returned. // Spec: https://matrix.org/docs/spec/server_server/r0.1.1.html#get-matrix-federation-v1-query-profile func (ac *FederationClient) LookupProfile( - ctx context.Context, s ServerName, userID string, field string, + ctx context.Context, origin, s ServerName, userID string, field string, ) (res RespProfile, err error) { path := federationPathPrefixV1 + "/query/profile?user_id=" + url.QueryEscape(userID) if field != "" { path += "&field=" + url.QueryEscape(field) } - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -483,9 +495,9 @@ func (ac *FederationClient) LookupProfile( // } // // https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-claim -func (ac *FederationClient) ClaimKeys(ctx context.Context, s ServerName, oneTimeKeys map[string]map[string]string) (res RespClaimKeys, err error) { +func (ac *FederationClient) ClaimKeys(ctx context.Context, origin, s ServerName, oneTimeKeys map[string]map[string]string) (res RespClaimKeys, err error) { path := federationPathPrefixV1 + "/user/keys/claim" - req := NewFederationRequest("POST", s, path) + req := NewFederationRequest("POST", origin, s, path) if err = req.SetContent(map[string]interface{}{ "one_time_keys": oneTimeKeys, }); err != nil { @@ -497,9 +509,9 @@ func (ac *FederationClient) ClaimKeys(ctx context.Context, s ServerName, oneTime // QueryKeys queries E2E device keys from a remote server. // https://matrix.org/docs/spec/server_server/latest#post-matrix-federation-v1-user-keys-query -func (ac *FederationClient) QueryKeys(ctx context.Context, s ServerName, keys map[string][]string) (res RespQueryKeys, err error) { +func (ac *FederationClient) QueryKeys(ctx context.Context, origin, s ServerName, keys map[string][]string) (res RespQueryKeys, err error) { path := federationPathPrefixV1 + "/user/keys/query" - req := NewFederationRequest("POST", s, path) + req := NewFederationRequest("POST", origin, s, path) if err = req.SetContent(map[string]interface{}{ "device_keys": keys, }); err != nil { @@ -512,10 +524,10 @@ func (ac *FederationClient) QueryKeys(ctx context.Context, s ServerName, keys ma // GetEvent gets an event by ID from a remote server. // See https://matrix.org/docs/spec/server_server/r0.1.1.html#get-matrix-federation-v1-event-eventid func (ac *FederationClient) GetEvent( - ctx context.Context, s ServerName, eventID string, + ctx context.Context, origin, s ServerName, eventID string, ) (res Transaction, err error) { path := federationPathPrefixV1 + "/event/" + url.PathEscape(eventID) - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -523,10 +535,10 @@ func (ac *FederationClient) GetEvent( // GetEventAuth gets an event auth chain from a remote server. // See https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-event-auth-roomid-eventid func (ac *FederationClient) GetEventAuth( - ctx context.Context, s ServerName, roomVersion RoomVersion, roomID, eventID string, + ctx context.Context, origin, s ServerName, roomVersion RoomVersion, roomID, eventID string, ) (res RespEventAuth, err error) { path := federationPathPrefixV1 + "/event_auth/" + url.PathEscape(roomID) + "/" + url.PathEscape(eventID) - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -534,10 +546,10 @@ func (ac *FederationClient) GetEventAuth( // GetUserDevices returns a list of the user's devices from a remote server. // See https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-user-devices-userid func (ac *FederationClient) GetUserDevices( - ctx context.Context, s ServerName, userID string, + ctx context.Context, origin, s ServerName, userID string, ) (res RespUserDevices, err error) { path := federationPathPrefixV1 + "/user/devices/" + url.PathEscape(userID) - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } @@ -546,7 +558,7 @@ func (ac *FederationClient) GetUserDevices( // local database. // See https://matrix.org/docs/spec/server_server/unstable.html#get-matrix-federation-v1-backfill-roomid func (ac *FederationClient) Backfill( - ctx context.Context, s ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s ServerName, roomID string, limit int, eventIDs []string, ) (res Transaction, err error) { // Parse the limit into a string so that we can include it in the URL's query. limitStr := strconv.Itoa(limit) @@ -564,17 +576,17 @@ func (ac *FederationClient) Backfill( path := u.RequestURI() // Send the request. - req := NewFederationRequest("GET", s, path) + req := NewFederationRequest("GET", origin, s, path) err = ac.doRequest(ctx, req, &res) return } // MSC2836EventRelationships performs an MSC2836 /event_relationships request. func (ac *FederationClient) MSC2836EventRelationships( - ctx context.Context, dst ServerName, r MSC2836EventRelationshipsRequest, roomVersion RoomVersion, + ctx context.Context, origin, dst ServerName, r MSC2836EventRelationshipsRequest, roomVersion RoomVersion, ) (res MSC2836EventRelationshipsResponse, err error) { path := "/_matrix/federation/unstable/event_relationships" - req := NewFederationRequest("POST", dst, path) + req := NewFederationRequest("POST", origin, dst, path) if err = req.SetContent(r); err != nil { return } @@ -583,13 +595,13 @@ func (ac *FederationClient) MSC2836EventRelationships( } func (ac *FederationClient) MSC2946Spaces( - ctx context.Context, dst ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, dst ServerName, roomID string, suggestedOnly bool, ) (res MSC2946SpacesResponse, err error) { path := "/_matrix/federation/v1/hierarchy/" + url.PathEscape(roomID) if suggestedOnly { path += "?suggested_only=true" } - req := NewFederationRequest("GET", dst, path) + req := NewFederationRequest("GET", origin, dst, path) err = ac.doRequest(ctx, req, &res) if err != nil { gerr, ok := err.(gomatrix.HTTPError) @@ -599,7 +611,7 @@ func (ac *FederationClient) MSC2946Spaces( if suggestedOnly { path += "?suggested_only=true" } - req := NewFederationRequest("GET", dst, path) + req := NewFederationRequest("GET", origin, dst, path) err = ac.doRequest(ctx, req, &res) } } diff --git a/federationclient_test.go b/federationclient_test.go index dd0b7775..c9bb0c63 100644 --- a/federationclient_test.go +++ b/federationclient_test.go @@ -46,7 +46,13 @@ func TestSendJoinFallback(t *testing.T) { t.Fatalf("failed to marshal RespSendJoin: %s", err) } fc := gomatrixserverlib.NewFederationClient( - serverName, keyID, privateKey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: serverName, + KeyID: keyID, + PrivateKey: privateKey, + }, + }, gomatrixserverlib.WithSkipVerify(true), ) fc.Client = *gomatrixserverlib.NewClient(gomatrixserverlib.WithTransport( @@ -76,7 +82,7 @@ func TestSendJoinFallback(t *testing.T) { if err != nil { t.Fatalf("failed to read event json: %s", err) } - res, err := fc.SendJoin(context.Background(), targetServerName, ev) + res, err := fc.SendJoin(context.Background(), serverName, targetServerName, ev) if err != nil { t.Fatalf("SendJoin returned an error: %s", err) } @@ -102,7 +108,13 @@ func TestSendJoinJSON(t *testing.T) { }`, string(retEv), string(retEv))) fc := gomatrixserverlib.NewFederationClient( - serverName, keyID, privateKey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: serverName, + KeyID: keyID, + PrivateKey: privateKey, + }, + }, gomatrixserverlib.WithSkipVerify(true), ) fc.Client = *gomatrixserverlib.NewClient(gomatrixserverlib.WithTransport( @@ -129,7 +141,7 @@ func TestSendJoinJSON(t *testing.T) { if err != nil { t.Fatalf("failed to read event json: %s", err) } - res, err := fc.SendJoin(context.Background(), targetServerName, ev) + res, err := fc.SendJoin(context.Background(), serverName, targetServerName, ev) if err != nil { t.Fatalf("SendJoin returned an error: %s", err) } diff --git a/request.go b/request.go index a1be0008..6d1a2da8 100644 --- a/request.go +++ b/request.go @@ -37,8 +37,9 @@ type FederationRequest struct { // The destination is the name of a matrix homeserver. // The request path must begin with a slash. // Eg. NewFederationRequest("GET", "matrix.org", "/_matrix/federation/v1/send/123") -func NewFederationRequest(method string, destination ServerName, requestURI string) FederationRequest { +func NewFederationRequest(method string, origin, destination ServerName, requestURI string) FederationRequest { var r FederationRequest + r.fields.Origin = origin r.fields.Destination = destination r.fields.Method = strings.ToUpper(method) r.fields.RequestURI = requestURI diff --git a/request_test.go b/request_test.go index de0048cf..a18956d5 100644 --- a/request_test.go +++ b/request_test.go @@ -47,7 +47,7 @@ const examplePutContent = `{"edus":[{"content":{"device_id":"YHRUBZNPFS",` + func TestSignGetRequest(t *testing.T) { request := NewFederationRequest( - "GET", "localhost:44033", + "GET", "localhost:8800", "localhost:44033", "/_matrix/federation/v1/query/directory?room_alias=%23test%3Alocalhost%3A44033", ) if err := request.Sign("localhost:8800", "ed25519:a_Obwu", privateKey1); err != nil { @@ -104,7 +104,7 @@ func TestVerifyGetRequest(t *testing.T) { func TestSignPutRequest(t *testing.T) { request := NewFederationRequest( - "PUT", "localhost:44033", "/_matrix/federation/v1/send/1493385816575/", + "PUT", "localhost:8800", "localhost:44033", "/_matrix/federation/v1/send/1493385816575/", ) if err := request.SetContent(RawJSON([]byte(examplePutContent))); err != nil { t.Fatal(err)