Gomain.go -rw-r--r-- 3 KiB
1package main
2
3import (
4 "context"
5 "flag"
6
7 "fmt"
8
9 "github.com/charmbracelet/huh"
10 "github.com/charmbracelet/huh/spinner"
11 "go.uber.org/zap"
12 "go.uber.org/zap/zapcore"
13)
14
15var log *zap.SugaredLogger
16
17func setupLogging(debug bool) {
18 config := zap.NewDevelopmentConfig()
19 config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
20 if debug {
21 config.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel)
22 } else {
23 config.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
24 }
25 logger, _ := config.Build()
26 log = logger.Sugar()
27}
28
29func main() {
30
31 cfgPath := flag.String("config", "", "Configuration File")
32 nArticles := flag.Int("n", 5, "Number of articles to translate")
33 debug := flag.Bool("debug", false, "Show debugging information")
34 flag.Parse()
35
36 setupLogging(*debug)
37
38 configFile := *cfgPath
39
40 if configFile == "" {
41 configFile = defaultConfigPath()
42 }
43
44 cfg, err := load_config(configFile)
45 if err != nil {
46 log.Fatal(err)
47 }
48
49 localLanguage := cfg.Language
50
51 fmt.Println(Introduction[localLanguage])
52
53 client, err := NewClient(cfg.Model, *debug)
54 if err != nil {
55 log.Fatal(err)
56 }
57
58 var sourceName string
59
60 sources := []huh.Option[string]{}
61 for i := range cfg.Sources {
62 sourceName := cfg.Sources[i].Name()
63 sources = append(sources, huh.NewOption(sourceName, sourceName))
64 }
65
66 form := huh.NewForm(
67 huh.NewGroup(
68 huh.NewSelect[string]().Title(SelectNewsSource[localLanguage]).
69 Options(sources...).Value(&sourceName),
70 ),
71 ).WithTheme(huh.ThemeBase16())
72
73 if err := form.Run(); err != nil {
74 log.Fatal(err)
75 }
76
77 source := cfg.ByName(sourceName)
78
79 ctx := context.Background()
80
81 var nArticlesRemote int
82
83 task := spinner.New().Title("Downloading Article Source...").Action(func() {
84 n, err := source.Len(ctx)
85 if err != nil {
86 log.Fatal(err)
87 }
88 nArticlesRemote = n
89 })
90
91 if err := task.Run(); err != nil {
92 log.Fatal(err)
93 }
94
95 count := min(*nArticles, nArticlesRemote)
96
97 for i := range count {
98
99 article, err := source.Get(ctx, i)
100 if err != nil {
101 log.Fatal(err)
102 }
103
104 for _, chunk := range article.Chunks() {
105
106 var inputText string
107
108 form = huh.NewForm(
109 huh.NewGroup(
110 huh.NewText().Title(chunk).
111 Description(
112 fmt.Sprintf(
113 SummarizeText[localLanguage],
114 localLanguage)).Value(&inputText),
115 ),
116 ).WithTheme(huh.ThemeBase16())
117
118 if err := form.Run(); err != nil {
119 log.Fatal(err)
120 }
121
122 sourceText := chunk
123 if localLanguage != source.Language() {
124 // translate the remote language to the local language before doing
125 // the comparision.
126 translated, err := client.Translate(ctx, chunk, localLanguage, source.Language())
127 if err != nil {
128 log.Fatal(err)
129 }
130 log.Debugf("Translated source:\n%s\nInto:\n%s", sourceText, translated)
131 sourceText = translated
132 }
133
134 var score int
135
136 task := spinner.New().Title("Comparing text").Action(func() {
137 n, err := client.Compare(ctx, sourceText, inputText)
138 if err != nil {
139 log.Fatal(err)
140 }
141 score = n
142 })
143
144 if err := task.Run(); err != nil {
145 log.Fatal(err)
146 }
147
148 log.Infof("Accuracy: %d", score)
149
150 }
151 }
152}