1 | package main
|
2 |
|
3 | import (
|
4 | "context"
|
5 | "flag"
|
6 |
|
7 | "fmt"
|
8 |
|
9 | "github.com/charmbracelet/huh"
|
10 | "github.com/charmbracelet/huh/spinner"
|
11 | "go.uber.org/zap"
|
12 | "go.uber.org/zap/zapcore"
|
13 | )
|
14 |
|
15 | var log *zap.SugaredLogger
|
16 |
|
17 | func setupLogging(debug bool) {
|
18 | config := zap.NewDevelopmentConfig()
|
19 | config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
20 | if debug {
|
21 | config.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel)
|
22 | } else {
|
23 | config.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
|
24 | }
|
25 | logger, _ := config.Build()
|
26 | log = logger.Sugar()
|
27 | }
|
28 |
|
29 | func main() {
|
30 |
|
31 | cfgPath := flag.String("config", "", "Configuration File")
|
32 | nArticles := flag.Int("n", 5, "Number of articles to translate")
|
33 | debug := flag.Bool("debug", false, "Show debugging information")
|
34 | flag.Parse()
|
35 |
|
36 | setupLogging(*debug)
|
37 |
|
38 | configFile := *cfgPath
|
39 |
|
40 | if configFile == "" {
|
41 | configFile = defaultConfigPath()
|
42 | }
|
43 |
|
44 | cfg, err := load_config(configFile)
|
45 | if err != nil {
|
46 | log.Fatal(err)
|
47 | }
|
48 |
|
49 | localLanguage := cfg.Language
|
50 |
|
51 | fmt.Println(Introduction[localLanguage])
|
52 |
|
53 | client, err := NewClient(cfg.Model, *debug)
|
54 | if err != nil {
|
55 | log.Fatal(err)
|
56 | }
|
57 |
|
58 | var sourceName string
|
59 |
|
60 | sources := []huh.Option[string]{}
|
61 | for i := range cfg.Sources {
|
62 | sourceName := cfg.Sources[i].Name()
|
63 | sources = append(sources, huh.NewOption(sourceName, sourceName))
|
64 | }
|
65 |
|
66 | form := huh.NewForm(
|
67 | huh.NewGroup(
|
68 | huh.NewSelect[string]().Title(SelectNewsSource[localLanguage]).
|
69 | Options(sources...).Value(&sourceName),
|
70 | ),
|
71 | ).WithTheme(huh.ThemeBase16())
|
72 |
|
73 | if err := form.Run(); err != nil {
|
74 | log.Fatal(err)
|
75 | }
|
76 |
|
77 | source := cfg.ByName(sourceName)
|
78 |
|
79 | ctx := context.Background()
|
80 |
|
81 | var nArticlesRemote int
|
82 |
|
83 | task := spinner.New().Title("Downloading Article Source...").Action(func() {
|
84 | n, err := source.Len(ctx)
|
85 | if err != nil {
|
86 | log.Fatal(err)
|
87 | }
|
88 | nArticlesRemote = n
|
89 | })
|
90 |
|
91 | if err := task.Run(); err != nil {
|
92 | log.Fatal(err)
|
93 | }
|
94 |
|
95 | count := min(*nArticles, nArticlesRemote)
|
96 |
|
97 | for i := range count {
|
98 |
|
99 | article, err := source.Get(ctx, i)
|
100 | if err != nil {
|
101 | log.Fatal(err)
|
102 | }
|
103 |
|
104 | for _, chunk := range article.Chunks() {
|
105 |
|
106 | var inputText string
|
107 |
|
108 | form = huh.NewForm(
|
109 | huh.NewGroup(
|
110 | huh.NewText().Title(chunk).
|
111 | Description(
|
112 | fmt.Sprintf(
|
113 | SummarizeText[localLanguage],
|
114 | localLanguage)).Value(&inputText),
|
115 | ),
|
116 | ).WithTheme(huh.ThemeBase16())
|
117 |
|
118 | if err := form.Run(); err != nil {
|
119 | log.Fatal(err)
|
120 | }
|
121 |
|
122 | sourceText := chunk
|
123 | if localLanguage != source.Language() {
|
124 | // translate the remote language to the local language before doing
|
125 | // the comparision.
|
126 | translated, err := client.Translate(ctx, chunk, localLanguage, source.Language())
|
127 | if err != nil {
|
128 | log.Fatal(err)
|
129 | }
|
130 | log.Debugf("Translated source:\n%s\nInto:\n%s", sourceText, translated)
|
131 | sourceText = translated
|
132 | }
|
133 |
|
134 | var score int
|
135 |
|
136 | task := spinner.New().Title("Comparing text").Action(func() {
|
137 | n, err := client.Compare(ctx, sourceText, inputText)
|
138 | if err != nil {
|
139 | log.Fatal(err)
|
140 | }
|
141 | score = n
|
142 | })
|
143 |
|
144 | if err := task.Run(); err != nil {
|
145 | log.Fatal(err)
|
146 | }
|
147 |
|
148 | log.Infof("Accuracy: %d", score)
|
149 |
|
150 | }
|
151 | }
|
152 | }
|