diff --git a/auth-server/server.go b/auth-server/server.go index 3959026..0a3557c 100644 --- a/auth-server/server.go +++ b/auth-server/server.go @@ -2,6 +2,7 @@ package main import ( "context" + "fmt" "github.com/nspcc-dev/neo-go/pkg/rpcclient" "github.com/nspcc-dev/neo-go/pkg/rpcclient/actor" @@ -77,15 +78,61 @@ func (model StorageClientInfo) GetSecret() string { } func (model StorageClientInfo) GetDomain() string { - // implement as in-memory + return model.GetDomain() } func (model StorageClientInfo) IsPublic() bool { - // implement as in-memory + return model.IsPublic() } func (model StorageClientInfo) GetUserID() string { - // implement as in-memory + return model.GetUserID() +} + +type InMemoryClient struct { + Id string + Domain string + IsPublic bool +} + +var clients map[string]InMemoryClient + +func addInMemoryClient(id string, domain string, isPublic bool) error { + result := InMemoryClient{ + Id: id, + Domain: domain, + IsPublic: isPublic, + } + _, contains := clients[id] + if contains { + return fmt.Errorf("client with id %s already exist", id) + } else { + clients[id] = result + return nil + } +} + +func getInMemoryClient(id string) (*InMemoryClient, error) { + result, contains := clients[id] + if contains { + return &result, nil + } else { + return nil, fmt.Errorf("client with id %s not found", id) + } +} + +func updateInMemoryClientById(id string, client InMemoryClient) error { + _, contains := clients[id] + if !contains || id != client.Id { + return fmt.Errorf("client with id %s not found", id) + } else { + clients[id] = client + return nil + } +} + +func deleteInMemoryClient(id string) { + delete(clients, id) } func main() { @@ -125,6 +172,13 @@ func main() { http.HandleFunc("/login", func(w http.ResponseWriter, r *http.Request) { srv.HandleTokenRequest(w, r) + + id := r.Header.Get("client_id") + _, err := getInMemoryClient(id) + if err != nil { + slog.Warn("Client with id " + id + "not found") + w.WriteHeader(http.StatusBadRequest) + } }) http.HandleFunc("/register", func(writer http.ResponseWriter, request *http.Request) { @@ -151,6 +205,16 @@ func main() { return } + client, _ := clientStore.GetByID(context.Background(), id) + err = addInMemoryClient( + client.GetID(), + client.GetDomain(), + client.IsPublic()) + if err != nil { + slog.Error(err.Error()) + return + } + writer.WriteHeader(http.StatusOK) })