subscribe-bot/db/db.go
2021-07-21 17:15:11 -05:00

141 lines
3.3 KiB
Go

package db
// Database is laid out like this:
// mapper/<mapper_id>/trackers/<channel_id> -> priority
// mapper/<mapper_id>/latestEvent
// channel/<channel_id>/tracks/<mapper_id> -> priority
import (
"strconv"
"github.com/pkg/errors"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"subscribe-bot/osuapi"
)
var (
LATEST_EVENT = []byte("latestEvent")
MAPPERS = []byte("mapper")
CHANNELS = []byte("channels")
)
type Db struct {
gorm *gorm.DB
api *osuapi.Osuapi
}
func OpenDb(path string, api *osuapi.Osuapi) (db *Db, err error) {
gorm, err := gorm.Open(sqlite.Open(path), &gorm.Config{})
if err != nil {
panic("failed to connect database")
}
// auto-migrate
gorm.AutoMigrate(&Config{})
gorm.AutoMigrate(&User{})
gorm.AutoMigrate(&Beatmapset{})
gorm.AutoMigrate(&DiscordChannel{})
db = &Db{gorm, api}
return
}
// Loop over channels that are tracking this specific mapper
func (db *Db) IterTrackingChannels(mapperId int, fn func(channel DiscordChannel) error) (err error) {
var channels []DiscordChannel
db.gorm.Model(&User{ID: mapperId}).Association("TrackingChannels").Find(&channels)
for _, channel := range channels {
fn(channel)
}
return
}
// Loop over tracked mappers for this channel
func (db *Db) IterChannelTrackedMappers(channelId string, fn func(user User) error) (err error) {
var mappers []User
db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Find(&mappers)
for _, mapper := range mappers {
fn(mapper)
}
return
}
// Loop over all tracked mappers
func (db *Db) IterAllTrackedMappers(fn func(user User) error) (err error) {
var mappers []User
db.gorm.Find(&mappers)
for _, mapper := range mappers {
fn(mapper)
}
return
}
func (db *Db) UpdateMapperLastEvent(userId int, eventId int) (err error) {
if eventId == -1 {
var events []osuapi.Event
events, err = db.api.GetUserEvents(userId, 1, 0)
if err != nil {
err = errors.Wrap(err, "couldn't get user events from API")
return
}
if len(events) > 0 {
eventId = events[0].ID
}
}
db.gorm.Model(&User{}).Where("id = ?", userId).Update("latest_event_id", eventId)
return nil
}
// Get the latest event ID of this mapper, if they have one
func (db *Db) MapperLastEvent(userId int) (has bool, id int) {
var user User
db.gorm.Select("latest_event_id").First(&user)
return true, user.LatestEventID
}
// Start tracking a new mapper (if they're not already tracked)
func (db *Db) ChannelTrackMapper(channelId string, mapperId int, priority int) (err error) {
err = db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Append(db.getUser(mapperId))
if err != nil {
err = errors.Wrap(err, "could not add tracking for channel "+channelId)
return
}
err = db.UpdateMapperLastEvent(mapperId, -1)
if err != nil {
err = errors.Wrap(err, "could not update mapper latest event")
return
}
return nil
}
func (db *Db) Close() {
}
func (db *Db) getUser(userId int) (user *User, err error) {
// TODO: cache user info for some time?
apiUser, err := db.api.GetUser(strconv.Itoa(userId))
if err != nil {
err = errors.Wrap(err, "could not retrieve user from the API")
return
}
user = &User{
ID: userId,
Username: apiUser.Username,
Country: apiUser.CountryCode,
}
db.gorm.Clauses(clause.OnConflict{UpdateAll: true}).Create(user)
return
}