mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-22 14:53:56 +00:00
Compare commits
622 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59e3a35203 | ||
|
|
b3e6120736 | ||
|
|
fb92b61754 | ||
|
|
8149a3c86e | ||
|
|
0cc90a8186 | ||
|
|
e42300f25b | ||
|
|
66e73809a1 | ||
|
|
c632fdbad8 | ||
|
|
517807cdf2 | ||
|
|
ead4a9a1d0 | ||
|
|
4383a3ab7a | ||
|
|
9d97e6a9f1 | ||
|
|
1081532430 | ||
|
|
59412fbb43 | ||
|
|
86834a2797 | ||
|
|
85ccf7354d | ||
|
|
30fb7e19f8 | ||
|
|
d3450dd52e | ||
|
|
4bcb04ad88 | ||
|
|
e3d5708754 | ||
|
|
4be4dc8717 | ||
|
|
109d4fc3b4 | ||
|
|
2cb0a580f3 | ||
|
|
7cce5aac76 | ||
|
|
131c496340 | ||
|
|
4ae4f47b16 | ||
|
|
073fa31df5 | ||
|
|
91fc3c48e3 | ||
|
|
6de62664d9 | ||
|
|
463a6caad8 | ||
|
|
fc5fb09f51 | ||
|
|
05ccb17c6e | ||
|
|
f804e8a460 | ||
|
|
9cfbffafc5 | ||
|
|
470d580205 | ||
|
|
b517bb1c19 | ||
|
|
e3ade453a8 | ||
|
|
048bd4472a | ||
|
|
ec8bf5e6c5 | ||
|
|
709bbb0b6d | ||
|
|
abeec240f9 | ||
|
|
df335aac09 | ||
|
|
026bc29237 | ||
|
|
883d031268 | ||
|
|
5271ff8559 | ||
|
|
d6f7233a1c | ||
|
|
8de1da4767 | ||
|
|
d925b5350c | ||
|
|
6eaf194b85 | ||
|
|
d5a0d8d904 | ||
|
|
ef7d26ba2c | ||
|
|
1a19df1f3a | ||
|
|
7ccfd97a93 | ||
|
|
c385ca8672 | ||
|
|
837379a94c | ||
|
|
a24f90604f | ||
|
|
dc5a645434 | ||
|
|
bb71654ebe | ||
|
|
d4af9f04f9 | ||
|
|
a343ae53a4 | ||
|
|
d0cf6c8281 | ||
|
|
8f4ec9ab28 | ||
|
|
dbfd7bd027 | ||
|
|
ee04dbba51 | ||
|
|
ea7657b54a | ||
|
|
2c776f0780 | ||
|
|
79f6376f5b | ||
|
|
756c78cfc7 | ||
|
|
d7f4f788d1 | ||
|
|
114c3f2265 | ||
|
|
f2e9c9aff5 | ||
|
|
aa9d889522 | ||
|
|
735c41f9ca | ||
|
|
223a619468 | ||
|
|
759dd78dd6 | ||
|
|
44bc36d063 | ||
|
|
8f14e1f5f6 | ||
|
|
203c137810 | ||
|
|
fa8be9e35c | ||
|
|
8a75e9ee15 | ||
|
|
9231379bce | ||
|
|
c7ba6128b4 | ||
|
|
8970233a2b | ||
|
|
cde948f976 | ||
|
|
7c8aba0d83 | ||
|
|
4742e12c23 | ||
|
|
2d06977ade | ||
|
|
30f8a68c4c | ||
|
|
e378e33421 | ||
|
|
fcec04bf42 | ||
|
|
ee92ca3e1d | ||
|
|
8253ad4d2b | ||
|
|
fa7776fd24 | ||
|
|
0d38b66502 | ||
|
|
e5e077b4b7 | ||
|
|
4183bb0574 | ||
|
|
ff89ba90bc | ||
|
|
6dcc5dfb9c | ||
|
|
25911a6e6b | ||
|
|
8afa6e83f2 | ||
|
|
ea85e27bbd | ||
|
|
c116a7523d | ||
|
|
3515cc377c | ||
|
|
bbf66c0b96 | ||
|
|
764be7480f | ||
|
|
b72e5adb14 | ||
|
|
80b538e312 | ||
|
|
4f8a0166cc | ||
|
|
1e6eab5c33 | ||
|
|
6c733bf0a6 | ||
|
|
3bac5cba60 | ||
|
|
4151ef8cf7 | ||
|
|
e4ff6e6c0f | ||
|
|
82da19c634 | ||
|
|
bdd9d22dfd | ||
|
|
5fc38d042f | ||
|
|
475a11d08e | ||
|
|
191d94289d | ||
|
|
802ad16ce4 | ||
|
|
5e67f4f90e | ||
|
|
e840ccb523 | ||
|
|
b4fe3adc0a | ||
|
|
d73f8aa8c3 | ||
|
|
92c2e8a56c | ||
|
|
2e3fd86d48 | ||
|
|
4261a3b0b2 | ||
|
|
acef9b4c1b | ||
|
|
9a43994c45 | ||
|
|
f8a6e88819 | ||
|
|
35fda7b4af | ||
|
|
66fb8575ce | ||
|
|
20c3266e94 | ||
|
|
34088dbcfb | ||
|
|
e41dd73705 | ||
|
|
43107b15b9 | ||
|
|
1f91cb0c8c | ||
|
|
12d8ad0d38 | ||
|
|
592d21e7db | ||
|
|
5a08b01f5b | ||
|
|
4f473e224c | ||
|
|
9d60bb44cf | ||
|
|
f371260e75 | ||
|
|
c9e6d7719e | ||
|
|
2c4ce40334 | ||
|
|
5d8c173529 | ||
|
|
44b17d2bfa | ||
|
|
4ad87b58bb | ||
|
|
3b8b692218 | ||
|
|
4129af9205 | ||
|
|
45f216a9c7 | ||
|
|
d0b32def60 | ||
|
|
11ffc36157 | ||
|
|
ba04902670 | ||
|
|
3944602f51 | ||
|
|
73b642e6f3 | ||
|
|
ad118d8b13 | ||
|
|
f08534137b | ||
|
|
4b4a90f233 | ||
|
|
03274a6b2f | ||
|
|
cc6463ebca | ||
|
|
405d2f628f | ||
|
|
a3f7dd3e98 | ||
|
|
c85c0ebf89 | ||
|
|
10a8e04a8d | ||
|
|
1c6669e64c | ||
|
|
b2b270ad5d | ||
|
|
2bb69b40c7 | ||
|
|
65bff664cb | ||
|
|
c088ac0e79 | ||
|
|
0a066cfd91 | ||
|
|
87b7af6cee | ||
|
|
f2527b08fb | ||
|
|
71a4057fcf | ||
|
|
5ab7422508 | ||
|
|
8bcb3125c1 | ||
|
|
6baf1e31e2 | ||
|
|
ed567ef43b | ||
|
|
a6e64fbdf2 | ||
|
|
60cfa2a203 | ||
|
|
55bbf3b4a1 | ||
|
|
6bda1d2479 | ||
|
|
50f2219dd6 | ||
|
|
9e125d884c | ||
|
|
a6fbfc880c | ||
|
|
502028968d | ||
|
|
5a8eb0e151 | ||
|
|
9f8a18ec05 | ||
|
|
6b04cad7e8 | ||
|
|
45f56355d5 | ||
|
|
0dabb4ef6a | ||
|
|
2e77aa1ae7 | ||
|
|
deaabe292d | ||
|
|
af21a5ac39 | ||
|
|
f63d7f68eb | ||
|
|
82ad1dbc07 | ||
|
|
feeabdadd2 | ||
|
|
fc0309615e | ||
|
|
09d308d6b6 | ||
|
|
a8ed68bd93 | ||
|
|
2ae65ae471 | ||
|
|
a3b6886b7d | ||
|
|
c6a6d7294d | ||
|
|
2cf007c9d1 | ||
|
|
0683efa637 | ||
|
|
0943001193 | ||
|
|
5c42800fca | ||
|
|
65f10c2823 | ||
|
|
aaa7818000 | ||
|
|
f15ffc4320 | ||
|
|
d008f108cc | ||
|
|
5f57b0ef42 | ||
|
|
aa25aff10d | ||
|
|
ea79003180 | ||
|
|
9239a254e0 | ||
|
|
066d0f4746 | ||
|
|
aea6fb9b58 | ||
|
|
012cf65340 | ||
|
|
a45231af47 | ||
|
|
2307fc2bcd | ||
|
|
6623898198 | ||
|
|
eda472df1b | ||
|
|
f18e0cb550 | ||
|
|
68b58c5cb8 | ||
|
|
e8b981fa5d | ||
|
|
884d26093c | ||
|
|
1f371ea92f | ||
|
|
73d6a82cce | ||
|
|
6db8a3771c | ||
|
|
d950ff12c0 | ||
|
|
adff143bcd | ||
|
|
fbe6ae285a | ||
|
|
fdd4d479a3 | ||
|
|
61aeaf7e81 | ||
|
|
7359b02707 | ||
|
|
c890011322 | ||
|
|
e0ed984cde | ||
|
|
139f84cf21 | ||
|
|
375839ea2d | ||
|
|
69b2fe9282 | ||
|
|
9ed8bf14cb | ||
|
|
e6a800ca11 | ||
|
|
ff180c3466 | ||
|
|
3fe74fba42 | ||
|
|
1a0cfd080a | ||
|
|
94ab428e3f | ||
|
|
d755577473 | ||
|
|
a2cc8571c5 | ||
|
|
7edfdd2f5f | ||
|
|
333e360422 | ||
|
|
cb104a2082 | ||
|
|
27da2cddc5 | ||
|
|
feb8923ada | ||
|
|
fe623c2cf4 | ||
|
|
3c14461d5d | ||
|
|
499ae7311f | ||
|
|
ef202789fa | ||
|
|
55760195e6 | ||
|
|
bd68d3ae50 | ||
|
|
ff80718e9c | ||
|
|
0aa8b371dd | ||
|
|
23125648b8 | ||
|
|
0478d440f0 | ||
|
|
8cc33f4c2b | ||
|
|
f46df4e5d2 | ||
|
|
c6bcdc4223 | ||
|
|
4b903f088a | ||
|
|
c7f4ae7b9c | ||
|
|
526b2ed102 | ||
|
|
a7240c6d63 | ||
|
|
9d6df90805 | ||
|
|
0cefd46f23 | ||
|
|
ad035ad595 | ||
|
|
f95a1f2bef | ||
|
|
82a9e9462a | ||
|
|
76724e2f29 | ||
|
|
ecf14a220f | ||
|
|
69ce44b33c | ||
|
|
5969674cf1 | ||
|
|
867d75b21e | ||
|
|
3fa78598a1 | ||
|
|
0d6e35d3c6 | ||
|
|
20c5fd39c8 | ||
|
|
6e9a7a2568 | ||
|
|
b585a58121 | ||
|
|
fa9973cd7f | ||
|
|
3d9498a425 | ||
|
|
3098c8b29b | ||
|
|
5e380c3b42 | ||
|
|
392de84031 | ||
|
|
5d967d59b1 | ||
|
|
af31ccefc0 | ||
|
|
fa393554b9 | ||
|
|
307e3b3e1d | ||
|
|
4090aca97b | ||
|
|
92ce438de0 | ||
|
|
424810450f | ||
|
|
95e744beeb | ||
|
|
3b2d2c8326 | ||
|
|
d931ee8f22 | ||
|
|
7073600797 | ||
|
|
b1c40138da | ||
|
|
17466217e5 | ||
|
|
1703d1472e | ||
|
|
913905028b | ||
|
|
7e5c8eee5c | ||
|
|
6a74bba7e7 | ||
|
|
76ea735aaf | ||
|
|
dd1d4e99e7 | ||
|
|
a6ef73f4f2 | ||
|
|
c2f5d6662b | ||
|
|
57fb759f3c | ||
|
|
8dd12c873d | ||
|
|
e6d2d04121 | ||
|
|
074bac8447 | ||
|
|
8e8f2c6d67 | ||
|
|
938e8447e8 | ||
|
|
d5d5f0c445 | ||
|
|
5478571e92 | ||
|
|
a7835c6716 | ||
|
|
ad3c7c9bda | ||
|
|
415c8fcc3d | ||
|
|
718eda1b3e | ||
|
|
421b7edeb4 | ||
|
|
7b68e254c2 | ||
|
|
7bec2724a5 | ||
|
|
a27462b708 | ||
|
|
6bf0b8193a | ||
|
|
db428adbb8 | ||
|
|
fe5b9bb21b | ||
|
|
6ec71d8fb6 | ||
|
|
44b466eeb2 | ||
|
|
a25f3f8260 | ||
|
|
dd93e1af85 | ||
|
|
d2ee599dcf | ||
|
|
6ed8898590 | ||
|
|
5cfc1c39f3 | ||
|
|
f0ad49ea17 | ||
|
|
7ba9fa9c7d | ||
|
|
8bf11b84c1 | ||
|
|
470af8ab89 | ||
|
|
178761aef3 | ||
|
|
f0c66e6dea | ||
|
|
54055a6dae | ||
|
|
340448d2d1 | ||
|
|
ced7d0e53d | ||
|
|
a0dba0f8ae | ||
|
|
5e20b170a7 | ||
|
|
d26c18e25c | ||
|
|
8d376acc9b | ||
|
|
dc1e81f027 | ||
|
|
5d0279164c | ||
|
|
214a7678ea | ||
|
|
4892872c18 | ||
|
|
0b9198bf47 | ||
|
|
e9e5f61c45 | ||
|
|
11dde41824 | ||
|
|
a53d744b01 | ||
|
|
e82cdb5f24 | ||
|
|
40b10eee6d | ||
|
|
424f648632 | ||
|
|
2eb1fb3231 | ||
|
|
0806521642 | ||
|
|
88738b357b | ||
|
|
4e535e6188 | ||
|
|
40b8fdbdca | ||
|
|
d9472e31b7 | ||
|
|
1d99451ad7 | ||
|
|
09bb2e30f6 | ||
|
|
dc264be6ff | ||
|
|
fbe7039618 | ||
|
|
943464ccb8 | ||
|
|
369de832cd | ||
|
|
3457a315b2 | ||
|
|
ed4e139314 | ||
|
|
56dc316a57 | ||
|
|
2fec73eef6 | ||
|
|
1e7f62cb42 | ||
|
|
ccb7eb8135 | ||
|
|
637fd21230 | ||
|
|
0fe487e732 | ||
|
|
6bfaa6e282 | ||
|
|
378d3210dc | ||
|
|
97fe45e36d | ||
|
|
64a9cc8f05 | ||
|
|
f50d691254 | ||
|
|
34c3b68fc8 | ||
|
|
f33ccd5d27 | ||
|
|
bc108b9ad6 | ||
|
|
0c3d27ae42 | ||
|
|
ef65174df2 | ||
|
|
42ecb9f138 | ||
|
|
5c0331fd83 | ||
|
|
e7019c9455 | ||
|
|
d98bfe7e70 | ||
|
|
6747099d71 | ||
|
|
ccc8c6777b | ||
|
|
dbb149e6f7 | ||
|
|
a807985e59 | ||
|
|
8643c4d5bf | ||
|
|
76014b9ac7 | ||
|
|
b0c3aba590 | ||
|
|
19c0c25de8 | ||
|
|
2f723ac2d6 | ||
|
|
249fbbe52f | ||
|
|
c38680b8a1 | ||
|
|
16fca86c4a | ||
|
|
0f3f9e353d | ||
|
|
eceb276901 | ||
|
|
6bd0a983cd | ||
|
|
1861fbdeb5 | ||
|
|
3b96a93672 | ||
|
|
e53b3cbd0c | ||
|
|
b51e0f397c | ||
|
|
b42970063d | ||
|
|
493385eb3e | ||
|
|
9876c9faa4 | ||
|
|
4e415029b3 | ||
|
|
e172f095ba | ||
|
|
c001b98087 | ||
|
|
23fc8e92eb | ||
|
|
4059a297a6 | ||
|
|
66b2539238 | ||
|
|
ef27d52e79 | ||
|
|
b2a465296d | ||
|
|
5d097277ef | ||
|
|
071a9872cb | ||
|
|
cc2978039c | ||
|
|
e9c7bade80 | ||
|
|
0bd0454ea7 | ||
|
|
6097b74894 | ||
|
|
2c9f7a9e17 | ||
|
|
01aa788722 | ||
|
|
ead27aa9fe | ||
|
|
b816ff86c9 | ||
|
|
e5d84fb90b | ||
|
|
dd66712e31 | ||
|
|
f66216e399 | ||
|
|
f4f0992b6e | ||
|
|
1feff61977 | ||
|
|
5e0b904e88 | ||
|
|
9bd1a6116c | ||
|
|
131f0355a5 | ||
|
|
17bb5ea679 | ||
|
|
ce929984a3 | ||
|
|
4b34930a31 | ||
|
|
74bd09652d | ||
|
|
fb6252d786 | ||
|
|
c794fef2f2 | ||
|
|
00ebda8cc4 | ||
|
|
d14ce75b95 | ||
|
|
2d6eac9084 | ||
|
|
3ed7ad3ab3 | ||
|
|
6d1103048e | ||
|
|
0ff28758b3 | ||
|
|
d3e9ca3eda | ||
|
|
0fbfcf3c9c | ||
|
|
0c220935bd | ||
|
|
ffbfe833da | ||
|
|
42a14f7f63 | ||
|
|
f8c3dbe5b5 | ||
|
|
b078dd157c | ||
|
|
2ddacd7516 | ||
|
|
da0e345200 | ||
|
|
df94175a0f | ||
|
|
61a8825216 | ||
|
|
a69a1e6e63 | ||
|
|
021dcf089d | ||
|
|
bf24498b1e | ||
|
|
95e271d98f | ||
|
|
364629b8d6 | ||
|
|
108fe02165 | ||
|
|
4561fff36e | ||
|
|
50b5962042 | ||
|
|
457576739f | ||
|
|
e27e4a3c1b | ||
|
|
088514bbd4 | ||
|
|
2c8b484643 | ||
|
|
8294676150 | ||
|
|
ef378ad673 | ||
|
|
2d2247e59e | ||
|
|
7bf793a600 | ||
|
|
282bfaaa95 | ||
|
|
9679f40146 | ||
|
|
3892c3a703 | ||
|
|
4e320b8b90 | ||
|
|
4cd0c73408 | ||
|
|
eb2b22b042 | ||
|
|
4ea4d2b189 | ||
|
|
8d76fa23ef | ||
|
|
74b44fdf8f | ||
|
|
65b88c544f | ||
|
|
a422ba39c9 | ||
|
|
d2ec22371e | ||
|
|
033cec232a | ||
|
|
543240fb5f | ||
|
|
4bed739259 | ||
|
|
80c7ce381b | ||
|
|
ccfd41c4f0 | ||
|
|
3e102b7dad | ||
|
|
ec46f3286c | ||
|
|
5e2e0b46b1 | ||
|
|
45a13b1dec | ||
|
|
5c0b663969 | ||
|
|
30d7a59ba8 | ||
|
|
4aeb67ef4c | ||
|
|
3ba91634c1 | ||
|
|
1b7433b71e | ||
|
|
a70820daa0 | ||
|
|
6b45b1d6b4 | ||
|
|
85ab552028 | ||
|
|
c3945aaa1d | ||
|
|
b3af953a55 | ||
|
|
3a65093078 | ||
|
|
ad4e0bf3be | ||
|
|
88ab587807 | ||
|
|
aee28501b5 | ||
|
|
83f0ec8269 | ||
|
|
c6b6938b3a | ||
|
|
fb4664fcec | ||
|
|
20e3593863 | ||
|
|
63a394068c | ||
|
|
ab39e08eb9 | ||
|
|
11bfa62796 | ||
|
|
f63e62e546 | ||
|
|
65b0f329d1 | ||
|
|
06007c0a18 | ||
|
|
a8e83a7654 | ||
|
|
475005504e | ||
|
|
2c40c4d35e | ||
|
|
e95278932b | ||
|
|
9d2a20a763 | ||
|
|
2e54d72fc3 | ||
|
|
6b32a2d549 | ||
|
|
c5cbe4fc2a | ||
|
|
f888912870 | ||
|
|
9e4642e9b3 | ||
|
|
6b0486c216 | ||
|
|
d368c039f0 | ||
|
|
9b54267e69 | ||
|
|
46bb0169c4 | ||
|
|
8934324b72 | ||
|
|
0e886595bf | ||
|
|
c62861f4fa | ||
|
|
0df1800436 | ||
|
|
631fecc6d9 | ||
|
|
4346c2409d | ||
|
|
4b037a97dc | ||
|
|
5f74d1fd47 | ||
|
|
4dcf80167a | ||
|
|
26a26998fb | ||
|
|
9926eae015 | ||
|
|
8585b7b151 | ||
|
|
7e34f4fbfa | ||
|
|
fe776293f7 | ||
|
|
d8a5d96b98 | ||
|
|
757668c42f | ||
|
|
96ec8afd09 | ||
|
|
e093db92c4 | ||
|
|
a1cda80bcb | ||
|
|
642a2496fe | ||
|
|
4614fafae0 | ||
|
|
4100ed7bdd | ||
|
|
f52b2615ef | ||
|
|
25f9b152f9 | ||
|
|
6da8b6a879 | ||
|
|
0daaaef8c9 | ||
|
|
98272fbd58 | ||
|
|
b27e8f3f10 | ||
|
|
45df786f09 | ||
|
|
daaf42e4a4 | ||
|
|
2dc60d4620 | ||
|
|
b5312f30e8 | ||
|
|
26c2e0bd35 | ||
|
|
bf920883d5 | ||
|
|
58b9ec1f6b | ||
|
|
7bae7fa5ce | ||
|
|
764e199d67 | ||
|
|
bfce55db3d | ||
|
|
bab6f34dc0 | ||
|
|
0682dae027 | ||
|
|
1f6986e919 | ||
|
|
4289c74359 | ||
|
|
25248f4bd5 | ||
|
|
e82001c122 | ||
|
|
a7e63b82be | ||
|
|
b70fc4d51e | ||
|
|
e2252d0fc6 | ||
|
|
cae5d4d4ea | ||
|
|
d80ea37d36 | ||
|
|
05a01fdecb | ||
|
|
8fe6f69f28 | ||
|
|
1fdb351c37 | ||
|
|
7a01ad7614 | ||
|
|
55ab9f371a | ||
|
|
fefbf8f74b | ||
|
|
b428ddd796 | ||
|
|
ba7d31240e | ||
|
|
d25efe3954 | ||
|
|
36dfb906bb | ||
|
|
a6f0f908b9 | ||
|
|
3b1ddb2b3a | ||
|
|
1579c4f06d | ||
|
|
3519dd1c6e | ||
|
|
e41c4cbea7 | ||
|
|
ee048b76d4 | ||
|
|
af68d60a58 | ||
|
|
92731dfc6f | ||
|
|
21aa666a1e | ||
|
|
ee141cc821 | ||
|
|
55e5776c44 | ||
|
|
854a9195f3 | ||
|
|
96a97adf9b | ||
|
|
e75c6126e9 | ||
|
|
cda6f5c66c | ||
|
|
1f7de23036 | ||
|
|
bebb6823c0 | ||
|
|
31e472baa4 | ||
|
|
657685e85d | ||
|
|
a14912858e | ||
|
|
eed11ded30 | ||
|
|
b42aba40ed | ||
|
|
25885e5335 |
178
.github/workflows/release.yaml
vendored
178
.github/workflows/release.yaml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
|
||||||
|
|
||||||
darwin-build:
|
darwin-build:
|
||||||
runs-on: macos-13
|
runs-on: macos-13-xlarge
|
||||||
environment: release
|
environment: release
|
||||||
needs: setup-environment
|
needs: setup-environment
|
||||||
strategy:
|
strategy:
|
||||||
@@ -54,48 +54,6 @@ jobs:
|
|||||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||||
path: dist/*
|
path: dist/*
|
||||||
|
|
||||||
darwin-sign:
|
|
||||||
runs-on: macos-13
|
|
||||||
environment: release
|
|
||||||
needs: darwin-build
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- run: |
|
|
||||||
echo $MACOS_SIGNING_KEY | base64 --decode > certificate.p12
|
|
||||||
security create-keychain -p password build.keychain
|
|
||||||
security default-keychain -s build.keychain
|
|
||||||
security unlock-keychain -p password build.keychain
|
|
||||||
security import certificate.p12 -k build.keychain -P $MACOS_SIGNING_KEY_PASSWORD -T /usr/bin/codesign
|
|
||||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
|
|
||||||
security set-keychain-settings -lut 3600 build.keychain
|
|
||||||
env:
|
|
||||||
MACOS_SIGNING_KEY: ${{ secrets.MACOS_SIGNING_KEY }}
|
|
||||||
MACOS_SIGNING_KEY_PASSWORD: ${{ secrets.MACOS_SIGNING_KEY_PASSWORD }}
|
|
||||||
- uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: build-darwin-amd64
|
|
||||||
path: dist/darwin-amd64
|
|
||||||
- uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: build-darwin-arm64
|
|
||||||
path: dist/darwin-arm64
|
|
||||||
- run: |
|
|
||||||
export VERSION=${GITHUB_REF_NAME#v}
|
|
||||||
./scripts/build_darwin.sh sign macapp
|
|
||||||
env:
|
|
||||||
APPLE_IDENTITY: ${{ secrets.APPLE_IDENTITY }}
|
|
||||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
|
||||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
|
||||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
|
||||||
SDKROOT: /Applications/Xcode_14.1.0.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
|
||||||
DEVELOPER_DIR: /Applications/Xcode_14.1.0.app/Contents/Developer
|
|
||||||
- uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dist-darwin
|
|
||||||
path: |
|
|
||||||
dist/Ollama-darwin.zip
|
|
||||||
dist/ollama-darwin.tgz
|
|
||||||
|
|
||||||
windows-depends:
|
windows-depends:
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
@@ -103,21 +61,18 @@ jobs:
|
|||||||
arch: [amd64]
|
arch: [amd64]
|
||||||
preset: ['CPU']
|
preset: ['CPU']
|
||||||
include:
|
include:
|
||||||
- os: windows
|
|
||||||
arch: amd64
|
|
||||||
preset: 'CUDA 11'
|
|
||||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
|
||||||
cuda-version: '11.3'
|
|
||||||
- os: windows
|
- os: windows
|
||||||
arch: amd64
|
arch: amd64
|
||||||
preset: 'CUDA 12'
|
preset: 'CUDA 12'
|
||||||
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||||
cuda-version: '12.8'
|
cuda-version: '12.8'
|
||||||
|
flags: ''
|
||||||
- os: windows
|
- os: windows
|
||||||
arch: amd64
|
arch: amd64
|
||||||
preset: 'ROCm 6'
|
preset: 'ROCm 6'
|
||||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||||
rocm-version: '6.2'
|
rocm-version: '6.2'
|
||||||
|
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||||
environment: release
|
environment: release
|
||||||
env:
|
env:
|
||||||
@@ -160,6 +115,9 @@ jobs:
|
|||||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
- if: matrix.preset == 'CPU'
|
- if: matrix.preset == 'CPU'
|
||||||
run: |
|
run: |
|
||||||
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
@@ -178,9 +136,9 @@ jobs:
|
|||||||
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
|
||||||
- name: Build target "${{ matrix.preset }}"
|
- name: Build target "${{ matrix.preset }}"
|
||||||
run: |
|
run: |
|
||||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
cmake --preset "${{ matrix.preset }}"
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
||||||
env:
|
env:
|
||||||
@@ -230,61 +188,11 @@ jobs:
|
|||||||
go-version-file: go.mod
|
go-version-file: go.mod
|
||||||
- run: |
|
- run: |
|
||||||
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
|
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
|
||||||
- if: matrix.arch == 'arm64'
|
|
||||||
run: |
|
|
||||||
Invoke-WebRequest -Uri "https://aka.ms/vs/17/release/vc_redist.arm64.exe" -OutFile "dist\windows-arm64\vc_redist.arm64.exe"
|
|
||||||
- run: |
|
|
||||||
$env:VERSION='${{ github.ref_name }}' -Replace "v(.*)", '$1'
|
|
||||||
& .\scripts\build_windows.ps1 buildApp
|
|
||||||
env:
|
|
||||||
VCToolsRedistDir: stub
|
|
||||||
- uses: actions/upload-artifact@v4
|
- uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
name: build-${{ matrix.os }}-${{ matrix.arch }}
|
||||||
path: |
|
path: |
|
||||||
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
|
dist\${{ matrix.os }}-${{ matrix.arch }}\*.exe
|
||||||
dist\${{ matrix.os }}-${{ matrix.arch }}-app.exe
|
|
||||||
|
|
||||||
windows-sign:
|
|
||||||
runs-on: windows-2022
|
|
||||||
environment: release
|
|
||||||
needs: [windows-depends, windows-build]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
- uses: google-github-actions/auth@v2
|
|
||||||
with:
|
|
||||||
project_id: ollama
|
|
||||||
credentials_json: ${{ secrets.GOOGLE_SIGNING_CREDENTIALS }}
|
|
||||||
- run: |
|
|
||||||
$ErrorActionPreference = "Stop"
|
|
||||||
Invoke-WebRequest -Uri "https://go.microsoft.com/fwlink/p/?LinkId=323507" -OutFile "${{ runner.temp }}\sdksetup.exe"
|
|
||||||
Start-Process "${{ runner.temp }}\sdksetup.exe" -ArgumentList @("/q") -NoNewWindow -Wait
|
|
||||||
|
|
||||||
Invoke-WebRequest -Uri "https://github.com/GoogleCloudPlatform/kms-integrations/releases/download/cng-v1.0/kmscng-1.0-windows-amd64.zip" -OutFile "${{ runner.temp }}\plugin.zip"
|
|
||||||
Expand-Archive -Path "${{ runner.temp }}\plugin.zip" -DestinationPath "${{ runner.temp }}\plugin\"
|
|
||||||
& "${{ runner.temp }}\plugin\*\kmscng.msi" /quiet
|
|
||||||
|
|
||||||
echo "${{ vars.OLLAMA_CERT }}" >ollama_inc.crt
|
|
||||||
- uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
pattern: build-windows-*
|
|
||||||
path: dist\
|
|
||||||
merge-multiple: true
|
|
||||||
- uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
pattern: depends-windows-amd64-*
|
|
||||||
path: dist\windows-amd64\
|
|
||||||
merge-multiple: true
|
|
||||||
- run: |
|
|
||||||
& .\scripts\build_windows.ps1 gatherDependencies sign buildInstaller distZip
|
|
||||||
env:
|
|
||||||
KEY_CONTAINER: ${{ vars.KEY_CONTAINER }}
|
|
||||||
- uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dist-windows
|
|
||||||
path: |
|
|
||||||
dist\OllamaSetup.exe
|
|
||||||
dist\ollama-windows-*.zip
|
|
||||||
|
|
||||||
linux-build:
|
linux-build:
|
||||||
strategy:
|
strategy:
|
||||||
@@ -317,21 +225,26 @@ jobs:
|
|||||||
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
|
CGO_CFLAGS=${{ env.CGO_CFLAGS }}
|
||||||
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
|
CGO_CXXFLAGS=${{ env.CGO_CXXFLAGS }}
|
||||||
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
outputs: type=local,dest=dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||||
cache-from: type=registry,ref=ollama/ollama:latest
|
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||||
cache-to: type=inline
|
cache-to: type=inline
|
||||||
- run: |
|
- run: |
|
||||||
for COMPONENT in bin/* lib/ollama/*; do
|
for COMPONENT in bin/* lib/ollama/*; do
|
||||||
case "$COMPONENT" in
|
case "$COMPONENT" in
|
||||||
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/*.so) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_v11) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/cuda_sbsa) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
||||||
lib/ollama/cuda_v12) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
|
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
|
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
||||||
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
|
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
||||||
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
|
|
||||||
esac
|
esac
|
||||||
done
|
done
|
||||||
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
|
working-directory: dist/${{ matrix.os }}-${{ matrix.arch }}
|
||||||
|
- run: |
|
||||||
|
echo "Manifests"
|
||||||
|
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in ; do
|
||||||
|
echo $ARCHIVE
|
||||||
|
cat $ARCHIVE
|
||||||
|
done
|
||||||
- run: |
|
- run: |
|
||||||
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
for ARCHIVE in dist/${{ matrix.os }}-${{ matrix.arch }}/*.tar.in; do
|
||||||
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
tar c -C dist/${{ matrix.os }}-${{ matrix.arch }} -T $ARCHIVE --owner 0 --group 0 | pigz -9vc >$(basename ${ARCHIVE//.*/}.tgz);
|
||||||
@@ -385,8 +298,8 @@ jobs:
|
|||||||
context: .
|
context: .
|
||||||
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
platforms: ${{ matrix.os }}/${{ matrix.arch }}
|
||||||
build-args: ${{ matrix.build-args }}
|
build-args: ${{ matrix.build-args }}
|
||||||
outputs: type=image,name=ollama/ollama,push-by-digest=true,name-canonical=true,push=true
|
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
|
||||||
cache-from: type=registry,ref=ollama/ollama:latest
|
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest
|
||||||
cache-to: type=inline
|
cache-to: type=inline
|
||||||
- run: |
|
- run: |
|
||||||
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
|
mkdir -p ${{ matrix.os }}-${{ matrix.arch }}
|
||||||
@@ -418,7 +331,7 @@ jobs:
|
|||||||
latest=false
|
latest=false
|
||||||
suffix=${{ matrix.suffix }}
|
suffix=${{ matrix.suffix }}
|
||||||
images: |
|
images: |
|
||||||
ollama/ollama
|
${{ vars.DOCKER_REPO }}
|
||||||
tags: |
|
tags: |
|
||||||
type=ref,enable=true,priority=600,prefix=pr-,event=pr
|
type=ref,enable=true,priority=600,prefix=pr-,event=pr
|
||||||
type=semver,pattern={{version}}
|
type=semver,pattern={{version}}
|
||||||
@@ -428,40 +341,24 @@ jobs:
|
|||||||
path: ${{ runner.temp }}
|
path: ${{ runner.temp }}
|
||||||
merge-multiple: true
|
merge-multiple: true
|
||||||
- run: |
|
- run: |
|
||||||
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf 'ollama/ollama@%s ')
|
docker buildx imagetools create $(echo '${{ steps.metadata.outputs.json }}' | jq -cr '.tags | map("-t", .) | join(" ")') $(cat *-${{ matrix.suffix }}.txt | xargs printf '${{ vars.DOCKER_REPO }}@%s ')
|
||||||
docker buildx imagetools inspect ollama/ollama:${{ steps.metadata.outputs.version }}
|
docker buildx imagetools inspect ${{ vars.DOCKER_REPO }}:${{ steps.metadata.outputs.version }}
|
||||||
working-directory: ${{ runner.temp }}
|
working-directory: ${{ runner.temp }}
|
||||||
|
|
||||||
# Aggregate all the assets and ship a release
|
# Trigger downstream release process
|
||||||
release:
|
trigger:
|
||||||
needs: [darwin-sign, windows-sign, linux-build]
|
runs-on: ubuntu-latest
|
||||||
runs-on: linux
|
|
||||||
environment: release
|
environment: release
|
||||||
|
needs: [darwin-build, windows-build, windows-depends, linux-build]
|
||||||
permissions:
|
permissions:
|
||||||
contents: write
|
contents: write
|
||||||
env:
|
env:
|
||||||
GH_TOKEN: ${{ github.token }}
|
GH_TOKEN: ${{ github.token }}
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- uses: actions/download-artifact@v4
|
- name: Create or update Release for tag
|
||||||
with:
|
|
||||||
name: dist-darwin
|
|
||||||
path: dist
|
|
||||||
- uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: dist-windows
|
|
||||||
path: dist
|
|
||||||
- uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
pattern: dist-linux-*
|
|
||||||
path: dist
|
|
||||||
merge-multiple: true
|
|
||||||
- run: find . -type f -not -name 'sha256sum.txt' | xargs sha256sum | tee sha256sum.txt
|
|
||||||
working-directory: dist
|
|
||||||
- name: Create or update Release
|
|
||||||
run: |
|
run: |
|
||||||
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
|
RELEASE_VERSION="$(echo ${GITHUB_REF_NAME} | cut -f1 -d-)"
|
||||||
|
|
||||||
echo "Looking for existing release for ${RELEASE_VERSION}"
|
echo "Looking for existing release for ${RELEASE_VERSION}"
|
||||||
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
|
OLD_TAG=$(gh release ls --json name,tagName | jq -r ".[] | select(.name == \"${RELEASE_VERSION}\") | .tagName")
|
||||||
if [ -n "$OLD_TAG" ]; then
|
if [ -n "$OLD_TAG" ]; then
|
||||||
@@ -475,5 +372,12 @@ jobs:
|
|||||||
--generate-notes \
|
--generate-notes \
|
||||||
--prerelease
|
--prerelease
|
||||||
fi
|
fi
|
||||||
echo "Uploading artifacts for tag ${GITHUB_REF_NAME}"
|
- name: Trigger downstream release process
|
||||||
gh release upload ${GITHUB_REF_NAME} dist/* --clobber
|
run: |
|
||||||
|
curl -L \
|
||||||
|
-X POST \
|
||||||
|
-H "Accept: application/vnd.github+json" \
|
||||||
|
-H "Authorization: Bearer ${{ secrets.RELEASE_TOKEN }}" \
|
||||||
|
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||||
|
https://api.github.com/repos/ollama/${{ vars.RELEASE_REPO }}/dispatches \
|
||||||
|
-d "{\"event_type\": \"trigger-workflow\", \"client_payload\": {\"run_id\": \"${GITHUB_RUN_ID}\", \"version\": \"${GITHUB_REF_NAME#v}\", \"origin\": \"${GITHUB_REPOSITORY}\", \"publish\": \"1\"}}"
|
||||||
|
|||||||
19
.github/workflows/test.yaml
vendored
19
.github/workflows/test.yaml
vendored
@@ -36,7 +36,7 @@ jobs:
|
|||||||
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))"
|
||||||
}
|
}
|
||||||
|
|
||||||
echo changed=$(changed 'llama/llama.cpp/**' 'ml/backend/ggml/ggml/**') | tee -a $GITHUB_OUTPUT
|
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
|
||||||
|
|
||||||
linux:
|
linux:
|
||||||
needs: [changes]
|
needs: [changes]
|
||||||
@@ -46,7 +46,7 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- preset: CPU
|
- preset: CPU
|
||||||
- preset: CUDA
|
- preset: CUDA
|
||||||
container: nvidia/cuda:11.8.0-devel-ubuntu22.04
|
container: nvidia/cuda:12.8.1-devel-ubuntu22.04
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=87'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
container: rocm/dev-ubuntu-22.04:6.1.2
|
container: rocm/dev-ubuntu-22.04:6.1.2
|
||||||
@@ -78,11 +78,11 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- preset: CPU
|
- preset: CPU
|
||||||
- preset: CUDA
|
- preset: CUDA
|
||||||
install: https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
install: https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_571.96_windows.exe
|
||||||
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
flags: '-DCMAKE_CUDA_ARCHITECTURES=80'
|
||||||
- preset: ROCm
|
- preset: ROCm
|
||||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||||
flags: '-DAMDGPU_TARGETS=gfx1010'
|
flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||||
runs-on: windows
|
runs-on: windows
|
||||||
steps:
|
steps:
|
||||||
- run: |
|
- run: |
|
||||||
@@ -102,7 +102,7 @@ jobs:
|
|||||||
$ErrorActionPreference = "Stop"
|
$ErrorActionPreference = "Stop"
|
||||||
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
|
||||||
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
|
||||||
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_11.3", "nvcc_11.3", "cublas_11.3", "cublas_dev_11.3")) -NoNewWindow -Wait
|
Start-Process -FilePath .\install.exe -ArgumentList (@("-s", "cudart_12.8", "nvcc_12.8", "cublas_12.8", "cublas_dev_12.8")) -NoNewWindow -Wait
|
||||||
}
|
}
|
||||||
|
|
||||||
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
$cudaPath = (Resolve-Path "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*").path
|
||||||
@@ -120,6 +120,9 @@ jobs:
|
|||||||
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||||
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
|
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
|
||||||
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
|
||||||
uses: actions/cache/save@v4
|
uses: actions/cache/save@v4
|
||||||
with:
|
with:
|
||||||
@@ -133,8 +136,8 @@ jobs:
|
|||||||
path: ${{ github.workspace }}\.ccache
|
path: ${{ github.workspace }}\.ccache
|
||||||
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
|
||||||
- run: |
|
- run: |
|
||||||
Import-Module 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
|
||||||
Enter-VsDevShell -VsInstallPath 'C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
|
||||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }}
|
||||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||||
env:
|
env:
|
||||||
@@ -237,5 +240,5 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
- name: Verify patches apply cleanly and do not change files
|
- name: Verify patches apply cleanly and do not change files
|
||||||
run: |
|
run: |
|
||||||
make -f Makefile.sync clean sync
|
make -f Makefile.sync clean checkout apply-patches sync
|
||||||
git diff --compact-summary --exit-code
|
git diff --compact-summary --exit-code
|
||||||
@@ -19,8 +19,8 @@ linters:
|
|||||||
- nolintlint
|
- nolintlint
|
||||||
- nosprintfhostport
|
- nosprintfhostport
|
||||||
- staticcheck
|
- staticcheck
|
||||||
- tenv
|
|
||||||
- unconvert
|
- unconvert
|
||||||
|
- usetesting
|
||||||
- wastedassign
|
- wastedassign
|
||||||
- whitespace
|
- whitespace
|
||||||
disable:
|
disable:
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ cmake_minimum_required(VERSION 3.21)
|
|||||||
project(Ollama C CXX)
|
project(Ollama C CXX)
|
||||||
|
|
||||||
include(CheckLanguage)
|
include(CheckLanguage)
|
||||||
|
include(GNUInstallDirs)
|
||||||
|
|
||||||
find_package(Threads REQUIRED)
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
@@ -23,6 +24,8 @@ set(GGML_SCHED_MAX_COPIES 4)
|
|||||||
set(GGML_LLAMAFILE ON)
|
set(GGML_LLAMAFILE ON)
|
||||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||||
set(GGML_CUDA_GRAPHS ON)
|
set(GGML_CUDA_GRAPHS ON)
|
||||||
|
set(GGML_CUDA_FA ON)
|
||||||
|
set(GGML_CUDA_COMPRESSION_MODE default)
|
||||||
|
|
||||||
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||||
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
||||||
@@ -49,6 +52,8 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
|
|||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
|
||||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
|
||||||
|
|
||||||
|
add_compile_definitions(NDEBUG GGML_VERSION=0x0 GGML_COMMIT=0x0)
|
||||||
|
|
||||||
set(GGML_CPU ON)
|
set(GGML_CPU ON)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
|
||||||
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)
|
||||||
@@ -74,32 +79,16 @@ if(CMAKE_CUDA_COMPILER)
|
|||||||
|
|
||||||
find_package(CUDAToolkit)
|
find_package(CUDAToolkit)
|
||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cuda)
|
||||||
set(OLLAMA_CUDA_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/cuda_v${CUDAToolkit_VERSION_MAJOR})
|
|
||||||
install(TARGETS ggml-cuda
|
install(TARGETS ggml-cuda
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCIES
|
||||||
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
DIRECTORIES ${CUDAToolkit_BIN_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||||
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
PRE_INCLUDE_REGEXES cublas cublasLt cudart
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
RUNTIME DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||||
LIBRARY DESTINATION ${OLLAMA_CUDA_INSTALL_DIR} COMPONENT CUDA
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT CUDA
|
||||||
)
|
)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
set(AMDGPU_TARGETS
|
|
||||||
gfx1010:xnack-
|
|
||||||
gfx1011
|
|
||||||
gfx1012:xnack-
|
|
||||||
gfx1030
|
|
||||||
gfx1031
|
|
||||||
gfx1032
|
|
||||||
gfx1034
|
|
||||||
gfx1035
|
|
||||||
gfx1100
|
|
||||||
gfx1101
|
|
||||||
gfx1102
|
|
||||||
gfx1103
|
|
||||||
gfx1150
|
|
||||||
CACHE STRING "List of AMDGPU targets to build for")
|
|
||||||
|
|
||||||
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX ""
|
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX ""
|
||||||
CACHE STRING
|
CACHE STRING
|
||||||
@@ -112,7 +101,7 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
|
|
||||||
find_package(hip REQUIRED)
|
find_package(hip REQUIRED)
|
||||||
if(NOT AMDGPU_TARGETS)
|
if(NOT AMDGPU_TARGETS)
|
||||||
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|900(:xnack-)|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011|1012(:xnack-)|103[0-6]|110[0-3]|1150)$")
|
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(803|902|906(:xnack-)|90c(:xnack-)|1010(:xnack-)|1011(:xnack-)|1012(:xnack-)|103[0-6]|110[0-3]|115[01]|120[01])$")
|
||||||
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
|
||||||
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
|
||||||
endif()
|
endif()
|
||||||
@@ -121,12 +110,18 @@ if(CMAKE_HIP_COMPILER)
|
|||||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
|
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
|
||||||
|
|
||||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||||
install(TARGETS ggml-hip
|
install(TARGETS ggml-hip
|
||||||
RUNTIME_DEPENDENCIES
|
RUNTIME_DEPENDENCY_SET rocm
|
||||||
|
RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||||
|
LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT HIP
|
||||||
|
)
|
||||||
|
install(RUNTIME_DEPENDENCY_SET rocm
|
||||||
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
DIRECTORIES ${HIP_BIN_INSTALL_DIR} ${HIP_LIB_INSTALL_DIR}
|
||||||
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
PRE_INCLUDE_REGEXES hipblas rocblas amdhip64 rocsolver amd_comgr hsa-runtime64 rocsparse tinfo rocprofiler-register drm drm_amdgpu numa elf
|
||||||
PRE_EXCLUDE_REGEXES ".*"
|
PRE_EXCLUDE_REGEXES ".*"
|
||||||
|
|||||||
@@ -6,7 +6,8 @@
|
|||||||
"binaryDir": "${sourceDir}/build",
|
"binaryDir": "${sourceDir}/build",
|
||||||
"installDir": "${sourceDir}/dist",
|
"installDir": "${sourceDir}/dist",
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"CMAKE_BUILD_TYPE": "Release"
|
"CMAKE_BUILD_TYPE": "Release",
|
||||||
|
"CMAKE_MSVC_RUNTIME_LIBRARY": "MultiThreaded"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -17,18 +18,12 @@
|
|||||||
"name": "CUDA",
|
"name": "CUDA",
|
||||||
"inherits": [ "Default" ]
|
"inherits": [ "Default" ]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "CUDA 11",
|
|
||||||
"inherits": [ "CUDA" ],
|
|
||||||
"cacheVariables": {
|
|
||||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;53;60;61;70;75;80;86"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "CUDA 12",
|
"name": "CUDA 12",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;100"
|
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
||||||
|
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -56,7 +51,8 @@
|
|||||||
"name": "ROCm 6",
|
"name": "ROCm 6",
|
||||||
"inherits": [ "ROCm" ],
|
"inherits": [ "ROCm" ],
|
||||||
"cacheVariables": {
|
"cacheVariables": {
|
||||||
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1011;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1012:xnack-;"
|
"CMAKE_HIP_FLAGS": "-parallel-jobs=4",
|
||||||
|
"AMDGPU_TARGETS": "gfx803;gfx902;gfx1030;gfx1031;gfx1032;gfx1034;gfx1035;gfx1036;gfx1100;gfx1101;gfx1102;gfx1103;gfx1150;gfx1151;gfx1200;gfx1201;gfx900:xnack-;gfx906:xnack-;gfx90c:xnack-;gfx1010:xnack-;gfx1011:xnack-;gfx1012:xnack-;"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -76,11 +72,6 @@
|
|||||||
"configurePreset": "CUDA",
|
"configurePreset": "CUDA",
|
||||||
"targets": [ "ggml-cuda" ]
|
"targets": [ "ggml-cuda" ]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"name": "CUDA 11",
|
|
||||||
"inherits": [ "CUDA" ],
|
|
||||||
"configurePreset": "CUDA 11"
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"name": "CUDA 12",
|
"name": "CUDA 12",
|
||||||
"inherits": [ "CUDA" ],
|
"inherits": [ "CUDA" ],
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ see if the change were accepted.
|
|||||||
|
|
||||||
The title should look like:
|
The title should look like:
|
||||||
|
|
||||||
<package>: <short description>
|
<package>: <short description>
|
||||||
|
|
||||||
The package is the most affected Go package. If the change does not affect Go
|
The package is the most affected Go package. If the change does not affect Go
|
||||||
code, then use the directory name instead. Changes to a single well-known
|
code, then use the directory name instead. Changes to a single well-known
|
||||||
@@ -65,7 +65,8 @@ continuation of the sentence:
|
|||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
llm/backend/mlx: support the llama architecture
|
llm/backend/mlx: support the llama architecture
|
||||||
CONTRIBUTING: provide clairity on good commit messages, and bad
|
CONTRIBUTING: provide clarity on good commit messages, and bad
|
||||||
|
docs: simplify manual installation with shorter curl commands
|
||||||
|
|
||||||
Bad Examples:
|
Bad Examples:
|
||||||
|
|
||||||
|
|||||||
37
Dockerfile
37
Dockerfile
@@ -7,12 +7,13 @@ ARG JETPACK5VERSION=r35.4.1
|
|||||||
ARG JETPACK6VERSION=r36.4.0
|
ARG JETPACK6VERSION=r36.4.0
|
||||||
ARG CMAKEVERSION=3.31.2
|
ARG CMAKEVERSION=3.31.2
|
||||||
|
|
||||||
# CUDA v11 requires gcc v10. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
# We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version
|
||||||
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64
|
||||||
RUN yum install -y yum-utils \
|
RUN yum install -y yum-utils \
|
||||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
|
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||||
|
&& dnf install -y ccache \
|
||||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||||
|
|
||||||
@@ -38,15 +39,6 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
&& cmake --build --parallel --preset 'CPU' \
|
&& cmake --build --parallel --preset 'CPU' \
|
||||||
&& cmake --install build --component CPU --strip --parallel 8
|
&& cmake --install build --component CPU --strip --parallel 8
|
||||||
|
|
||||||
FROM base AS cuda-11
|
|
||||||
ARG CUDA11VERSION=11.3
|
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
|
|
||||||
ENV PATH=/usr/local/cuda-11/bin:$PATH
|
|
||||||
RUN --mount=type=cache,target=/root/.ccache \
|
|
||||||
cmake --preset 'CUDA 11' \
|
|
||||||
&& cmake --build --parallel --preset 'CUDA 11' \
|
|
||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
|
||||||
|
|
||||||
FROM base AS cuda-12
|
FROM base AS cuda-12
|
||||||
ARG CUDA12VERSION=12.8
|
ARG CUDA12VERSION=12.8
|
||||||
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
|
||||||
@@ -86,34 +78,35 @@ RUN --mount=type=cache,target=/root/.ccache \
|
|||||||
&& cmake --install build --component CUDA --strip --parallel 8
|
&& cmake --install build --component CUDA --strip --parallel 8
|
||||||
|
|
||||||
FROM base AS build
|
FROM base AS build
|
||||||
ARG GOVERSION=1.23.4
|
|
||||||
RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
|
||||||
ENV PATH=/usr/local/go/bin:$PATH
|
|
||||||
WORKDIR /go/src/github.com/ollama/ollama
|
WORKDIR /go/src/github.com/ollama/ollama
|
||||||
|
COPY go.mod go.sum .
|
||||||
|
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||||
|
ENV PATH=/usr/local/go/bin:$PATH
|
||||||
|
RUN go mod download
|
||||||
COPY . .
|
COPY . .
|
||||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||||
ENV CGO_ENABLED=1
|
ENV CGO_ENABLED=1
|
||||||
|
ARG CGO_CFLAGS
|
||||||
|
ARG CGO_CXXFLAGS
|
||||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||||
|
|
||||||
FROM --platform=linux/amd64 scratch AS amd64
|
FROM --platform=linux/amd64 scratch AS amd64
|
||||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
COPY --from=cuda-12 dist/lib/ollama /lib/ollama
|
||||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
|
||||||
|
|
||||||
FROM --platform=linux/arm64 scratch AS arm64
|
FROM --platform=linux/arm64 scratch AS arm64
|
||||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/cuda_sbsa
|
||||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
COPY --from=jetpack-5 dist/lib/ollama /lib/ollama/cuda_jetpack5
|
||||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
|
COPY --from=jetpack-6 dist/lib/ollama /lib/ollama/cuda_jetpack6
|
||||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
|
|
||||||
|
|
||||||
FROM scratch AS rocm
|
FROM scratch AS rocm
|
||||||
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
COPY --from=rocm-6 dist/lib/ollama /lib/ollama
|
||||||
|
|
||||||
FROM ${FLAVOR} AS archive
|
FROM ${FLAVOR} AS archive
|
||||||
COPY --from=cpu dist/lib/ollama /lib/ollama
|
COPY --from=cpu dist/lib/ollama /lib/ollama
|
||||||
COPY --from=build /bin/ollama /bin/ollama
|
COPY --from=build /bin/ollama /bin/ollama
|
||||||
|
|
||||||
FROM ubuntu:20.04
|
FROM ubuntu:24.04
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y ca-certificates \
|
&& apt-get install -y ca-certificates \
|
||||||
&& apt-get clean \
|
&& apt-get clean \
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
UPSTREAM=https://github.com/ggerganov/llama.cpp.git
|
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
|
||||||
WORKDIR=llama/vendor
|
WORKDIR=llama/vendor
|
||||||
FETCH_HEAD=d7cfe1ffe0f435d0048a6058d529daf76e072d9c
|
FETCH_HEAD=e54d41befcc1575f4c898c5ff4ef43970cead75f
|
||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@@ -12,31 +12,42 @@ help:
|
|||||||
@echo " clean Clean local repository"
|
@echo " clean Clean local repository"
|
||||||
@echo
|
@echo
|
||||||
@echo "Example:"
|
@echo "Example:"
|
||||||
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean sync"
|
@echo " make -f $(lastword $(MAKEFILE_LIST)) clean apply-patches sync"
|
||||||
|
|
||||||
.PHONY: sync
|
.PHONY: sync
|
||||||
sync: llama/build-info.cpp llama/llama.cpp ml/backend/ggml/ggml apply-patches
|
sync: llama/build-info.cpp ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal
|
||||||
|
|
||||||
.PHONY: llama/build-info.cpp
|
llama/build-info.cpp: llama/build-info.cpp.in llama/llama.cpp
|
||||||
llama/build-info.cpp: llama/build-info.cpp.in
|
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' <$< >$@
|
||||||
sed -e 's|@FETCH_HEAD@|$(FETCH_HEAD)|' $< > $@
|
|
||||||
|
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal: ml/backend/ggml/ggml
|
||||||
|
go generate ./$(@D)
|
||||||
|
|
||||||
.PHONY: llama/llama.cpp
|
.PHONY: llama/llama.cpp
|
||||||
llama/llama.cpp: llama/vendor/ apply-patches
|
llama/llama.cpp: llama/vendor
|
||||||
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
rsync -arvzc --delete -f "include LICENSE" -f "merge $@/.rsync-filter" $(addprefix $<,/LICENSE /) $@
|
||||||
|
|
||||||
.PHONY: ml/backend/ggml/ggml apply-patches
|
.PHONY: ml/backend/ggml/ggml
|
||||||
ml/backend/ggml/ggml: llama/vendor/ggml/ apply-patches
|
ml/backend/ggml/ggml: llama/vendor
|
||||||
rsync -arvzc -f "merge $@/.rsync-filter" $< $@
|
rsync -arvzc --delete -f "include LICENSE" -f "merge $@/.rsync-filter" $(addprefix $<,/LICENSE /ggml/) $@
|
||||||
|
|
||||||
PATCHES=$(wildcard llama/patches/*.patch)
|
PATCHES=$(wildcard llama/patches/*.patch)
|
||||||
|
PATCHED=$(join $(dir $(PATCHES)), $(addsuffix ed, $(addprefix ., $(notdir $(PATCHES)))))
|
||||||
|
|
||||||
.PHONY: apply-patches
|
.PHONY: apply-patches
|
||||||
.NOTPARALLEL:
|
.NOTPARALLEL:
|
||||||
apply-patches: $(addsuffix ed, $(PATCHES))
|
apply-patches: $(PATCHED)
|
||||||
|
|
||||||
%.patched: %.patch
|
llama/patches/.%.patched: llama/patches/%.patch
|
||||||
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then touch $@; else git -C $(WORKDIR) am --abort; exit 1; fi
|
@if git -c user.name=nobody -c 'user.email=<>' -C $(WORKDIR) am -3 $(realpath $<); then \
|
||||||
|
touch $@; \
|
||||||
|
else \
|
||||||
|
echo "Patch failed. Resolve any conflicts then continue."; \
|
||||||
|
echo "1. Run 'git -C $(WORKDIR) am --continue'"; \
|
||||||
|
echo "2. Run 'make -f $(lastword $(MAKEFILE_LIST)) format-patches'"; \
|
||||||
|
echo "3. Run 'make -f $(lastword $(MAKEFILE_LIST)) clean apply-patches'"; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
|
||||||
.PHONY: checkout
|
.PHONY: checkout
|
||||||
checkout: $(WORKDIR)
|
checkout: $(WORKDIR)
|
||||||
@@ -57,4 +68,5 @@ format-patches: llama/patches
|
|||||||
|
|
||||||
.PHONE: clean
|
.PHONE: clean
|
||||||
clean: checkout
|
clean: checkout
|
||||||
$(RM) $(addsuffix ed, $(PATCHES))
|
@git -C $(WORKDIR) am --abort || true
|
||||||
|
$(RM) llama/patches/.*.patched
|
||||||
|
|||||||
107
README.md
107
README.md
@@ -1,6 +1,6 @@
|
|||||||
<div align="center">
|
<div align="center">
|
||||||
<a href="https://ollama.com" />
|
<a href="https://ollama.com">
|
||||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
<img alt="ollama" width="240" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||||
</a>
|
</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -10,7 +10,7 @@ Get up and running with large language models.
|
|||||||
|
|
||||||
### macOS
|
### macOS
|
||||||
|
|
||||||
[Download](https://ollama.com/download/Ollama-darwin.zip)
|
[Download](https://ollama.com/download/Ollama.dmg)
|
||||||
|
|
||||||
### Windows
|
### Windows
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ Please download from ollama [official](https://ollama.com/download/OllamaSetup.e
|
|||||||
|
|
||||||
Example extra list add on this repo.
|
Example extra list add on this repo.
|
||||||
```
|
```
|
||||||
"gfx803" "gfx900:xnack-" "gfx902" gfx906:xnack- "gfx1010:xnack-" "gfx1011" "gfx1012:xnack-" "gfx1031" "gfx1032" "gfx1034" "gfx1035" "gfx1036" "gfx1103" "gfx1150(expertimental)"...
|
(ROCm5) "gfx803" "gfx900:xnack-" "gfx902" (ROCm6) gfx906:xnack- "gfx1010:xnack-" "gfx1011" "gfx1012:xnack-" "gfx1031" "gfx1032" "gfx1034" "gfx1035" "gfx1036" "gfx1103" "gfx1150" "gfx1201" (expertimental)"...
|
||||||
```
|
```
|
||||||
Please follow the [wiki](https://github.com/likelovewant/ollama-for-amd/wiki) guide to build or use the pre-release version.
|
Please follow the [wiki](https://github.com/likelovewant/ollama-for-amd/wiki) guide to build or use the pre-release version.
|
||||||
|
|
||||||
@@ -62,10 +62,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
|
|||||||
|
|
||||||
## Quickstart
|
## Quickstart
|
||||||
|
|
||||||
To run and chat with [Llama 3.2](https://ollama.com/library/llama3.2):
|
To run and chat with [Gemma 3](https://ollama.com/library/gemma3):
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
ollama run llama3.2
|
ollama run gemma3
|
||||||
```
|
```
|
||||||
|
|
||||||
## Model library
|
## Model library
|
||||||
@@ -76,8 +76,15 @@ Here are some example models that can be downloaded:
|
|||||||
|
|
||||||
| Model | Parameters | Size | Download |
|
| Model | Parameters | Size | Download |
|
||||||
| ------------------ | ---------- | ----- | -------------------------------- |
|
| ------------------ | ---------- | ----- | -------------------------------- |
|
||||||
|
| Gemma 3 | 1B | 815MB | `ollama run gemma3:1b` |
|
||||||
|
| Gemma 3 | 4B | 3.3GB | `ollama run gemma3` |
|
||||||
|
| Gemma 3 | 12B | 8.1GB | `ollama run gemma3:12b` |
|
||||||
|
| Gemma 3 | 27B | 17GB | `ollama run gemma3:27b` |
|
||||||
|
| QwQ | 32B | 20GB | `ollama run qwq` |
|
||||||
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
| DeepSeek-R1 | 7B | 4.7GB | `ollama run deepseek-r1` |
|
||||||
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
| DeepSeek-R1 | 671B | 404GB | `ollama run deepseek-r1:671b` |
|
||||||
|
| Llama 4 | 109B | 67GB | `ollama run llama4:scout` |
|
||||||
|
| Llama 4 | 400B | 245GB | `ollama run llama4:maverick` |
|
||||||
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
| Llama 3.3 | 70B | 43GB | `ollama run llama3.3` |
|
||||||
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
| Llama 3.2 | 3B | 2.0GB | `ollama run llama3.2` |
|
||||||
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
| Llama 3.2 | 1B | 1.3GB | `ollama run llama3.2:1b` |
|
||||||
@@ -86,10 +93,7 @@ Here are some example models that can be downloaded:
|
|||||||
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
||||||
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
||||||
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
||||||
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
|
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
|
||||||
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
|
|
||||||
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
|
|
||||||
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
|
|
||||||
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
| Mistral | 7B | 4.1GB | `ollama run mistral` |
|
||||||
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
| Moondream 2 | 1.4B | 829MB | `ollama run moondream` |
|
||||||
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
|
||||||
@@ -97,7 +101,7 @@ Here are some example models that can be downloaded:
|
|||||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
| Granite-3.3 | 8B | 4.9GB | `ollama run granite3.3` |
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||||
@@ -297,6 +301,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
### Web & Desktop
|
### Web & Desktop
|
||||||
|
|
||||||
- [Open WebUI](https://github.com/open-webui/open-webui)
|
- [Open WebUI](https://github.com/open-webui/open-webui)
|
||||||
|
- [SwiftChat (macOS with ReactNative)](https://github.com/aws-samples/swift-chat)
|
||||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||||
- [Hollama](https://github.com/fmaclen/hollama)
|
- [Hollama](https://github.com/fmaclen/hollama)
|
||||||
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
||||||
@@ -304,12 +309,13 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||||
- [Saddle](https://github.com/jikkuatwork/saddle)
|
- [Saddle](https://github.com/jikkuatwork/saddle)
|
||||||
|
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
|
||||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||||
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
|
||||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
||||||
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
|
- [big-AGI](https://github.com/enricoros/big-AGI)
|
||||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
||||||
- [Amica](https://github.com/semperai/amica)
|
- [Amica](https://github.com/semperai/amica)
|
||||||
- [chatd](https://github.com/BruceMacD/chatd)
|
- [chatd](https://github.com/BruceMacD/chatd)
|
||||||
@@ -330,6 +336,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||||
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
|
- [IntelliBar](https://intellibar.app/) (AI-powered assistant for macOS)
|
||||||
|
- [Jirapt](https://github.com/AliAhmedNada/jirapt) (Jira Integration to generate issues, tasks, epics)
|
||||||
|
- [ojira](https://github.com/AliAhmedNada/ojira) (Jira chrome plugin to easily generate descriptions for tasks)
|
||||||
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
|
- [QA-Pilot](https://github.com/reid41/QA-Pilot) (Interactive chat tool that can leverage Ollama models for rapid understanding and navigation of GitHub code repositories)
|
||||||
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
|
- [ChatOllama](https://github.com/sugarforever/chat-ollama) (Open Source Chatbot based on Ollama with Knowledge Bases)
|
||||||
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
|
- [CRAG Ollama Chat](https://github.com/Nagi-ovo/CRAG-Ollama-Chat) (Simple Web Search with Corrective RAG)
|
||||||
@@ -343,13 +351,14 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
- [RWKV-Runner](https://github.com/josStorer/RWKV-Runner) (RWKV offline LLM deployment tool, also usable as a client for ChatGPT and Ollama)
|
||||||
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
- [Ollama Grid Search](https://github.com/dezoito/ollama-grid-search) (app to evaluate and compare models)
|
||||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||||
|
- [Casibase](https://casibase.org) (An open source AI knowledge base and dialogue system combining the latest RAG, SSO, ollama support, and multiple large language models.)
|
||||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||||
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
- [Shinkai Desktop](https://github.com/dcSpark/shinkai-apps) (Two click install Local AI using Ollama + Files + RAG)
|
||||||
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in discord )
|
- [AiLama](https://github.com/zeyoyt/ailama) (A Discord User App that allows you to interact with Ollama anywhere in Discord)
|
||||||
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
|
||||||
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
|
- [R2R](https://github.com/SciPhi-AI/R2R) (Open-source RAG engine)
|
||||||
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy to use GUI with sample custom LLM for Drivers Education)
|
- [Ollama-Kis](https://github.com/elearningshow/ollama-kis) (A simple easy-to-use GUI with sample custom LLM for Drivers Education)
|
||||||
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
|
- [OpenGPA](https://opengpa.org) (Open-source offline-first Enterprise Agentic Application)
|
||||||
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
|
- [Painting Droid](https://github.com/mateuszmigas/painting-droid) (Painting app with AI integrations)
|
||||||
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
|
||||||
@@ -358,22 +367,22 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
|
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
|
||||||
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
|
- [BoltAI for Mac](https://boltai.com) (AI Chat Client for Mac)
|
||||||
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
|
- [Harbor](https://github.com/av/harbor) (Containerized LLM Toolkit with Ollama as default backend)
|
||||||
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows and Mac)
|
- [PyGPT](https://github.com/szczyglis-dev/py-gpt) (AI desktop assistant for Linux, Windows, and Mac)
|
||||||
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for linux and macos made with GTK4 and Adwaita)
|
- [Alpaca](https://github.com/Jeffser/Alpaca) (An Ollama client application for Linux and macOS made with GTK4 and Adwaita)
|
||||||
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
|
- [AutoGPT](https://github.com/Significant-Gravitas/AutoGPT/blob/master/docs/content/platform/ollama.md) (AutoGPT Ollama integration)
|
||||||
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
|
- [Go-CREW](https://www.jonathanhecl.com/go-crew/) (Powerful Offline RAG in Golang)
|
||||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
||||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
|
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
|
||||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||||
- [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
|
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
|
||||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||||
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
|
- [crewAI with Mesop](https://github.com/rapidarchitect/ollama-crew-mesop) (Mesop Web Interface to run crewAI with Ollama)
|
||||||
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
- [Tkinter-based client](https://github.com/chyok/ollama-gui) (Python tkinter-based Client for Ollama)
|
||||||
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
|
- [LLMChat](https://github.com/trendy-design/llmchat) (Privacy focused, 100% local, intuitive all-in-one chat interface)
|
||||||
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
|
- [Local Multimodal AI Chat](https://github.com/Leon-Sander/Local-Multimodal-AI-Chat) (Ollama-based LLM Chat with support for multiple features, including PDF RAG, voice chat, image-based interactions, and integration with OpenAI.)
|
||||||
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG on Mac/Windows/Linux)
|
- [ARGO](https://github.com/xark-argo/argo) (Locally download and run Ollama and Huggingface models with RAG and deep research on Mac/Windows/Linux)
|
||||||
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
|
- [OrionChat](https://github.com/EliasPereirah/OrionChat) - OrionChat is a web interface for chatting with different AI providers
|
||||||
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
|
- [G1](https://github.com/bklieger-groq/g1) (Prototype of using prompting strategies to improve the LLM's reasoning through o1-like reasoning chains.)
|
||||||
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
|
||||||
@@ -385,7 +394,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
|
||||||
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
|
||||||
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
|
||||||
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
|
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard, and said in the meetings)
|
||||||
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
|
- [Hexabot](https://github.com/hexastack/hexabot) (A conversational AI builder)
|
||||||
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
|
- [Reddit Rate](https://github.com/rapidarchitect/reddit_analyzer) (Search and Rate Reddit topics with a weighted summation)
|
||||||
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
|
- [OpenTalkGpt](https://github.com/adarshM84/OpenTalkGpt) (Chrome Extension to manage open-source models supported by Ollama, create custom models, and chat with models from a user-friendly UI)
|
||||||
@@ -403,11 +412,29 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
|
||||||
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
- [LocalLLM](https://github.com/qusaismael/localllm) (Minimal Web-App to run ollama models on it with a GUI)
|
||||||
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
- [Ollamazing](https://github.com/buiducnhat/ollamazing) (Web extension to run Ollama models)
|
||||||
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
|
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivalent endpoint with Ollama support for running locally)
|
||||||
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
|
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
|
||||||
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
||||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||||
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
||||||
|
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||||
|
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||||
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
|
- [Flufy](https://github.com/Aharon-Bensadoun/Flufy) (A beautiful chat interface for interacting with Ollama's API. Built with React, TypeScript, and Material-UI.)
|
||||||
|
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
|
||||||
|
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
|
||||||
|
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
|
||||||
|
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
|
||||||
|
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
|
||||||
|
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
|
||||||
|
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
|
||||||
|
- [macLlama (macOS native)](https://github.com/hellotunamayo/macLlama) (A native macOS GUI application for interacting with Ollama models, featuring a chat interface.)
|
||||||
|
- [GPTranslate](https://github.com/philberndt/GPTranslate) (A fast and lightweight, AI powered desktop translation application written with Rust and Tauri. Features real-time translation with OpenAI/Azure/Ollama.)
|
||||||
|
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||||
|
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||||
|
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||||
|
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||||
|
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||||
|
|
||||||
### Cloud
|
### Cloud
|
||||||
|
|
||||||
@@ -447,10 +474,17 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
|
||||||
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
|
||||||
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
|
||||||
|
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||||
|
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull, and download models from Ollama Registry in your terminal.
|
||||||
|
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||||
|
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||||
|
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
|
||||||
|
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
|
||||||
|
|
||||||
### Apple Vision Pro
|
### Apple Vision Pro
|
||||||
|
|
||||||
|
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Cross-platform AI chat app supporting Apple Vision Pro via "Designed for iPad")
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||||
|
|
||||||
### Database
|
### Database
|
||||||
@@ -473,7 +507,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
|
|
||||||
### Libraries
|
### Libraries
|
||||||
|
|
||||||
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
- [LangChain](https://python.langchain.com/docs/integrations/chat/ollama/) and [LangChain.js](https://js.langchain.com/docs/integrations/chat/ollama/) with [example](https://js.langchain.com/docs/tutorials/local_rag/)
|
||||||
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
|
- [Firebase Genkit](https://firebase.google.com/docs/genkit/plugins/ollama)
|
||||||
- [crewAI](https://github.com/crewAIInc/crewAI)
|
- [crewAI](https://github.com/crewAIInc/crewAI)
|
||||||
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
|
- [Yacana](https://remembersoftwares.github.io/yacana/) (User-friendly multi-agent framework for brainstorming and executing predetermined flows with built-in tool integration)
|
||||||
@@ -520,18 +554,26 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
|
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
|
||||||
- [GoLamify](https://github.com/prasad89/golamify)
|
- [GoLamify](https://github.com/prasad89/golamify)
|
||||||
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
|
||||||
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in unified API)
|
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
|
||||||
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
- [LlmTornado](https://github.com/lofcz/llmtornado) (C# library providing a unified interface for major FOSS & Commercial inference APIs)
|
||||||
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
|
||||||
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
|
||||||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||||
|
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||||
|
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||||
|
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
||||||
|
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
||||||
|
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
|
||||||
|
|
||||||
### Mobile
|
### Mobile
|
||||||
|
|
||||||
|
- [SwiftChat](https://github.com/aws-samples/swift-chat) (Lightning-fast Cross-platform AI chat app with native UI for Android, iOS, and iPad)
|
||||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
|
||||||
|
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
||||||
|
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||||
|
|
||||||
### Extensions & Plugins
|
### Extensions & Plugins
|
||||||
|
|
||||||
@@ -553,7 +595,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
|
- [Obsidian Local GPT plugin](https://github.com/pfrankov/obsidian-local-gpt)
|
||||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||||
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
|
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use Ollama as a copilot like GitHub Copilot)
|
||||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
||||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
|
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face)
|
||||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
||||||
@@ -563,8 +605,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||||
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
|
- [ChatGPTBox: All in one browser extension](https://github.com/josStorer/chatGPTBox) with [Integrating Tutorial](https://github.com/josStorer/chatGPTBox/issues/616#issuecomment-1975186467)
|
||||||
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
||||||
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
|
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depend on ollama server)
|
||||||
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front end Open WebUI service.)
|
- [Terraform AWS Ollama & Open WebUI](https://github.com/xuyangbocn/terraform-aws-self-host-llm) (A Terraform module to deploy on AWS a ready-to-use Ollama service, together with its front-end Open WebUI service.)
|
||||||
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
- [node-red-contrib-ollama](https://github.com/jakubburkiewicz/node-red-contrib-ollama)
|
||||||
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
- [Local AI Helper](https://github.com/ivostoykov/localAI) (Chrome and Firefox extensions that enable interactions with the active tab and customisable API endpoints. Includes secure storage for user prompts.)
|
||||||
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
|
- [vnc-lm](https://github.com/jake83741/vnc-lm) (Discord bot for messaging with LLMs through Ollama and LiteLLM. Seamlessly move between local and flagship models.)
|
||||||
@@ -577,12 +619,19 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||||||
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
|
||||||
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
- [Simple-Discord-AI](https://github.com/zyphixor/simple-discord-ai)
|
||||||
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
|
||||||
|
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
|
||||||
|
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
|
||||||
|
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||||
|
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||||
|
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||||
|
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
||||||
|
|
||||||
### Supported backends
|
### Supported backends
|
||||||
|
|
||||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
|
||||||
|
|
||||||
### Observability
|
### Observability
|
||||||
|
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
|
||||||
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
|
||||||
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
|
||||||
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
|
||||||
|
|||||||
@@ -24,7 +24,10 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/auth"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
@@ -76,6 +79,14 @@ func NewClient(base *url.URL, http *http.Client) *Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getAuthorizationToken(ctx context.Context, challenge string) (string, error) {
|
||||||
|
token, err := auth.Sign(ctx, []byte(challenge))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
func (c *Client) do(ctx context.Context, method, path string, reqData, respData any) error {
|
||||||
var reqBody io.Reader
|
var reqBody io.Reader
|
||||||
var data []byte
|
var data []byte
|
||||||
@@ -97,6 +108,21 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestURL := c.base.JoinPath(path)
|
requestURL := c.base.JoinPath(path)
|
||||||
|
|
||||||
|
var token string
|
||||||
|
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
|
||||||
|
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||||
|
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
|
||||||
|
token, err = getAuthorizationToken(ctx, chal)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
q := requestURL.Query()
|
||||||
|
q.Set("ts", now)
|
||||||
|
requestURL.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
|
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -106,6 +132,10 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
|
|||||||
request.Header.Set("Accept", "application/json")
|
request.Header.Set("Accept", "application/json")
|
||||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
|
|
||||||
|
if token != "" {
|
||||||
|
request.Header.Set("Authorization", token)
|
||||||
|
}
|
||||||
|
|
||||||
respObj, err := c.http.Do(request)
|
respObj, err := c.http.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -143,6 +173,22 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
}
|
}
|
||||||
|
|
||||||
requestURL := c.base.JoinPath(path)
|
requestURL := c.base.JoinPath(path)
|
||||||
|
|
||||||
|
var token string
|
||||||
|
if envconfig.UseAuth() || c.base.Hostname() == "ollama.com" {
|
||||||
|
var err error
|
||||||
|
now := strconv.FormatInt(time.Now().Unix(), 10)
|
||||||
|
chal := fmt.Sprintf("%s,%s?ts=%s", method, path, now)
|
||||||
|
token, err = getAuthorizationToken(ctx, chal)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
q := requestURL.Query()
|
||||||
|
q.Set("ts", now)
|
||||||
|
requestURL.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
|
request, err := http.NewRequestWithContext(ctx, method, requestURL.String(), buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -152,6 +198,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
request.Header.Set("Accept", "application/x-ndjson")
|
request.Header.Set("Accept", "application/x-ndjson")
|
||||||
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||||
|
|
||||||
|
if token != "" {
|
||||||
|
request.Header.Set("Authorization", token)
|
||||||
|
}
|
||||||
|
|
||||||
response, err := c.http.Do(request)
|
response, err := c.http.Do(request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -172,10 +222,6 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
return fmt.Errorf("unmarshal: %w", err)
|
return fmt.Errorf("unmarshal: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errorResponse.Error != "" {
|
|
||||||
return errors.New(errorResponse.Error)
|
|
||||||
}
|
|
||||||
|
|
||||||
if response.StatusCode >= http.StatusBadRequest {
|
if response.StatusCode >= http.StatusBadRequest {
|
||||||
return StatusError{
|
return StatusError{
|
||||||
StatusCode: response.StatusCode,
|
StatusCode: response.StatusCode,
|
||||||
@@ -184,6 +230,10 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if errorResponse.Error != "" {
|
||||||
|
return errors.New(errorResponse.Error)
|
||||||
|
}
|
||||||
|
|
||||||
if err := fn(bts); err != nil {
|
if err := fn(bts); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -90,6 +89,16 @@ func TestClientStream(t *testing.T) {
|
|||||||
},
|
},
|
||||||
wantErr: "mid-stream error",
|
wantErr: "mid-stream error",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "http status error takes precedence over general error",
|
||||||
|
responses: []any{
|
||||||
|
testError{
|
||||||
|
message: "custom error message",
|
||||||
|
statusCode: http.StatusInternalServerError,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantErr: "500",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "successful stream completion",
|
name: "successful stream completion",
|
||||||
responses: []any{
|
responses: []any{
|
||||||
@@ -137,7 +146,7 @@ func TestClientStream(t *testing.T) {
|
|||||||
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
|
||||||
|
|
||||||
var receivedChunks []ChatResponse
|
var receivedChunks []ChatResponse
|
||||||
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
err := client.stream(t.Context(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
|
||||||
var resp ChatResponse
|
var resp ChatResponse
|
||||||
if err := json.Unmarshal(chunk, &resp); err != nil {
|
if err := json.Unmarshal(chunk, &resp); err != nil {
|
||||||
return fmt.Errorf("failed to unmarshal chunk: %w", err)
|
return fmt.Errorf("failed to unmarshal chunk: %w", err)
|
||||||
@@ -223,7 +232,7 @@ func TestClientDo(t *testing.T) {
|
|||||||
ID string `json:"id"`
|
ID string `json:"id"`
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
}
|
}
|
||||||
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
|
err := client.do(t.Context(), http.MethodPost, "/v1/messages", nil, &resp)
|
||||||
|
|
||||||
if tc.wantErr != "" {
|
if tc.wantErr != "" {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
374
api/types.go
374
api/types.go
@@ -12,6 +12,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StatusError is an error with an HTTP status code and message.
|
// StatusError is an error with an HTTP status code and message.
|
||||||
@@ -75,13 +76,24 @@ type GenerateRequest struct {
|
|||||||
// this request.
|
// this request.
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
// Images is an optional list of base64-encoded images accompanying this
|
// Images is an optional list of raw image bytes accompanying this
|
||||||
// request, for multimodal models.
|
// request, for multimodal models.
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options. For example, temperature can be
|
// Options lists model-specific options. For example, temperature can be
|
||||||
// set through this field, if the model supports it.
|
// set through this field, if the model supports it.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
|
|
||||||
|
// Think controls whether thinking/reasoning models will think before
|
||||||
|
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||||
|
// for supported models. Needs to be a pointer so we can distinguish between false
|
||||||
|
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||||
|
// before this option was introduced)
|
||||||
|
Think *ThinkValue `json:"think,omitempty"`
|
||||||
|
|
||||||
|
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||||
|
// template instead of calling the model.
|
||||||
|
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatRequest describes a request sent by [Client.Chat].
|
// ChatRequest describes a request sent by [Client.Chat].
|
||||||
@@ -106,7 +118,16 @@ type ChatRequest struct {
|
|||||||
Tools `json:"tools,omitempty"`
|
Tools `json:"tools,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
|
|
||||||
|
// Think controls whether thinking/reasoning models will think before
|
||||||
|
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||||
|
// for supported models.
|
||||||
|
Think *ThinkValue `json:"think,omitempty"`
|
||||||
|
|
||||||
|
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||||
|
// template instead of calling the model.
|
||||||
|
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tools []Tool
|
type Tools []Tool
|
||||||
@@ -125,10 +146,14 @@ func (t Tool) String() string {
|
|||||||
// role ("system", "user", or "assistant"), the content and an optional list
|
// role ("system", "user", or "assistant"), the content and an optional list
|
||||||
// of images.
|
// of images.
|
||||||
type Message struct {
|
type Message struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Content string `json:"content"`
|
Content string `json:"content"`
|
||||||
|
// Thinking contains the text that was inside thinking tags in the
|
||||||
|
// original model output when ChatRequest.Think is enabled.
|
||||||
|
Thinking string `json:"thinking,omitempty"`
|
||||||
Images []ImageData `json:"images,omitempty"`
|
Images []ImageData `json:"images,omitempty"`
|
||||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
|
ToolName string `json:"tool_name,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||||
@@ -162,21 +187,122 @@ func (t *ToolCallFunctionArguments) String() string {
|
|||||||
|
|
||||||
type Tool struct {
|
type Tool struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
Function ToolFunction `json:"function"`
|
Function ToolFunction `json:"function"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PropertyType can be either a string or an array of strings
|
||||||
|
type PropertyType []string
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||||
|
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
|
||||||
|
// Try to unmarshal as a string first
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err == nil {
|
||||||
|
*pt = []string{s}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If that fails, try to unmarshal as an array of strings
|
||||||
|
var a []string
|
||||||
|
if err := json.Unmarshal(data, &a); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*pt = a
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface
|
||||||
|
func (pt PropertyType) MarshalJSON() ([]byte, error) {
|
||||||
|
if len(pt) == 1 {
|
||||||
|
// If there's only one type, marshal as a string
|
||||||
|
return json.Marshal(pt[0])
|
||||||
|
}
|
||||||
|
// Otherwise marshal as an array
|
||||||
|
return json.Marshal([]string(pt))
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the PropertyType
|
||||||
|
func (pt PropertyType) String() string {
|
||||||
|
if len(pt) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(pt) == 1 {
|
||||||
|
return pt[0]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%v", []string(pt))
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolProperty struct {
|
||||||
|
AnyOf []ToolProperty `json:"anyOf,omitempty"`
|
||||||
|
Type PropertyType `json:"type"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Enum []any `json:"enum,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToTypeScriptType converts a ToolProperty to a TypeScript type string
|
||||||
|
func (tp ToolProperty) ToTypeScriptType() string {
|
||||||
|
if len(tp.AnyOf) > 0 {
|
||||||
|
var types []string
|
||||||
|
for _, anyOf := range tp.AnyOf {
|
||||||
|
types = append(types, anyOf.ToTypeScriptType())
|
||||||
|
}
|
||||||
|
return strings.Join(types, " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tp.Type) == 0 {
|
||||||
|
return "any"
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tp.Type) == 1 {
|
||||||
|
return mapToTypeScriptType(tp.Type[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
var types []string
|
||||||
|
for _, t := range tp.Type {
|
||||||
|
types = append(types, mapToTypeScriptType(t))
|
||||||
|
}
|
||||||
|
return strings.Join(types, " | ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapToTypeScriptType maps JSON Schema types to TypeScript types
|
||||||
|
func mapToTypeScriptType(jsonType string) string {
|
||||||
|
switch jsonType {
|
||||||
|
case "string":
|
||||||
|
return "string"
|
||||||
|
case "number", "integer":
|
||||||
|
return "number"
|
||||||
|
case "boolean":
|
||||||
|
return "boolean"
|
||||||
|
case "array":
|
||||||
|
return "any[]"
|
||||||
|
case "object":
|
||||||
|
return "Record<string, any>"
|
||||||
|
case "null":
|
||||||
|
return "null"
|
||||||
|
default:
|
||||||
|
return "any"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ToolFunctionParameters struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Defs any `json:"$defs,omitempty"`
|
||||||
|
Items any `json:"items,omitempty"`
|
||||||
|
Required []string `json:"required"`
|
||||||
|
Properties map[string]ToolProperty `json:"properties"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ToolFunctionParameters) String() string {
|
||||||
|
bts, _ := json.Marshal(t)
|
||||||
|
return string(bts)
|
||||||
|
}
|
||||||
|
|
||||||
type ToolFunction struct {
|
type ToolFunction struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Description string `json:"description"`
|
Description string `json:"description"`
|
||||||
Parameters struct {
|
Parameters ToolFunctionParameters `json:"parameters"`
|
||||||
Type string `json:"type"`
|
|
||||||
Required []string `json:"required"`
|
|
||||||
Properties map[string]struct {
|
|
||||||
Type string `json:"type"`
|
|
||||||
Description string `json:"description"`
|
|
||||||
Enum []string `json:"enum,omitempty"`
|
|
||||||
} `json:"properties"`
|
|
||||||
} `json:"parameters"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ToolFunction) String() string {
|
func (t *ToolFunction) String() string {
|
||||||
@@ -197,6 +323,19 @@ type ChatResponse struct {
|
|||||||
Metrics
|
Metrics
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DebugInfo contains debug information for template rendering
|
||||||
|
type DebugInfo struct {
|
||||||
|
RenderedTemplate string `json:"rendered_template"`
|
||||||
|
ImageCount int `json:"image_count,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DebugTemplateResponse is returned when _debug_render_only is set to true
|
||||||
|
type DebugTemplateResponse struct {
|
||||||
|
Model string `json:"model"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
DebugInfo DebugInfo `json:"_debug_info"`
|
||||||
|
}
|
||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||||
@@ -224,9 +363,6 @@ type Options struct {
|
|||||||
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
RepeatPenalty float32 `json:"repeat_penalty,omitempty"`
|
||||||
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
PresencePenalty float32 `json:"presence_penalty,omitempty"`
|
||||||
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
FrequencyPenalty float32 `json:"frequency_penalty,omitempty"`
|
||||||
Mirostat int `json:"mirostat,omitempty"`
|
|
||||||
MirostatTau float32 `json:"mirostat_tau,omitempty"`
|
|
||||||
MirostatEta float32 `json:"mirostat_eta,omitempty"`
|
|
||||||
Stop []string `json:"stop,omitempty"`
|
Stop []string `json:"stop,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -236,12 +372,7 @@ type Runner struct {
|
|||||||
NumBatch int `json:"num_batch,omitempty"`
|
NumBatch int `json:"num_batch,omitempty"`
|
||||||
NumGPU int `json:"num_gpu,omitempty"`
|
NumGPU int `json:"num_gpu,omitempty"`
|
||||||
MainGPU int `json:"main_gpu,omitempty"`
|
MainGPU int `json:"main_gpu,omitempty"`
|
||||||
LowVRAM bool `json:"low_vram,omitempty"`
|
|
||||||
F16KV bool `json:"f16_kv,omitempty"` // Deprecated: This option is ignored
|
|
||||||
LogitsAll bool `json:"logits_all,omitempty"`
|
|
||||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
|
||||||
UseMMap *bool `json:"use_mmap,omitempty"`
|
UseMMap *bool `json:"use_mmap,omitempty"`
|
||||||
UseMLock bool `json:"use_mlock,omitempty"`
|
|
||||||
NumThread int `json:"num_thread,omitempty"`
|
NumThread int `json:"num_thread,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,7 +391,7 @@ type EmbedRequest struct {
|
|||||||
Truncate *bool `json:"truncate,omitempty"`
|
Truncate *bool `json:"truncate,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbedResponse is the response from [Client.Embed].
|
// EmbedResponse is the response from [Client.Embed].
|
||||||
@@ -286,7 +417,7 @@ type EmbeddingRequest struct {
|
|||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
// Options lists model-specific options.
|
// Options lists model-specific options.
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// EmbeddingResponse is the response from [Client.Embeddings].
|
// EmbeddingResponse is the response from [Client.Embeddings].
|
||||||
@@ -332,7 +463,7 @@ type ShowRequest struct {
|
|||||||
Template string `json:"template"`
|
Template string `json:"template"`
|
||||||
Verbose bool `json:"verbose"`
|
Verbose bool `json:"verbose"`
|
||||||
|
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]any `json:"options"`
|
||||||
|
|
||||||
// Deprecated: set the model name with Model instead
|
// Deprecated: set the model name with Model instead
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
@@ -340,16 +471,18 @@ type ShowRequest struct {
|
|||||||
|
|
||||||
// ShowResponse is the response returned from [Client.Show].
|
// ShowResponse is the response returned from [Client.Show].
|
||||||
type ShowResponse struct {
|
type ShowResponse struct {
|
||||||
License string `json:"license,omitempty"`
|
License string `json:"license,omitempty"`
|
||||||
Modelfile string `json:"modelfile,omitempty"`
|
Modelfile string `json:"modelfile,omitempty"`
|
||||||
Parameters string `json:"parameters,omitempty"`
|
Parameters string `json:"parameters,omitempty"`
|
||||||
Template string `json:"template,omitempty"`
|
Template string `json:"template,omitempty"`
|
||||||
System string `json:"system,omitempty"`
|
System string `json:"system,omitempty"`
|
||||||
Details ModelDetails `json:"details,omitempty"`
|
Details ModelDetails `json:"details,omitempty"`
|
||||||
Messages []Message `json:"messages,omitempty"`
|
Messages []Message `json:"messages,omitempty"`
|
||||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||||
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
ProjectorInfo map[string]any `json:"projector_info,omitempty"`
|
||||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
Tensors []Tensor `json:"tensors,omitempty"`
|
||||||
|
Capabilities []model.Capability `json:"capabilities,omitempty"`
|
||||||
|
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// CopyRequest is the request passed to [Client.Copy].
|
// CopyRequest is the request passed to [Client.Copy].
|
||||||
@@ -361,9 +494,9 @@ type CopyRequest struct {
|
|||||||
// PullRequest is the request passed to [Client.Pull].
|
// PullRequest is the request passed to [Client.Pull].
|
||||||
type PullRequest struct {
|
type PullRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Insecure bool `json:"insecure,omitempty"`
|
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
|
||||||
Username string `json:"username"`
|
Username string `json:"username"` // Deprecated: ignored
|
||||||
Password string `json:"password"`
|
Password string `json:"password"` // Deprecated: ignored
|
||||||
Stream *bool `json:"stream,omitempty"`
|
Stream *bool `json:"stream,omitempty"`
|
||||||
|
|
||||||
// Deprecated: set the model name with Model instead
|
// Deprecated: set the model name with Model instead
|
||||||
@@ -413,20 +546,14 @@ type ListModelResponse struct {
|
|||||||
|
|
||||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||||
type ProcessModelResponse struct {
|
type ProcessModelResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Size int64 `json:"size"`
|
Size int64 `json:"size"`
|
||||||
Digest string `json:"digest"`
|
Digest string `json:"digest"`
|
||||||
Details ModelDetails `json:"details,omitempty"`
|
Details ModelDetails `json:"details,omitempty"`
|
||||||
ExpiresAt time.Time `json:"expires_at"`
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
SizeVRAM int64 `json:"size_vram"`
|
SizeVRAM int64 `json:"size_vram"`
|
||||||
}
|
ContextLength int `json:"context_length"`
|
||||||
|
|
||||||
type RetrieveModelResponse struct {
|
|
||||||
Id string `json:"id"`
|
|
||||||
Object string `json:"object"`
|
|
||||||
Created int64 `json:"created"`
|
|
||||||
OwnedBy string `json:"owned_by"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
@@ -444,6 +571,10 @@ type GenerateResponse struct {
|
|||||||
// Response is the textual response itself.
|
// Response is the textual response itself.
|
||||||
Response string `json:"response"`
|
Response string `json:"response"`
|
||||||
|
|
||||||
|
// Thinking contains the text that was inside thinking tags in the
|
||||||
|
// original model output when ChatRequest.Think is enabled.
|
||||||
|
Thinking string `json:"thinking,omitempty"`
|
||||||
|
|
||||||
// Done specifies if the response is complete.
|
// Done specifies if the response is complete.
|
||||||
Done bool `json:"done"`
|
Done bool `json:"done"`
|
||||||
|
|
||||||
@@ -455,6 +586,8 @@ type GenerateResponse struct {
|
|||||||
Context []int `json:"context,omitempty"`
|
Context []int `json:"context,omitempty"`
|
||||||
|
|
||||||
Metrics
|
Metrics
|
||||||
|
|
||||||
|
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ModelDetails provides details about a model.
|
// ModelDetails provides details about a model.
|
||||||
@@ -467,6 +600,13 @@ type ModelDetails struct {
|
|||||||
QuantizationLevel string `json:"quantization_level"`
|
QuantizationLevel string `json:"quantization_level"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tensor describes the metadata for a given tensor.
|
||||||
|
type Tensor struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Shape []uint64 `json:"shape"`
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Metrics) Summary() {
|
func (m *Metrics) Summary() {
|
||||||
if m.TotalDuration > 0 {
|
if m.TotalDuration > 0 {
|
||||||
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
fmt.Fprintf(os.Stderr, "total duration: %v\n", m.TotalDuration)
|
||||||
@@ -495,7 +635,7 @@ func (m *Metrics) Summary() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
func (opts *Options) FromMap(m map[string]any) error {
|
||||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||||
|
|
||||||
@@ -552,12 +692,12 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
|
|||||||
}
|
}
|
||||||
field.SetString(val)
|
field.SetString(val)
|
||||||
case reflect.Slice:
|
case reflect.Slice:
|
||||||
// JSON unmarshals to []interface{}, not []string
|
// JSON unmarshals to []any, not []string
|
||||||
val, ok := val.([]interface{})
|
val, ok := val.([]any)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("option %q must be of type array", key)
|
return fmt.Errorf("option %q must be of type array", key)
|
||||||
}
|
}
|
||||||
// convert []interface{} to []string
|
// convert []any to []string
|
||||||
slice := make([]string, len(val))
|
slice := make([]string, len(val))
|
||||||
for i, item := range val {
|
for i, item := range val {
|
||||||
str, ok := item.(string)
|
str, ok := item.(string)
|
||||||
@@ -604,9 +744,6 @@ func DefaultOptions() Options {
|
|||||||
RepeatPenalty: 1.1,
|
RepeatPenalty: 1.1,
|
||||||
PresencePenalty: 0.0,
|
PresencePenalty: 0.0,
|
||||||
FrequencyPenalty: 0.0,
|
FrequencyPenalty: 0.0,
|
||||||
Mirostat: 0,
|
|
||||||
MirostatTau: 5.0,
|
|
||||||
MirostatEta: 0.1,
|
|
||||||
Seed: -1,
|
Seed: -1,
|
||||||
|
|
||||||
Runner: Runner{
|
Runner: Runner{
|
||||||
@@ -615,13 +752,118 @@ func DefaultOptions() Options {
|
|||||||
NumBatch: 512,
|
NumBatch: 512,
|
||||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||||
NumThread: 0, // let the runtime decide
|
NumThread: 0, // let the runtime decide
|
||||||
LowVRAM: false,
|
|
||||||
UseMLock: false,
|
|
||||||
UseMMap: nil,
|
UseMMap: nil,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ThinkValue represents a value that can be a boolean or a string ("high", "medium", "low")
|
||||||
|
type ThinkValue struct {
|
||||||
|
// Value can be a bool or string
|
||||||
|
Value interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid checks if the ThinkValue is valid
|
||||||
|
func (t *ThinkValue) IsValid() bool {
|
||||||
|
if t == nil || t.Value == nil {
|
||||||
|
return true // nil is valid (means not set)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := t.Value.(type) {
|
||||||
|
case bool:
|
||||||
|
return true
|
||||||
|
case string:
|
||||||
|
return v == "high" || v == "medium" || v == "low"
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsBool returns true if the value is a boolean
|
||||||
|
func (t *ThinkValue) IsBool() bool {
|
||||||
|
if t == nil || t.Value == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := t.Value.(bool)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsString returns true if the value is a string
|
||||||
|
func (t *ThinkValue) IsString() bool {
|
||||||
|
if t == nil || t.Value == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, ok := t.Value.(string)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bool returns the value as a bool (true if enabled in any way)
|
||||||
|
func (t *ThinkValue) Bool() bool {
|
||||||
|
if t == nil || t.Value == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := t.Value.(type) {
|
||||||
|
case bool:
|
||||||
|
return v
|
||||||
|
case string:
|
||||||
|
// Any string value ("high", "medium", "low") means thinking is enabled
|
||||||
|
return v == "high" || v == "medium" || v == "low"
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the value as a string
|
||||||
|
func (t *ThinkValue) String() string {
|
||||||
|
if t == nil || t.Value == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
switch v := t.Value.(type) {
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
case bool:
|
||||||
|
if v {
|
||||||
|
return "medium" // Default level when just true
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler
|
||||||
|
func (t *ThinkValue) UnmarshalJSON(data []byte) error {
|
||||||
|
// Try to unmarshal as bool first
|
||||||
|
var b bool
|
||||||
|
if err := json.Unmarshal(data, &b); err == nil {
|
||||||
|
t.Value = b
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to unmarshal as string
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err == nil {
|
||||||
|
// Validate string values
|
||||||
|
if s != "high" && s != "medium" && s != "low" {
|
||||||
|
return fmt.Errorf("invalid think value: %q (must be \"high\", \"medium\", \"low\", true, or false)", s)
|
||||||
|
}
|
||||||
|
t.Value = s
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\")")
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler
|
||||||
|
func (t *ThinkValue) MarshalJSON() ([]byte, error) {
|
||||||
|
if t == nil || t.Value == nil {
|
||||||
|
return []byte("null"), nil
|
||||||
|
}
|
||||||
|
return json.Marshal(t.Value)
|
||||||
|
}
|
||||||
|
|
||||||
type Duration struct {
|
type Duration struct {
|
||||||
time.Duration
|
time.Duration
|
||||||
}
|
}
|
||||||
@@ -646,7 +888,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
|||||||
if t < 0 {
|
if t < 0 {
|
||||||
d.Duration = time.Duration(math.MaxInt64)
|
d.Duration = time.Duration(math.MaxInt64)
|
||||||
} else {
|
} else {
|
||||||
d.Duration = time.Duration(int(t) * int(time.Second))
|
d.Duration = time.Duration(t * float64(time.Second))
|
||||||
}
|
}
|
||||||
case string:
|
case string:
|
||||||
d.Duration, err = time.ParseDuration(t)
|
d.Duration, err = time.ParseDuration(t)
|
||||||
@@ -664,7 +906,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// FormatParams converts specified parameter options to their correct types
|
// FormatParams converts specified parameter options to their correct types
|
||||||
func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
func FormatParams(params map[string][]string) (map[string]any, error) {
|
||||||
opts := Options{}
|
opts := Options{}
|
||||||
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
|
||||||
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
|
||||||
@@ -678,7 +920,7 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out := make(map[string]interface{})
|
out := make(map[string]any)
|
||||||
// iterate params and set values based on json struct tags
|
// iterate params and set values based on json struct tags
|
||||||
for key, vals := range params {
|
for key, vals := range params {
|
||||||
if opt, ok := jsonOpts[key]; !ok {
|
if opt, ok := jsonOpts[key]; !ok {
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
|||||||
req string
|
req string
|
||||||
exp *Duration
|
exp *Duration
|
||||||
}{
|
}{
|
||||||
|
{
|
||||||
|
name: "Unset",
|
||||||
|
req: `{ }`,
|
||||||
|
exp: nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "Positive Integer",
|
name: "Positive Integer",
|
||||||
req: `{ "keep_alive": 42 }`,
|
req: `{ "keep_alive": 42 }`,
|
||||||
@@ -25,7 +30,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Positive Float",
|
name: "Positive Float",
|
||||||
req: `{ "keep_alive": 42.5 }`,
|
req: `{ "keep_alive": 42.5 }`,
|
||||||
exp: &Duration{42 * time.Second},
|
exp: &Duration{42500 * time.Millisecond},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Positive Integer String",
|
name: "Positive Integer String",
|
||||||
@@ -134,7 +139,7 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
|
|||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
var oMap map[string]interface{}
|
var oMap map[string]any
|
||||||
err := json.Unmarshal([]byte(test.req), &oMap)
|
err := json.Unmarshal([]byte(test.req), &oMap)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
opts := DefaultOptions()
|
opts := DefaultOptions()
|
||||||
@@ -231,3 +236,255 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestToolFunction_UnmarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid enum with same types",
|
||||||
|
input: `{
|
||||||
|
"name": "test",
|
||||||
|
"description": "test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["test"],
|
||||||
|
"properties": {
|
||||||
|
"test": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "test prop",
|
||||||
|
"enum": ["a", "b", "c"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty enum array",
|
||||||
|
input: `{
|
||||||
|
"name": "test",
|
||||||
|
"description": "test function",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"required": ["test"],
|
||||||
|
"properties": {
|
||||||
|
"test": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "test prop",
|
||||||
|
"enum": []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`,
|
||||||
|
wantErr: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var tf ToolFunction
|
||||||
|
err := json.Unmarshal([]byte(tt.input), &tf)
|
||||||
|
|
||||||
|
if tt.wantErr != "" {
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), tt.wantErr)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPropertyType_UnmarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected PropertyType
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "string type",
|
||||||
|
input: `"string"`,
|
||||||
|
expected: PropertyType{"string"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array of types",
|
||||||
|
input: `["string", "number"]`,
|
||||||
|
expected: PropertyType{"string", "number"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array with single type",
|
||||||
|
input: `["string"]`,
|
||||||
|
expected: PropertyType{"string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var pt PropertyType
|
||||||
|
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(pt) != len(test.expected) {
|
||||||
|
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, v := range pt {
|
||||||
|
if v != test.expected[i] {
|
||||||
|
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input PropertyType
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single type",
|
||||||
|
input: PropertyType{"string"},
|
||||||
|
expected: `"string"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple types",
|
||||||
|
input: PropertyType{"string", "number"},
|
||||||
|
expected: `["string","number"]`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty type",
|
||||||
|
input: PropertyType{},
|
||||||
|
expected: `[]`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
data, err := json.Marshal(test.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(data) != test.expected {
|
||||||
|
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectedThinking *ThinkValue
|
||||||
|
expectedError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "true",
|
||||||
|
input: `{ "think": true }`,
|
||||||
|
expectedThinking: &ThinkValue{Value: true},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "false",
|
||||||
|
input: `{ "think": false }`,
|
||||||
|
expectedThinking: &ThinkValue{Value: false},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unset",
|
||||||
|
input: `{ }`,
|
||||||
|
expectedThinking: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string_high",
|
||||||
|
input: `{ "think": "high" }`,
|
||||||
|
expectedThinking: &ThinkValue{Value: "high"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string_medium",
|
||||||
|
input: `{ "think": "medium" }`,
|
||||||
|
expectedThinking: &ThinkValue{Value: "medium"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string_low",
|
||||||
|
input: `{ "think": "low" }`,
|
||||||
|
expectedThinking: &ThinkValue{Value: "low"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid_string",
|
||||||
|
input: `{ "think": "invalid" }`,
|
||||||
|
expectedThinking: nil,
|
||||||
|
expectedError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
var req GenerateRequest
|
||||||
|
err := json.Unmarshal([]byte(test.input), &req)
|
||||||
|
if test.expectedError {
|
||||||
|
require.Error(t, err)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
if test.expectedThinking == nil {
|
||||||
|
assert.Nil(t, req.Think)
|
||||||
|
} else {
|
||||||
|
require.NotNil(t, req.Think)
|
||||||
|
assert.Equal(t, test.expectedThinking.Value, req.Think.Value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolFunctionParameters_String(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
params ToolFunctionParameters
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple object with string property",
|
||||||
|
params: ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"name"},
|
||||||
|
Properties: map[string]ToolProperty{
|
||||||
|
"name": {
|
||||||
|
Type: PropertyType{"string"},
|
||||||
|
Description: "The name of the person",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "marshal failure returns empty string",
|
||||||
|
params: ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Defs: func() any {
|
||||||
|
// Create a cycle that will cause json.Marshal to fail
|
||||||
|
type selfRef struct {
|
||||||
|
Self *selfRef
|
||||||
|
}
|
||||||
|
s := &selfRef{}
|
||||||
|
s.Self = s
|
||||||
|
return s
|
||||||
|
}(),
|
||||||
|
Properties: map[string]ToolProperty{},
|
||||||
|
},
|
||||||
|
expected: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
result := test.params.String()
|
||||||
|
assert.Equal(t, test.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
142
api/types_typescript_test.go
Normal file
142
api/types_typescript_test.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestToolParameterToTypeScriptType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
param ToolProperty
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "single string type",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"string"},
|
||||||
|
},
|
||||||
|
expected: "string",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single number type",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"number"},
|
||||||
|
},
|
||||||
|
expected: "number",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "integer maps to number",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"integer"},
|
||||||
|
},
|
||||||
|
expected: "number",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "boolean type",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"boolean"},
|
||||||
|
},
|
||||||
|
expected: "boolean",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "array type",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"array"},
|
||||||
|
},
|
||||||
|
expected: "any[]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "object type",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"object"},
|
||||||
|
},
|
||||||
|
expected: "Record<string, any>",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null type",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"null"},
|
||||||
|
},
|
||||||
|
expected: "null",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple types as union",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"string", "number"},
|
||||||
|
},
|
||||||
|
expected: "string | number",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "string or null union",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"string", "null"},
|
||||||
|
},
|
||||||
|
expected: "string | null",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anyOf with single types",
|
||||||
|
param: ToolProperty{
|
||||||
|
AnyOf: []ToolProperty{
|
||||||
|
{Type: PropertyType{"string"}},
|
||||||
|
{Type: PropertyType{"number"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "string | number",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "anyOf with multiple types in each branch",
|
||||||
|
param: ToolProperty{
|
||||||
|
AnyOf: []ToolProperty{
|
||||||
|
{Type: PropertyType{"string", "null"}},
|
||||||
|
{Type: PropertyType{"number"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "string | null | number",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested anyOf",
|
||||||
|
param: ToolProperty{
|
||||||
|
AnyOf: []ToolProperty{
|
||||||
|
{Type: PropertyType{"boolean"}},
|
||||||
|
{
|
||||||
|
AnyOf: []ToolProperty{
|
||||||
|
{Type: PropertyType{"string"}},
|
||||||
|
{Type: PropertyType{"number"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "boolean | string | number",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty type returns any",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{},
|
||||||
|
},
|
||||||
|
expected: "any",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown type maps to any",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"unknown_type"},
|
||||||
|
},
|
||||||
|
expected: "any",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple types including array",
|
||||||
|
param: ToolProperty{
|
||||||
|
Type: PropertyType{"string", "array", "null"},
|
||||||
|
},
|
||||||
|
expected: "string | any[] | null",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := tt.param.ToTypeScriptType()
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ToTypeScriptType() = %q, want %q", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,20 +4,14 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func InitLogging() {
|
func InitLogging() {
|
||||||
level := slog.LevelInfo
|
|
||||||
|
|
||||||
if envconfig.Debug() {
|
|
||||||
level = slog.LevelDebug
|
|
||||||
}
|
|
||||||
|
|
||||||
var logFile *os.File
|
var logFile *os.File
|
||||||
var err error
|
var err error
|
||||||
// Detect if we're a GUI app on windows, and if not, send logs to console
|
// Detect if we're a GUI app on windows, and if not, send logs to console
|
||||||
@@ -33,20 +27,8 @@ func InitLogging() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
handler := slog.NewTextHandler(logFile, &slog.HandlerOptions{
|
|
||||||
Level: level,
|
|
||||||
AddSource: true,
|
|
||||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
|
||||||
if attr.Key == slog.SourceKey {
|
|
||||||
source := attr.Value.Any().(*slog.Source)
|
|
||||||
source.File = filepath.Base(source.File)
|
|
||||||
}
|
|
||||||
return attr
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
slog.SetDefault(slog.New(handler))
|
|
||||||
|
|
||||||
|
slog.SetDefault(logutil.NewLogger(logFile, envconfig.LogLevel()))
|
||||||
slog.Info("ollama app started")
|
slog.Info("ollama app started")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
492
cmd/cmd.go
492
cmd/cmd.go
@@ -18,6 +18,8 @@ import (
|
|||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"slices"
|
||||||
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -29,20 +31,39 @@ import (
|
|||||||
"github.com/olekukonko/tablewriter"
|
"github.com/olekukonko/tablewriter"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
"golang.org/x/term"
|
"golang.org/x/term"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/llama"
|
|
||||||
"github.com/ollama/ollama/parser"
|
"github.com/ollama/ollama/parser"
|
||||||
"github.com/ollama/ollama/progress"
|
"github.com/ollama/ollama/progress"
|
||||||
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/runner"
|
"github.com/ollama/ollama/runner"
|
||||||
"github.com/ollama/ollama/server"
|
"github.com/ollama/ollama/server"
|
||||||
"github.com/ollama/ollama/types/model"
|
"github.com/ollama/ollama/types/model"
|
||||||
|
"github.com/ollama/ollama/types/syncmap"
|
||||||
"github.com/ollama/ollama/version"
|
"github.com/ollama/ollama/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ensureThinkingSupport emits a warning if the model does not advertise thinking support
|
||||||
|
func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) {
|
||||||
|
if name == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Model: name})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, cap := range resp.Capabilities {
|
||||||
|
if cap == model.CapabilityThinking {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "warning: model %q does not support thinking output\n", name)
|
||||||
|
}
|
||||||
|
|
||||||
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
|
||||||
|
|
||||||
func getModelfileName(cmd *cobra.Command) (string, error) {
|
func getModelfileName(cmd *cobra.Command) (string, error) {
|
||||||
@@ -105,7 +126,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
|
|
||||||
req.Name = args[0]
|
req.Model = args[0]
|
||||||
quantize, _ := cmd.Flags().GetString("quantize")
|
quantize, _ := cmd.Flags().GetString("quantize")
|
||||||
if quantize != "" {
|
if quantize != "" {
|
||||||
req.Quantize = quantize
|
req.Quantize = quantize
|
||||||
@@ -116,34 +137,54 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req.Files) > 0 {
|
var g errgroup.Group
|
||||||
fileMap := map[string]string{}
|
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
||||||
for f, digest := range req.Files {
|
|
||||||
|
files := syncmap.NewSyncMap[string, string]()
|
||||||
|
for f, digest := range req.Files {
|
||||||
|
g.Go(func() error {
|
||||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fileMap[filepath.Base(f)] = digest
|
|
||||||
}
|
// TODO: this is incorrect since the file might be in a subdirectory
|
||||||
req.Files = fileMap
|
// instead this should take the path relative to the model directory
|
||||||
|
// but the current implementation does not allow this
|
||||||
|
files.Store(filepath.Base(f), digest)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(req.Adapters) > 0 {
|
adapters := syncmap.NewSyncMap[string, string]()
|
||||||
fileMap := map[string]string{}
|
for f, digest := range req.Adapters {
|
||||||
for f, digest := range req.Adapters {
|
g.Go(func() error {
|
||||||
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fileMap[filepath.Base(f)] = digest
|
|
||||||
}
|
// TODO: same here
|
||||||
req.Adapters = fileMap
|
adapters.Store(filepath.Base(f), digest)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := g.Wait(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Files = files.Items()
|
||||||
|
req.Adapters = adapters.Items()
|
||||||
|
|
||||||
bars := make(map[string]*progress.Bar)
|
bars := make(map[string]*progress.Bar)
|
||||||
fn := func(resp api.ProgressResponse) error {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
if resp.Digest != "" {
|
if resp.Digest != "" {
|
||||||
bar, ok := bars[resp.Digest]
|
bar, ok := bars[resp.Digest]
|
||||||
if !ok {
|
if !ok {
|
||||||
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
msg := resp.Status
|
||||||
|
if msg == "" {
|
||||||
|
msg = fmt.Sprintf("pulling %s...", resp.Digest[7:19])
|
||||||
|
}
|
||||||
|
bar = progress.NewBar(msg, resp.Total, resp.Completed)
|
||||||
bars[resp.Digest] = bar
|
bars[resp.Digest] = bar
|
||||||
p.Add(resp.Digest, bar)
|
p.Add(resp.Digest, bar)
|
||||||
}
|
}
|
||||||
@@ -212,7 +253,7 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, digest stri
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
if err := client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
return digest, nil
|
return digest, nil
|
||||||
@@ -242,6 +283,9 @@ func loadOrUnloadModel(cmd *cobra.Command, opts *runOptions) error {
|
|||||||
req := &api.GenerateRequest{
|
req := &api.GenerateRequest{
|
||||||
Model: opts.Model,
|
Model: opts.Model,
|
||||||
KeepAlive: opts.KeepAlive,
|
KeepAlive: opts.KeepAlive,
|
||||||
|
|
||||||
|
// pass Think here so we fail before getting to the chat prompt if the model doesn't support it
|
||||||
|
Think: opts.Think,
|
||||||
}
|
}
|
||||||
|
|
||||||
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
return client.Generate(cmd.Context(), req, func(api.GenerateResponse) error { return nil })
|
||||||
@@ -256,6 +300,7 @@ func StopHandler(cmd *cobra.Command, args []string) error {
|
|||||||
if strings.Contains(err.Error(), "not found") {
|
if strings.Contains(err.Error(), "not found") {
|
||||||
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
||||||
}
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -266,7 +311,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
opts := runOptions{
|
opts := runOptions{
|
||||||
Model: args[0],
|
Model: args[0],
|
||||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||||
Options: map[string]interface{}{},
|
Options: map[string]any{},
|
||||||
}
|
}
|
||||||
|
|
||||||
format, err := cmd.Flags().GetString("format")
|
format, err := cmd.Flags().GetString("format")
|
||||||
@@ -275,6 +320,34 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
opts.Format = format
|
opts.Format = format
|
||||||
|
|
||||||
|
thinkFlag := cmd.Flags().Lookup("think")
|
||||||
|
if thinkFlag.Changed {
|
||||||
|
thinkStr, err := cmd.Flags().GetString("think")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle different values for --think
|
||||||
|
switch thinkStr {
|
||||||
|
case "", "true":
|
||||||
|
// --think or --think=true
|
||||||
|
opts.Think = &api.ThinkValue{Value: true}
|
||||||
|
case "false":
|
||||||
|
opts.Think = &api.ThinkValue{Value: false}
|
||||||
|
case "high", "medium", "low":
|
||||||
|
opts.Think = &api.ThinkValue{Value: thinkStr}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("invalid value for --think: %q (must be true, false, high, medium, or low)", thinkStr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
opts.Think = nil
|
||||||
|
}
|
||||||
|
hidethinking, err := cmd.Flags().GetBool("hidethinking")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
opts.HideThinking = hidethinking
|
||||||
|
|
||||||
keepAlive, err := cmd.Flags().GetString("keepalive")
|
keepAlive, err := cmd.Flags().GetString("keepalive")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -338,10 +411,26 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jessegross): We should either find another way to know if this is
|
opts.Think, err = inferThinkingOption(&info.Capabilities, &opts, thinkFlag.Changed)
|
||||||
// a vision model or remove the logic. Also consider that other modalities will
|
if err != nil {
|
||||||
// need different behavior anyways.
|
return err
|
||||||
opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
|
}
|
||||||
|
|
||||||
|
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
|
||||||
|
|
||||||
|
// TODO: remove the projector info and vision info checks below,
|
||||||
|
// these are left in for backwards compatibility with older servers
|
||||||
|
// that don't have the capabilities field in the model info
|
||||||
|
if len(info.ProjectorInfo) != 0 {
|
||||||
|
opts.MultiModal = true
|
||||||
|
}
|
||||||
|
for k := range info.ModelInfo {
|
||||||
|
if strings.Contains(k, ".vision.") {
|
||||||
|
opts.MultiModal = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
opts.ParentModel = info.Details.ParentModel
|
opts.ParentModel = info.Details.ParentModel
|
||||||
|
|
||||||
if interactive {
|
if interactive {
|
||||||
@@ -506,12 +595,13 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error {
|
|||||||
} else {
|
} else {
|
||||||
until = format.HumanTime(m.ExpiresAt, "Never")
|
until = format.HumanTime(m.ExpiresAt, "Never")
|
||||||
}
|
}
|
||||||
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, until})
|
ctxStr := strconv.Itoa(m.ContextLength)
|
||||||
|
data = append(data, []string{m.Name, m.Digest[:12], format.HumanBytes(m.Size), procStr, ctxStr, until})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
table := tablewriter.NewWriter(os.Stdout)
|
table := tablewriter.NewWriter(os.Stdout)
|
||||||
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "UNTIL"})
|
table.SetHeader([]string{"NAME", "ID", "SIZE", "PROCESSOR", "CONTEXT", "UNTIL"})
|
||||||
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
|
table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
|
||||||
table.SetAlignment(tablewriter.ALIGN_LEFT)
|
table.SetAlignment(tablewriter.ALIGN_LEFT)
|
||||||
table.SetHeaderLine(false)
|
table.SetHeaderLine(false)
|
||||||
@@ -562,8 +652,9 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
parameters, errParams := cmd.Flags().GetBool("parameters")
|
parameters, errParams := cmd.Flags().GetBool("parameters")
|
||||||
system, errSystem := cmd.Flags().GetBool("system")
|
system, errSystem := cmd.Flags().GetBool("system")
|
||||||
template, errTemplate := cmd.Flags().GetBool("template")
|
template, errTemplate := cmd.Flags().GetBool("template")
|
||||||
|
verbose, errVerbose := cmd.Flags().GetBool("verbose")
|
||||||
|
|
||||||
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate} {
|
for _, boolErr := range []error{errLicense, errModelfile, errParams, errSystem, errTemplate, errVerbose} {
|
||||||
if boolErr != nil {
|
if boolErr != nil {
|
||||||
return errors.New("error retrieving flags")
|
return errors.New("error retrieving flags")
|
||||||
}
|
}
|
||||||
@@ -601,7 +692,7 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
return errors.New("only one of '--license', '--modelfile', '--parameters', '--system', or '--template' can be specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
req := api.ShowRequest{Name: args[0]}
|
req := api.ShowRequest{Name: args[0], Verbose: verbose}
|
||||||
resp, err := client.Show(cmd.Context(), &req)
|
resp, err := client.Show(cmd.Context(), &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -624,10 +715,10 @@ func ShowHandler(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return showInfo(resp, os.Stdout)
|
return showInfo(resp, verbose, os.Stdout)
|
||||||
}
|
}
|
||||||
|
|
||||||
func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
|
||||||
tableRender := func(header string, rows func() [][]string) {
|
tableRender := func(header string, rows func() [][]string) {
|
||||||
fmt.Fprintln(w, " ", header)
|
fmt.Fprintln(w, " ", header)
|
||||||
table := tablewriter.NewWriter(w)
|
table := tablewriter.NewWriter(w)
|
||||||
@@ -661,6 +752,15 @@ func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
|||||||
return
|
return
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if len(resp.Capabilities) > 0 {
|
||||||
|
tableRender("Capabilities", func() (rows [][]string) {
|
||||||
|
for _, capability := range resp.Capabilities {
|
||||||
|
rows = append(rows, []string{"", capability.String()})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if resp.ProjectorInfo != nil {
|
if resp.ProjectorInfo != nil {
|
||||||
tableRender("Projector", func() (rows [][]string) {
|
tableRender("Projector", func() (rows [][]string) {
|
||||||
arch := resp.ProjectorInfo["general.architecture"].(string)
|
arch := resp.ProjectorInfo["general.architecture"].(string)
|
||||||
@@ -684,12 +784,89 @@ func showInfo(resp *api.ShowResponse, w io.Writer) error {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resp.ModelInfo != nil && verbose {
|
||||||
|
tableRender("Metadata", func() (rows [][]string) {
|
||||||
|
keys := make([]string, 0, len(resp.ModelInfo))
|
||||||
|
for k := range resp.ModelInfo {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for _, k := range keys {
|
||||||
|
var v string
|
||||||
|
switch vData := resp.ModelInfo[k].(type) {
|
||||||
|
case bool:
|
||||||
|
v = fmt.Sprintf("%t", vData)
|
||||||
|
case string:
|
||||||
|
v = vData
|
||||||
|
case float64:
|
||||||
|
v = fmt.Sprintf("%g", vData)
|
||||||
|
case []any:
|
||||||
|
targetWidth := 10 // Small width where we are displaying the data in a column
|
||||||
|
|
||||||
|
var itemsToShow int
|
||||||
|
totalWidth := 1 // Start with 1 for opening bracket
|
||||||
|
|
||||||
|
// Find how many we can fit
|
||||||
|
for i := range vData {
|
||||||
|
itemStr := fmt.Sprintf("%v", vData[i])
|
||||||
|
width := runewidth.StringWidth(itemStr)
|
||||||
|
|
||||||
|
// Add separator width (", ") for all items except the first
|
||||||
|
if i > 0 {
|
||||||
|
width += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if adding this item would exceed our width limit
|
||||||
|
if totalWidth+width > targetWidth && i > 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
totalWidth += width
|
||||||
|
itemsToShow++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the output
|
||||||
|
if itemsToShow < len(vData) {
|
||||||
|
v = fmt.Sprintf("%v", vData[:itemsToShow])
|
||||||
|
v = strings.TrimSuffix(v, "]")
|
||||||
|
v += fmt.Sprintf(" ...+%d more]", len(vData)-itemsToShow)
|
||||||
|
} else {
|
||||||
|
v = fmt.Sprintf("%v", vData)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
v = fmt.Sprintf("%T", vData)
|
||||||
|
}
|
||||||
|
rows = append(rows, []string{"", k, v})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Tensors) > 0 && verbose {
|
||||||
|
tableRender("Tensors", func() (rows [][]string) {
|
||||||
|
for _, t := range resp.Tensors {
|
||||||
|
rows = append(rows, []string{"", t.Name, t.Type, fmt.Sprint(t.Shape)})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
head := func(s string, n int) (rows [][]string) {
|
head := func(s string, n int) (rows [][]string) {
|
||||||
scanner := bufio.NewScanner(strings.NewReader(s))
|
scanner := bufio.NewScanner(strings.NewReader(s))
|
||||||
for scanner.Scan() && (len(rows) < n || n < 0) {
|
count := 0
|
||||||
if text := scanner.Text(); text != "" {
|
for scanner.Scan() {
|
||||||
rows = append(rows, []string{"", strings.TrimSpace(text)})
|
text := strings.TrimSpace(scanner.Text())
|
||||||
|
if text == "" {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
count++
|
||||||
|
if n < 0 || count <= n {
|
||||||
|
rows = append(rows, []string{"", text})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if n >= 0 && count > n {
|
||||||
|
rows = append(rows, []string{"", "..."})
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -744,13 +921,38 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
fn := func(resp api.ProgressResponse) error {
|
fn := func(resp api.ProgressResponse) error {
|
||||||
if resp.Digest != "" {
|
if resp.Digest != "" {
|
||||||
|
if resp.Completed == 0 {
|
||||||
|
// This is the initial status update for the
|
||||||
|
// layer, which the server sends before
|
||||||
|
// beginning the download, for clients to
|
||||||
|
// compute total size and prepare for
|
||||||
|
// downloads, if needed.
|
||||||
|
//
|
||||||
|
// Skipping this here to avoid showing a 0%
|
||||||
|
// progress bar, which *should* clue the user
|
||||||
|
// into the fact that many things are being
|
||||||
|
// downloaded and that the current active
|
||||||
|
// download is not that last. However, in rare
|
||||||
|
// cases it seems to be triggering to some, and
|
||||||
|
// it isn't worth explaining, so just ignore
|
||||||
|
// and regress to the old UI that keeps giving
|
||||||
|
// you the "But wait, there is more!" after
|
||||||
|
// each "100% done" bar, which is "better."
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if spinner != nil {
|
if spinner != nil {
|
||||||
spinner.Stop()
|
spinner.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
bar, ok := bars[resp.Digest]
|
bar, ok := bars[resp.Digest]
|
||||||
if !ok {
|
if !ok {
|
||||||
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
|
name, isDigest := strings.CutPrefix(resp.Digest, "sha256:")
|
||||||
|
name = strings.TrimSpace(name)
|
||||||
|
if isDigest {
|
||||||
|
name = name[:min(12, len(name))]
|
||||||
|
}
|
||||||
|
bar = progress.NewBar(fmt.Sprintf("pulling %s:", name), resp.Total, resp.Completed)
|
||||||
bars[resp.Digest] = bar
|
bars[resp.Digest] = bar
|
||||||
p.Add(resp.Digest, bar)
|
p.Add(resp.Digest, bar)
|
||||||
}
|
}
|
||||||
@@ -770,27 +972,25 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
request := api.PullRequest{Name: args[0], Insecure: insecure}
|
||||||
if err := client.Pull(cmd.Context(), &request, fn); err != nil {
|
return client.Pull(cmd.Context(), &request, fn)
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type generateContextKey string
|
type generateContextKey string
|
||||||
|
|
||||||
type runOptions struct {
|
type runOptions struct {
|
||||||
Model string
|
Model string
|
||||||
ParentModel string
|
ParentModel string
|
||||||
Prompt string
|
Prompt string
|
||||||
Messages []api.Message
|
Messages []api.Message
|
||||||
WordWrap bool
|
WordWrap bool
|
||||||
Format string
|
Format string
|
||||||
System string
|
System string
|
||||||
Images []api.ImageData
|
Images []api.ImageData
|
||||||
Options map[string]interface{}
|
Options map[string]any
|
||||||
MultiModal bool
|
MultiModal bool
|
||||||
KeepAlive *api.Duration
|
KeepAlive *api.Duration
|
||||||
|
Think *api.ThinkValue
|
||||||
|
HideThinking bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type displayResponseState struct {
|
type displayResponseState struct {
|
||||||
@@ -829,10 +1029,11 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch ch {
|
switch ch {
|
||||||
case ' ':
|
case ' ', '\t':
|
||||||
state.wordBuffer = ""
|
state.wordBuffer = ""
|
||||||
case '\n':
|
case '\n', '\r':
|
||||||
state.lineLength = 0
|
state.lineLength = 0
|
||||||
|
state.wordBuffer = ""
|
||||||
default:
|
default:
|
||||||
state.wordBuffer += string(ch)
|
state.wordBuffer += string(ch)
|
||||||
}
|
}
|
||||||
@@ -846,6 +1047,26 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func thinkingOutputOpeningText(plainText bool) string {
|
||||||
|
text := "Thinking...\n"
|
||||||
|
|
||||||
|
if plainText {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault + readline.ColorGrey
|
||||||
|
}
|
||||||
|
|
||||||
|
func thinkingOutputClosingText(plainText bool) string {
|
||||||
|
text := "...done thinking.\n\n"
|
||||||
|
|
||||||
|
if plainText {
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
|
||||||
|
return readline.ColorGrey + readline.ColorBold + text + readline.ColorDefault
|
||||||
|
}
|
||||||
|
|
||||||
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||||
client, err := api.ClientFromEnvironment()
|
client, err := api.ClientFromEnvironment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -870,19 +1091,55 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var state *displayResponseState = &displayResponseState{}
|
var state *displayResponseState = &displayResponseState{}
|
||||||
|
var thinkingContent strings.Builder
|
||||||
var latest api.ChatResponse
|
var latest api.ChatResponse
|
||||||
var fullResponse strings.Builder
|
var fullResponse strings.Builder
|
||||||
var role string
|
var thinkTagOpened bool = false
|
||||||
|
var thinkTagClosed bool = false
|
||||||
|
|
||||||
|
role := "assistant"
|
||||||
|
|
||||||
fn := func(response api.ChatResponse) error {
|
fn := func(response api.ChatResponse) error {
|
||||||
p.StopAndClear()
|
if response.Message.Content != "" || !opts.HideThinking {
|
||||||
|
p.StopAndClear()
|
||||||
|
}
|
||||||
|
|
||||||
latest = response
|
latest = response
|
||||||
|
|
||||||
role = response.Message.Role
|
role = response.Message.Role
|
||||||
|
if response.Message.Thinking != "" && !opts.HideThinking {
|
||||||
|
if !thinkTagOpened {
|
||||||
|
fmt.Print(thinkingOutputOpeningText(false))
|
||||||
|
thinkTagOpened = true
|
||||||
|
thinkTagClosed = false
|
||||||
|
}
|
||||||
|
thinkingContent.WriteString(response.Message.Thinking)
|
||||||
|
displayResponse(response.Message.Thinking, opts.WordWrap, state)
|
||||||
|
}
|
||||||
|
|
||||||
content := response.Message.Content
|
content := response.Message.Content
|
||||||
|
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.Message.ToolCalls) > 0) {
|
||||||
|
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
fmt.Print(thinkingOutputClosingText(false))
|
||||||
|
thinkTagOpened = false
|
||||||
|
thinkTagClosed = true
|
||||||
|
state = &displayResponseState{}
|
||||||
|
}
|
||||||
|
// purposefully not putting thinking blocks in the response, which would
|
||||||
|
// only be needed if we later added tool calling to the cli (they get
|
||||||
|
// filtered out anyway since current models don't expect them unless you're
|
||||||
|
// about to finish some tool calls)
|
||||||
fullResponse.WriteString(content)
|
fullResponse.WriteString(content)
|
||||||
|
|
||||||
|
if response.Message.ToolCalls != nil {
|
||||||
|
toolCalls := response.Message.ToolCalls
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
fmt.Print(renderToolCalls(toolCalls, false))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
displayResponse(content, opts.WordWrap, state)
|
displayResponse(content, opts.WordWrap, state)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -897,6 +1154,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
Messages: opts.Messages,
|
Messages: opts.Messages,
|
||||||
Format: json.RawMessage(opts.Format),
|
Format: json.RawMessage(opts.Format),
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
|
Think: opts.Think,
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.KeepAlive != nil {
|
if opts.KeepAlive != nil {
|
||||||
@@ -907,6 +1165,14 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
|||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// this error should ideally be wrapped properly by the client
|
||||||
|
if strings.Contains(err.Error(), "upstream error") {
|
||||||
|
p.StopAndClear()
|
||||||
|
fmt.Println("An error occurred while processing your message. Please try again.")
|
||||||
|
fmt.Println()
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -958,15 +1224,49 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var state *displayResponseState = &displayResponseState{}
|
var state *displayResponseState = &displayResponseState{}
|
||||||
|
var thinkingContent strings.Builder
|
||||||
|
var thinkTagOpened bool = false
|
||||||
|
var thinkTagClosed bool = false
|
||||||
|
|
||||||
|
plainText := !term.IsTerminal(int(os.Stdout.Fd()))
|
||||||
|
|
||||||
fn := func(response api.GenerateResponse) error {
|
fn := func(response api.GenerateResponse) error {
|
||||||
p.StopAndClear()
|
|
||||||
|
|
||||||
latest = response
|
latest = response
|
||||||
content := response.Response
|
content := response.Response
|
||||||
|
|
||||||
|
if response.Response != "" || !opts.HideThinking {
|
||||||
|
p.StopAndClear()
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Thinking != "" && !opts.HideThinking {
|
||||||
|
if !thinkTagOpened {
|
||||||
|
fmt.Print(thinkingOutputOpeningText(plainText))
|
||||||
|
thinkTagOpened = true
|
||||||
|
thinkTagClosed = false
|
||||||
|
}
|
||||||
|
thinkingContent.WriteString(response.Thinking)
|
||||||
|
displayResponse(response.Thinking, opts.WordWrap, state)
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.ToolCalls) > 0) {
|
||||||
|
if !strings.HasSuffix(thinkingContent.String(), "\n") {
|
||||||
|
fmt.Println()
|
||||||
|
}
|
||||||
|
fmt.Print(thinkingOutputClosingText(plainText))
|
||||||
|
thinkTagOpened = false
|
||||||
|
thinkTagClosed = true
|
||||||
|
state = &displayResponseState{}
|
||||||
|
}
|
||||||
|
|
||||||
displayResponse(content, opts.WordWrap, state)
|
displayResponse(content, opts.WordWrap, state)
|
||||||
|
|
||||||
|
if response.ToolCalls != nil {
|
||||||
|
toolCalls := response.ToolCalls
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
fmt.Print(renderToolCalls(toolCalls, plainText))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -990,6 +1290,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
|||||||
System: opts.System,
|
System: opts.System,
|
||||||
Options: opts.Options,
|
Options: opts.Options,
|
||||||
KeepAlive: opts.KeepAlive,
|
KeepAlive: opts.KeepAlive,
|
||||||
|
Think: opts.Think,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||||
@@ -1093,11 +1394,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := client.Heartbeat(cmd.Context()); err != nil {
|
if err := client.Heartbeat(cmd.Context()); err != nil {
|
||||||
if !strings.Contains(err.Error(), " refused") {
|
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := startApp(cmd.Context(), client); err != nil {
|
if err := startApp(cmd.Context(), client); err != nil {
|
||||||
return errors.New("could not connect to ollama app, is it running?")
|
return fmt.Errorf("ollama server not responding - %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -1168,14 +1469,14 @@ func NewCLI() *cobra.Command {
|
|||||||
|
|
||||||
createCmd := &cobra.Command{
|
createCmd := &cobra.Command{
|
||||||
Use: "create MODEL",
|
Use: "create MODEL",
|
||||||
Short: "Create a model from a Modelfile",
|
Short: "Create a model",
|
||||||
Args: cobra.ExactArgs(1),
|
Args: cobra.ExactArgs(1),
|
||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
RunE: CreateHandler,
|
RunE: CreateHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
|
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\")")
|
||||||
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
|
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
|
||||||
|
|
||||||
showCmd := &cobra.Command{
|
showCmd := &cobra.Command{
|
||||||
Use: "show MODEL",
|
Use: "show MODEL",
|
||||||
@@ -1190,6 +1491,7 @@ func NewCLI() *cobra.Command {
|
|||||||
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
|
showCmd.Flags().Bool("parameters", false, "Show parameters of a model")
|
||||||
showCmd.Flags().Bool("template", false, "Show template of a model")
|
showCmd.Flags().Bool("template", false, "Show template of a model")
|
||||||
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
||||||
|
showCmd.Flags().BoolP("verbose", "v", false, "Show detailed model information")
|
||||||
|
|
||||||
runCmd := &cobra.Command{
|
runCmd := &cobra.Command{
|
||||||
Use: "run MODEL [PROMPT]",
|
Use: "run MODEL [PROMPT]",
|
||||||
@@ -1204,6 +1506,9 @@ func NewCLI() *cobra.Command {
|
|||||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||||
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
runCmd.Flags().Bool("nowordwrap", false, "Don't wrap words to the next line automatically")
|
||||||
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
runCmd.Flags().String("format", "", "Response format (e.g. json)")
|
||||||
|
runCmd.Flags().String("think", "", "Enable thinking mode: true/false or high/medium/low for supported models")
|
||||||
|
runCmd.Flags().Lookup("think").NoOptDefVal = "true"
|
||||||
|
runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)")
|
||||||
|
|
||||||
stopCmd := &cobra.Command{
|
stopCmd := &cobra.Command{
|
||||||
Use: "stop MODEL",
|
Use: "stop MODEL",
|
||||||
@@ -1255,7 +1560,6 @@ func NewCLI() *cobra.Command {
|
|||||||
PreRunE: checkServerHeartbeat,
|
PreRunE: checkServerHeartbeat,
|
||||||
RunE: ListRunningHandler,
|
RunE: ListRunningHandler,
|
||||||
}
|
}
|
||||||
|
|
||||||
copyCmd := &cobra.Command{
|
copyCmd := &cobra.Command{
|
||||||
Use: "cp SOURCE DESTINATION",
|
Use: "cp SOURCE DESTINATION",
|
||||||
Short: "Copy a model",
|
Short: "Copy a model",
|
||||||
@@ -1274,7 +1578,6 @@ func NewCLI() *cobra.Command {
|
|||||||
|
|
||||||
runnerCmd := &cobra.Command{
|
runnerCmd := &cobra.Command{
|
||||||
Use: "runner",
|
Use: "runner",
|
||||||
Short: llama.PrintSystemInfo(),
|
|
||||||
Hidden: true,
|
Hidden: true,
|
||||||
RunE: func(cmd *cobra.Command, args []string) error {
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
return runner.Execute(os.Args[1:])
|
return runner.Execute(os.Args[1:])
|
||||||
@@ -1309,6 +1612,7 @@ func NewCLI() *cobra.Command {
|
|||||||
appendEnvDocs(cmd, []envconfig.EnvVar{
|
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||||
envVars["OLLAMA_DEBUG"],
|
envVars["OLLAMA_DEBUG"],
|
||||||
envVars["OLLAMA_HOST"],
|
envVars["OLLAMA_HOST"],
|
||||||
|
envVars["OLLAMA_CONTEXT_LENGTH"],
|
||||||
envVars["OLLAMA_KEEP_ALIVE"],
|
envVars["OLLAMA_KEEP_ALIVE"],
|
||||||
envVars["OLLAMA_MAX_LOADED_MODELS"],
|
envVars["OLLAMA_MAX_LOADED_MODELS"],
|
||||||
envVars["OLLAMA_MAX_QUEUE"],
|
envVars["OLLAMA_MAX_QUEUE"],
|
||||||
@@ -1317,7 +1621,6 @@ func NewCLI() *cobra.Command {
|
|||||||
envVars["OLLAMA_NOPRUNE"],
|
envVars["OLLAMA_NOPRUNE"],
|
||||||
envVars["OLLAMA_ORIGINS"],
|
envVars["OLLAMA_ORIGINS"],
|
||||||
envVars["OLLAMA_SCHED_SPREAD"],
|
envVars["OLLAMA_SCHED_SPREAD"],
|
||||||
envVars["OLLAMA_TMPDIR"],
|
|
||||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||||
envVars["OLLAMA_KV_CACHE_TYPE"],
|
envVars["OLLAMA_KV_CACHE_TYPE"],
|
||||||
envVars["OLLAMA_LLM_LIBRARY"],
|
envVars["OLLAMA_LLM_LIBRARY"],
|
||||||
@@ -1346,3 +1649,70 @@ func NewCLI() *cobra.Command {
|
|||||||
|
|
||||||
return rootCmd
|
return rootCmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the user has explicitly set thinking options, either through the CLI or
|
||||||
|
// through the `/set think` or `set nothink` interactive options, then we
|
||||||
|
// respect them. Otherwise, we check model capabilities to see if the model
|
||||||
|
// supports thinking. If the model does support thinking, we enable it.
|
||||||
|
// Otherwise, we unset the thinking option (which is different than setting it
|
||||||
|
// to false).
|
||||||
|
//
|
||||||
|
// If capabilities are not provided, we fetch them from the server.
|
||||||
|
func inferThinkingOption(caps *[]model.Capability, runOpts *runOptions, explicitlySetByUser bool) (*api.ThinkValue, error) {
|
||||||
|
if explicitlySetByUser {
|
||||||
|
return runOpts.Think, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if caps == nil {
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ret, err := client.Show(context.Background(), &api.ShowRequest{
|
||||||
|
Model: runOpts.Model,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
caps = &ret.Capabilities
|
||||||
|
}
|
||||||
|
|
||||||
|
thinkingSupported := false
|
||||||
|
for _, cap := range *caps {
|
||||||
|
if cap == model.CapabilityThinking {
|
||||||
|
thinkingSupported = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if thinkingSupported {
|
||||||
|
return &api.ThinkValue{Value: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string {
|
||||||
|
out := ""
|
||||||
|
formatExplanation := ""
|
||||||
|
formatValues := ""
|
||||||
|
if !plainText {
|
||||||
|
formatExplanation = readline.ColorGrey + readline.ColorBold
|
||||||
|
formatValues = readline.ColorDefault
|
||||||
|
out += formatExplanation
|
||||||
|
}
|
||||||
|
for i, toolCall := range toolCalls {
|
||||||
|
argsAsJSON, err := json.Marshal(toolCall.Function.Arguments)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if i > 0 {
|
||||||
|
out += "\n"
|
||||||
|
}
|
||||||
|
// all tool calls are unexpected since we don't currently support registering any in the CLI
|
||||||
|
out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation)
|
||||||
|
}
|
||||||
|
if !plainText {
|
||||||
|
out += readline.ColorDefault
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|||||||
246
cmd/cmd_test.go
246
cmd/cmd_test.go
@@ -2,7 +2,6 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -16,6 +15,7 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestShowInfo(t *testing.T) {
|
func TestShowInfo(t *testing.T) {
|
||||||
@@ -27,7 +27,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
ParameterSize: "7B",
|
ParameterSize: "7B",
|
||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,6 +68,60 @@ func TestShowInfo(t *testing.T) {
|
|||||||
embedding length 0
|
embedding length 0
|
||||||
quantization FP16
|
quantization FP16
|
||||||
|
|
||||||
|
`
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("verbose model", func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "test",
|
||||||
|
ParameterSize: "8B",
|
||||||
|
QuantizationLevel: "FP16",
|
||||||
|
},
|
||||||
|
Parameters: `
|
||||||
|
stop up`,
|
||||||
|
ModelInfo: map[string]any{
|
||||||
|
"general.architecture": "test",
|
||||||
|
"general.parameter_count": float64(8_000_000_000),
|
||||||
|
"some.true_bool": true,
|
||||||
|
"some.false_bool": false,
|
||||||
|
"test.context_length": float64(1000),
|
||||||
|
"test.embedding_length": float64(11434),
|
||||||
|
},
|
||||||
|
Tensors: []api.Tensor{
|
||||||
|
{Name: "blk.0.attn_k.weight", Type: "BF16", Shape: []uint64{42, 3117}},
|
||||||
|
{Name: "blk.0.attn_q.weight", Type: "FP16", Shape: []uint64{3117, 42}},
|
||||||
|
},
|
||||||
|
}, true, &b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := ` Model
|
||||||
|
architecture test
|
||||||
|
parameters 8B
|
||||||
|
context length 1000
|
||||||
|
embedding length 11434
|
||||||
|
quantization FP16
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
stop up
|
||||||
|
|
||||||
|
Metadata
|
||||||
|
general.architecture test
|
||||||
|
general.parameter_count 8e+09
|
||||||
|
some.false_bool false
|
||||||
|
some.true_bool true
|
||||||
|
test.context_length 1000
|
||||||
|
test.embedding_length 11434
|
||||||
|
|
||||||
|
Tensors
|
||||||
|
blk.0.attn_k.weight BF16 [42 3117]
|
||||||
|
blk.0.attn_q.weight FP16 [3117 42]
|
||||||
|
|
||||||
`
|
`
|
||||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
@@ -89,7 +143,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
stop you
|
stop you
|
||||||
stop up
|
stop up
|
||||||
temperature 99`,
|
temperature 99`,
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +180,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
"clip.vision.embedding_length": float64(0),
|
"clip.vision.embedding_length": float64(0),
|
||||||
"clip.vision.projection_dim": float64(0),
|
"clip.vision.projection_dim": float64(0),
|
||||||
},
|
},
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,7 +213,7 @@ func TestShowInfo(t *testing.T) {
|
|||||||
Ahoy, matey!
|
Ahoy, matey!
|
||||||
Weigh anchor!
|
Weigh anchor!
|
||||||
`,
|
`,
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,6 +225,7 @@ Weigh anchor!
|
|||||||
System
|
System
|
||||||
You are a pirate!
|
You are a pirate!
|
||||||
Ahoy, matey!
|
Ahoy, matey!
|
||||||
|
...
|
||||||
|
|
||||||
`
|
`
|
||||||
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
@@ -188,7 +243,7 @@ Weigh anchor!
|
|||||||
QuantizationLevel: "FP16",
|
QuantizationLevel: "FP16",
|
||||||
},
|
},
|
||||||
License: license,
|
License: license,
|
||||||
}, &b); err != nil {
|
}, false, &b); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,6 +261,34 @@ Weigh anchor!
|
|||||||
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("capabilities", func(t *testing.T) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if err := showInfo(&api.ShowResponse{
|
||||||
|
Details: api.ModelDetails{
|
||||||
|
Family: "test",
|
||||||
|
ParameterSize: "7B",
|
||||||
|
QuantizationLevel: "FP16",
|
||||||
|
},
|
||||||
|
Capabilities: []model.Capability{model.CapabilityVision, model.CapabilityTools},
|
||||||
|
}, false, &b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect := " Model\n" +
|
||||||
|
" architecture test \n" +
|
||||||
|
" parameters 7B \n" +
|
||||||
|
" quantization FP16 \n" +
|
||||||
|
"\n" +
|
||||||
|
" Capabilities\n" +
|
||||||
|
" vision \n" +
|
||||||
|
" tools \n" +
|
||||||
|
"\n"
|
||||||
|
|
||||||
|
if diff := cmp.Diff(expect, b.String()); diff != "" {
|
||||||
|
t.Errorf("unexpected output (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDeleteHandler(t *testing.T) {
|
func TestDeleteHandler(t *testing.T) {
|
||||||
@@ -254,7 +337,7 @@ func TestDeleteHandler(t *testing.T) {
|
|||||||
t.Cleanup(mockServer.Close)
|
t.Cleanup(mockServer.Close)
|
||||||
|
|
||||||
cmd := &cobra.Command{}
|
cmd := &cobra.Command{}
|
||||||
cmd.SetContext(context.TODO())
|
cmd.SetContext(t.Context())
|
||||||
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
|
if err := DeleteHandler(cmd, []string{"test-model"}); err != nil {
|
||||||
t.Fatalf("DeleteHandler failed: %v", err)
|
t.Fatalf("DeleteHandler failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -316,11 +399,6 @@ func TestGetModelfileName(t *testing.T) {
|
|||||||
var expectedFilename string
|
var expectedFilename string
|
||||||
|
|
||||||
if tt.fileExists {
|
if tt.fileExists {
|
||||||
tempDir, err := os.MkdirTemp("", "modelfiledir")
|
|
||||||
defer os.RemoveAll(tempDir)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("temp modelfile dir creation failed: %v", err)
|
|
||||||
}
|
|
||||||
var fn string
|
var fn string
|
||||||
if tt.modelfileName != "" {
|
if tt.modelfileName != "" {
|
||||||
fn = tt.modelfileName
|
fn = tt.modelfileName
|
||||||
@@ -328,10 +406,11 @@ func TestGetModelfileName(t *testing.T) {
|
|||||||
fn = "Modelfile"
|
fn = "Modelfile"
|
||||||
}
|
}
|
||||||
|
|
||||||
tempFile, err := os.CreateTemp(tempDir, fn)
|
tempFile, err := os.CreateTemp(t.TempDir(), fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("temp modelfile creation failed: %v", err)
|
t.Fatalf("temp modelfile creation failed: %v", err)
|
||||||
}
|
}
|
||||||
|
defer tempFile.Close()
|
||||||
|
|
||||||
expectedFilename = tempFile.Name()
|
expectedFilename = tempFile.Name()
|
||||||
err = cmd.Flags().Set("file", expectedFilename)
|
err = cmd.Flags().Set("file", expectedFilename)
|
||||||
@@ -446,7 +525,7 @@ func TestPushHandler(t *testing.T) {
|
|||||||
|
|
||||||
cmd := &cobra.Command{}
|
cmd := &cobra.Command{}
|
||||||
cmd.Flags().Bool("insecure", false, "")
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
cmd.SetContext(context.TODO())
|
cmd.SetContext(t.Context())
|
||||||
|
|
||||||
// Redirect stderr to capture progress output
|
// Redirect stderr to capture progress output
|
||||||
oldStderr := os.Stderr
|
oldStderr := os.Stderr
|
||||||
@@ -551,7 +630,7 @@ func TestListHandler(t *testing.T) {
|
|||||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
|
|
||||||
cmd := &cobra.Command{}
|
cmd := &cobra.Command{}
|
||||||
cmd.SetContext(context.TODO())
|
cmd.SetContext(t.Context())
|
||||||
|
|
||||||
// Capture stdout
|
// Capture stdout
|
||||||
oldStdout := os.Stdout
|
oldStdout := os.Stdout
|
||||||
@@ -606,7 +685,7 @@ func TestCreateHandler(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if req.Name != "test-model" {
|
if req.Model != "test-model" {
|
||||||
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
t.Errorf("expected model name 'test-model', got %s", req.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -646,7 +725,7 @@ func TestCreateHandler(t *testing.T) {
|
|||||||
}))
|
}))
|
||||||
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
t.Setenv("OLLAMA_HOST", mockServer.URL)
|
||||||
t.Cleanup(mockServer.Close)
|
t.Cleanup(mockServer.Close)
|
||||||
tempFile, err := os.CreateTemp("", "modelfile")
|
tempFile, err := os.CreateTemp(t.TempDir(), "modelfile")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -666,7 +745,7 @@ func TestCreateHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
cmd.Flags().Bool("insecure", false, "")
|
cmd.Flags().Bool("insecure", false, "")
|
||||||
cmd.SetContext(context.TODO())
|
cmd.SetContext(t.Context())
|
||||||
|
|
||||||
// Redirect stderr to capture progress output
|
// Redirect stderr to capture progress output
|
||||||
oldStderr := os.Stderr
|
oldStderr := os.Stderr
|
||||||
@@ -707,3 +786,132 @@ func TestCreateHandler(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestNewCreateRequest(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
from string
|
||||||
|
opts runOptions
|
||||||
|
expected *api.CreateRequest
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
"basic test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "",
|
||||||
|
Prompt: "You are a fun AI agent",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "mymodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"parent model test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"parent model as filepath test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "/some/file/like/etc/passwd",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "mymodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"parent model as windows filepath test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "D:\\some\\file\\like\\etc\\passwd",
|
||||||
|
Messages: []api.Message{},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "mymodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"options test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
Options: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
Parameters: map[string]any{
|
||||||
|
"temperature": 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"messages test",
|
||||||
|
"newmodel",
|
||||||
|
runOptions{
|
||||||
|
Model: "mymodel",
|
||||||
|
ParentModel: "parentmodel",
|
||||||
|
System: "You are a fun AI agent",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hello there!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "hello to you!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
WordWrap: true,
|
||||||
|
},
|
||||||
|
&api.CreateRequest{
|
||||||
|
From: "parentmodel",
|
||||||
|
Model: "newmodel",
|
||||||
|
System: "You are a fun AI agent",
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "hello there!",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "hello to you!",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
actual := NewCreateRequest(tt.from, tt.opts)
|
||||||
|
if !cmp.Equal(actual, tt.expected) {
|
||||||
|
t.Errorf("expected output %#v, got %#v", tt.expected, actual)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/readline"
|
"github.com/ollama/ollama/readline"
|
||||||
"github.com/ollama/ollama/types/errtypes"
|
"github.com/ollama/ollama/types/errtypes"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MultilineState int
|
type MultilineState int
|
||||||
@@ -43,7 +44,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||||
|
|
||||||
if opts.MultiModal {
|
if opts.MultiModal {
|
||||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
|
fmt.Fprintf(os.Stderr, "Use %s to include .jpg, .png, or .webp images.\n", filepath.FromSlash("/path/to/file"))
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
@@ -61,6 +62,8 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
fmt.Fprintln(os.Stderr, " /set noformat Disable formatting")
|
||||||
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
fmt.Fprintln(os.Stderr, " /set verbose Show LLM stats")
|
||||||
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
fmt.Fprintln(os.Stderr, " /set quiet Disable LLM stats")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set think Enable thinking")
|
||||||
|
fmt.Fprintln(os.Stderr, " /set nothink Disable thinking")
|
||||||
fmt.Fprintln(os.Stderr, "")
|
fmt.Fprintln(os.Stderr, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,6 +130,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
var multiline MultilineState
|
var multiline MultilineState
|
||||||
|
var thinkExplicitlySet bool = opts.Think != nil
|
||||||
|
|
||||||
for {
|
for {
|
||||||
line, err := scanner.Readline()
|
line, err := scanner.Readline()
|
||||||
@@ -194,7 +198,19 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
opts.Model = args[1]
|
opts.Model = args[1]
|
||||||
opts.Messages = []api.Message{}
|
opts.Messages = []api.Message{}
|
||||||
fmt.Printf("Loading model '%s'\n", opts.Model)
|
fmt.Printf("Loading model '%s'\n", opts.Model)
|
||||||
|
opts.Think, err = inferThinkingOption(nil, &opts, thinkExplicitlySet)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
if err := loadOrUnloadModel(cmd, &opts); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "not found") {
|
||||||
|
fmt.Printf("error: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), "does not support thinking") {
|
||||||
|
fmt.Printf("error: %v\n", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -255,6 +271,35 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fmt.Println("Set 'quiet' mode.")
|
fmt.Println("Set 'quiet' mode.")
|
||||||
|
case "think":
|
||||||
|
thinkValue := api.ThinkValue{Value: true}
|
||||||
|
var maybeLevel string
|
||||||
|
if len(args) > 2 {
|
||||||
|
maybeLevel = args[2]
|
||||||
|
}
|
||||||
|
if maybeLevel != "" {
|
||||||
|
// TODO(drifkin): validate the level, could be model dependent
|
||||||
|
// though... It will also be validated on the server once a call is
|
||||||
|
// made.
|
||||||
|
thinkValue.Value = maybeLevel
|
||||||
|
}
|
||||||
|
opts.Think = &thinkValue
|
||||||
|
thinkExplicitlySet = true
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||||
|
}
|
||||||
|
if maybeLevel != "" {
|
||||||
|
fmt.Printf("Set 'think' mode to '%s'.\n", maybeLevel)
|
||||||
|
} else {
|
||||||
|
fmt.Println("Set 'think' mode.")
|
||||||
|
}
|
||||||
|
case "nothink":
|
||||||
|
opts.Think = &api.ThinkValue{Value: false}
|
||||||
|
thinkExplicitlySet = true
|
||||||
|
if client, err := api.ClientFromEnvironment(); err == nil {
|
||||||
|
ensureThinkingSupport(cmd.Context(), client, opts.Model)
|
||||||
|
}
|
||||||
|
fmt.Println("Set 'nothink' mode.")
|
||||||
case "format":
|
case "format":
|
||||||
if len(args) < 3 || args[2] != "json" {
|
if len(args) < 3 || args[2] != "json" {
|
||||||
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
fmt.Println("Invalid or missing format. For 'json' mode use '/set format json'")
|
||||||
@@ -343,7 +388,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
|
|
||||||
switch args[1] {
|
switch args[1] {
|
||||||
case "info":
|
case "info":
|
||||||
_ = showInfo(resp, os.Stderr)
|
_ = showInfo(resp, false, os.Stderr)
|
||||||
case "license":
|
case "license":
|
||||||
if resp.License == "" {
|
if resp.License == "" {
|
||||||
fmt.Println("No license was specified for this model.")
|
fmt.Println("No license was specified for this model.")
|
||||||
@@ -353,18 +398,21 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
case "modelfile":
|
case "modelfile":
|
||||||
fmt.Println(resp.Modelfile)
|
fmt.Println(resp.Modelfile)
|
||||||
case "parameters":
|
case "parameters":
|
||||||
|
fmt.Println("Model defined parameters:")
|
||||||
if resp.Parameters == "" {
|
if resp.Parameters == "" {
|
||||||
fmt.Println("No parameters were specified for this model.")
|
fmt.Println(" No additional parameters were specified for this model.")
|
||||||
} else {
|
} else {
|
||||||
if len(opts.Options) > 0 {
|
for _, l := range strings.Split(resp.Parameters, "\n") {
|
||||||
fmt.Println("User defined parameters:")
|
fmt.Printf(" %s\n", l)
|
||||||
for k, v := range opts.Options {
|
|
||||||
fmt.Printf("%-*s %v\n", 30, k, v)
|
|
||||||
}
|
|
||||||
fmt.Println()
|
|
||||||
}
|
}
|
||||||
fmt.Println("Model defined parameters:")
|
}
|
||||||
fmt.Println(resp.Parameters)
|
fmt.Println()
|
||||||
|
if len(opts.Options) > 0 {
|
||||||
|
fmt.Println("User defined parameters:")
|
||||||
|
for k, v := range opts.Options {
|
||||||
|
fmt.Printf(" %-*s %v\n", 30, k, v)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
}
|
}
|
||||||
case "system":
|
case "system":
|
||||||
switch {
|
switch {
|
||||||
@@ -443,6 +491,12 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
|
|
||||||
assistant, err := chat(cmd, opts)
|
assistant, err := chat(cmd, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "does not support thinking") ||
|
||||||
|
strings.Contains(err.Error(), "invalid think value") {
|
||||||
|
fmt.Printf("error: %v\n", err)
|
||||||
|
sb.Reset()
|
||||||
|
continue
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if assistant != nil {
|
if assistant != nil {
|
||||||
@@ -455,9 +509,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
|
||||||
|
parentModel := opts.ParentModel
|
||||||
|
|
||||||
|
modelName := model.ParseName(parentModel)
|
||||||
|
if !modelName.IsValid() {
|
||||||
|
parentModel = ""
|
||||||
|
}
|
||||||
|
|
||||||
req := &api.CreateRequest{
|
req := &api.CreateRequest{
|
||||||
Name: name,
|
Model: name,
|
||||||
From: cmp.Or(opts.ParentModel, opts.Model),
|
From: cmp.Or(parentModel, opts.Model),
|
||||||
}
|
}
|
||||||
|
|
||||||
if opts.System != "" {
|
if opts.System != "" {
|
||||||
@@ -491,6 +552,7 @@ func normalizeFilePath(fp string) string {
|
|||||||
"\\\\", "\\", // Escaped backslash
|
"\\\\", "\\", // Escaped backslash
|
||||||
"\\*", "*", // Escaped asterisk
|
"\\*", "*", // Escaped asterisk
|
||||||
"\\?", "?", // Escaped question mark
|
"\\?", "?", // Escaped question mark
|
||||||
|
"\\~", "~", // Escaped tilde
|
||||||
).Replace(fp)
|
).Replace(fp)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -498,7 +560,7 @@ func extractFileNames(input string) []string {
|
|||||||
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
// Regex to match file paths starting with optional drive letter, / ./ \ or .\ and include escaped or unescaped spaces (\ or %20)
|
||||||
// and followed by more characters and a file extension
|
// and followed by more characters and a file extension
|
||||||
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
// This will capture non filename strings, but we'll check for file existence to remove mismatches
|
||||||
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png)\b`
|
regexPattern := `(?:[a-zA-Z]:)?(?:\./|/|\\)[\S\\ ]+?\.(?i:jpg|jpeg|png|webp)\b`
|
||||||
re := regexp.MustCompile(regexPattern)
|
re := regexp.MustCompile(regexPattern)
|
||||||
|
|
||||||
return re.FindAllString(input, -1)
|
return re.FindAllString(input, -1)
|
||||||
@@ -518,6 +580,8 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
|||||||
return "", imgs, err
|
return "", imgs, err
|
||||||
}
|
}
|
||||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||||
|
input = strings.ReplaceAll(input, "'"+nfp+"'", "")
|
||||||
|
input = strings.ReplaceAll(input, "'"+fp+"'", "")
|
||||||
input = strings.ReplaceAll(input, fp, "")
|
input = strings.ReplaceAll(input, fp, "")
|
||||||
imgs = append(imgs, data)
|
imgs = append(imgs, data)
|
||||||
}
|
}
|
||||||
@@ -538,7 +602,7 @@ func getImageData(filePath string) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
contentType := http.DetectContentType(buf)
|
contentType := http.DetectContentType(buf)
|
||||||
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
|
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png", "image/webp"}
|
||||||
if !slices.Contains(allowedTypes, contentType) {
|
if !slices.Contains(allowedTypes, contentType) {
|
||||||
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
return nil, fmt.Errorf("invalid image type: %s", contentType)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -10,14 +12,17 @@ func TestExtractFilenames(t *testing.T) {
|
|||||||
// Unix style paths
|
// Unix style paths
|
||||||
input := ` some preamble
|
input := ` some preamble
|
||||||
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
|
./relative\ path/one.png inbetween1 ./not a valid two.jpg inbetween2 ./1.svg
|
||||||
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG`
|
/unescaped space /three.jpeg inbetween3 /valid\ path/dir/four.png "./quoted with spaces/five.JPG
|
||||||
|
/unescaped space /six.webp inbetween6 /valid\ path/dir/seven.WEBP`
|
||||||
res := extractFileNames(input)
|
res := extractFileNames(input)
|
||||||
assert.Len(t, res, 5)
|
assert.Len(t, res, 7)
|
||||||
assert.Contains(t, res[0], "one.png")
|
assert.Contains(t, res[0], "one.png")
|
||||||
assert.Contains(t, res[1], "two.jpg")
|
assert.Contains(t, res[1], "two.jpg")
|
||||||
assert.Contains(t, res[2], "three.jpeg")
|
assert.Contains(t, res[2], "three.jpeg")
|
||||||
assert.Contains(t, res[3], "four.png")
|
assert.Contains(t, res[3], "four.png")
|
||||||
assert.Contains(t, res[4], "five.JPG")
|
assert.Contains(t, res[4], "five.JPG")
|
||||||
|
assert.Contains(t, res[5], "six.webp")
|
||||||
|
assert.Contains(t, res[6], "seven.WEBP")
|
||||||
assert.NotContains(t, res[4], '"')
|
assert.NotContains(t, res[4], '"')
|
||||||
assert.NotContains(t, res, "inbetween1")
|
assert.NotContains(t, res, "inbetween1")
|
||||||
assert.NotContains(t, res, "./1.svg")
|
assert.NotContains(t, res, "./1.svg")
|
||||||
@@ -28,10 +33,12 @@ func TestExtractFilenames(t *testing.T) {
|
|||||||
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
/absolute/nospace/three.jpeg inbetween3 /absolute/with space/four.png inbetween4
|
||||||
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
|
./relative\ path/five.JPG inbetween5 "./relative with/spaces/six.png inbetween6
|
||||||
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
||||||
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG some ending
|
d:\program files\someplace\nine.png inbetween9 "E:\program files\someplace\ten.PNG
|
||||||
|
c:/users/jdoe/eleven.webp inbetween11 c:/program files/someplace/twelve.WebP inbetween12
|
||||||
|
d:\path with\spaces\thirteen.WEBP some ending
|
||||||
`
|
`
|
||||||
res = extractFileNames(input)
|
res = extractFileNames(input)
|
||||||
assert.Len(t, res, 10)
|
assert.Len(t, res, 13)
|
||||||
assert.NotContains(t, res, "inbetween2")
|
assert.NotContains(t, res, "inbetween2")
|
||||||
assert.Contains(t, res[0], "one.png")
|
assert.Contains(t, res[0], "one.png")
|
||||||
assert.Contains(t, res[0], "c:")
|
assert.Contains(t, res[0], "c:")
|
||||||
@@ -49,4 +56,31 @@ d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
|
|||||||
assert.Contains(t, res[8], "d:")
|
assert.Contains(t, res[8], "d:")
|
||||||
assert.Contains(t, res[9], "ten.PNG")
|
assert.Contains(t, res[9], "ten.PNG")
|
||||||
assert.Contains(t, res[9], "E:")
|
assert.Contains(t, res[9], "E:")
|
||||||
|
assert.Contains(t, res[10], "eleven.webp")
|
||||||
|
assert.Contains(t, res[10], "c:")
|
||||||
|
assert.Contains(t, res[11], "twelve.WebP")
|
||||||
|
assert.Contains(t, res[11], "c:")
|
||||||
|
assert.Contains(t, res[12], "thirteen.WEBP")
|
||||||
|
assert.Contains(t, res[12], "d:")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure that file paths wrapped in single quotes are removed with the quotes.
|
||||||
|
func TestExtractFileDataRemovesQuotedFilepath(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
fp := filepath.Join(dir, "img.jpg")
|
||||||
|
data := make([]byte, 600)
|
||||||
|
copy(data, []byte{
|
||||||
|
0xff, 0xd8, 0xff, 0xe0, 0x00, 0x10, 'J', 'F', 'I', 'F',
|
||||||
|
0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
||||||
|
0xff, 0xd9,
|
||||||
|
})
|
||||||
|
if err := os.WriteFile(fp, data, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to write test image: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
input := "before '" + fp + "' after"
|
||||||
|
cleaned, imgs, err := extractFileData(input)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, imgs, 1)
|
||||||
|
assert.Equal(t, cleaned, "before after")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strings"
|
"regexp"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
)
|
)
|
||||||
@@ -19,11 +19,12 @@ func startApp(ctx context.Context, client *api.Client) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !strings.Contains(link, "Ollama.app") {
|
r := regexp.MustCompile(`^.*/Ollama\s?\d*.app`)
|
||||||
|
m := r.FindStringSubmatch(link)
|
||||||
|
if len(m) != 1 {
|
||||||
return errors.New("could not find ollama app")
|
return errors.New("could not find ollama app")
|
||||||
}
|
}
|
||||||
path := strings.Split(link, "Ollama.app")
|
if err := exec.Command("/usr/bin/open", "-j", "-a", m[0], "--args", "--fast-startup").Run(); err != nil {
|
||||||
if err := exec.Command("/usr/bin/open", "-a", path[0]+"Ollama.app").Run(); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return waitForServer(ctx, client)
|
return waitForServer(ctx, client)
|
||||||
|
|||||||
@@ -4,17 +4,27 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Installer = "OllamaSetup.exe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func startApp(ctx context.Context, client *api.Client) error {
|
func startApp(ctx context.Context, client *api.Client) error {
|
||||||
// log.Printf("XXX Attempting to find and start ollama app")
|
if len(isProcRunning(Installer)) > 0 {
|
||||||
|
return fmt.Errorf("upgrade in progress...")
|
||||||
|
}
|
||||||
AppName := "ollama app.exe"
|
AppName := "ollama app.exe"
|
||||||
exe, err := os.Executable()
|
exe, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -35,14 +45,11 @@ func startApp(ctx context.Context, client *api.Client) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// log.Printf("XXX attempting to start app %s", appExe)
|
|
||||||
|
|
||||||
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
cmd_path := "c:\\Windows\\system32\\cmd.exe"
|
||||||
cmd := exec.Command(cmd_path, "/c", appExe)
|
cmd := exec.Command(cmd_path, "/c", appExe, "--hide", "--fast-startup")
|
||||||
// TODO - these hide flags aren't working - still pops up a command window for some reason
|
|
||||||
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
|
cmd.SysProcAttr = &syscall.SysProcAttr{CreationFlags: 0x08000000, HideWindow: true}
|
||||||
|
|
||||||
// TODO this didn't help either...
|
|
||||||
cmd.Stdin = strings.NewReader("")
|
cmd.Stdin = strings.NewReader("")
|
||||||
cmd.Stdout = os.Stdout
|
cmd.Stdout = os.Stdout
|
||||||
cmd.Stderr = os.Stderr
|
cmd.Stderr = os.Stderr
|
||||||
@@ -56,3 +63,50 @@ func startApp(ctx context.Context, client *api.Client) error {
|
|||||||
}
|
}
|
||||||
return waitForServer(ctx, client)
|
return waitForServer(ctx, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isProcRunning(procName string) []uint32 {
|
||||||
|
pids := make([]uint32, 2048)
|
||||||
|
var ret uint32
|
||||||
|
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
|
||||||
|
slog.Debug("failed to check for running installers", "error", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if ret > uint32(len(pids)) {
|
||||||
|
pids = make([]uint32, ret+10)
|
||||||
|
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
|
||||||
|
slog.Debug("failed to check for running installers", "error", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if ret < uint32(len(pids)) {
|
||||||
|
pids = pids[:ret]
|
||||||
|
}
|
||||||
|
var matches []uint32
|
||||||
|
for _, pid := range pids {
|
||||||
|
if pid == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer windows.CloseHandle(hProcess)
|
||||||
|
var module windows.Handle
|
||||||
|
var cbNeeded uint32
|
||||||
|
cb := (uint32)(unsafe.Sizeof(module))
|
||||||
|
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var sz uint32 = 1024 * 8
|
||||||
|
moduleName := make([]uint16, sz)
|
||||||
|
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
|
||||||
|
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
|
||||||
|
if strings.EqualFold(exeFile, procName) {
|
||||||
|
matches = append(matches, pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return matches
|
||||||
|
}
|
||||||
|
|||||||
63
cmd/warn_thinking_test.go
Normal file
63
cmd/warn_thinking_test.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/types/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test that a warning is printed when thinking is requested but not supported.
|
||||||
|
func TestWarnMissingThinking(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
capabilities []model.Capability
|
||||||
|
expectWarn bool
|
||||||
|
}{
|
||||||
|
{capabilities: []model.Capability{model.CapabilityThinking}, expectWarn: false},
|
||||||
|
{capabilities: []model.Capability{}, expectWarn: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/api/show" || r.Method != http.MethodPost {
|
||||||
|
t.Fatalf("unexpected request to %s %s", r.URL.Path, r.Method)
|
||||||
|
}
|
||||||
|
var req api.ShowRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
t.Fatalf("decode request: %v", err)
|
||||||
|
}
|
||||||
|
resp := api.ShowResponse{Capabilities: tc.capabilities}
|
||||||
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||||
|
t.Fatalf("encode response: %v", err)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
t.Setenv("OLLAMA_HOST", srv.URL)
|
||||||
|
client, err := api.ClientFromEnvironment()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
oldStderr := os.Stderr
|
||||||
|
r, w, _ := os.Pipe()
|
||||||
|
os.Stderr = w
|
||||||
|
ensureThinkingSupport(t.Context(), client, "m")
|
||||||
|
w.Close()
|
||||||
|
os.Stderr = oldStderr
|
||||||
|
out, _ := io.ReadAll(r)
|
||||||
|
|
||||||
|
warned := strings.Contains(string(out), "warning:")
|
||||||
|
if tc.expectWarn && !warned {
|
||||||
|
t.Errorf("expected warning, got none")
|
||||||
|
}
|
||||||
|
if !tc.expectWarn && warned {
|
||||||
|
t.Errorf("did not expect warning, got: %s", string(out))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,12 +1,14 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
@@ -15,6 +17,10 @@ import (
|
|||||||
type ModelParameters struct {
|
type ModelParameters struct {
|
||||||
Architectures []string `json:"architectures"`
|
Architectures []string `json:"architectures"`
|
||||||
VocabSize uint32 `json:"vocab_size"`
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
|
||||||
|
TextModel struct {
|
||||||
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
} `json:"text_config"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AdapterParameters struct {
|
type AdapterParameters struct {
|
||||||
@@ -47,8 +53,11 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, sv := range t.SpecialVocabulary {
|
for _, sv := range t.SpecialVocabulary {
|
||||||
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
|
|
||||||
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
|
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
|
||||||
|
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
|
||||||
|
if len(sv.IDs) > 0 {
|
||||||
|
kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return kv
|
return kv
|
||||||
@@ -79,27 +88,17 @@ func (ModelParameters) specialTokenTypes() []string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ModelParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
|
||||||
return ggml.WriteGGUF(ws, kv, ts)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (AdapterParameters) writeFile(ws io.WriteSeeker, kv ggml.KV, ts []ggml.Tensor) error {
|
|
||||||
return ggml.WriteGGUF(ws, kv, ts)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ModelConverter interface {
|
type ModelConverter interface {
|
||||||
// KV maps parameters to LLM key-values
|
// KV maps parameters to LLM key-values
|
||||||
KV(*Tokenizer) ggml.KV
|
KV(*Tokenizer) ggml.KV
|
||||||
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
// Tensors maps input tensors to LLM tensors. Model specific modifications can be done here.
|
||||||
Tensors([]Tensor) []ggml.Tensor
|
Tensors([]Tensor) []*ggml.Tensor
|
||||||
// Replacements returns a list of string pairs to replace in tensor names.
|
// Replacements returns a list of string pairs to replace in tensor names.
|
||||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||||
Replacements() []string
|
Replacements() []string
|
||||||
|
|
||||||
// specialTokenTypes returns any special token types the model uses
|
// specialTokenTypes returns any special token types the model uses
|
||||||
specialTokenTypes() []string
|
specialTokenTypes() []string
|
||||||
// writeFile writes the model to the provided io.WriteSeeker
|
|
||||||
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type moreParser interface {
|
type moreParser interface {
|
||||||
@@ -110,15 +109,13 @@ type AdapterConverter interface {
|
|||||||
// KV maps parameters to LLM key-values
|
// KV maps parameters to LLM key-values
|
||||||
KV(ggml.KV) ggml.KV
|
KV(ggml.KV) ggml.KV
|
||||||
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
// Tensors maps input tensors to LLM tensors. Adapter specific modifications can be done here.
|
||||||
Tensors([]Tensor) []ggml.Tensor
|
Tensors([]Tensor) []*ggml.Tensor
|
||||||
// Replacements returns a list of string pairs to replace in tensor names.
|
// Replacements returns a list of string pairs to replace in tensor names.
|
||||||
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
// See [strings.Replacer](https://pkg.go.dev/strings#Replacer) for details
|
||||||
Replacements() []string
|
Replacements() []string
|
||||||
|
|
||||||
writeFile(io.WriteSeeker, ggml.KV, []ggml.Tensor) error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
|
func ConvertAdapter(fsys fs.FS, f *os.File, baseKV ggml.KV) error {
|
||||||
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
bts, err := fs.ReadFile(fsys, "adapter_config.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -153,14 +150,14 @@ func ConvertAdapter(fsys fs.FS, ws io.WriteSeeker, baseKV ggml.KV) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conv.writeFile(ws, conv.KV(baseKV), conv.Tensors(ts))
|
return writeFile(f, conv.KV(baseKV), conv.Tensors(ts))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
// Convert writes an Ollama compatible model to the provided io.WriteSeeker based on configurations
|
||||||
// and files it finds in the input path.
|
// and files it finds in the input path.
|
||||||
// Supported input model formats include safetensors.
|
// Supported input model formats include safetensors.
|
||||||
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
// Supported input tokenizers files include tokenizer.json (preferred) and tokenizer.model.
|
||||||
func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||||
bts, err := fs.ReadFile(fsys, "config.json")
|
bts, err := fs.ReadFile(fsys, "config.json")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -177,24 +174,38 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
|
|
||||||
var conv ModelConverter
|
var conv ModelConverter
|
||||||
switch p.Architectures[0] {
|
switch p.Architectures[0] {
|
||||||
case "LlamaForCausalLM", "MistralForCausalLM":
|
case "LlamaForCausalLM":
|
||||||
conv = &llamaModel{}
|
conv = &llamaModel{}
|
||||||
|
case "MllamaForConditionalGeneration":
|
||||||
|
conv = &mllamaModel{}
|
||||||
|
case "Llama4ForConditionalGeneration":
|
||||||
|
conv = &llama4Model{}
|
||||||
|
case "Mistral3ForConditionalGeneration":
|
||||||
|
conv = &mistral3Model{}
|
||||||
case "MixtralForCausalLM":
|
case "MixtralForCausalLM":
|
||||||
conv = &mixtralModel{}
|
conv = &mixtralModel{}
|
||||||
case "GemmaForCausalLM":
|
case "GemmaForCausalLM":
|
||||||
conv = &gemmaModel{}
|
conv = &gemmaModel{}
|
||||||
case "Gemma2ForCausalLM":
|
case "Gemma2ForCausalLM":
|
||||||
conv = &gemma2Model{}
|
conv = &gemma2Model{}
|
||||||
|
case "Gemma3ForCausalLM", "Gemma3ForConditionalGeneration":
|
||||||
|
conv = &gemma3Model{Architecture: p.Architectures[0]}
|
||||||
|
case "Gemma3nForConditionalGeneration":
|
||||||
|
conv = &gemma3nModel{}
|
||||||
case "Phi3ForCausalLM":
|
case "Phi3ForCausalLM":
|
||||||
conv = &phi3Model{}
|
conv = &phi3Model{}
|
||||||
case "Qwen2ForCausalLM":
|
case "Qwen2ForCausalLM":
|
||||||
conv = &qwen2Model{}
|
conv = &qwen2Model{}
|
||||||
|
case "Qwen2_5_VLForConditionalGeneration":
|
||||||
|
conv = &qwen25VLModel{}
|
||||||
case "BertModel":
|
case "BertModel":
|
||||||
conv = &bertModel{}
|
conv = &bertModel{}
|
||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
conv = &commandrModel{}
|
conv = &commandrModel{}
|
||||||
|
case "GptOssForCausalLM":
|
||||||
|
conv = &gptossModel{}
|
||||||
default:
|
default:
|
||||||
return errors.New("unsupported architecture")
|
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(bts, conv); err != nil {
|
if err := json.Unmarshal(bts, conv); err != nil {
|
||||||
@@ -212,17 +223,22 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
vocabSize := int(p.VocabSize)
|
vocabSize := int(cmp.Or(p.VocabSize, p.TextModel.VocabSize))
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
|
case vocabSize == 0:
|
||||||
|
slog.Debug("vocabulary size was not explicitly set by the model", "default size", len(t.Vocabulary.Tokens))
|
||||||
case vocabSize > len(t.Vocabulary.Tokens):
|
case vocabSize > len(t.Vocabulary.Tokens):
|
||||||
slog.Warn("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
slog.Debug("vocabulary is smaller than expected, padding with dummy tokens", "expect", vocabSize, "actual", len(t.Vocabulary.Tokens))
|
||||||
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
for i := range vocabSize - len(t.Vocabulary.Tokens) {
|
||||||
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
|
t.Vocabulary.Tokens = append(t.Vocabulary.Tokens, fmt.Sprintf("[PAD%d]", i))
|
||||||
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
|
t.Vocabulary.Scores = append(t.Vocabulary.Scores, -1)
|
||||||
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
|
t.Vocabulary.Types = append(t.Vocabulary.Types, tokenTypeUserDefined)
|
||||||
}
|
}
|
||||||
case vocabSize < len(t.Vocabulary.Tokens):
|
case vocabSize < len(t.Vocabulary.Tokens):
|
||||||
return fmt.Errorf("vocabulary is larger than expected '%d' instead of '%d'", len(t.Vocabulary.Tokens), vocabSize)
|
slog.Debug("vocabulary is larger than expected", "want", vocabSize, "got", len(t.Vocabulary.Tokens))
|
||||||
|
p.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
||||||
|
p.TextModel.VocabSize = uint32(len(t.Vocabulary.Tokens))
|
||||||
default:
|
default:
|
||||||
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
slog.Debug("vocabulary", "size", len(t.Vocabulary.Tokens))
|
||||||
}
|
}
|
||||||
@@ -232,5 +248,13 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return conv.writeFile(ws, conv.KV(t), conv.Tensors(ts))
|
return writeFile(f, conv.KV(t), conv.Tensors(ts))
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeFile(f *os.File, kv ggml.KV, ts []*ggml.Tensor) error {
|
||||||
|
for i := range ts {
|
||||||
|
ts[i].Shape = slices.Clone(ts[i].Shape)
|
||||||
|
slices.Reverse(ts[i].Shape)
|
||||||
|
}
|
||||||
|
return ggml.WriteGGUF(f, kv, ts)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,8 +132,8 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
|
func (p *bertModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if slices.Contains([]string{
|
if slices.Contains([]string{
|
||||||
"embeddings.position_ids",
|
"embeddings.position_ids",
|
||||||
@@ -143,7 +143,7 @@ func (p *bertModel) Tensors(ts []Tensor) []ggml.Tensor {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -43,10 +43,10 @@ func (p *commandrModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *commandrModel) Tensors(ts []Tensor) []ggml.Tensor {
|
func (p *commandrModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -42,14 +42,14 @@ func (p *gemmaModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemmaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
func (p *gemmaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if strings.HasSuffix(t.Name(), "_norm.weight") {
|
if !strings.HasPrefix(t.Name(), "v.") && strings.HasSuffix(t.Name(), "_norm.weight") {
|
||||||
t.SetRepacker(p.addOne)
|
t.SetRepacker(p.addOne)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ func (p *gemma2Adapter) KV(baseKV ggml.KV) ggml.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
|
func (p *gemma2Adapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||||
@@ -31,7 +31,7 @@ func (p *gemma2Adapter) Tensors(ts []Tensor) []ggml.Tensor {
|
|||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
142
convert/convert_gemma3.go
Normal file
142
convert/convert_gemma3.go
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type gemma3Model struct {
|
||||||
|
gemmaModel
|
||||||
|
Architecture string
|
||||||
|
TextModel struct {
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct {
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"` // attention.head_count 16
|
||||||
|
LayerNormEpsilon float32 `json:"layer_norm_eps"` // attention.layer_norm_epsilon 1e-05
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"` // block_count 32
|
||||||
|
HiddenSize uint32 `json:"hidden_size"` // embedding_length 1280
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"` // feed_forward_length 5120
|
||||||
|
ImageSize uint32 `json:"image_size"` // image_size 560
|
||||||
|
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||||
|
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||||
|
} `json:"vision_config"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||||
|
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||||
|
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
gemma4BLayerCount = 34
|
||||||
|
gemma12BLayerCount = 48
|
||||||
|
gemma27BLayerCount = 62
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "gemma3"
|
||||||
|
|
||||||
|
numBlocks := cmp.Or(p.HiddenLayers, p.TextModel.HiddenLayers)
|
||||||
|
kv["gemma3.block_count"] = numBlocks
|
||||||
|
|
||||||
|
var (
|
||||||
|
numHeads uint32
|
||||||
|
numKVHeads uint32
|
||||||
|
)
|
||||||
|
|
||||||
|
switch numBlocks {
|
||||||
|
case gemma4BLayerCount:
|
||||||
|
numHeads = 8
|
||||||
|
numKVHeads = 4
|
||||||
|
case gemma12BLayerCount:
|
||||||
|
numHeads = 16
|
||||||
|
numKVHeads = 8
|
||||||
|
case gemma27BLayerCount:
|
||||||
|
numHeads = 32
|
||||||
|
numKVHeads = 16
|
||||||
|
default:
|
||||||
|
numHeads = p.NumAttentionHeads
|
||||||
|
numKVHeads = p.NumKeyValueHeads
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["gemma3.attention.head_count"] = numHeads
|
||||||
|
kv["gemma3.attention.head_count_kv"] = numKVHeads
|
||||||
|
|
||||||
|
switch p.Architecture {
|
||||||
|
case "Gemma3ForCausalLM":
|
||||||
|
kv["gemma3.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["gemma3.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||||
|
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||||
|
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||||
|
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
||||||
|
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||||
|
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
||||||
|
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||||
|
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||||
|
default:
|
||||||
|
kv["gemma3.context_length"] = cmp.Or(p.MaxPositionEmbeddings, 131072)
|
||||||
|
kv["gemma3.embedding_length"] = p.TextModel.HiddenSize
|
||||||
|
kv["gemma3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||||
|
kv["gemma3.attention.sliding_window"] = p.TextModel.SlidingWindow
|
||||||
|
kv["gemma3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
|
kv["gemma3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||||
|
kv["gemma3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||||
|
kv["gemma3.vision.image_size"] = p.VisionModel.ImageSize
|
||||||
|
kv["gemma3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
|
kv["gemma3.vision.num_channels"] = cmp.Or(p.VisionModel.NumChannels, 3)
|
||||||
|
kv["gemma3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||||
|
kv["gemma3.vision.attention.layer_norm_epsilon"] = cmp.Or(p.VisionModel.LayerNormEpsilon, 1e-6)
|
||||||
|
kv["gemma3.attention.key_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
||||||
|
kv["gemma3.attention.value_length"] = cmp.Or(p.TextModel.HeadDim, 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.MultiModalTokensPerImage > 0 {
|
||||||
|
kv["gemma3.mm.tokens_per_image"] = p.MultiModalTokensPerImage
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *gemma3Model) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"vision_tower.vision_model.embeddings", "v",
|
||||||
|
"vision_tower.vision_model", "v",
|
||||||
|
"vision_model.vision_model.embeddings", "v",
|
||||||
|
"vision_model.vision_model", "v",
|
||||||
|
"language_model.", "",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"encoder.layers", "blk",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.q_norm", "attn_q_norm",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.k_norm", "attn_k_norm",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"self_attn.out_proj", "attn_output",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
|
"pre_feedforward_layernorm", "ffn_norm",
|
||||||
|
"post_feedforward_layernorm", "post_ffw_norm",
|
||||||
|
"input_projection_weight", "input_projection.weight",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
}
|
||||||
|
}
|
||||||
165
convert/convert_gemma3n.go
Normal file
165
convert/convert_gemma3n.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
"gonum.org/v1/gonum/stat/distuv"
|
||||||
|
)
|
||||||
|
|
||||||
|
type gemma3nModel struct {
|
||||||
|
ModelParameters
|
||||||
|
|
||||||
|
TextModel struct {
|
||||||
|
ActivationSparsityPattern []float32 `json:"activation_sparsity_pattern"`
|
||||||
|
AltupActiveIdx uint32 `json:"altup_active_idx"`
|
||||||
|
AltupCoefClip float32 `json:"altup_coef_clip"`
|
||||||
|
AltupCorrectScale bool `json:"altup_correct_scale"`
|
||||||
|
AltupLRMultiplier float32 `json:"altup_lr_multiplier"`
|
||||||
|
AltupNumInputs uint32 `json:"altup_num_inputs"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
HiddenSizePerLayerInput uint32 `json:"hidden_size_per_layer_input"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
NumKVSharedLayers uint32 `json:"num_kv_shared_layers"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
RopeLocalBaseFreq float32 `json:"rope_local_base_freq"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct{} `json:"vision_config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *gemma3nModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := m.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "gemma3n"
|
||||||
|
kv["gemma3n.activation_sparsity_scale"] = slices.Collect(func(yield func(float32) bool) {
|
||||||
|
norm := distuv.Normal{Mu: 0, Sigma: 1}
|
||||||
|
for _, v := range m.TextModel.ActivationSparsityPattern {
|
||||||
|
if !yield(float32(norm.Quantile(float64(v)))) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
kv["gemma3n.altup.active_idx"] = m.TextModel.AltupActiveIdx
|
||||||
|
kv["gemma3n.altup.correct_scale"] = m.TextModel.AltupCorrectScale
|
||||||
|
kv["gemma3n.altup.lr_multiplier"] = m.TextModel.AltupLRMultiplier
|
||||||
|
kv["gemma3n.altup.num_inputs"] = m.TextModel.AltupNumInputs
|
||||||
|
kv["gemma3n.attention.head_count_kv"] = m.TextModel.NumKeyValueHeads
|
||||||
|
kv["gemma3n.attention.head_count"] = m.TextModel.NumAttentionHeads
|
||||||
|
kv["gemma3n.attention.layer_norm_rms_epsilon"] = m.TextModel.RMSNormEPS
|
||||||
|
kv["gemma3n.attention.sliding_window"] = m.TextModel.SlidingWindow
|
||||||
|
kv["gemma3n.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||||
|
for _, t := range m.TextModel.LayerTypes {
|
||||||
|
if !yield(t == "sliding_attention") {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
kv["gemma3n.attention.shared_kv_layers"] = m.TextModel.NumKVSharedLayers
|
||||||
|
kv["gemma3n.block_count"] = m.TextModel.NumHiddenLayers
|
||||||
|
kv["gemma3n.context_length"] = m.TextModel.MaxPositionEmbeddings
|
||||||
|
kv["gemma3n.embedding_length_per_layer_input"] = m.TextModel.HiddenSizePerLayerInput
|
||||||
|
kv["gemma3n.embedding_length"] = m.TextModel.HiddenSize
|
||||||
|
kv["gemma3n.feed_forward_length"] = m.TextModel.IntermediateSize
|
||||||
|
kv["gemma3n.head_dim"] = m.TextModel.HeadDim
|
||||||
|
kv["gemma3n.rope.freq_base_local"] = m.TextModel.RopeLocalBaseFreq
|
||||||
|
kv["gemma3n.rope.freq_base"] = m.TextModel.RopeTheta
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *gemma3nModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
out, ts := mergeTensors(ts,
|
||||||
|
merge{"altup_proj.*.weight", "altup_proj.weight"},
|
||||||
|
merge{"altup_unembd_proj.*.weight", "altup_unembd_proj.weight"},
|
||||||
|
)
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
switch {
|
||||||
|
case strings.Contains(t.Name(), "audio_tower"),
|
||||||
|
strings.Contains(t.Name(), "embed_audio"),
|
||||||
|
strings.Contains(t.Name(), "vision_tower"),
|
||||||
|
strings.Contains(t.Name(), "embed_vision"):
|
||||||
|
// TODO: handle audio and vision towers
|
||||||
|
continue
|
||||||
|
case strings.Contains(t.Name(), "altup_predict_coef"),
|
||||||
|
strings.Contains(t.Name(), "altup_correct_coef"):
|
||||||
|
if m.TextModel.AltupCoefClip > 0 {
|
||||||
|
t.SetRepacker(func(name string, data []float32, shape []uint64) (_ []float32, err error) {
|
||||||
|
dims := make([]int, len(shape))
|
||||||
|
for i := range shape {
|
||||||
|
dims[i] = int(shape[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
|
||||||
|
t, err = tensor.Clamp(t, -m.TextModel.AltupCoefClip, m.TextModel.AltupCoefClip)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return native.VectorF32(t.(*tensor.Dense))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *gemma3nModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"model.language_model.embed_tokens_per_layer", "per_layer_token_embd",
|
||||||
|
"model.language_model.embed_tokens", "token_embd",
|
||||||
|
"model.language_model.per_layer_model_projection", "per_layer_model_proj",
|
||||||
|
"model.language_model.per_layer_projection_norm", "per_layer_proj_norm", "model.language_model.altup_projections", "altup_proj",
|
||||||
|
"model.language_model.altup_unembed_projections", "altup_unembd_proj",
|
||||||
|
"model.language_model.norm", "output_norm",
|
||||||
|
"model.language_model.layers", "blk",
|
||||||
|
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.q_norm", "attn_q_norm",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.k_norm", "attn_k_norm",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
|
"pre_feedforward_layernorm", "ffn_norm",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"post_feedforward_layernorm", "post_ffw_norm",
|
||||||
|
"per_layer_input_gate", "inp_gate",
|
||||||
|
"per_layer_projection", "proj",
|
||||||
|
"post_per_layer_input_norm", "post_norm",
|
||||||
|
"altup.", "altup_",
|
||||||
|
"modality_router", "router",
|
||||||
|
"prediction_coefs", "predict_coef",
|
||||||
|
"correction_coefs", "correct_coef",
|
||||||
|
"correct_output_scale", "correct_scale.weight",
|
||||||
|
"laurel.", "laurel_",
|
||||||
|
"linear_left", "l",
|
||||||
|
"linear_right", "r",
|
||||||
|
"post_laurel_norm", "post_norm",
|
||||||
|
}
|
||||||
|
}
|
||||||
223
convert/convert_gptoss.go
Normal file
223
convert/convert_gptoss.go
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"cmp"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
)
|
||||||
|
|
||||||
|
type gptossModel struct {
|
||||||
|
ModelParameters
|
||||||
|
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
AttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
Experts uint32 `json:"num_experts"`
|
||||||
|
LocalExperts uint32 `json:"num_local_experts"`
|
||||||
|
ExpertsPerToken uint32 `json:"experts_per_token"`
|
||||||
|
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
||||||
|
InitialContextLength uint32 `json:"initial_context_length"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
||||||
|
RopeScaling struct {
|
||||||
|
Factor float32 `json:"factor"`
|
||||||
|
} `json:"rope_scaling"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ModelConverter = (*gptossModel)(nil)
|
||||||
|
|
||||||
|
func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := m.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "gptoss"
|
||||||
|
kv["general.file_type"] = uint32(4)
|
||||||
|
kv["gptoss.context_length"] = cmp.Or(m.MaxPositionEmbeddings, uint32(m.RopeScalingFactor*float32(m.InitialContextLength)))
|
||||||
|
kv["gptoss.block_count"] = m.HiddenLayers
|
||||||
|
kv["gptoss.embedding_length"] = m.HiddenSize
|
||||||
|
kv["gptoss.feed_forward_length"] = m.IntermediateSize
|
||||||
|
kv["gptoss.expert_count"] = cmp.Or(m.Experts, m.LocalExperts)
|
||||||
|
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
|
||||||
|
kv["gptoss.attention.head_count"] = m.AttentionHeads
|
||||||
|
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
|
||||||
|
kv["gptoss.attention.key_length"] = m.HeadDim
|
||||||
|
kv["gptoss.attention.value_length"] = m.HeadDim
|
||||||
|
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
|
||||||
|
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
|
||||||
|
kv["gptoss.rope.freq_base"] = m.RopeTheta
|
||||||
|
kv["gptoss.rope.scaling.factor"] = cmp.Or(m.RopeScalingFactor, m.RopeScaling.Factor)
|
||||||
|
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
|
||||||
|
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
|
||||||
|
kv["tokenizer.ggml.add_bos_token"] = false
|
||||||
|
kv["tokenizer.ggml.eos_token_id"] = uint32(199999) // <|endoftext|>
|
||||||
|
kv["tokenizer.ggml.eos_token_ids"] = []int32{
|
||||||
|
199999, /* <|endoftext|> */
|
||||||
|
200002, /* <|return|> */
|
||||||
|
200012, /* <|call|> */
|
||||||
|
}
|
||||||
|
kv["tokenizer.ggml.add_eos_token"] = false
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
mxfp4s := make(map[string]*mxfp4)
|
||||||
|
for _, t := range ts {
|
||||||
|
if strings.HasSuffix(t.Name(), ".blocks") || strings.HasSuffix(t.Name(), ".scales") {
|
||||||
|
dot := strings.LastIndex(t.Name(), ".")
|
||||||
|
name, suffix := t.Name()[:dot], t.Name()[dot+1:]
|
||||||
|
if _, ok := mxfp4s[name]; !ok {
|
||||||
|
mxfp4s[name] = &mxfp4{}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch suffix {
|
||||||
|
case "blocks":
|
||||||
|
mxfp4s[name].blocks = t
|
||||||
|
case "scales":
|
||||||
|
mxfp4s[name].scales = t
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, mxfp4 := range mxfp4s {
|
||||||
|
dims := mxfp4.blocks.Shape()
|
||||||
|
|
||||||
|
if !strings.HasSuffix(name, ".weight") {
|
||||||
|
name += ".weight"
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: name,
|
||||||
|
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||||
|
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||||
|
WriterTo: mxfp4,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *gptossModel) Replacements() []string {
|
||||||
|
var replacements []string
|
||||||
|
if m.MaxPositionEmbeddings > 0 {
|
||||||
|
// hf flavored model
|
||||||
|
replacements = []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_out",
|
||||||
|
"self_attn.sinks", "attn_sinks",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
"mlp.router", "ffn_gate_inp",
|
||||||
|
"mlp.experts.gate_up_proj_", "ffn_gate_up_exps.",
|
||||||
|
"mlp.experts.down_proj_", "ffn_down_exps.",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
replacements = []string{
|
||||||
|
// noop replacements so other replacements will not be applied
|
||||||
|
".blocks", ".blocks",
|
||||||
|
".scales", ".scales",
|
||||||
|
// real replacements
|
||||||
|
"block", "blk",
|
||||||
|
"attn.norm", "attn_norm",
|
||||||
|
"attn.qkv", "attn_qkv",
|
||||||
|
"attn.sinks", "attn_sinks",
|
||||||
|
"attn.out", "attn_out",
|
||||||
|
"mlp.norm", "ffn_norm",
|
||||||
|
"mlp.gate", "ffn_gate_inp",
|
||||||
|
"mlp.mlp1_", "ffn_gate_up_exps.",
|
||||||
|
"mlp.mlp2_", "ffn_down_exps.",
|
||||||
|
"embedding", "token_embd",
|
||||||
|
"norm", "output_norm",
|
||||||
|
"unembedding", "output",
|
||||||
|
"scale", "weight",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return replacements
|
||||||
|
}
|
||||||
|
|
||||||
|
type mxfp4 struct {
|
||||||
|
blocks, scales Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := m.blocks.WriteTo(&b); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
blocksDims := make([]int, len(m.blocks.Shape()))
|
||||||
|
for i, d := range m.blocks.Shape() {
|
||||||
|
blocksDims[i] = int(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
bts := b.Bytes()
|
||||||
|
var tmp [16]byte
|
||||||
|
for i := 0; i < b.Len(); i += 16 {
|
||||||
|
for j := range 8 {
|
||||||
|
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||||
|
a, b := bts[i+j], bts[i+j+8]
|
||||||
|
tmp[2*j+0] = (a & 0x0F) | (b << 4)
|
||||||
|
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(bts[i:i+16], tmp[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
|
||||||
|
|
||||||
|
var s bytes.Buffer
|
||||||
|
if _, err := m.scales.WriteTo(&s); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
scalesDims := slices.Repeat([]int{1}, len(m.blocks.Shape()))
|
||||||
|
for i, d := range m.scales.Shape() {
|
||||||
|
scalesDims[i] = int(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
var scales tensor.Tensor = tensor.New(tensor.WithShape(scalesDims...), tensor.WithBacking(s.Bytes()))
|
||||||
|
|
||||||
|
out, err := tensor.Concat(3, scales, blocks)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
out = tensor.Materialize(out)
|
||||||
|
|
||||||
|
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
u8s, err := native.VectorU8(out.(*tensor.Dense))
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(w, binary.LittleEndian, u8s); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(len(u8s)), nil
|
||||||
|
}
|
||||||
@@ -28,12 +28,12 @@ type llamaModel struct {
|
|||||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
RopeTheta float32 `json:"rope_theta"`
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
RopeScaling struct {
|
RopeScaling struct {
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
RopeType string `json:"rope_type"`
|
RopeType string `json:"rope_type"`
|
||||||
Factor float32 `json:"factor"`
|
Factor float32 `json:"factor"`
|
||||||
LowFrequencyFactor float32 `json:"low_freq_factor"`
|
LowFrequencyFactor float32 `json:"low_freq_factor"`
|
||||||
HighFrequencyFactor float32 `json:"high_freq_factor"`
|
HighFrequencyFactor float32 `json:"high_freq_factor"`
|
||||||
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
|
||||||
factors ropeFactor
|
factors ropeFactor
|
||||||
} `json:"rope_scaling"`
|
} `json:"rope_scaling"`
|
||||||
@@ -42,6 +42,8 @@ type llamaModel struct {
|
|||||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||||
NormEpsilon float32 `json:"norm_epsilon"`
|
NormEpsilon float32 `json:"norm_epsilon"`
|
||||||
HeadDim uint32 `json:"head_dim"`
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
|
||||||
|
skipRepack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ ModelConverter = (*llamaModel)(nil)
|
var _ ModelConverter = (*llamaModel)(nil)
|
||||||
@@ -70,6 +72,10 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
|
kv["llama.rope.dimension_count"] = p.HiddenSize / headCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.HeadDim > 0 {
|
||||||
|
kv["llama.attention.head_dim"] = p.HeadDim
|
||||||
|
}
|
||||||
|
|
||||||
if p.RopeTheta > 0 {
|
if p.RopeTheta > 0 {
|
||||||
kv["llama.rope.freq_base"] = p.RopeTheta
|
kv["llama.rope.freq_base"] = p.RopeTheta
|
||||||
}
|
}
|
||||||
@@ -84,7 +90,7 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
|
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
|
||||||
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
|
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
|
||||||
|
|
||||||
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
|
original := cmp.Or(p.RopeScaling.OriginalMaxPositionEmbeddings, 8192)
|
||||||
lambdaLow := float32(original) / factorLow
|
lambdaLow := float32(original) / factorLow
|
||||||
lambdaHigh := float32(original) / factorHigh
|
lambdaHigh := float32(original) / factorHigh
|
||||||
|
|
||||||
@@ -120,11 +126,11 @@ func (p *llamaModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
if p.RopeScaling.factors != nil {
|
if p.RopeScaling.factors != nil {
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: "rope_freqs.weight",
|
Name: "rope_freqs.weight",
|
||||||
Kind: 0,
|
Kind: 0,
|
||||||
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
|
||||||
@@ -133,12 +139,14 @@ func (p *llamaModel) Tensors(ts []Tensor) []ggml.Tensor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
|
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") ||
|
||||||
strings.HasSuffix(t.Name(), "attn_k.weight") {
|
strings.HasSuffix(t.Name(), "attn_q_proj.weight") || strings.HasSuffix(t.Name(), "attn_k_proj.weight") {
|
||||||
t.SetRepacker(p.repack)
|
if !p.skipRepack {
|
||||||
|
t.SetRepacker(p.repack)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
@@ -174,9 +182,9 @@ func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]floa
|
|||||||
}
|
}
|
||||||
|
|
||||||
var heads uint32
|
var heads uint32
|
||||||
if strings.HasSuffix(name, "attn_q.weight") {
|
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_q_proj.weight") {
|
||||||
heads = p.NumAttentionHeads
|
heads = p.NumAttentionHeads
|
||||||
} else if strings.HasSuffix(name, "attn_k.weight") {
|
} else if strings.HasSuffix(name, "attn_k.weight") || strings.HasSuffix(name, "attn_k_proj.weight") {
|
||||||
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||||
} else {
|
} else {
|
||||||
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||||
|
|||||||
169
convert/convert_llama4.go
Normal file
169
convert/convert_llama4.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type llama4Model struct {
|
||||||
|
ModelParameters
|
||||||
|
TextModel struct {
|
||||||
|
llamaModel
|
||||||
|
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
|
||||||
|
NumLocalExperts uint32 `json:"num_local_experts"`
|
||||||
|
InterleaveMOELayerStep uint32 `json:"interleave_moe_layer_step"`
|
||||||
|
UseQKNorm bool `json:"use_qk_norm"`
|
||||||
|
IntermediateSizeMLP uint32 `json:"intermediate_size_mlp"`
|
||||||
|
AttentionChunkSize uint32 `json:"attention_chunk_size"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct {
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
NormEpsilon float32 `json:"norm_eps"`
|
||||||
|
PixelShuffleRatio float32 `json:"pixel_shuffle_ratio"`
|
||||||
|
} `json:"vision_config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// KV implements ModelConverter.
|
||||||
|
func (p *llama4Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "llama4"
|
||||||
|
|
||||||
|
for k, v := range p.TextModel.KV(t) {
|
||||||
|
if strings.HasPrefix(k, "llama.") {
|
||||||
|
kv[strings.ReplaceAll(k, "llama.", "llama4.")] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["llama4.feed_forward_length"] = p.TextModel.IntermediateSizeMLP
|
||||||
|
kv["llama4.expert_feed_forward_length"] = p.TextModel.IntermediateSize
|
||||||
|
|
||||||
|
kv["llama4.expert_count"] = p.TextModel.NumLocalExperts
|
||||||
|
kv["llama4.expert_used_count"] = p.TextModel.NumExpertsPerToken
|
||||||
|
kv["llama4.interleave_moe_layer_step"] = p.TextModel.InterleaveMOELayerStep
|
||||||
|
kv["llama4.use_qk_norm"] = p.TextModel.UseQKNorm
|
||||||
|
kv["llama4.attention.chunk_size"] = p.TextModel.AttentionChunkSize
|
||||||
|
|
||||||
|
kv["llama4.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
|
kv["llama4.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||||
|
kv["llama4.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||||
|
kv["llama4.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||||
|
kv["llama4.vision.image_size"] = p.VisionModel.ImageSize
|
||||||
|
kv["llama4.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
|
kv["llama4.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||||
|
kv["llama4.vision.layer_norm_epsilon"] = p.VisionModel.NormEpsilon
|
||||||
|
kv["llama4.vision.pixel_shuffle_ratio"] = p.VisionModel.PixelShuffleRatio
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replacements implements ModelConverter.
|
||||||
|
func (p *llama4Model) Replacements() []string {
|
||||||
|
return append(
|
||||||
|
p.TextModel.Replacements(),
|
||||||
|
"language_model.", "",
|
||||||
|
"vision_model", "v",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
"feed_forward.down_proj", "ffn_down",
|
||||||
|
"feed_forward.up_proj", "ffn_up",
|
||||||
|
"feed_forward.gate_proj", "ffn_gate",
|
||||||
|
"feed_forward.", "ffn_",
|
||||||
|
"shared_expert.down_proj", "down_shexp",
|
||||||
|
"shared_expert.gate_proj", "gate_shexp",
|
||||||
|
"shared_expert.up_proj", "up_shexp",
|
||||||
|
"experts.down_proj", "down_exps.weight",
|
||||||
|
"experts.gate_up_proj", "gate_up_exps.weight",
|
||||||
|
"router", "gate_inp",
|
||||||
|
"patch_embedding.linear", "patch_embedding",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tensors implements ModelConverter.
|
||||||
|
func (p *llama4Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
|
var textTensors []Tensor
|
||||||
|
for _, t := range ts {
|
||||||
|
if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
} else if strings.Contains(t.Name(), "ffn_gate_up_exps") {
|
||||||
|
// gate and up projectors are fused
|
||||||
|
// dims[1], dims[2] must be swapped
|
||||||
|
// [experts, hidden_size, intermediate_size * 2] --> [experts, intermediate_size, hidden_size]
|
||||||
|
halfDim := int(t.Shape()[2]) / 2
|
||||||
|
|
||||||
|
newShape := slices.Clone(t.Shape())
|
||||||
|
newShape[1], newShape[2] = newShape[2]/2, newShape[1]
|
||||||
|
for i, name := range []string{"ffn_gate_exps", "ffn_up_exps"} {
|
||||||
|
// clone tensor since we need separate repackers
|
||||||
|
tt := t.Clone()
|
||||||
|
tt.SetRepacker(p.repack(nil, nil, tensor.S(i*halfDim, (i+1)*halfDim)))
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: strings.ReplaceAll(tt.Name(), "ffn_gate_up_exps", name),
|
||||||
|
Kind: tt.Kind(),
|
||||||
|
Shape: newShape,
|
||||||
|
WriterTo: tt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if strings.Contains(t.Name(), "ffn_down_exps") {
|
||||||
|
// dims[1], dims[2] must be swapped
|
||||||
|
// [experts, intermediate_size, hidden_size] --> [experts, hidden_size, intermediate_size]
|
||||||
|
t.SetRepacker(p.repack())
|
||||||
|
newShape := slices.Clone(t.Shape())
|
||||||
|
newShape[1], newShape[2] = newShape[2], newShape[1]
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: newShape,
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
textTensors = append(textTensors, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.TextModel.skipRepack = true
|
||||||
|
out = append(out, p.TextModel.Tensors(textTensors)...)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *llama4Model) repack(slice ...tensor.Slice) Repacker {
|
||||||
|
return func(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
dims := make([]int, len(shape))
|
||||||
|
for i, dim := range shape {
|
||||||
|
dims[i] = int(dim)
|
||||||
|
}
|
||||||
|
|
||||||
|
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
t, err := t.Slice(slice...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.T(0, 2, 1); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t = tensor.Materialize(t)
|
||||||
|
// flatten tensor so it can be return as a vector
|
||||||
|
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return native.VectorF32(t.(*tensor.Dense))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -29,8 +29,8 @@ func (p *llamaAdapter) KV(baseKV ggml.KV) ggml.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
|
func (p *llamaAdapter) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
shape := t.Shape()
|
shape := t.Shape()
|
||||||
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
if (strings.HasSuffix(t.Name(), "weight.lora_a") && shape[0] > shape[1]) ||
|
||||||
@@ -41,7 +41,7 @@ func (p *llamaAdapter) Tensors(ts []Tensor) []ggml.Tensor {
|
|||||||
t.SetRepacker(p.repack)
|
t.SetRepacker(p.repack)
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: shape,
|
Shape: shape,
|
||||||
|
|||||||
190
convert/convert_mistral.go
Normal file
190
convert/convert_mistral.go
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mistral3Model struct {
|
||||||
|
ModelParameters
|
||||||
|
ImageTokenIndex uint32 `json:"image_token_index"`
|
||||||
|
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||||
|
VisionFeatureLayer int32 `json:"vision_feature_layer"`
|
||||||
|
TextModel struct {
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
SlidingWindow *uint32 `json:"sliding_window"`
|
||||||
|
HiddenAct string `json:"hidden_act"`
|
||||||
|
VocabSize uint32 `json:"vocab_size"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct {
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
NumChannels uint32 `json:"num_channels"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
HeadDim uint32 `json:"head_dim"`
|
||||||
|
HiddenAct string `json:"hidden_act"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
} `json:"vision_config"`
|
||||||
|
MultiModalProjectorBias bool `json:"multimodal_projector_bias"`
|
||||||
|
ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "mistral3"
|
||||||
|
kv["mistral3.vocab_size"] = p.TextModel.VocabSize
|
||||||
|
|
||||||
|
// Text configuration
|
||||||
|
kv["mistral3.block_count"] = p.TextModel.NumHiddenLayers
|
||||||
|
kv["mistral3.context_length"] = p.TextModel.MaxPositionEmbeddings
|
||||||
|
kv["mistral3.embedding_length"] = p.TextModel.HiddenSize
|
||||||
|
kv["mistral3.feed_forward_length"] = p.TextModel.IntermediateSize
|
||||||
|
kv["mistral3.attention.head_count"] = p.TextModel.NumAttentionHeads
|
||||||
|
kv["mistral3.attention.head_count_kv"] = p.TextModel.NumKeyValueHeads
|
||||||
|
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||||
|
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||||
|
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||||
|
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||||
|
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||||
|
|
||||||
|
// Vision configuration
|
||||||
|
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||||
|
kv["mistral3.vision.embedding_length"] = p.VisionModel.HiddenSize
|
||||||
|
kv["mistral3.vision.feed_forward_length"] = p.VisionModel.IntermediateSize
|
||||||
|
kv["mistral3.vision.attention.head_count"] = p.VisionModel.NumAttentionHeads
|
||||||
|
kv["mistral3.vision.attention.key_length"] = p.VisionModel.HeadDim
|
||||||
|
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
|
||||||
|
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
|
||||||
|
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
|
||||||
|
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
|
||||||
|
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
|
||||||
|
|
||||||
|
// Multimodal configuration
|
||||||
|
kv["mistral3.image_token_index"] = p.ImageTokenIndex
|
||||||
|
kv["mistral3.spatial_merge_size"] = p.SpatialMergeSize
|
||||||
|
|
||||||
|
kv["mistral3.mm.projector_bias"] = p.MultiModalProjectorBias
|
||||||
|
|
||||||
|
if p.ProjectorHiddenAct != "" {
|
||||||
|
kv["mistral3.mm.projector_hidden_act"] = p.ProjectorHiddenAct
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
if !strings.HasPrefix(t.Name(), "v.") {
|
||||||
|
if strings.HasSuffix(t.Name(), ".attn_q.weight") ||
|
||||||
|
strings.HasSuffix(t.Name(), ".attn_k.weight") {
|
||||||
|
t.SetRepacker(p.repack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"language_model.model.norm", "output_norm",
|
||||||
|
"language_model.model.", "",
|
||||||
|
"language_model.", "",
|
||||||
|
"layers", "blk",
|
||||||
|
"transformer.layers", "blk",
|
||||||
|
"vision_tower", "v",
|
||||||
|
"ln_pre", "encoder_norm",
|
||||||
|
"input_layernorm", "attn_norm",
|
||||||
|
"post_attention_layernorm", "ffn_norm",
|
||||||
|
"embed_tokens", "token_embd",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
"attention.q_proj", "attn_q",
|
||||||
|
"attention.k_proj", "attn_k",
|
||||||
|
"attention.v_proj", "attn_v",
|
||||||
|
"attention.o_proj", "attn_output",
|
||||||
|
"attention_norm", "attn_norm",
|
||||||
|
"feed_forward.gate_proj", "ffn_gate",
|
||||||
|
"feed_forward.down_proj", "ffn_down",
|
||||||
|
"feed_forward.up_proj", "ffn_up",
|
||||||
|
"multi_modal_projector", "mm",
|
||||||
|
"ffn_norm", "ffn_norm",
|
||||||
|
"lm_head", "output",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *mistral3Model) repack(name string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
var dims []int
|
||||||
|
for _, dim := range shape {
|
||||||
|
dims = append(dims, int(dim))
|
||||||
|
}
|
||||||
|
|
||||||
|
var heads uint32
|
||||||
|
if strings.HasSuffix(name, ".attn_q.weight") {
|
||||||
|
heads = p.TextModel.NumAttentionHeads
|
||||||
|
} else if strings.HasSuffix(name, ".attn_k.weight") {
|
||||||
|
heads = cmp.Or(p.TextModel.NumKeyValueHeads, p.TextModel.NumAttentionHeads)
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.T(0, 2, 1, 3); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Reshape(dims...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := n.Transpose(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := native.SelectF32(n, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var f32s []float32
|
||||||
|
for _, t := range ts {
|
||||||
|
f32s = append(f32s, t...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f32s, nil
|
||||||
|
}
|
||||||
@@ -2,9 +2,6 @@ package convert
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
@@ -29,66 +26,39 @@ func (p *mixtralModel) KV(t *Tokenizer) ggml.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mixtralModel) Tensors(ts []Tensor) []ggml.Tensor {
|
func (p *mixtralModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
oldnew := []string{
|
merges := make([]merge, 0, p.NumHiddenLayers*6)
|
||||||
"model.layers", "blk",
|
for i := range p.NumHiddenLayers {
|
||||||
"w1", "ffn_gate_exps",
|
merges = append(merges, merge{
|
||||||
"w2", "ffn_down_exps",
|
fmt.Sprintf("blk.%d.*.w1.weight", i),
|
||||||
"w3", "ffn_up_exps",
|
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
|
||||||
}
|
}, merge{
|
||||||
|
fmt.Sprintf("blk.%d.*.w1.bias", i),
|
||||||
for i := range p.NumLocalExperts {
|
fmt.Sprintf("blk.%d.ffn_gate_exps.bias", i),
|
||||||
oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
|
}, merge{
|
||||||
}
|
fmt.Sprintf("blk.%d.*.w2.weight", i),
|
||||||
|
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
|
||||||
// group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
|
}, merge{
|
||||||
namer := strings.NewReplacer(oldnew...)
|
fmt.Sprintf("blk.%d.*.w2.bias", i),
|
||||||
experts := make(map[string]experts)
|
fmt.Sprintf("blk.%d.ffn_up_exps.bias", i),
|
||||||
|
}, merge{
|
||||||
// merge experts into a single tensor while removing them from ts
|
fmt.Sprintf("blk.%d.*.w3.weight", i),
|
||||||
ts = slices.DeleteFunc(ts, func(t Tensor) bool {
|
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
|
||||||
if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
|
}, merge{
|
||||||
return false
|
fmt.Sprintf("blk.%d.*.w3.bias", i),
|
||||||
}
|
fmt.Sprintf("blk.%d.ffn_down_exps.bias", i),
|
||||||
|
|
||||||
name := namer.Replace(t.Name())
|
|
||||||
experts[name] = append(experts[name], t)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
var out []ggml.Tensor
|
|
||||||
for n, e := range experts {
|
|
||||||
// TODO(mxyng): sanity check experts
|
|
||||||
out = append(out, ggml.Tensor{
|
|
||||||
Name: n,
|
|
||||||
Kind: e[0].Kind(),
|
|
||||||
Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
|
|
||||||
WriterTo: e,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
out, ts := mergeTensors(ts, merges...)
|
||||||
return append(out, p.llamaModel.Tensors(ts)...)
|
return append(out, p.llamaModel.Tensors(ts)...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *mixtralModel) Replacements() []string {
|
func (p *mixtralModel) Replacements() []string {
|
||||||
return append(
|
return append(
|
||||||
p.llamaModel.Replacements(),
|
p.llamaModel.Replacements(),
|
||||||
|
"model.layers", "blk",
|
||||||
"block_sparse_moe.gate", "ffn_gate_inp",
|
"block_sparse_moe.gate", "ffn_gate_inp",
|
||||||
|
"block_sparse_moe.experts.", ".",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
type experts []Tensor
|
|
||||||
|
|
||||||
func (e experts) WriteTo(w io.Writer) (int64, error) {
|
|
||||||
// TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
|
|
||||||
for _, t := range e {
|
|
||||||
// the canonical merged experts tensor stacks all experts along a new, 0 axis,
|
|
||||||
// e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
|
|
||||||
// this accomplishes the same thing by writing each expert tensor in sequence
|
|
||||||
if _, err := t.WriteTo(w); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|||||||
179
convert/convert_mllama.go
Normal file
179
convert/convert_mllama.go
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mllamaModel struct {
|
||||||
|
ModelParameters
|
||||||
|
TextModel struct {
|
||||||
|
llamaModel
|
||||||
|
|
||||||
|
CrossAttentionLayers []int32 `json:"cross_attention_layers"`
|
||||||
|
} `json:"text_config"`
|
||||||
|
VisionModel struct {
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
NumGlobalLayers uint32 `json:"num_global_layers"`
|
||||||
|
IntermediateLayersIndices []int32 `json:"intermediate_layers_indices"`
|
||||||
|
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
|
||||||
|
AttentionHeads uint32 `json:"attention_heads"`
|
||||||
|
|
||||||
|
ImageSize uint32 `json:"image_size"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
NumChannels uint32 `json:"num_channels"`
|
||||||
|
MaxNumTiles uint32 `json:"max_num_tiles"`
|
||||||
|
NormEpsilon float32 `json:"norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope.freq_base"`
|
||||||
|
} `json:"vision_config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mllamaModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := m.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "mllama"
|
||||||
|
|
||||||
|
for k, v := range m.TextModel.KV(t) {
|
||||||
|
if strings.HasPrefix(k, "llama.") {
|
||||||
|
kv[strings.ReplaceAll(k, "llama.", "mllama.")] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["mllama.attention.cross_attention_layers"] = m.TextModel.CrossAttentionLayers
|
||||||
|
|
||||||
|
kv["mllama.vision.block_count"] = m.VisionModel.NumHiddenLayers
|
||||||
|
kv["mllama.vision.global.block_count"] = m.VisionModel.NumGlobalLayers
|
||||||
|
kv["mllama.vision.intermediate_layers_indices"] = m.VisionModel.IntermediateLayersIndices
|
||||||
|
|
||||||
|
kv["mllama.vision.embedding_length"] = m.VisionModel.HiddenSize
|
||||||
|
kv["mllama.vision.feed_forward_length"] = m.VisionModel.IntermediateSize
|
||||||
|
|
||||||
|
kv["mllama.vision.attention.head_count"] = m.VisionModel.AttentionHeads
|
||||||
|
kv["mllama.vision.attention.layer_norm_epsilon"] = m.VisionModel.NormEpsilon
|
||||||
|
|
||||||
|
kv["mllama.vision.image_size"] = m.VisionModel.ImageSize
|
||||||
|
kv["mllama.vision.patch_size"] = m.VisionModel.PatchSize
|
||||||
|
kv["mllama.vision.max_num_tiles"] = m.VisionModel.MaxNumTiles
|
||||||
|
kv["mllama.vision.num_channels"] = m.VisionModel.NumChannels
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mllamaModel) Replacements() []string {
|
||||||
|
return append(
|
||||||
|
m.TextModel.Replacements(),
|
||||||
|
"language_model.", "",
|
||||||
|
"gate_attn", "attn_gate",
|
||||||
|
"gate_ffn", "ffn_gate",
|
||||||
|
"cross_attn.", "cross_attn_",
|
||||||
|
"vision_model", "v",
|
||||||
|
"class_embedding", "class_embd",
|
||||||
|
"patch_embedding", "patch_embd",
|
||||||
|
"gated_positional_embedding.tile_embedding", "tile_position_embd",
|
||||||
|
"gated_positional_embedding.embedding", "position_embd.weight",
|
||||||
|
"gated_positional_embedding", "position_embd",
|
||||||
|
"embedding.weight", "weight",
|
||||||
|
"pre_tile_positional_embedding", "pre_tile_position_embd",
|
||||||
|
"post_tile_positional_embedding", "post_tile_position_embd",
|
||||||
|
"layernorm_pre", "pre_ln",
|
||||||
|
"layernorm_post", "post_ln",
|
||||||
|
"global_transformer.layers", "global.blk",
|
||||||
|
"transformer.layers", "blk",
|
||||||
|
"mlp.fc1", "ffn_up",
|
||||||
|
"mlp.fc2", "ffn_down",
|
||||||
|
"multi_modal_projector", "mm.0",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
var text []Tensor
|
||||||
|
for _, t := range ts {
|
||||||
|
if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") {
|
||||||
|
text = append(text, t)
|
||||||
|
} else if t.Name() == "v.position_embd.gate" {
|
||||||
|
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
|
||||||
|
tt := t.Clone()
|
||||||
|
tt.SetRepacker(m.repack(name))
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: name,
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: tt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
|
||||||
|
t.SetRepacker(m.repack(t.Name()))
|
||||||
|
} else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
|
||||||
|
t.SetRepacker(m.repack(t.Name()))
|
||||||
|
} else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") {
|
||||||
|
t.SetRepacker(m.repack(t.Name()))
|
||||||
|
}
|
||||||
|
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return append(out, m.TextModel.Tensors(text)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mllamaModel) repack(name string) Repacker {
|
||||||
|
return func(_ string, data []float32, shape []uint64) (_ []float32, err error) {
|
||||||
|
dims := make([]int, len(shape))
|
||||||
|
for i, dim := range shape {
|
||||||
|
dims[i] = int(dim)
|
||||||
|
}
|
||||||
|
|
||||||
|
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
|
||||||
|
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") {
|
||||||
|
heads := m.VisionModel.AttentionHeads
|
||||||
|
if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.T(0, 2, 1, 3); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.Reshape(dims...); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := t.Transpose(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t, err = tensor.Tanh(t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if name == "v.position_embd.gate" {
|
||||||
|
t, err = tensor.Sub(float32(1), t)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t = tensor.Materialize(t)
|
||||||
|
// flatten tensor so it can be return as a vector
|
||||||
|
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return native.VectorF32(t.(*tensor.Dense))
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -68,19 +68,19 @@ func (p *phi3Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
func (p *phi3Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var addRopeFactors sync.Once
|
var addRopeFactors sync.Once
|
||||||
|
|
||||||
out := make([]ggml.Tensor, 0, len(ts)+2)
|
out := make([]*ggml.Tensor, 0, len(ts)+2)
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if strings.HasPrefix(t.Name(), "blk.0.") {
|
if strings.HasPrefix(t.Name(), "blk.0.") {
|
||||||
addRopeFactors.Do(func() {
|
addRopeFactors.Do(func() {
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: "rope_factors_long.weight",
|
Name: "rope_factors_long.weight",
|
||||||
Kind: 0,
|
Kind: 0,
|
||||||
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
Shape: []uint64{uint64(len(p.RopeScaling.LongFactor))},
|
||||||
WriterTo: p.RopeScaling.LongFactor,
|
WriterTo: p.RopeScaling.LongFactor,
|
||||||
}, ggml.Tensor{
|
}, &ggml.Tensor{
|
||||||
Name: "rope_factors_short.weight",
|
Name: "rope_factors_short.weight",
|
||||||
Kind: 0,
|
Kind: 0,
|
||||||
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
Shape: []uint64{uint64(len(p.RopeScaling.ShortFactor))},
|
||||||
@@ -89,7 +89,7 @@ func (p *phi3Model) Tensors(ts []Tensor) []ggml.Tensor {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
@@ -118,6 +118,5 @@ func (p *phi3Model) Replacements() []string {
|
|||||||
type ropeFactor []float32
|
type ropeFactor []float32
|
||||||
|
|
||||||
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
|
func (r ropeFactor) WriteTo(w io.Writer) (int64, error) {
|
||||||
err := binary.Write(w, binary.LittleEndian, r)
|
return 0, binary.Write(w, binary.LittleEndian, r)
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ type qwen2Model struct {
|
|||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
Factor ropeFactor `json:"factor"`
|
Factor ropeFactor `json:"factor"`
|
||||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||||
|
MropeSection []int32 `json:"mrope_section"`
|
||||||
} `json:"rope_scaling"`
|
} `json:"rope_scaling"`
|
||||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
}
|
}
|
||||||
@@ -39,16 +40,18 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
|
|||||||
case "yarn":
|
case "yarn":
|
||||||
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
|
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
|
||||||
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
|
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
|
||||||
|
case "mrope", "default":
|
||||||
|
kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection
|
||||||
default:
|
default:
|
||||||
panic("unknown rope scaling type")
|
panic("unknown rope scaling type")
|
||||||
}
|
}
|
||||||
return kv
|
return kv
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *qwen2Model) Tensors(ts []Tensor) []ggml.Tensor {
|
func (q *qwen2Model) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
var out []ggml.Tensor
|
var out []*ggml.Tensor
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
out = append(out, ggml.Tensor{
|
out = append(out, &ggml.Tensor{
|
||||||
Name: t.Name(),
|
Name: t.Name(),
|
||||||
Kind: t.Kind(),
|
Kind: t.Kind(),
|
||||||
Shape: t.Shape(),
|
Shape: t.Shape(),
|
||||||
|
|||||||
102
convert/convert_qwen25vl.go
Normal file
102
convert/convert_qwen25vl.go
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type qwen25VLModel struct {
|
||||||
|
qwen2Model
|
||||||
|
|
||||||
|
VisionModel struct {
|
||||||
|
Depth uint32 `json:"depth"`
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NumHeads uint32 `json:"num_heads"`
|
||||||
|
InChannels uint32 `json:"in_chans"`
|
||||||
|
PatchSize uint32 `json:"patch_size"`
|
||||||
|
SpatialMergeSize uint32 `json:"spatial_merge_size"`
|
||||||
|
SpatialPatchSize uint32 `json:"spatial_patch_size"`
|
||||||
|
WindowSize uint32 `json:"window_size"`
|
||||||
|
RMSNormEps float32 `json:"layer_norm_epsilon"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
FullAttentionBlocks []int32 `json:"fullatt_block_indexes"`
|
||||||
|
TemporalPatchSize uint32 `json:"temporal_patch_size"`
|
||||||
|
} `json:"vision_config"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ModelConverter = (*qwen25VLModel)(nil)
|
||||||
|
|
||||||
|
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := q.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "qwen25vl"
|
||||||
|
|
||||||
|
for k, v := range q.qwen2Model.KV(t) {
|
||||||
|
if strings.HasPrefix(k, "qwen2.") {
|
||||||
|
kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if q.VisionModel.FullAttentionBlocks == nil {
|
||||||
|
kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31}
|
||||||
|
}
|
||||||
|
|
||||||
|
kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32)
|
||||||
|
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
|
||||||
|
kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16)
|
||||||
|
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
|
||||||
|
kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14)
|
||||||
|
kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2)
|
||||||
|
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
|
||||||
|
kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112)
|
||||||
|
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
|
||||||
|
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4)
|
||||||
|
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
|
||||||
|
kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2)
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
var out []*ggml.Tensor
|
||||||
|
|
||||||
|
for _, t := range ts {
|
||||||
|
if strings.Contains(t.Name(), "patch_embed.proj") {
|
||||||
|
for t := range splitDim(t, 2,
|
||||||
|
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_0")},
|
||||||
|
split{Replacer: strings.NewReplacer("patch_embed.proj", "patch_embd_1")},
|
||||||
|
) {
|
||||||
|
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
|
||||||
|
out = append(out, t)
|
||||||
|
}
|
||||||
|
} else if strings.Contains(t.Name(), "attn.qkv") {
|
||||||
|
out = append(out, slices.Collect(splitDim(t, 0,
|
||||||
|
split{Replacer: strings.NewReplacer("attn.qkv", "attn_q")},
|
||||||
|
split{Replacer: strings.NewReplacer("attn.qkv", "attn_k")},
|
||||||
|
split{Replacer: strings.NewReplacer("attn.qkv", "attn_v")},
|
||||||
|
))...)
|
||||||
|
} else {
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *qwen25VLModel) Replacements() []string {
|
||||||
|
return append(
|
||||||
|
p.qwen2Model.Replacements(),
|
||||||
|
"visual", "v",
|
||||||
|
"blocks", "blk",
|
||||||
|
"attn.proj", "attn_out",
|
||||||
|
"norm1", "ln1",
|
||||||
|
"norm2", "ln2",
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -11,15 +11,13 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"math"
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/fs/ggml"
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,7 +46,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
|
|||||||
}
|
}
|
||||||
t.Cleanup(func() { r.Close() })
|
t.Cleanup(func() { r.Close() })
|
||||||
|
|
||||||
m, _, err := ggml.Decode(r, math.MaxInt)
|
m, err := ggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -131,15 +129,14 @@ func TestConvertModel(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
defer expectFile.Close()
|
||||||
|
|
||||||
var expect map[string]string
|
var expect map[string]string
|
||||||
if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
|
if err := json.NewDecoder(expectFile).Decode(&expect); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := maps.Keys(expect)
|
for _, k := range slices.Sorted(maps.Keys(expect)) {
|
||||||
slices.Sort(keys)
|
|
||||||
for _, k := range keys {
|
|
||||||
if v, ok := actual[k]; !ok {
|
if v, ok := actual[k]; !ok {
|
||||||
t.Errorf("missing %s", k)
|
t.Errorf("missing %s", k)
|
||||||
} else if v != expect[k] {
|
} else if v != expect[k] {
|
||||||
@@ -332,7 +329,7 @@ func TestConvertAdapter(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer r.Close()
|
defer r.Close()
|
||||||
|
|
||||||
m, _, err := ggml.Decode(r, math.MaxInt)
|
m, err := ggml.Decode(r, -1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -343,9 +340,7 @@ func TestConvertAdapter(t *testing.T) {
|
|||||||
|
|
||||||
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
|
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
|
||||||
|
|
||||||
keys := maps.Keys(c.Expected)
|
for _, k := range slices.Sorted(maps.Keys(c.Expected)) {
|
||||||
slices.Sort(keys)
|
|
||||||
for _, k := range keys {
|
|
||||||
if v, ok := actual[k]; !ok {
|
if v, ok := actual[k]; !ok {
|
||||||
t.Errorf("missing %s", k)
|
t.Errorf("missing %s", k)
|
||||||
} else if v != c.Expected[k] {
|
} else if v != c.Expected[k] {
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
package convert
|
|
||||||
|
|
||||||
import (
|
|
||||||
"archive/zip"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ZipReader struct {
|
|
||||||
r *zip.Reader
|
|
||||||
p string
|
|
||||||
|
|
||||||
// limit is the maximum size of a file that can be read directly
|
|
||||||
// from the zip archive. Files larger than this size will be extracted
|
|
||||||
limit int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewZipReader(r *zip.Reader, p string, limit int64) fs.FS {
|
|
||||||
return &ZipReader{r, p, limit}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (z *ZipReader) Open(name string) (fs.File, error) {
|
|
||||||
r, err := z.r.Open(name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer r.Close()
|
|
||||||
|
|
||||||
if fi, err := r.Stat(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if fi.Size() < z.limit {
|
|
||||||
return r, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if !filepath.IsLocal(name) {
|
|
||||||
return nil, zip.ErrInsecurePath
|
|
||||||
}
|
|
||||||
|
|
||||||
n := filepath.Join(z.p, name)
|
|
||||||
if _, err := os.Stat(n); errors.Is(err, os.ErrNotExist) {
|
|
||||||
w, err := os.Create(n)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer w.Close()
|
|
||||||
|
|
||||||
if _, err := io.Copy(w, r); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return os.Open(n)
|
|
||||||
}
|
|
||||||
@@ -11,14 +11,15 @@ type Tensor interface {
|
|||||||
Name() string
|
Name() string
|
||||||
Shape() []uint64
|
Shape() []uint64
|
||||||
Kind() uint32
|
Kind() uint32
|
||||||
SetRepacker(repacker)
|
SetRepacker(Repacker)
|
||||||
WriteTo(io.Writer) (int64, error)
|
WriteTo(io.Writer) (int64, error)
|
||||||
|
Clone() Tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
type tensorBase struct {
|
type tensorBase struct {
|
||||||
name string
|
name string
|
||||||
shape []uint64
|
shape []uint64
|
||||||
repacker
|
repacker Repacker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t tensorBase) Name() string {
|
func (t tensorBase) Name() string {
|
||||||
@@ -30,42 +31,46 @@ func (t tensorBase) Shape() []uint64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
tensorKindF32 uint32 = iota
|
tensorKindFP32 uint32 = iota
|
||||||
tensorKindF16
|
tensorKindFP16
|
||||||
|
tensorKindBF16 = 30
|
||||||
|
tensorKindMXFP4 = 39
|
||||||
)
|
)
|
||||||
|
|
||||||
func (t tensorBase) Kind() uint32 {
|
func (t tensorBase) Kind() uint32 {
|
||||||
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
if strings.HasSuffix(t.name, ".ffn_gate_inp.weight") ||
|
||||||
t.name == "token_types.weight" {
|
strings.HasSuffix(t.name, ".bias") ||
|
||||||
|
t.name == "token_types.weight" ||
|
||||||
|
t.name == "v.positional_embedding_vlm" ||
|
||||||
|
t.name == "v.tile_position_embd.weight" ||
|
||||||
|
t.name == "v.pre_tile_position_embd.weight" ||
|
||||||
|
t.name == "v.post_tile_position_embd.weight" {
|
||||||
// these tensors are always F32
|
// these tensors are always F32
|
||||||
return 0
|
return tensorKindFP32
|
||||||
}
|
}
|
||||||
|
|
||||||
switch len(t.shape) {
|
switch len(t.shape) {
|
||||||
case 0:
|
case 0:
|
||||||
panic("invalid tensor shape")
|
panic("invalid tensor shape")
|
||||||
case 1:
|
case 1:
|
||||||
return tensorKindF32
|
return tensorKindFP32
|
||||||
default:
|
default:
|
||||||
return tensorKindF16
|
return tensorKindFP16
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *tensorBase) SetRepacker(fn repacker) {
|
func (t *tensorBase) SetRepacker(fn Repacker) {
|
||||||
t.repacker = fn
|
t.repacker = fn
|
||||||
}
|
}
|
||||||
|
|
||||||
type repacker func(string, []float32, []uint64) ([]float32, error)
|
type Repacker func(string, []float32, []uint64) ([]float32, error)
|
||||||
|
|
||||||
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
func parseTensors(fsys fs.FS, replacer *strings.Replacer) ([]Tensor, error) {
|
||||||
patterns := []struct {
|
patterns := []struct {
|
||||||
Pattern string
|
Pattern string
|
||||||
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
Func func(fs.FS, *strings.Replacer, ...string) ([]Tensor, error)
|
||||||
}{
|
}{
|
||||||
{"model-*-of-*.safetensors", parseSafetensors},
|
{"*.safetensors", parseSafetensors},
|
||||||
{"model.safetensors", parseSafetensors},
|
|
||||||
{"adapters.safetensors", parseSafetensors},
|
|
||||||
{"adapter_model.safetensors", parseSafetensors},
|
|
||||||
{"pytorch_model-*-of-*.bin", parseTorch},
|
{"pytorch_model-*-of-*.bin", parseTorch},
|
||||||
{"pytorch_model.bin", parseTorch},
|
{"pytorch_model.bin", parseTorch},
|
||||||
{"consolidated.*.pth", parseTorch},
|
{"consolidated.*.pth", parseTorch},
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package convert
|
package convert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -8,12 +9,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"maps"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/d4l3k/go-bfloat16"
|
"github.com/d4l3k/go-bfloat16"
|
||||||
"github.com/x448/float16"
|
"github.com/x448/float16"
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type safetensorMetadata struct {
|
type safetensorMetadata struct {
|
||||||
@@ -46,8 +47,7 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := maps.Keys(headers)
|
keys := slices.Sorted(maps.Keys(headers))
|
||||||
slices.Sort(keys)
|
|
||||||
|
|
||||||
names := make(map[string]struct{}, len(keys))
|
names := make(map[string]struct{}, len(keys))
|
||||||
|
|
||||||
@@ -94,6 +94,30 @@ type safetensor struct {
|
|||||||
*tensorBase
|
*tensorBase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (st safetensor) Kind() uint32 {
|
||||||
|
kind := st.tensorBase.Kind()
|
||||||
|
if st.dtype == "BF16" && kind != tensorKindFP32 {
|
||||||
|
kind = tensorKindBF16
|
||||||
|
}
|
||||||
|
|
||||||
|
return kind
|
||||||
|
}
|
||||||
|
|
||||||
|
func (st safetensor) Clone() Tensor {
|
||||||
|
return &safetensor{
|
||||||
|
fs: st.fs,
|
||||||
|
path: st.path,
|
||||||
|
dtype: st.dtype,
|
||||||
|
offset: st.offset,
|
||||||
|
size: st.size,
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: st.name,
|
||||||
|
repacker: st.repacker,
|
||||||
|
shape: slices.Clone(st.shape),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||||
f, err := st.fs.Open(st.path)
|
f, err := st.fs.Open(st.path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -101,26 +125,41 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
|||||||
}
|
}
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
if seeker, ok := f.(io.Seeker); ok {
|
r, err := func() (io.Reader, error) {
|
||||||
if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
|
if readerAt, ok := f.(io.ReaderAt); ok {
|
||||||
return 0, err
|
return io.NewSectionReader(readerAt, st.offset, st.size), nil
|
||||||
}
|
} else if seeker, ok := f.(io.Seeker); ok {
|
||||||
} else {
|
_, err := seeker.Seek(st.offset, io.SeekStart)
|
||||||
if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
|
return f, err
|
||||||
return 0, err
|
} else {
|
||||||
|
_, err := io.CopyN(io.Discard, f, st.offset)
|
||||||
|
return f, err
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
br := bufio.NewReaderSize(r, min(32<<10, int(st.size)))
|
||||||
|
// special case when input and output are same type and the
|
||||||
|
// tensor doesn't need repacking
|
||||||
|
if (st.repacker == nil) &&
|
||||||
|
((st.dtype == "F32" && st.Kind() == tensorKindFP32) ||
|
||||||
|
(st.dtype == "F16" && st.Kind() == tensorKindFP16) ||
|
||||||
|
(st.dtype == "U8")) {
|
||||||
|
return io.CopyN(w, br, st.size)
|
||||||
}
|
}
|
||||||
|
|
||||||
var f32s []float32
|
var f32s []float32
|
||||||
switch st.dtype {
|
switch st.dtype {
|
||||||
case "F32":
|
case "F32":
|
||||||
f32s = make([]float32, st.size/4)
|
f32s = make([]float32, st.size/4)
|
||||||
if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
|
if err = binary.Read(br, binary.LittleEndian, f32s); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
case "F16":
|
case "F16":
|
||||||
u16s := make([]uint16, st.size/2)
|
u16s := make([]uint16, st.size/2)
|
||||||
if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
|
if err = binary.Read(br, binary.LittleEndian, u16s); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +170,7 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
|||||||
|
|
||||||
case "BF16":
|
case "BF16":
|
||||||
u8s := make([]uint8, st.size)
|
u8s := make([]uint8, st.size)
|
||||||
if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
|
if err = binary.Read(br, binary.LittleEndian, u8s); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,15 +187,18 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
switch st.Kind() {
|
switch st.Kind() {
|
||||||
case tensorKindF32:
|
case tensorKindFP32:
|
||||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
|
||||||
case tensorKindF16:
|
case tensorKindFP16:
|
||||||
f16s := make([]uint16, len(f32s))
|
f16s := make([]uint16, len(f32s))
|
||||||
for i := range f32s {
|
for i := range f32s {
|
||||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
|
||||||
|
case tensorKindBF16:
|
||||||
|
u8s := bfloat16.EncodeFloat32(f32s)
|
||||||
|
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
|
||||||
default:
|
default:
|
||||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||||
}
|
}
|
||||||
|
|||||||
232
convert/reader_test.go
Normal file
232
convert/reader_test.go
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/d4l3k/go-bfloat16"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/x448/float16"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSafetensors(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
root, err := os.OpenRoot(t.TempDir())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer root.Close()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name,
|
||||||
|
dtype string
|
||||||
|
offset,
|
||||||
|
size int64
|
||||||
|
shape []uint64
|
||||||
|
setup func(*testing.T, *os.File)
|
||||||
|
want []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "fp32-fp32",
|
||||||
|
dtype: "F32",
|
||||||
|
size: 32 * 4, // 32 floats, each 4 bytes
|
||||||
|
shape: []uint64{32},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
f32s := make([]float32, 32)
|
||||||
|
for i := range f32s {
|
||||||
|
f32s[i] = float32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||||
|
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||||
|
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||||
|
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||||
|
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||||
|
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||||
|
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||||
|
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fp32-fp16",
|
||||||
|
dtype: "F32",
|
||||||
|
size: 32 * 4, // 32 floats, each 4 bytes
|
||||||
|
shape: []uint64{16, 2},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
f32s := make([]float32, 32)
|
||||||
|
for i := range f32s {
|
||||||
|
f32s[i] = float32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||||
|
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||||
|
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||||
|
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fp16-fp16",
|
||||||
|
dtype: "F16",
|
||||||
|
size: 32 * 2, // 32 floats, each 2 bytes
|
||||||
|
shape: []uint64{16, 2},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
u16s := make([]uint16, 32)
|
||||||
|
for i := range u16s {
|
||||||
|
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||||
|
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||||
|
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||||
|
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "fp16-fp32",
|
||||||
|
dtype: "F16",
|
||||||
|
size: 32 * 2, // 32 floats, each 2 bytes
|
||||||
|
shape: []uint64{32},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
u16s := make([]uint16, 32)
|
||||||
|
for i := range u16s {
|
||||||
|
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||||
|
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||||
|
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||||
|
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||||
|
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||||
|
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||||
|
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||||
|
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bf16-bf16",
|
||||||
|
dtype: "BF16",
|
||||||
|
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||||
|
shape: []uint64{16, 2},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
f32s := make([]float32, 32)
|
||||||
|
for i := range f32s {
|
||||||
|
f32s[i] = float32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x80, 0x3f, 0x00, 0x40, 0x40, 0x40, 0x80, 0x40, 0xa0, 0x40, 0xc0, 0x40, 0xe0, 0x40,
|
||||||
|
0x00, 0x41, 0x10, 0x41, 0x20, 0x41, 0x30, 0x41, 0x40, 0x41, 0x50, 0x41, 0x60, 0x41, 0x70, 0x41,
|
||||||
|
0x80, 0x41, 0x88, 0x41, 0x90, 0x41, 0x98, 0x41, 0xa0, 0x41, 0xa8, 0x41, 0xb0, 0x41, 0xb8, 0x41,
|
||||||
|
0xc0, 0x41, 0xc8, 0x41, 0xd0, 0x41, 0xd8, 0x41, 0xe0, 0x41, 0xe8, 0x41, 0xf0, 0x41, 0xf8, 0x41,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bf16-fp32",
|
||||||
|
dtype: "BF16",
|
||||||
|
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||||
|
shape: []uint64{32},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
f32s := make([]float32, 32)
|
||||||
|
for i := range f32s {
|
||||||
|
f32s[i] = float32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||||
|
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||||
|
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||||
|
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||||
|
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||||
|
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||||
|
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||||
|
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "u8-u8",
|
||||||
|
dtype: "U8",
|
||||||
|
size: 32, // 32 brain floats, each 1 bytes
|
||||||
|
shape: []uint64{32},
|
||||||
|
setup: func(t *testing.T, f *os.File) {
|
||||||
|
u8s := make([]uint8, 32)
|
||||||
|
for i := range u8s {
|
||||||
|
u8s[i] = uint8(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, u8s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
want: []byte{
|
||||||
|
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||||
|
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
path := filepath.Base(t.Name())
|
||||||
|
st := safetensor{
|
||||||
|
fs: root.FS(),
|
||||||
|
path: path,
|
||||||
|
dtype: tt.dtype,
|
||||||
|
offset: tt.offset,
|
||||||
|
size: tt.size,
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: tt.name,
|
||||||
|
shape: tt.shape,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := root.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
tt.setup(t, f)
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := st.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.want, b.Bytes()); diff != "" {
|
||||||
|
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -43,6 +43,17 @@ type torch struct {
|
|||||||
*tensorBase
|
*tensorBase
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t torch) Clone() Tensor {
|
||||||
|
return torch{
|
||||||
|
storage: t.storage,
|
||||||
|
tensorBase: &tensorBase{
|
||||||
|
name: t.name,
|
||||||
|
shape: t.shape,
|
||||||
|
repacker: t.repacker,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (pt torch) WriteTo(w io.Writer) (int64, error) {
|
func (pt torch) WriteTo(w io.Writer) (int64, error) {
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1360,7 +1360,7 @@ func file_sentencepiece_model_proto_rawDescGZIP() []byte {
|
|||||||
|
|
||||||
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
var file_sentencepiece_model_proto_enumTypes = make([]protoimpl.EnumInfo, 2)
|
||||||
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
var file_sentencepiece_model_proto_msgTypes = make([]protoimpl.MessageInfo, 6)
|
||||||
var file_sentencepiece_model_proto_goTypes = []interface{}{
|
var file_sentencepiece_model_proto_goTypes = []any{
|
||||||
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
|
(TrainerSpec_ModelType)(0), // 0: sentencepiece.TrainerSpec.ModelType
|
||||||
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
|
(ModelProto_SentencePiece_Type)(0), // 1: sentencepiece.ModelProto.SentencePiece.Type
|
||||||
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
|
(*TrainerSpec)(nil), // 2: sentencepiece.TrainerSpec
|
||||||
@@ -1392,7 +1392,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !protoimpl.UnsafeEnabled {
|
if !protoimpl.UnsafeEnabled {
|
||||||
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[0].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*TrainerSpec); i {
|
switch v := v.(*TrainerSpec); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1406,7 +1406,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[1].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*NormalizerSpec); i {
|
switch v := v.(*NormalizerSpec); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1420,7 +1420,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[2].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*SelfTestData); i {
|
switch v := v.(*SelfTestData); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1434,7 +1434,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[3].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*ModelProto); i {
|
switch v := v.(*ModelProto); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1448,7 +1448,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[4].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*SelfTestData_Sample); i {
|
switch v := v.(*SelfTestData_Sample); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
@@ -1460,7 +1460,7 @@ func file_sentencepiece_model_proto_init() {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} {
|
file_sentencepiece_model_proto_msgTypes[5].Exporter = func(v any, i int) any {
|
||||||
switch v := v.(*ModelProto_SentencePiece); i {
|
switch v := v.(*ModelProto_SentencePiece); i {
|
||||||
case 0:
|
case 0:
|
||||||
return &v.state
|
return &v.state
|
||||||
|
|||||||
129
convert/tensor.go
Normal file
129
convert/tensor.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"io"
|
||||||
|
"iter"
|
||||||
|
"path"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
"github.com/pdevine/tensor/native"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type split struct {
|
||||||
|
*strings.Replacer
|
||||||
|
dim int
|
||||||
|
|
||||||
|
// fn is an optional function to apply to the tensor after slicing
|
||||||
|
fn func(tensor.Tensor) (tensor.Tensor, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
|
||||||
|
// is split evenly based on the number of replacers provided unless a specific count is given.
|
||||||
|
func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||||
|
return func(yield func(*ggml.Tensor) bool) {
|
||||||
|
var offset int
|
||||||
|
for _, split := range splits {
|
||||||
|
t := t.Clone()
|
||||||
|
shape := slices.Clone(t.Shape())
|
||||||
|
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
|
||||||
|
|
||||||
|
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||||
|
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||||
|
offset += int(shape[dim])
|
||||||
|
|
||||||
|
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||||
|
dims := make([]int, len(shape))
|
||||||
|
for i := range shape {
|
||||||
|
dims[i] = int(shape[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
|
||||||
|
tt, err := tt.Slice(slice...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tt = tensor.Materialize(tt)
|
||||||
|
|
||||||
|
if split.fn != nil {
|
||||||
|
tt, err = split.fn(tt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flatten tensor so it can be written as a vector
|
||||||
|
if err := tt.Reshape(tt.Shape().TotalSize()); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return native.VectorF32(tt.(*tensor.Dense))
|
||||||
|
})
|
||||||
|
|
||||||
|
if !yield(&ggml.Tensor{
|
||||||
|
Name: split.Replace(t.Name()),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: shape,
|
||||||
|
WriterTo: t,
|
||||||
|
}) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type merge struct {
|
||||||
|
pattern, name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// mergeTensors merges tensors that match a given pattern into a single tensor.
|
||||||
|
func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []Tensor) {
|
||||||
|
var matched []Tensor
|
||||||
|
for i := range merges {
|
||||||
|
matched, unmatched = slicesSplitFunc(unmatched, func(t Tensor) bool {
|
||||||
|
matched, _ := path.Match(merges[i].pattern, t.Name())
|
||||||
|
return matched
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(matched) > 0 {
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: merges[i].name,
|
||||||
|
Kind: matched[0].Kind(),
|
||||||
|
Shape: append([]uint64{uint64(len(matched))}, matched[0].Shape()...),
|
||||||
|
WriterTo: mergeGroup(matched),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out, unmatched
|
||||||
|
}
|
||||||
|
|
||||||
|
// slicesSplitFunc splits a slice into two slices based on a predicate function.
|
||||||
|
func slicesSplitFunc[S ~[]E, E comparable](s S, fn func(e E) bool) (matched, unmatched S) {
|
||||||
|
for _, e := range s {
|
||||||
|
if fn(e) {
|
||||||
|
matched = append(matched, e)
|
||||||
|
} else {
|
||||||
|
unmatched = append(unmatched, e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return matched, unmatched
|
||||||
|
}
|
||||||
|
|
||||||
|
type mergeGroup []Tensor
|
||||||
|
|
||||||
|
func (g mergeGroup) WriteTo(w io.Writer) (int64, error) {
|
||||||
|
for _, t := range g {
|
||||||
|
if _, err := t.WriteTo(w); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
953
convert/tensor_test.go
Normal file
953
convert/tensor_test.go
Normal file
@@ -0,0 +1,953 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"iter"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/pdevine/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeTensor struct {
|
||||||
|
name string
|
||||||
|
shape []uint64
|
||||||
|
data []float32
|
||||||
|
|
||||||
|
repacker Repacker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeTensor) Name() string {
|
||||||
|
return f.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeTensor) Shape() []uint64 {
|
||||||
|
return f.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeTensor) Kind() uint32 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeTensor) SetRepacker(fn Repacker) {
|
||||||
|
f.repacker = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeTensor) Clone() Tensor {
|
||||||
|
return &fakeTensor{
|
||||||
|
name: f.name,
|
||||||
|
shape: slices.Clone(f.shape),
|
||||||
|
data: slices.Clone(f.data),
|
||||||
|
repacker: f.repacker,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f fakeTensor) WriteTo(w io.Writer) (n int64, err error) {
|
||||||
|
data := f.data
|
||||||
|
if f.repacker != nil {
|
||||||
|
data, err = f.repacker(f.name, data, f.shape)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(w, binary.LittleEndian, data); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return int64(len(data) * 4), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mul(shape []uint64) int {
|
||||||
|
n := 1
|
||||||
|
for _, dim := range shape {
|
||||||
|
n *= int(dim)
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitDim(t *testing.T) {
|
||||||
|
t.Run("2d", func(t *testing.T) {
|
||||||
|
r := fakeTensor{
|
||||||
|
name: "a.b",
|
||||||
|
shape: []uint64{3, 4},
|
||||||
|
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("no split", func(t *testing.T) {
|
||||||
|
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 4}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("even split", func(t *testing.T) {
|
||||||
|
next, stop := iter.Pull(splitDim(&r, 1,
|
||||||
|
split{Replacer: strings.NewReplacer("a", "x")},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "y")},
|
||||||
|
))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.y" {
|
||||||
|
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{2, 3, 6, 7, 10, 11}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uneven split", func(t *testing.T) {
|
||||||
|
next, stop := iter.Pull(splitDim(&r, 0,
|
||||||
|
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||||
|
))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{2, 4}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.y" {
|
||||||
|
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("three way split", func(t *testing.T) {
|
||||||
|
next, stop := iter.Pull(splitDim(&r, 0,
|
||||||
|
split{Replacer: strings.NewReplacer("a", "x"), dim: 1},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||||
|
))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.y" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.z" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{1, 4}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uneven three way split", func(t *testing.T) {
|
||||||
|
next, stop := iter.Pull(splitDim(&r, 1,
|
||||||
|
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||||
|
))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.y" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{2, 6, 10}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.z" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 1}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{3, 7, 11}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("split with transpose", func(t *testing.T) {
|
||||||
|
next, stop := iter.Pull(splitDim(&r, 1,
|
||||||
|
split{Replacer: strings.NewReplacer("a", "x")},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) {
|
||||||
|
return tensor.Transpose(tt, 1, 0)
|
||||||
|
}},
|
||||||
|
))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 4, 5, 8, 9}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.y" {
|
||||||
|
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{2, 6, 10, 3, 7, 11}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("3d", func(t *testing.T) {
|
||||||
|
r := fakeTensor{
|
||||||
|
name: "a.b",
|
||||||
|
shape: []uint64{3, 4, 2},
|
||||||
|
data: []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("no split", func(t *testing.T) {
|
||||||
|
for tt := range splitDim(&r, 0, split{Replacer: strings.NewReplacer("a", "x")}) {
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatalf("expected name 'x', got '%s'", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 4, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("even split", func(t *testing.T) {
|
||||||
|
next, stop := iter.Pull(splitDim(&r, 1,
|
||||||
|
split{Replacer: strings.NewReplacer("a", "x")},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "y")},
|
||||||
|
))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.y" {
|
||||||
|
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uneven split", func(t *testing.T) {
|
||||||
|
next, stop := iter.Pull(splitDim(&r, 0,
|
||||||
|
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||||
|
))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{2, 4, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.y" {
|
||||||
|
t.Fatal("expected name 'a.y', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("three way split", func(t *testing.T) {
|
||||||
|
next, stop := iter.Pull(splitDim(&r, 0,
|
||||||
|
split{Replacer: strings.NewReplacer("a", "x"), dim: 1},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||||
|
))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 4, 5, 6, 7}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.y" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{8, 9, 10, 11, 12, 13, 14, 15}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.z" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{1, 4, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{16, 17, 18, 19, 20, 21, 22, 23}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uneven three way split", func(t *testing.T) {
|
||||||
|
next, stop := iter.Pull(splitDim(&r, 1,
|
||||||
|
split{Replacer: strings.NewReplacer("a", "x"), dim: 2},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "y"), dim: 1},
|
||||||
|
split{Replacer: strings.NewReplacer("b", "z"), dim: 1},
|
||||||
|
))
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "x.b" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 2, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.y" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{4, 5, 12, 13, 20, 21}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
tt, ok := next()
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("expected at least one split")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.Name != "a.z" {
|
||||||
|
t.Fatal("expected name 'x.b', got", tt.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tt.Shape, []uint64{3, 1, 2}); diff != "" {
|
||||||
|
t.Errorf("unexpected shape (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := tt.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, mul(tt.Shape))
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f32s, []float32{6, 7, 14, 15, 22, 23}); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMerge(t *testing.T) {
|
||||||
|
unmatched := []Tensor{
|
||||||
|
&fakeTensor{
|
||||||
|
name: "a.0.b",
|
||||||
|
shape: []uint64{5, 2},
|
||||||
|
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
|
||||||
|
},
|
||||||
|
&fakeTensor{
|
||||||
|
name: "a.1.b",
|
||||||
|
shape: []uint64{5, 2},
|
||||||
|
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29},
|
||||||
|
},
|
||||||
|
&fakeTensor{
|
||||||
|
name: "c.0.d",
|
||||||
|
shape: []uint64{5, 2},
|
||||||
|
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39},
|
||||||
|
},
|
||||||
|
&fakeTensor{
|
||||||
|
name: "c.1.d",
|
||||||
|
shape: []uint64{5, 2},
|
||||||
|
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49},
|
||||||
|
},
|
||||||
|
&fakeTensor{
|
||||||
|
name: "e.0.f",
|
||||||
|
shape: []uint64{5, 2},
|
||||||
|
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) {
|
||||||
|
for i := range n {
|
||||||
|
got := matched[i]
|
||||||
|
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" {
|
||||||
|
t.Errorf("unexpected (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := got.WriteTo(&b); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f32s := make([]float32, 20)
|
||||||
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := 10 + (i * 20)
|
||||||
|
want := make([]float32, 20)
|
||||||
|
for j := range 20 {
|
||||||
|
want[j] = float32(offset + j)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(want, f32s); diff != "" {
|
||||||
|
t.Errorf("unexpected data (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("single merge", func(t *testing.T) {
|
||||||
|
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"})
|
||||||
|
if len(unmatched) != 3 {
|
||||||
|
t.Error("expected 3 remaining tensors, got", len(unmatched))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matched) != 1 {
|
||||||
|
t.Error("expected 1 merged tensor, got", len(matched))
|
||||||
|
}
|
||||||
|
|
||||||
|
checkMatched(t, 1, matched)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple merges", func(t *testing.T) {
|
||||||
|
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"})
|
||||||
|
if len(unmatched) != 1 {
|
||||||
|
t.Error("expected 1 remaining tensors, got", len(unmatched))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matched) != 2 {
|
||||||
|
t.Error("expected 2 merged tensor, got", len(matched))
|
||||||
|
}
|
||||||
|
|
||||||
|
checkMatched(t, 2, matched)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no match", func(t *testing.T) {
|
||||||
|
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"})
|
||||||
|
if len(unmatched) != 5 {
|
||||||
|
t.Error("expected 5 remaining tensors, got", len(unmatched))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matched) != 0 {
|
||||||
|
t.Error("expected no merged tensors, got", len(matched))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -8,11 +8,10 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"maps"
|
||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"golang.org/x/exp/maps"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -110,6 +109,7 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
|
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||||
|
// noop
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
@@ -171,6 +171,34 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) {
|
||||||
|
} else if err != nil {
|
||||||
|
return nil, err
|
||||||
|
} else {
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
var p map[string]json.RawMessage
|
||||||
|
if err := json.NewDecoder(f).Decode(&p); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, st := range specialTokenTypes {
|
||||||
|
if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok {
|
||||||
|
var ids []int32
|
||||||
|
if err := json.Unmarshal(bts, &ids); err != nil {
|
||||||
|
// value is not a list so the existing ID is used
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool {
|
||||||
|
return sv.Type == st
|
||||||
|
}); i >= 0 {
|
||||||
|
t.SpecialVocabulary[i].IDs = ids
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -231,11 +259,8 @@ func parseVocabularyFromTokenizer(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
tokens[token.ID] = token
|
tokens[token.ID] = token
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := maps.Keys(tokens)
|
|
||||||
slices.Sort(keys)
|
|
||||||
|
|
||||||
v := Vocabulary{Model: "gpt2"}
|
v := Vocabulary{Model: "gpt2"}
|
||||||
for _, k := range keys {
|
for _, k := range slices.Sorted(maps.Keys(tokens)) {
|
||||||
token := tokens[k]
|
token := tokens[k]
|
||||||
v.Tokens = append(v.Tokens, token.Content)
|
v.Tokens = append(v.Tokens, token.Content)
|
||||||
v.Scores = append(v.Scores, float32(token.ID))
|
v.Scores = append(v.Scores, float32(token.ID))
|
||||||
@@ -280,6 +305,9 @@ type SpecialVocabulary struct {
|
|||||||
ID int
|
ID int
|
||||||
Content string
|
Content string
|
||||||
AddToken bool
|
AddToken bool
|
||||||
|
|
||||||
|
// IDs is populated by generation_config.json
|
||||||
|
IDs []int32
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sv SpecialVocabulary) Key() string {
|
func (sv SpecialVocabulary) Key() string {
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/fs"
|
"io/fs"
|
||||||
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
|
"reflect"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
@@ -15,6 +17,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
||||||
|
slog.Debug("using spm vocabulary")
|
||||||
|
|
||||||
ast, err := parseAdditionalSpecialTokens(fsys)
|
ast, err := parseAdditionalSpecialTokens(fsys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -43,10 +47,19 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
v.Types = append(v.Types, int32(t))
|
v.Types = append(v.Types, int32(t))
|
||||||
default:
|
default:
|
||||||
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
tt := int32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||||||
if slices.Contains(ast, piece.GetPiece()) {
|
|
||||||
|
// temporary fix to handle gemma3 broken configs
|
||||||
|
if slices.Contains([]string{"<end_of_turn>", "<start_of_turn>"}, piece.GetPiece()) {
|
||||||
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, t := range ast {
|
||||||
|
if t.Content == piece.GetPiece() {
|
||||||
|
tt = int32(sentencepiece.ModelProto_SentencePiece_CONTROL)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
v.Types = append(v.Types, tt)
|
v.Types = append(v.Types, tt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,10 +91,16 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
return cmp.Compare(i.id, j.id)
|
return cmp.Compare(i.id, j.id)
|
||||||
})
|
})
|
||||||
|
|
||||||
n := len(v.Tokens)
|
for _, t := range ts {
|
||||||
for i, t := range ts {
|
if t.id < len(v.Tokens) {
|
||||||
if t.id != i+n {
|
if v.Tokens[t.id] == t.content {
|
||||||
return nil, fmt.Errorf("invalid token id: %d", t.id)
|
slog.Warn("tokenizer", "duplicate token", t.content, "id", t.id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("token mismatch: %s != %s at pos [%d]", t.content, v.Tokens[t.id], t.id)
|
||||||
|
}
|
||||||
|
if t.id != len(v.Tokens) {
|
||||||
|
return nil, fmt.Errorf("invalid token id: [%d] as pos [%d]", t.id, len(v.Tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
v.Tokens = append(v.Tokens, t.content)
|
v.Tokens = append(v.Tokens, t.content)
|
||||||
@@ -92,7 +111,15 @@ func parseSentencePiece(fsys fs.FS) (*Vocabulary, error) {
|
|||||||
return &v, nil
|
return &v, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
|
type specialToken struct {
|
||||||
|
Content string `json:"content"`
|
||||||
|
Lstrip bool `json:"lstrip"`
|
||||||
|
Normalized bool `json:"normalized"`
|
||||||
|
Rstrip bool `json:"rstrip"`
|
||||||
|
SingleWord bool `json:"single_word"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseAdditionalSpecialTokens(fsys fs.FS) ([]specialToken, error) {
|
||||||
f, err := fsys.Open("special_tokens_map.json")
|
f, err := fsys.Open("special_tokens_map.json")
|
||||||
if errors.Is(err, os.ErrNotExist) {
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
@@ -102,12 +129,43 @@ func parseAdditionalSpecialTokens(fsys fs.FS) ([]string, error) {
|
|||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
var m struct {
|
var m struct {
|
||||||
AdditionalSpecialTokens []string `json:"additional_special_tokens"`
|
AdditionalSpecialTokens any `json:"additional_special_tokens"`
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
if err := json.NewDecoder(f).Decode(&m); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.AdditionalSpecialTokens, nil
|
var ast []specialToken
|
||||||
|
|
||||||
|
switch st := m.AdditionalSpecialTokens.(type) {
|
||||||
|
case []string:
|
||||||
|
for _, s := range st {
|
||||||
|
ast = append(ast, specialToken{Content: s})
|
||||||
|
}
|
||||||
|
case []any:
|
||||||
|
for _, s := range st {
|
||||||
|
// marshal and unmarshal the object to get the special token
|
||||||
|
tMap := s.(map[string]any)
|
||||||
|
data, err := json.Marshal(tMap)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var token specialToken
|
||||||
|
err = json.Unmarshal(data, &token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ast = append(ast, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
slog.Warn("special token", "unknown token", reflect.TypeOf(st))
|
||||||
|
}
|
||||||
|
|
||||||
|
slog.Debug("spm tokenizer", "additional tokens", ast)
|
||||||
|
|
||||||
|
return ast, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -247,6 +247,67 @@ func TestParseTokenizer(t *testing.T) {
|
|||||||
Pre: "default",
|
Pre: "default",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "generation config eos token ids",
|
||||||
|
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
|
||||||
|
"tokenizer.json": strings.NewReader(`{
|
||||||
|
"added_tokens": [
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"content": "<bos>",
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"content": "<eos>",
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 2,
|
||||||
|
"content": "<eot>",
|
||||||
|
"special": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 3,
|
||||||
|
"content": "<eom>",
|
||||||
|
"special": true
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": {
|
||||||
|
"vocab": {
|
||||||
|
"<bos>": 0,
|
||||||
|
"<eos>": 1,
|
||||||
|
"<eot>": 2,
|
||||||
|
"<eom>": 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`),
|
||||||
|
"tokenizer_config.json": strings.NewReader(`{
|
||||||
|
"add_bos_token": true,
|
||||||
|
"add_eos_token": false,
|
||||||
|
"bos_token": "<bos>",
|
||||||
|
"eos_token": "<eos>"
|
||||||
|
}`),
|
||||||
|
"generation_config.json": strings.NewReader(`{
|
||||||
|
"bos_token_id": 0,
|
||||||
|
"eos_token_id": [1, 2, 3]
|
||||||
|
}`),
|
||||||
|
}),
|
||||||
|
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
|
||||||
|
want: &Tokenizer{
|
||||||
|
Vocabulary: &Vocabulary{
|
||||||
|
Model: "gpt2",
|
||||||
|
Tokens: []string{"<bos>", "<eos>", "<eot>", "<eom>"},
|
||||||
|
Scores: []float32{0, 1, 2, 3},
|
||||||
|
Types: []int32{3, 3, 3, 3},
|
||||||
|
},
|
||||||
|
SpecialVocabulary: []*SpecialVocabulary{
|
||||||
|
{Type: "eos", Content: "<eos>", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false},
|
||||||
|
{Type: "bos", Content: "<bos>", ID: 0, AddToken: true},
|
||||||
|
},
|
||||||
|
Pre: "default",
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range cases {
|
for _, tt := range cases {
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
driverMajor, driverMinor, err := AMDDriverVersion()
|
driverMajor, driverMinor, err := AMDDriverVersion()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
// TODO - if we see users crash and burn with the upstreamed kernel this can be adjusted to hard-fail rocm support and fallback to CPU
|
||||||
slog.Warn("ollama recommends running the https://www.amd.com/en/support/linux-drivers", "error", err)
|
slog.Warn("ollama recommends running the https://www.amd.com/en/support/download/linux-drivers.html", "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
||||||
@@ -97,6 +97,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
return a < b
|
return a < b
|
||||||
})
|
})
|
||||||
gpuCount := 0
|
gpuCount := 0
|
||||||
|
gpuOrdinalID := 0
|
||||||
for _, match := range matches {
|
for _, match := range matches {
|
||||||
slog.Debug("evaluating amdgpu node " + match)
|
slog.Debug("evaluating amdgpu node " + match)
|
||||||
fp, err := os.Open(match)
|
fp, err := os.Open(match)
|
||||||
@@ -187,10 +188,6 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep track of numeric IDs based on valid GPUs
|
|
||||||
gpuID := gpuCount
|
|
||||||
gpuCount += 1
|
|
||||||
|
|
||||||
// Look up the memory for the current node
|
// Look up the memory for the current node
|
||||||
totalMemory := uint64(0)
|
totalMemory := uint64(0)
|
||||||
usedMemory := uint64(0)
|
usedMemory := uint64(0)
|
||||||
@@ -269,7 +266,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
if uniqueID != 0 {
|
if uniqueID != 0 {
|
||||||
ID = fmt.Sprintf("GPU-%016x", uniqueID)
|
ID = fmt.Sprintf("GPU-%016x", uniqueID)
|
||||||
} else {
|
} else {
|
||||||
ID = strconv.Itoa(gpuID)
|
ID = strconv.Itoa(gpuOrdinalID)
|
||||||
}
|
}
|
||||||
|
|
||||||
gpuInfo := RocmGPUInfo{
|
gpuInfo := RocmGPUInfo{
|
||||||
@@ -280,6 +277,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
FreeMemory: (totalMemory - usedMemory),
|
FreeMemory: (totalMemory - usedMemory),
|
||||||
},
|
},
|
||||||
ID: ID,
|
ID: ID,
|
||||||
|
filterID: gpuOrdinalID,
|
||||||
Name: name,
|
Name: name,
|
||||||
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
||||||
MinimumMemory: rocmMinimumMemory,
|
MinimumMemory: rocmMinimumMemory,
|
||||||
@@ -287,23 +285,50 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
DriverMinor: driverMinor,
|
DriverMinor: driverMinor,
|
||||||
},
|
},
|
||||||
usedFilepath: usedFile,
|
usedFilepath: usedFile,
|
||||||
index: gpuID,
|
index: gpuCount,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Keep track of numeric IDs based on valid GPUs
|
||||||
|
gpuCount += 1
|
||||||
|
|
||||||
|
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||||
|
if len(visibleDevices) > 0 {
|
||||||
|
include := false
|
||||||
|
for _, visible := range visibleDevices {
|
||||||
|
if (uniqueID != 0 && visible == gpuInfo.ID) || visible == strconv.Itoa(gpuInfo.index) {
|
||||||
|
include = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !include {
|
||||||
|
reason := "filtering out device per user request"
|
||||||
|
slog.Info(reason, "id", gpuInfo.ID, "index", gpuInfo.index, "visible_devices", visibleDevices)
|
||||||
|
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||||
|
GpuInfo: gpuInfo.GpuInfo,
|
||||||
|
Reason: reason,
|
||||||
|
})
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ordinal IDs are based on the visible GPUs
|
||||||
|
gpuOrdinalID += 1
|
||||||
|
|
||||||
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||||
if totalMemory < IGPUMemLimit {
|
if totalMemory < IGPUMemLimit {
|
||||||
reason := "unsupported Radeon iGPU detected skipping"
|
reason := "unsupported Radeon iGPU detected skipping"
|
||||||
slog.Info(reason, "id", gpuID, "total", format.HumanBytes2(totalMemory))
|
slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory))
|
||||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||||
GpuInfo: gpuInfo.GpuInfo,
|
GpuInfo: gpuInfo.GpuInfo,
|
||||||
Reason: reason,
|
Reason: reason,
|
||||||
})
|
})
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
minVer, err := strconv.Atoi(RocmComputeMajorMin)
|
//minVer, err := strconv.Atoi(RocmComputeMajorMin)
|
||||||
if err != nil {
|
//if err != nil {
|
||||||
slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err)
|
// slog.Error("invalid RocmComputeMajorMin setting", "value", RocmComputeMajorMin, "error", err)
|
||||||
}
|
//}
|
||||||
// if int(major) < minVer {
|
// if int(major) < minVer {
|
||||||
// reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch)
|
// reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch)
|
||||||
// slog.Warn(reason, "gpu", gpuID)
|
// slog.Warn(reason, "gpu", gpuID)
|
||||||
@@ -315,29 +340,8 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
// continue
|
// continue
|
||||||
//}
|
//}
|
||||||
|
|
||||||
slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "total", format.HumanBytes2(totalMemory))
|
||||||
slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||||
|
|
||||||
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
|
||||||
if len(visibleDevices) > 0 {
|
|
||||||
include := false
|
|
||||||
for _, visible := range visibleDevices {
|
|
||||||
if visible == gpuInfo.ID || visible == strconv.Itoa(gpuInfo.index) {
|
|
||||||
include = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !include {
|
|
||||||
reason := "filtering out device per user request"
|
|
||||||
slog.Info(reason, "id", gpuInfo.ID, "visible_devices", visibleDevices)
|
|
||||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
|
||||||
GpuInfo: gpuInfo.GpuInfo,
|
|
||||||
Reason: reason,
|
|
||||||
})
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Final validation is gfx compatibility - load the library if we haven't already loaded it
|
// Final validation is gfx compatibility - load the library if we haven't already loaded it
|
||||||
// even if the user overrides, we still need to validate the library
|
// even if the user overrides, we still need to validate the library
|
||||||
@@ -391,7 +395,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
|
|
||||||
// Check for env var workarounds
|
// Check for env var workarounds
|
||||||
if name == "1002:687f" { // Vega RX 56
|
if name == "1002:687f" { // Vega RX 56
|
||||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
|
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, "HSA_ENABLE_SDMA=0")
|
||||||
}
|
}
|
||||||
|
|
||||||
// The GPU has passed all the verification steps and is supported
|
// The GPU has passed all the verification steps and is supported
|
||||||
@@ -520,19 +524,26 @@ func verifyKFDDriverAccess() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||||
ids := []string{}
|
ids := []string{}
|
||||||
for _, info := range gpuInfo {
|
for _, info := range gpuInfo {
|
||||||
if info.Library != "rocm" {
|
if info.Library != "rocm" {
|
||||||
// TODO shouldn't happen if things are wired correctly...
|
|
||||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ids = append(ids, info.ID)
|
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||||
|
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||||
|
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||||
|
} else {
|
||||||
|
ids = append(ids, info.ID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// There are 3 potential env vars to use to select GPUs.
|
// There are 3 potential env vars to use to select GPUs.
|
||||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
|
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
|
||||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||||
return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",")
|
return "ROCR_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||||||
UnreliableFreeMemory: true,
|
UnreliableFreeMemory: true,
|
||||||
|
|
||||||
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
||||||
|
filterID: i,
|
||||||
DependencyPath: []string{libDir},
|
DependencyPath: []string{libDir},
|
||||||
MinimumMemory: rocmMinimumMemory,
|
MinimumMemory: rocmMinimumMemory,
|
||||||
Name: name,
|
Name: name,
|
||||||
@@ -200,19 +201,26 @@ func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||||
ids := []string{}
|
ids := []string{}
|
||||||
for _, info := range gpuInfo {
|
for _, info := range gpuInfo {
|
||||||
if info.Library != "rocm" {
|
if info.Library != "rocm" {
|
||||||
// TODO shouldn't happen if things are wired correctly...
|
|
||||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
ids = append(ids, info.ID)
|
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||||
|
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||||
|
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||||
|
} else {
|
||||||
|
ids = append(ids, info.ID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if len(ids) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// There are 3 potential env vars to use to select GPUs.
|
// There are 3 potential env vars to use to select GPUs.
|
||||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
|
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
|
||||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||||
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
|
return "HIP_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ func IsNUMA() bool {
|
|||||||
// numa support in llama.cpp is linux only
|
// numa support in llama.cpp is linux only
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
ids := map[string]interface{}{}
|
ids := map[string]any{}
|
||||||
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
packageIds, _ := filepath.Glob("/sys/devices/system/cpu/cpu*/topology/physical_package_id")
|
||||||
for _, packageId := range packageIds {
|
for _, packageId := range packageIds {
|
||||||
id, err := os.ReadFile(packageId)
|
id, err := os.ReadFile(packageId)
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
package discover
|
package discover
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
@@ -15,19 +16,6 @@ import (
|
|||||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||||
|
|
||||||
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
|
||||||
ids := []string{}
|
|
||||||
for _, info := range gpuInfo {
|
|
||||||
if info.Library != "cuda" {
|
|
||||||
// TODO shouldn't happen if things are wired correctly...
|
|
||||||
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ids = append(ids, info.ID)
|
|
||||||
}
|
|
||||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||||
if CudaTegra != "" {
|
if CudaTegra != "" {
|
||||||
@@ -55,10 +43,13 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return "sbsa"
|
||||||
}
|
}
|
||||||
|
|
||||||
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
// driver 12.0 has problems with the cuda v12 library, so run v11 on those older drivers
|
||||||
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
if gpuInfo.DriverMajor < 12 || (gpuInfo.DriverMajor == 12 && gpuInfo.DriverMinor == 0) {
|
||||||
|
// The detected driver is older than Feb 2023
|
||||||
|
slog.Warn("old CUDA driver detected - please upgrade to a newer driver", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||||
return "v11"
|
return "v11"
|
||||||
}
|
}
|
||||||
return "v12"
|
return "v12"
|
||||||
|
|||||||
@@ -263,6 +263,8 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
var driverMinor int
|
var driverMinor int
|
||||||
if cHandles.cudart != nil {
|
if cHandles.cudart != nil {
|
||||||
C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
|
C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
|
||||||
|
driverMajor = int(cHandles.cudart.driver_major)
|
||||||
|
driverMinor = int(cHandles.cudart.driver_minor)
|
||||||
} else {
|
} else {
|
||||||
C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
|
C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
|
||||||
driverMajor = int(cHandles.nvcuda.driver_major)
|
driverMajor = int(cHandles.nvcuda.driver_major)
|
||||||
@@ -369,6 +371,15 @@ func GetGPUInfo() GpuInfoList {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rocmGPUs, err = AMDGetGPUInfo()
|
rocmGPUs, err = AMDGetGPUInfo()
|
||||||
|
|
||||||
|
// The ID field is used in context of the filtered set of GPUS
|
||||||
|
// so we have to replace any of these numeric IDs with their
|
||||||
|
// placement in this set of GPUs
|
||||||
|
for i := range rocmGPUs {
|
||||||
|
if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil {
|
||||||
|
rocmGPUs[i].ID = strconv.Itoa(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
bootstrapErrors = append(bootstrapErrors, err)
|
bootstrapErrors = append(bootstrapErrors, err)
|
||||||
}
|
}
|
||||||
@@ -670,7 +681,7 @@ func loadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getVerboseState() C.uint16_t {
|
func getVerboseState() C.uint16_t {
|
||||||
if envconfig.Debug() {
|
if envconfig.LogLevel() < slog.LevelInfo {
|
||||||
return C.uint16_t(1)
|
return C.uint16_t(1)
|
||||||
}
|
}
|
||||||
return C.uint16_t(0)
|
return C.uint16_t(0)
|
||||||
@@ -678,23 +689,16 @@ func getVerboseState() C.uint16_t {
|
|||||||
|
|
||||||
// Given the list of GPUs this instantiation is targeted for,
|
// Given the list of GPUs this instantiation is targeted for,
|
||||||
// figure out the visible devices environment variable
|
// figure out the visible devices environment variable
|
||||||
//
|
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||||
// If different libraries are detected, the first one is what we use
|
|
||||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
|
||||||
if len(l) == 0 {
|
if len(l) == 0 {
|
||||||
return "", ""
|
return nil
|
||||||
}
|
}
|
||||||
switch l[0].Library {
|
vd := []string{}
|
||||||
case "cuda":
|
// Only filter the AMD GPUs at this level, let all NVIDIA devices through
|
||||||
return cudaGetVisibleDevicesEnv(l)
|
if tmp := rocmGetVisibleDevicesEnv(l); tmp != "" {
|
||||||
case "rocm":
|
vd = append(vd, tmp)
|
||||||
return rocmGetVisibleDevicesEnv(l)
|
|
||||||
case "oneapi":
|
|
||||||
return oneapiGetVisibleDevicesEnv(l)
|
|
||||||
default:
|
|
||||||
slog.Debug("no filter required for library " + l[0].Library)
|
|
||||||
return "", ""
|
|
||||||
}
|
}
|
||||||
|
return vd
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSystemInfo() SystemInfo {
|
func GetSystemInfo() SystemInfo {
|
||||||
|
|||||||
@@ -62,9 +62,9 @@ func GetCPUMem() (memInfo, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||||
// No-op on darwin
|
// No-op on darwin
|
||||||
return "", ""
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSystemInfo() SystemInfo {
|
func GetSystemInfo() SystemInfo {
|
||||||
|
|||||||
@@ -27,12 +27,14 @@
|
|||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#ifndef LOG
|
||||||
#define LOG(verbose, ...) \
|
#define LOG(verbose, ...) \
|
||||||
do { \
|
do { \
|
||||||
if (verbose) { \
|
if (verbose) { \
|
||||||
fprintf(stderr, __VA_ARGS__); \
|
fprintf(stderr, __VA_ARGS__); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||||
|
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
#include <inttypes.h>
|
||||||
#include "gpu_info_cudart.h"
|
#include "gpu_info_cudart.h"
|
||||||
|
|
||||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||||
@@ -58,7 +59,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
|||||||
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
|
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
|
||||||
UNLOAD_LIBRARY(resp->ch.handle);
|
UNLOAD_LIBRARY(resp->ch.handle);
|
||||||
resp->ch.handle = NULL;
|
resp->ch.handle = NULL;
|
||||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
if (ret == CUDART_ERROR_INSUFFICIENT_DRIVER) {
|
||||||
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
resp->err = strdup("your nvidia driver is too old or missing. If you have a CUDA GPU please upgrade to run ollama");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -68,18 +69,15 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int version = 0;
|
int version = 0;
|
||||||
cudartDriverVersion_t driverVersion;
|
|
||||||
driverVersion.major = 0;
|
|
||||||
driverVersion.minor = 0;
|
|
||||||
|
|
||||||
// Report driver version if we're in verbose mode, ignore errors
|
// Report driver version if we're in verbose mode, ignore errors
|
||||||
ret = (*resp->ch.cudaDriverGetVersion)(&version);
|
ret = (*resp->ch.cudaDriverGetVersion)(&version);
|
||||||
if (ret != CUDART_SUCCESS) {
|
if (ret != CUDART_SUCCESS) {
|
||||||
LOG(resp->ch.verbose, "cudaDriverGetVersion failed: %d\n", ret);
|
LOG(resp->ch.verbose, "cudaDriverGetVersion failed: %d\n", ret);
|
||||||
} else {
|
} else {
|
||||||
driverVersion.major = version / 1000;
|
resp->ch.driver_major = version / 1000;
|
||||||
driverVersion.minor = (version - (driverVersion.major * 1000)) / 10;
|
resp->ch.driver_minor = (version - (resp->ch.driver_major * 1000)) / 10;
|
||||||
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", driverVersion.major, driverVersion.minor);
|
LOG(resp->ch.verbose, "CUDA driver version: %d-%d\n", resp->ch.driver_major, resp->ch.driver_minor);
|
||||||
}
|
}
|
||||||
|
|
||||||
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
|
ret = (*resp->ch.cudaGetDeviceCount)(&resp->num_devices);
|
||||||
@@ -168,9 +166,9 @@ void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
|
|||||||
resp->free = memInfo.free;
|
resp->free = memInfo.free;
|
||||||
resp->used = memInfo.used;
|
resp->used = memInfo.used;
|
||||||
|
|
||||||
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
|
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "\n", resp->gpu_id, resp->total);
|
||||||
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
|
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "\n", resp->gpu_id, resp->free);
|
||||||
LOG(h.verbose, "[%s] CUDA usedMem %lu\n", resp->gpu_id, resp->used);
|
LOG(h.verbose, "[%s] CUDA usedMem %" PRId64 "\n", resp->gpu_id, resp->used);
|
||||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -29,11 +29,6 @@ typedef struct cudartMemory_st {
|
|||||||
size_t used;
|
size_t used;
|
||||||
} cudartMemory_t;
|
} cudartMemory_t;
|
||||||
|
|
||||||
typedef struct cudartDriverVersion {
|
|
||||||
int major;
|
|
||||||
int minor;
|
|
||||||
} cudartDriverVersion_t;
|
|
||||||
|
|
||||||
typedef struct cudaUUID {
|
typedef struct cudaUUID {
|
||||||
unsigned char bytes[16];
|
unsigned char bytes[16];
|
||||||
} cudaUUID_t;
|
} cudaUUID_t;
|
||||||
@@ -123,6 +118,8 @@ typedef struct cudaDeviceProp {
|
|||||||
typedef struct cudart_handle {
|
typedef struct cudart_handle {
|
||||||
void *handle;
|
void *handle;
|
||||||
uint16_t verbose;
|
uint16_t verbose;
|
||||||
|
int driver_major;
|
||||||
|
int driver_minor;
|
||||||
cudartReturn_t (*cudaSetDevice)(int device);
|
cudartReturn_t (*cudaSetDevice)(int device);
|
||||||
cudartReturn_t (*cudaDeviceSynchronize)(void);
|
cudartReturn_t (*cudaDeviceSynchronize)(void);
|
||||||
cudartReturn_t (*cudaDeviceReset)(void);
|
cudartReturn_t (*cudaDeviceReset)(void);
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||||
|
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
#include <inttypes.h>
|
||||||
#include "gpu_info_nvcuda.h"
|
#include "gpu_info_nvcuda.h"
|
||||||
|
|
||||||
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
||||||
@@ -193,8 +194,8 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
|||||||
resp->total = memInfo.total;
|
resp->total = memInfo.total;
|
||||||
resp->free = memInfo.free;
|
resp->free = memInfo.free;
|
||||||
|
|
||||||
LOG(h.verbose, "[%s] CUDA totalMem %lu mb\n", resp->gpu_id, resp->total / 1024 / 1024);
|
LOG(h.verbose, "[%s] CUDA totalMem %" PRId64 "mb\n", resp->gpu_id, resp->total / 1024 / 1024);
|
||||||
LOG(h.verbose, "[%s] CUDA freeMem %lu mb\n", resp->gpu_id, resp->free / 1024 / 1024);
|
LOG(h.verbose, "[%s] CUDA freeMem %" PRId64 "mb\n", resp->gpu_id, resp->free / 1024 / 1024);
|
||||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer file.Close()
|
||||||
return linuxCPUDetails(file)
|
return linuxCPUDetails(file)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
|
|||||||
for id, s := range socketByID {
|
for id, s := range socketByID {
|
||||||
s.CoreCount = len(coreBySocket[id])
|
s.CoreCount = len(coreBySocket[id])
|
||||||
s.ThreadCount = 0
|
s.ThreadCount = 0
|
||||||
for _, tc := range threadsByCoreBySocket[id] {
|
|
||||||
s.ThreadCount += tc
|
|
||||||
}
|
|
||||||
|
|
||||||
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
|
||||||
efficiencyCoreCount := 0
|
efficiencyCoreCount := 0
|
||||||
for _, threads := range threadsByCoreBySocket[id] {
|
for _, threads := range threadsByCoreBySocket[id] {
|
||||||
|
s.ThreadCount += threads
|
||||||
if threads == 1 {
|
if threads == 1 {
|
||||||
efficiencyCoreCount++
|
efficiencyCoreCount++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
//go:build linux || windows
|
|
||||||
|
|
||||||
package discover
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log/slog"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
|
||||||
ids := []string{}
|
|
||||||
for _, info := range gpuInfo {
|
|
||||||
if info.Library != "oneapi" {
|
|
||||||
// TODO shouldn't happen if things are wired correctly...
|
|
||||||
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ids = append(ids, info.ID)
|
|
||||||
}
|
|
||||||
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
|
|
||||||
}
|
|
||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
// '../lib/ollama' on Linux and the executable's directory on macOS
|
// '../lib/ollama' on Linux and the executable's directory on macOS
|
||||||
// note: distribution builds, additional GPU-specific libraries are
|
// note: distribution builds, additional GPU-specific libraries are
|
||||||
// found in subdirectories of the returned path, such as
|
// found in subdirectories of the returned path, such as
|
||||||
// 'cuda_v11', 'cuda_v12', 'rocm', etc.
|
// 'cuda_v12', 'rocm', etc.
|
||||||
var LibOllamaPath string = func() string {
|
var LibOllamaPath string = func() string {
|
||||||
exe, err := os.Executable()
|
exe, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -27,8 +27,8 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
|||||||
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||||
DependencyPath []string `json:"lib_path,omitempty"`
|
DependencyPath []string `json:"lib_path,omitempty"`
|
||||||
|
|
||||||
// Extra environment variables specific to the GPU as list of [key,value]
|
// Extra environment variables specific to the GPU as list of [key=value]
|
||||||
EnvWorkarounds [][2]string `json:"envs,omitempty"`
|
EnvWorkarounds []string `json:"envs,omitempty"`
|
||||||
|
|
||||||
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
||||||
// the FreeMemory is best effort, and may over or under report actual memory usage
|
// the FreeMemory is best effort, and may over or under report actual memory usage
|
||||||
@@ -36,9 +36,10 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
|||||||
UnreliableFreeMemory bool
|
UnreliableFreeMemory bool
|
||||||
|
|
||||||
// GPU information
|
// GPU information
|
||||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||||
Name string `json:"name"` // user friendly name if available
|
filterID int //nolint:unused,nolintlint // AMD Workaround: The numeric ID of the device used to filter out other devices
|
||||||
Compute string `json:"compute"` // Compute Capability or gfx
|
Name string `json:"name"` // user friendly name if available
|
||||||
|
Compute string `json:"compute"` // Compute Capability or gfx
|
||||||
|
|
||||||
// Driver Information - TODO no need to put this on each GPU
|
// Driver Information - TODO no need to put this on each GPU
|
||||||
DriverMajor int `json:"driver_major,omitempty"`
|
DriverMajor int `json:"driver_major,omitempty"`
|
||||||
@@ -171,7 +172,8 @@ func (si SystemInfo) GetOptimalThreadCount() int {
|
|||||||
// For each GPU, check if it does NOT support flash attention
|
// For each GPU, check if it does NOT support flash attention
|
||||||
func (l GpuInfoList) FlashAttentionSupported() bool {
|
func (l GpuInfoList) FlashAttentionSupported() bool {
|
||||||
for _, gpu := range l {
|
for _, gpu := range l {
|
||||||
supportsFA := gpu.Library == "metal" ||
|
supportsFA := gpu.Library == "cpu" ||
|
||||||
|
gpu.Library == "metal" ||
|
||||||
(gpu.Library == "cuda" && gpu.DriverMajor >= 7) ||
|
(gpu.Library == "cuda" && gpu.DriverMajor >= 7) ||
|
||||||
gpu.Library == "rocm"
|
gpu.Library == "rocm"
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
* [Quickstart](../README.md#quickstart)
|
* [Quickstart](../README.md#quickstart)
|
||||||
* [Examples](./examples.md)
|
* [Examples](./examples.md)
|
||||||
* [Importing models](./import.md)
|
* [Importing models](./import.md)
|
||||||
|
* [MacOS Documentation](./macos.md)
|
||||||
* [Linux Documentation](./linux.md)
|
* [Linux Documentation](./linux.md)
|
||||||
* [Windows Documentation](./windows.md)
|
* [Windows Documentation](./windows.md)
|
||||||
* [Docker Documentation](./docker.md)
|
* [Docker Documentation](./docker.md)
|
||||||
|
|||||||
336
docs/api.md
336
docs/api.md
@@ -19,7 +19,7 @@
|
|||||||
|
|
||||||
### Model names
|
### Model names
|
||||||
|
|
||||||
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q4_1` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
Model names follow a `model:tag` format, where `model` can have an optional namespace such as `example/model`. Some examples are `orca-mini:3b-q8_0` and `llama3:70b`. The tag is optional and, if not provided, will default to `latest`. The tag is used to identify a specific version.
|
||||||
|
|
||||||
### Durations
|
### Durations
|
||||||
|
|
||||||
@@ -43,6 +43,7 @@ Generate a response for a given prompt with a provided model. This is a streamin
|
|||||||
- `prompt`: the prompt to generate a response for
|
- `prompt`: the prompt to generate a response for
|
||||||
- `suffix`: the text after the model response
|
- `suffix`: the text after the model response
|
||||||
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
|
||||||
|
- `think`: (for thinking models) should the model think before responding?
|
||||||
|
|
||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
|
|
||||||
@@ -173,7 +174,7 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
|
|
||||||
##### Response
|
##### Response
|
||||||
|
|
||||||
```json
|
```json5
|
||||||
{
|
{
|
||||||
"model": "codellama:code",
|
"model": "codellama:code",
|
||||||
"created_at": "2024-07-22T20:47:51.147561Z",
|
"created_at": "2024-07-22T20:47:51.147561Z",
|
||||||
@@ -394,9 +395,6 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
"repeat_penalty": 1.2,
|
"repeat_penalty": 1.2,
|
||||||
"presence_penalty": 1.5,
|
"presence_penalty": 1.5,
|
||||||
"frequency_penalty": 1.0,
|
"frequency_penalty": 1.0,
|
||||||
"mirostat": 1,
|
|
||||||
"mirostat_tau": 0.8,
|
|
||||||
"mirostat_eta": 0.6,
|
|
||||||
"penalize_newline": true,
|
"penalize_newline": true,
|
||||||
"stop": ["\n", "user:"],
|
"stop": ["\n", "user:"],
|
||||||
"numa": false,
|
"numa": false,
|
||||||
@@ -404,10 +402,7 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
"num_batch": 2,
|
"num_batch": 2,
|
||||||
"num_gpu": 1,
|
"num_gpu": 1,
|
||||||
"main_gpu": 0,
|
"main_gpu": 0,
|
||||||
"low_vram": false,
|
|
||||||
"vocab_only": false,
|
|
||||||
"use_mmap": true,
|
"use_mmap": true,
|
||||||
"use_mlock": false,
|
|
||||||
"num_thread": 8
|
"num_thread": 8
|
||||||
}
|
}
|
||||||
}'
|
}'
|
||||||
@@ -496,13 +491,16 @@ Generate the next message in a chat with a provided model. This is a streaming e
|
|||||||
- `model`: (required) the [model name](#model-names)
|
- `model`: (required) the [model name](#model-names)
|
||||||
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
- `messages`: the messages of the chat, this can be used to keep a chat memory
|
||||||
- `tools`: list of tools in JSON for the model to use if supported
|
- `tools`: list of tools in JSON for the model to use if supported
|
||||||
|
- `think`: (for thinking models) should the model think before responding?
|
||||||
|
|
||||||
The `message` object has the following fields:
|
The `message` object has the following fields:
|
||||||
|
|
||||||
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool`
|
||||||
- `content`: the content of the message
|
- `content`: the content of the message
|
||||||
|
- `thinking`: (for thinking models) the model's thinking process
|
||||||
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
|
||||||
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
- `tool_calls` (optional): a list of tools in JSON that the model wants to use
|
||||||
|
- `tool_name` (optional): add the name of the tool that was executed to inform the model of the result
|
||||||
|
|
||||||
Advanced parameters (optional):
|
Advanced parameters (optional):
|
||||||
|
|
||||||
@@ -511,13 +509,21 @@ Advanced parameters (optional):
|
|||||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||||
|
|
||||||
|
### Tool calling
|
||||||
|
|
||||||
|
Tool calling is supported by providing a list of tools in the `tools` parameter. The model will generate a response that includes a list of tool calls. See the [Chat request (Streaming with tools)](#chat-request-streaming-with-tools) example below.
|
||||||
|
|
||||||
|
Models can also explain the result of the tool call in the response. See the [Chat request (With history, with tools)](#chat-request-with-history-with-tools) example below.
|
||||||
|
|
||||||
|
[See models with tool calling capabilities](https://ollama.com/search?c=tool).
|
||||||
|
|
||||||
### Structured outputs
|
### Structured outputs
|
||||||
|
|
||||||
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
|
Structured outputs are supported by providing a JSON schema in the `format` parameter. The model will generate a response that matches the schema. See the [Chat request (Structured outputs)](#chat-request-structured-outputs) example below.
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
#### Chat Request (Streaming)
|
#### Chat request (Streaming)
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
|
|
||||||
@@ -558,6 +564,10 @@ Final response:
|
|||||||
{
|
{
|
||||||
"model": "llama3.2",
|
"model": "llama3.2",
|
||||||
"created_at": "2023-08-04T19:22:45.499127Z",
|
"created_at": "2023-08-04T19:22:45.499127Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
"done": true,
|
"done": true,
|
||||||
"total_duration": 4883583458,
|
"total_duration": 4883583458,
|
||||||
"load_duration": 1334875,
|
"load_duration": 1334875,
|
||||||
@@ -568,6 +578,88 @@ Final response:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Chat request (Streaming with tools)
|
||||||
|
|
||||||
|
##### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/chat -d '{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the weather in tokyo?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given city",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city to get the weather for"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": true
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Response
|
||||||
|
|
||||||
|
A stream of JSON objects is returned:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"created_at": "2025-07-07T20:22:19.184789Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": {
|
||||||
|
"city": "Tokyo"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"done": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Final response:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model":"llama3.2",
|
||||||
|
"created_at":"2025-07-07T20:22:19.19314Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": ""
|
||||||
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 182242375,
|
||||||
|
"load_duration": 41295167,
|
||||||
|
"prompt_eval_count": 169,
|
||||||
|
"prompt_eval_duration": 24573166,
|
||||||
|
"eval_count": 15,
|
||||||
|
"eval_duration": 115959084
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
#### Chat request (No streaming)
|
#### Chat request (No streaming)
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
@@ -605,6 +697,74 @@ curl http://localhost:11434/api/chat -d '{
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Chat request (No streaming, with tools)
|
||||||
|
|
||||||
|
##### Request
|
||||||
|
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/chat -d '{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the weather in tokyo?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given city",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city to get the weather for"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": false
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"created_at": "2025-07-07T20:32:53.844124Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": {
|
||||||
|
"city": "Tokyo"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 3244883583,
|
||||||
|
"load_duration": 2969184542,
|
||||||
|
"prompt_eval_count": 169,
|
||||||
|
"prompt_eval_duration": 141656333,
|
||||||
|
"eval_count": 18,
|
||||||
|
"eval_duration": 133293625
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
#### Chat request (Structured outputs)
|
#### Chat request (Structured outputs)
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
@@ -711,6 +871,87 @@ Final response:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Chat request (With history, with tools)
|
||||||
|
|
||||||
|
##### Request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/chat -d '{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what is the weather in Toronto?"
|
||||||
|
},
|
||||||
|
// the message from the model appended to history
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {
|
||||||
|
"name": "get_temperature",
|
||||||
|
"arguments": {
|
||||||
|
"city": "Toronto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
// the tool call result appended to history
|
||||||
|
{
|
||||||
|
"role": "tool",
|
||||||
|
"content": "11 degrees celsius",
|
||||||
|
"tool_name": "get_temperature",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": false,
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get the weather in a given city",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"city": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city to get the weather for"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"model": "llama3.2",
|
||||||
|
"created_at": "2025-07-07T20:43:37.688511Z",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The current temperature in Toronto is 11°C."
|
||||||
|
},
|
||||||
|
"done_reason": "stop",
|
||||||
|
"done": true,
|
||||||
|
"total_duration": 890771750,
|
||||||
|
"load_duration": 707634750,
|
||||||
|
"prompt_eval_count": 94,
|
||||||
|
"prompt_eval_duration": 91703208,
|
||||||
|
"eval_count": 11,
|
||||||
|
"eval_duration": 90282125
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
#### Chat request (with images)
|
#### Chat request (with images)
|
||||||
|
|
||||||
##### Request
|
##### Request
|
||||||
@@ -954,19 +1195,8 @@ If you are creating a model from a safetensors directory or from a GGUF file, yo
|
|||||||
|
|
||||||
| Type | Recommended |
|
| Type | Recommended |
|
||||||
| --- | :-: |
|
| --- | :-: |
|
||||||
| q2_K | |
|
|
||||||
| q3_K_L | |
|
|
||||||
| q3_K_M | |
|
|
||||||
| q3_K_S | |
|
|
||||||
| q4_0 | |
|
|
||||||
| q4_1 | |
|
|
||||||
| q4_K_M | * |
|
| q4_K_M | * |
|
||||||
| q4_K_S | |
|
| q4_K_S | |
|
||||||
| q5_0 | |
|
|
||||||
| q5_1 | |
|
|
||||||
| q5_K_M | |
|
|
||||||
| q5_K_S | |
|
|
||||||
| q6_K | |
|
|
||||||
| q8_0 | * |
|
| q8_0 | * |
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
@@ -1011,8 +1241,8 @@ Quantize a non-quantized model.
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/create -d '{
|
curl http://localhost:11434/api/create -d '{
|
||||||
"model": "llama3.1:quantized",
|
"model": "llama3.2:quantized",
|
||||||
"from": "llama3.1:8b-instruct-fp16",
|
"from": "llama3.2:3b-instruct-fp16",
|
||||||
"quantize": "q4_K_M"
|
"quantize": "q4_K_M"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
@@ -1022,12 +1252,14 @@ curl http://localhost:11434/api/create -d '{
|
|||||||
A stream of JSON objects is returned:
|
A stream of JSON objects is returned:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{"status":"quantizing F16 model to Q4_K_M"}
|
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":12302}
|
||||||
{"status":"creating new layer sha256:667b0c1932bc6ffc593ed1d03f895bf2dc8dc6df21db3042284a6f4416b06a29"}
|
{"status":"quantizing F16 model to Q4_K_M","digest":"0","total":6433687776,"completed":6433687552}
|
||||||
{"status":"using existing layer sha256:11ce4ee3e170f6adebac9a991c22e22ab3f8530e154ee669954c4bc73061c258"}
|
{"status":"verifying conversion"}
|
||||||
{"status":"using existing layer sha256:0ba8f0e314b4264dfd19df045cde9d4c394a52474bf92ed6a3de22a4ca31a177"}
|
{"status":"creating new layer sha256:fb7f4f211b89c6c4928ff4ddb73db9f9c0cfca3e000c3e40d6cf27ddc6ca72eb"}
|
||||||
|
{"status":"using existing layer sha256:966de95ca8a62200913e3f8bfbf84c8494536f1b94b49166851e76644e966396"}
|
||||||
|
{"status":"using existing layer sha256:fcc5a6bec9daf9b561a68827b67ab6088e1dba9d1fa2a50d7bbcc8384e0a265d"}
|
||||||
|
{"status":"using existing layer sha256:a70ff7e570d97baaf4e62ac6e6ad9975e04caa6d900d3742d37698494479e0cd"}
|
||||||
{"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}
|
{"status":"using existing layer sha256:56bb8bd477a519ffa694fc449c2413c6f0e1d3b1c88fa7e3c9d88d3ae49d4dcb"}
|
||||||
{"status":"creating new layer sha256:455f34728c9b5dd3376378bfb809ee166c145b0b4c1f1a6feca069055066ef9a"}
|
|
||||||
{"status":"writing manifest"}
|
{"status":"writing manifest"}
|
||||||
{"status":"success"}
|
{"status":"success"}
|
||||||
```
|
```
|
||||||
@@ -1165,29 +1397,37 @@ A single JSON object will be returned.
|
|||||||
{
|
{
|
||||||
"models": [
|
"models": [
|
||||||
{
|
{
|
||||||
"name": "codellama:13b",
|
"name": "deepseek-r1:latest",
|
||||||
"modified_at": "2023-11-04T14:56:49.277302595-07:00",
|
"model": "deepseek-r1:latest",
|
||||||
"size": 7365960935,
|
"modified_at": "2025-05-10T08:06:48.639712648-07:00",
|
||||||
"digest": "9f438cb9cd581fc025612d27f7c1a6669ff83a8bb0ed86c94fcf4c5440555697",
|
"size": 4683075271,
|
||||||
|
"digest": "0a8c266910232fd3291e71e5ba1e058cc5af9d411192cf88b6d30e92b6e73163",
|
||||||
"details": {
|
"details": {
|
||||||
|
"parent_model": "",
|
||||||
"format": "gguf",
|
"format": "gguf",
|
||||||
"family": "llama",
|
"family": "qwen2",
|
||||||
"families": null,
|
"families": [
|
||||||
"parameter_size": "13B",
|
"qwen2"
|
||||||
"quantization_level": "Q4_0"
|
],
|
||||||
|
"parameter_size": "7.6B",
|
||||||
|
"quantization_level": "Q4_K_M"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "llama3:latest",
|
"name": "llama3.2:latest",
|
||||||
"modified_at": "2023-12-07T09:32:18.757212583-08:00",
|
"model": "llama3.2:latest",
|
||||||
"size": 3825819519,
|
"modified_at": "2025-05-04T17:37:44.706015396-07:00",
|
||||||
"digest": "fe938a131f40e6f6d40083c9f0f430a515233eb2edaa6d72eb85c50d64f2300e",
|
"size": 2019393189,
|
||||||
|
"digest": "a80c4f17acd55265feec403c7aef86be0c25983ab279d83f3bcd3abbcb5b8b72",
|
||||||
"details": {
|
"details": {
|
||||||
|
"parent_model": "",
|
||||||
"format": "gguf",
|
"format": "gguf",
|
||||||
"family": "llama",
|
"family": "llama",
|
||||||
"families": null,
|
"families": [
|
||||||
"parameter_size": "7B",
|
"llama"
|
||||||
"quantization_level": "Q4_0"
|
],
|
||||||
|
"parameter_size": "3.2B",
|
||||||
|
"quantization_level": "Q4_K_M"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
@@ -1213,13 +1453,13 @@ Show information about a model including details, modelfile, template, parameter
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl http://localhost:11434/api/show -d '{
|
curl http://localhost:11434/api/show -d '{
|
||||||
"model": "llama3.2"
|
"model": "llava"
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Response
|
#### Response
|
||||||
|
|
||||||
```json
|
```json5
|
||||||
{
|
{
|
||||||
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
|
"modelfile": "# Modelfile generated by \"ollama show\"\n# To build a new Modelfile based on this one, replace the FROM line with:\n# FROM llava:latest\n\nFROM /Users/matt/.ollama/models/blobs/sha256:200765e1283640ffbd013184bf496e261032fa75b99498a9613be4e94d63ad52\nTEMPLATE \"\"\"{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: \"\"\"\nPARAMETER num_ctx 4096\nPARAMETER stop \"\u003c/s\u003e\"\nPARAMETER stop \"USER:\"\nPARAMETER stop \"ASSISTANT:\"",
|
||||||
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",
|
"parameters": "num_keep 24\nstop \"<|start_header_id|>\"\nstop \"<|end_header_id|>\"\nstop \"<|eot_id|>\"",
|
||||||
@@ -1256,7 +1496,11 @@ curl http://localhost:11434/api/show -d '{
|
|||||||
"tokenizer.ggml.pre": "llama-bpe",
|
"tokenizer.ggml.pre": "llama-bpe",
|
||||||
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
"tokenizer.ggml.token_type": [], // populates if `verbose=true`
|
||||||
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
"tokenizer.ggml.tokens": [] // populates if `verbose=true`
|
||||||
}
|
},
|
||||||
|
"capabilities": [
|
||||||
|
"completion",
|
||||||
|
"vision"
|
||||||
|
],
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -1349,7 +1593,7 @@ Then there is a series of downloading responses. Until any of the download is co
|
|||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"status": "downloading digestname",
|
"status": "pulling digestname",
|
||||||
"digest": "digestname",
|
"digest": "digestname",
|
||||||
"total": 2142590208,
|
"total": 2142590208,
|
||||||
"completed": 241970
|
"completed": 241970
|
||||||
|
|||||||
@@ -118,6 +118,35 @@ To run tests, use `go test`:
|
|||||||
go test ./...
|
go test ./...
|
||||||
```
|
```
|
||||||
|
|
||||||
|
> NOTE: In rare circumstances, you may need to change a package using the new
|
||||||
|
> "synctest" package in go1.24.
|
||||||
|
>
|
||||||
|
> If you do not have the "synctest" package enabled, you will not see build or
|
||||||
|
> test failures resulting from your change(s), if any, locally, but CI will
|
||||||
|
> break.
|
||||||
|
>
|
||||||
|
> If you see failures in CI, you can either keep pushing changes to see if the
|
||||||
|
> CI build passes, or you can enable the "synctest" package locally to see the
|
||||||
|
> failures before pushing.
|
||||||
|
>
|
||||||
|
> To enable the "synctest" package for testing, run the following command:
|
||||||
|
>
|
||||||
|
> ```shell
|
||||||
|
> GOEXPERIMENT=synctest go test ./...
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> If you wish to enable synctest for all go commands, you can set the
|
||||||
|
> `GOEXPERIMENT` environment variable in your shell profile or by using:
|
||||||
|
>
|
||||||
|
> ```shell
|
||||||
|
> go env -w GOEXPERIMENT=synctest
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> Which will enable the "synctest" package for all go commands without needing
|
||||||
|
> to set it for all shell sessions.
|
||||||
|
>
|
||||||
|
> The synctest package is not required for production builds.
|
||||||
|
|
||||||
## Library detection
|
## Library detection
|
||||||
|
|
||||||
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
|
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
|
||||||
|
|||||||
44
docs/faq.md
44
docs/faq.md
@@ -20,7 +20,13 @@ Please refer to the [GPU docs](./gpu.md).
|
|||||||
|
|
||||||
## How can I specify the context window size?
|
## How can I specify the context window size?
|
||||||
|
|
||||||
By default, Ollama uses a context window size of 2048 tokens.
|
By default, Ollama uses a context window size of 4096 tokens for most models. The `gpt-oss` model has a default context window size of 8192 tokens.
|
||||||
|
|
||||||
|
This can be overridden in Settings in the Windows and macOS App, or with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
To change this when using `ollama run`, use `/set parameter`:
|
To change this when using `ollama run`, use `/set parameter`:
|
||||||
|
|
||||||
@@ -40,6 +46,8 @@ curl http://localhost:11434/api/generate -d '{
|
|||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Setting the context length higher may cause the model to not be able to fit onto the GPU which make the model run more slowly.
|
||||||
|
|
||||||
## How can I tell if my model was loaded onto the GPU?
|
## How can I tell if my model was loaded onto the GPU?
|
||||||
|
|
||||||
Use the `ollama ps` command to see what models are currently loaded into memory.
|
Use the `ollama ps` command to see what models are currently loaded into memory.
|
||||||
@@ -51,8 +59,8 @@ ollama ps
|
|||||||
> **Output**:
|
> **Output**:
|
||||||
>
|
>
|
||||||
> ```
|
> ```
|
||||||
> NAME ID SIZE PROCESSOR UNTIL
|
> NAME ID SIZE PROCESSOR CONTEXT UNTIL
|
||||||
> llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
> gpt-oss:20b 05afbac4bad6 16 GB 100% GPU 8192 4 minutes from now
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
The `Processor` column will show which memory the model was loaded in to:
|
The `Processor` column will show which memory the model was loaded in to:
|
||||||
@@ -142,9 +150,11 @@ docker build -t ollama-with-ca .
|
|||||||
docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-with-ca
|
docker run -d -e HTTPS_PROXY=https://my.proxy.example.com -p 11434:11434 ollama-with-ca
|
||||||
```
|
```
|
||||||
|
|
||||||
## Does Ollama send my prompts and answers back to ollama.com?
|
## Does Ollama send my prompts and responses back to ollama.com?
|
||||||
|
|
||||||
No. Ollama runs locally, and conversation data does not leave your machine.
|
If you're running a model locally, your prompts and responses will always stay on your machine. Ollama Turbo in the App allows you to run your queries on Ollama's servers if you don't have a powerful enough GPU. Web search lets a model query the web, giving you more accurate and up-to-date information. Both Turbo and web search require sending your prompts and responses to Ollama.com. This data is neither logged nor stored.
|
||||||
|
|
||||||
|
If you don't want to see the Turbo and web search options in the app, you can disable them in Settings by turning on Airplane mode. In Airplane mode, all models will run locally, and your prompts and responses will stay on your machine.
|
||||||
|
|
||||||
## How can I expose Ollama on my network?
|
## How can I expose Ollama on my network?
|
||||||
|
|
||||||
@@ -187,6 +197,13 @@ cloudflared tunnel --url http://localhost:11434 --http-host-header="localhost:11
|
|||||||
|
|
||||||
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
|
||||||
|
|
||||||
|
For browser extensions, you'll need to explicitly allow the extension's origin pattern. Set `OLLAMA_ORIGINS` to include `chrome-extension://*`, `moz-extension://*`, and `safari-web-extension://*` if you wish to allow all browser extensions access, or specific extensions as needed:
|
||||||
|
|
||||||
|
```
|
||||||
|
# Allow all Chrome, Firefox, and Safari extensions
|
||||||
|
OLLAMA_ORIGINS=chrome-extension://*,moz-extension://*,safari-web-extension://* ollama serve
|
||||||
|
```
|
||||||
|
|
||||||
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
|
||||||
|
|
||||||
## Where are models stored?
|
## Where are models stored?
|
||||||
@@ -279,7 +296,7 @@ If too many requests are sent to the server, it will respond with a 503 error in
|
|||||||
|
|
||||||
## How does Ollama handle concurrent requests?
|
## How does Ollama handle concurrent requests?
|
||||||
|
|
||||||
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it is configured to allow parallel request processing.
|
Ollama supports two levels of concurrent processing. If your system has sufficient available memory (system memory when using CPU inference, or VRAM for GPU inference) then multiple models can be loaded at the same time. For a given model, if there is sufficient available memory when the model is loaded, it can be configured to allow parallel request processing.
|
||||||
|
|
||||||
If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
|
If there is insufficient available memory to load a new model request while one or more models are already loaded, all new requests will be queued until the new model can be loaded. As prior models become idle, one or more will be unloaded to make room for the new model. Queued requests will be processed in order. When using GPU inference new models must be able to completely fit in VRAM to allow concurrent model loads.
|
||||||
|
|
||||||
@@ -288,7 +305,7 @@ Parallel request processing for a given model results in increasing the context
|
|||||||
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
The following server settings may be used to adjust how Ollama handles concurrent requests on most platforms:
|
||||||
|
|
||||||
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
|
- `OLLAMA_MAX_LOADED_MODELS` - The maximum number of models that can be loaded concurrently provided they fit in available memory. The default is 3 * the number of GPUs or 3 for CPU inference.
|
||||||
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
|
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default is 1, and will handle 1 request per model at a time.
|
||||||
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
|
||||||
|
|
||||||
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
|
||||||
@@ -320,3 +337,16 @@ The currently available K/V cache quantization types are:
|
|||||||
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
|
How much the cache quantization impacts the model's response quality will depend on the model and the task. Models that have a high GQA count (e.g. Qwen2) may see a larger impact on precision from quantization than models with a low GQA count.
|
||||||
|
|
||||||
You may need to experiment with different quantization types to find the best balance between memory usage and quality.
|
You may need to experiment with different quantization types to find the best balance between memory usage and quality.
|
||||||
|
|
||||||
|
## How can I stop Ollama from starting when I login to my computer
|
||||||
|
|
||||||
|
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
|
||||||
|
|
||||||
|
**Windows**
|
||||||
|
- Remove `%APPDATA%\Microsoft\Windows\Start Menu\Programs\Startup\Ollama.lnk`
|
||||||
|
|
||||||
|
**MacOS Monterey (v12)**
|
||||||
|
- Open `Settings` -> `Users & Groups` -> `Login Items` and find the `Ollama` entry, then click the `-` (minus) to remove
|
||||||
|
|
||||||
|
**MacOS Ventura (v13) and later**
|
||||||
|
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||||
|
|||||||
@@ -1,12 +1,14 @@
|
|||||||
# GPU
|
# GPU
|
||||||
## Nvidia
|
## Nvidia
|
||||||
Ollama supports Nvidia GPUs with compute capability 5.0+.
|
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
|
||||||
|
|
||||||
Check your compute compatibility to see if your card is supported:
|
Check your compute compatibility to see if your card is supported:
|
||||||
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
|
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
|
||||||
|
|
||||||
| Compute Capability | Family | Cards |
|
| Compute Capability | Family | Cards |
|
||||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||||
|
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
|
||||||
|
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
|
||||||
| 9.0 | NVIDIA | `H200` `H100` |
|
| 9.0 | NVIDIA | `H200` `H100` |
|
||||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ FROM /path/to/safetensors/directory
|
|||||||
|
|
||||||
If you create the Modelfile in the same directory as the weights, you can use the command `FROM .`.
|
If you create the Modelfile in the same directory as the weights, you can use the command `FROM .`.
|
||||||
|
|
||||||
|
If you do not create the Modelfile, ollama will act as if there was a Modelfile with the command `FROM .`.
|
||||||
|
|
||||||
Now run the `ollama create` command from the directory where you created the `Modelfile`:
|
Now run the `ollama create` command from the directory where you created the `Modelfile`:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
@@ -132,22 +134,12 @@ success
|
|||||||
|
|
||||||
### Supported Quantizations
|
### Supported Quantizations
|
||||||
|
|
||||||
- `q4_0`
|
|
||||||
- `q4_1`
|
|
||||||
- `q5_0`
|
|
||||||
- `q5_1`
|
|
||||||
- `q8_0`
|
- `q8_0`
|
||||||
|
|
||||||
#### K-means Quantizations
|
#### K-means Quantizations
|
||||||
|
|
||||||
- `q3_K_S`
|
|
||||||
- `q3_K_M`
|
|
||||||
- `q3_K_L`
|
|
||||||
- `q4_K_S`
|
- `q4_K_S`
|
||||||
- `q4_K_M`
|
- `q4_K_M`
|
||||||
- `q5_K_S`
|
|
||||||
- `q5_K_M`
|
|
||||||
- `q6_K`
|
|
||||||
|
|
||||||
|
|
||||||
## Sharing your model on ollama.com
|
## Sharing your model on ollama.com
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ curl -fsSL https://ollama.com/install.sh | sh
|
|||||||
Download and extract the package:
|
Download and extract the package:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -L https://ollama.com/download/ollama-linux-amd64.tgz -o ollama-linux-amd64.tgz
|
curl -LO https://ollama.com/download/ollama-linux-amd64.tgz
|
||||||
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
sudo tar -C /usr -xzf ollama-linux-amd64.tgz
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -34,7 +34,11 @@ ollama -v
|
|||||||
|
|
||||||
### AMD GPU install
|
### AMD GPU install
|
||||||
|
|
||||||
If you have an AMD GPU, also download and extract the additional ROCm package:
|
If you have an AMD GPU, **also** download and extract the additional ROCm package:
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> The ROCm tgz contains only AMD dependent libraries. You must extract **both** `ollama-linux-amd64.tgz` and `ollama-linux-amd64-rocm.tgz` into the same location.
|
||||||
|
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
curl -L https://ollama.com/download/ollama-linux-amd64-rocm.tgz -o ollama-linux-amd64-rocm.tgz
|
curl -L https://ollama.com/download/ollama-linux-amd64-rocm.tgz -o ollama-linux-amd64-rocm.tgz
|
||||||
@@ -75,7 +79,7 @@ RestartSec=3
|
|||||||
Environment="PATH=$PATH"
|
Environment="PATH=$PATH"
|
||||||
|
|
||||||
[Install]
|
[Install]
|
||||||
WantedBy=default.target
|
WantedBy=multi-user.target
|
||||||
```
|
```
|
||||||
|
|
||||||
Then start the service:
|
Then start the service:
|
||||||
@@ -112,8 +116,8 @@ sudo systemctl status ollama
|
|||||||
> While AMD has contributed the `amdgpu` driver upstream to the official linux
|
> While AMD has contributed the `amdgpu` driver upstream to the official linux
|
||||||
> kernel source, the version is older and may not support all ROCm features. We
|
> kernel source, the version is older and may not support all ROCm features. We
|
||||||
> recommend you install the latest driver from
|
> recommend you install the latest driver from
|
||||||
> https://www.amd.com/en/support/linux-drivers for best support of your Radeon
|
> [AMD](https://www.amd.com/en/support/download/linux-drivers.html) for best support
|
||||||
> GPU.
|
> of your Radeon GPU.
|
||||||
|
|
||||||
## Customizing
|
## Customizing
|
||||||
|
|
||||||
|
|||||||
42
docs/macos.md
Normal file
42
docs/macos.md
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Ollama for macOS
|
||||||
|
|
||||||
|
## System Requirements
|
||||||
|
|
||||||
|
* MacOS Monterey (v12) or newer
|
||||||
|
* Apple M series (CPU and GPU support) or x86 (CPU only)
|
||||||
|
|
||||||
|
|
||||||
|
## Filesystem Requirements
|
||||||
|
|
||||||
|
The preferred method of installation is to mount the `ollama.dmg` and drag-and-drop the Ollama application to the system-wide `Applications` folder. Upon startup, the Ollama app will verify the `ollama` CLI is present in your PATH, and if not detected, will prompt for permission to create a link in `/usr/local/bin`
|
||||||
|
|
||||||
|
Once you've installed Ollama, you'll need additional space for storing the Large Language models, which can be tens to hundreds of GB in size. If your home directory doesn't have enough space, you can change where the binaries are installed, and where the models are stored.
|
||||||
|
|
||||||
|
### Changing Install Location
|
||||||
|
|
||||||
|
To install the Ollama application somewhere other than `Applications`, place the Ollama application in the desired location, and ensure the CLI `Ollama.app/Contents/Resources/ollama` or a sym-link to the CLI can be found in your path. Upon first start decline the "Move to Applications?" request.
|
||||||
|
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
Ollama on MacOS stores files in a few different locations.
|
||||||
|
- `~/.ollama` contains models and configuration
|
||||||
|
- `~/.ollama/logs` contains logs
|
||||||
|
- *app.log* contains most recent logs from the GUI application
|
||||||
|
- *server.log* contains the most recent server logs
|
||||||
|
- `<install location>/Ollama.app/Contents/Resources/ollama` the CLI binary
|
||||||
|
|
||||||
|
## Uninstall
|
||||||
|
|
||||||
|
To fully remove Ollama from your system, remove the following files and folders:
|
||||||
|
|
||||||
|
```
|
||||||
|
sudo rm -rf /Applications/Ollama.app
|
||||||
|
sudo rm /usr/local/bin/ollama
|
||||||
|
rm -rf "~/Library/Application Support/Ollama"
|
||||||
|
rm -rf "~/Library/Saved Application State/com.electron.ollama.savedState"
|
||||||
|
rm -rf ~/Library/Caches/com.electron.ollama/
|
||||||
|
rm -rf ~/Library/Caches/ollama
|
||||||
|
rm -rf ~/Library/WebKit/com.electron.ollama
|
||||||
|
rm -rf ~/.ollama
|
||||||
|
```
|
||||||
@@ -150,10 +150,7 @@ PARAMETER <parameter> <parametervalue>
|
|||||||
|
|
||||||
| Parameter | Description | Value Type | Example Usage |
|
| Parameter | Description | Value Type | Example Usage |
|
||||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||||
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 4096) | int | num_ctx 4096 |
|
||||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
|
||||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
|
||||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
|
||||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ A basic Go template consists of three main parts:
|
|||||||
|
|
||||||
Here's an example of a simple chat template:
|
Here's an example of a simple chat template:
|
||||||
|
|
||||||
```gotmpl
|
```go
|
||||||
{{- range .Messages }}
|
{{- range .Messages }}
|
||||||
{{ .Role }}: {{ .Content }}
|
{{ .Role }}: {{ .Content }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
@@ -162,6 +162,6 @@ CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://o
|
|||||||
|
|
||||||
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
|
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
|
||||||
|
|
||||||
```gotmpl
|
```go
|
||||||
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
|
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
|
|||||||
On **Linux** systems with systemd, the logs can be found with this command:
|
On **Linux** systems with systemd, the logs can be found with this command:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
journalctl -u ollama --no-pager
|
journalctl -u ollama --no-pager --follow --pager-end
|
||||||
```
|
```
|
||||||
|
|
||||||
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:
|
||||||
@@ -26,7 +26,6 @@ When you run Ollama on **Windows**, there are a few different locations. You can
|
|||||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||||
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
|
||||||
|
|
||||||
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
||||||
|
|
||||||
@@ -39,12 +38,12 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
|||||||
|
|
||||||
## LLM libraries
|
## LLM libraries
|
||||||
|
|
||||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||||
|
|
||||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||||
|
|
||||||
```
|
```
|
||||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
|
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
|
||||||
```
|
```
|
||||||
|
|
||||||
**Experimental LLM Library Override**
|
**Experimental LLM Library Override**
|
||||||
@@ -69,10 +68,6 @@ If you run into problems on Linux and want to install an older version, or you'd
|
|||||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
||||||
```
|
```
|
||||||
|
|
||||||
## Linux tmp noexec
|
|
||||||
|
|
||||||
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
|
|
||||||
|
|
||||||
## Linux docker
|
## Linux docker
|
||||||
|
|
||||||
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
|
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
|
||||||
|
|||||||
107
docs/turbo.md
Normal file
107
docs/turbo.md
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
# Turbo
|
||||||
|
|
||||||
|
> ⚠️ Turbo is preview
|
||||||
|
|
||||||
|
Ollama’s [Turbo](https://ollama.com/turbo) is a new way to run open-source models with acceleration from datacenter-grade hardware.
|
||||||
|
|
||||||
|
Currently, the following models are available in Turbo:
|
||||||
|
|
||||||
|
- `gpt-oss:20b`
|
||||||
|
- `gpt-oss:120b`
|
||||||
|
|
||||||
|
## Get started
|
||||||
|
|
||||||
|
### Ollama for macOS & Windows
|
||||||
|
|
||||||
|
Download Ollama
|
||||||
|
|
||||||
|
- Select a model such as `gpt-oss:20b` or `gpt-oss:120b`
|
||||||
|
- Click on **Turbo**. You’ll be prompted to create an account or sign in
|
||||||
|
|
||||||
|
### Ollama’s CLI
|
||||||
|
|
||||||
|
- [Sign up](https://ollama.com/signup) for an Ollama account
|
||||||
|
- Add your Ollama key [to ollama.com](https://ollama.com/settings/keys).
|
||||||
|
|
||||||
|
On macOS and Linux:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
cat ~/.ollama/id_ed25519.pub
|
||||||
|
```
|
||||||
|
|
||||||
|
On Windows:
|
||||||
|
|
||||||
|
```
|
||||||
|
type "%USERPROFILE%\.ollama\id_ed25519.pub"
|
||||||
|
```
|
||||||
|
|
||||||
|
- Then run a model setting `OLLAMA_HOST` to `ollama.com`:
|
||||||
|
```shell
|
||||||
|
OLLAMA_HOST=ollama.com ollama run gpt-oss:120b
|
||||||
|
```
|
||||||
|
|
||||||
|
### Ollama’s Python library
|
||||||
|
|
||||||
|
- Download Ollama's [Python library](https://github.com/ollama/ollama-python)
|
||||||
|
- [Sign up](https://ollama.com/signup) for an Ollama account
|
||||||
|
- Create an API key by visiting https://ollama.com/settings/keys
|
||||||
|
|
||||||
|
```python
|
||||||
|
from ollama import Client
|
||||||
|
|
||||||
|
client = Client(
|
||||||
|
host="https://ollama.com",
|
||||||
|
headers={'Authorization': '<api key>'}
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
'role': 'user',
|
||||||
|
'content': 'Why is the sky blue?',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
|
||||||
|
print(part['message']['content'], end='', flush=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Ollama’s JavaScript library
|
||||||
|
|
||||||
|
- Download Ollama's [JavaScript library](https://github.com/ollama/ollama-js)
|
||||||
|
- [Sign up](https://ollama.com/signup) for an Ollama account
|
||||||
|
- Create an API key by visiting https://ollama.com/settings/keys
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { Ollama } from 'ollama';
|
||||||
|
|
||||||
|
const ollama = new Ollama({
|
||||||
|
host: 'https://ollama.com',
|
||||||
|
headers: {
|
||||||
|
Authorization: "Bearer <api key>"
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
const response = await ollama.chat({
|
||||||
|
model: 'gpt-oss:120b',
|
||||||
|
messages: [{ role: 'user', content: 'Explain quantum computing' }],
|
||||||
|
stream: true
|
||||||
|
});
|
||||||
|
|
||||||
|
for await (const part of response) {
|
||||||
|
process.stdout.write(part.message.content)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Community integrations
|
||||||
|
|
||||||
|
Turbo mode is also compatible with several community integrations.
|
||||||
|
|
||||||
|
#### Open WebUI
|
||||||
|
|
||||||
|
- Go to **settings** → **Admin settings** → **Connections**
|
||||||
|
- Under **Ollama API,** click **+**
|
||||||
|
- For the **URL** put `https://ollama.com`
|
||||||
|
- For the **API key,** create an API key on https://ollama.com/settings/keys and add it.
|
||||||
|
- Click **Save**
|
||||||
|
|
||||||
|
Now, if you navigate to the model selector, Turbo models should be available under **External**.
|
||||||
@@ -30,20 +30,6 @@ To install the Ollama application in a location different than your home directo
|
|||||||
OllamaSetup.exe /DIR="d:\some\location"
|
OllamaSetup.exe /DIR="d:\some\location"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Changing Model Location
|
|
||||||
|
|
||||||
To change where Ollama stores the downloaded models instead of using your home directory, set the environment variable `OLLAMA_MODELS` in your user account.
|
|
||||||
|
|
||||||
1. Start the Settings (Windows 11) or Control Panel (Windows 10) application and search for _environment variables_.
|
|
||||||
|
|
||||||
2. Click on _Edit environment variables for your account_.
|
|
||||||
|
|
||||||
3. Edit or create a new variable for your user account for `OLLAMA_MODELS` where you want the models stored
|
|
||||||
|
|
||||||
4. Click OK/Apply to save.
|
|
||||||
|
|
||||||
If Ollama is already running, Quit the tray application and relaunch it from the Start menu, or a new terminal started after you saved the environment variables.
|
|
||||||
|
|
||||||
## API Access
|
## API Access
|
||||||
|
|
||||||
Here's a quick example showing API access from `powershell`
|
Here's a quick example showing API access from `powershell`
|
||||||
@@ -62,7 +48,6 @@ the explorer window by hitting `<Ctrl>+R` and type in:
|
|||||||
- *upgrade.log* contains log output for upgrades
|
- *upgrade.log* contains log output for upgrades
|
||||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
|
||||||
|
|
||||||
## Uninstall
|
## Uninstall
|
||||||
|
|
||||||
@@ -81,9 +66,11 @@ help you keep up to date.
|
|||||||
|
|
||||||
If you'd like to install or integrate Ollama as a service, a standalone
|
If you'd like to install or integrate Ollama as a service, a standalone
|
||||||
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
`ollama-windows-amd64.zip` zip file is available containing only the Ollama CLI
|
||||||
and GPU library dependencies for Nvidia and AMD. This allows for embedding
|
and GPU library dependencies for Nvidia. If you have an AMD GPU, also download
|
||||||
Ollama in existing applications, or running it as a system service via `ollama
|
and extract the additional ROCm package `ollama-windows-amd64-rocm.zip` into the
|
||||||
serve` with tools such as [NSSM](https://nssm.cc/).
|
same directory. Both zip files are necessary for a complete AMD installation.
|
||||||
|
This allows for embedding Ollama in existing applications, or running it as a
|
||||||
|
system service via `ollama serve` with tools such as [NSSM](https://nssm.cc/).
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> If you are upgrading from a prior version, you should remove the old directories first.
|
> If you are upgrading from a prior version, you should remove the old directories first.
|
||||||
|
|||||||
@@ -149,9 +149,22 @@ func Bool(k string) func() bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LogLevel returns the log level for the application.
|
||||||
|
// Values are 0 or false INFO (Default), 1 or true DEBUG, 2 TRACE
|
||||||
|
func LogLevel() slog.Level {
|
||||||
|
level := slog.LevelInfo
|
||||||
|
if s := Var("OLLAMA_DEBUG"); s != "" {
|
||||||
|
if b, _ := strconv.ParseBool(s); b {
|
||||||
|
level = slog.LevelDebug
|
||||||
|
} else if i, _ := strconv.ParseInt(s, 10, 64); i != 0 {
|
||||||
|
level = slog.Level(i * -4)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return level
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// Debug enabled additional debug information.
|
|
||||||
Debug = Bool("OLLAMA_DEBUG")
|
|
||||||
// FlashAttention enables the experimental flash attention feature.
|
// FlashAttention enables the experimental flash attention feature.
|
||||||
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
FlashAttention = Bool("OLLAMA_FLASH_ATTENTION")
|
||||||
// KvCacheType is the quantization type for the K/V cache.
|
// KvCacheType is the quantization type for the K/V cache.
|
||||||
@@ -169,7 +182,11 @@ var (
|
|||||||
// Enable the new Ollama engine
|
// Enable the new Ollama engine
|
||||||
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
NewEngine = Bool("OLLAMA_NEW_ENGINE")
|
||||||
// ContextLength sets the default context length
|
// ContextLength sets the default context length
|
||||||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 2048)
|
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||||
|
// Auth enables authentication between the Ollama client and server
|
||||||
|
UseAuth = Bool("OLLAMA_AUTH")
|
||||||
|
// Enable the new memory estimation logic
|
||||||
|
NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES")
|
||||||
)
|
)
|
||||||
|
|
||||||
func String(s string) func() string {
|
func String(s string) func() string {
|
||||||
@@ -204,13 +221,11 @@ func Uint(key string, defaultValue uint) func() uint {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
|
// NumParallel sets the number of parallel model requests. NumParallel can be configured via the OLLAMA_NUM_PARALLEL environment variable.
|
||||||
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 0)
|
NumParallel = Uint("OLLAMA_NUM_PARALLEL", 1)
|
||||||
// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
|
// MaxRunners sets the maximum number of loaded models. MaxRunners can be configured via the OLLAMA_MAX_LOADED_MODELS environment variable.
|
||||||
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
|
MaxRunners = Uint("OLLAMA_MAX_LOADED_MODELS", 0)
|
||||||
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
|
// MaxQueue sets the maximum number of queued requests. MaxQueue can be configured via the OLLAMA_MAX_QUEUE environment variable.
|
||||||
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
|
MaxQueue = Uint("OLLAMA_MAX_QUEUE", 512)
|
||||||
// MaxVRAM sets a maximum VRAM override in bytes. MaxVRAM can be configured via the OLLAMA_MAX_VRAM environment variable.
|
|
||||||
MaxVRAM = Uint("OLLAMA_MAX_VRAM", 0)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Uint64(key string, defaultValue uint64) func() uint64 {
|
func Uint64(key string, defaultValue uint64) func() uint64 {
|
||||||
@@ -238,7 +253,7 @@ type EnvVar struct {
|
|||||||
|
|
||||||
func AsMap() map[string]EnvVar {
|
func AsMap() map[string]EnvVar {
|
||||||
ret := map[string]EnvVar{
|
ret := map[string]EnvVar{
|
||||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", LogLevel(), "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention(), "Enabled flash attention"},
|
||||||
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
"OLLAMA_KV_CACHE_TYPE": {"OLLAMA_KV_CACHE_TYPE", KvCacheType(), "Quantization type for the K/V cache (default: f16)"},
|
||||||
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
"OLLAMA_GPU_OVERHEAD": {"OLLAMA_GPU_OVERHEAD", GpuOverhead(), "Reserve a portion of VRAM per GPU (bytes)"},
|
||||||
@@ -255,8 +270,9 @@ func AsMap() map[string]EnvVar {
|
|||||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"},
|
||||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
|
||||||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 2048)"},
|
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||||
|
"OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"},
|
||||||
|
|
||||||
// Informational
|
// Informational
|
||||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package envconfig
|
package envconfig
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
"math"
|
"math"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestHost(t *testing.T) {
|
func TestHost(t *testing.T) {
|
||||||
@@ -279,8 +281,8 @@ func TestVar(t *testing.T) {
|
|||||||
|
|
||||||
func TestContextLength(t *testing.T) {
|
func TestContextLength(t *testing.T) {
|
||||||
cases := map[string]uint{
|
cases := map[string]uint{
|
||||||
"": 2048,
|
"": 4096,
|
||||||
"4096": 4096,
|
"2048": 2048,
|
||||||
}
|
}
|
||||||
|
|
||||||
for k, v := range cases {
|
for k, v := range cases {
|
||||||
@@ -292,3 +294,34 @@ func TestContextLength(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLogLevel(t *testing.T) {
|
||||||
|
cases := map[string]slog.Level{
|
||||||
|
// Default to INFO
|
||||||
|
"": slog.LevelInfo,
|
||||||
|
"false": slog.LevelInfo,
|
||||||
|
"f": slog.LevelInfo,
|
||||||
|
"0": slog.LevelInfo,
|
||||||
|
|
||||||
|
// True values enable Debug
|
||||||
|
"true": slog.LevelDebug,
|
||||||
|
"t": slog.LevelDebug,
|
||||||
|
|
||||||
|
// Positive values increase verbosity
|
||||||
|
"1": slog.LevelDebug,
|
||||||
|
"2": logutil.LevelTrace,
|
||||||
|
|
||||||
|
// Negative values decrease verbosity
|
||||||
|
"-1": slog.LevelWarn,
|
||||||
|
"-2": slog.LevelError,
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range cases {
|
||||||
|
t.Run(k, func(t *testing.T) {
|
||||||
|
t.Setenv("OLLAMA_DEBUG", k)
|
||||||
|
if i := LogLevel(); i != v {
|
||||||
|
t.Errorf("%s: expected %d, got %d", k, v, i)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func assertEqual(t *testing.T, a interface{}, b interface{}) {
|
func assertEqual(t *testing.T, a any, b any) {
|
||||||
if a != b {
|
if a != b {
|
||||||
t.Errorf("Assert failed, expected %v, got %v", b, a)
|
t.Errorf("Assert failed, expected %v, got %v", b, a)
|
||||||
}
|
}
|
||||||
|
|||||||
14
fs/config.go
Normal file
14
fs/config.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package fs
|
||||||
|
|
||||||
|
type Config interface {
|
||||||
|
Architecture() string
|
||||||
|
String(string, ...string) string
|
||||||
|
Uint(string, ...uint32) uint32
|
||||||
|
Float(string, ...float32) float32
|
||||||
|
Bool(string, ...bool) bool
|
||||||
|
|
||||||
|
Strings(string, ...[]string) []string
|
||||||
|
Ints(string, ...[]int32) []int32
|
||||||
|
Floats(string, ...[]float32) []float32
|
||||||
|
Bools(string, ...[]bool) []bool
|
||||||
|
}
|
||||||
446
fs/ggml/ggml.go
446
fs/ggml/ggml.go
@@ -1,20 +1,24 @@
|
|||||||
package ggml
|
package ggml
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"cmp"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/format"
|
||||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GGML struct {
|
type GGML struct {
|
||||||
container
|
container
|
||||||
model
|
model
|
||||||
|
Length int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type model interface {
|
type model interface {
|
||||||
@@ -33,15 +37,16 @@ func (kv KV) Kind() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) ParameterCount() uint64 {
|
func (kv KV) ParameterCount() uint64 {
|
||||||
return keyValue[uint64](kv, "general.parameter_count")
|
val, _ := keyValue(kv, "general.parameter_count", uint64(0))
|
||||||
|
return val
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) FileType() fileType {
|
func (kv KV) FileType() FileType {
|
||||||
if t := kv.Uint("general.file_type"); t > 0 {
|
if t := kv.Uint("general.file_type"); t > 0 {
|
||||||
return fileType(t)
|
return FileType(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
return fileTypeUnknown
|
return FileTypeUnknown
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) BlockCount() uint64 {
|
func (kv KV) BlockCount() uint64 {
|
||||||
@@ -52,16 +57,27 @@ func (kv KV) EmbeddingLength() uint64 {
|
|||||||
return uint64(kv.Uint("embedding_length"))
|
return uint64(kv.Uint("embedding_length"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCount() uint64 {
|
func (kv KV) HeadCountMax() uint64 {
|
||||||
return uint64(kv.Uint("attention.head_count"))
|
// TODO(drifkin): using the max value can cause an overestimation. In the
|
||||||
|
// future if array values become more popular, we can adapt the more invasive
|
||||||
|
// <https://github.com/ollama/ollama/pull/10225>
|
||||||
|
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) HeadCountKV() uint64 {
|
func (kv KV) HeadCountMin() uint64 {
|
||||||
return uint64(kv.Uint("attention.head_count_kv", 1))
|
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
func (kv KV) HeadCountKVMax() uint64 {
|
||||||
if heads := kv.HeadCount(); heads > 0 {
|
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) HeadCountKVMin() uint64 {
|
||||||
|
return uint64(kv.UintOrMinArrayValue("attention.head_count_kv", 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) EmbeddingHeadCountMax() uint64 {
|
||||||
|
if heads := kv.HeadCountMin(); heads > 0 {
|
||||||
return kv.EmbeddingLength() / heads
|
return kv.EmbeddingLength() / heads
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -69,15 +85,11 @@ func (kv KV) EmbeddingHeadCount() uint64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCountMax())))
|
||||||
}
|
|
||||||
|
|
||||||
func (kv KV) GQA() uint64 {
|
|
||||||
return kv.HeadCount() / kv.HeadCountKV()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) ContextLength() uint64 {
|
func (kv KV) ContextLength() uint64 {
|
||||||
@@ -89,52 +101,114 @@ func (kv KV) ChatTemplate() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) String(key string, defaultValue ...string) string {
|
func (kv KV) String(key string, defaultValue ...string) string {
|
||||||
return keyValue(kv, key, append(defaultValue, "")...)
|
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||||
|
return val
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
func (kv KV) Uint(key string, defaultValue ...uint32) uint32 {
|
||||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||||
|
return val
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
val, _ := keyValue(kv, key, append(defaultValue, 0)...)
|
||||||
|
return val
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||||
return keyValue(kv, key, append(defaultValue, false)...)
|
val, _ := keyValue(kv, key, append(defaultValue, false)...)
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) UintOrMaxArrayValue(key string, defaultValue uint32) uint32 {
|
||||||
|
_, max := kv.UintOrArrayValue(key, defaultValue)
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
|
||||||
|
min, _ := kv.UintOrArrayValue(key, defaultValue)
|
||||||
|
return min
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
||||||
|
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
||||||
|
return u32, u32
|
||||||
|
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
||||||
|
min := slices.Min(u32s.values)
|
||||||
|
max := slices.Max(u32s.values)
|
||||||
|
return min, max
|
||||||
|
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
||||||
|
min := slices.Min(i32s.values)
|
||||||
|
max := slices.Max(i32s.values)
|
||||||
|
if min < 0 || max < 0 {
|
||||||
|
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
|
||||||
|
}
|
||||||
|
return uint32(min), uint32(max)
|
||||||
|
}
|
||||||
|
|
||||||
|
return defaultValue, defaultValue
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||||
r := keyValue(kv, key, &array{})
|
val, _ := keyValue(kv, key, &array[string]{values: append(defaultValue, []string(nil))[0]})
|
||||||
s := make([]string, r.size)
|
return val.values
|
||||||
for i := range r.size {
|
}
|
||||||
s[i] = r.values[i].(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s
|
func (kv KV) Ints(key string, defaultValue ...[]int32) []int32 {
|
||||||
|
val, _ := keyValue(kv, key, &array[int32]{values: append(defaultValue, []int32(nil))[0]})
|
||||||
|
return val.values
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||||
r := keyValue(kv, key, &array{})
|
val, _ := keyValue(kv, key, &array[uint32]{values: append(defaultValue, []uint32(nil))[0]})
|
||||||
s := make([]uint32, r.size)
|
return val.values
|
||||||
for i := range r.size {
|
|
||||||
s[i] = uint32(r.values[i].(int32))
|
|
||||||
}
|
|
||||||
|
|
||||||
return s
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
func (kv KV) Floats(key string, defaultValue ...[]float32) []float32 {
|
||||||
|
val, _ := keyValue(kv, key, &array[float32]{values: append(defaultValue, []float32(nil))[0]})
|
||||||
|
return val.values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) Bools(key string, defaultValue ...[]bool) []bool {
|
||||||
|
val, _ := keyValue(kv, key, &array[bool]{values: append(defaultValue, []bool(nil))[0]})
|
||||||
|
return val.values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KV) OllamaEngineRequired() bool {
|
||||||
|
return slices.Contains([]string{
|
||||||
|
"gemma3",
|
||||||
|
"gemma3n",
|
||||||
|
"mistral3",
|
||||||
|
"llama4",
|
||||||
|
"mllama",
|
||||||
|
"qwen25vl",
|
||||||
|
"gptoss", "gpt-oss",
|
||||||
|
}, kv.Architecture())
|
||||||
|
}
|
||||||
|
|
||||||
|
type valueTypes interface {
|
||||||
|
uint8 | int8 | uint16 | int16 |
|
||||||
|
uint32 | int32 | uint64 | int64 |
|
||||||
|
string | float32 | float64 | bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type arrayValueTypes interface {
|
||||||
|
*array[uint8] | *array[int8] | *array[uint16] | *array[int16] |
|
||||||
|
*array[uint32] | *array[int32] | *array[uint64] | *array[int64] |
|
||||||
|
*array[string] | *array[float32] | *array[float64] | *array[bool]
|
||||||
|
}
|
||||||
|
|
||||||
|
func keyValue[T valueTypes | arrayValueTypes](kv KV, key string, defaultValue ...T) (T, bool) {
|
||||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||||
key = kv.Architecture() + "." + key
|
key = kv.Architecture() + "." + key
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, ok := kv[key]; ok {
|
if val, ok := kv[key].(T); ok {
|
||||||
return val.(T)
|
return val, true
|
||||||
}
|
}
|
||||||
|
|
||||||
slog.Warn("key not found", "key", key, "default", defaultValue[0])
|
slog.Debug("key with type not found", "key", key, "default", defaultValue[0])
|
||||||
return defaultValue[0]
|
return defaultValue[0], false
|
||||||
}
|
}
|
||||||
|
|
||||||
type Tensors struct {
|
type Tensors struct {
|
||||||
@@ -203,32 +277,37 @@ type Tensor struct {
|
|||||||
|
|
||||||
func (t Tensor) block() (n int) {
|
func (t Tensor) block() (n int) {
|
||||||
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
||||||
return -1
|
return math.MaxInt
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tensor) blockSize() uint64 {
|
func (t Tensor) blockSize() uint64 {
|
||||||
switch t.Kind {
|
return TensorType(t.Kind).BlockSize()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorType) BlockSize() uint64 {
|
||||||
|
switch t {
|
||||||
case
|
case
|
||||||
0, // F32
|
TensorTypeF32,
|
||||||
1, // F16
|
TensorTypeF16,
|
||||||
24, // I8
|
TensorTypeI8,
|
||||||
25, // I16
|
TensorTypeI16,
|
||||||
26, // I32
|
TensorTypeI32,
|
||||||
27, // I64
|
TensorTypeI64,
|
||||||
28, // F64
|
TensorTypeF64,
|
||||||
30: // BF16
|
TensorTypeBF16:
|
||||||
return 1
|
return 1
|
||||||
case
|
case
|
||||||
2, // Q4_0
|
TensorTypeQ4_0,
|
||||||
3, // Q4_1
|
TensorTypeQ4_1,
|
||||||
6, // Q5_0
|
TensorTypeQ5_0,
|
||||||
7, // Q5_1
|
TensorTypeQ5_1,
|
||||||
8, // Q8_0
|
TensorTypeQ8_0,
|
||||||
9, // Q8_1
|
TensorTypeQ8_1,
|
||||||
20: // IQ4_NL
|
tensorTypeIQ4_NL,
|
||||||
|
4, TensorTypeMXFP4:
|
||||||
return 32
|
return 32
|
||||||
default:
|
default:
|
||||||
return 256
|
return 256
|
||||||
@@ -236,73 +315,79 @@ func (t Tensor) blockSize() uint64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tensor) typeSize() uint64 {
|
func (t Tensor) typeSize() uint64 {
|
||||||
blockSize := t.blockSize()
|
return TensorType(t.Kind).TypeSize()
|
||||||
|
}
|
||||||
|
|
||||||
switch t.Kind {
|
func (t TensorType) TypeSize() uint64 {
|
||||||
case 0: // FP32
|
blockSize := t.BlockSize()
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case TensorTypeF32:
|
||||||
return 4
|
return 4
|
||||||
case 1: // FP16
|
case TensorTypeF16:
|
||||||
return 2
|
return 2
|
||||||
case 2: // Q4_0
|
case TensorTypeQ4_0:
|
||||||
return 2 + blockSize/2
|
return 2 + blockSize/2
|
||||||
case 3: // Q4_1
|
case TensorTypeQ4_1:
|
||||||
return 2 + 2 + blockSize/2
|
return 2 + 2 + blockSize/2
|
||||||
case 6: // Q5_0
|
case TensorTypeQ5_0:
|
||||||
return 2 + 4 + blockSize/2
|
return 2 + 4 + blockSize/2
|
||||||
case 7: // Q5_1
|
case TensorTypeQ5_1:
|
||||||
return 2 + 2 + 4 + blockSize/2
|
return 2 + 2 + 4 + blockSize/2
|
||||||
case 8: // Q8_0
|
case TensorTypeQ8_0:
|
||||||
return 2 + blockSize
|
return 2 + blockSize
|
||||||
case 9: // Q8_1
|
case TensorTypeQ8_1:
|
||||||
return 2 + 2 + blockSize
|
return 2 + 2 + blockSize
|
||||||
case 10: // Q2_K
|
case TensorTypeQ2_K:
|
||||||
return blockSize/16 + blockSize/4 + 2 + 2
|
return blockSize/16 + blockSize/4 + 2 + 2
|
||||||
case 11: // Q3_K
|
case TensorTypeQ3_K:
|
||||||
return blockSize/8 + blockSize/4 + 12 + 2
|
return blockSize/8 + blockSize/4 + 12 + 2
|
||||||
case 12: // Q4_K
|
case TensorTypeQ4_K:
|
||||||
return 2 + 2 + 12 + blockSize/2
|
return 2 + 2 + 12 + blockSize/2
|
||||||
case 13: // Q5_K
|
case TensorTypeQ5_K:
|
||||||
return 2 + 2 + 12 + blockSize/8 + blockSize/2
|
return 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||||
case 14: // Q6_K
|
case TensorTypeQ6_K:
|
||||||
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||||
case 15: // Q8_K
|
case TensorTypeQ8_K:
|
||||||
return 4 + blockSize + 2*blockSize/16
|
return 4 + blockSize + 2*blockSize/16
|
||||||
case 16: // IQ2_XXS
|
case tensorTypeIQ2_XXS:
|
||||||
return 2 + 2*blockSize/8
|
return 2 + 2*blockSize/8
|
||||||
case 17: // IQ2_XS
|
case tensorTypeIQ2_XS:
|
||||||
return 2 + 2*blockSize/8 + blockSize/32
|
return 2 + 2*blockSize/8 + blockSize/32
|
||||||
case 18: // IQ3_XXS
|
case tensorTypeIQ3_XXS:
|
||||||
return 2 + blockSize/4 + blockSize/8
|
return 2 + blockSize/4 + blockSize/8
|
||||||
case 19: // IQ1_S
|
case tensorTypeIQ1_S:
|
||||||
return 2 + blockSize/8 + blockSize/16
|
return 2 + blockSize/8 + blockSize/16
|
||||||
case 20: // IQ4_NL
|
case tensorTypeIQ4_NL:
|
||||||
return 2 + blockSize/2
|
return 2 + blockSize/2
|
||||||
case 21: // IQ3_S
|
case tensorTypeIQ3_S:
|
||||||
return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
|
return 2 + blockSize/4 + blockSize/8 + blockSize/32 + 4
|
||||||
case 22: // IQ2_S
|
case tensorTypeIQ2_S:
|
||||||
return 2 + blockSize/4 + blockSize/16
|
return 2 + blockSize/4 + blockSize/16
|
||||||
case 23: // IQ4_XS
|
case tensorTypeIQ4_XS:
|
||||||
return 2 + 2 + blockSize/2 + blockSize/64
|
return 2 + 2 + blockSize/2 + blockSize/64
|
||||||
case 24: // I8
|
case TensorTypeI8:
|
||||||
return 1
|
return 1
|
||||||
case 25: // I16
|
case TensorTypeI16:
|
||||||
return 2
|
return 2
|
||||||
case 26: // I32
|
case TensorTypeI32:
|
||||||
return 4
|
return 4
|
||||||
case 27: // I64
|
case TensorTypeI64:
|
||||||
return 8
|
return 8
|
||||||
case 28: // F64
|
case TensorTypeF64:
|
||||||
return 8
|
return 8
|
||||||
case 29: // IQ1_M
|
case tensorTypeIQ1_M:
|
||||||
return blockSize/8 + blockSize/16 + blockSize/32
|
return blockSize/8 + blockSize/16 + blockSize/32
|
||||||
case 30: // BF16
|
case TensorTypeBF16:
|
||||||
return 2
|
return 2
|
||||||
|
case 4, TensorTypeMXFP4:
|
||||||
|
return 1 + blockSize/2
|
||||||
default:
|
default:
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t Tensor) parameters() uint64 {
|
func (t Tensor) Elements() uint64 {
|
||||||
var count uint64 = 1
|
var count uint64 = 1
|
||||||
for _, n := range t.Shape {
|
for _, n := range t.Shape {
|
||||||
count *= n
|
count *= n
|
||||||
@@ -311,7 +396,11 @@ func (t Tensor) parameters() uint64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t Tensor) Size() uint64 {
|
func (t Tensor) Size() uint64 {
|
||||||
return t.parameters() * t.typeSize() / t.blockSize()
|
return t.Elements() * t.typeSize() / t.blockSize()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Tensor) Type() string {
|
||||||
|
return TensorType(t.Kind).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
type container interface {
|
type container interface {
|
||||||
@@ -355,18 +444,13 @@ func DetectContentType(b []byte) string {
|
|||||||
// Decode decodes a GGML model from the given reader.
|
// Decode decodes a GGML model from the given reader.
|
||||||
//
|
//
|
||||||
// It collects array values for arrays with a size less than or equal to
|
// It collects array values for arrays with a size less than or equal to
|
||||||
// maxArraySize. If maxArraySize is 0, the default value of 1024 is used. If
|
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
|
||||||
// the maxArraySize is negative, all arrays are collected.
|
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||||
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|
||||||
if maxArraySize == 0 {
|
|
||||||
maxArraySize = 1024
|
|
||||||
}
|
|
||||||
|
|
||||||
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
|
||||||
|
|
||||||
var magic uint32
|
var magic uint32
|
||||||
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
||||||
return nil, 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var c container
|
var c container
|
||||||
@@ -376,43 +460,51 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
|||||||
case FILE_MAGIC_GGUF_BE:
|
case FILE_MAGIC_GGUF_BE:
|
||||||
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
|
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
|
||||||
default:
|
default:
|
||||||
return nil, 0, errors.New("invalid file magic")
|
return nil, errors.New("invalid file magic")
|
||||||
}
|
}
|
||||||
|
|
||||||
model, err := c.Decode(rs)
|
model, err := c.Decode(rs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// final model type
|
// final model type
|
||||||
return &GGML{
|
return &GGML{
|
||||||
container: c,
|
container: c,
|
||||||
model: model,
|
model: model,
|
||||||
}, offset, nil
|
Length: offset,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
|
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||||
embedding := f.KV().EmbeddingLength()
|
context *= uint64(numParallel)
|
||||||
heads := f.KV().HeadCount()
|
|
||||||
headsKV := f.KV().HeadCountKV()
|
|
||||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
|
|
||||||
|
|
||||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
embedding := f.KV().EmbeddingLength()
|
||||||
|
heads := f.KV().HeadCountMax()
|
||||||
|
headsKV := f.KV().HeadCountKVMax()
|
||||||
|
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||||
|
|
||||||
|
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
||||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||||
|
|
||||||
layers := f.Tensors().GroupLayers()
|
layers := f.Tensors().GroupLayers()
|
||||||
|
|
||||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||||
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
var kvTotal uint64
|
||||||
|
kv = make([]uint64, f.KV().BlockCount())
|
||||||
|
for i := range kv {
|
||||||
|
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
kvTotal += kv[i]
|
||||||
|
}
|
||||||
|
|
||||||
switch f.KV().Architecture() {
|
switch f.KV().Architecture() {
|
||||||
case "llama":
|
case "llama", "llama4":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(1+4*embedding+context*(1+heads)),
|
4*batch*(1+4*embedding+context*(1+heads)),
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
@@ -426,7 +518,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
|
|
||||||
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
if ffnGateExpsWeight, ok := layers["blk.0"]["ffn_gate_exps.weight"]; ok {
|
||||||
// mixtral 8x22b
|
// mixtral 8x22b
|
||||||
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
|
ff := uint64(f.KV().Uint("feed_forward_length"))
|
||||||
partialOffload = max(
|
partialOffload = max(
|
||||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
||||||
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
||||||
@@ -443,16 +535,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
case "mllama":
|
case "mllama":
|
||||||
var visionTokens, tiles uint64 = 1601, 4
|
var visionTokens, tiles uint64 = 1601, 4
|
||||||
|
|
||||||
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
|
crossAttentionLayers := f.KV().Ints("attention.cross_attention_layers")
|
||||||
kv = headsKV *
|
for i := range kv {
|
||||||
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
|
if slices.Contains(crossAttentionLayers, int32(i)) {
|
||||||
(2* // sizeof(float16)
|
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||||
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
|
4 * // sizeof(float32)
|
||||||
context +
|
visionTokens *
|
||||||
4* // sizeof(float32)
|
tiles
|
||||||
uint64(crossAttentionLayers.size)* // num cross attention layers
|
}
|
||||||
visionTokens*
|
|
||||||
tiles)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
@@ -464,7 +554,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
var ropeFreqsCount uint64
|
var ropeFreqsCount uint64
|
||||||
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
|
if ropeFreqs, ok := f.Tensors().GroupLayers()["rope_freqs"]; ok {
|
||||||
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
if ropeFreqsWeights, ok := ropeFreqs["weights"]; ok {
|
||||||
ropeFreqsCount = ropeFreqsWeights.parameters()
|
ropeFreqsCount = ropeFreqsWeights.Elements()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -476,7 +566,7 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
// vocab graph
|
// vocab graph
|
||||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||||
)
|
)
|
||||||
case "gemma", "gemma2":
|
case "gemma", "gemma2", "gemma3", "gemma3n":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||||
@@ -488,6 +578,25 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
4*embeddingHeadsK*context*8+
|
4*embeddingHeadsK*context*8+
|
||||||
embedding*embeddingHeadsK*heads*9/16,
|
embedding*embeddingHeadsK*heads*9/16,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if f.KV().Architecture() == "gemma3n" {
|
||||||
|
fullOffload *= 4
|
||||||
|
partialOffload *= 4
|
||||||
|
}
|
||||||
|
|
||||||
|
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||||
|
// engine. Gemma3 always uses the Ollama engine.
|
||||||
|
if f.KV().Architecture() == "gemma3" {
|
||||||
|
const gemma3GlobalCacheCount = 6
|
||||||
|
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
|
||||||
|
for i := range kv {
|
||||||
|
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||||
|
// layers are the smaller local (sliding) layers.
|
||||||
|
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||||
|
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
case "command-r":
|
case "command-r":
|
||||||
fullOffload = max(
|
fullOffload = max(
|
||||||
4*batch*(embedding+vocab),
|
4*batch*(embedding+vocab),
|
||||||
@@ -560,13 +669,101 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
|||||||
4*qkvBias.Shape[0],
|
4*qkvBias.Shape[0],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
case "gptoss", "gpt-oss":
|
||||||
|
kv = make([]uint64, f.KV().BlockCount())
|
||||||
|
for i := range kv {
|
||||||
|
kv[i] = uint64(float64((embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||||
|
if i%2 == 0 {
|
||||||
|
kv[i] *= (uint64(numParallel)*4096 + batch)
|
||||||
|
} else {
|
||||||
|
kv[i] *= context
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||||
|
if useFlashAttention {
|
||||||
|
// rough estimate of graph size with flash attention on
|
||||||
|
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||||
|
if llm.KV().Uint("vision.block_count") == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, layer := range llm.Tensors().GroupLayers() {
|
||||||
|
if name == "v" || strings.HasPrefix(name, "v.") {
|
||||||
|
for _, tensor := range layer {
|
||||||
|
weights += tensor.Size()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
imageSize := uint64(llm.KV().Uint("vision.image_size"))
|
||||||
|
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
|
||||||
|
if patchSize == 0 {
|
||||||
|
slog.Warn("unknown patch size for vision model")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
|
||||||
|
|
||||||
|
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
|
||||||
|
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||||
|
numPatches++
|
||||||
|
}
|
||||||
|
|
||||||
|
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
|
||||||
|
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
|
||||||
|
|
||||||
|
switch llm.KV().Architecture() {
|
||||||
|
case "mllama":
|
||||||
|
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||||
|
|
||||||
|
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
|
||||||
|
|
||||||
|
graphSize = 4 * (8 +
|
||||||
|
imageSize*imageSize*numChannels*maxNumTiles +
|
||||||
|
embeddingLength*numPatches*maxNumTiles +
|
||||||
|
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||||
|
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||||
|
case "gemma3", "mistral3":
|
||||||
|
graphSize = 4 * (imageSize*imageSize*numChannels +
|
||||||
|
embeddingLength*patchSize +
|
||||||
|
numPatches*numPatches*headCount)
|
||||||
|
case "qwen25vl":
|
||||||
|
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
|
||||||
|
|
||||||
|
numPatches := maxPixels / (patchSize * patchSize)
|
||||||
|
|
||||||
|
graphSize = 4 * (maxPixels*numChannels + // Original image storage
|
||||||
|
// Normalized pixels
|
||||||
|
maxPixels*numChannels +
|
||||||
|
// Patches storage (numPatches * channels * patchSize^2)
|
||||||
|
numPatches*numChannels*patchSize*patchSize +
|
||||||
|
// Self-attention calculations
|
||||||
|
numPatches*numPatches*headCount +
|
||||||
|
// Additional buffer for processing
|
||||||
|
embeddingLength*numPatches)
|
||||||
|
case "llama4":
|
||||||
|
// vision graph is computed independently in the same schedule
|
||||||
|
// and is negligible compared to the worst case text graph
|
||||||
|
}
|
||||||
|
|
||||||
|
return weights, graphSize
|
||||||
|
}
|
||||||
|
|
||||||
// SupportsKVCacheType checks if the requested cache type is supported
|
// SupportsKVCacheType checks if the requested cache type is supported
|
||||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||||
|
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
|
||||||
|
// gpt-oss uses attention with sinks which does not support quantized cache types
|
||||||
|
slog.Warn("model only supports non-quantized cache types ", "mode", arch)
|
||||||
|
return cacheType == "f16"
|
||||||
|
}
|
||||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -583,6 +780,13 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FlashAttention checks if the model should enable flash attention
|
||||||
|
func (f GGML) FlashAttention() bool {
|
||||||
|
return slices.Contains([]string{
|
||||||
|
"gptoss", "gpt-oss",
|
||||||
|
}, f.KV().String("general.architecture"))
|
||||||
|
}
|
||||||
|
|
||||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||||
switch cacheType {
|
switch cacheType {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package ggml
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"maps"
|
"maps"
|
||||||
|
"math"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -210,3 +211,91 @@ func TestTensorTypes(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestKeyValue(t *testing.T) {
|
||||||
|
kv := KV{
|
||||||
|
"general.architecture": "test",
|
||||||
|
"test.strings": &array[string]{size: 3, values: []string{"a", "b", "c"}},
|
||||||
|
"test.float32s": &array[float32]{size: 3, values: []float32{1.0, 2.0, 3.0}},
|
||||||
|
"test.int32s": &array[int32]{size: 3, values: []int32{1, 2, 3}},
|
||||||
|
"test.uint32s": &array[uint32]{size: 3, values: []uint32{1, 2, 3}},
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Strings("strings"), []string{"a", "b", "c"}); diff != "" {
|
||||||
|
t.Errorf("unexpected strings (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Strings("nonexistent.strings"), []string(nil)); diff != "" {
|
||||||
|
t.Errorf("unexpected strings (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Strings("default.strings", []string{"ollama"}), []string{"ollama"}); diff != "" {
|
||||||
|
t.Errorf("unexpected strings (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Floats("float32s"), []float32{1.0, 2.0, 3.0}); diff != "" {
|
||||||
|
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Floats("nonexistent.float32s"), []float32(nil)); diff != "" {
|
||||||
|
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Floats("default.float32s", []float32{math.MaxFloat32}), []float32{math.MaxFloat32}); diff != "" {
|
||||||
|
t.Errorf("unexpected float32s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Ints("int32s"), []int32{1, 2, 3}); diff != "" {
|
||||||
|
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Ints("nonexistent.int32s"), []int32(nil)); diff != "" {
|
||||||
|
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Ints("default.int32s", []int32{math.MaxInt32}), []int32{math.MaxInt32}); diff != "" {
|
||||||
|
t.Errorf("unexpected int8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Uints("uint32s"), []uint32{1, 2, 3}); diff != "" {
|
||||||
|
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Uints("nonexistent.uint32s"), []uint32(nil)); diff != "" {
|
||||||
|
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kv.Uints("default.uint32s", []uint32{math.MaxUint32}), []uint32{math.MaxUint32}); diff != "" {
|
||||||
|
t.Errorf("unexpected uint8s (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeadCount(t *testing.T) {
|
||||||
|
valuesArray := []int32{1, 5, 3, 4}
|
||||||
|
cases := []struct {
|
||||||
|
kv KV
|
||||||
|
want uint64
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
kv: KV{
|
||||||
|
"general.architecture": "abc",
|
||||||
|
"abc.attention.head_count": &array[int32]{values: valuesArray, size: len(valuesArray)},
|
||||||
|
},
|
||||||
|
want: uint64(5),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
kv: KV{
|
||||||
|
"general.architecture": "abc",
|
||||||
|
"abc.attention.head_count": uint32(3),
|
||||||
|
},
|
||||||
|
want: uint64(3),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
got := tt.kv.HeadCountMax()
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("unexpected max value: got=%d want=%d", got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
346
fs/ggml/gguf.go
346
fs/ggml/gguf.go
@@ -9,8 +9,12 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"maps"
|
"maps"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
)
|
)
|
||||||
|
|
||||||
type containerGGUF struct {
|
type containerGGUF struct {
|
||||||
@@ -36,10 +40,6 @@ type containerGGUF struct {
|
|||||||
maxArraySize int
|
maxArraySize int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *containerGGUF) canCollectArray(size int) bool {
|
|
||||||
return c.maxArraySize < 0 || size <= c.maxArraySize
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *containerGGUF) Name() string {
|
func (c *containerGGUF) Name() string {
|
||||||
return "gguf"
|
return "gguf"
|
||||||
}
|
}
|
||||||
@@ -229,16 +229,13 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
llm.tensors = append(llm.tensors, &tensor)
|
llm.tensors = append(llm.tensors, &tensor)
|
||||||
llm.parameters += tensor.parameters()
|
llm.parameters += tensor.Elements()
|
||||||
}
|
}
|
||||||
|
|
||||||
// patch KV with parameter count
|
// patch KV with parameter count
|
||||||
llm.kv["general.parameter_count"] = llm.parameters
|
llm.kv["general.parameter_count"] = llm.parameters
|
||||||
|
|
||||||
alignment, ok := llm.kv["general.alignment"].(uint32)
|
alignment := llm.kv.Uint("general.alignment", 32)
|
||||||
if !ok {
|
|
||||||
alignment = 32
|
|
||||||
}
|
|
||||||
|
|
||||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -298,6 +295,23 @@ func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
|
|||||||
return b.String(), nil
|
return b.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
|
||||||
|
for i := range a.size {
|
||||||
|
if a.values != nil {
|
||||||
|
e, err := readGGUFV1String(llm, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
a.values[i] = e
|
||||||
|
} else {
|
||||||
|
discardGGUFString(llm, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
func discardGGUFString(llm *gguf, r io.Reader) error {
|
func discardGGUFString(llm *gguf, r io.Reader) error {
|
||||||
buf := llm.scratch[:8]
|
buf := llm.scratch[:8]
|
||||||
_, err := io.ReadFull(r, buf)
|
_, err := io.ReadFull(r, buf)
|
||||||
@@ -355,78 +369,44 @@ func writeGGUFString(w io.Writer, s string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
type array struct {
|
func readGGUFStringsData(llm *gguf, r io.Reader, a *array[string]) (any, error) {
|
||||||
size int
|
for i := range a.size {
|
||||||
values []any
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *array) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(a.values)
|
|
||||||
}
|
|
||||||
|
|
||||||
func readGGUFV1Array(llm *gguf, r io.Reader) (*array, error) {
|
|
||||||
t, err := readGGUF[uint32](llm, r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := readGGUF[uint32](llm, r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
a := &array{size: int(n)}
|
|
||||||
if llm.canCollectArray(int(n)) {
|
|
||||||
a.values = make([]any, 0, int(n))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range n {
|
|
||||||
var e any
|
|
||||||
switch t {
|
|
||||||
case ggufTypeUint8:
|
|
||||||
e, err = readGGUF[uint8](llm, r)
|
|
||||||
case ggufTypeInt8:
|
|
||||||
e, err = readGGUF[int8](llm, r)
|
|
||||||
case ggufTypeUint16:
|
|
||||||
e, err = readGGUF[uint16](llm, r)
|
|
||||||
case ggufTypeInt16:
|
|
||||||
e, err = readGGUF[int16](llm, r)
|
|
||||||
case ggufTypeUint32:
|
|
||||||
e, err = readGGUF[uint32](llm, r)
|
|
||||||
case ggufTypeInt32:
|
|
||||||
e, err = readGGUF[int32](llm, r)
|
|
||||||
case ggufTypeUint64:
|
|
||||||
e, err = readGGUF[uint64](llm, r)
|
|
||||||
case ggufTypeInt64:
|
|
||||||
e, err = readGGUF[int64](llm, r)
|
|
||||||
case ggufTypeFloat32:
|
|
||||||
e, err = readGGUF[float32](llm, r)
|
|
||||||
case ggufTypeFloat64:
|
|
||||||
e, err = readGGUF[float64](llm, r)
|
|
||||||
case ggufTypeBool:
|
|
||||||
e, err = readGGUF[bool](llm, r)
|
|
||||||
case ggufTypeString:
|
|
||||||
e, err = readGGUFV1String(llm, r)
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("invalid array type: %d", t)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.values != nil {
|
if a.values != nil {
|
||||||
|
e, err := readGGUFString(llm, r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
a.values[i] = e
|
a.values[i] = e
|
||||||
|
} else {
|
||||||
|
discardGGUFString(llm, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return a, nil
|
return a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
|
type array[T any] struct {
|
||||||
if llm.Version == 1 {
|
// size is the actual size of the array
|
||||||
return readGGUFV1Array(llm, r)
|
size int
|
||||||
}
|
|
||||||
|
|
||||||
|
// values is the array of values. this is nil if the array is larger than configured maxSize
|
||||||
|
values []T
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *array[T]) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(a.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newArray[T any](size, maxSize int) *array[T] {
|
||||||
|
a := array[T]{size: size}
|
||||||
|
if maxSize < 0 || size <= maxSize {
|
||||||
|
a.values = make([]T, size)
|
||||||
|
}
|
||||||
|
return &a
|
||||||
|
}
|
||||||
|
|
||||||
|
func readGGUFArray(llm *gguf, r io.Reader) (any, error) {
|
||||||
t, err := readGGUF[uint32](llm, r)
|
t, err := readGGUF[uint32](llm, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -437,45 +417,55 @@ func readGGUFArray(llm *gguf, r io.Reader) (*array, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
a := &array{size: int(n)}
|
switch t {
|
||||||
if llm.canCollectArray(int(n)) {
|
case ggufTypeUint8:
|
||||||
a.values = make([]any, int(n))
|
a := newArray[uint8](int(n), llm.maxArraySize)
|
||||||
}
|
return readGGUFArrayData(llm, r, a)
|
||||||
|
case ggufTypeInt8:
|
||||||
for i := range n {
|
a := newArray[int8](int(n), llm.maxArraySize)
|
||||||
var e any
|
return readGGUFArrayData(llm, r, a)
|
||||||
switch t {
|
case ggufTypeUint16:
|
||||||
case ggufTypeUint8:
|
a := newArray[uint16](int(n), llm.maxArraySize)
|
||||||
e, err = readGGUF[uint8](llm, r)
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeInt8:
|
case ggufTypeInt16:
|
||||||
e, err = readGGUF[int8](llm, r)
|
a := newArray[int16](int(n), llm.maxArraySize)
|
||||||
case ggufTypeUint16:
|
return readGGUFArrayData(llm, r, a)
|
||||||
e, err = readGGUF[uint16](llm, r)
|
case ggufTypeUint32:
|
||||||
case ggufTypeInt16:
|
a := newArray[uint32](int(n), llm.maxArraySize)
|
||||||
e, err = readGGUF[int16](llm, r)
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeUint32:
|
case ggufTypeInt32:
|
||||||
e, err = readGGUF[uint32](llm, r)
|
a := newArray[int32](int(n), llm.maxArraySize)
|
||||||
case ggufTypeInt32:
|
return readGGUFArrayData(llm, r, a)
|
||||||
e, err = readGGUF[int32](llm, r)
|
case ggufTypeUint64:
|
||||||
case ggufTypeUint64:
|
a := newArray[uint64](int(n), llm.maxArraySize)
|
||||||
e, err = readGGUF[uint64](llm, r)
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeInt64:
|
case ggufTypeInt64:
|
||||||
e, err = readGGUF[int64](llm, r)
|
a := newArray[int64](int(n), llm.maxArraySize)
|
||||||
case ggufTypeFloat32:
|
return readGGUFArrayData(llm, r, a)
|
||||||
e, err = readGGUF[float32](llm, r)
|
case ggufTypeFloat32:
|
||||||
case ggufTypeFloat64:
|
a := newArray[float32](int(n), llm.maxArraySize)
|
||||||
e, err = readGGUF[float64](llm, r)
|
return readGGUFArrayData(llm, r, a)
|
||||||
case ggufTypeBool:
|
case ggufTypeFloat64:
|
||||||
e, err = readGGUF[bool](llm, r)
|
a := newArray[float64](int(n), llm.maxArraySize)
|
||||||
case ggufTypeString:
|
return readGGUFArrayData(llm, r, a)
|
||||||
if a.values != nil {
|
case ggufTypeBool:
|
||||||
e, err = readGGUFString(llm, r)
|
a := newArray[bool](int(n), llm.maxArraySize)
|
||||||
} else {
|
return readGGUFArrayData(llm, r, a)
|
||||||
err = discardGGUFString(llm, r)
|
case ggufTypeString:
|
||||||
}
|
a := newArray[string](int(n), llm.maxArraySize)
|
||||||
default:
|
if llm.Version == 1 {
|
||||||
return nil, fmt.Errorf("invalid array type: %d", t)
|
return readGGUFV1StringsData(llm, r, a)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return readGGUFStringsData(llm, r, a)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid array type: %d", t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readGGUFArrayData[T any](llm *gguf, r io.Reader, a *array[T]) (any, error) {
|
||||||
|
for i := range a.size {
|
||||||
|
e, err := readGGUF[T](llm, r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -502,62 +492,86 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if t == ggufTypeString {
|
||||||
|
for _, e := range any(s).([]string) {
|
||||||
|
if err := binary.Write(w, binary.LittleEndian, uint64(len(e))); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Write(w, binary.LittleEndian, []byte(e)); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
return binary.Write(w, binary.LittleEndian, s)
|
return binary.Write(w, binary.LittleEndian, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteGGUF(ws io.WriteSeeker, kv KV, ts []Tensor) error {
|
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||||
if err := binary.Write(ws, binary.LittleEndian, []byte("GGUF")); err != nil {
|
alignment := kv.Uint("general.alignment", 32)
|
||||||
|
|
||||||
|
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := binary.Write(ws, binary.LittleEndian, uint32(3)); err != nil {
|
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(ts))); err != nil {
|
if err := binary.Write(f, binary.LittleEndian, uint64(len(ts))); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(kv))); err != nil {
|
if err := binary.Write(f, binary.LittleEndian, uint64(len(kv))); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := slices.Collect(maps.Keys(kv))
|
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||||
slices.Sort(keys)
|
if err := ggufWriteKV(f, key, kv[key]); err != nil {
|
||||||
|
|
||||||
for _, key := range keys {
|
|
||||||
if err := ggufWriteKV(ws, key, kv[key]); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
slices.SortStableFunc(ts, func(a, b Tensor) int {
|
slices.SortStableFunc(
|
||||||
if i, j := a.block(), b.block(); i < 0 && j > 0 {
|
ts,
|
||||||
return 1
|
func(a, b *Tensor) int {
|
||||||
} else if i > 0 && j < 0 {
|
return cmp.Or(
|
||||||
return -1
|
cmp.Compare(a.block(), b.block()),
|
||||||
} else {
|
cmp.Compare(a.Name, b.Name),
|
||||||
return cmp.Compare(i, j)
|
)
|
||||||
}
|
},
|
||||||
})
|
)
|
||||||
|
|
||||||
var s uint64
|
var s uint64
|
||||||
for _, t := range ts {
|
for i := range ts {
|
||||||
t.Offset = s
|
ts[i].Offset = s
|
||||||
if err := ggufWriteTensorInfo(ws, t); err != nil {
|
if err := ggufWriteTensorInfo(f, ts[i]); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s += t.Size()
|
s += ts[i].Size()
|
||||||
|
s += uint64(ggufPadding(int64(s), int64(alignment)))
|
||||||
}
|
}
|
||||||
|
|
||||||
var alignment int64 = 32
|
offset, err := f.Seek(0, io.SeekCurrent)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
offset += ggufPadding(offset, int64(alignment))
|
||||||
|
|
||||||
|
var g errgroup.Group
|
||||||
|
g.SetLimit(runtime.GOMAXPROCS(0))
|
||||||
|
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
|
||||||
for _, t := range ts {
|
for _, t := range ts {
|
||||||
if err := ggufWriteTensor(ws, t, alignment); err != nil {
|
t := t
|
||||||
|
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
|
||||||
|
g.Go(func() error {
|
||||||
|
_, err := t.WriteTo(w)
|
||||||
return err
|
return err
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return g.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||||
@@ -572,8 +586,10 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
|||||||
|
|
||||||
var err error
|
var err error
|
||||||
switch v := v.(type) {
|
switch v := v.(type) {
|
||||||
case uint32:
|
case uint32, FileType:
|
||||||
err = writeGGUF(ws, ggufTypeUint32, v)
|
err = writeGGUF(ws, ggufTypeUint32, v)
|
||||||
|
case uint64:
|
||||||
|
err = writeGGUF(ws, ggufTypeUint64, v)
|
||||||
case float32:
|
case float32:
|
||||||
err = writeGGUF(ws, ggufTypeFloat32, v)
|
err = writeGGUF(ws, ggufTypeFloat32, v)
|
||||||
case bool:
|
case bool:
|
||||||
@@ -582,32 +598,24 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
|||||||
err = writeGGUFString(ws, v)
|
err = writeGGUFString(ws, v)
|
||||||
case []int32:
|
case []int32:
|
||||||
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
err = writeGGUFArray(ws, ggufTypeInt32, v)
|
||||||
|
case *array[int32]:
|
||||||
|
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
|
||||||
case []uint32:
|
case []uint32:
|
||||||
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
err = writeGGUFArray(ws, ggufTypeUint32, v)
|
||||||
|
case *array[uint32]:
|
||||||
|
err = writeGGUFArray(ws, ggufTypeUint32, v.values)
|
||||||
case []float32:
|
case []float32:
|
||||||
err = writeGGUFArray(ws, ggufTypeFloat32, v)
|
err = writeGGUFArray(ws, ggufTypeFloat32, v)
|
||||||
|
case *array[float32]:
|
||||||
|
err = writeGGUFArray(ws, ggufTypeFloat32, v.values)
|
||||||
case []string:
|
case []string:
|
||||||
if err := binary.Write(ws, binary.LittleEndian, ggufTypeArray); err != nil {
|
err = writeGGUFArray(ws, ggufTypeString, v)
|
||||||
return err
|
case *array[string]:
|
||||||
}
|
err = writeGGUFArray(ws, ggufTypeString, v.values)
|
||||||
|
case []bool:
|
||||||
if err := binary.Write(ws, binary.LittleEndian, ggufTypeString); err != nil {
|
err = writeGGUFArray(ws, ggufTypeBool, v)
|
||||||
return err
|
case *array[bool]:
|
||||||
}
|
err = writeGGUFArray(ws, ggufTypeBool, v.values)
|
||||||
|
|
||||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(v))); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, e := range v {
|
|
||||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(e))); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := binary.Write(ws, binary.LittleEndian, []byte(e)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("improper type for '%s'", k)
|
return fmt.Errorf("improper type for '%s'", k)
|
||||||
}
|
}
|
||||||
@@ -615,7 +623,7 @@ func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
|
func ggufWriteTensorInfo(ws io.WriteSeeker, t *Tensor) error {
|
||||||
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
|
slog.Debug(t.Name, "kind", t.Kind, "shape", t.Shape, "offset", t.Offset)
|
||||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
|
if err := binary.Write(ws, binary.LittleEndian, uint64(len(t.Name))); err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -629,8 +637,8 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range len(t.Shape) {
|
for _, n := range t.Shape {
|
||||||
if err := binary.Write(ws, binary.LittleEndian, t.Shape[len(t.Shape)-i-1]); err != nil {
|
if err := binary.Write(ws, binary.LittleEndian, n); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -642,20 +650,6 @@ func ggufWriteTensorInfo(ws io.WriteSeeker, t Tensor) error {
|
|||||||
return binary.Write(ws, binary.LittleEndian, t.Offset)
|
return binary.Write(ws, binary.LittleEndian, t.Offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ggufWriteTensor(ws io.WriteSeeker, t Tensor, alignment int64) error {
|
|
||||||
offset, err := ws.Seek(0, io.SeekCurrent)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := binary.Write(ws, binary.LittleEndian, bytes.Repeat([]byte{0}, int(ggufPadding(offset, alignment)))); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = t.WriteTo(ws)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func ggufPadding(offset, align int64) int64 {
|
func ggufPadding(offset, align int64) int64 {
|
||||||
return (align - offset%align) % align
|
return (align - offset%align) % align
|
||||||
}
|
}
|
||||||
|
|||||||
83
fs/ggml/gguf_test.go
Normal file
83
fs/ggml/gguf_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package ggml
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"math/rand/v2"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestWriteGGUF(t *testing.T) {
|
||||||
|
b := bytes.NewBuffer(make([]byte, 2*3))
|
||||||
|
for range 8 {
|
||||||
|
t.Run("shuffle", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ts := []*Tensor{
|
||||||
|
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
|
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
|
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
|
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
|
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||||
|
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||||
|
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||||
|
}
|
||||||
|
|
||||||
|
rand.Shuffle(len(ts), func(i, j int) {
|
||||||
|
ts[i], ts[j] = ts[j], ts[i]
|
||||||
|
})
|
||||||
|
|
||||||
|
w, err := os.CreateTemp(t.TempDir(), strings.ReplaceAll(t.Name(), "/", "_")+"*.bin")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer w.Close()
|
||||||
|
|
||||||
|
if err := WriteGGUF(w, KV{
|
||||||
|
"general.alignment": uint32(16),
|
||||||
|
}, ts); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r, err := os.Open(w.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
ff, err := Decode(r, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(KV{
|
||||||
|
"general.alignment": uint32(16),
|
||||||
|
"general.parameter_count": uint64(54),
|
||||||
|
}, ff.KV()); diff != "" {
|
||||||
|
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(Tensors{
|
||||||
|
Offset: 592,
|
||||||
|
items: []*Tensor{
|
||||||
|
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||||
|
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||||
|
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||||
|
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||||
|
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||||
|
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||||
|
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||||
|
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||||
|
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||||
|
},
|
||||||
|
}, ff.Tensors(), cmp.AllowUnexported(Tensors{})); diff != "" {
|
||||||
|
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
352
fs/ggml/type.go
352
fs/ggml/type.go
@@ -1,26 +1,31 @@
|
|||||||
package ggml
|
package ggml
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
type fileType uint32
|
// FileType is the Go equivalent to llama_ftype used for gguf file typing
|
||||||
|
type FileType uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
fileTypeF32 fileType = iota
|
FileTypeF32 FileType = iota
|
||||||
fileTypeF16
|
FileTypeF16
|
||||||
fileTypeQ4_0
|
fileTypeQ4_0
|
||||||
fileTypeQ4_1
|
fileTypeQ4_1
|
||||||
fileTypeQ4_1_F16
|
fileTypeMXFP4 // originally fileTypeQ4_1_F16 // unused by GGML
|
||||||
fileTypeQ4_2 // unused
|
fileTypeQ4_2 // unused by GGML
|
||||||
fileTypeQ4_3 // unused
|
fileTypeQ4_3 // unused by GGML
|
||||||
fileTypeQ8_0
|
FileTypeQ8_0
|
||||||
fileTypeQ5_0
|
fileTypeQ5_0
|
||||||
fileTypeQ5_1
|
fileTypeQ5_1
|
||||||
fileTypeQ2_K
|
fileTypeQ2_K
|
||||||
fileTypeQ3_K_S
|
fileTypeQ3_K_S
|
||||||
fileTypeQ3_K_M
|
fileTypeQ3_K_M
|
||||||
fileTypeQ3_K_L
|
fileTypeQ3_K_L
|
||||||
fileTypeQ4_K_S
|
FileTypeQ4_K_S
|
||||||
fileTypeQ4_K_M
|
FileTypeQ4_K_M
|
||||||
fileTypeQ5_K_S
|
fileTypeQ5_K_S
|
||||||
fileTypeQ5_K_M
|
fileTypeQ5_K_M
|
||||||
fileTypeQ6_K
|
fileTypeQ6_K
|
||||||
@@ -37,93 +42,64 @@ const (
|
|||||||
fileTypeIQ2_M
|
fileTypeIQ2_M
|
||||||
fileTypeIQ4_XS
|
fileTypeIQ4_XS
|
||||||
fileTypeIQ1_M
|
fileTypeIQ1_M
|
||||||
fileTypeBF16
|
FileTypeBF16
|
||||||
|
fileTypeQ4_0_4_4 // unused by GGML
|
||||||
|
fileTypeQ4_0_4_8 // unused by GGML
|
||||||
|
fileTypeQ4_0_8_8 // unused by GGML
|
||||||
|
fileTypeTQ1_0
|
||||||
|
fileTypeTQ2_0
|
||||||
|
|
||||||
fileTypeUnknown
|
FileTypeUnknown = 1024
|
||||||
)
|
)
|
||||||
|
|
||||||
func ParseFileType(s string) (fileType, error) {
|
// ParseFileType parses the provided GGUF file type
|
||||||
|
// Only Ollama supported types are considered valid
|
||||||
|
func ParseFileType(s string) (FileType, error) {
|
||||||
switch s {
|
switch s {
|
||||||
case "F32":
|
case "F32":
|
||||||
return fileTypeF32, nil
|
return FileTypeF32, nil
|
||||||
case "F16":
|
case "F16":
|
||||||
return fileTypeF16, nil
|
return FileTypeF16, nil
|
||||||
case "Q4_0":
|
|
||||||
return fileTypeQ4_0, nil
|
|
||||||
case "Q4_1":
|
|
||||||
return fileTypeQ4_1, nil
|
|
||||||
case "Q4_1_F16":
|
|
||||||
return fileTypeQ4_1_F16, nil
|
|
||||||
case "Q8_0":
|
case "Q8_0":
|
||||||
return fileTypeQ8_0, nil
|
return FileTypeQ8_0, nil
|
||||||
case "Q5_0":
|
|
||||||
return fileTypeQ5_0, nil
|
|
||||||
case "Q5_1":
|
|
||||||
return fileTypeQ5_1, nil
|
|
||||||
case "Q2_K":
|
|
||||||
return fileTypeQ2_K, nil
|
|
||||||
case "Q3_K_S":
|
|
||||||
return fileTypeQ3_K_S, nil
|
|
||||||
case "Q3_K_M":
|
|
||||||
return fileTypeQ3_K_M, nil
|
|
||||||
case "Q3_K_L":
|
|
||||||
return fileTypeQ3_K_L, nil
|
|
||||||
case "Q4_K_S":
|
case "Q4_K_S":
|
||||||
return fileTypeQ4_K_S, nil
|
return FileTypeQ4_K_S, nil
|
||||||
case "Q4_K_M":
|
case "Q4_K_M", "Q4_K":
|
||||||
return fileTypeQ4_K_M, nil
|
return FileTypeQ4_K_M, nil
|
||||||
case "Q5_K_S":
|
|
||||||
return fileTypeQ5_K_S, nil
|
|
||||||
case "Q5_K_M":
|
|
||||||
return fileTypeQ5_K_M, nil
|
|
||||||
case "Q6_K":
|
|
||||||
return fileTypeQ6_K, nil
|
|
||||||
case "IQ2_XXS":
|
|
||||||
return fileTypeIQ2_XXS, nil
|
|
||||||
case "IQ2_XS":
|
|
||||||
return fileTypeIQ2_XS, nil
|
|
||||||
case "Q2_K_S":
|
|
||||||
return fileTypeQ2_K_S, nil
|
|
||||||
case "IQ3_XS":
|
|
||||||
return fileTypeIQ3_XS, nil
|
|
||||||
case "IQ3_XXS":
|
|
||||||
return fileTypeIQ3_XXS, nil
|
|
||||||
case "IQ1_S":
|
|
||||||
return fileTypeIQ1_S, nil
|
|
||||||
case "IQ4_NL":
|
|
||||||
return fileTypeIQ4_NL, nil
|
|
||||||
case "IQ3_S":
|
|
||||||
return fileTypeIQ3_S, nil
|
|
||||||
case "IQ3_M":
|
|
||||||
return fileTypeIQ3_M, nil
|
|
||||||
case "IQ2_S":
|
|
||||||
return fileTypeIQ2_S, nil
|
|
||||||
case "IQ2_M":
|
|
||||||
return fileTypeIQ2_M, nil
|
|
||||||
case "IQ4_XS":
|
|
||||||
return fileTypeIQ4_XS, nil
|
|
||||||
case "IQ1_M":
|
|
||||||
return fileTypeIQ1_M, nil
|
|
||||||
case "BF16":
|
case "BF16":
|
||||||
return fileTypeBF16, nil
|
return FileTypeBF16, nil
|
||||||
default:
|
default:
|
||||||
return fileTypeUnknown, fmt.Errorf("unknown fileType: %s", s)
|
supportedFileTypes := []FileType{
|
||||||
|
FileTypeF32,
|
||||||
|
FileTypeF16,
|
||||||
|
FileTypeQ4_K_S,
|
||||||
|
FileTypeQ4_K_M,
|
||||||
|
FileTypeQ8_0,
|
||||||
|
// fsggml.FileTypeBF16, // TODO
|
||||||
|
}
|
||||||
|
strs := make([]string, len(supportedFileTypes))
|
||||||
|
for i := range supportedFileTypes {
|
||||||
|
strs[i] = supportedFileTypes[i].String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return FileTypeUnknown, fmt.Errorf("unsupported quantization type %s - supported types are %s", s, strings.Join(strs, ", "))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t fileType) String() string {
|
func (t FileType) String() string {
|
||||||
|
// Note: this routine will return a broader set of file types for existing models
|
||||||
switch t {
|
switch t {
|
||||||
case fileTypeF32:
|
case FileTypeF32:
|
||||||
return "F32"
|
return "F32"
|
||||||
case fileTypeF16:
|
case FileTypeF16:
|
||||||
return "F16"
|
return "F16"
|
||||||
case fileTypeQ4_0:
|
case fileTypeQ4_0:
|
||||||
return "Q4_0"
|
return "Q4_0"
|
||||||
case fileTypeQ4_1:
|
case fileTypeQ4_1:
|
||||||
return "Q4_1"
|
return "Q4_1"
|
||||||
case fileTypeQ4_1_F16:
|
case fileTypeMXFP4:
|
||||||
return "Q4_1_F16"
|
return "MXFP4"
|
||||||
case fileTypeQ8_0:
|
case FileTypeQ8_0:
|
||||||
return "Q8_0"
|
return "Q8_0"
|
||||||
case fileTypeQ5_0:
|
case fileTypeQ5_0:
|
||||||
return "Q5_0"
|
return "Q5_0"
|
||||||
@@ -137,9 +113,9 @@ func (t fileType) String() string {
|
|||||||
return "Q3_K_M"
|
return "Q3_K_M"
|
||||||
case fileTypeQ3_K_L:
|
case fileTypeQ3_K_L:
|
||||||
return "Q3_K_L"
|
return "Q3_K_L"
|
||||||
case fileTypeQ4_K_S:
|
case FileTypeQ4_K_S:
|
||||||
return "Q4_K_S"
|
return "Q4_K_S"
|
||||||
case fileTypeQ4_K_M:
|
case FileTypeQ4_K_M:
|
||||||
return "Q4_K_M"
|
return "Q4_K_M"
|
||||||
case fileTypeQ5_K_S:
|
case fileTypeQ5_K_S:
|
||||||
return "Q5_K_S"
|
return "Q5_K_S"
|
||||||
@@ -147,39 +123,205 @@ func (t fileType) String() string {
|
|||||||
return "Q5_K_M"
|
return "Q5_K_M"
|
||||||
case fileTypeQ6_K:
|
case fileTypeQ6_K:
|
||||||
return "Q6_K"
|
return "Q6_K"
|
||||||
case fileTypeIQ2_XXS:
|
|
||||||
return "IQ2_XXS"
|
|
||||||
case fileTypeIQ2_XS:
|
|
||||||
return "IQ2_XS"
|
|
||||||
case fileTypeQ2_K_S:
|
case fileTypeQ2_K_S:
|
||||||
return "Q2_K_S"
|
return "Q2_K_S"
|
||||||
case fileTypeIQ3_XS:
|
case FileTypeBF16:
|
||||||
return "IQ3_XS"
|
|
||||||
case fileTypeIQ3_XXS:
|
|
||||||
return "IQ3_XXS"
|
|
||||||
case fileTypeIQ1_S:
|
|
||||||
return "IQ1_S"
|
|
||||||
case fileTypeIQ4_NL:
|
|
||||||
return "IQ4_NL"
|
|
||||||
case fileTypeIQ3_S:
|
|
||||||
return "IQ3_S"
|
|
||||||
case fileTypeIQ3_M:
|
|
||||||
return "IQ3_M"
|
|
||||||
case fileTypeIQ2_S:
|
|
||||||
return "IQ2_S"
|
|
||||||
case fileTypeIQ4_XS:
|
|
||||||
return "IQ4_XS"
|
|
||||||
case fileTypeIQ2_M:
|
|
||||||
return "IQ2_M"
|
|
||||||
case fileTypeIQ1_M:
|
|
||||||
return "IQ1_M"
|
|
||||||
case fileTypeBF16:
|
|
||||||
return "BF16"
|
return "BF16"
|
||||||
default:
|
default:
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t fileType) Value() uint32 {
|
func (t FileType) Value() uint32 {
|
||||||
return uint32(t)
|
return uint32(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ftype FileType) ToTensorType() TensorType {
|
||||||
|
switch ftype {
|
||||||
|
case FileTypeF32:
|
||||||
|
return TensorTypeF32
|
||||||
|
case FileTypeF16:
|
||||||
|
return TensorTypeF16
|
||||||
|
case fileTypeQ4_0:
|
||||||
|
return TensorTypeQ4_0
|
||||||
|
case fileTypeQ4_1:
|
||||||
|
return TensorTypeQ4_1
|
||||||
|
case FileTypeQ8_0:
|
||||||
|
return TensorTypeQ8_0
|
||||||
|
case fileTypeQ5_0:
|
||||||
|
return TensorTypeQ5_0
|
||||||
|
case fileTypeQ5_1:
|
||||||
|
return TensorTypeQ5_1
|
||||||
|
case fileTypeQ2_K:
|
||||||
|
return TensorTypeQ2_K
|
||||||
|
case fileTypeQ3_K_S:
|
||||||
|
return TensorTypeQ3_K
|
||||||
|
case fileTypeQ3_K_M:
|
||||||
|
return TensorTypeQ3_K
|
||||||
|
case fileTypeQ3_K_L:
|
||||||
|
return TensorTypeQ3_K
|
||||||
|
case FileTypeQ4_K_S:
|
||||||
|
return TensorTypeQ4_K
|
||||||
|
case FileTypeQ4_K_M:
|
||||||
|
return TensorTypeQ4_K
|
||||||
|
case fileTypeQ5_K_S:
|
||||||
|
return TensorTypeQ5_K
|
||||||
|
case fileTypeQ5_K_M:
|
||||||
|
return TensorTypeQ5_K
|
||||||
|
case fileTypeQ6_K:
|
||||||
|
return TensorTypeQ6_K
|
||||||
|
case fileTypeQ2_K_S:
|
||||||
|
return TensorTypeQ2_K
|
||||||
|
case FileTypeBF16:
|
||||||
|
return TensorTypeBF16
|
||||||
|
case fileTypeMXFP4:
|
||||||
|
return TensorTypeMXFP4
|
||||||
|
default:
|
||||||
|
slog.Warn("unsupported file type", "type", ftype)
|
||||||
|
return 0 // F32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TensorType is equivalent to ggml_type for individual tensor types
|
||||||
|
// Note: these are not the same as FileType
|
||||||
|
type TensorType uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
TensorTypeF32 TensorType = iota
|
||||||
|
TensorTypeF16
|
||||||
|
TensorTypeQ4_0
|
||||||
|
TensorTypeQ4_1
|
||||||
|
tensorTypeQ4_2
|
||||||
|
tensorTypeQ4_3 // unused by GGML
|
||||||
|
TensorTypeQ5_0
|
||||||
|
TensorTypeQ5_1
|
||||||
|
TensorTypeQ8_0
|
||||||
|
TensorTypeQ8_1
|
||||||
|
TensorTypeQ2_K
|
||||||
|
TensorTypeQ3_K
|
||||||
|
TensorTypeQ4_K
|
||||||
|
TensorTypeQ5_K
|
||||||
|
TensorTypeQ6_K
|
||||||
|
TensorTypeQ8_K
|
||||||
|
tensorTypeIQ2_XXS // not supported by ollama
|
||||||
|
tensorTypeIQ2_XS // not supported by ollama
|
||||||
|
tensorTypeIQ3_XXS // not supported by ollama
|
||||||
|
tensorTypeIQ1_S // not supported by ollama
|
||||||
|
tensorTypeIQ4_NL // not supported by ollama
|
||||||
|
tensorTypeIQ3_S // not supported by ollama
|
||||||
|
tensorTypeIQ2_S // not supported by ollama
|
||||||
|
tensorTypeIQ4_XS // not supported by ollama
|
||||||
|
TensorTypeI8
|
||||||
|
TensorTypeI16
|
||||||
|
TensorTypeI32
|
||||||
|
TensorTypeI64
|
||||||
|
TensorTypeF64
|
||||||
|
tensorTypeIQ1_M // not supported by ollama
|
||||||
|
TensorTypeBF16
|
||||||
|
tensorTypeQ4_0_4_4 // unused by GGML
|
||||||
|
tensorTypeQ4_0_4_8 // unused by GGML
|
||||||
|
tensorTypeQ4_0_8_8 // unused by GGML
|
||||||
|
tensorTypeTQ1_0 // not supported by ollama
|
||||||
|
tensorTypeTQ2_0 // not supported by ollama
|
||||||
|
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||||
|
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||||
|
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||||
|
TensorTypeMXFP4
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseFileType parses the provided GGUF file type
|
||||||
|
// Only Ollama supported types are considered valid
|
||||||
|
func ParseTensorType(s string) (TensorType, error) {
|
||||||
|
switch s {
|
||||||
|
case "F32":
|
||||||
|
return TensorTypeF32, nil
|
||||||
|
case "F16":
|
||||||
|
return TensorTypeF16, nil
|
||||||
|
case "Q4_0":
|
||||||
|
return TensorTypeQ4_0, nil
|
||||||
|
case "Q4_1":
|
||||||
|
return TensorTypeQ4_1, nil
|
||||||
|
case "Q5_0":
|
||||||
|
return TensorTypeQ5_0, nil
|
||||||
|
case "Q5_1":
|
||||||
|
return TensorTypeQ5_1, nil
|
||||||
|
case "Q8_0":
|
||||||
|
return TensorTypeQ8_0, nil
|
||||||
|
case "Q8_1":
|
||||||
|
return TensorTypeQ8_1, nil
|
||||||
|
case "Q2_K":
|
||||||
|
return TensorTypeQ2_K, nil
|
||||||
|
case "Q3_K":
|
||||||
|
return TensorTypeQ3_K, nil
|
||||||
|
case "Q4_K":
|
||||||
|
return TensorTypeQ4_K, nil
|
||||||
|
case "Q5_K":
|
||||||
|
return TensorTypeQ5_K, nil
|
||||||
|
case "Q6_K":
|
||||||
|
return TensorTypeQ6_K, nil
|
||||||
|
case "Q8_K":
|
||||||
|
return TensorTypeQ8_K, nil
|
||||||
|
case "F64":
|
||||||
|
return TensorTypeF64, nil
|
||||||
|
case "BF16":
|
||||||
|
return TensorTypeBF16, nil
|
||||||
|
case "MXFP4":
|
||||||
|
return TensorTypeMXFP4, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unsupported quantization type %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorType) IsQuantized() bool {
|
||||||
|
switch t {
|
||||||
|
case TensorTypeF32, TensorTypeF16, TensorTypeBF16:
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorType) RowSize(ne uint64) uint64 {
|
||||||
|
return t.TypeSize() * ne / t.BlockSize()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t TensorType) String() string {
|
||||||
|
switch t {
|
||||||
|
case TensorTypeF32:
|
||||||
|
return "F32"
|
||||||
|
case TensorTypeF16:
|
||||||
|
return "F16"
|
||||||
|
case TensorTypeQ4_0:
|
||||||
|
return "Q4_0"
|
||||||
|
case TensorTypeQ4_1:
|
||||||
|
return "Q4_1"
|
||||||
|
case TensorTypeQ5_0:
|
||||||
|
return "Q5_0"
|
||||||
|
case TensorTypeQ5_1:
|
||||||
|
return "Q5_1"
|
||||||
|
case TensorTypeQ8_0:
|
||||||
|
return "Q8_0"
|
||||||
|
case TensorTypeQ8_1:
|
||||||
|
return "Q8_1"
|
||||||
|
case TensorTypeQ2_K:
|
||||||
|
return "Q2_K"
|
||||||
|
case TensorTypeQ3_K:
|
||||||
|
return "Q3_K"
|
||||||
|
case TensorTypeQ4_K:
|
||||||
|
return "Q4_K"
|
||||||
|
case TensorTypeQ5_K:
|
||||||
|
return "Q5_K"
|
||||||
|
case TensorTypeQ6_K:
|
||||||
|
return "Q6_K"
|
||||||
|
case TensorTypeQ8_K:
|
||||||
|
return "Q8_K"
|
||||||
|
case TensorTypeF64:
|
||||||
|
return "F64"
|
||||||
|
case TensorTypeBF16:
|
||||||
|
return "BF16"
|
||||||
|
case 4, TensorTypeMXFP4:
|
||||||
|
return "MXFP4"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
347
fs/gguf/gguf.go
Normal file
347
fs/gguf/gguf.go
Normal file
@@ -0,0 +1,347 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"cmp"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"iter"
|
||||||
|
"os"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
typeUint8 uint32 = iota
|
||||||
|
typeInt8
|
||||||
|
typeUint16
|
||||||
|
typeInt16
|
||||||
|
typeUint32
|
||||||
|
typeInt32
|
||||||
|
typeFloat32
|
||||||
|
typeBool
|
||||||
|
typeString
|
||||||
|
typeArray
|
||||||
|
typeUint64
|
||||||
|
typeInt64
|
||||||
|
typeFloat64
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrUnsupported = errors.New("unsupported")
|
||||||
|
|
||||||
|
type File struct {
|
||||||
|
Magic [4]byte
|
||||||
|
Version uint32
|
||||||
|
|
||||||
|
keyValues *lazy[KeyValue]
|
||||||
|
tensors *lazy[TensorInfo]
|
||||||
|
offset int64
|
||||||
|
|
||||||
|
file *os.File
|
||||||
|
reader *bufferedReader
|
||||||
|
bts []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func Open(path string) (f *File, err error) {
|
||||||
|
f = &File{bts: make([]byte, 4096)}
|
||||||
|
f.file, err = os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.reader = newBufferedReader(f.file, 32<<10)
|
||||||
|
|
||||||
|
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(f.Magic[:], []byte("gguf")) {
|
||||||
|
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.Version < 2 {
|
||||||
|
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.tensors, err = newLazy(f, f.readTensor)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f.tensors.successFunc = func() error {
|
||||||
|
offset := f.reader.offset
|
||||||
|
|
||||||
|
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
|
||||||
|
f.offset = offset + (alignment-offset%alignment)%alignment
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
f.keyValues, err = newLazy(f, f.readKeyValue)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) readTensor() (TensorInfo, error) {
|
||||||
|
name, err := readString(f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dims, err := read[uint32](f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
shape := make([]uint64, dims)
|
||||||
|
for i := range dims {
|
||||||
|
shape[i], err = read[uint64](f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type_, err := read[uint32](f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
offset, err := read[uint64](f)
|
||||||
|
if err != nil {
|
||||||
|
return TensorInfo{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return TensorInfo{
|
||||||
|
Name: name,
|
||||||
|
Offset: offset,
|
||||||
|
Shape: shape,
|
||||||
|
Type: TensorType(type_),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) readKeyValue() (KeyValue, error) {
|
||||||
|
key, err := readString(f)
|
||||||
|
if err != nil {
|
||||||
|
return KeyValue{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
t, err := read[uint32](f)
|
||||||
|
if err != nil {
|
||||||
|
return KeyValue{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := func() (any, error) {
|
||||||
|
switch t {
|
||||||
|
case typeUint8:
|
||||||
|
return read[uint8](f)
|
||||||
|
case typeInt8:
|
||||||
|
return read[int8](f)
|
||||||
|
case typeUint16:
|
||||||
|
return read[uint16](f)
|
||||||
|
case typeInt16:
|
||||||
|
return read[int16](f)
|
||||||
|
case typeUint32:
|
||||||
|
return read[uint32](f)
|
||||||
|
case typeInt32:
|
||||||
|
return read[int32](f)
|
||||||
|
case typeUint64:
|
||||||
|
return read[uint64](f)
|
||||||
|
case typeInt64:
|
||||||
|
return read[int64](f)
|
||||||
|
case typeFloat32:
|
||||||
|
return read[float32](f)
|
||||||
|
case typeFloat64:
|
||||||
|
return read[float64](f)
|
||||||
|
case typeBool:
|
||||||
|
return read[bool](f)
|
||||||
|
case typeString:
|
||||||
|
return readString(f)
|
||||||
|
case typeArray:
|
||||||
|
return readArray(f)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return KeyValue{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return KeyValue{
|
||||||
|
Key: key,
|
||||||
|
Value: Value{value},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func read[T any](f *File) (t T, err error) {
|
||||||
|
err = binary.Read(f.reader, binary.LittleEndian, &t)
|
||||||
|
return t, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func readString(f *File) (string, error) {
|
||||||
|
n, err := read[uint64](f)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if int(n) > len(f.bts) {
|
||||||
|
f.bts = make([]byte, n)
|
||||||
|
}
|
||||||
|
|
||||||
|
bts := f.bts[:n]
|
||||||
|
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer clear(bts)
|
||||||
|
|
||||||
|
return string(bts), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readArray(f *File) (any, error) {
|
||||||
|
t, err := read[uint32](f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := read[uint64](f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case typeUint8:
|
||||||
|
return readArrayData[uint8](f, n)
|
||||||
|
case typeInt8:
|
||||||
|
return readArrayData[int8](f, n)
|
||||||
|
case typeUint16:
|
||||||
|
return readArrayData[uint16](f, n)
|
||||||
|
case typeInt16:
|
||||||
|
return readArrayData[int16](f, n)
|
||||||
|
case typeUint32:
|
||||||
|
return readArrayData[uint32](f, n)
|
||||||
|
case typeInt32:
|
||||||
|
return readArrayData[int32](f, n)
|
||||||
|
case typeUint64:
|
||||||
|
return readArrayData[uint64](f, n)
|
||||||
|
case typeInt64:
|
||||||
|
return readArrayData[int64](f, n)
|
||||||
|
case typeFloat32:
|
||||||
|
return readArrayData[float32](f, n)
|
||||||
|
case typeFloat64:
|
||||||
|
return readArrayData[float64](f, n)
|
||||||
|
case typeBool:
|
||||||
|
return readArrayData[bool](f, n)
|
||||||
|
case typeString:
|
||||||
|
return readArrayString(f, n)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
|
||||||
|
s = make([]T, n)
|
||||||
|
for i := range n {
|
||||||
|
e, err := read[T](f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s[i] = e
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readArrayString(f *File, n uint64) (s []string, err error) {
|
||||||
|
s = make([]string, n)
|
||||||
|
for i := range n {
|
||||||
|
e, err := readString(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s[i] = e
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) Close() error {
|
||||||
|
f.keyValues.stop()
|
||||||
|
f.tensors.stop()
|
||||||
|
return f.file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) KeyValue(key string) KeyValue {
|
||||||
|
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") {
|
||||||
|
key = f.KeyValue("general.architecture").String() + "." + key
|
||||||
|
}
|
||||||
|
|
||||||
|
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool {
|
||||||
|
return kv.Key == key
|
||||||
|
}); index >= 0 {
|
||||||
|
return f.keyValues.values[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() {
|
||||||
|
if keyValue.Key == key {
|
||||||
|
return keyValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return KeyValue{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) NumKeyValues() int {
|
||||||
|
return int(f.keyValues.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
|
||||||
|
return f.keyValues.All()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) TensorInfo(name string) TensorInfo {
|
||||||
|
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
|
||||||
|
return t.Name == name
|
||||||
|
}); index >= 0 {
|
||||||
|
return f.tensors.values[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
// fast-forward through key values if we haven't already
|
||||||
|
_ = f.keyValues.rest()
|
||||||
|
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
|
||||||
|
if tensor.Name == name {
|
||||||
|
return tensor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TensorInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) NumTensors() int {
|
||||||
|
return int(f.tensors.count)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
|
||||||
|
// fast forward through key values if we haven't already
|
||||||
|
f.keyValues.rest()
|
||||||
|
return f.tensors.All()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
|
||||||
|
t := f.TensorInfo(name)
|
||||||
|
if t.NumBytes() == 0 {
|
||||||
|
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// fast forward through tensor info if we haven't already
|
||||||
|
_ = f.tensors.rest()
|
||||||
|
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
|
||||||
|
}
|
||||||
249
fs/gguf/gguf_test.go
Normal file
249
fs/gguf/gguf_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package gguf_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/google/go-cmp/cmp/cmpopts"
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
"github.com/ollama/ollama/fs/gguf"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createBinFile(tb testing.TB) string {
|
||||||
|
tb.Helper()
|
||||||
|
f, err := os.CreateTemp(tb.TempDir(), "")
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
kv := ggml.KV{
|
||||||
|
"general.architecture": "llama",
|
||||||
|
"llama.block_count": uint32(8),
|
||||||
|
"llama.embedding_length": uint32(3),
|
||||||
|
"llama.attention.head_count": uint32(2),
|
||||||
|
"llama.attention.head_count_kv": uint32(2),
|
||||||
|
"llama.attention.key_length": uint32(3),
|
||||||
|
"llama.rope.dimension_count": uint32(4),
|
||||||
|
"llama.rope.freq_base": float32(10000.0),
|
||||||
|
"llama.rope.freq_scale": float32(1.0),
|
||||||
|
"llama.attention.layer_norm_rms_epsilon": float32(1e-6),
|
||||||
|
"tokenizer.ggml.eos_token_id": uint32(0),
|
||||||
|
"tokenizer.ggml.eos_token_ids": []int32{1, 2, 3},
|
||||||
|
"tokenizer.ggml.tokens": []string{"hello", "world"},
|
||||||
|
"tokenizer.ggml.scores": []float32{0, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
tensors := []*ggml.Tensor{
|
||||||
|
{
|
||||||
|
Name: "token_embd.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{2, 3},
|
||||||
|
WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "output.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{3, 2},
|
||||||
|
WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range 8 {
|
||||||
|
tensors = append(tensors, &ggml.Tensor{
|
||||||
|
Name: "blk." + strconv.Itoa(i) + ".attn_q.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{3, 3},
|
||||||
|
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||||
|
}, &ggml.Tensor{
|
||||||
|
Name: "blk." + strconv.Itoa(i) + ".attn_k.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{3, 3},
|
||||||
|
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||||
|
}, &ggml.Tensor{
|
||||||
|
Name: "blk." + strconv.Itoa(i) + ".attn_v.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{3, 3},
|
||||||
|
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||||
|
}, &ggml.Tensor{
|
||||||
|
Name: "blk." + strconv.Itoa(i) + ".attn_output.weight",
|
||||||
|
Kind: 0,
|
||||||
|
Shape: []uint64{3, 3},
|
||||||
|
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
|
||||||
|
tb.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return f.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRead(t *testing.T) {
|
||||||
|
f, err := gguf.Open(createBinFile(t))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if got := f.KeyValue("does.not.exist").Valid(); got {
|
||||||
|
t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := f.KeyValue("general.architecture").String(); got != "llama" {
|
||||||
|
t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama")
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" {
|
||||||
|
t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight")
|
||||||
|
} else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" {
|
||||||
|
t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff)
|
||||||
|
} else if got.Type != gguf.TensorTypeF32 {
|
||||||
|
t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := f.KeyValue("block_count").Uint(); got != 8 {
|
||||||
|
t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" {
|
||||||
|
t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" {
|
||||||
|
t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var kvs []string
|
||||||
|
for _, kv := range f.KeyValues() {
|
||||||
|
if !kv.Valid() {
|
||||||
|
t.Error("found invalid key-value pair:", kv)
|
||||||
|
}
|
||||||
|
|
||||||
|
kvs = append(kvs, kv.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(kvs) != f.NumKeyValues() {
|
||||||
|
t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues())
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(kvs, []string{
|
||||||
|
"general.architecture",
|
||||||
|
"llama.block_count",
|
||||||
|
"llama.embedding_length",
|
||||||
|
"llama.attention.head_count",
|
||||||
|
"llama.attention.head_count_kv",
|
||||||
|
"llama.attention.key_length",
|
||||||
|
"llama.rope.dimension_count",
|
||||||
|
"llama.rope.freq_base",
|
||||||
|
"llama.rope.freq_scale",
|
||||||
|
"llama.attention.layer_norm_rms_epsilon",
|
||||||
|
"tokenizer.ggml.eos_token_id",
|
||||||
|
"tokenizer.ggml.eos_token_ids",
|
||||||
|
"tokenizer.ggml.tokens",
|
||||||
|
"tokenizer.ggml.scores",
|
||||||
|
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
|
||||||
|
t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tis []string
|
||||||
|
for _, ti := range f.TensorInfos() {
|
||||||
|
if !ti.Valid() {
|
||||||
|
t.Error("found invalid tensor info:", ti)
|
||||||
|
}
|
||||||
|
|
||||||
|
tis = append(tis, ti.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tis) != f.NumTensors() {
|
||||||
|
t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors())
|
||||||
|
}
|
||||||
|
|
||||||
|
if diff := cmp.Diff(tis, []string{
|
||||||
|
"token_embd.weight",
|
||||||
|
"output.weight",
|
||||||
|
"blk.0.attn_q.weight",
|
||||||
|
"blk.0.attn_k.weight",
|
||||||
|
"blk.0.attn_v.weight",
|
||||||
|
"blk.0.attn_output.weight",
|
||||||
|
"blk.1.attn_q.weight",
|
||||||
|
"blk.1.attn_k.weight",
|
||||||
|
"blk.1.attn_v.weight",
|
||||||
|
"blk.1.attn_output.weight",
|
||||||
|
"blk.2.attn_q.weight",
|
||||||
|
"blk.2.attn_k.weight",
|
||||||
|
"blk.2.attn_v.weight",
|
||||||
|
"blk.2.attn_output.weight",
|
||||||
|
"blk.3.attn_q.weight",
|
||||||
|
"blk.3.attn_k.weight",
|
||||||
|
"blk.3.attn_v.weight",
|
||||||
|
"blk.3.attn_output.weight",
|
||||||
|
"blk.4.attn_q.weight",
|
||||||
|
"blk.4.attn_k.weight",
|
||||||
|
"blk.4.attn_v.weight",
|
||||||
|
"blk.4.attn_output.weight",
|
||||||
|
"blk.5.attn_q.weight",
|
||||||
|
"blk.5.attn_k.weight",
|
||||||
|
"blk.5.attn_v.weight",
|
||||||
|
"blk.5.attn_output.weight",
|
||||||
|
"blk.6.attn_q.weight",
|
||||||
|
"blk.6.attn_k.weight",
|
||||||
|
"blk.6.attn_v.weight",
|
||||||
|
"blk.6.attn_output.weight",
|
||||||
|
"blk.7.attn_q.weight",
|
||||||
|
"blk.7.attn_k.weight",
|
||||||
|
"blk.7.attn_v.weight",
|
||||||
|
"blk.7.attn_output.weight",
|
||||||
|
}, cmpopts.SortSlices(strings.Compare)); diff != "" {
|
||||||
|
t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
ti, r, err := f.TensorReader("output.weight")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf(`TensorReader("output.weight") error: %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ti.Name != "output.weight" {
|
||||||
|
t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight")
|
||||||
|
} else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" {
|
||||||
|
t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff)
|
||||||
|
} else if ti.Type != gguf.TensorTypeF32 {
|
||||||
|
t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32)
|
||||||
|
}
|
||||||
|
|
||||||
|
var b bytes.Buffer
|
||||||
|
if _, err := b.ReadFrom(r); err != nil {
|
||||||
|
t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if b.Len() != int(ti.NumBytes()) {
|
||||||
|
t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkRead(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
p := createBinFile(b)
|
||||||
|
for b.Loop() {
|
||||||
|
f, err := gguf.Open(p)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := f.KeyValue("general.architecture").String(); got != "llama" {
|
||||||
|
b.Errorf("got = %q, want %q", got, "llama")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate through some tensors
|
||||||
|
for range f.TensorInfos() {
|
||||||
|
}
|
||||||
|
|
||||||
|
f.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
90
fs/gguf/keyvalue.go
Normal file
90
fs/gguf/keyvalue.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"slices"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyValue struct {
|
||||||
|
Key string
|
||||||
|
Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kv KeyValue) Valid() bool {
|
||||||
|
return kv.Key != "" && kv.Value.value != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type Value struct {
|
||||||
|
value any
|
||||||
|
}
|
||||||
|
|
||||||
|
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
|
||||||
|
vv := reflect.ValueOf(v.value)
|
||||||
|
if slices.Contains(kinds, vv.Kind()) {
|
||||||
|
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
|
||||||
|
switch vv := reflect.ValueOf(v.value); vv.Kind() {
|
||||||
|
case reflect.Slice:
|
||||||
|
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
|
||||||
|
ts = make([]T, vv.Len())
|
||||||
|
for i := range vv.Len() {
|
||||||
|
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
|
||||||
|
func (v Value) Int() int64 {
|
||||||
|
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
|
||||||
|
func (v Value) Ints() (i64s []int64) {
|
||||||
|
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
|
||||||
|
func (v Value) Uint() uint64 {
|
||||||
|
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
|
||||||
|
func (v Value) Uints() (u64s []uint64) {
|
||||||
|
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float returns Value as a float. If it is not a float, it returns 0.
|
||||||
|
func (v Value) Float() float64 {
|
||||||
|
return value[float64](v, reflect.Float32, reflect.Float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
|
||||||
|
func (v Value) Floats() (f64s []float64) {
|
||||||
|
return values[float64](v, reflect.Float32, reflect.Float64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
|
||||||
|
func (v Value) Bool() bool {
|
||||||
|
return value[bool](v, reflect.Bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
|
||||||
|
func (v Value) Bools() (bools []bool) {
|
||||||
|
return values[bool](v, reflect.Bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns Value as a string. If it is not a string, it returns an empty string.
|
||||||
|
func (v Value) String() string {
|
||||||
|
return value[string](v, reflect.String)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
|
||||||
|
func (v Value) Strings() (strings []string) {
|
||||||
|
return values[string](v, reflect.String)
|
||||||
|
}
|
||||||
208
fs/gguf/keyvalue_test.go
Normal file
208
fs/gguf/keyvalue_test.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func split(name string, values map[string][]any) (matched []any, unmatched []any) {
|
||||||
|
for key, value := range values {
|
||||||
|
if key == name {
|
||||||
|
matched = value
|
||||||
|
} else {
|
||||||
|
unmatched = append(unmatched, value...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValue(t *testing.T) {
|
||||||
|
values := map[string][]any{
|
||||||
|
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)},
|
||||||
|
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)},
|
||||||
|
"float64": {float32(42), float64(42)},
|
||||||
|
"string": {"42", "hello"},
|
||||||
|
"bool": {true, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("int64", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("int64", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if i64 := kv.Int(); i64 != 42 {
|
||||||
|
t.Errorf("expected 42, got %d", i64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if i64 := kv.Int(); i64 != 0 {
|
||||||
|
t.Errorf("expected 42, got %d", i64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uint64", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("uint64", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if u64 := kv.Uint(); u64 != 42 {
|
||||||
|
t.Errorf("expected 42, got %d", u64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if u64 := kv.Uint(); u64 != 0 {
|
||||||
|
t.Errorf("expected 42, got %d", u64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("float64", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("float64", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if f64 := kv.Float(); f64 != 42 {
|
||||||
|
t.Errorf("expected 42, got %f", f64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if f64 := kv.Float(); f64 != 0 {
|
||||||
|
t.Errorf("expected 42, got %f", f64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("string", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("string", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if s := kv.String(); s != v {
|
||||||
|
t.Errorf("expected 42, got %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if s := kv.String(); s != "" {
|
||||||
|
t.Errorf("expected 42, got %s", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bool", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("bool", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if b := kv.Bool(); b != v {
|
||||||
|
t.Errorf("expected true, got %v", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if b := kv.Bool(); b != false {
|
||||||
|
t.Errorf("expected false, got %v", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValues(t *testing.T) {
|
||||||
|
values := map[string][]any{
|
||||||
|
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}},
|
||||||
|
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}},
|
||||||
|
"float64s": {[]float32{42}, []float64{42}},
|
||||||
|
"strings": {[]string{"42"}, []string{"hello"}},
|
||||||
|
"bools": {[]bool{true}, []bool{false}},
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("int64s", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("int64s", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if i64s := kv.Ints(); i64s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", i64s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uint64s", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("uint64s", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if u64s := kv.Uints(); u64s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", u64s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("float64s", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("float64s", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if f64s := kv.Floats(); f64s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", f64s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("strings", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("strings", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Strings(), v); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if s := kv.Strings(); s != nil {
|
||||||
|
t.Errorf("expected nil, got %v", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bools", func(t *testing.T) {
|
||||||
|
matched, unmatched := split("bools", values)
|
||||||
|
for _, v := range matched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if diff := cmp.Diff(kv.Bools(), v); diff != "" {
|
||||||
|
t.Errorf("diff: %s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, v := range unmatched {
|
||||||
|
kv := KeyValue{"key", Value{v}}
|
||||||
|
if b := kv.Bools(); b != nil {
|
||||||
|
t.Errorf("expected nil, got %v", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
89
fs/gguf/lazy.go
Normal file
89
fs/gguf/lazy.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"iter"
|
||||||
|
"log/slog"
|
||||||
|
)
|
||||||
|
|
||||||
|
type lazy[T any] struct {
|
||||||
|
count uint64
|
||||||
|
next func() (T, bool)
|
||||||
|
stop func()
|
||||||
|
values []T
|
||||||
|
|
||||||
|
// successFunc is called when all values have been successfully read.
|
||||||
|
successFunc func() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) {
|
||||||
|
it := lazy[T]{}
|
||||||
|
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
it.values = make([]T, 0)
|
||||||
|
it.next, it.stop = iter.Pull(func(yield func(T) bool) {
|
||||||
|
for i := range it.count {
|
||||||
|
t, err := fn()
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("error reading tensor", "index", i, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
it.values = append(it.values, t)
|
||||||
|
if !yield(t) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if it.successFunc != nil {
|
||||||
|
it.successFunc()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return &it, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *lazy[T]) Values() iter.Seq[T] {
|
||||||
|
return func(yield func(T) bool) {
|
||||||
|
for _, v := range g.All() {
|
||||||
|
if !yield(v) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *lazy[T]) All() iter.Seq2[int, T] {
|
||||||
|
return func(yield func(int, T) bool) {
|
||||||
|
for i := range int(g.count) {
|
||||||
|
if i < len(g.values) {
|
||||||
|
if !yield(i, g.values[i]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t, ok := g.next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if !yield(i, t) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *lazy[T]) rest() (collected bool) {
|
||||||
|
for {
|
||||||
|
_, ok := g.next()
|
||||||
|
collected = collected || ok
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return collected
|
||||||
|
}
|
||||||
23
fs/gguf/reader.go
Normal file
23
fs/gguf/reader.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
type bufferedReader struct {
|
||||||
|
offset int64
|
||||||
|
*bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader {
|
||||||
|
return &bufferedReader{
|
||||||
|
Reader: bufio.NewReaderSize(rs, size),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *bufferedReader) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = rs.Reader.Read(p)
|
||||||
|
rs.offset += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
288
fs/gguf/tensor.go
Normal file
288
fs/gguf/tensor.go
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
package gguf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TensorInfo struct {
|
||||||
|
Name string
|
||||||
|
Offset uint64
|
||||||
|
Shape []uint64
|
||||||
|
Type TensorType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti TensorInfo) Valid() bool {
|
||||||
|
return ti.Name != "" && ti.NumBytes() > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti TensorInfo) NumValues() int64 {
|
||||||
|
var numItems int64 = 1
|
||||||
|
for _, dim := range ti.Shape {
|
||||||
|
numItems *= int64(dim)
|
||||||
|
}
|
||||||
|
return numItems
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumBytes returns the number of bytes in the tensor.
|
||||||
|
func (ti TensorInfo) NumBytes() int64 {
|
||||||
|
return int64(float64(ti.NumValues()) * ti.Type.NumBytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ti TensorInfo) LogValue() slog.Value {
|
||||||
|
return slog.GroupValue(
|
||||||
|
slog.String("name", ti.Name),
|
||||||
|
slog.Int64("offset", int64(ti.Offset)),
|
||||||
|
slog.Any("shape", ti.Shape),
|
||||||
|
slog.Int64("num_values", ti.NumValues()),
|
||||||
|
slog.Int64("num_bytes", ti.NumBytes()),
|
||||||
|
slog.Any("type", ti.Type),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
type TensorType uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
TensorTypeF32 TensorType = iota
|
||||||
|
TensorTypeF16
|
||||||
|
TensorTypeQ4_0
|
||||||
|
TensorTypeQ4_1
|
||||||
|
|
||||||
|
// unexported // unused in gguf
|
||||||
|
tensorTypeQ4_2
|
||||||
|
tensorTypeQ4_3
|
||||||
|
|
||||||
|
TensorTypeQ5_0
|
||||||
|
TensorTypeQ5_1
|
||||||
|
TensorTypeQ8_0
|
||||||
|
TensorTypeQ8_1
|
||||||
|
TensorTypeQ2_K
|
||||||
|
TensorTypeQ3_K
|
||||||
|
TensorTypeQ4_K
|
||||||
|
TensorTypeQ5_K
|
||||||
|
TensorTypeQ6_K
|
||||||
|
TensorTypeQ8_K
|
||||||
|
|
||||||
|
// unexported // unquantizable by ollama
|
||||||
|
tensorTypeIQ2_XXS
|
||||||
|
tensorTypeIQ2_XS
|
||||||
|
tensorTypeIQ3_XXS
|
||||||
|
tensorTypeIQ1_S
|
||||||
|
tensorTypeIQ4_NL
|
||||||
|
tensorTypeIQ3_S
|
||||||
|
tensorTypeIQ2_S
|
||||||
|
tensorTypeIQ4_XS
|
||||||
|
|
||||||
|
TensorTypeI8
|
||||||
|
TensorTypeI16
|
||||||
|
TensorTypeI32
|
||||||
|
TensorTypeI64
|
||||||
|
TensorTypeF64
|
||||||
|
|
||||||
|
// unexported // unquantizable by ollama
|
||||||
|
tensorTypeIQ1_M
|
||||||
|
|
||||||
|
TensorTypeBF16
|
||||||
|
|
||||||
|
// unexported // unused in gguf
|
||||||
|
tensorTypeQ4_0_4_4
|
||||||
|
tensorTypeQ4_0_4_8
|
||||||
|
tensorTypeQ4_0_8_8
|
||||||
|
|
||||||
|
// unexported // unquantizable by ollama
|
||||||
|
tensorTypeTQ1_0
|
||||||
|
tensorTypeTQ2_0
|
||||||
|
|
||||||
|
// unexported // unused in gguf
|
||||||
|
tensorTypeIQ4_NL_4_4
|
||||||
|
tensorTypeIQ4_NL_4_8
|
||||||
|
tensorTypeIQ4_NL_8_8
|
||||||
|
)
|
||||||
|
|
||||||
|
func (tt TensorType) NumBytes() float64 {
|
||||||
|
return float64(tt.typeSize()) / float64(tt.blockSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt TensorType) typeSize() int64 {
|
||||||
|
switch tt {
|
||||||
|
case TensorTypeF32:
|
||||||
|
return 4
|
||||||
|
case TensorTypeF16:
|
||||||
|
return 2
|
||||||
|
case TensorTypeQ4_0:
|
||||||
|
return 2 + tt.blockSize()/2
|
||||||
|
case TensorTypeQ4_1:
|
||||||
|
return 2 + 2 + tt.blockSize()/2
|
||||||
|
case TensorTypeQ5_0:
|
||||||
|
return 2 + 4 + tt.blockSize()/2
|
||||||
|
case TensorTypeQ5_1:
|
||||||
|
return 2 + 2 + 4 + tt.blockSize()/2
|
||||||
|
case TensorTypeQ8_0:
|
||||||
|
return 2 + tt.blockSize()
|
||||||
|
case TensorTypeQ8_1:
|
||||||
|
return 2 + 2 + tt.blockSize()
|
||||||
|
case TensorTypeQ2_K:
|
||||||
|
return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2
|
||||||
|
case TensorTypeQ3_K:
|
||||||
|
return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2
|
||||||
|
case TensorTypeQ4_K:
|
||||||
|
return 2 + 2 + 12 + tt.blockSize()/2
|
||||||
|
case TensorTypeQ5_K:
|
||||||
|
return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2
|
||||||
|
case TensorTypeQ6_K:
|
||||||
|
return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2
|
||||||
|
case TensorTypeQ8_K:
|
||||||
|
return 4 + tt.blockSize() + 2*tt.blockSize()/16
|
||||||
|
case tensorTypeIQ2_XXS:
|
||||||
|
return 2 + 2*tt.blockSize()/8
|
||||||
|
case tensorTypeIQ2_XS:
|
||||||
|
return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32
|
||||||
|
case tensorTypeIQ3_XXS:
|
||||||
|
return 2 + tt.blockSize()/4 + tt.blockSize()/8
|
||||||
|
case tensorTypeIQ1_S:
|
||||||
|
return 2 + tt.blockSize()/8 + tt.blockSize()/16
|
||||||
|
case tensorTypeIQ4_NL:
|
||||||
|
return 2 + tt.blockSize()/2
|
||||||
|
case tensorTypeIQ3_S:
|
||||||
|
return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4
|
||||||
|
case tensorTypeIQ2_S:
|
||||||
|
return 2 + tt.blockSize()/4 + tt.blockSize()/16
|
||||||
|
case tensorTypeIQ4_XS:
|
||||||
|
return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64
|
||||||
|
case TensorTypeI8:
|
||||||
|
return 1
|
||||||
|
case TensorTypeI16:
|
||||||
|
return 2
|
||||||
|
case TensorTypeI32:
|
||||||
|
return 4
|
||||||
|
case TensorTypeI64:
|
||||||
|
return 8
|
||||||
|
case TensorTypeF64:
|
||||||
|
return 8
|
||||||
|
case tensorTypeIQ1_M:
|
||||||
|
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32
|
||||||
|
case TensorTypeBF16:
|
||||||
|
return 2
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt TensorType) blockSize() int64 {
|
||||||
|
switch tt {
|
||||||
|
case TensorTypeF32,
|
||||||
|
TensorTypeF16,
|
||||||
|
TensorTypeI8,
|
||||||
|
TensorTypeI16,
|
||||||
|
TensorTypeI32,
|
||||||
|
TensorTypeI64,
|
||||||
|
TensorTypeF64,
|
||||||
|
TensorTypeBF16:
|
||||||
|
return 1
|
||||||
|
case TensorTypeQ4_0,
|
||||||
|
TensorTypeQ4_1,
|
||||||
|
TensorTypeQ5_0,
|
||||||
|
TensorTypeQ5_1,
|
||||||
|
TensorTypeQ8_0,
|
||||||
|
TensorTypeQ8_1,
|
||||||
|
tensorTypeIQ4_NL:
|
||||||
|
return 32
|
||||||
|
default:
|
||||||
|
return 256
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt TensorType) String() string {
|
||||||
|
switch tt {
|
||||||
|
case TensorTypeF32:
|
||||||
|
return "f32"
|
||||||
|
case TensorTypeF16:
|
||||||
|
return "f16"
|
||||||
|
case TensorTypeQ4_0:
|
||||||
|
return "q4_0"
|
||||||
|
case TensorTypeQ4_1:
|
||||||
|
return "q4_1"
|
||||||
|
case tensorTypeQ4_2:
|
||||||
|
return "q4_2"
|
||||||
|
case tensorTypeQ4_3:
|
||||||
|
return "q4_3"
|
||||||
|
case TensorTypeQ5_0:
|
||||||
|
return "q5_0"
|
||||||
|
case TensorTypeQ5_1:
|
||||||
|
return "q5_1"
|
||||||
|
case TensorTypeQ8_0:
|
||||||
|
return "q8_0"
|
||||||
|
case TensorTypeQ8_1:
|
||||||
|
return "q8_1"
|
||||||
|
case TensorTypeQ2_K:
|
||||||
|
return "q2_k"
|
||||||
|
case TensorTypeQ3_K:
|
||||||
|
return "q3_k"
|
||||||
|
case TensorTypeQ4_K:
|
||||||
|
return "q4_k"
|
||||||
|
case TensorTypeQ5_K:
|
||||||
|
return "q5_k"
|
||||||
|
case TensorTypeQ6_K:
|
||||||
|
return "q6_k"
|
||||||
|
case TensorTypeQ8_K:
|
||||||
|
return "q8_k"
|
||||||
|
case tensorTypeIQ2_XXS:
|
||||||
|
return "iq2_xxs"
|
||||||
|
case tensorTypeIQ2_XS:
|
||||||
|
return "iq2_xs"
|
||||||
|
case tensorTypeIQ3_XXS:
|
||||||
|
return "iq3_xxs"
|
||||||
|
case tensorTypeIQ1_S:
|
||||||
|
return "iq1_s"
|
||||||
|
case tensorTypeIQ4_NL:
|
||||||
|
return "iq4_nl"
|
||||||
|
case tensorTypeIQ3_S:
|
||||||
|
return "iq3_s"
|
||||||
|
case tensorTypeIQ2_S:
|
||||||
|
return "iq2_s"
|
||||||
|
case tensorTypeIQ4_XS:
|
||||||
|
return "iq4_xs"
|
||||||
|
case TensorTypeI8:
|
||||||
|
return "i8"
|
||||||
|
case TensorTypeI16:
|
||||||
|
return "i16"
|
||||||
|
case TensorTypeI32:
|
||||||
|
return "i32"
|
||||||
|
case TensorTypeI64:
|
||||||
|
return "i64"
|
||||||
|
case TensorTypeF64:
|
||||||
|
return "f64"
|
||||||
|
case tensorTypeIQ1_M:
|
||||||
|
return "iq1_m"
|
||||||
|
case TensorTypeBF16:
|
||||||
|
return "bf16"
|
||||||
|
case tensorTypeQ4_0_4_4:
|
||||||
|
return "q4_0_4_4"
|
||||||
|
case tensorTypeQ4_0_4_8:
|
||||||
|
return "q4_0_4_8"
|
||||||
|
case tensorTypeQ4_0_8_8:
|
||||||
|
return "q4_0_8_8"
|
||||||
|
case tensorTypeTQ1_0:
|
||||||
|
return "tq1_0"
|
||||||
|
case tensorTypeTQ2_0:
|
||||||
|
return "tq2_0"
|
||||||
|
case tensorTypeIQ4_NL_4_4:
|
||||||
|
return "iq4_nl_4_4"
|
||||||
|
case tensorTypeIQ4_NL_4_8:
|
||||||
|
return "iq4_nl_4_8"
|
||||||
|
case tensorTypeIQ4_NL_8_8:
|
||||||
|
return "iq4_nl_8_8"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tt TensorType) LogValue() slog.Value {
|
||||||
|
return slog.GroupValue(
|
||||||
|
slog.Uint64("value", uint64(tt)),
|
||||||
|
slog.String("name", strings.ToUpper(tt.String())),
|
||||||
|
slog.Int64("size", tt.typeSize()),
|
||||||
|
slog.Int64("block_size", tt.blockSize()),
|
||||||
|
slog.Float64("num_bytes", tt.NumBytes()),
|
||||||
|
)
|
||||||
|
}
|
||||||
17
go.mod
17
go.mod
@@ -11,7 +11,7 @@ require (
|
|||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
github.com/stretchr/testify v1.9.0
|
github.com/stretchr/testify v1.9.0
|
||||||
github.com/x448/float16 v0.8.4
|
github.com/x448/float16 v0.8.4
|
||||||
golang.org/x/sync v0.11.0
|
golang.org/x/sync v0.12.0
|
||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
@@ -19,11 +19,12 @@ require (
|
|||||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||||
github.com/dlclark/regexp2 v1.11.4
|
github.com/dlclark/regexp2 v1.11.4
|
||||||
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
github.com/emirpasic/gods/v2 v2.0.0-alpha
|
||||||
github.com/google/go-cmp v0.6.0
|
github.com/google/go-cmp v0.7.0
|
||||||
github.com/mattn/go-runewidth v0.0.14
|
github.com/mattn/go-runewidth v0.0.14
|
||||||
github.com/nlpodyssey/gopickle v0.3.0
|
github.com/nlpodyssey/gopickle v0.3.0
|
||||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||||
golang.org/x/image v0.22.0
|
golang.org/x/image v0.22.0
|
||||||
|
golang.org/x/tools v0.30.0
|
||||||
gonum.org/v1/gonum v0.15.0
|
gonum.org/v1/gonum v0.15.0
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -69,12 +70,12 @@ require (
|
|||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
golang.org/x/crypto v0.33.0
|
golang.org/x/crypto v0.36.0
|
||||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa // indirect
|
||||||
golang.org/x/net v0.35.0 // indirect
|
golang.org/x/net v0.38.0 // indirect
|
||||||
golang.org/x/sys v0.30.0
|
golang.org/x/sys v0.31.0
|
||||||
golang.org/x/term v0.29.0
|
golang.org/x/term v0.30.0
|
||||||
golang.org/x/text v0.22.0
|
golang.org/x/text v0.23.0
|
||||||
google.golang.org/protobuf v1.34.1
|
google.golang.org/protobuf v1.34.1
|
||||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||||
)
|
)
|
||||||
|
|||||||
30
go.sum
30
go.sum
@@ -112,8 +112,8 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
@@ -214,8 +214,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
|||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
||||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
||||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
|||||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
|
||||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
||||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
|||||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
golang.org/x/sync v0.12.0 h1:MHc5BpPuC30uJk597Ri8TV3CNZcTLu6B6z4lJy+g6Jw=
|
||||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
golang.org/x/sync v0.12.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
|||||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
||||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
golang.org/x/term v0.30.0 h1:PQ39fJZ+mfadBm0y5WlL4vlM7Sx1Hgf13sMIY2+QS9Y=
|
||||||
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
golang.org/x/term v0.30.0/go.mod h1:NYYFdzHoI5wRh/h5tDMdMqCqPJZEuNqVR5xJLd/n67g=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
|
||||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
|
||||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
@@ -309,6 +309,8 @@ golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapK
|
|||||||
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
|
||||||
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
|
||||||
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||||
|
golang.org/x/tools v0.30.0 h1:BgcpHewrV5AUp2G9MebG4XPFI1E2W41zU1SaqVA9vJY=
|
||||||
|
golang.org/x/tools v0.30.0/go.mod h1:c347cR/OJfw5TI+GfX7RUPNMdDRRbjvYTS0jPyvsVtY=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
463
harmony/harmonyparser.go
Normal file
463
harmony/harmonyparser.go
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
package harmony
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type harmonyParserState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||||
|
harmonyParserState_ParsingHeader
|
||||||
|
harmonyParserState_ParsingContent
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s harmonyParserState) String() string {
|
||||||
|
switch s {
|
||||||
|
// we're looking for the message start tag
|
||||||
|
case harmonyParserState_LookingForMessageStart:
|
||||||
|
return "LookingForMessageStart"
|
||||||
|
case harmonyParserState_ParsingHeader:
|
||||||
|
return "ParsingHeader"
|
||||||
|
case harmonyParserState_ParsingContent:
|
||||||
|
return "ParsingContent"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type HarmonyParser struct {
|
||||||
|
state harmonyParserState
|
||||||
|
MessageStartTag string
|
||||||
|
MessageEndTag string
|
||||||
|
HeaderEndTag string
|
||||||
|
acc strings.Builder
|
||||||
|
lifetimeAcc strings.Builder
|
||||||
|
}
|
||||||
|
|
||||||
|
type HarmonyEvent interface {
|
||||||
|
isHarmonyEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
type HarmonyEventMessageStart struct{}
|
||||||
|
|
||||||
|
func (HarmonyEventMessageStart) isHarmonyEvent() {}
|
||||||
|
|
||||||
|
type HarmonyEventHeaderComplete struct {
|
||||||
|
Header HarmonyHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (HarmonyEventHeaderComplete) isHarmonyEvent() {}
|
||||||
|
|
||||||
|
type HarmonyEventContentEmitted struct {
|
||||||
|
Content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (HarmonyEventContentEmitted) isHarmonyEvent() {}
|
||||||
|
|
||||||
|
type HarmonyEventMessageEnd struct{}
|
||||||
|
|
||||||
|
func (HarmonyEventMessageEnd) isHarmonyEvent() {}
|
||||||
|
|
||||||
|
type HarmonyHeader struct {
|
||||||
|
Role string
|
||||||
|
Channel string
|
||||||
|
Recipient string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HarmonyParser) AddImplicitStart() {
|
||||||
|
s.acc.WriteString("<|start|>assistant")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||||
|
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||||
|
// handle prefilling conditions
|
||||||
|
if lastMessage.Content != "" {
|
||||||
|
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||||
|
return
|
||||||
|
} else if lastMessage.Thinking != "" {
|
||||||
|
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.AddImplicitStart()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
||||||
|
s.lifetimeAcc.WriteString(content)
|
||||||
|
s.acc.WriteString(content)
|
||||||
|
|
||||||
|
var events []HarmonyEvent
|
||||||
|
|
||||||
|
keepLooping := true
|
||||||
|
// we loop because we might pass through multiple parsing states in a single
|
||||||
|
// call to addContent, and we want to make sure callers don't have to wait for
|
||||||
|
// data that's already unambiguous
|
||||||
|
for keepLooping {
|
||||||
|
var newEvents []HarmonyEvent
|
||||||
|
newEvents, keepLooping = eat(s)
|
||||||
|
events = append(events, newEvents...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return events
|
||||||
|
}
|
||||||
|
|
||||||
|
// the additional bool return is true iff we should continue eating
|
||||||
|
func eat(s *HarmonyParser) ([]HarmonyEvent, bool) {
|
||||||
|
switch s.state {
|
||||||
|
case harmonyParserState_LookingForMessageStart:
|
||||||
|
// does the acc contain the message start tag?
|
||||||
|
if strings.Contains(s.acc.String(), s.MessageStartTag) {
|
||||||
|
// split the acc into the message start tag and the rest
|
||||||
|
split := strings.SplitN(s.acc.String(), s.MessageStartTag, 2)
|
||||||
|
before := split[0]
|
||||||
|
if before != "" {
|
||||||
|
slog.Warn("harmony parser: found message start tag in the middle of the content", "content", s.acc.String())
|
||||||
|
}
|
||||||
|
after := split[1]
|
||||||
|
s.acc.Reset()
|
||||||
|
s.acc.WriteString(after)
|
||||||
|
s.state = harmonyParserState_ParsingHeader
|
||||||
|
return []HarmonyEvent{HarmonyEventMessageStart{}}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// no match, so we keep accumulating
|
||||||
|
return nil, false
|
||||||
|
case harmonyParserState_ParsingHeader:
|
||||||
|
if strings.Contains(s.acc.String(), s.HeaderEndTag) {
|
||||||
|
split := strings.SplitN(s.acc.String(), s.HeaderEndTag, 2)
|
||||||
|
header := split[0]
|
||||||
|
after := split[1]
|
||||||
|
s.acc.Reset()
|
||||||
|
s.acc.WriteString(after)
|
||||||
|
s.state = harmonyParserState_ParsingContent
|
||||||
|
return []HarmonyEvent{HarmonyEventHeaderComplete{Header: s.parseHeader(header)}}, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
case harmonyParserState_ParsingContent:
|
||||||
|
if strings.Contains(s.acc.String(), s.MessageEndTag) {
|
||||||
|
// if we already have the message end tag, we can emit the content up to it
|
||||||
|
split := strings.SplitN(s.acc.String(), s.MessageEndTag, 2)
|
||||||
|
content := split[0]
|
||||||
|
after := split[1]
|
||||||
|
s.acc.Reset()
|
||||||
|
s.acc.WriteString(after)
|
||||||
|
s.state = harmonyParserState_LookingForMessageStart
|
||||||
|
events := []HarmonyEvent{}
|
||||||
|
if content != "" {
|
||||||
|
events = append(events, HarmonyEventContentEmitted{Content: content})
|
||||||
|
}
|
||||||
|
events = append(events, HarmonyEventMessageEnd{})
|
||||||
|
return events, true
|
||||||
|
} else if overlapLen := overlap(s.acc.String(), s.MessageEndTag); overlapLen > 0 {
|
||||||
|
// if our suffix contains the start of the message end tag, we can emit
|
||||||
|
// the content up to the start of the message end tag
|
||||||
|
content := s.acc.String()[:len(s.acc.String())-overlapLen]
|
||||||
|
remaining := s.acc.String()[len(s.acc.String())-overlapLen:]
|
||||||
|
s.acc.Reset()
|
||||||
|
s.acc.WriteString(remaining)
|
||||||
|
// emit the content we know isn't part of the message end tag, and keep
|
||||||
|
// accumulating to disambiguate the rest
|
||||||
|
if content == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||||
|
} else {
|
||||||
|
// no end tag, so it's still normal content that we can immediately emit
|
||||||
|
content := s.acc.String()
|
||||||
|
if content == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
s.acc.Reset()
|
||||||
|
return []HarmonyEvent{HarmonyEventContentEmitted{Content: content}}, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *HarmonyParser) parseHeader(raw string) HarmonyHeader {
|
||||||
|
harmonyHeader := HarmonyHeader{}
|
||||||
|
|
||||||
|
// if `<|constrain|>` is present, ensure it has a space before it so it gets
|
||||||
|
// parsed as a separate token, even if the model didn't include the space
|
||||||
|
if strings.Contains(raw, "<|constrain|>") {
|
||||||
|
raw = strings.Replace(raw, "<|constrain|>", " <|constrain|>", 1)
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// look for the optional channel tag, which is `<|channel|>` followed by the
|
||||||
|
// channel name, all without any whitespace
|
||||||
|
channelIndex := strings.Index(raw, "<|channel|>")
|
||||||
|
if channelIndex != -1 {
|
||||||
|
before := raw[:channelIndex]
|
||||||
|
after := raw[channelIndex+len("<|channel|>"):]
|
||||||
|
// the channel name is `after` all the way up to the first (if any) whitespace character
|
||||||
|
idx := strings.IndexFunc(after, func(r rune) bool {
|
||||||
|
return unicode.IsSpace(r)
|
||||||
|
})
|
||||||
|
if idx == -1 {
|
||||||
|
idx = len(after)
|
||||||
|
}
|
||||||
|
harmonyHeader.Channel = after[:idx]
|
||||||
|
after = after[idx:]
|
||||||
|
// now we remove the channel tag from the raw string to further process
|
||||||
|
raw = before + after
|
||||||
|
raw = strings.TrimSpace(raw)
|
||||||
|
}
|
||||||
|
|
||||||
|
// split the header into whitespace-separated tokens
|
||||||
|
tokens := strings.Fields(raw)
|
||||||
|
|
||||||
|
// the first token is treated as the role
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
slog.Error("harmony parser: missing role in header", "header", raw)
|
||||||
|
return harmonyHeader
|
||||||
|
}
|
||||||
|
role := tokens[0]
|
||||||
|
tokens = tokens[1:]
|
||||||
|
// special case: if role starts with to= then it's a tool call
|
||||||
|
if strings.HasPrefix(role, "to=") {
|
||||||
|
harmonyHeader.Recipient = role[3:]
|
||||||
|
harmonyHeader.Role = "tool"
|
||||||
|
} else {
|
||||||
|
harmonyHeader.Role = role
|
||||||
|
}
|
||||||
|
|
||||||
|
// the recipient (if any) can be specified before or after the channel tag, so
|
||||||
|
// we check it at the end once we've already parsed the channel and role
|
||||||
|
if harmonyHeader.Recipient == "" && len(tokens) > 0 && strings.HasPrefix(tokens[0], "to=") {
|
||||||
|
harmonyHeader.Recipient = tokens[0][3:]
|
||||||
|
}
|
||||||
|
|
||||||
|
return harmonyHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// longest overlap between suffix of s and prefix of delim
|
||||||
|
func overlap(s, delim string) int {
|
||||||
|
max := min(len(delim), len(s))
|
||||||
|
for i := max; i > 0; i-- {
|
||||||
|
if strings.HasSuffix(s, delim[:i]) {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// harmonyMessageState represents the current state of message processing
|
||||||
|
type harmonyMessageState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
harmonyMessageState_Normal harmonyMessageState = iota
|
||||||
|
harmonyMessageState_Thinking
|
||||||
|
harmonyMessageState_ToolCalling
|
||||||
|
)
|
||||||
|
|
||||||
|
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
||||||
|
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||||
|
type HarmonyMessageHandler struct {
|
||||||
|
state harmonyMessageState
|
||||||
|
HarmonyParser *HarmonyParser
|
||||||
|
FunctionNameMap *FunctionNameMap
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHarmonyMessageHandler creates a new message handler
|
||||||
|
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||||
|
return &HarmonyMessageHandler{
|
||||||
|
state: harmonyMessageState_Normal,
|
||||||
|
HarmonyParser: &HarmonyParser{
|
||||||
|
MessageStartTag: "<|start|>",
|
||||||
|
MessageEndTag: "<|end|>",
|
||||||
|
HeaderEndTag: "<|message|>",
|
||||||
|
},
|
||||||
|
FunctionNameMap: NewFunctionNameMap(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddContent processes the content and returns the content, thinking, and tool content.
|
||||||
|
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
|
||||||
|
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
|
||||||
|
contentSb := strings.Builder{}
|
||||||
|
thinkingSb := strings.Builder{}
|
||||||
|
toolContentSb := strings.Builder{}
|
||||||
|
|
||||||
|
events := h.HarmonyParser.AddContent(content)
|
||||||
|
for _, event := range events {
|
||||||
|
switch event := event.(type) {
|
||||||
|
case HarmonyEventHeaderComplete:
|
||||||
|
logutil.Trace("harmony event header complete", "header", event.Header)
|
||||||
|
switch event.Header.Channel {
|
||||||
|
case "analysis":
|
||||||
|
if event.Header.Recipient != "" {
|
||||||
|
h.state = harmonyMessageState_ToolCalling
|
||||||
|
// event.Header.Recipient is the tool name, something like
|
||||||
|
// "browser.search" for a built-in, or "functions.calc" for a
|
||||||
|
// custom one
|
||||||
|
toolParser.SetToolName(event.Header.Recipient)
|
||||||
|
} else {
|
||||||
|
h.state = harmonyMessageState_Thinking
|
||||||
|
}
|
||||||
|
case "commentary":
|
||||||
|
if event.Header.Recipient != "" {
|
||||||
|
h.state = harmonyMessageState_ToolCalling
|
||||||
|
toolParser.SetToolName(event.Header.Recipient)
|
||||||
|
} else {
|
||||||
|
h.state = harmonyMessageState_Normal
|
||||||
|
}
|
||||||
|
case "final":
|
||||||
|
h.state = harmonyMessageState_Normal
|
||||||
|
}
|
||||||
|
case HarmonyEventContentEmitted:
|
||||||
|
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||||
|
if h.state == harmonyMessageState_Normal {
|
||||||
|
contentSb.WriteString(event.Content)
|
||||||
|
} else if h.state == harmonyMessageState_Thinking {
|
||||||
|
thinkingSb.WriteString(event.Content)
|
||||||
|
} else if h.state == harmonyMessageState_ToolCalling {
|
||||||
|
toolContentSb.WriteString(event.Content)
|
||||||
|
}
|
||||||
|
case HarmonyEventMessageEnd:
|
||||||
|
h.state = harmonyMessageState_Normal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
|
||||||
|
return &HarmonyToolCallAccumulator{
|
||||||
|
state: harmonyToolCallState_Normal,
|
||||||
|
currentToolName: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type harmonyToolCallState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
harmonyToolCallState_Normal harmonyToolCallState = iota
|
||||||
|
harmonyToolCallState_ToolCalling
|
||||||
|
)
|
||||||
|
|
||||||
|
type HarmonyToolCallAccumulator struct {
|
||||||
|
state harmonyToolCallState
|
||||||
|
acc strings.Builder
|
||||||
|
currentToolName *string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *HarmonyToolCallAccumulator) SetToolName(toolName string) {
|
||||||
|
a.currentToolName = &toolName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *HarmonyToolCallAccumulator) Add(content string) {
|
||||||
|
a.acc.WriteString(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
|
||||||
|
str := a.acc.String()
|
||||||
|
a.state = harmonyToolCallState_Normal
|
||||||
|
a.acc.Reset()
|
||||||
|
return a.currentToolName, str
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *HarmonyToolCallAccumulator) Content() string {
|
||||||
|
return a.acc.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunctionNameMap maps a user-specified function name to a valid function
|
||||||
|
// name for harmony (which look like TypeScript identifiers). This is needed to
|
||||||
|
// transform user-specified function names, which might contain characters that
|
||||||
|
// are not allowed in TypeScript identifiers
|
||||||
|
type FunctionNameMap struct {
|
||||||
|
userToHarmony map[string]string
|
||||||
|
harmonyToUser map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFunctionNameMap() *FunctionNameMap {
|
||||||
|
return &FunctionNameMap{
|
||||||
|
userToHarmony: make(map[string]string),
|
||||||
|
harmonyToUser: make(map[string]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||||
|
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||||
|
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||||
|
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||||
|
return harmonyFunctionName
|
||||||
|
}
|
||||||
|
|
||||||
|
// OriginalFromConverted looks up the reverse-mapping of a previously-converted
|
||||||
|
// user->harmony function name. To unmap reliably, the mapping must exist, as
|
||||||
|
// the conversion process is not reversible without the appropriate state
|
||||||
|
func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string {
|
||||||
|
if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok {
|
||||||
|
return userFunctionName
|
||||||
|
}
|
||||||
|
slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName)
|
||||||
|
// fallback to the original function name if we can't find a mapping
|
||||||
|
return harmonyFunctionName
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToValidChars converts a user-specified function name to a valid
|
||||||
|
// TypeScript identifier.
|
||||||
|
//
|
||||||
|
// Limitations:
|
||||||
|
//
|
||||||
|
// - This doesn't restrict reserved TypeScript keywords.
|
||||||
|
// - We don't perform a real ID_Start/ID_Continue check, and instead use the more
|
||||||
|
// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of
|
||||||
|
// identifiers these models were trained on, so in the end we might want to
|
||||||
|
// convert unicode-heavy identifiers to their closest ASCII equivalents.
|
||||||
|
func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string {
|
||||||
|
mapper := func(r rune) rune {
|
||||||
|
// first, replace certain characters with underscores
|
||||||
|
if r == ' ' || r == '-' || r == '.' {
|
||||||
|
return '_'
|
||||||
|
}
|
||||||
|
|
||||||
|
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// finally, remove any other characters
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
candidate := strings.Map(mapper, userFunctionName)
|
||||||
|
|
||||||
|
// set a default name if we end up with nothing left
|
||||||
|
if candidate == "" {
|
||||||
|
return "unnamed"
|
||||||
|
}
|
||||||
|
|
||||||
|
// if the candidate starts with a number, prepend an underscore to make it a
|
||||||
|
// valid identifier
|
||||||
|
if unicode.IsDigit(rune(candidate[0])) {
|
||||||
|
candidate = "_" + candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *FunctionNameMap) deriveName(userFunctionName string) string {
|
||||||
|
originalCandidate := m.convertToValidChars(userFunctionName)
|
||||||
|
candidate := originalCandidate
|
||||||
|
|
||||||
|
// Check for dupes, and if so, add a number to the end.
|
||||||
|
// We start at 2 because if we have dupes and the first is never renamed, it
|
||||||
|
// makes sense for them to be named, say, `f`, `f_2`, `f_3`
|
||||||
|
count := 2
|
||||||
|
for {
|
||||||
|
if _, exists := m.harmonyToUser[candidate]; !exists {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
candidate = fmt.Sprintf("%s_%d", originalCandidate, count)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
537
harmony/harmonyparser_test.go
Normal file
537
harmony/harmonyparser_test.go
Normal file
@@ -0,0 +1,537 @@
|
|||||||
|
package harmony
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHeaderParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
in, wantRole, wantChannel, wantRecipient string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
in: "assistant<|channel|>analysis",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "analysis",
|
||||||
|
wantRecipient: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "assistant<|channel|>analysis to=functions.get_weather",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "analysis",
|
||||||
|
wantRecipient: "functions.get_weather",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "assistant to=functions.get_weather<|channel|>analysis",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "analysis",
|
||||||
|
wantRecipient: "functions.get_weather",
|
||||||
|
},
|
||||||
|
// special case where the role is replaced by the recipient (matches reference code)
|
||||||
|
{
|
||||||
|
in: "to=functions.get_weather<|channel|>analysis",
|
||||||
|
wantRole: "tool",
|
||||||
|
wantChannel: "analysis",
|
||||||
|
wantRecipient: "functions.get_weather",
|
||||||
|
},
|
||||||
|
// extra token after the recipient is ignored
|
||||||
|
{
|
||||||
|
in: "assistant to=functions.get_weather abc<|channel|>analysis",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "analysis",
|
||||||
|
wantRecipient: "functions.get_weather",
|
||||||
|
},
|
||||||
|
// with constrain tag, recipient after channel tag
|
||||||
|
{
|
||||||
|
in: "assistant<|channel|>commentary to=functions.get_weather <|constrain|>json",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "commentary",
|
||||||
|
wantRecipient: "functions.get_weather",
|
||||||
|
},
|
||||||
|
// with constrain tag, recipient before channel tag
|
||||||
|
{
|
||||||
|
in: "assistant to=functions.get_weather<|channel|>commentary <|constrain|>json",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "commentary",
|
||||||
|
wantRecipient: "functions.get_weather",
|
||||||
|
},
|
||||||
|
// constrain tag without space
|
||||||
|
{
|
||||||
|
in: "assistant<|channel|>commentary to=functions.get_weather<|constrain|>json",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "commentary",
|
||||||
|
wantRecipient: "functions.get_weather",
|
||||||
|
},
|
||||||
|
// constrain tag without space, different order
|
||||||
|
{
|
||||||
|
in: "assistant to=functions.get_weather<|channel|>commentary<|constrain|>json",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "commentary",
|
||||||
|
wantRecipient: "functions.get_weather",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, tt := range tests {
|
||||||
|
parser := HarmonyParser{
|
||||||
|
MessageStartTag: "<|start|>",
|
||||||
|
MessageEndTag: "<|end|>",
|
||||||
|
HeaderEndTag: "<|message|>",
|
||||||
|
}
|
||||||
|
header := parser.parseHeader(tt.in)
|
||||||
|
|
||||||
|
if header.Role != tt.wantRole {
|
||||||
|
t.Errorf("case %d: got role \"%s\", want \"%s\"", i, header.Role, tt.wantRole)
|
||||||
|
}
|
||||||
|
if header.Channel != tt.wantChannel {
|
||||||
|
t.Errorf("case %d: got channel \"%s\", want \"%s\"", i, header.Channel, tt.wantChannel)
|
||||||
|
}
|
||||||
|
if header.Recipient != tt.wantRecipient {
|
||||||
|
t.Errorf("case %d: got recipient \"%s\", want \"%s\"", i, header.Recipient, tt.wantRecipient)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHarmonyParserHeaderEvent(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
in, wantRole, wantChannel, wantRecipient string
|
||||||
|
implicitStart bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||||
|
wantRole: "user",
|
||||||
|
wantChannel: "",
|
||||||
|
wantRecipient: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "<|start|>assistant<|channel|>analysis<|message|>What is 2 + 2?<|end|>",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "analysis",
|
||||||
|
wantRecipient: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "<|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{\"location\":\"San Francisco\"}<|call|><|start|>functions.get_weather to=assistant<|message|>{\"sunny\": true, \"temperature\": 20}<|end|>",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "commentary",
|
||||||
|
wantRecipient: "functions.get_weather",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "<|channel|>analysis<|message|>User asks weather in SF. We need location. Use get_current_weather with location \"San Francisco, CA\".<|end|><|start|>assistant<|channel|>commentary to=functions.get_current_weather <|constrain|>json<|message|>{\"location\":\"San Francisco, CA\"}<|call|>",
|
||||||
|
wantRole: "assistant",
|
||||||
|
wantChannel: "analysis",
|
||||||
|
wantRecipient: "",
|
||||||
|
implicitStart: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, tt := range tests {
|
||||||
|
parser := HarmonyParser{
|
||||||
|
MessageStartTag: "<|start|>",
|
||||||
|
MessageEndTag: "<|end|>",
|
||||||
|
HeaderEndTag: "<|message|>",
|
||||||
|
}
|
||||||
|
if tt.implicitStart {
|
||||||
|
parser.AddImplicitStart()
|
||||||
|
}
|
||||||
|
gotEvents := parser.AddContent(tt.in)
|
||||||
|
if len(gotEvents) == 0 {
|
||||||
|
t.Errorf("case %d: got no events, want at least one", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
var firstHeaderEvent *HarmonyEventHeaderComplete
|
||||||
|
// print events
|
||||||
|
for _, event := range gotEvents {
|
||||||
|
fmt.Printf("event: %+v\n", event)
|
||||||
|
}
|
||||||
|
for _, event := range gotEvents {
|
||||||
|
if event, ok := event.(HarmonyEventHeaderComplete); ok {
|
||||||
|
firstHeaderEvent = &event
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstHeaderEvent == nil {
|
||||||
|
t.Errorf("case %d: got no header complete event, want one", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
gotHeader := firstHeaderEvent.Header
|
||||||
|
if gotHeader.Role != tt.wantRole || gotHeader.Channel != tt.wantChannel || gotHeader.Recipient != tt.wantRecipient {
|
||||||
|
t.Errorf("case %d: got header %+v, want role=%s channel=%s recipient=%s", i, gotHeader, tt.wantRole, tt.wantChannel, tt.wantRecipient)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHarmonyParserNonStreaming(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
in string
|
||||||
|
implicitStart bool
|
||||||
|
wantEvents []HarmonyEvent
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
in: "<|start|>user<|message|>What is 2 + 2?<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||||
|
HarmonyEventContentEmitted{Content: "What is 2 + 2?"},
|
||||||
|
HarmonyEventMessageEnd{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "<|start|>assistant<|channel|>analysis<|message|>The answer is 4<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}},
|
||||||
|
HarmonyEventContentEmitted{Content: "The answer is 4"},
|
||||||
|
HarmonyEventMessageEnd{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>Computing...<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||||
|
HarmonyEventContentEmitted{Content: "Computing..."},
|
||||||
|
HarmonyEventMessageEnd{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "<|start|>user<|message|><|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||||
|
HarmonyEventMessageEnd{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "<|start|>user<|message|>Hello<|end|><|start|>assistant<|message|>Hi!<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||||
|
HarmonyEventContentEmitted{Content: "Hello"},
|
||||||
|
HarmonyEventMessageEnd{},
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}},
|
||||||
|
HarmonyEventContentEmitted{Content: "Hi!"},
|
||||||
|
HarmonyEventMessageEnd{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
in: "<|channel|>analysis<|message|>Thinking about the request<|end|>",
|
||||||
|
implicitStart: true,
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}, HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}, HarmonyEventContentEmitted{Content: "Thinking about the request"}, HarmonyEventMessageEnd{}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, tt := range tests {
|
||||||
|
parser := HarmonyParser{
|
||||||
|
MessageStartTag: "<|start|>",
|
||||||
|
MessageEndTag: "<|end|>",
|
||||||
|
HeaderEndTag: "<|message|>",
|
||||||
|
}
|
||||||
|
if tt.implicitStart {
|
||||||
|
parser.AddImplicitStart()
|
||||||
|
}
|
||||||
|
gotEvents := parser.AddContent(tt.in)
|
||||||
|
if !reflect.DeepEqual(gotEvents, tt.wantEvents) {
|
||||||
|
t.Errorf("case %d: got events %#v, want %#v", i, gotEvents, tt.wantEvents)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHarmonyParserStreaming(t *testing.T) {
|
||||||
|
type step struct {
|
||||||
|
input string
|
||||||
|
wantEvents []HarmonyEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
desc string
|
||||||
|
implicitStart bool
|
||||||
|
steps []step
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
desc: "simple message streamed character by character",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<",
|
||||||
|
wantEvents: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "|",
|
||||||
|
wantEvents: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "start|>u",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "ser<|mess",
|
||||||
|
wantEvents: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "age|>Hi",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||||
|
HarmonyEventContentEmitted{Content: "Hi"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: " there",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: " there"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|e",
|
||||||
|
wantEvents: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "nd|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "message with channel streamed",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<|start|>assistant",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|chan",
|
||||||
|
wantEvents: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "nel|>analysis",
|
||||||
|
wantEvents: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|message|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "analysis", Recipient: ""}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "Thinking",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Thinking"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "...",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "..."}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "message with channel and recipient",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<|start|>assistant<|channel|>commentary to=functions.calc<|message|>",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "{\"x\": 5}",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "message with channel and recipient (receipient before channel)",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<|start|>assistant to=functions.calc<|channel|>commentary<|message|>",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "commentary", Recipient: "functions.calc"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "{\"x\": 5}",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "{\"x\": 5}"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "implicit start with channel",
|
||||||
|
implicitStart: true,
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<|channel|>thinking",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|message|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "thinking", Recipient: ""}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "Processing request",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Processing request"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "multiple messages streamed",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<|start|>user<|message|>Hello<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||||
|
HarmonyEventContentEmitted{Content: "Hello"},
|
||||||
|
HarmonyEventMessageEnd{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|start|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageStart{}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "assistant<|message|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "assistant", Channel: "", Recipient: ""}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "Hi!",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "Hi!"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "empty message",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<|start|>system<|message|><|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "system", Channel: "", Recipient: ""}},
|
||||||
|
HarmonyEventMessageEnd{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
desc: "partial tag that looks like end but isn't",
|
||||||
|
steps: []step{
|
||||||
|
{
|
||||||
|
input: "<|start|>user<|message|>test<|e",
|
||||||
|
wantEvents: []HarmonyEvent{
|
||||||
|
HarmonyEventMessageStart{},
|
||||||
|
HarmonyEventHeaderComplete{Header: HarmonyHeader{Role: "user", Channel: "", Recipient: ""}},
|
||||||
|
HarmonyEventContentEmitted{Content: "test"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "xample|>more",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventContentEmitted{Content: "<|example|>more"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
input: "<|end|>",
|
||||||
|
wantEvents: []HarmonyEvent{HarmonyEventMessageEnd{}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.desc, func(t *testing.T) {
|
||||||
|
parser := HarmonyParser{
|
||||||
|
MessageStartTag: "<|start|>",
|
||||||
|
MessageEndTag: "<|end|>",
|
||||||
|
HeaderEndTag: "<|message|>",
|
||||||
|
}
|
||||||
|
if tc.implicitStart {
|
||||||
|
parser.AddImplicitStart()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, step := range tc.steps {
|
||||||
|
gotEvents := parser.AddContent(step.input)
|
||||||
|
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||||
|
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't
|
||||||
|
// handle any saving (and therefore no dupe handling)
|
||||||
|
func TestFunctionConvertToValidChars(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "replace spaces with underscores", in: "get weather", want: "get_weather"},
|
||||||
|
{name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"},
|
||||||
|
{name: "replace periods with underscores", in: "get.weather", want: "get_weather"},
|
||||||
|
{name: "disallow non-word characters", in: "get weather!", want: "get_weather"},
|
||||||
|
{name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"},
|
||||||
|
{name: "names that only contain invalid characters", in: "🫠", want: "unnamed"},
|
||||||
|
{name: "leading number", in: "123", want: "_123"},
|
||||||
|
{name: "$ allowed", in: "$", want: "$"},
|
||||||
|
// show that we allow weird unicode letter characters, though we might want
|
||||||
|
// to convert them to their closest ASCII equivalents in the future
|
||||||
|
{name: "allow weird unicode letter characters", in: "𝓸𝓵𝓵𝓪𝓶𝓪", want: "𝓸𝓵𝓵𝓪𝓶𝓪"},
|
||||||
|
// names that look like words but are invalid (i.e., not ID_Start/ID_Continue)
|
||||||
|
{name: "disallow non-word characters that look like words", in: "ⓞⓛⓛⓐⓜⓐ123", want: "_123"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parser := NewFunctionNameMap()
|
||||||
|
got := parser.convertToValidChars(tt.in)
|
||||||
|
if got != tt.want {
|
||||||
|
t.Errorf("case %d: got %q, want %q", i, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFunctionConvertAndAdd(t *testing.T) {
|
||||||
|
// make a fresh map for each test, but within a test use the same map so we can test for dupe handling
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
in []string
|
||||||
|
want []string
|
||||||
|
}{
|
||||||
|
{name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}},
|
||||||
|
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||||
|
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||||
|
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tt := range tests {
|
||||||
|
parser := NewFunctionNameMap()
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
for j, in := range tt.in {
|
||||||
|
got := parser.ConvertAndAdd(in)
|
||||||
|
want := tt.want[j]
|
||||||
|
if got != want {
|
||||||
|
t.Errorf("case %d: got %q, want %q", i, got, want)
|
||||||
|
}
|
||||||
|
// check that the maps are correct
|
||||||
|
if parser.userToHarmony[in] != want {
|
||||||
|
t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want)
|
||||||
|
}
|
||||||
|
if parser.harmonyToUser[want] != in {
|
||||||
|
t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,10 +2,13 @@
|
|||||||
|
|
||||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||||
|
|
||||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||||
|
|
||||||
|
|
||||||
The integration tests have 2 modes of operating.
|
The integration tests have 2 modes of operating.
|
||||||
|
|
||||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||||
|
|||||||
412
integration/api_test.go
Normal file
412
integration/api_test.go
Normal file
@@ -0,0 +1,412 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAPIGenerate(t *testing.T) {
|
||||||
|
initialTimeout := 60 * time.Second
|
||||||
|
streamTimeout := 30 * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
// Set up the test data
|
||||||
|
req := api.GenerateRequest{
|
||||||
|
Model: smol,
|
||||||
|
Prompt: "why is the sky blue? be brief",
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
anyResp := []string{"rayleigh", "scattering"}
|
||||||
|
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
stream bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "stream",
|
||||||
|
stream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_stream",
|
||||||
|
stream: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fn := func(response api.GenerateResponse) error {
|
||||||
|
// Fields that must always be present
|
||||||
|
if response.Model == "" {
|
||||||
|
t.Errorf("response missing model: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Done {
|
||||||
|
// Required fields for final updates:
|
||||||
|
if response.DoneReason == "" && *req.Stream {
|
||||||
|
// TODO - is the lack of done reason on non-stream a bug?
|
||||||
|
t.Errorf("final response missing done_reason: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.TotalDuration == 0 {
|
||||||
|
t.Errorf("final response missing total_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.LoadDuration == 0 {
|
||||||
|
t.Errorf("final response missing load_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("final response missing prompt_eval_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.EvalCount == 0 {
|
||||||
|
t.Errorf("final response missing eval_count: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.EvalDuration == 0 {
|
||||||
|
t.Errorf("final response missing eval_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if len(response.Context) == 0 {
|
||||||
|
t.Errorf("final response missing context: %#v", response)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: caching can result in no prompt eval count, so this can't be verified reliably
|
||||||
|
// if response.Metrics.PromptEvalCount == 0 {
|
||||||
|
// t.Errorf("final response missing prompt_eval_count: %#v", response)
|
||||||
|
// }
|
||||||
|
|
||||||
|
} // else incremental response, nothing to check right now...
|
||||||
|
buf.Write([]byte(response.Response))
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
req.Stream = &test.stream
|
||||||
|
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
|
||||||
|
genErr = client.Generate(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||||
|
} else {
|
||||||
|
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||||
|
}
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil {
|
||||||
|
t.Fatalf("failed with %s request prompt %s ", req.Model, req.Prompt)
|
||||||
|
}
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response := buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Errorf("none of %v found in %s", anyResp, response)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("outer test context done while waiting for generate")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate PS while we're at it...
|
||||||
|
resp, err := client.ListRunning(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list models API error: %s", err)
|
||||||
|
}
|
||||||
|
if resp == nil || len(resp.Models) == 0 {
|
||||||
|
t.Fatalf("list models API returned empty list while model should still be loaded")
|
||||||
|
}
|
||||||
|
// Find the model we just loaded and verify some attributes
|
||||||
|
found := false
|
||||||
|
for _, model := range resp.Models {
|
||||||
|
if strings.Contains(model.Name, req.Model) {
|
||||||
|
found = true
|
||||||
|
if model.Model == "" {
|
||||||
|
t.Errorf("model field omitted: %#v", model)
|
||||||
|
}
|
||||||
|
if model.Size == 0 {
|
||||||
|
t.Errorf("size omitted: %#v", model)
|
||||||
|
}
|
||||||
|
if model.Digest == "" {
|
||||||
|
t.Errorf("digest omitted: %#v", model)
|
||||||
|
}
|
||||||
|
verifyModelDetails(t, model.Details)
|
||||||
|
var nilTime time.Time
|
||||||
|
if model.ExpiresAt == nilTime {
|
||||||
|
t.Errorf("expires_at omitted: %#v", model)
|
||||||
|
}
|
||||||
|
// SizeVRAM could be zero.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
t.Errorf("unable to locate running model: %#v", resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIChat(t *testing.T) {
|
||||||
|
initialTimeout := 60 * time.Second
|
||||||
|
streamTimeout := 30 * time.Second
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
// Set up the test data
|
||||||
|
req := api.ChatRequest{
|
||||||
|
Model: smol,
|
||||||
|
Messages: []api.Message{
|
||||||
|
{
|
||||||
|
Role: "user",
|
||||||
|
Content: "why is the sky blue? be brief",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
anyResp := []string{"rayleigh", "scattering"}
|
||||||
|
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
stream bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "stream",
|
||||||
|
stream: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no_stream",
|
||||||
|
stream: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, test := range tests {
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
stallTimer := time.NewTimer(initialTimeout)
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fn := func(response api.ChatResponse) error {
|
||||||
|
// Fields that must always be present
|
||||||
|
if response.Model == "" {
|
||||||
|
t.Errorf("response missing model: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Done {
|
||||||
|
// Required fields for final updates:
|
||||||
|
var nilTime time.Time
|
||||||
|
if response.CreatedAt == nilTime {
|
||||||
|
t.Errorf("final response missing total_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.DoneReason == "" {
|
||||||
|
t.Errorf("final response missing done_reason: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.TotalDuration == 0 {
|
||||||
|
t.Errorf("final response missing total_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.LoadDuration == 0 {
|
||||||
|
t.Errorf("final response missing load_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.PromptEvalDuration == 0 {
|
||||||
|
t.Errorf("final response missing prompt_eval_duration: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.EvalCount == 0 {
|
||||||
|
t.Errorf("final response missing eval_count: %#v", response)
|
||||||
|
}
|
||||||
|
if response.Metrics.EvalDuration == 0 {
|
||||||
|
t.Errorf("final response missing eval_duration: %#v", response)
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.Metrics.PromptEvalCount == 0 {
|
||||||
|
t.Errorf("final response missing prompt_eval_count: %#v", response)
|
||||||
|
}
|
||||||
|
} // else incremental response, nothing to check right now...
|
||||||
|
buf.Write([]byte(response.Message.Content))
|
||||||
|
if !stallTimer.Reset(streamTimeout) {
|
||||||
|
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan int)
|
||||||
|
var genErr error
|
||||||
|
go func() {
|
||||||
|
req.Stream = &test.stream
|
||||||
|
req.Options["seed"] = rand.Int() // bust cache for prompt eval results
|
||||||
|
genErr = client.Chat(ctx, &req, fn)
|
||||||
|
done <- 0
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-stallTimer.C:
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
t.Errorf("chat never started. Timed out after :%s", initialTimeout.String())
|
||||||
|
} else {
|
||||||
|
t.Errorf("chat stalled. Response so far:%s", buf.String())
|
||||||
|
}
|
||||||
|
case <-done:
|
||||||
|
if genErr != nil {
|
||||||
|
t.Fatalf("failed with %s request prompt %v", req.Model, req.Messages)
|
||||||
|
}
|
||||||
|
// Verify the response contains the expected data
|
||||||
|
response := buf.String()
|
||||||
|
atLeastOne := false
|
||||||
|
for _, resp := range anyResp {
|
||||||
|
if strings.Contains(strings.ToLower(response), resp) {
|
||||||
|
atLeastOne = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !atLeastOne {
|
||||||
|
t.Errorf("none of %v found in %s", anyResp, response)
|
||||||
|
}
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.Error("outer test context done while waiting for chat")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIListModels(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Make sure we have at least one model so an empty list can be considered a failure
|
||||||
|
if err := PullIfMissing(ctx, client, smol); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.List(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to list models: %s", err)
|
||||||
|
}
|
||||||
|
if len(resp.Models) == 0 {
|
||||||
|
t.Fatalf("list should not be empty")
|
||||||
|
}
|
||||||
|
model := resp.Models[0]
|
||||||
|
if model.Name == "" {
|
||||||
|
t.Errorf("first model name empty: %#v", model)
|
||||||
|
}
|
||||||
|
var nilTime time.Time
|
||||||
|
if model.ModifiedAt == nilTime {
|
||||||
|
t.Errorf("first model modified_at empty: %#v", model)
|
||||||
|
}
|
||||||
|
if model.Size == 0 {
|
||||||
|
t.Errorf("first model size empty: %#v", model)
|
||||||
|
}
|
||||||
|
if model.Digest == "" {
|
||||||
|
t.Errorf("first model digest empty: %#v", model)
|
||||||
|
}
|
||||||
|
verifyModelDetails(t, model.Details)
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyModelDetails(t *testing.T, details api.ModelDetails) {
|
||||||
|
if details.Format == "" {
|
||||||
|
t.Errorf("first model details.format empty: %#v", details)
|
||||||
|
}
|
||||||
|
if details.Family == "" {
|
||||||
|
t.Errorf("first model details.family empty: %#v", details)
|
||||||
|
}
|
||||||
|
if details.ParameterSize == "" {
|
||||||
|
t.Errorf("first model details.parameter_size empty: %#v", details)
|
||||||
|
}
|
||||||
|
if details.QuantizationLevel == "" {
|
||||||
|
t.Errorf("first model details.quantization_level empty: %#v", details)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIShowModel(t *testing.T) {
|
||||||
|
modelName := "llama3.2"
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
resp, err := client.Show(ctx, &api.ShowRequest{Name: modelName})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unable to show model: %s", err)
|
||||||
|
}
|
||||||
|
if resp.License == "" {
|
||||||
|
t.Errorf("%s missing license: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
if resp.Modelfile == "" {
|
||||||
|
t.Errorf("%s missing modelfile: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
if resp.Parameters == "" {
|
||||||
|
t.Errorf("%s missing parameters: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
if resp.Template == "" {
|
||||||
|
t.Errorf("%s missing template: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
// llama3 omits system
|
||||||
|
verifyModelDetails(t, resp.Details)
|
||||||
|
// llama3 ommits messages
|
||||||
|
if len(resp.ModelInfo) == 0 {
|
||||||
|
t.Errorf("%s missing model_info: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
// llama3 omits projectors
|
||||||
|
var nilTime time.Time
|
||||||
|
if resp.ModifiedAt == nilTime {
|
||||||
|
t.Errorf("%s missing modified_at: %#v", modelName, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAPIEmbeddings(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
client, _, cleanup := InitServerConnection(ctx, t)
|
||||||
|
defer cleanup()
|
||||||
|
req := api.EmbeddingRequest{
|
||||||
|
Model: libraryEmbedModels[0],
|
||||||
|
Prompt: "why is the sky blue?",
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||||
|
t.Fatalf("pull failed %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := client.Embeddings(ctx, &req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("embeddings call failed %s", err)
|
||||||
|
}
|
||||||
|
if len(resp.Embedding) == 0 {
|
||||||
|
t.Errorf("zero length embedding response")
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user