1 | package main
|
2 |
|
3 | import (
|
4 | "context"
|
5 | "fmt"
|
6 | "strconv"
|
7 |
|
8 | "github.com/ollama/ollama/api"
|
9 | )
|
10 |
|
11 | type Client struct {
|
12 | c *api.Client
|
13 | debug bool
|
14 | model string
|
15 | }
|
16 |
|
17 | func NewClient(model string, debug bool) (*Client, error) {
|
18 |
|
19 | client, err := api.ClientFromEnvironment()
|
20 | if err != nil {
|
21 | return nil, err
|
22 | }
|
23 |
|
24 | return &Client{
|
25 | c: client,
|
26 | debug: debug,
|
27 | model: model,
|
28 | }, nil
|
29 | }
|
30 |
|
31 | func (c Client) Translate(ctx context.Context, input string, local, remote Language) (string, error) {
|
32 | messages := []api.Message{
|
33 | {
|
34 | Role: "system",
|
35 | Content: fmt.Sprintf(
|
36 | TRANSLATE_TEMPLATE, remote.String(), local.String()),
|
37 | },
|
38 | {
|
39 | Role: "user",
|
40 | Content: input,
|
41 | },
|
42 | }
|
43 |
|
44 | text := ""
|
45 |
|
46 | req := &api.ChatRequest{
|
47 | Model: c.model,
|
48 | Messages: messages,
|
49 | }
|
50 |
|
51 | respFunc := func(resp api.ChatResponse) error {
|
52 | if c.debug {
|
53 | fmt.Print(resp.Message.Content)
|
54 | }
|
55 | text += resp.Message.Content
|
56 | return nil
|
57 | }
|
58 |
|
59 | err := c.c.Chat(ctx, req, respFunc)
|
60 | if err != nil {
|
61 | return "", nil
|
62 | }
|
63 |
|
64 | return text, nil
|
65 | }
|
66 |
|
67 | func (c Client) Compare(ctx context.Context, t1, t2 string) (int, error) {
|
68 | promptInput := fmt.Sprintf(COMPARISON_TEMPLATE, t1, t2)
|
69 | log.Debugf("input: %s", promptInput)
|
70 | messages := []api.Message{
|
71 | {
|
72 | Role: "system",
|
73 | Content: COMPARISION_TEMPLATE_SYSTEM,
|
74 | },
|
75 | {
|
76 | Role: "user",
|
77 | Content: promptInput,
|
78 | },
|
79 | }
|
80 |
|
81 | text := ""
|
82 |
|
83 | req := &api.ChatRequest{
|
84 | Model: c.model,
|
85 | Messages: messages,
|
86 | }
|
87 |
|
88 | respFunc := func(resp api.ChatResponse) error {
|
89 | text += resp.Message.Content
|
90 | return nil
|
91 | }
|
92 |
|
93 | err := c.c.Chat(ctx, req, respFunc)
|
94 | if err != nil {
|
95 | return -1, err
|
96 | }
|
97 |
|
98 | score, err := strconv.Atoi(text)
|
99 |
|
100 | if err != nil {
|
101 | log.Errorf("Failed to extract numeric score from model output: %s", err)
|
102 | }
|
103 |
|
104 | return score, nil
|
105 | }
|