package main
|
|
import (
|
"context"
|
"flag"
|
|
"fmt"
|
|
"github.com/charmbracelet/huh"
|
"github.com/charmbracelet/huh/spinner"
|
"go.uber.org/zap"
|
"go.uber.org/zap/zapcore"
|
)
|
|
var log *zap.SugaredLogger
|
|
func setupLogging(debug bool) {
|
config := zap.NewDevelopmentConfig()
|
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
if debug {
|
config.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel)
|
} else {
|
config.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
|
}
|
logger, _ := config.Build()
|
log = logger.Sugar()
|
}
|
|
func main() {
|
|
cfgPath := flag.String("config", "", "Configuration File")
|
nArticles := flag.Int("n", 5, "Number of articles to translate")
|
debug := flag.Bool("debug", false, "Show debugging information")
|
flag.Parse()
|
|
setupLogging(*debug)
|
|
configFile := *cfgPath
|
|
if configFile == "" {
|
configFile = defaultConfigPath()
|
}
|
|
cfg, err := load_config(configFile)
|
if err != nil {
|
log.Fatal(err)
|
}
|
|
localLanguage := cfg.Language
|
|
fmt.Println(Introduction[localLanguage])
|
|
client, err := NewClient(cfg.Model, *debug)
|
if err != nil {
|
log.Fatal(err)
|
}
|
|
var sourceName string
|
|
sources := []huh.Option[string]{}
|
for i := range cfg.Sources {
|
sourceName := cfg.Sources[i].Name()
|
sources = append(sources, huh.NewOption(sourceName, sourceName))
|
}
|
|
form := huh.NewForm(
|
huh.NewGroup(
|
huh.NewSelect[string]().Title(SelectNewsSource[localLanguage]).
|
Options(sources...).Value(&sourceName),
|
),
|
).WithTheme(huh.ThemeBase16())
|
|
if err := form.Run(); err != nil {
|
log.Fatal(err)
|
}
|
|
source := cfg.ByName(sourceName)
|
|
ctx := context.Background()
|
|
var nArticlesRemote int
|
|
task := spinner.New().Title("Downloading Article Source...").Action(func() {
|
n, err := source.Len(ctx)
|
if err != nil {
|
log.Fatal(err)
|
}
|
nArticlesRemote = n
|
})
|
|
if err := task.Run(); err != nil {
|
log.Fatal(err)
|
}
|
|
count := min(*nArticles, nArticlesRemote)
|
|
for i := range count {
|
|
article, err := source.Get(ctx, i)
|
if err != nil {
|
log.Fatal(err)
|
}
|
|
for _, chunk := range article.Chunks() {
|
|
var inputText string
|
|
form = huh.NewForm(
|
huh.NewGroup(
|
huh.NewText().Title(chunk).
|
Description(
|
fmt.Sprintf(
|
SummarizeText[localLanguage],
|
localLanguage)).Value(&inputText),
|
),
|
).WithTheme(huh.ThemeBase16())
|
|
if err := form.Run(); err != nil {
|
log.Fatal(err)
|
}
|
|
sourceText := chunk
|
if localLanguage != source.Language() {
|
// translate the remote language to the local language before doing
|
// the comparision.
|
translated, err := client.Translate(ctx, chunk, localLanguage, source.Language())
|
if err != nil {
|
log.Fatal(err)
|
}
|
log.Debugf("Translated source:\n%s\nInto:\n%s", sourceText, translated)
|
sourceText = translated
|
}
|
|
var score int
|
|
task := spinner.New().Title("Comparing text").Action(func() {
|
n, err := client.Compare(ctx, sourceText, inputText)
|
if err != nil {
|
log.Fatal(err)
|
}
|
score = n
|
})
|
|
if err := task.Run(); err != nil {
|
log.Fatal(err)
|
}
|
|
log.Infof("Accuracy: %d", score)
|
|
}
|
}
|
}
|