diff --git a/internal/pkg/apiclient/updatedb.go b/internal/pkg/apiclient/updatedb.go index a39cb5a..9a9dbc5 100644 --- a/internal/pkg/apiclient/updatedb.go +++ b/internal/pkg/apiclient/updatedb.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "net/http" + "reflect" "github.com/AFASystems/presence/internal/pkg/config" "github.com/AFASystems/presence/internal/pkg/controller" @@ -25,10 +26,8 @@ func UpdateDB(db *gorm.DB, ctx context.Context, cfg *config.Config, writer *kafk return err } - trackers, err := GetTrackers(token, client) - if err != nil { - fmt.Printf("Error in getting trackers: %+v\n", err) - } else { + if trackers, err := GetTrackers(token, client); err == nil { + syncTable(db, trackers) if err := controller.SendKafkaMessage(writer, &model.ApiUpdate{Method: "DELETE", MAC: "all"}, ctx); err != nil { fmt.Printf("Error in sending delete all from lookup message: %v", err) } @@ -44,63 +43,36 @@ func UpdateDB(db *gorm.DB, ctx context.Context, cfg *config.Config, writer *kafk fmt.Printf("Error in sending POST kafka message: %v", err) } } + } - var ids []string - for _, t := range trackers { - ids = append(ids, t.ID) - } - db.Where("id NOT IN ?", ids).Delete(&model.Tracker{}) + if gateways, err := GetGateways(token, client); err == nil { + syncTable(db, gateways) } - gateways, err := GetGateways(token, client) - if err != nil { - fmt.Printf("Error in getting gateways: %+v\n", err) - } else { - var ids []string - for _, g := range gateways { - ids = append(ids, g.ID) - } - db.Where("id NOT IN ?", ids).Delete(&model.Gateway{}) + if zones, err := GetZones(token, client); err == nil { + syncTable(db, zones) } - zones, err := GetZones(token, client) - if err != nil { - fmt.Printf("Error in getting zones: %+v\n", err) - } else { - var ids []string - for _, z := range zones { - ids = append(ids, z.ID) - } - db.Where("id NOT IN ?", ids).Delete(&model.Zone{}) + if trackerZones, err := GetTrackerZones(token, client); err == nil { + syncTable(db, trackerZones) } - trackerZones, err := GetTrackerZones(token, client) - if err != nil { - fmt.Printf("Error in getting tracker zones: %+v\n", err) - } else { - var ids []string - for _, tz := range trackerZones { - ids = append(ids, tz.ID) - } - db.Where("id NOT IN ?", ids).Delete(&model.TrackerZones{}) + return nil +} + +func syncTable[T any](db *gorm.DB, data []T) { + if len(data) == 0 { + return } - db.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "id"}}, - UpdateAll: true, - }).Create(&trackers) - db.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "id"}}, - UpdateAll: true, - }).Create(&gateways) - db.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "id"}}, - UpdateAll: true, - }).Create(&zones) - db.Clauses(clause.OnConflict{ - Columns: []clause.Column{{Name: "id"}}, - UpdateAll: true, - }).Create(&trackerZones) + var ids []string + for _, item := range data { + v := reflect.ValueOf(item).FieldByName("ID").String() + ids = append(ids, v) + } - return nil + db.Transaction(func(tx *gorm.DB) error { + tx.Where("id NOT IN ?", ids).Delete(new(T)) + return tx.Clauses(clause.OnConflict{UpdateAll: true}).Create(&data).Error + }) }