141 lines
3.3 KiB
Go
141 lines
3.3 KiB
Go
package db
|
|
|
|
// Database is laid out like this:
|
|
// mapper/<mapper_id>/trackers/<channel_id> -> priority
|
|
// mapper/<mapper_id>/latestEvent
|
|
// channel/<channel_id>/tracks/<mapper_id> -> priority
|
|
|
|
import (
|
|
"strconv"
|
|
|
|
"github.com/pkg/errors"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/clause"
|
|
|
|
"subscribe-bot/osuapi"
|
|
)
|
|
|
|
var (
|
|
LATEST_EVENT = []byte("latestEvent")
|
|
MAPPERS = []byte("mapper")
|
|
CHANNELS = []byte("channels")
|
|
)
|
|
|
|
type Db struct {
|
|
gorm *gorm.DB
|
|
api *osuapi.Osuapi
|
|
}
|
|
|
|
func OpenDb(path string, api *osuapi.Osuapi) (db *Db, err error) {
|
|
gorm, err := gorm.Open(sqlite.Open(path), &gorm.Config{})
|
|
if err != nil {
|
|
panic("failed to connect database")
|
|
}
|
|
|
|
// auto-migrate
|
|
gorm.AutoMigrate(&Config{})
|
|
gorm.AutoMigrate(&User{})
|
|
gorm.AutoMigrate(&Beatmapset{})
|
|
gorm.AutoMigrate(&DiscordChannel{})
|
|
|
|
db = &Db{gorm, api}
|
|
return
|
|
}
|
|
|
|
// Loop over channels that are tracking this specific mapper
|
|
func (db *Db) IterTrackingChannels(mapperId int, fn func(channel DiscordChannel) error) (err error) {
|
|
var channels []DiscordChannel
|
|
db.gorm.Model(&User{ID: mapperId}).Association("TrackingChannels").Find(&channels)
|
|
for _, channel := range channels {
|
|
fn(channel)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// Loop over tracked mappers for this channel
|
|
func (db *Db) IterChannelTrackedMappers(channelId string, fn func(user User) error) (err error) {
|
|
var mappers []User
|
|
db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Find(&mappers)
|
|
for _, mapper := range mappers {
|
|
fn(mapper)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// Loop over all tracked mappers
|
|
func (db *Db) IterAllTrackedMappers(fn func(user User) error) (err error) {
|
|
var mappers []User
|
|
db.gorm.Find(&mappers)
|
|
for _, mapper := range mappers {
|
|
fn(mapper)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (db *Db) UpdateMapperLastEvent(userId int, eventId int) (err error) {
|
|
if eventId == -1 {
|
|
var events []osuapi.Event
|
|
events, err = db.api.GetUserEvents(userId, 1, 0)
|
|
if err != nil {
|
|
err = errors.Wrap(err, "couldn't get user events from API")
|
|
return
|
|
}
|
|
|
|
if len(events) > 0 {
|
|
eventId = events[0].ID
|
|
}
|
|
}
|
|
|
|
db.gorm.Model(&User{}).Where("id = ?", userId).Update("latest_event_id", eventId)
|
|
return nil
|
|
}
|
|
|
|
// Get the latest event ID of this mapper, if they have one
|
|
func (db *Db) MapperLastEvent(userId int) (has bool, id int) {
|
|
var user User
|
|
db.gorm.Select("latest_event_id").First(&user)
|
|
return true, user.LatestEventID
|
|
}
|
|
|
|
// Start tracking a new mapper (if they're not already tracked)
|
|
func (db *Db) ChannelTrackMapper(channelId string, mapperId int, priority int) (err error) {
|
|
err = db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Append(db.getUser(mapperId))
|
|
if err != nil {
|
|
err = errors.Wrap(err, "could not add tracking for channel "+channelId)
|
|
return
|
|
}
|
|
|
|
err = db.UpdateMapperLastEvent(mapperId, -1)
|
|
if err != nil {
|
|
err = errors.Wrap(err, "could not update mapper latest event")
|
|
return
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *Db) Close() {
|
|
}
|
|
|
|
func (db *Db) getUser(userId int) (user *User, err error) {
|
|
// TODO: cache user info for some time?
|
|
|
|
apiUser, err := db.api.GetUser(strconv.Itoa(userId))
|
|
if err != nil {
|
|
err = errors.Wrap(err, "could not retrieve user from the API")
|
|
return
|
|
}
|
|
|
|
user = &User{
|
|
ID: userId,
|
|
Username: apiUser.Username,
|
|
Country: apiUser.CountryCode,
|
|
}
|
|
db.gorm.Clauses(clause.OnConflict{UpdateAll: true}).Create(user)
|
|
|
|
return
|
|
}
|